python - numpy linalg.norm 在 for 循环中使用时获取距离的性能缓慢
问题描述
如果这个问题看起来很基本,我提前致以诚挚的歉意。
鉴于:
import numpy as np
import time
r,q = int(9), int(5) # sample example to print expected results
#r,q = int(4e4), int(3e4) # VERY time consuming
query_pose = np.random.randn(q,3)
cnn_poses = [ np.random.randn(3) for ri in range(r) ]
DIST_TH = 2.0 # meters
目标:
我想计算TopN
匹配和RankN
相应的匹配,因为它们的欧几里德距离(np.linalg.norm
)低于某个阈值(DIST_TH = 2.0 # meters
)。我将结果用于进一步计算和绘图。
现在,我有一个乏味且耗时for loop
的操作如下:
topNs = np.zeros( r )
rankN = np.zeros( r )
ranked_cnn_poses = np.full( ( r, q, 3 ), np.nan )
bt = time.time()
for qi in range(q):
qp = query_pose[qi]
dist_list = list( np.linalg.norm( np.array(qp) - np.array(cnn_poses), axis=1 ) )
for di, dv in enumerate(dist_list):
if any( dv <= DIST_TH for di, dv in enumerate( dist_list[ :(di+1) ] ) ):
topNs[di] += 1
for di, dv in enumerate(dist_list):
if dv <= DIST_TH:
rankN[di] += 1
ranked_cnn_poses[di, qi, :] = np.array(cnn_poses[di])
break # only one match
et = time.time()
print(f">> Took {(et-bt):.3f} s")
如果使用以下命令运行需要很长时间r,q = int(4e4), int(3e4)
:
>> Took 1849.151 s
预期成绩:
recall_N=topNs/q
print('#'*100)
print(f"Top: {topNs.shape}: {topNs}")
print('#'*100)
print(f"rank: {rankN.shape}: {rankN}")
print('#'*100)
print(f"Recall: {recall_N.shape}: {recall_N}")
print('#'*100)
它明显因 而变化np.random.randn
,但它看起来像:
q: 0
>> d: [2.9, 3.2, 2.0, 1.7, 2.0, 2.5, 2.0, 0.5, 3.2]
---------------------------------------------------
q: 1
>> d: [0.9, 1.9, 0.8, 1.1, 1.2, 0.8, 1.3, 1.7, 1.7]
---------------------------------------------------
q: 2
>> d: [3.7, 3.1, 2.7, 2.2, 1.9, 3.3, 3.4, 1.7, 4.4]
---------------------------------------------------
q: 3
>> d: [1.7, 1.5, 0.6, 1.5, 0.9, 1.3, 1.7, 1.4, 2.9]
---------------------------------------------------
q: 4
>> d: [2.6, 4.1, 2.8, 2.5, 3.3, 2.6, 1.8, 2.7, 1.1]
---------------------------------------------------
Top: (9,): [2. 2. 2. 3. 4. 4. 5. 5. 5.]
######################################################
rank: (9,): [2. 0. 0. 1. 1. 0. 1. 0. 0.]
######################################################
Recall: (9,): [0.4 0.4 0.4 0.6 0.8 0.8 1. 1. 1. ]
######################################################
问题:
有没有其他时间有效的方法来解决这个问题?
干杯,
解决方案
推荐阅读
- json - 指定扩展的私钥已存在。重复使用该密钥或先将其删除
- javascript - 使用 vue-router 仅在特定页面上使用动画
- javascript - 如何使用 AsyncStorage 存储和检索多个数据
- laravel - 提交注册表单时出现服务器错误 500
- java - 为什么 session.getAttribute 返回 null?
- getstream-io - 关注汇总提要
- react-native - 远程调试 JS
- docker - 为什么我不能将本地 Docker 映像推送到 Docker Hub 存储库?
- python - 从 Bokeh 中的数据表中单击选择的图表
- go - CI/CD 构建因 go -ldflags 失败