Matrix multiplication speed-up trick on MATLAB

April 23rd, 2013 | Categories: Making MATLAB faster, matlab, programming | Tags:

I was recently working on some MATLAB code with Manchester University’s David McCormick.  Buried deep within this code was a function that was called many,many times…taking up a significant amount of overall run time.  We managed to speed up an important part of this function by almost a factor of two (on his machine) simply by inserting two brackets….a new personal record in overall application performance improvement per number of keystrokes.

The code in question is hugely complex, but the trick we used is really very simple.  Consider the following MATLAB code

>> a=rand(4000);
>> c=12.3;
>> tic;res1=c*a*a';toc
Elapsed time is 1.472930 seconds.

With the insertion of just two brackets, this runs quite a bit faster on my Ivy Bridge quad-core desktop.

>> tic;res2=c*(a*a');toc
Elapsed time is 0.907086 seconds.

So, what’s going on? Well, we think that in the first version of the code, MATLAB first calculates c*a to form a temporary matrix (let’s call it temp here) and then goes on to find temp*a’.  However, in the second version, we think that MATLAB calculates a*a’ first and in doing so it takes advantage of the fact that the result of multiplying a matrix by its transpose will be symmetric which is where we get the speedup.

Another demonstration of this phenomena can be seen as follows

>> a=rand(4000);
>> b=rand(4000);
>> tic;a*a';toc 
Elapsed time is 0.887524 seconds.
>> tic;a*b;toc  
Elapsed time is 1.473208 seconds.
>> tic;b*b';toc
Elapsed time is 0.966085 seconds.

Note that the symmetric matrix-matrix multiplications are faster than the general, non-symmetric one.

  1. April 23rd, 2013 at 18:48
    Reply | Quote | #1

    This is just awesome!
    Now I have to check all my old files and find where I can apply the trick ;)

  2. Martin Cohen
    April 24th, 2013 at 05:25
    Reply | Quote | #2

    I recall a variation on this when computing a triple matrix product from my long-ago numerical analysis days.

    The basic idea is that if you want to compute a matrix product A*B*C, where A is n by m, B is m by p, and C is p by q, calculate the number of operations based on the result that nops(u by v, v by w)=u v w.

    For A*(B*C), the number of ops is m p q for B*C and n m q for A times that for a total of mpq+nmq = mq(n+p).

    For (A*B)*C, the number of ops is n m p for A*B and n p q for that times C for a total of nmp+npq = np(m+q).

    Different size matrices can result in vastly different numbers of operations.

    It becomes an interesting optimization problem when the number of matrices being multiplied gets large.

    I would be surprised if this is not in Knuth.

  3. Stephen
    April 24th, 2013 at 10:05
    Reply | Quote | #3

    another option is to form ca = sqrt(abs(c))*a, and then form ca*ca’, since c*(a*a’) requires n^2 scalar multiplies (if length(a)=n) which is unnecessary. If c isn’t positive, you could throw in a -1 multiplication at the end (which is probably faster than floating pt multiplication).

  4. per isakson
    April 24th, 2013 at 11:10
    Reply | Quote | #4

    The JIT/Accelerator for MATLAB does not take advantage of the fact pointed out by Martin Cohen (AFAIK from experiments). I find that remarkable, but am not sure whether that tells more about me than about The Mathworks.

  5. April 24th, 2013 at 12:58
    Reply | Quote | #5

    Nice. I was skeptical, thinking that (a*a) may also be faster, because only one matrix is involved. But I was wrong, and it is as slow as (a*b). Also (a*b’) is slow too. I checked, just in case there was some reason why it’s better for the second argument to be transposed.

  6. April 24th, 2013 at 15:12
    Reply | Quote | #6

    I wondered what happened on the free software front. Happily the speedup occurs in Octave too. Unfortunately dot(a,a.T) in numpy doesn’t get the speedup. Compiling with theano didn’t help. Presumably one could directly call the right Fortran routine from within Python, but that isn’t an “insert two brackets” fix :-(.

  7. Nick Higham
    April 24th, 2013 at 20:10
    Reply | Quote | #7

    Martin’s point about the associative law is a very good one. It’s well known and is something we emphasize in our Masters level Numerical Linear Algebra course in Manchester. It underlies some important techniques, such as the forward versus reverse mode versions of automatic differentiation. What seems to be going on in Mike’s example is different, and down to the vagaries of code optimization rather than the mathematics.

  8. April 25th, 2013 at 16:35
    Reply | Quote | #8

    Octave explicitly spots this case and uses DSYRK from BLAS to work out half of the result matrix, then copies the answers across to the other half. Here’s a numpy version of that:
    http://homepages.inf.ed.ac.uk/imurray2/code/tdot/tdot.py
    Using a fast (A*A’) primitive could speed up things like distance computations and finding covariances.

  9. Mike Croucher
    April 25th, 2013 at 17:07
    Reply | Quote | #9

    Iain, do you use MKL with numpy? My experience is that an MLK-linked numpy is much faster than the competition (on my Intel hardware at least).

  10. April 25th, 2013 at 18:16

    I’ve tried MKL as linked by anaconda. It doesn’t hugely beat OpenBLAS or a tuned compile of ATLAS for me. Although the ranking of the three changes between my laptop and desktop.

    The main advantage of MKL and OpenBLAS for me is being able to set the number of threads with environment variables. Use CPUs for interactive use, 1 thread for embarrassingly parallel production runs.

    I’ve tried tdot.py with both MKL (anaconda python) and OpenBLAS (Ubuntu python). My code uses DSYRK from whichever BLAS numpy is using. I get similar times for tdot(A) with both, and both beating dot(A,A.T).

  11. Mike Croucher
    April 26th, 2013 at 09:50

    Nice to see that the Julia people have been looking at this idea too

    https://groups.google.com/forum/#!topic/julia-dev/Vrph5SGXnAk

  12. William Heath
    May 1st, 2013 at 09:15

    You can test Mike’s hypothesis by typing
    >> max(max(abs(res1-res1′)))
    and
    >> max(max(abs(res2-res2′)))
    Only the latter returns zero – ie the symmetry has been recognised.

Comments are closed.