python - 对有滞后的列求和的最快方法
问题描述
给定一个方阵,我想将每一行移动其行号并对列求和。例如:
array([[0, 1, 2], array([[0, 1, 2],
[3, 4, 5], -> [3, 4, 5], -> array([0, 1+3, 2+4+6, 5+7, 8]) = array([0, 4, 12, 12, 8])
[6, 7, 8]]) [6, 7, 8]])
我有 4 个解决方案 - fast
、和,它们的作用完全相同,并按速度排名slow
:slower
slowest
def fast(A):
n = A.shape[0]
retval = np.zeros(2*n-1)
for i in range(n):
retval[i:(i+n)] += A[i, :]
return retval
def slow(A):
n = A.shape[0]
indices = np.arange(n)
indices = indices + indices[:,None]
return np.bincount(indices.ravel(), A.ravel())
def slower(A):
r, _ = A.shape
retval = np.c_[A, np.zeros((r, r), dtype=A.dtype)].ravel()[:-r].reshape(r, -1)
return retval.sum(0)
def slowest(A):
n = A.shape[0]
retval = np.zeros(2*n-1)
indices = np.arange(n)
indices = indices + indices[:,None]
np.add.at(retval, indices, A)
return retval
令人惊讶的是,非矢量化解决方案是最快的。这是我的基准:
A = np.random.randn(1000,1000)
%timeit fast(A)
# 1.85 ms ± 20 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit slow(A)
# 3.28 ms ± 9.55 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit slower(A)
# 4.07 ms ± 18.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit slowest(A)
# 58.4 ms ± 993 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
是否存在更快的解决方案?如果没有,有人可以解释为什么实际上fast
是最快的吗?
编辑
略微加速slow
:
def slow(A):
n = A.shape[0]
indices = np.arange(2*n-1)
indices = np.lib.stride_tricks.as_strided(indices, A.shape, (8,8))
return np.bincount(indices.ravel(), A.ravel())
以与 Pierre 相同的方式绘制运行时(以 2**15 作为上限 - 由于某种原因slow
无法处理此大小)
slow
对于.numba
的数组,比任何解决方案(不使用 )略快 仍然是数组的最佳选择。100x100
sum_antidiagonals
1000x1000
解决方案
这是一种有时比您的“ ”版本更快的方法,但仅限于数组fast()
的有限范围n
(大约在 30 到 1000 之间) 。即使使用,n x n
循环 ( fast()
)在大型数组上也很难被击败numba
,并且实际上渐近收敛到 simple 的时间a.sum(axis=0)
,这表明这与大型数组的效率差不多(感谢您的循环!)
该方法(我将称之为sum_antidiagonals()
)用于np.add.reduce()
条纹版本a
和来自相对较小的 1D 阵列的合成蒙版,该阵列被条纹以创建 2D 阵列的错觉(不消耗更多内存)。
此外,它不限于方形数组(但fast()
也可以很容易地适应这种泛化,请参见fast_g()
本文底部)。
def sum_antidiagonals(a):
assert a.flags.c_contiguous
r, c = a.shape
s0, s1 = a.strides
z = np.lib.stride_tricks.as_strided(
a, shape=(r, c+r-1), strides=(s0 - s1, s1), writeable=False)
# mask
kern = np.r_[np.repeat(False, r-1), np.repeat(True, c), np.repeat(False, r-1)]
mask = np.fliplr(np.lib.stride_tricks.as_strided(
kern, shape=(r, c+r-1), strides=(1, 1), writeable=False))
return np.add.reduce(z, where=mask)
请注意,它不限于方数组:
>>> sum_antidiagonals(np.arange(15).reshape(5,3))
array([ 0, 4, 12, 21, 30, 24, 14])
解释
为了理解它是如何工作的,让我们用一个例子来检查这些条带数组。
给定一个初始数组a
:(3, 2)
a = np.arange(6).reshape(3, 2)
>>> a
array([[0, 1],
[2, 3],
[4, 5]])
# after calculating z in the function
>>> z
array([[0, 1, 2, 3],
[1, 2, 3, 4],
[2, 3, 4, 5]])
你可以看到它几乎是我们想要的sum(axis=0)
,除了上下三角形是不需要的。我们真正想要总结的是:
array([[0, 1, -, -],
[-, 2, 3, -],
[-, -, 4, 5]])
输入掩码,我们可以从一维内核开始构建它:
kern = np.r_[np.repeat(False, r-1), np.repeat(True, c), np.repeat(False, r-1)]
>>> kern
array([False, False, True, True, False, False])
我们使用了一个有趣的 slice: (1, 1)
,这意味着我们重复同一行,但每次滑动一个元素:
>>> np.lib.stride_tricks.as_strided(
... kern, shape=(r, c+r-1), strides=(1, 1), writeable=False)
array([[False, False, True, True],
[False, True, True, False],
[ True, True, False, False]])
然后我们只需将其向左/向右翻转,并将其where
用作np.add.reduce()
.
速度
b = np.random.normal(size=(1000, 1000))
# check equivalence with the OP's fast() function:
>>> np.allclose(fast(b), sum_antidiagonals(b))
True
%timeit sum_antidiagonals(b)
# 1.83 ms ± 840 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit fast(b)
# 2.07 ms ± 15.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
在这种情况下,它会快一点,但只有大约 10%。
在 300x300 阵列上,sum_antidiagonals()
比fast()
.
然而!
尽管放在一起z
并且mask
非常快(在np.add.reduce()
上面的 1000x1000 示例中,之前的整个设置只需要 46 µs),但总和本身是O[r (r+c)]
,即使只需要O[r c]
实际的加法(其中mask == True
)。因此,对于方形阵列,大约要多做 2 倍的操作。
在 10K x 10K 阵列上,这赶上了我们:
fast
需要 95 毫秒,而sum_antidiagonals
需要 208 毫秒。
通过尺寸范围进行比较
我们将使用可爱的perfplot
包通过以下范围比较多种方法的速度n
:
perfplot.show(
setup=lambda n: np.random.normal(size=(n, n)),
kernels=[just_sum_0, fast, fast_g, nb_fast_i, nb_fast_ij, sum_antidiagonals],
n_range=[2 ** k for k in range(3, 16)],
equality_check=None, # because of just_sum_0
xlabel='n',
relative_to=1,
)
观察
- 如您所见,
sum_antidiagonals()
速度优势fast()
被限制在n
大约 30 到 1000 之间的范围内。 - 它永远不会击败
numba
版本。 just_sum_0()
,这只是简单的总和axis=0
(因此是一个很好的底线基准,几乎不可能被击败),对于大型阵列来说只是稍微快一点。这一事实表明,fast()
它与大型阵列一样快。- 令人惊讶的是,
numba
在一定大小后会减损(即在前几次运行以“烧入”LLVM 编译之后)。我不确定为什么会这样,但它似乎对大型阵列很重要。
其他功能的完整代码
(包括fast
对非方形数组的简单概括)
from numba import njit
@njit
def nb_fast_ij(a):
# numba loves loops...
r, c = a.shape
z = np.zeros(c + r - 1, dtype=a.dtype)
for i in range(r):
for j in range(c):
z[i+j] += a[i, j]
return z
@njit
def nb_fast_i(a):
r, c = a.shape
z = np.zeros(c + r - 1, dtype=a.dtype)
for i in range(r):
z[i:i+c] += a[i, :]
return z
def fast_g(a):
# generalizes fast() to non-square arrays, also returns the same dtype
r, c = a.shape
z = np.zeros(c + r - 1, dtype=a.dtype)
for i in range(r):
z[i:i+c] += a[i]
return z
def fast(A):
# the OP's code
n = A.shape[0]
retval = np.zeros(2*n-1)
for i in range(n):
retval[i:(i+n)] += A[i, :]
return retval
def just_sum_0(a):
# for benchmarking comparison
return a.sum(axis=0)
推荐阅读
- mongodb - loopback4 中是否有任何选项可以使用 MongoDB 文档验证
- mysql - 客户端之间通过 Google App Engine 上的 websocket 服务器进行实时通信
- mathjax - MathJax 3:粗体 \text{} 可能吗?
- javascript - JavaScript 将总和保存到字段
- node.js - 将选项传递给 Sequelize 钩子不起作用
- css - 在顶部固定一行单元格作为样本
- python - 在浏览器重新加载时破折号重新加载熊猫数据框
- python - 没有从 __new__ 中调用方法
- c++ - 在哪里声明仅在一个函数中使用的局部变量更好?
- javascript - 优化随机颜色的生成