Checkpointing MATLAB Programs

February 19th, 2014 | Categories: Condor, matlab, programming, random numbers, tutorials | Tags:

I occasionally get emails from researchers saying something like this

‘My MATLAB code takes a week to run and the cleaner/cat/my husband keeps switching off my machine  before it’s completed — could you help me make the code go faster please so that I can get my results in between these events’

While I am more than happy to try to optimise the code in question, what these users really need is some sort of checkpointing scheme. Checkpointing is also important for users of high performance computing systems that limit the length of each individual job.

The solution – Checkpointing (or ‘Assume that your job will frequently be killed’)

The basic idea behind checkpointing is to periodically save your program’s state so that, if it is interrupted, it can start again where it left off rather than from the beginning. In order to demonstrate some of the principals involved, I’m going to need some code that’s sufficiently simple that it doesn’t cloud what I want to discuss. Let’s add up some numbers using a for-loop.

%addup.m
%This is not the recommended way to sum integers in MATLAB -- we only use it here to keep things simple
%This version does NOT use checkpointing

mysum=0;
for count=1:100
    mysum = mysum + count;
    pause(1);           %Let's pretend that this is a complicated calculation
    fprintf('Completed iteration %d \n',count);
end

fprintf('The sum is %f \n',mysum);

Using a for-loop to perform an addition like this is something that I’d never usually suggest in MATLAB but I’m using it here because it is so simple that it won’t get in the way of understanding the checkpointing code.

If you run this program in MATLAB, it will take about 100 seconds thanks to that pause statement which is acting as a proxy for some real work. Try interrupting it by pressing CTRL-C and then restart it. As you might expect, it will always start from the beginning:

>> addup
Completed iteration 1
Completed iteration 2
Completed iteration 3
Operation terminated by user during addup (line 6)

>> addup
Completed iteration 1
Completed iteration 2
Completed iteration 3
Operation terminated by user during addup (line 6)

This is no big deal when your calculation only takes 100 seconds but is going to be a major problem when the calculation represented by that pause statement becomes something like an hour rather than a second.

Let’s now look at a version of the above that makes use of checkpointing.

%addup_checkpoint.m
if exist( 'checkpoint.mat','file' ) % If a checkpoint file exists, load it
    fprintf('Checkpoint file found - Loading\n');
    load('checkpoint.mat')

else %otherwise, start from the beginning
    fprintf('No checkpoint file found - starting from beginning\n');
    mysum=0;
    countmin=1;
end

for count = countmin:100
    mysum = mysum + count;
    pause(1);           %Let's pretend that this is a complicated calculation

    %save checkpoint
    countmin = count+1;  %If we load this checkpoint, we want to start on the next iteration
    fprintf('Saving checkpoint\n');
    save('checkpoint.mat');

    fprintf('Completed iteration %d \n',count);
end
fprintf('The sum is %f \n',mysum);

Before you run the above code, the checkpoint file checkpoint.mat does not exist and so the calculation starts from the beginning. After every iteration, a checkpoint file is created which contains every variable in the MATLAB workspace. If the program is restarted, it will find the checkpoint file and continue where it left off. Our code now deals with interruptions a lot more gracefully.

>> addup_checkpoint
No checkpoint file found - starting from beginning
Saving checkpoint
Completed iteration 1 
Saving checkpoint
Completed iteration 2 
Saving checkpoint
Completed iteration 3 
Operation terminated by user during addup_checkpoint (line 16)

>> addup_checkpoint
Checkpoint file found - Loading
Saving checkpoint
Completed iteration 4 
Saving checkpoint
Completed iteration 5 
Saving checkpoint
Completed iteration 6 
Operation terminated by user during addup_checkpoint (line 16)

Note that we’ve had to change the program logic slightly. Our original loop counter was

for count = 1:100

In the check-pointed example, however, we’ve had to introduce the variable countmin

for count = countmin:100

This allows us to start the loop from whatever value of countmin was in our last checkpoint file. Such minor modifications are often necessary when converting code to use checkpointing and you should carefully check that the introduction of checkpointing does not introduce bugs in your code.

Don’t checkpoint too often

The creation of even a small checkpoint file is a time consuming process. Consider our original addup code but without the pause command.

%addup_nopause.m
%This version does NOT use checkpointing
mysum=0;
for count=1:100
    mysum = mysum + count;
    fprintf('Completed iteration %d \n',count);
end
fprintf('The sum is %f \n',mysum);

On my machine, this code takes 0.0046 seconds to execute. Compare this to the checkpointed version, again with the pause statement removed.

%addup_checkpoint_nopause.m

if exist( 'checkpoint.mat','file' ) % If a checkpoint file exists, load it
    fprintf('Checkpoint file found - Loading\n');
    load('checkpoint.mat')

else %otherwise, start from the beginning
    fprintf('No checkpoint file found - starting from beginning\n');
    mysum=0;
    countmin=1;
end

for count = countmin:100
    mysum = mysum + count;

    %save checkpoint
    countmin = count+1;  %If we load this checkpoint, we want to start on the next iteration
    fprintf('Saving checkpoint\n');
    save('checkpoint.mat');

    fprintf('Completed iteration %d \n',count);
end
fprintf('The sum is %f \n',mysum);

This checkpointed version takes 0.85 seconds to execute on the same machine — Over 180 times slower than the original! The problem is that the time it takes to checkpoint is long compared to the calculation time.

If we make a modification so that we only checkpoint every 25 iterations, code execution time comes down to 0.05 seconds:

%Checkpoint every 25 iterations

if exist( 'checkpoint.mat','file' ) % If a checkpoint file exists, load it
    fprintf('Checkpoint file found - Loading\n');
    load('checkpoint.mat')

else %otherwise, start from the beginning
    fprintf('No checkpoint file found - starting from beginning\n');
    mysum=0;
    countmin=1;
end

for count = countmin:100
    mysum = mysum + count;
    countmin = count+1;  %If we load this checkpoint, we want to start on the next iteration

    if mod(count,25)==0
        %save checkpoint   
        fprintf('Saving checkpoint\n');
        save('checkpoint.mat');
    end

    fprintf('Completed iteration %d \n',count);
end
fprintf('The sum is %f \n',mysum);

Of course, the issue now is that we might lose more work if our program is interrupted between checkpoints. Additionally, in this particular case, the mod command used to decide whether or not to checkpoint is more expensive than simply performing the calculation but hopefully that isn’t going to be the case when working with real world calculations.

In practice, we have to work out a balance such that we checkpoint often enough so that we don’t stand to lose too much work but not so often that our program runs too slowly.

Checkpointing code that involves random numbers

Extra care needs to be taken when running code that involves random numbers. Consider a modification of our checkpointed adding program that creates a sum of random numbers.

%addup_checkpoint_rand.m
%Adding random numbers the slow way, in order to demo checkpointing
%This version has a bug

if exist( 'checkpoint.mat','file' ) % If a checkpoint file exists, load it
    fprintf('Checkpoint file found - Loading\n');
    load('checkpoint.mat')

else %otherwise, start from the beginning
    fprintf('No checkpoint file found - starting from beginning\n');
    mysum=0;
    countmin=1;
    rng(0);     %Seed the random number generator for reproducible results
end

for count = countmin:100
    mysum = mysum + rand();
    countmin = count+1;  %If we load this checkpoint, we want to start on the next iteration
    pause(1); %pretend this is a complicated calculation

    %save checkpoint
    fprintf('Saving checkpoint\n');
    save('checkpoint.mat');

    fprintf('Completed iteration %d \n',count);
end
fprintf('The sum is %f \n',mysum);

In the above, we set the seed of the random number generator to 0 at the beginning of the calculation. This ensures that we always get the same set of random numbers and allows us to get reproducible results. As such, the sum should always come out to be 52.799447 to the number of decimal places used in the program.

The above code has a subtle bug that you won’t find if your testing is confined to interrupting using CTRL-C and then restarting in an interactive session of MATLAB. Proceed that way, and you’ll get exactly the sum you’ll expect : 52.799447.  If, on the other hand, you test your code by doing the following

  • Run for a few iterations
  • Interrupt with CTRL-C
  • Restart MATLAB
  • Run the code again, ensuring that it starts from the checkpoint

You’ll get a different result. This is not what we want!

The root cause of this problem is that we are not saving the state of the random number generator in our checkpoint file. Thus, when we restart MATLAB, all information concerning this state is lost. If we don’t restart MATLAB between interruptions, the state of the random number generator is safely tucked away behind the scenes.

Assume, for example, that you stop the calculation running after the third iteration. The random numbers you’d have consumed would be (to 4 d.p.)

0.8147
0.9058
0.1270

Your checkpoint file will contain the variables mysum, count and countmin but will contain nothing about the state of the random number generator. In English, this state is something like ‘The next random number will be the 4th one in the sequence defined by a starting seed of 0.’

When we restart MATLAB, the default seed is 0 so we’ll be using the right sequence (since we explicitly set it to be 0 in our code) but we’ll be starting right from the beginning again. That is, the 4th,5th and 6th iterations of the summation will contain the first 3 numbers in the stream, thus double counting them, and so our checkpointing procedure will alter the results of the calculation.

In order to fix this, we need to additionally save the state of the random number generator when we save a checkpoint and also make correct use of this on restarting. Here’s the code

%addup_checkpoint_rand_correct.m
%Adding random numbers the slow way, in order to demo checkpointing

if exist( 'checkpoint.mat','file' ) % If a checkpoint file exists, load it
    fprintf('Checkpoint file found - Loading\n');
    load('checkpoint.mat')

    %use the saved RNG state
    stream = RandStream.getGlobalStream;
    stream.State = savedState;

else % otherwise, start from the beginning
    fprintf('No checkpoint file found - starting from beginning\n');
    mysum=0;
    countmin=1;
    rng(0);     %Seed the random number generator for reproducible results
end

for count = countmin:100
    mysum = mysum + rand();
    countmin = count+1;  %If we load this checkpoint, we want to start on the next iteration
    pause(1); %pretend this is a complicated calculation

    %save the state of the random number genertor
    stream = RandStream.getGlobalStream;
    savedState = stream.State;
    %save checkpoint
    fprintf('Saving checkpoint\n');
    save('checkpoint.mat');

    fprintf('Completed iteration %d \n',count);
end
fprintf('The sum is %f \n',mysum);

Ensuring that the checkpoint save completes

Events that terminate our code can occur extremely quickly — a powercut for example. There is a risk that the machine was switched off while our check-point file was being written. How can we ensure that the file is complete?

The solution, which I found on the MATLAB checkpointing page of the Liverpool University Condor Pool site is to first write a temporary file and then rename it.  That is, instead of

save('checkpoint.mat')/pre>

we do

if strcmp(computer,'PCWIN64') || strcmp(computer,'PCWIN')
            %We are running on a windows machine
            system( 'move /y checkpoint_tmp.mat checkpoint.mat' );
else
            %We are running on Linux or Mac
            system( 'mv checkpoint_tmp.mat checkpoint.mat' );
end

As the author of that page explains ‘The operating system should guarantee that the move command is “atomic” (in the sense that it is indivisible i.e. it succeeds completely or not at all) so that there is no danger of receiving a corrupt “half-written” checkpoint file from the job.’

Only checkpoint what is necessary

So far, we’ve been saving the entire MATLAB workspace in our checkpoint files and this hasn’t been a problem since our workspace hasn’t contained much. In general, however, the workspace might contain all manner of intermediate variables that we simply don’t need in order to restart where we left off. Saving the stuff that we might not need can be expensive.

For the sake of illustration, let’s skip 100 million random numbers before adding one to our sum. For reasons only known to ourselves, we store these numbers in an intermediate variable which we never do anything with. This array isn’t particularly large at 763 Megabytes but its existence slows down our checkpointing somewhat. The correct result of this variation of the calculation is 41.251376 if we set the starting seed to 0; something we can use to test our new checkpoint strategy.

Here’s the code

% A demo of how slow checkpointing can be if you include large intermediate variables

if exist( 'checkpoint.mat','file' ) % If a checkpoint file exists, load it
    fprintf('Checkpoint file found - Loading\n');
    load('checkpoint.mat')
    %use the saved RNG state
    stream = RandStream.getGlobalStream;
    stream.State = savedState;
else %otherwise, start from the beginning
    fprintf('No checkpoint file found - starting from beginning\n');
    mysum=0;
    countmin=1;
    rng(0);     %Seed the random number generator for reproducible results
end

for count = countmin:100
    %Create and store 100 million random numbers for no particular reason
    randoms = rand(10000);
    mysum = mysum + rand();
    countmin = count+1;  %If we load this checkpoint, we want to start on the next iteration
    fprintf('Completed iteration %d \n',count);
    
    if mod(count,25)==0
        %save the state of the random number generator
        stream = RandStream.getGlobalStream;
        savedState = stream.State;
        %save and time checkpoint
        tic
        save('checkpoint_tmp.mat');
        if strcmp(computer,'PCWIN64') || strcmp(computer,'PCWIN')
            %We are running on a windows machine
            system( 'move /y checkpoint_tmp.mat checkpoint.mat' );
        else
            %We are running on Linux or Mac
            system( 'mv checkpoint_tmp.mat checkpoint.mat' );
        end
        timing = toc;
        fprintf('Checkpoint save took %f seconds\n',timing);
    end
    
end
fprintf('The sum is %f \n',mysum);

On my Windows 7 Desktop, each checkpoint save takes around 17 seconds:

Completed iteration 25 
        1 file(s) moved. 
Checkpoint save took 17.269897 seconds

It is not necessary to include that huge random matrix in a checkpoint file. If we are explicit in what we require, we can reduce the time taken to checkpoint significantly. Here, we change

save('checkpoint_tmp.mat');

to

save('checkpoint_tmp.mat','mysum','countmin','savedState');

This has a dramatic effect on check-pointing time:

Completed iteration 25 
        1 file(s) moved. 
Checkpoint save took 0.033576 seconds

Here’s the final piece of code that uses everything discussed in this article

%Final checkpointing demo

if exist( 'checkpoint.mat','file' ) % If a checkpoint file exists, load it
    fprintf('Checkpoint file found - Loading\n');
    load('checkpoint.mat')
    %use the saved RNG state
    stream = RandStream.getGlobalStream;
    stream.State = savedState;
else %otherwise, start from the beginning
    fprintf('No checkpoint file found - starting from beginning\n');
    mysum=0;
    countmin=1;
    rng(0);     %Seed the random number generator for reproducible results
end

for count = countmin:100
    %Create and store 100 million random numbers for no particular reason
    randoms = rand(10000);
    mysum = mysum + rand();
    countmin = count+1;  %If we load this checkpoint, we want to start on the next iteration
    fprintf('Completed iteration %d \n',count);
    
    if mod(count,25)==0 %checkpoint every 25th iteration
        %save the state of the random number generator
        stream = RandStream.getGlobalStream;
        savedState = stream.State;
        %save and time checkpoint
        tic
        %only save the variables that are strictly necessary
        save('checkpoint_tmp.mat','mysum','countmin','savedState');
        %Ensure that the save completed
        if strcmp(computer,'PCWIN64') || strcmp(computer,'PCWIN')
            %We are running on a windows machine
            system( 'move /y checkpoint_tmp.mat checkpoint.mat' );
        else
            %We are running on Linux or Mac
            system( 'mv checkpoint_tmp.mat checkpoint.mat' );
        end
        timing = toc;
        fprintf('Checkpoint save took %f seconds\n',timing);
    end
    
end
fprintf('The sum is %f \n',mysum);

Parallel checkpointing

If your code includes parallel regions using constructs such as parfor or spmd, you might have to do more work to checkpoint correctly. I haven’t considered any of the potential issues that may arise in such code in this article

Checkpointing checklist

Here’s a reminder of everything you need to consider

  • Test to ensure that the introduction of checkpointing doesn’t alter results
  • Don’t checkpoint too often
  • Take care when checkpointing code that involves random numbers – you need to explicitly save the state of the random number generator.
  • Take measures to ensure that the checkpoint save is completed
  • Only checkpoint what is necessary
  • Code that includes parallel regions might require extra care
  1. February 19th, 2014 at 20:00
    Reply | Quote | #1

    Useful guide!

    A pathological gotcha I’ve seen is with lots of embarrassingly parallel jobs running across a cluster with shared NFS disk space. If all jobs checkpoint every N iterations, they can thrash the network share and might lock up for ages (especially if NFS isn’t set up well, which seems common!). Even if the jobs aren’t quite in synch to start with, they can fall into lock-step and jam pretty quickly if writing a checkpoint takes some time. Randomizing whether to checkpoint or or not, e.g. “if rand < 1/mean_interval", fixed this problem for me.

  2. Mike Croucher
    February 20th, 2014 at 10:39
    Reply | Quote | #2

    Thanks for the tip, Iain — that’s not something I’ve come across yet. Much of the embarrassingly parallel work with MATLAB I help support is with a Condor pool — no shared file system — so we wouldn’t suffer from that problem.

  3. Steve L
    February 20th, 2014 at 19:52
    Reply | Quote | #3

    Rather than using RandStream.getGlobalStream (which at least in your last example could hide a bug [*]) you could use RNG as a safeguard. See example 1 on the documentation page for RNG to see what I mean.

    There’s one more tool that may be of use at least in the Ctrl-C case (not in the power failure case) and that is the onCleanup object. It’s whole purpose is to die, and do something “with its last breath.”

    [*] You’re assuming the State property of the random number generator is the only thing that needs to be saved. But what if you’re running code that for some reason uses a different random number generator algorithm, say changing to ‘multFibonacci’ instead of the default ‘twister’? The output from RNG avoids this by containing information about both the “status” of the generator (avoiding the heavily-overloaded term “state” here) and which generator is being used.