How to multiply N matrices without a FOR loop? (Slices of 3D array)
10 views (last 30 days)
Show older comments
I have a 3D matrix 2x2xN which, for my purposes, are essentially N 2x2 matrices and I want to do matrix multiplication with all of them so that I would get the following result:
N = 14;
M = rand(2,2,N);
Z = M(:,:,1)*M(:,:,2)* ... *M(:,:,N);
size(Z) == [2 2]
I can do it with a for loop, but I am looking for a single line approach, something like:
prod(M,3);
but probably with mtimes that would do matrix multiplication along the 3rd dimension (not the element-wise product).
I also converted matrix M into a Nx1 cell array of 2x2 matrices, but this approach did not work either to do the multiplication.
8 Comments
Jan
on 7 Dec 2017
Edited: Jan
on 7 Dec 2017
Stephen's comment is very good.
For the estimation of the effects of optimizing the code, the usual sizes of the inputs matter: Is it really a [2 x 2 x N] array and what sizes of N do you have? For larger rows and columns, the main is done by mtimes, while the loop does not matter much. mtimes calls optimized BLAS or ATLAS functions, such that there is no room for further improvements. But I do not know, if these library function handle tiny 2x2 matrices with unrolled loops. So perhaps a C-Mex function could be more efficient.
Answers (5)
Jan
on 7 Dec 2017
Edited: Jan
on 7 Dec 2017
If you really have 2x2 sub matrices to accumulate, try a C-Mex function:
#include "mex.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
const mwSize *size;
mwSize N;
double *p, *q, q11, q12, q21, q22, t11, t21;
p = mxGetPr(prhs[0]);
size = mxGetDimensions(prhs[0]);
if (size[0] != 2 || size[1] != 2) {
mexErrMsgIdAndTxt("JSimon:CumMProd2x2:BadInput1",
"1st input must be a [2 x 2 x N] array.");
}
N = size[2];
q11 = p[0];
q21 = p[1];
q12 = p[2];
q22 = p[3];
while (--N) { // Unrolled 2x2 matrix multiplication
p += 4;
t11 = q11 * p[0] + q12 * p[1];
t21 = q21 * p[0] + q22 * p[1];
q12 = q11 * p[2] + q12 * p[3];
q22 = q21 * p[2] + q22 * p[3];
q11 = t11;
q21 = t21;
}
plhs[0] = mxCreateDoubleMatrix(2, 2, mxREAL);
q = mxGetPr(plhs[0]);
q[0] = q11;
q[1] = q21;
q[2] = q12;
q[3] = q22;
return;
}
[EDITED] This is tested now. The speed is very interesting:
function speed
x = rand(2, 2, 1000);
tic; for k = 1:1000, y = CumMProd2x2(x); end; toc
tic; for k = 1:1000, y = CumMProd2x2_AB(x); end; toc
tic
for k = 1:1000 % Jos (10584)
iif = @(varargin) varargin{2*find([varargin{1:2:end}], 1, 'first')}() ;
mprodf = @(F,M,n) iif (n < 2, M(:,:,1), true, @() F(F,M,n-1) * M(:,:,n)) ;
out = mprodf(mprodf, x, size(x, 3));
end
toc
end
function out = CumMProd2x2_AB(M) % Andrei Bobrov
s = size(M, 3);
out = M(:,:,1);
for ii = 2:s
out = out * M(:,:,ii);
end
end
R2016b/64/Win7:
Elapsed time is 0.011403 seconds. C-mex
Elapsed time is 3.884977 seconds. Loop
Elapsed time is 96.038754 seconds. Recursive anonymous function
I was surprised, that Andrei's loop is such slow, although it is clearly the nicest and cleaned solution. Let's try to unroll the loops like in the C-Code:
function out = CumMProd2x2_unroll(M)
q11 = M(1);
q21 = M(2);
q12 = M(3);
q22 = M(4);
c = 1;
for ii = 2:size(M, 3)
c = c + 4;
t11 = q11 * M(c) + q12 * M(c+1);
t21 = q21 * M(c) + q22 * M(c+1);
q12 = q11 * M(c+2) + q12 * M(c+3);
q22 = q21 * M(c+2) + q22 * M(c+3);
q11 = t11;
q21 = t21;
end
out = [q11, q12; q21, q22];
end
This 64 times faster than the direct approach "out * M(:,:,ii)":
Elapsed time is 0.061287 seconds. Unrolled
Obviously Matlab calls very smart highly optimized libraries for the matrix multiplication, which treat the tiny input with the same hammer method as a 1000x1000 matrix.
But this unrolled version is such ugly, that I would hesitate to use it in productive code. For x = rand(2, 2, 100000) I get the timings for 1000 iterations:
Elapsed time is 1.377695 seconds. C-mex
Elapsed time is 2.872356 seconds. M with unrolled mtimes
Only a factor 2! Another example, that loops are not such bad in Matlab compared to C.
2 Comments
Jos (10584)
on 7 Dec 2017
haha, I really liked my anonymous function approach, and did expect it to perform poorly, but that poor ... haha
Andrei Bobrov
on 6 Dec 2017
s = size(M)
out = M(:,:,1);
for ii = 2:s(3)
out = out*M(:,:,ii);
end
5 Comments
Jan
on 7 Dec 2017
+1: This is the nicest solution. That the multiplication of 2x2 matrices is much faster with hard coded algorithm is not a problem of this solution.
Although the C-Mex approach is faster, it would be very hard to generalize it for inputs beside 2x2xN arrays.
Matt J
on 7 Dec 2017
Although the C-Mex approach is faster, it would be very hard to generalize it for inputs beside 2x2xN arrays.
Just wanted to note that, while my solution based on MTIMESX is not as fast as Jan's for the 2x2xN case, it is applicable to arbitrary MxMxN arrays,
Matt J
on 6 Dec 2017
The following is not a one-line solution (for that just stick it in a function file) and requires MTIMESX from the File Exchange. However, I do see a few factors speed-up over a conventional for-loop,
out=M;
while size(out,3)>1
n=size(out,3);
if mod(n,2)
n=n-1;
A=out(:,:,1:2:n);
B=out(:,:,2:2:n);
out=cat(3,mtimesx(A,B),out(:,:,n+1));
else
A=out(:,:,1:2:n);
B=out(:,:,2:2:n);
out=mtimesx(A,B);
end
end
5 Comments
James Tursa
on 7 Dec 2017
Edited: James Tursa
on 7 Dec 2017
Side Note: MTIMESX by default calls BLAS library routines for matrix multiply so that it matches MATLAB for-loop m-code result, whereas MTIMESX with the 'SPEED' option will use hand-coded inline matrix multiply code for up to 5x5 size slices which may not match MATLAB for-loop m-code result exactly.
Sometime back I had a beta version of MTIMESX that implemented the matrix equivalent versions of 'prod' and 'cumprod'. Maybe it is time I dust that off and finish the implementation/testing so I can publish it.
Matt J
on 7 Dec 2017
That is strange, since I still see significant speed-up even with
mtimesx MATLAB
Steven Lord
on 17 Sep 2020
If you're using release R2020b or later, take a look at the pagemtimes function introduced in that release.
0 Comments
Jos (10584)
on 6 Dec 2017
Here is one using recursion without a for-loop; not faster though, and somewhat mysterious, but just nice :) ...
M = randi(5,[2 2 4]) ; % data
iif = @(varargin) varargin{2*find([varargin{1:2:end}], 1, 'first')}() ;
mprodf = @(F,M,n) iif (n < 2, M(:,:,1), true, @() F(F,M,n-1) * M(:,:,n)) ;
out = mprodf(mprodf,M,size(M,3)) % voila, it works!
3 Comments
Jos (10584)
on 7 Dec 2017
It is the inline version of this recursive m-file:
function X = mprod(M,n)
% X = mprod(M) returns M(:,:,1) * M(:,:,2) * ... * M(:,:,end)
% where M is a 3D array
if nargin==1
X = mprod(M,size(M,3)) ;
elseif n < 2
X = M(:,:,1) ;
else
X = mprod(M,n-1) * M(:,:,n) ;
end
See Also
Categories
Find more on Matrices and Arrays in Help Center and File Exchange
Products
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!