python - 一系列矩阵的快速乘法
问题描述
最快的运行方式是什么:
reduce(lambda x,y : x@y, ls)
在蟒蛇?
对于矩阵列表ls
。我没有 Nvidia GPU,但我确实有很多 CPU 内核可以使用。我以为我可以使该过程并行工作(将其拆分为log
迭代),但似乎对于小 ( 1000x1000
) 矩阵,这实际上是最糟糕的。这是我尝试过的代码:
from multiprocessing import Pool
import numpy as np
from itertools import zip_longest
def matmul(x):
if x[1] is None:
return x[0]
return x[1]@x[0]
def fast_mul(ls):
while True:
n = len(ls)
if n == 0:
raise Exception("Splitting Error")
if n == 1:
return ls[0]
if n == 2:
return ls[1]@ls[0]
with Pool(processes=(n//2+1)) as pool:
ls = pool.map(matmul, list(zip_longest(*[iter(ls)]*2)))
解决方案
有一个功能可以做到这一点:np.linalg.multi_dot
,据说为最佳评估顺序进行了优化:
np.linalg.multi_dot(ls)
事实上,文档说的非常接近你原来的措辞:
认为
multi_dot
是:def multi_dot(arrays): return functools.reduce(np.dot, arrays)
您也可以尝试np.einsum
,这将允许您将最多 25 个矩阵相乘:
from string import ascii_lowercase
ls = [...]
index = ','.join(ascii_lowercase[x:x + 2] for x in range(len(ls)))
index += f'->{index[0]}{index[-1]}'
np.einsum(index, *ls)
定时
简单案例:
ls = np.random.rand(100, 1000, 1000) - 0.5
%timeit reduce(lambda x, y : x @ y, ls)
4.3 s ± 76.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
4.35 s ± 84.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
4.86 s ± 68.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
5.24 s ± 66.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
更复杂的情况:
ls = [x.T if i % 2 else x for i, x in enumerate(np.random.rand(100, 2000, 500) - 0.5)]
%timeit reduce(lambda x, y : x @ y, ls)
7.94 s ± 96.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
7.91 s ± 33.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
9.38 s ± 111 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
2.03 s ± 52.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
请注意,在简单的情况下,由 完成的前期工作multi_dot
有负面的好处(更令人惊讶的是,lambda
它比原始运算符工作得更快),但在不太简单的情况下节省了 75% 的时间。
因此,为了完整起见,这里有一个不那么不方形的情况:
ls = [x.T if i % 2 else x for i, x in enumerate(np.random.rand(100, 400, 300) - 0.5)]
%timeit reduce(lambda x, y : x @ y, ls)
245 ms ± 8.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.matmul, ls)
245 ms ± 12.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit reduce(np.dot, ls)
284 ms ± 12.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.linalg.multi_dot(ls)
638 ms ± 12.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
所以实际上,对于大多数一般情况,您的原始reduce
通话实际上与您需要的一样好。我唯一的建议是使用operator.matmul
而不是 lambda。
推荐阅读
- javascript - 如何使用javascript在名称属性上添加一个?
- mysql - mysql中同一张表的联合
- type-inference - 证明任意嵌套的 Vect 别名是可显示的
- c# - 将默认光标更改为自定义光标
- laravel - 将多个相同的名称转换为一个名称,并且值在 laravel 中得到总和
- azure - 如何从 DevOps api 获取团队头像
- openlayers - 开放层。矢量瓷砖,边缘的样式特征
- dragonruby-game-toolkit - 如何在 DragonRuby Game Toolkit 中的“sprite”上渲染“solid”?
- go - 如何更正“if”构造中的类型比较
- vim - Vim 视觉模式在 cygwin 中表现异常