如何将pytorch中的标签转换为onehot [英] How to transform labels in pytorch to onehot
本文介绍了如何将pytorch中的标签转换为onehot的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
如何给 target_transform
一个函数将标签更改为 onehot 编码?
How to give target_transform
a function for changing the labels to onehot encoding?
例如torchvision中的MNIST数据集:
For example, the MNIST dataset in torchvision:
train_dataset = torchvision.datasets.MNIST(root='./mnist_data/',
train=True,
download=True,
transform=train_transform,
target_transform=<????>)
尝试了 F.onehot()
但没有奏效.
Tried F.onehot()
but it didn't work.
推荐答案
我就是这样实现的.不确定是否有更清洁的方法.
This is how I implemented it. Not sure if there's a cleaner way.
train_dataset = torchvision.datasets.MNIST(root='./data/', train=True,
transform=torchvision.transforms.ToTensor(),
target_transform=torchvision.transforms.Compose([
lambda x:torch.LongTensor([x]), # or just torch.tensor
lambda x:F.one_hot(x,10)]),
download=True)
它需要是一个
索引张量
?即 int64不能使用
torchvision.ToTensor
因为它不是图像Can't use
torchvision.ToTensor
because it's not an image还有
torch.LongTensor
和torch.tensor
与int
输入的行为不同Also
torch.LongTensor
andtorch.tensor
behave differently withint
input需要提供类的数量
这篇关于如何将pytorch中的标签转换为onehot的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
查看全文