首页 > 解决方案 > 使用 numpy 计算 3D 矩阵乘法的有效方法

问题描述

如何使用 numpy 有效地编写和计算这个乘法:

 for k in range(K):
    for i in range(SIZE):
       for j in range(SIZE):
          for i_b in range(B_SIZE):
             for j_b in range(B_SIZE):
                for k_b in range(k+1):
                   data[k, i * w + i_b, j * h + j_b] += arr1[k_b, i_b, j_b] * arr2[k_b, i, j]

例如:

SIZE, B_SIZE = 32, 8
arr1.shape -> (8, 8, 8)
arr2.shape -> (8, 32, 32)
data.shape -> (K, 256, 256)

谢谢你。

标签: numpymatrixoptimizationmatrix-multiplicationnumpy-ndarray

解决方案


您可以将Numba用于这种非平凡的情况,并重新设计循环以有效地使用 CPU缓存。这是一个例子:

import numba as nb

@nb.njit
def compute(data, arr1, arr2):
    for k in range(K):
        for k_b in range(k+1):
            for i in range(SIZE):
                for j in range(SIZE):
                    tmp = arr2[k_b, i, j]
                    for i_b in range(B_SIZE):
                        for j_b in range(B_SIZE):
                            data[k, i * w + i_b, j * h + j_b] += arr1[k_b, i_b, j_b] * tmp

如果您执行此操作一次,则可以通过提供数组的类型来预编译Numba 代码。如果K很大,那么您可以使用and use不是. 这应该是几个数量级的脂肪。@nb.njit(parallel=True)for k in nb.prange(K)for k in range(K)


推荐阅读