首页 > 解决方案 > 如何在 PyTorch 中计算点集和线之间的成对距离?

问题描述

点集A是一个Nx3矩阵,从两个相同大小的B点集我们可以得到它们之间的线。现在我想计算从 in 中的每个点到中的每条线的距离。is和is ,那么这些线来自具有相应行的点,矩阵也是如此。基本方法计算如下:CMx3BCABCBMx3CMx3BCMx3

D = torch.zeros((N, M), dtype=torch.float32)
for i in range(N):
    p = A[i]  # 1x3
    for j in range(M):
        p1 = B[j] # 1x3
        p2 = C[j] # 1x3
        D[i,j] = torch.norm(torch.cross(p1 - p2, p - p1)) / torch.norm(p1 - p2) 

有没有更快的方法来完成这项工作?谢谢。

标签: pythonpytorchtorch

解决方案


您可以通过这样做来删除for循环(它应该以内存为代价加速,除非MN很小):

diff_B_C = B - C
diff_A_C = A[:, None] - C
norm_lines = torch.norm(diff_B_C, dim=-1)
cross_result = torch.cross(diff_B_C[None, :].expand(N, -1, -1), diff_A_C, dim=-1)
norm_cross = torch.norm(cross_result, dim=-1)
D = norm_cross / norm_lines

当然,你不需要一步一步来。我只是想清楚变量名。

注意:如果您不提供dimto torch.cross,它将使用第一个dim=3如果N=3(来自文档)会给出错误的结果:

如果没有给出 dim ,则默认为找到的第一个尺寸为 3 的尺寸。

如果你想知道,你可以在这里查看我为什么选择expand而不是repeat.


推荐阅读