deep-learning - 如何在 PyTorch 中有效地计算批量成对距离
问题描述
我有形状的张量 X 和形状的BxNxD
Y BxNxD
。
我想计算批次中每个元素的成对距离,即我是一个BxMxN
张量。
我该怎么做呢?
这里有一些关于这个话题的讨论:https ://github.com/pytorch/pytorch/issues/9406 ,但我不明白,因为有很多实现细节,而没有突出显示实际的解决方案。
一种天真的方法是使用此处讨论的非批量成对距离的答案:https ://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065 ,即
import torch
import numpy as np
B = 32
N = 128
M = 256
D = 3
X = torch.from_numpy(np.random.normal(size=(B, N, D)))
Y = torch.from_numpy(np.random.normal(size=(B, M, D)))
def pairwise_distances(x, y=None):
x_norm = (x**2).sum(1).view(-1, 1)
if y is not None:
y_t = torch.transpose(y, 0, 1)
y_norm = (y**2).sum(1).view(1, -1)
else:
y_t = torch.transpose(x, 0, 1)
y_norm = x_norm.view(1, -1)
dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
return torch.clamp(dist, 0.0, np.inf)
out = []
for b in range(B):
out.append(pairwise_distances(X[b], Y[b]))
print(torch.stack(out).shape)
我怎样才能在不循环 B 的情况下做到这一点?谢谢
解决方案
我有一个类似的问题,并花了一些时间来找到最简单和最快的解决方案。现在您可以使用 PyTorch cdist计算批量距离,这将为您提供BxMxN
张量:
torch.cdist(Y, X)
此外,如果您只想计算两个矩阵的每对行之间的距离,它也很有效。
推荐阅读
- typescript - 在打字稿中打印类型对象
- javascript - 如何在 Ant Table List 上正确设置切换开关?
- list - 在 Python 3 中混合 n 个大小为 m 的列表以创建一个矩阵(m*m*m*..n 次)
- c# - 我无法从 C# 向 GMAIL 发送电子邮件
- python - Pandas/xarray - 基于另一个数据帧动态水平移动值
- sql - MS ACCESS 报告中的求和函数太复杂
- android - 我有空字段,但改造不发送到后端
- django-rest-framework - DRF (ListAPIView) - 如何在查询集中应用搜索过滤器
- spring-boot - 就绪探针发出停止服务错误
- html - 如何在不损失质量的情况下使图像变小?