PyTorch 张量:基于旧张量和指数的新张量 [英] PyTorch tensors: new tensor based on old tensor and indices

查看:50
本文介绍了PyTorch 张量:基于旧张量和指数的新张量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我是张量的新手,对这个问题很头疼:

I'm new to tensors and having a headache over this problem:

我有一个大小为 k 的索引张量,其值在 0 到 k-1 之间:

I have an index tensor of size k with values between 0 and k-1:

tensor([0,1,2,0])

和以下矩阵:

tensor([[[0, 9],
     [1, 8],
     [2, 3],
     [4, 9]]])

我想创建一个新的张量,其中包含索引中指定的行,按顺序排列.所以我想要:

I want to create a new tensor which contains the rows specified in index, in that order. So I want:

tensor([[[0, 9],
     [1, 8],
     [2, 3],
     [0, 9]]])

外部张量我或多或少会像这样执行此操作:

Outside tensors I'd do this operation more or less like this:

new_matrix = [matrix[i] for i in index]

如何在 PyTorch 中对张量执行类似的操作?

How do I do something similar in PyTorch on tensors?

推荐答案

您使用 花式索引:

from torch import tensor

index = tensor([0,1,2,0])
t = tensor([[[0, 9],
     [1, 8],
     [2, 3],
     [0, 9]]])

result = t[:, index, :]

得到

tensor([[[0, 9],
         [1, 8],
         [2, 3],
         [0, 9]]])

注意 t.shape == (1, 4, 2) 并且你想在 second 轴上索引;所以我们将它应用在第二个参数中,并通过 :s 即 [:, index, :] 保持其余的相同.

Note that t.shape == (1, 4, 2) and you want to index on the second axis; so we apply it in the second argument and keep the rest the same via :s i.e. [:, index, :].

这篇关于PyTorch 张量:基于旧张量和指数的新张量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆