pytorch:如何在二维张量的每一行中找到第一个非零元素的索引? [英] Pytorch: How can I find indices of first nonzero element in each row of a 2D tensor?

查看:341
本文介绍了pytorch:如何在二维张量的每一行中找到第一个非零元素的索引?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个二维张量,每行中都有一些非零元素,像这样:

I have a 2D tensor with some nonzero element in each row like this:

import torch
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                    [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)

我想要一个张量,在每行中包含第一个非零元素的索引:

I want a tensor containing the index of first nonzero element in each row:

indices = tensor([2],
                 [3])

如何在Pytorch中计算它?

How can I calculate it in Pytorch?

推荐答案

我可以为我的问题找到一个棘手的答案:

I could find a tricky answer for my question:

  tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
                     [0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
  idx = reversed(torch.Tensor(range(1,8)))
  print(idx)

  tmp2= torch.einsum("ab,b->ab", (tmp, idx))

  print(tmp2)

  indices = torch.argmax(tmp2, 1, keepdim=True)
  print(indeces)

结果是:

tensor([7., 6., 5., 4., 3., 2., 1.])
tensor([[0., 0., 5., 0., 3., 0., 0.],
       [0., 0., 0., 4., 3., 0., 0.]])
tensor([[2],
        [3]])

这篇关于pytorch:如何在二维张量的每一行中找到第一个非零元素的索引?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆