如何将pytorch中的标签转换为onehot [英] How to transform labels in pytorch to onehot

查看:70
本文介绍了如何将pytorch中的标签转换为onehot的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

如何给 target_transform 一个函数将标签更改为 onehot 编码?

How to give target_transform a function for changing the labels to onehot encoding?

例如torchvision中的MNIST数据集:

For example, the MNIST dataset in torchvision:

train_dataset = torchvision.datasets.MNIST(root='./mnist_data/', 
                                           train=True,
                                           download=True,
                                           transform=train_transform,
                                           target_transform=<????>)

尝试了 F.onehot() 但没有奏效.

Tried F.onehot() but it didn't work.

推荐答案

我就是这样实现的.不确定是否有更清洁的方法.

This is how I implemented it. Not sure if there's a cleaner way.

train_dataset = torchvision.datasets.MNIST(root='./data/', train=True,
                                 transform=torchvision.transforms.ToTensor(),
                                 target_transform=torchvision.transforms.Compose([
                                 lambda x:torch.LongTensor([x]), # or just torch.tensor
                                 lambda x:F.one_hot(x,10)]),
                                 download=True)

  • 它需要是一个索引张量?即 int64

    不能使用 torchvision.ToTensor 因为它不是图像

    Can't use torchvision.ToTensor because it's not an image

    还有 torch.LongTensortorch.tensorint 输入的行为不同

    Also torch.LongTensor and torch.tensor behave differently with int input

    需要提供类的数量

    这篇关于如何将pytorch中的标签转换为onehot的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆