首页 > 解决方案 > 找到两个 pytorch 张量的不相交

问题描述

提前感谢大家的帮助!我在 PyTorch 中尝试做的事情类似于 numpy's setdiff1d。例如给定以下两个张量:

t1 = torch.tensor([1, 9, 12, 5, 24]).to('cuda:0')
t2 = torch.tensor([1, 24]).to('cuda:0')

预期的输出应该是(排序的或未排序的):

torch.tensor([9, 12, 5])

理想情况下,操作在 GPU 上完成,GPU 和 CPU 之间没有来回。非常感激!

标签: pythonnumpypytorch

解决方案


我遇到了同样的问题,但是在使用更大的数组时,建议的解决方案太慢了。以下简单的解决方案适用于 CPU 和 GPU,并且比其他建议的解决方案要快得多:

combined = torch.cat((t1, t2))
uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1]
intersection = uniques[counts > 1]

推荐阅读