首页 > 解决方案 > Matlab中没有循环的张量乘法

问题描述

我有一个 3d 数组 A,例如 A=rand(N,N,K)。

我需要一个数组 B st

B(n,m) = norm(A(:,:,n)*A(:,:,m)' - A(:,:,m)*A(:,:,n)','fro')^2 for all indices n,m in 1:K.

这是循环代码:

B = zeros(K,K);    
for n=1:K
       for m=1:K
           B(n,m) = norm(A(:,:,n)*A(:,:,m)' - A(:,:,m)*A(:,:,n)','fro')^2;
       end
end

我不想循环播放 1:K。

我可以创建一个大小为 N K x N K st的数组 An_x_mt

An_x_mt equals A(:,:,n)*A(:,:,m)' for all n,m in 1:K by
An_x_mt = Ar*Ac_t; 

Ac_t=reshape(permute(A,[2 1 3]),size(A,1),[]); 
Ar=Ac_t';

如何创建大小为 N K x N K st的数组 Am_x_nt

Am_x_nt equals A(:,:,m)*A(:,:,n)' for all n,m in 1:K

这样我就可以做到

B = An_x_mt  - Am_x_nt
B = reshape(B,N,N,[]);
B = reshape(squeeze(sum(sum(B.^2,1),2)),K,K);

谢谢

标签: matlabloopstensor

解决方案


对于那些不能/不会使用 mmx 并希望坚持使用纯 Matlab 代码的人,您可以这样做。mat2cell 和 cell2mat 函数是你的朋友:

[N,~,nmat]=size(A);
Atc = reshape(permute(A,[2 1 3]),N,[]); % A', N x N*nmat
Ar = Atc'; % A, N*nmat x N
Anmt_2d = Ar*Atc; % An*Am'
Anmt_2d_cell = mat2cell(Anmt_2d,N*ones(nmat,1),N*ones(nmat,1));
Amnt_2d_cell = Anmt_2d_cell'; % ONLY products transposed, NOT their factors
Amnt_2d = cell2mat(Amnt_2d_cell); % Am*An'
Anm = Anmt_2d - Amnt_2d;
Anm = Anm.^2;
Anm_cell = mat2cell(Anm,N*ones(nmat,1),N*ones(nmat,1));
d = cellfun(@(c) sum(c(:)), Anm_cell); % squared Frobenius norm of each product; nmat x nmat

或者,在计算 Anmt_2d_cell 和 Amnt_2d_cell 之后,您可以将它们转换为 3d,其中第 3 维编码 (n,m) 和 (m,n) 索引,然后在 3d 中进行其余计算。您将需要来自此处的 permn() 实用程序https://www.mathworks.com/matlabcentral/fileexchange/7147-permn-vnk

Anmt_3d = cat(3,Anmt_2d_cell);
Amnt_3d = cat(3,Amnt_2d_cell);
Anm_3d = Anmt_3d - Amnt_3d;
Anm_3d = Anm_3d.^2;
Anm = squeeze(sum(sum(Anm_3d,1),2));
d = zeros(nmat,nmat);
nm=permn(1:nmat, 2); % all permutations (n,m) with repeat, by-row order
d(sub2ind([nmat,nmat],nm(:,1),nm(:,2))) = Anm;

出于某种原因,第二个选项(3D 阵列)要快两倍。

希望这会有所帮助。


推荐阅读