首页 > 解决方案 > 检查numpy数组的子维度中的元素是否在另一个numpy数组的子维度中的快速方法

问题描述

假设我们有两个大小为 (N,M) 的整数 numpy 数组 A 和 B。我想检查每个 i in N, is A[i,:]in B[i,:]

一个for循环的实现是:

for i in range(N):
    C[i] = np.isin(A[i,:],B[i,:])

然而,这对于大型阵列来说相当慢。有没有更快的方法来实现这个?(例如矢量化?)

谢谢!

标签: pythonnumpy

解决方案


这是一种基于每行偏移的矢量化方法,如在Vectorized searchsorted numpy's solution-

# https://stackoverflow.com/a/40588862/ @Divakar
def searchsorted2d(a,b):
    m,n = a.shape
    max_num = np.maximum(a.max() - a.min(), b.max() - b.min()) + 1
    r = max_num*np.arange(a.shape[0])[:,None]
    p = np.searchsorted( (a+r).ravel(), (b+r).ravel() ).reshape(m,-1)
    return p - n*(np.arange(m)[:,None])

def numpy_isin2D(A,B):
    sB = np.sort(B,axis=1)
    idx = searchsorted2d(sB,A)
    idx[idx==sB.shape[1]] = 0
    return np.take_along_axis(sB, idx, axis=1) == A

样品运行 -

In [351]: A
Out[351]: 
array([[5, 0, 3, 3],
       [7, 3, 5, 2],
       [4, 7, 6, 8],
       [8, 1, 6, 7],
       [7, 8, 1, 5]])

In [352]: B
Out[352]: 
array([[8, 4, 3, 0, 3, 5],
       [0, 2, 3, 8, 1, 3],
       [3, 3, 7, 0, 1, 0],
       [4, 7, 3, 2, 7, 2],
       [0, 0, 4, 5, 5, 6]])

In [353]: numpy_isin2D(A,B)
Out[353]: 
array([[ True,  True,  True,  True],
       [False,  True, False,  True],
       [False,  True, False, False],
       [False, False, False,  True],
       [False, False, False,  True]])

推荐阅读