python - 是否可以使用 Numpy 实现这个版本的矩阵乘法?
问题描述
我希望快速评估下面的函数,它在高层次上类似于矩阵乘法。对于大型矩阵,下面的实现比矩阵的 numpy 乘法慢几个数量级,这让我相信有更好的方法来实现这个使用 numpy. 有没有办法使用 numpy 函数而不是 for 循环来实现这一点?我正在使用的矩阵在每个维度上都有 10K-100K 的元素,因此非常需要这种优化。
一种方法是使用 3D numpy 数组,但这证明太大而无法存储。我还研究了似乎不合适的 np.vectorize 。
非常感谢您的指导。
编辑:感谢大家的精彩见解和答案。非常感谢您的意见。将日志移到循环之外大大提高了运行时间,有趣的是,k
查找很重要。如果可以的话,我有一个跟进:如果内部循环表达式变为 ,你能看到一种加速的方法C[i,j] += A[i,k] * np.log(A[i,k] + B[k,j])
吗?可以像以前一样将日志移出,但前提A[i,k]
是要取幂,这很昂贵并且消除了移出日志的收益。
import numpy as np
import numba
from numba import njit, prange
@numba.jit(fastmath=True, parallel=True)
def f(A, B):
C = np.zeros((A.shape[0], B.shape[1]))
for i in prange(A.shape[0]):
for j in prange(B.shape[1]):
for k in prange(A.shape[1]):
C[i,j] += np.log(A[i,k] + B[k,j])
#matrix mult. would be: C[i,j] += A[i,k] * B[k,j]
return C
#A = np.random.rand(100000, 100000)
#B = np.random.rand(100000, 100000)
#f(A, B)
解决方案
由于log(a) + log(b) == log(a * b)
,您可以通过用乘法替换加法并仅在最后执行对数来节省大量对数计算,这应该可以节省大量时间。
import numpy as np
import numba as nb
@nb.njit(fastmath=True, parallel=True)
def f(A, B):
C = np.ones((A.shape[0], B.shape[1]), A.dtype)
for i in nb.prange(A.shape[0]):
for j in nb.prange(B.shape[1]):
# Accumulate product
for k in nb.prange(A.shape[1]):
C[i,j] *= (A[i,k] + B[k,j])
# Apply logarithm at the end
return np.log(C)
# For comparison
@nb.njit(fastmath=True, parallel=True)
def f_orig(A, B):
C = np.zeros((A.shape[0], B.shape[1]), A.dtype)
for i in nb.prange(A.shape[0]):
for j in nb.prange(B.shape[1]):
for k in nb.prange(A.shape[1]):
C[i,j] += np.log(A[i,k] + B[k,j])
return C
# Test
np.random.seed(0)
a, b = np.random.random((1000, 100)), np.random.random((100, 2000))
print(np.allclose(f(a, b), f_orig(a, b)))
# True
%timeit f(a, b)
# 36.2 ms ± 2.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit f_orig(a, b)
# 296 ms ± 3.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
推荐阅读
- ios - 无法从 Crashlytics 重置 Firebase 链接
- android - 从主要活动捆绑到片段
- firebase - 提交带有单选按钮的表单以将它们存储在 firebase 调查表中
- javascript - 输入焦点的浏览器自动填充问题
- excel - 调用时显示单独定位的用户窗体跳转
- meteor - 模板助手中的异常:TypeError:无法读取未定义的属性“mergedSchema”
- docker - Docker:安装 apt-utils 时遇到问题
- android - 如何在 React Native for Android 中为文本输入定义插入符号颜色?
- wildfly-swarm - Wildfly Swarm 2018.5.0 不启动 HTTPS 监听
- python - 将参数传递给python中的装饰器