首页 > 解决方案 > 如何在numpy中找到最近的邻居?

问题描述

有两个数组 u 和 v。

u.shape = (N,d) v.shape = (q,d)

我需要为每个 q 找到 u 中每个 d 的最接近值的索引。

例如:

u = [[5,3],
     [3,4],
     [3,2],
     [8,7]] , shape (4,2)
v = [[1,3],
     [2,4]] , shape (2,2)

我发现很多人说我们可以做到:

v = v.expand_dims(v,axis=1) # reshape to (2,1,2) for broadcast

result = np.argmin(abs(v-u),axis=1) # (u-v).shape = (2,4,2)

当然它找到了最接近的值的索引。但!当有两个最接近的值时,我需要采用“第二个”的索引。

在这种情况下:

v-u = [[[-4,  0],
        [-2, -1],
        [-2,  1],
        [-7, -4]],

       [[-3,  1],
        [-1,  0],
        [-1,  2],
        [-6, -3]]])

沿着axis=1,(uv)[0,:,0]有两个-2,(uv)[1,:,0]有两个-1如果我们直接用:

result = np.argmin(abs(v-u),axis=1)

结果将是:

array([[1, 0],
       [1, 1]], dtype=int64)

它返回对应于第一次出现的索引,但我需要第二个,即

array([[2, 0],
       [2, 1]], dtype=int64)

任何人都可以帮忙吗?谢谢!

标签: pythonnumpy

解决方案


如果最多可以有 2 个最小值,则可以检索最后一个最小值的索引。

去做吧:

  • 沿轴 1反向abs(vu) ,
  • 计算argmin,得到一个“reversed_index”(实际上是反向数组中的索引),
  • 使用u.shape[0] - 1 - <reversed_index>公式映射回“原始”索引 (在 4 行的情况下,反向索引 == 3对应于 原始索引 == 0

整个代码是:

u.shape[0] - 1 - np.argmin(abs(v-u)[:,::-1,:],axis=1)

当最小值可能超过2 个时,另一种选择是为一输入数组编写一个专门的argmin版本,如果它们更多,则返回第二个最小值的索引:

def argmin2(arr):
    ind = arr.argpartition(1)[:2]
    return ind[0] if arr[ind[0]] < arr[ind[1]] else ind[1]

然后沿轴1将其应用于abs(vu)

np.apply_along_axis(argmin2, 1, abs(v-u))

推荐阅读