首页 > 解决方案 > 优化张量乘法

问题描述

我有一个我正在尝试优化的实时图像处理程序,这一切都归结为矩阵乘法。考虑我在初始化阶段计算的 3 个张量:

  1. A = np.arange(35 * 51 * 59).reshape([35, 51, 59])
  2. B = np.arange(37 * 51 * 51 * 59).reshape([37, 51, 51, 59])
  3. C = np.arange(59 * 27).reshape([59, 27])

每一帧,我都会以第四张量的形式获得一个新数据:

目前,我正在计算D = np.einsum('xyf,xtf,ytpf,fr->tpr', M, A, B, C)D我想要的结果在哪里,这是程序的主要瓶颈。为了优化它,我试图遵循两个方向。

首先我尝试提出一个张量T,我可以预先计算一个函数A, B, C, D,然后它会全部沸腾为D = np.tensordot(M, T, axes=..). 我没有成功。我花了很多时间,这甚至可能吗?

此外,程序本身是用 MATLAB 编写的。由于它没有内置的张量乘法函数(einsumtensordot等效函数),我目前正在使用该tprod工具箱,并且正在执行以下操作:

temp1 = etprod('dcb', A, 'abc', M, 'adc');
temp2 = etprod('dbc', B, 'abcd', temp1, 'adb');
D = etprod('cdb', C, 'ab', temp2, 'acd');

由于 MATLAB 中的默认点积函数(用于 2D 矩阵)要快得多etprod,因此我想A, B, C, D以一种能够使用默认函数处理多个 2D 矩阵的方式将其重塑为 2D 数组,而无需手动编写for循环。我也没有成功。

有什么想法吗?谢谢!

标签: matlabnumpydot-productnumpy-einsum

解决方案


如果使用不同的 M 值多次执行此操作,我们可以定义

D0 = np.einsum('xft,fr->tpr',A, B, C)

整个操作可以分解为二进制步骤:

D0=np.einsum('xtf,ytpf->xyptf',A,B)
D0=np.einsum('xyptf,fr->xyftpr',D0,C)
D=np.einsum('tprxfy,xfy->tpr',D0,M)

最后的运算使用 D0 和 M,可以编码为矩阵向量运算。在 Matlab 中它会是

D=reshape(D0.[],numel(M))*M(:);

然后可以根据需要重新排序。我们可以把这个顺序写成 (((A,B),C),M)

但是,使用 ((M,C),A,B) 可能会更好

D=np.einsum('xyf,fr->xyfr',M,C)
D0=np.einsum('xyfr,xtf->ytfr',D,A)
D=np.einsum('ytfr,ytpf->tpr',D,B)

这种操作排序的中间数组只有 4 个索引而不是 6 个索引。如果每个操作都比单个操作快得多,这可能是一个优势。


推荐阅读