首页 > 解决方案 > 一系列矩阵的快速乘法

问题描述

最快的运行方式是什么:

    reduce(lambda x,y : x@y, ls)

在蟒蛇?

对于矩阵列表ls。我没有 Nvidia GPU,但我确实有很多 CPU 内核可以使用。我以为我可以使该过程并行工作(将其拆分为log迭代),但似乎对于小 ( 1000x1000) 矩阵,这实际上是最糟糕的。这是我尝试过的代码:

from multiprocessing import Pool
import numpy as np
from itertools import zip_longest

def matmul(x):
    if x[1] is None:
        return x[0]
    return x[1]@x[0]

def fast_mul(ls):
    while True:
        
        n = len(ls)
        if n == 0:
            raise Exception("Splitting Error")
        if n == 1:
            return ls[0]
        if n == 2:
            return ls[1]@ls[0]

        with Pool(processes=(n//2+1)) as pool:
            ls = pool.map(matmul, list(zip_longest(*[iter(ls)]*2)))
    

标签: pythonnumpymatrix-multiplication

解决方案


有一个功能可以做到这一点:np.linalg.multi_dot,据说为最佳评估顺序进行了优化:

np.linalg.multi_dot(ls)

事实上,文档说的非常接近你原来的措辞:

认为multi_dot是:

def multi_dot(arrays): return functools.reduce(np.dot, arrays)

您也可以尝试np.einsum,这将允许您将最多 25 个矩阵相乘:

from string import ascii_lowercase

ls = [...]
index = ','.join(ascii_lowercase[x:x + 2] for x in range(len(ls)))
index += f'->{index[0]}{index[-1]}'
np.einsum(index, *ls)

定时

简单案例:

ls = np.random.rand(100, 1000, 1000) - 0.5

%timeit reduce(lambda x, y : x @ y, ls)
4.3 s ± 76.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
4.35 s ± 84.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
4.86 s ± 68.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
5.24 s ± 66.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

更复杂的情况:

ls = [x.T if i % 2 else x for i, x in enumerate(np.random.rand(100, 2000, 500) - 0.5)]

%timeit reduce(lambda x, y : x @ y, ls)
7.94 s ± 96.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
7.91 s ± 33.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
9.38 s ± 111 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
2.03 s ± 52.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

请注意,在简单的情况下,由 完成的前期工作multi_dot有负面的好处(更令人惊讶的是,lambda它比原始运算符工作得更快),但在不太简单的情况下节省了 75% 的时间。

因此,为了完整起见,这里有一个不那么不方形的情况:

ls = [x.T if i % 2 else x for i, x in enumerate(np.random.rand(100, 400, 300) - 0.5)]

%timeit reduce(lambda x, y : x @ y, ls)
245 ms ± 8.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
245 ms ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
284 ms ± 12.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
638 ms ± 12.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

所以实际上,对于大多数一般情况,您的原始reduce通话实际上与您需要的一样好。我唯一的建议是使用operator.matmul而不是 lambda。


推荐阅读