首页 > 解决方案 > 是否可以使用 Numpy 实现这个版本的矩阵乘法?

问题描述

我希望快速评估下面的函数,它在高层次上类似于矩阵乘法。对于大型矩阵,下面的实现比矩阵的 numpy 乘法慢几个数量级,这让我相信有更好的方法来实现这个使用 numpy. 有没有办法使用 numpy 函数而不是 for 循环来实现这一点?我正在使用的矩阵在每个维度上都有 10K-100K 的元素,因此非常需要这种优化。

一种方法是使用 3D numpy 数组,但这证明太大而无法存储。我还研究了似乎不合适的 np.vectorize 。

非常感谢您的指导。

编辑:感谢大家的精彩见解和答案。非常感谢您的意见。将日志移到循环之外大大提高了运行时间,有趣的是,k查找很重要。如果可以的话,我有一个跟进:如果内部循环表达式变为 ,你能看到一种加速的方法C[i,j] += A[i,k] * np.log(A[i,k] + B[k,j])吗?可以像以前一样将日志移出,但前提A[i,k]是要取幂,这很昂贵并且消除了移出日志的收益。

import numpy as np
import numba
from numba import njit, prange

@numba.jit(fastmath=True, parallel=True)
def f(A, B):
    
    C = np.zeros((A.shape[0], B.shape[1]))

    for i in prange(A.shape[0]):
        for j in prange(B.shape[1]):
            for k in prange(A.shape[1]):
                
                C[i,j] += np.log(A[i,k] + B[k,j])
                #matrix mult. would be: C[i,j] += A[i,k] * B[k,j]

    return C

#A = np.random.rand(100000, 100000)
#B = np.random.rand(100000, 100000)
#f(A, B)

标签: pythonarraysnumpymatrix-multiplicationnumba

解决方案


由于log(a) + log(b) == log(a * b),您可以通过用乘法替换加法并仅在最后执行对数来节省大量对数计算,这应该可以节省大量时间。

import numpy as np
import numba as nb

@nb.njit(fastmath=True, parallel=True)
def f(A, B):
    C = np.ones((A.shape[0], B.shape[1]), A.dtype)
    for i in nb.prange(A.shape[0]):
        for j in nb.prange(B.shape[1]):
            # Accumulate product
            for k in nb.prange(A.shape[1]):
                C[i,j] *= (A[i,k] + B[k,j])
    # Apply logarithm at the end
    return np.log(C)

# For comparison
@nb.njit(fastmath=True, parallel=True)
def f_orig(A, B):
    C = np.zeros((A.shape[0], B.shape[1]), A.dtype)
    for i in nb.prange(A.shape[0]):
        for j in nb.prange(B.shape[1]):
            for k in nb.prange(A.shape[1]):
                C[i,j] += np.log(A[i,k] + B[k,j])
    return C

# Test
np.random.seed(0)
a, b = np.random.random((1000, 100)), np.random.random((100, 2000))
print(np.allclose(f(a, b), f_orig(a, b)))
# True
%timeit f(a, b)
# 36.2 ms ± 2.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit f_orig(a, b)
# 296 ms ± 3.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

推荐阅读