首页 > 解决方案 > 如何更快地计算 (3, 2000) ndarray 中的选择?

问题描述

有没有办法加快以下两行代码的速度?

choice = np.argmax(cust_profit, axis=0) 
taken = np.array([np.sum(choice == i) for i in range(n_pr)])
%timeit np.argmax(cust_profit, axis=0)
37.6 µs ± 222 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit np.array([np.sum(choice == i) for i in range(n_pr)])
40.2 µs ± 206 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
n_pr == 2
cust_profit.shape == (n_pr+1, 2000)

解决方案:

%timeit np.unique(choice, return_counts=True)
53.7 µs ± 190 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit np.histogram(choice, bins=np.arange(n_pr + 2))
70.5 µs ± 205 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit np.bincount(choice)
7.4 µs ± 17.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

这些微秒让我担心,因为这段代码位于两层 scipy.optimize.minimize(method='Nelder-Mead') 下,位于双嵌套循环中,因此 40µs 等于 4 小时。而且我认为将其全部包含在基因搜索中。

标签: pythonnumpy

解决方案


第一行似乎很简单。除非您可以对数据或类似的东西进行排序,否则您将陷入np.argmax. 第二行可以通过使用 numpy 而不是 vanilla python 来实现它来加速:

v, counts = np.unique(choice, return_counts=True)

或者:

counts = np.histogram(choice, bins=np.arange(n_pr + 2))

histogram还存在一个针对整数优化的版本:

count = np.bincount(choice)

如果您想保证 bin 包含 的所有可能值,则后两个选项会更好choice,无论它们是否实际存在于数组中。

话虽如此,您可能不应该担心需要几微秒的事情。


推荐阅读