如何找到张量对象中每一行的最大索引? [英] How to find the max index for each row in a tensor object?

查看:57
本文介绍了如何找到张量对象中每一行的最大索引?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

所以我正在创建一个 pytorch 模型,对于前向传递,我正在应用我的前向传递方法来获取包含每个类的预测分数的分数张量.这个张量的形状是 [100, 10].现在,我想通过将其与包含实际分数的 y 进行比较来获得准确度.该张量的形状为 [100].为了比较两者,我将使用 torch.mean(scores == y) 并计算有多少是相同的.

So I'm creating a pytorch model and for the forward pass, I'm applying my forward pass method to get the scores tensor which contains the prediction scores for each class. The shape of this tensor is [100, 10]. Now, I want to get the accuracy by comparing it to y which contains the actual scores. This tensor has the shape [100]. To compare the two I'll be using torch.mean(scores == y) and I'll count how many are the same.

问题是我需要转换分数张量,以便每一行只包含每一行中最高值的索引.例如,如果张量看起来像这样,

The problem is that I need to convert the scores tensor so that each row simply contains the index of the highest value in each row. For example if the tensor looked like this,

tensor(
    [[0.3232, -0.2321, 0.2332, -0.1231, 0.2435, 0.6728],

    [0.2323, -0.1231, -0.5321, -0.1452, 0.5435, 0.1722],

    [0.9823, -0.1321, -0.6433, 0.1231, 0.023, 0.0711]]
)

然后我希望它被转换成这样.

Then I'd want it to be converted so that it looks like this.

tensor([5, 4, 0])

我怎么能这样做?

推荐答案

使用 argmax 和所需的 dim(又名轴)

Use argmax with desired dim (a.k.a. axis)

a = tensor(
    [[0.3232, -0.2321, 0.2332, -0.1231, 0.2435, 0.6728],
    [0.2323, -0.1231, -0.5321, -0.1452, 0.5435, 0.1722],
    [0.9823, -0.1321, -0.6433, 0.1231, 0.023, 0.0711]]
)

a.argmax(1)
# tensor([ 5,  4,  0])

这篇关于如何找到张量对象中每一行的最大索引?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆