MATLAB是否会优化diag(A * B)? [英] Does MATLAB optimize diag(A*B)?
问题描述
说我有两个非常大的矩阵A
(M-by-N)和B
(N-by-M).我需要A*B
的对角线.计算完整的A*B
需要M * M * N乘法,而计算它的对角线只需要M * N乘法,因为不需要计算以对角线结尾的元素.
Say I have two very big matrices A
(M-by-N) and B
(N-by-M). I need the diagonal of A*B
. Computing the full A*B
requires M*M*N multiplications, while computing the diagonal of it only requires M*N multiplications since there's no need to compute the elements that will end up outside the diagonal.
MATLAB是自动实现这一目标并动态优化diag(A*B)
的,还是在这种情况下最好使用for循环?
Does MATLAB realize this and on-the-fly-optimize diag(A*B)
automagically, or am I better off using a for loop in this case?
推荐答案
一个人也可以将diag(A*B)
实现为sum(A.*B',2)
.让我们将此基准与针对该问题建议的所有其他实现/解决方案进行基准测试.
One can also implement diag(A*B)
as sum(A.*B',2)
. Let's benchmark this along with all other implementations/solutions as suggested for this question.
出于基准测试目的,下面列出了实现为函数的不同方法:
The different methods implemented as functions are listed below for benchmarking purposes:
-
求和乘法1
Sum-multiplication method-1
function out = sum_mult_method1(A,B)
out = sum(A.*B',2);
求和乘法2
Sum-multiplication method-2
function out = sum_mult_method2(A,B)
out = sum(A.'.*B).';
循环方法
For-loop method
function out = for_loop_method(A,B)
M = size(A,1);
out = zeros(M,1);
for i=1:M
out(i) = A(i,:) * B(:,i);
end
全/直接乘法
Full/Direct-multiplication method
function out = direct_mult_method(A,B)
out = diag(A*B);
Bsxfun方法
Bsxfun-method
function out = bsxfun_method(A,B)
out = sum(bsxfun(@times,A,B.'),2);
基准代码
num_runs = 1000;
M_arr = [100 200 500 1000];
N = 4;
%// Warm up tic/toc.
tic();
elapsed = toc();
tic();
elapsed = toc();
for k2 = 1:numel(M_arr)
M = M_arr(k2);
fprintf('\n')
disp(strcat('*** Benchmarking sizes are M =',num2str(M),' and N = ',num2str(N)));
A = randi(9,M,N);
B = randi(9,N,M);
disp('1. Sum-multiplication method-1');
tic
for k = 1:num_runs
out1 = sum_mult_method1(A,B);
end
toc
clear out1
disp('2. Sum-multiplication method-2');
tic
for k = 1:num_runs
out2 = sum_mult_method2(A,B);
end
toc
clear out2
disp('3. For-loop method');
tic
for k = 1:num_runs
out3 = for_loop_method(A,B);
end
toc
clear out3
disp('4. Direct-multiplication method');
tic
for k = 1:num_runs
out4 = direct_mult_method(A,B);
end
toc
clear out4
disp('5. Bsxfun method');
tic
for k = 1:num_runs
out5 = bsxfun_method(A,B);
end
toc
clear out5
end
结果
*** Benchmarking sizes are M =100 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.015242 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.015180 seconds.
3. For-loop method
Elapsed time is 0.192021 seconds.
4. Direct-multiplication method
Elapsed time is 0.065543 seconds.
5. Bsxfun method
Elapsed time is 0.054149 seconds.
*** Benchmarking sizes are M =200 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.009138 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.009428 seconds.
3. For-loop method
Elapsed time is 0.435735 seconds.
4. Direct-multiplication method
Elapsed time is 0.148908 seconds.
5. Bsxfun method
Elapsed time is 0.030946 seconds.
*** Benchmarking sizes are M =500 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.033287 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.026405 seconds.
3. For-loop method
Elapsed time is 0.965260 seconds.
4. Direct-multiplication method
Elapsed time is 2.832855 seconds.
5. Bsxfun method
Elapsed time is 0.034923 seconds.
*** Benchmarking sizes are M =1000 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.026068 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.032850 seconds.
3. For-loop method
Elapsed time is 1.775382 seconds.
4. Direct-multiplication method
Elapsed time is 13.764870 seconds.
5. Bsxfun method
Elapsed time is 0.044931 seconds.
中间结论
看起来像sum-multiplication
方法是最好的方法,尽管bsxfun
方法似乎是在M
从100增加到1000时赶上它们.
Looks like sum-multiplication
methods are the best approaches, though bsxfun
approach seems be to catching up with them as M
increases from 100 to 1000.
接下来,仅使用sum-multiplication
和bsxfun
方法测试了更高的基准测试大小.大小是-
Next, higher benchmarking sizes were tested with just the sum-multiplication
and bsxfun
methods. The sizes were -
M_arr = [1000 2000 5000 10000 20000 50000];
结果是-
*** Benchmarking sizes are M =1000 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.030390 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.032334 seconds.
5. Bsxfun method
Elapsed time is 0.047377 seconds.
*** Benchmarking sizes are M =2000 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.040111 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.045132 seconds.
5. Bsxfun method
Elapsed time is 0.060762 seconds.
*** Benchmarking sizes are M =5000 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.099986 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.103213 seconds.
5. Bsxfun method
Elapsed time is 0.117650 seconds.
*** Benchmarking sizes are M =10000 and N =4
1. Sum-multiplication method-1
Elapsed time is 0.375604 seconds.
2. Sum-multiplication method-2
Elapsed time is 0.273726 seconds.
5. Bsxfun method
Elapsed time is 0.226791 seconds.
*** Benchmarking sizes are M =20000 and N =4
1. Sum-multiplication method-1
Elapsed time is 1.906839 seconds.
2. Sum-multiplication method-2
Elapsed time is 1.849166 seconds.
5. Bsxfun method
Elapsed time is 1.344905 seconds.
*** Benchmarking sizes are M =50000 and N =4
1. Sum-multiplication method-1
Elapsed time is 5.159177 seconds.
2. Sum-multiplication method-2
Elapsed time is 5.081211 seconds.
5. Bsxfun method
Elapsed time is 3.866018 seconds.
备用基准测试代码(带有"timeit")
num_runs = 1000;
M_arr = [1000 2000 5000 10000 20000 50000 100000 200000 500000 1000000];
N = 4;
timeall = zeros(5,numel(M_arr));
for k2 = 1:numel(M_arr)
M = M_arr(k2);
A = rand(M,N);
B = rand(N,M);
f = @() sum_mult_method1(A,B);
timeall(1,k2) = timeit(f);
clear f
f = @() sum_mult_method2(A,B);
timeall(2,k2) = timeit(f);
clear f
f = @() bsxfun_method(A,B);
timeall(5,k2) = timeit(f);
clear f
end
figure,
hold on
plot(M_arr,timeall(1,:),'-ro')
plot(M_arr,timeall(2,:),'-ko')
plot(M_arr,timeall(5,:),'-.b')
legend('sum-method1','sum-method2','bsxfun-method')
xlabel('M ->')
ylabel('Time(sec) ->')
情节
最终结论
在某些阶段,sum-multiplication
方法似乎很棒,大约在M=5000
标记附近,之后bsxfun
似乎有一点优势.
It seems sum-multiplication
method is great till certain stage, which is around M=5000
mark and after that bsxfun
seems to have a slight upper-hand.
未来的工作
人们可以研究各种N
并研究此处提到的实现的性能.
One can look into varying N
and study the performances for the implementations mentioned here.
这篇关于MATLAB是否会优化diag(A * B)?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!