Matlab中不带循环的张量乘法 [英] Tensor multiplication w/o looping in Matlab

查看:131
本文介绍了Matlab中不带循环的张量乘法的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个3d数组A,例如A = rand(N,N,K).

I have a 3d array A, e.g. A=rand(N,N,K).

我需要一个数组B.

B(n,m) = norm(A(:,:,n)*A(:,:,m)' - A(:,:,m)*A(:,:,n)','fro')^2 for all indices n,m in 1:K.

这是循环代码:

B = zeros(K,K);    
for n=1:K
       for m=1:K
           B(n,m) = norm(A(:,:,n)*A(:,:,m)' - A(:,:,m)*A(:,:,n)','fro')^2;
       end
end

我不想循环到1:K.

我可以创建大小为N K x N K s.t的数组An_x_mt.

I can create an array An_x_mt of size NK x NK s.t.

An_x_mt equals A(:,:,n)*A(:,:,m)' for all n,m in 1:K by
An_x_mt = Ar*Ac_t; 

Ac_t=reshape(permute(A,[2 1 3]),size(A,1),[]); 
Ar=Ac_t';

如何创建尺寸也为N K x N K s.t的数组Am_x_nt.

How do I create an array Am_x_nt also of size NK x NK s.t.

Am_x_nt equals A(:,:,m)*A(:,:,n)' for all n,m in 1:K

这样我就可以做

B = An_x_mt  - Am_x_nt
B = reshape(B,N,N,[]);
B = reshape(squeeze(sum(sum(B.^2,1),2)),K,K);

Thx

推荐答案

对于那些不能/不会使用mmx并希望坚持使用纯Matlab代码的人,可以按照以下方法进行操作. mat2cell和cell2mat函数是您的朋友:

For those who can't/won't use mmx and want to stick to pure Matlab code, here's how you could do it. mat2cell and cell2mat functions are your friends:

[N,~,nmat]=size(A);
Atc = reshape(permute(A,[2 1 3]),N,[]); % A', N x N*nmat
Ar = Atc'; % A, N*nmat x N
Anmt_2d = Ar*Atc; % An*Am'
Anmt_2d_cell = mat2cell(Anmt_2d,N*ones(nmat,1),N*ones(nmat,1));
Amnt_2d_cell = Anmt_2d_cell'; % ONLY products transposed, NOT their factors
Amnt_2d = cell2mat(Amnt_2d_cell); % Am*An'
Anm = Anmt_2d - Amnt_2d;
Anm = Anm.^2;
Anm_cell = mat2cell(Anm,N*ones(nmat,1),N*ones(nmat,1));
d = cellfun(@(c) sum(c(:)), Anm_cell); % squared Frobenius norm of each product; nmat x nmat

或者,在计算Anmt_2d_cell和Amnt_2d_cell之后,您可以使用第3维将它们转换为3d,对(n,m)和(m,n)索引进行编码,然后在3d中进行其余的计算.您可能需要从此处 https://www.mathworks. com/matlabcentral/fileexchange/7147-permn-vnk

Alternatively, after computing Anmt_2d_cell and Amnt_2d_cell, you could convert them to 3d with the 3rd dimension encoding the (n,m) and (m,n) indices and then do the rest of the computations in 3d. You would need the permn() utility from here https://www.mathworks.com/matlabcentral/fileexchange/7147-permn-v-n-k

Anmt_3d = cat(3,Anmt_2d_cell);
Amnt_3d = cat(3,Amnt_2d_cell);
Anm_3d = Anmt_3d - Amnt_3d;
Anm_3d = Anm_3d.^2;
Anm = squeeze(sum(sum(Anm_3d,1),2));
d = zeros(nmat,nmat);
nm=permn(1:nmat, 2); % all permutations (n,m) with repeat, by-row order
d(sub2ind([nmat,nmat],nm(:,1),nm(:,2))) = Anm;

由于某种原因,第二个选项(3D阵列)的速度快了两倍.

For some reason, the 2nd option (3D arrays) is twice faster.

希望这会有所帮助.

这篇关于Matlab中不带循环的张量乘法的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆