将张量转换为一个热编码的索引张量 [英] converting tensor to one hot encoded tensor of indices

查看:23
本文介绍了将张量转换为一个热编码的索引张量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有形状 (1,1,128,128,128) 的标签张量,其中值的范围可能为 0,24.我想使用 nn.fucntional.one_hot 函数

I have my label tensor of shape (1,1,128,128,128) in which the values might range from 0,24. I want to convert this to one hot encoded tensor, using the nn.fucntional.one_hot function

n = 24
one_hot = torch.nn.functional.one_hot(indices, n)

但这需要一个指数张量,老实说,我不确定如何获得这些指数.我唯一的张量是上述形状的标签张量,它包含的值范围为 1-24,而不是索引

but this expects a tensor of indices, honestly, I am not sure how to get those. The only tensor I have is the label tensor of the shape described above and it contains values ranging from 1-24, not the indices

如何从张量中获取索引张量?提前致谢.

How can I get a tensor of indices from my tensor? Thanks in advance.

推荐答案

如果你得到的错误是这个:

If the error you are getting is this one:

Traceback (most recent call last):
    File "<stdin>", line 1, in <module>
RuntimeError: one_hot is only applicable to index tensor.

也许你只需要转换成int64:

import torch

# random Tensor with the shape you said
indices = torch.Tensor(1, 1, 128, 128, 128).random_(1, 24)
# indices.shape => torch.Size([1, 1, 128, 128, 128])
# indices.dtype => torch.float32

n = 24
one_hot = torch.nn.functional.one_hot(indices.to(torch.int64), n)
# one_hot.shape => torch.Size([1, 1, 128, 128, 128, 24])
# one_hot.dtype => torch.int64

您也可以使用 indices.long().

这篇关于将张量转换为一个热编码的索引张量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆