python - 如何有效地计算pytorch中两组不同大小的3D张量的距离矩阵?
问题描述
我有形状为 BxCxHxW 的张量 X 和形状为 NxCxHxW 的 Y。B 是批量大小,C 是通道,H 是高度,W 是宽度,对于任何批次,N 都是恒定的。基本上我想要一组 B 图像和另一组 N 图像之间的距离的 BxN 距离矩阵。
我尝试使用 torch.cdist 通过将 X 重塑为 1xBx(C*H*W) 并将 Y 重塑为 1xNx(C*H*W) 通过解压缩尺寸并展平最后 3 个通道来使用,但我进行了健全性检查并得到了错误的答案用这种方法。
我想要L2距离。
解决方案
根据 的文档页面torch.cdist
,两个输入和输出的形状如下:x1
: (B, P, M)
、x2
:(B, R, M)
和output
: (B, P, R)
。
为了匹配您的情况:B=1
, P=B
, R=N
, while M=C*H*W
(即展平)。正如你刚才解释的那样。
所以你基本上是为了:
>>> torch.cdist(X[None].flatten(2), Y[None].flatten(2))
如果你不服气,你可以用下面的方法检查:
>>> dist = []
>>> for x in X:
... for y in Y:
... dist.append((x-y).norm())
并将torch.cdist
结果与torch.tensor(dist).reshape(len(X), len(Y))
.
推荐阅读
- javascript - 如何处理覆盖固定元素的Android键盘?
- python - 如何通过在 python 中使用 selenium webdriver 从 html 表中的特定行获取数据
- python - 使用 lambda 从另一个函数传递参数
- html - 为什么div有内容时位置会发生变化?
- oracle - 为 Oracle 数据库禁用 VS 2019 SSRS 数据源构建按钮
- javascript - React 路由器(多页面应用)
- symfony - 如何在树枝中使用 svg 精灵和符号?
- c# - 如何在c#中将带有前导0的字符串数字转换为双精度数
- qml - QML:如何注册用户自定义快捷方式
- javascript - 如何在 JavaScript 中附加函数