首页 > 解决方案 > Cumsum 重启

问题描述

每次超过阈值 10000 时,我都想对数据进行分类。

我试过这个没有运气:

# data which is an array of floats

diff = np.diff(np.cumsum(data)//10000, prepend=0)

indices = (np.argwhere(diff > 0)).flatten()

问题是所有的箱子都不包含 10000,这是我的目标。


预期产出

input_data = [4000, 5000, 6000, 2000, 8000, 3000]
# (4000+5000+6000 >= 10000. Index 2)
# (2000+8000 >= 10000. Index 4)
Output: [2, 4]

我想知道是否有任何替代 for 循环?

标签: pythonnumpy

解决方案


不确定如何对它进行矢量化,如果它甚至可以,因为通过获取累积总和,您将在每次超过阈值时传播余数。所以这可能是一个很好的例子numba,它将代码编译到 C 级别,允许一个循环但高性能的方法:

from numba import njit, int32

@njit('int32[:](int32[:], uintc)')
def windowed_cumsum(a, thr):
    indices = np.zeros(len(a), int32) 
    window = 0
    ix = 0
    for i in range(len(a)):
        window += a[i]
        if window >= thr:
            indices[ix] = i
            ix += 1
            window = 0
    return indices[:ix]

显式签名意味着提前编译,尽管这会在输入数组上强制执行特定的 dtype。示例数组的推断 dtype 是 of int32,但如果情况并非总是如此,或者对于更灵活的解决方案,您始终可以忽略签名中的 dtypes,这仅意味着该函数将在第一次执行时编译。

input_data = np.array([4000, 5000, 6000, 2000, 8000, 3000])

windowed_cumsum(input_data, 10000)
# array([2, 4])

@jdehesa 还提出了一个有趣的观点,即对于与 bin 数量相比非常长的数组,更好的选择可能是将索引附加到列表中。因此,这是一种使用列表的替代方法(也在 no-python 模式下),以及不同场景下的时间:

from numba import njit, int32

@njit
def windowed_cumsum_list(a, thr):
    indices = []
    window = 0
    for i in range(len(a)):
        window += a[i]
        if window >= thr:
            indices.append(i)
            window = 0
    return indices

a = np.random.randint(0,10,10_000)

%timeit windowed_cumsum(a, 20)
# 16.1 µs ± 232 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit windowed_cumsum_list(a, 20)
# 65.5 µs ± 623 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit windowed_cumsum(a, 2000)
# 7.38 µs ± 167 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit windowed_cumsum_list(a, 2000)
# 7.1 µs ± 103 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

因此,在大多数情况下,使用 numpy 似乎是一个更快的选择,因为即使在第二种情况下,使用长度10000数组和生成20的 bin 索引数组,两者的性能相似,尽管出于内存效率的原因,后者可能更方便在某些情况下。


推荐阅读