首页 > 解决方案 > numpy linalg.norm 在 for 循环中使用时获取距离的性能缓慢

问题描述

如果这个问题看起来很基本,我提前致以诚挚的歉意。

鉴于:

import numpy as np
import time

r,q = int(9), int(5)      # sample example to print expected results
#r,q = int(4e4), int(3e4) # VERY time consuming

query_pose = np.random.randn(q,3)
cnn_poses  = [ np.random.randn(3) for ri in range(r) ]
DIST_TH    = 2.0 # meters

目标:

我想计算TopN匹配和RankN相应的匹配,因为它们的欧几里德距离(np.linalg.norm)低于某个阈值(DIST_TH = 2.0 # meters)。我将结果用于进一步计算和绘图。

现在,我有一个乏味且耗时for loop的操作如下:

topNs = np.zeros( r )
rankN = np.zeros( r )
ranked_cnn_poses = np.full( ( r, q, 3 ), np.nan )

bt = time.time()
for qi in range(q):
    qp  = query_pose[qi]
    dist_list = list( np.linalg.norm( np.array(qp) - np.array(cnn_poses), axis=1 ) )
    
    for di, dv in enumerate(dist_list):
        if any( dv <= DIST_TH for di, dv in enumerate( dist_list[ :(di+1) ] ) ):    
            topNs[di] += 1
    for di, dv in enumerate(dist_list):
        if dv <= DIST_TH:
            rankN[di] += 1
            ranked_cnn_poses[di, qi, :] = np.array(cnn_poses[di])           
            break  # only one match
    
et = time.time()
print(f">> Took {(et-bt):.3f} s")

如果使用以下命令运行需要很长时间r,q = int(4e4), int(3e4)

>> Took 1849.151 s

预期成绩:

recall_N=topNs/q
print('#'*100)
print(f"Top: {topNs.shape}: {topNs}")
print('#'*100)
print(f"rank: {rankN.shape}: {rankN}")
print('#'*100)
print(f"Recall: {recall_N.shape}: {recall_N}")
print('#'*100)

它明显因 而变化np.random.randn,但它看起来像:

q: 0
>> d: [2.9, 3.2, 2.0, 1.7, 2.0, 2.5, 2.0, 0.5, 3.2]
---------------------------------------------------
q: 1
>> d: [0.9, 1.9, 0.8, 1.1, 1.2, 0.8, 1.3, 1.7, 1.7]
---------------------------------------------------    
q: 2
>> d: [3.7, 3.1, 2.7, 2.2, 1.9, 3.3, 3.4, 1.7, 4.4]
---------------------------------------------------    
q: 3
>> d: [1.7, 1.5, 0.6, 1.5, 0.9, 1.3, 1.7, 1.4, 2.9]
---------------------------------------------------   
q: 4
>> d: [2.6, 4.1, 2.8, 2.5, 3.3, 2.6, 1.8, 2.7, 1.1]
---------------------------------------------------

Top: (9,): [2. 2. 2. 3. 4. 4. 5. 5. 5.]
######################################################
rank: (9,): [2. 0. 0. 1. 1. 0. 1. 0. 0.]
######################################################
Recall: (9,): [0.4 0.4 0.4 0.6 0.8 0.8 1.  1.  1. ]
######################################################

问题:

有没有其他时间有效的方法来解决这个问题?

干杯,

标签: pythonnumpy-ndarrayeuclidean-distance

解决方案


推荐阅读