python - 有效地返回数组中第一个满足条件的值的索引
问题描述
我需要在满足条件的一维 NumPy 数组或 Pandas 数字系列中找到第一个值的索引。数组很大,索引可能在数组的开头或结尾附近,或者根本不满足条件。我无法提前判断哪个更有可能。如果不满足条件,则返回值应为-1
。我考虑了几种方法。
尝试 1
# func(arr) returns a Boolean array
idx = next(iter(np.where(func(arr))[0]), -1)
但这通常太慢,因为在整个func(arr)
数组上应用矢量化函数而不是在满足条件时停止。具体来说,当在数组的开始附近满足条件时,它是昂贵的。
尝试 2
np.argmax
稍微快一点,但无法识别何时从未满足条件:
np.random.seed(0)
arr = np.random.rand(10**7)
assert next(iter(np.where(arr > 0.999999)[0]), -1) == np.argmax(arr > 0.999999)
%timeit next(iter(np.where(arr > 0.999999)[0]), -1) # 21.2 ms
%timeit np.argmax(arr > 0.999999) # 17.7 ms
np.argmax(arr > 1.0)
返回0
,即条件不满足时的实例。
尝试 3
# func(arr) returns a Boolean scalar
idx = next((idx for idx, val in enumerate(arr) if func(arr)), -1)
但是当在数组末尾附近满足条件时,这太慢了。这可能是因为生成器表达式因大量__next__
调用而产生了昂贵的开销。
这是否总是一种折衷方案,或者对于 generic 是否有办法func
有效地提取第一个索引?
基准测试
对于基准测试,假设func
在值大于给定常数时找到索引:
# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0
import numpy as np
np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999
# Start of array benchmark
%timeit next(iter(np.where(arr > m)[0]), -1) # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1) # 2.5 µs
# End of array benchmark
%timeit next(iter(np.where(arr > n)[0]), -1) # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1) # 39.2 ms
解决方案
numba
numba
可以优化这两种情况。从语法上讲,您只需要构造一个带有简单for
循环的函数:
from numba import njit
@njit
def get_first_index_nb(A, k):
for i in range(len(A)):
if A[i] > k:
return i
return -1
idx = get_first_index_nb(A, 0.9)
Numba 通过 JIT(“及时”)编译代码和利用CPU 级优化来提高性能。没有装饰器的常规 for
循环@njit
通常会比您已经尝试过的方法慢,以应对条件迟到的情况。
对于 Pandas 数字系列df['data']
,您可以简单地将 NumPy 表示提供给 JIT 编译的函数:
idx = get_first_index_nb(df['data'].values, 0.9)
概括
由于numba
允许函数作为参数,并且假设传递的函数也可以进行 JIT 编译,因此您可以找到一种方法来计算满足任意 条件的第nfunc
个索引。
@njit
def get_nth_index_count(A, func, count):
c = 0
for i in range(len(A)):
if func(A[i]):
c += 1
if c == count:
return i
return -1
@njit
def func(val):
return val > 0.9
# get index of 3rd value where func evaluates to True
idx = get_nth_index_count(arr, func, 3)
对于倒数第三个值,您可以提供反向 ,arr[::-1]
并否定来自 的结果len(arr) - 1
,这- 1
是考虑 0 索引的必要条件。
性能基准测试
# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0
np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999
@njit
def get_first_index_nb(A, k):
for i in range(len(A)):
if A[i] > k:
return i
return -1
def get_first_index_np(A, k):
for i in range(len(A)):
if A[i] > k:
return i
return -1
%timeit get_first_index_nb(arr, m) # 375 ns
%timeit get_first_index_np(arr, m) # 2.71 µs
%timeit next(iter(np.where(arr > m)[0]), -1) # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1) # 2.5 µs
%timeit get_first_index_nb(arr, n) # 204 µs
%timeit get_first_index_np(arr, n) # 44.8 ms
%timeit next(iter(np.where(arr > n)[0]), -1) # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1) # 39.2 ms
推荐阅读
- ruby - 如何更改 GEM 包管理器版本,而不是特定的 gem/module/package
- android - 在非标准构建类型的情况下,androidTest 依赖项的未解析符号
- r - 如何展平非原子函数结果,以便可以将其分配为 dplyr mutate 步骤的一部分?
- dictionary - 以字符串为键和任何值的字典
- python - 什么是线性池化层?
- angular - 我如何从角度组件调用 HTML 标记函数
- javascript - 如何从网站下载嵌入式视频?
- installation - LoadRunner 2020 安装中的致命错误
- python - 获取 IndexError:列表索引超出范围错误
- c# - 为什么图像没有输入变量?