多维张量的前 K 个指数 [英] Top K indices of a multi-dimensional tensor

查看:35
本文介绍了多维张量的前 K 个指数的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个二维张量,我想获得前 k 个值的索引.我知道 pytorch 的 topk 功能.pytorch 的 topk 函数的问题在于,它计算某个维度上的 topk 值.我想获得两个维度的 topk 值.

I have a 2D tensor and I want to get the indices of the top k values. I know about pytorch's topk function. The problem with pytorch's topk function is, it computes the topk values over some dimension. I want to get topk values over both dimensions.

例如下面的张量

a = torch.tensor([[4, 9, 7, 4, 0],
        [8, 1, 3, 1, 0],
        [9, 8, 4, 4, 8],
        [0, 9, 4, 7, 8],
        [8, 8, 0, 1, 4]])

pytorch 的 topk 函数会给我以下信息.

pytorch's topk function will give me the following.

values, indices = torch.topk(a, 3)

print(indices)
# tensor([[1, 2, 0],
#        [0, 2, 1],
#        [0, 1, 4],
#        [1, 4, 3],
#        [1, 0, 4]])

但我想得到以下内容

tensor([[0, 1],
        [2, 0],
        [3, 1]])

这是二维张量中 9 的索引.

This is the indices of 9 in the 2D tensor.

是否有任何方法可以使用 pytorch 实现此目的?

Is there any approach to achieve this using pytorch?

推荐答案

v, i = torch.topk(a.flatten(), 3)
print (np.array(np.unravel_index(i.numpy(), a.shape)).T)

输出:

[[3 1]
 [2 0]
 [0 1]]

  1. 展平并找到顶部 k
  2. 使用 unravel_index
  3. 将一维索引转换为二维索引
  1. Flatten and find top k
  2. Convert 1D indices to 2D using unravel_index

这篇关于多维张量的前 K 个指数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆