按 2d 张量中的值索引 pytorch 4d 张量 [英] Index pytorch 4d tensor by values in 2d tensor

查看:42
本文介绍了按 2d 张量中的值索引 pytorch 4d 张量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有两个 pytorch 张量:

I have two pytorch tensors:

  • X 带形状 (A, B, C, D)
  • I 带形状 (A, B)
  • X with shape (A, B, C, D)
  • I with shape (A, B)

I 中的值是 [0, C) 范围内的整数.

Values in I are integers in range [0, C).

获得形状为 (A, B, D) 的张量 Y 的最有效方法是什么,例如:

What is the most efficient way to get tensor Y with shape (A, B, D), such that:

Y[i][j][k] = X[i][j][ I[i][j] ][k]

推荐答案

您可能想要使用 torch.gather 用于索引和 expand 调整 I 到需要的大小:

You probably want to use torch.gather for the indexing and expand to adjust I to the required size:

eI = I[..., None, None].expand(-1, -1, 1, X.size(3))  # make eI the same for the last dimension
Y = torch.gather(X, dim=2, index=eI).squeeze()

测试代码:

A = 3 
B = 4 
C = 5 
D = 7

X = torch.rand(A, B, C, D)
I = torch.randint(0, C, (A, B), dtype=torch.long)

eI = I[..., None, None].expand(-1, -1, 1, X.size(3))
Y = torch.gather(X, dim=2, index=eI).squeeze()

# manually gather
refY = torch.empty(A, B, D)
for i in range(A):
    for j in range(B):
        refY[i, j, :] = X[i, j, I[i,j], :]

(refY == Y).all()
# Out[]: tensor(1, dtype=torch.uint8)

这篇关于按 2d 张量中的值索引 pytorch 4d 张量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆