May 31, 2019 · 3 min read

Where Did I Put My Loss Values?

How to add metrics and key-values generated while training a deep learning model in PyTorch to your checkpoints.

Originally published on Medium

Checkpointing

Saving and loading PyTorch models is super easy and intuitive.

Save model weights:

torch.save(model.state_dict(), PATH)

Load model weights:

model = TheModelClass()
model.load_state_dict(torch.load(PATH))

This PyTorch tutorial also explains how to save the optimizer state and other good tips on this topic.

But what about all the other parameters?

The above saving and loading examples are a good way to load your model in order to use it for inference in test time, or in case you are using a pre-trained model for finetuning.

Now let’s say you train a very deep model on a very large dataset, it’s gonna take a lot of time and in case you are using cloud instances, it also costs lots of money. A good solution for saving money will be to use spot instances which gives a significant discount. Chances are you’re going to lose your instance somewhere in the middle of the training process. Resume training using the model weights and optimizer state alone is not enough. You will also need all the metrics you measured, such as the loss and accuracy values for each set, the best top-k-accuracy on a validation set, the number of epochs, samples or iterations done until the stop point and any other special value you keep track on.

Those are needed in order to continue reporting the learning progress curves after you resume training, and for other on the fly decisions like learning rate reduction based on the validation loss which stopped improving.

Fully resuming a learning experiment is also important for reproducible reasons, publishing a paper and a code base. In addition, being able to start from any “point” on the convergence curve, continue the learning process with or without changing any hyperparameters is crucial for research of a new task or field.


I am using PyTorch and working on large video datasets, such as Kinetics. In the process of looking for a full checkpoint solution for my research, I started working on such checkpoint handler for PyTorch and recently released a package to the Python Package Index (PyPI): https://pypi.org/project/pytorchcheckpoint/ and also a GitHub repository: https://github.com/bomri/pytorch-checkpoint.

In order to install the package:

pip install pytorchcheckpoint

Then, somewhere at the beginning of your training code initiate the class:

from pytorchcheckpoint.checkpoint import CheckpointHandler
checkpoint_handler = CheckpointHandler()

Now, in addition to saving your model weights and optimizer state, you can add any other value in any step of the learning process.

For example, in order to save the number of classes you can run:

# saving
checkpoint_handler.store_var(var_name='n_classes', value=1000)

# restoring
n_classes = checkpoint_handler.get_var(var_name='n_classes')

In addition, you can store values and metrics:

  • per set: training/validation/test
  • for each number of epochs/samples/iterations

For example, the top-1-accuracy value of the train and validation sets for each epoch can be stored by using:

# train set - top1
checkpoint_handler.store_running_var_with_header(header='train', var_name='top1', iteration=0, value=80)
checkpoint_handler.store_running_var_with_header(header='train', var_name='top1', iteration=1, value=85)
checkpoint_handler.store_running_var_with_header(header='train', var_name='top1', iteration=2, value=90)
checkpoint_handler.store_running_var_with_header(header='train', var_name='top1', iteration=3, value=91)

# valid set - top1
checkpoint_handler.store_running_var_with_header(header='valid', var_name='top1', iteration=0, value=70)
checkpoint_handler.store_running_var_with_header(header='valid', var_name='top1', iteration=1, value=75)
checkpoint_handler.store_running_var_with_header(header='valid', var_name='top1', iteration=2, value=80)
checkpoint_handler.store_running_var_with_header(header='valid', var_name='top1', iteration=3, value=85)

Saving and loading the full checkpoint is done with a single line of code:

# save checkpoint
checkpoint_handler.save_checkpoint(checkpoint_path=path, iteration=25, model=model)

# load checkpoint
checkpoint_handler = checkpoint_handler.load_checkpoint(path)

You can check the pytorch-checkpoint README for more useful examples.

So the next time you fire up a training process make sure you have those loss values at hand for any rainy day.