首页 > 解决方案 > 对于 NumPy 2D 数组中的每个非零元素,计算到最近非零元素的欧几里得距离的高效和 Pythonic 方法

问题描述

我有一个形状为 M × N 的二维 NumPy 数组,其中许多值设置为 0,其他值 ≠ 0。
以下是上述矩阵的示例:

A = np.array([[0, 0, 0, 1, 0, 2, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 3, 0], [1, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 6, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0]])

这就是它的格式:

A = [[0 0 0 1 0 2 0 0]
     [0 0 0 0 0 0 0 0]
     [0 0 0 0 1 0 0 0]
     [0 1 0 0 0 0 0 0]
     [0 0 0 1 0 0 3 0]
     [1 0 0 0 0 0 0 0]
     [0 0 0 0 6 0 0 0]
     [0 0 1 0 0 0 0 0]]

我的任务是找到,对于二维数组 (A) 中的每个非零元素(例如 1、2、1、1、1、3、1、6 和 1),到最近的非零元素(除了自身)的距离为欧几里得距离的平均值,然后用计算出的距离创建一个列表 (L)。
以下不变量必须保持:

if np.count_nonzero(A) < 2:
    assert len(L) == 0
else
    assert np.count_nonzero(A) == len(L)

数组 A 的计算如下:

则列表 L 为L = [2, 2, 2.24, 2.24, 2.24, 2.83, 2.24, 2.24, 2.24]

我写了下面的代码来解决这个问题,我认为它可以正常工作,但是它有两个问题:它是(M²×N²)时间复杂度的幼稚,蛮力和非向量化的解决方案,并且不是很清晰,简洁和简洁; 也就是说,它不是 Pythonic。

def get_distance_list(A):
    L = []
    for (m, n), a_mn in np.ndenumerate(A):
        # skip this element if its value is 0
        if a_mn == 0:
            continue
        d_min = math.inf
        for (k, l), a_kl in np.ndenumerate(A):
            # skip this element if its value is 0 or if it's me
            if a_kl == 0 or (m, n) == (k, l):
                continue
            d = scipy.spatial.distance.euclidean((m, n), (k, l))
            d_min = min(d_min, d)
        # in case there are less than two nonzero values in the matrix,
        # the returned list must be empty, so only add the distance
        # if it's different than the default value of +inf
        if d_min != math.inf:
            L.append(d_min)
    return L

你知道是否有一个内置函数(可能在 NumPy、SciPy、SciKit 等中)可以替代我写的那个,或者是否有更快/矢量化和更 Pythonic 的方法来解决这个问题?

标签: pythonarraysnumpymatrix

解决方案


我认为 usingscipy.spatial.KDTree非常适合这个。

from scipy.spatial import KDTree

nonzeros = np.transpose(np.nonzero(A))
t = KDTree(nonzeros)
dists, nns = t.query(nonzeros, 2)

for (i, j), d in zip(nns, dists[:,1]):
    print(nonzeros[i], "is closest to", nonzeros[j], "with distance", d)

结果:

[0 3] is closest to [0 5] with distance 2.0
[0 5] is closest to [0 3] with distance 2.0
[2 4] is closest to [0 5] with distance 2.23606797749979
[3 1] is closest to [4 3] with distance 2.23606797749979
[4 3] is closest to [3 1] with distance 2.23606797749979
[4 6] is closest to [2 4] with distance 2.8284271247461903
[5 0] is closest to [3 1] with distance 2.23606797749979
[6 4] is closest to [4 3] with distance 2.23606797749979
[7 2] is closest to [6 4] with distance 2.23606797749979

推荐阅读