首页 > 解决方案 > 如何有效地计算pytorch中两组不同大小的3D张量的距离矩阵?

问题描述

我有形状为 BxCxHxW 的张量 X 和形状为 NxCxHxW 的 Y。B 是批量大小,C 是通道,H 是高度,W 是宽度,对于任何批次,N 都是恒定的。基本上我想要一组 B 图像和另一组 N 图像之间的距离的 BxN 距离矩阵。

我尝试使用 torch.cdist 通过将 X 重塑为 1xBx(C*H*W) 并将 Y 重塑为 1xNx(C*H*W) 通过解压缩尺寸并展平最后 3 个通道来使用,但我进行了健全性检查并得到了错误的答案用这种方法。

我想要L2距离。

标签: pythonpytorchdistance

解决方案


根据 的文档页面torch.cdist,两个输入和输出的形状如下:x1: (B, P, M)x2:(B, R, M)output: (B, P, R)

为了匹配您的情况:B=1, P=B, R=N, while M=C*H*W展平)。正如你刚才解释的那样。

所以你基本上是为了:

>>> torch.cdist(X[None].flatten(2), Y[None].flatten(2))

如果你不服气,你可以用下面的方法检查:

>>> dist = []
>>> for x in X:
...    for y in Y:
...       dist.append((x-y).norm())

并将torch.cdist结果与torch.tensor(dist).reshape(len(X), len(Y))​​ .


推荐阅读