首页 > 解决方案 > Numpy where 与数组比较

问题描述

我有一个名为的数组Y,其中包含类标签。我想找到与列表实验室指定的多个值匹配的 Y 的所有索引。

在这种情况下:

Y = np.array([1,2,3,1,2,3,1,2,3,1,2,3])
labs = [2,3]

我怎么能做这样的事情np.where(Y == labs)返回

array([1,2,4,5,7,8,10,11])

我知道一种可能性是遍历列表实验室并进行元素比较。但我正在寻找一种更加基于 pythonic/numpy 的解决方案,它可以避免循环。

标签: numpyindexingpython-3.7

解决方案


您可以在此处的 [ np.where(..)numpy-doc]上使用np.isin(..)[numpy-doc]

>>> np.where(np.isin(Y, L))[0]
array([ 1,  2,  4,  5,  7,  8, 10, 11])

将为我们.isin(Y, L)提供一个数组,True其中False的项目Y与 中的元素匹配L

>>> np.isin(Y, labs)
array([False,  True,  True, False,  True,  True, False,  True,  True,
       False,  True,  True])

并且np.where(..)我们将Trues 映射到相应的索引。

正如@hpaulj所说,对于 small Ls,我们可以这样写:

np.any([Y == li for li in labs],axis=0)

在这里,对于 中的每个元素labs,我们将检查是否Y是该元素,并使用np.any(..)在它们之间创建“逻辑 OR 链”以将其折叠为布尔值。


推荐阅读