首页 > 解决方案 > 检查 numpy 数组是否在另一个 numpy 数组中并创建一个掩码

问题描述

给定 numpy 数组

a = [[[0 0] [1 0] [2 0]]
     [[0 1] [1 1] [2 1]]
     [[0 2] [1 2] [2 2]]]

和列表 b

b = [[1, 0], [2, 0]]

我怎样才能得到面具c

c = [[False True True]
     [False False False]
     [False False False]]

标签: pythonnumpy

解决方案


您可以使用 numpy 广播功能将 b 中的每个数字对与 b 上的所有数字对进行比较,如下所示

## np.newaxis add a new dimension at that position. missing dimension (i.e 
## dimension with size 1) will repeat to match size of corresponding dimension

a = np.asarray([[[0, 0], [1, 0], [2, 0]],
     [[0, 1], [1, 1], [2, 1]],
     [[0, 2], [1, 2], [2, 2]]])[:,:,np.newaxis,:]

b = np.array([[1, 0], [2, 0]])[np.newaxis,:,:]

(a == b).all(axis=3).any(axis=2)

结果

array([[False,  True,  True],
       [False, False, False],
       [False, False, False]])

推荐阅读