python - 如何在 Python 中高效地计算矩阵乘积的稀疏值?
问题描述
我想以一种在内存使用和计算时间上有效的方式计算矩阵乘积的某些值。问题是中间矩阵有两个非常大的维度,可能无法存储。
具有示例值的尺寸:
N = 7 # very large
K = 3
M = 10 # very large
L = 8 # very very large
'a' 是形状 (N,K)
的矩阵 'b' 是形状 (K,N) 的矩阵
a = np.arange(N*K).reshape(N,K)
b = np.arange(K*M).reshape(K,M)
rows 是一个索引数组,其值在 range(N) 和长度 L 内
cols 是一个索引数组,其值在 range(M) 和长度 L 内
rows = [0,0,1,2,3,3,4,6]
cols = [0,9,5,8,2,8,3,6]
我需要以下内容,但由于其大小,无法计算(a @ b)
具有形状 (MxN) 作为中间结果的矩阵:
values = (a @ b)[rows, cols]
另一种实现可能涉及对 a[rows] 和 b[:,cols] 进行切片,创建形状为 (L,K) 和 (K,L) 的矩阵,但这些矩阵也太大了。Numpy 在进行花式切片时复制值
values = np.einsum("ij,ji->i", a[rows], b[:,cols])
提前致谢
解决方案
一种可能性是直接计算结果。也许还有一些其他技巧可以在没有巨大临时数组的情况下使用 BLAS 例程,但这也可以。
例子
import numpy as np
import numba as nb
import time
@nb.njit(fastmath=True,parallel=True)
def sparse_mult(a,b_Trans,inds):
res=np.empty(inds.shape[0],dtype=a.dtype)
for x in nb.prange(inds.shape[0]):
i=inds[x,0]
j=inds[x,1]
sum=0.
for k in range(a.shape[1]):
sum+=a[i,k]*b_Trans[j,k]
res[x]=sum
return res
#-------------------------------------------------
K=int(1e3)
N=int(1e5)
M=int(1e5)
L=int(1e7)
a = np.arange(N*K).reshape(N,K).astype(np.float64)
b = np.arange(K*M).reshape(K,M).astype(np.float64)
inds=np.empty((L,2),dtype=np.uint64)
inds[:,0] = np.random.randint(low=0,high=N,size=L) #rows
inds[:,1] = np.random.randint(low=0,high=M,size=L) #cols
#prepare
#-----------------------------------------------
#sort inds for better cache usage
inds=inds[np.argsort(inds[:,1]),:]
# transpose b for easy SIMD-usage
# we wan't a real transpose here not a view
b_T=np.copy(np.transpose(b))
#calculate results
values=sparse_mult(a,b_T,inds)
计算步骤,包括准备(排序、b 矩阵的转置)应在 60 秒内运行。
推荐阅读
- c# - c# 使用正则表达式提取句子
- python - 为什么 LSTM 的导数设置为零
- ios - 故事板全局色调在 Xcode 11.3.1 中未正确显示
- flutter - 如何设置颤动页面的高度
- android - 如何获取选定的弹出菜单项位置值并传递它
- python - 如何使用日志在 python 中启用或禁用许多打印语句?
- android - xml中登录字段的条件渲染
- python - Python findall 使用带有逻辑 AND 的 xpath 模式
- python - 无法导入 tpot:ImportError:无法导入名称“ensure_object”
- java - 在java中创建泛型方法