如何在pytorch中动态索引张量? [英] How to dynamically index the tensor in pytorch?
问题描述
例如,我得到了张量:
tensor = torch.rand(12, 512, 768)
我得到了一个索引列表,说是:
And I got an index list, say it is:
[0,2,3,400,5,32,7,8,321,107,100,511]
在给定索引列表的情况下,我希望从维度2的512个元素中选择1个元素.然后,张量的大小将变为(12,1,768)
.
I wish to select 1 element out of 512 elements on dimension 2 given the index list. And then the tensor's size would become (12, 1, 768)
.
有办法吗?
推荐答案
还有一种方法是仅使用PyTorch并使用 indexing 和 torch.split
:
There is also a way just using PyTorch and avoiding the loop using indexing and torch.split
:
tensor = torch.rand(12, 512, 768)
# create tensor with idx
idx_list = [0,2,3,400,5,32,7,8,321,107,100,511]
# convert list to tensor
idx_tensor = torch.tensor(idx_list)
# indexing and splitting
list_of_tensors = tensor[:, idx_tensor, :].split(1, dim=1)
当调用 tensor [:, idx_tensor,:]
时,您将得到一个形状为张量的张量:
(12,len_of_idx_list,768)
.
第二维取决于索引的数量.
When you call tensor[:, idx_tensor, :]
you will get a tensor of shape:
(12, len_of_idx_list, 768)
.
Where the second dimension depends on your number of indices.
使用 torch.split
将该张量分成形状为(12,1,768)
.
Using torch.split
this tensor is split into a list of tensors of shape: (12, 1, 768)
.
所以最终 list_of_tensors
包含以下形状的张量:
So finally list_of_tensors
contains tensors of the shape:
[torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768]),
torch.Size([12, 1, 768])]
这篇关于如何在pytorch中动态索引张量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!