pytorch:如何在二维张量的每一行中找到第一个非零元素的索引? [英] Pytorch: How can I find indices of first nonzero element in each row of a 2D tensor?
本文介绍了pytorch:如何在二维张量的每一行中找到第一个非零元素的索引?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
我有一个二维张量,每行中都有一些非零元素,像这样:
I have a 2D tensor with some nonzero element in each row like this:
import torch
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
我想要一个张量,在每行中包含第一个非零元素的索引:
I want a tensor containing the index of first nonzero element in each row:
indices = tensor([2],
[3])
如何在Pytorch中计算它?
How can I calculate it in Pytorch?
推荐答案
我可以为我的问题找到一个棘手的答案:
I could find a tricky answer for my question:
tmp = torch.tensor([[0, 0, 1, 0, 1, 0, 0],
[0, 0, 0, 1, 1, 0, 0]], dtype=torch.float)
idx = reversed(torch.Tensor(range(1,8)))
print(idx)
tmp2= torch.einsum("ab,b->ab", (tmp, idx))
print(tmp2)
indices = torch.argmax(tmp2, 1, keepdim=True)
print(indeces)
结果是:
tensor([7., 6., 5., 4., 3., 2., 1.])
tensor([[0., 0., 5., 0., 3., 0., 0.],
[0., 0., 0., 4., 3., 0., 0.]])
tensor([[2],
[3]])
这篇关于pytorch:如何在二维张量的每一行中找到第一个非零元素的索引?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
查看全文