numpy - 使用 numpy 计算 3D 矩阵乘法的有效方法
问题描述
如何使用 numpy 有效地编写和计算这个乘法:
for k in range(K):
for i in range(SIZE):
for j in range(SIZE):
for i_b in range(B_SIZE):
for j_b in range(B_SIZE):
for k_b in range(k+1):
data[k, i * w + i_b, j * h + j_b] += arr1[k_b, i_b, j_b] * arr2[k_b, i, j]
例如:
SIZE, B_SIZE = 32, 8
arr1.shape -> (8, 8, 8)
arr2.shape -> (8, 32, 32)
data.shape -> (K, 256, 256)
谢谢你。
解决方案
您可以将Numba用于这种非平凡的情况,并重新设计循环以有效地使用 CPU缓存。这是一个例子:
import numba as nb
@nb.njit
def compute(data, arr1, arr2):
for k in range(K):
for k_b in range(k+1):
for i in range(SIZE):
for j in range(SIZE):
tmp = arr2[k_b, i, j]
for i_b in range(B_SIZE):
for j_b in range(B_SIZE):
data[k, i * w + i_b, j * h + j_b] += arr1[k_b, i_b, j_b] * tmp
如果您执行此操作一次,则可以通过提供数组的类型来预编译Numba 代码。如果K
很大,那么您可以使用and use而不是. 这应该是几个数量级的脂肪。@nb.njit(parallel=True)
for k in nb.prange(K)
for k in range(K)