首页 > 解决方案 > 我怎么能在某些条件下获得 numpy 数组索引

问题描述

我遇到了这样的问题:假设我有这样的数组: a = np.array([[1,2,3,4,5,4,3,2,1],]) label = np.array([[1,0,1,0,0,1,1,0,1],]) 我需要获取a元素值为label1 的位置的索引,并且 的值a是导致为 1 的最大数量label

可能会混淆,在上面的例子中,label1的索引是:0、2、5、6、8,它们对应的值a是:1、3、4、3、1,其中4是最大的,因此我需要得到 5 的结果,它是 4 中数字的索引a。我怎么能用 numpy 做到这一点?

标签: pythonarraysnumpyindexing

解决方案


获取1s索引说 as idx,然后a用它索引,获取max索引,最后通过索引将其追溯到原始顺序idx-

idx = np.flatnonzero(label==1)
out = idx[a[idx].argmax()]

样品运行 -

# Assuming inputs to be 1D
In [18]: a
Out[18]: array([1, 2, 3, 4, 5, 4, 3, 2, 1])

In [19]: label
Out[19]: array([1, 0, 1, 0, 0, 1, 1, 0, 1])

In [20]: idx = np.flatnonzero(label==1)

In [21]: idx[a[idx].argmax()]
Out[21]: 5

对于a整数和 andlabel的数组0s1s我们可以进一步优化,因为我们可以a根据其中的值范围进行缩放,就像这样 -

(label*(a.max()-a.min()+1) + a).argmax()

此外,如果a只有正数,它将简化为 -

(label*(a.max()+1) + a).argmax()

正整数的时序较大a -

In [115]: np.random.seed(0)
     ...: a = np.random.randint(0,10,(100000))
     ...: label = np.random.randint(0,2,(100000))

In [117]: %%timeit
     ...: idx = np.flatnonzero(label==1)
     ...: out = idx[a[idx].argmax()]
1000 loops, best of 3: 592 µs per loop

In [116]: %timeit (label*(a.max()-a.min()+1) + a).argmax()
1000 loops, best of 3: 357 µs per loop

# @coldspeed's soln
In [120]: %timeit np.ma.masked_where(~label.astype(bool), a).argmax()
1000 loops, best of 3: 1.63 ms per loop

# won't work with negative numbers in a
In [119]: %timeit (label*(a.max()+1) + a).argmax()
1000 loops, best of 3: 292 µs per loop

# @klim's soln (won't work with negative numbers in a)
In [121]: %timeit np.argmax(a * (label == 1))
1000 loops, best of 3: 229 µs per loop

推荐阅读