python - 我怎么能在某些条件下获得 numpy 数组索引
问题描述
我遇到了这样的问题:假设我有这样的数组:
a = np.array([[1,2,3,4,5,4,3,2,1],])
label = np.array([[1,0,1,0,0,1,1,0,1],])
我需要获取a
元素值为label
1 的位置的索引,并且 的值a
是导致为 1 的最大数量label
。
可能会混淆,在上面的例子中,label
1的索引是:0、2、5、6、8,它们对应的值a
是:1、3、4、3、1,其中4是最大的,因此我需要得到 5 的结果,它是 4 中数字的索引a
。我怎么能用 numpy 做到这一点?
解决方案
获取1s
索引说 as idx
,然后a
用它索引,获取max
索引,最后通过索引将其追溯到原始顺序idx
-
idx = np.flatnonzero(label==1)
out = idx[a[idx].argmax()]
样品运行 -
# Assuming inputs to be 1D
In [18]: a
Out[18]: array([1, 2, 3, 4, 5, 4, 3, 2, 1])
In [19]: label
Out[19]: array([1, 0, 1, 0, 0, 1, 1, 0, 1])
In [20]: idx = np.flatnonzero(label==1)
In [21]: idx[a[idx].argmax()]
Out[21]: 5
对于a
整数和 andlabel
的数组0s
,1s
我们可以进一步优化,因为我们可以a
根据其中的值范围进行缩放,就像这样 -
(label*(a.max()-a.min()+1) + a).argmax()
此外,如果a
只有正数,它将简化为 -
(label*(a.max()+1) + a).argmax()
正整数的时序较大a
-
In [115]: np.random.seed(0)
...: a = np.random.randint(0,10,(100000))
...: label = np.random.randint(0,2,(100000))
In [117]: %%timeit
...: idx = np.flatnonzero(label==1)
...: out = idx[a[idx].argmax()]
1000 loops, best of 3: 592 µs per loop
In [116]: %timeit (label*(a.max()-a.min()+1) + a).argmax()
1000 loops, best of 3: 357 µs per loop
# @coldspeed's soln
In [120]: %timeit np.ma.masked_where(~label.astype(bool), a).argmax()
1000 loops, best of 3: 1.63 ms per loop
# won't work with negative numbers in a
In [119]: %timeit (label*(a.max()+1) + a).argmax()
1000 loops, best of 3: 292 µs per loop
# @klim's soln (won't work with negative numbers in a)
In [121]: %timeit np.argmax(a * (label == 1))
1000 loops, best of 3: 229 µs per loop
推荐阅读
- sql - 更新后 Wikidata 查询抛出 StackOverflowError
- c# - UWP 可以写入下载文件夹中多种类型的文件,但不能对 SQLite db 文件执行相同操作
- android - 如何将带有降价语法的文本共享到电报?
- python - 在数据框中创建新列
- html - Get_Template_Part 部分未显示 CSS
- python - 有什么方法可以轻松交换字符串中的两个字符?
- android - 活动转换 API 调用意图
- javascript - 使用 Express 将 BigQuery 流式传输到前端
- google-cloud-platform - 气流 trigger_rule=none_failed 不起作用
- c# - 使用 lambda 更改变量值时遇到错误