寻找两个pytorch张量的不交集 [英] Finding non-intersection of two pytorch tensors
本文介绍了寻找两个pytorch张量的不交集的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
提前感谢大家的帮助!我在 PyTorch 中尝试做的事情类似于 numpy 的 setdiff1d
.例如给定以下两个张量:
Thanks everyone in advance for your help! What I'm trying to do in PyTorch is something like numpy's setdiff1d
. For example given the below two tensors:
t1 = torch.tensor([1, 9, 12, 5, 24]).to('cuda:0')
t2 = torch.tensor([1, 24]).to('cuda:0')
预期的输出应该是(已排序或未排序):
The expected output should be (sorted or unsorted):
torch.tensor([9, 12, 5])
理想情况下,操作是在 GPU 上完成的,GPU 和 CPU 之间没有来回.非常感谢!
Ideally the operations are done on GPU and no back and forth between GPU and CPU. Much appreciated!
推荐答案
如果您不想离开 cuda,解决方法可能是:
if you don't want to leave cuda, a workaround could be:
t1 = torch.tensor([1, 9, 12, 5, 24], device = 'cuda')
t2 = torch.tensor([1, 24], device = 'cuda')
indices = torch.ones_like(t1, dtype = torch.uint8, device = 'cuda')
for elem in t2:
indices = indices & (t1 != elem)
intersection = t1[indices]
这篇关于寻找两个pytorch张量的不交集的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
查看全文