寻找两个pytorch张量的不交集 [英] Finding non-intersection of two pytorch tensors

查看:348
本文介绍了寻找两个pytorch张量的不交集的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

提前感谢大家的帮助!我在 PyTorch 中尝试做的事情类似于 numpy 的 setdiff1d.例如给定以下两个张量:

Thanks everyone in advance for your help! What I'm trying to do in PyTorch is something like numpy's setdiff1d. For example given the below two tensors:

t1 = torch.tensor([1, 9, 12, 5, 24]).to('cuda:0')
t2 = torch.tensor([1, 24]).to('cuda:0')

预期的输出应该是(已排序或未排序):

The expected output should be (sorted or unsorted):

torch.tensor([9, 12, 5])

理想情况下,操作是在 GPU 上完成的,GPU 和 CPU 之间没有来回.非常感谢!

Ideally the operations are done on GPU and no back and forth between GPU and CPU. Much appreciated!

推荐答案

如果您不想离开 cuda,解决方法可能是:

if you don't want to leave cuda, a workaround could be:

t1 = torch.tensor([1, 9, 12, 5, 24], device = 'cuda')
t2 = torch.tensor([1, 24], device = 'cuda')
indices = torch.ones_like(t1, dtype = torch.uint8, device = 'cuda')
for elem in t2:
    indices = indices & (t1 != elem)  
intersection = t1[indices]  

这篇关于寻找两个pytorch张量的不交集的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆