python - 如何在 PyTorch 中计算点集和线之间的成对距离?
问题描述
点集A
是一个Nx3
矩阵,从两个相同大小的B
点集我们可以得到它们之间的线。现在我想计算从 in 中的每个点到中的每条线的距离。is和is ,那么这些线来自具有相应行的点,矩阵也是如此。基本方法计算如下:C
Mx3
BC
A
BC
B
Mx3
C
Mx3
BC
Mx3
D = torch.zeros((N, M), dtype=torch.float32)
for i in range(N):
p = A[i] # 1x3
for j in range(M):
p1 = B[j] # 1x3
p2 = C[j] # 1x3
D[i,j] = torch.norm(torch.cross(p1 - p2, p - p1)) / torch.norm(p1 - p2)
有没有更快的方法来完成这项工作?谢谢。
解决方案
您可以通过这样做来删除for
循环(它应该以内存为代价加速,除非M
和N
很小):
diff_B_C = B - C
diff_A_C = A[:, None] - C
norm_lines = torch.norm(diff_B_C, dim=-1)
cross_result = torch.cross(diff_B_C[None, :].expand(N, -1, -1), diff_A_C, dim=-1)
norm_cross = torch.norm(cross_result, dim=-1)
D = norm_cross / norm_lines
当然,你不需要一步一步来。我只是想清楚变量名。
注意:如果您不提供dim
to torch.cross
,它将使用第一个dim=3
如果N=3
(来自文档)会给出错误的结果:
如果没有给出 dim ,则默认为找到的第一个尺寸为 3 的尺寸。
如果你想知道,你可以在这里查看我为什么选择expand
而不是repeat
.
推荐阅读
- python - 如何使用 UNISWAP API 获取代币价格
- openrefine - 在 OpenRefine 的大文本中删除某些行中字符之前的所有内容
- three.js - 3D 模型中的三框工具提示
- common-lisp - 如何将动态生成的 Lisp 系统保存在外部文件中?
- java - 我是否需要端口转发才能在同一网络上的两台机器之间进行通信?
- amazon-rds - 设置 Amazon RDS 数据库的推荐安全设置是什么
- azure-storage - Azure Blob 存储同步:获取接触文件的列表
- javascript - 猫鼬总是保存错误的日期
- flutter - 下面的构造函数在 dart 中是如何工作的,我已经提取了小部件,flutter 为我的小部件提供了下面的构造函数
- java - Javadoc {@value} 不适用于常量