如何在pytorch中动态索引张量? [英] How to dynamically index the tensor in pytorch?

查看:67
本文介绍了如何在pytorch中动态索引张量?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

例如,我得到了张量:

tensor = torch.rand(12, 512, 768)

我得到了一个索引列表,说是:

And I got an index list, say it is:

[0,2,3,400,5,32,7,8,321,107,100,511]

在给定索引列表的情况下,我希望从维度2的512个元素中选择1个元素.然后,张量的大小将变为(12,1,768).

I wish to select 1 element out of 512 elements on dimension 2 given the index list. And then the tensor's size would become (12, 1, 768).

有办法吗?

推荐答案

还有一种方法是仅使用PyTorch并使用 indexing torch.split :

There is also a way just using PyTorch and avoiding the loop using indexing and torch.split:

tensor = torch.rand(12, 512, 768)

# create tensor with idx
idx_list = [0,2,3,400,5,32,7,8,321,107,100,511]
# convert list to tensor
idx_tensor = torch.tensor(idx_list) 

# indexing and splitting
list_of_tensors = tensor[:, idx_tensor, :].split(1, dim=1)

当调用 tensor [:, idx_tensor,:] 时,您将得到一个形状为张量的张量:
(12,len_of_idx_list,768).
第二维取决于索引的数量.

When you call tensor[:, idx_tensor, :] you will get a tensor of shape:
(12, len_of_idx_list, 768).
Where the second dimension depends on your number of indices.

使用 torch.split 将该张量分成形状为(12,1,768).

Using torch.split this tensor is split into a list of tensors of shape: (12, 1, 768).

所以最终 list_of_tensors 包含以下形状的张量:

So finally list_of_tensors contains tensors of the shape:

[torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768]),
 torch.Size([12, 1, 768])]

这篇关于如何在pytorch中动态索引张量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆