python - 为什么 np.linalg.norm(..., axis=1) 比写出向量范数的公式慢?
问题描述
要将矩阵的行标准化X
为单位长度,我通常使用:
X /= np.linalg.norm(X, axis=1, keepdims=True)
尝试针对算法优化此操作时,我很惊讶地发现在我的机器上写出规范化的速度大约快 40%:
X /= np.sqrt(X[:,0]**2+X[:,1]**2+X[:,2]**2)[:,np.newaxis]
X /= np.sqrt(sum(X[:,i]**2 for i in range(X.shape[1])))[:,np.newaxis]
怎么会?性能损失在np.linalg.norm()
哪里?
import numpy as np
X = np.random.randn(10000,3)
%timeit X/np.linalg.norm(X,axis=1, keepdims=True)
# 276 µs ± 4.55 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit X/np.sqrt(X[:,0]**2+X[:,1]**2+X[:,2]**2)[:,np.newaxis]
# 169 µs ± 1.38 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit X/np.sqrt(sum(X[:,i]**2 for i in range(X.shape[1])))[:,np.newaxis]
# 185 µs ± 4.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
我在支持 OpenBLAS 的 MacbookPro 2015 上(1) python3.6 + numpy v1.17.2
观察到这一点。(2) python3.9 + numpy v1.19.3
我不认为这是这篇文章的副本,它解决了矩阵范数,而这篇文章是关于向量的 L2 范数。
解决方案
row-wise L2-norm的源代码归结为以下代码行:
def norm(x, keepdims=False):
x = np.asarray(x)
s = x**2
return np.sqrt(s.sum(axis=(1,), keepdims=keepdims))
简化代码假定实值x
并利用np.add.reduce(s, ...)
等价于s.sum(...)
.
因此,OP 问题与询问为什么np.sum(x,axis=1)
慢于sum(x[:,i] for i in range(x.shape[1]))
:
%timeit X.sum(axis=1, keepdims=False)
# 131 µs ± 1.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit sum(X[:,i] for i in range(X.shape[1]))
# 36.7 µs ± 91.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
这个问题已经在这里回答了。简而言之,减少 ( .sum(axis=1)
) 伴随着开销成本,这些开销通常在浮点精度和速度方面得到回报(例如缓存机制、并行性),但在仅减少三列的特殊情况下则不会。在这种情况下,与实际计算相比,开销相对较大。
如果X
有更多列,情况会发生变化。现在,numpy-boosted 规范化比使用 python for 循环的归约要快得多:
X = np.random.randn(10000,100)
%timeit X/np.linalg.norm(X,axis=1, keepdims=True)
# 3.36 ms ± 132 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit X/np.sqrt(sum(X[:,i]**2 for i in range(X.shape[1])))[:,np.newaxis]
# 5.92 ms ± 168 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
在这里可以找到另一个相关的 SO 线程:numpy ufuncs vs. for loop。
问题仍然是为什么 numpy 没有明确地处理常见的减少特殊情况(例如对具有低轴维度的矩阵的列或行的求和)。可能是因为这种优化的效果往往在很大程度上取决于目标机器,并大大增加了代码的复杂性。
推荐阅读
- chrome-canary - 如何让 WebGPU 在 Chrome Canary 97 中运行?
- python - 我想通过输入拆分将其分为命令、键和值,并输入到字典中
- c# - 如何使用 IdentityServer4 配置 Google 身份验证以避免外部身份验证错误?
- html - 仅使用 CSS 检查单选按钮后切换 div 内容
- mongodb - mongodb如何找到12级以上的餐厅
- php - 是否可以限制用户可以使用 woocommerce 进行的变化量
- python - 在 django 上处理选择多个
- java - 用给定的一组线段计算最大距离,形成粘三角形
- javascript - 所有大写字母和所有非字母字符的正则表达式是什么
- c# - C#中具有列表值的字典的类型转换