首页 > 解决方案 > scipy 的 numba,特别是 cdist

问题描述

我正在尝试加快两个点云之间的比较,我有一些代码需要一个小时才能完成。我已经把它屠杀了,并试图实现 numba。该代码与 scipy cdist 函数除外。这是我第一次使用 numba 测试,我哪里出错了?

from numba import jit

@jit(nopython=True)
def near_dist_top(T, B):
    xi = [i[0] for i in T]
    yi = [i[1] for i in T]
    zi = [i[2] for i in T]
    XB = B

    insert_params = []

    for i in range(len(T)):
        XA = [T[i]]
        disti = cdist(XA, XB, metric='euclidean').min()
        insert_params.append((xi[i], yi[i], zi[i], disti))
        # print("Top: " + str(i) + " of " + str(len(T)))
        print(i)
    return insert_params
    print(XB)

@@@ 编辑@@@

T 和 B 都是坐标列表

(580992.507, 4275268.8321, 192.4599), (580992.507, 4275268.8391, 192.4209), (580992.507, 4275268.8391, 192.4209)

嗯,numba 是否处理列表,它是否需要是一个 numpy 数组,cdist 是否会处理一个 numpy 数组...?

错误

numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Untyped global name 'cdist': cannot determine Numba type of <class 'function'>

File "scratch_28.py", line 132:
def near_dist_top(T, B):
    <source elided>
        XA = [T[i]]
        disti = cdist(XA, XB, metric='euclidean').min()
        ^

标签: python-3.xscipynumba

解决方案


推荐阅读