pytorch 等效 tf.gather [英] pytorch equivalent tf.gather

查看:22
本文介绍了pytorch 等效 tf.gather的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在将一些代码从 tensorflow 移植到 pytorch 时遇到了一些问题.

I'm having some trouble porting some code over from tensorflow to pytorch.

所以我有一个尺寸为 10x30 的矩阵,代表 10 个示例,每个示例具有 30 个特征.然后我有另一个维度为 10x5 的矩阵,其中包含第一个矩阵中每个示例的 5 个最接近示例的索引.我想使用第二个矩阵中包含的索引收集"第一个矩阵中每个示例的 5 个最接近的示例,留下形状为 10x5x30 的 3d 张量.

So I have a matrix with dimensions 10x30 representing 10 examples each with 30 features. Then I have another matrix with dimensions 10x5 containing indices of the the 5 closest examples for each examples in the first matrix. I want to 'gather' using the indices contained in the second matrix the 5 closet examples for each example in the first matrix leaving me with a 3d tensor of shape 10x5x30.

在 tensorflow 中,这是通过 tf.gather(matrix1, matrix2) 完成的.有谁知道我如何在 pytorch 中做到这一点?

In tensorflow this is done with tf.gather(matrix1, matrix2). Does anyone know how i could do this in pytorch?

推荐答案

这个怎么样?

matrix1 = torch.randn(10, 30)
matrix2 = torch.randint(high=10, size=(10, 5))
gathered = matrix1[matrix2]

它使用了用整数数组进行索引的技巧.

It uses the trick of indexing with an array of integers.

这篇关于pytorch 等效 tf.gather的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆