A practical example of how to save and load a model in PyTorch. We are going to look at how to continue training and load the model for inference
The goal of this article is to show you how to save a model and load it to continue training after the previous epoch and make a prediction. If you are reading this article, I assume you are familiar with the basics of deep learning and PyTorch.
Have you experienced a situation where you spend hours or days training your model and then it stops in the middle? Or you are not satisfied with your model performance and want to train the model again? There are multiple reasons why we might need a flexible way to save and load our model.
Most of the free cloud services such as Kaggle, Google Colab, etc have idle time-outs that will disconnect your notebook, plus the notebook will be disconnected or interrupted once it reaches its limit time. Unless you train for a small number of epochs with GPU, the process takes time. Being able to save the model gives you a huge advantage and saves the day. To be flexible, I am going to save both the latest checkpoint and the best checkpoint.
Fashion_MNIST_data will be used as our dataset and we’ll write a complete flow from import data to make the prediction. In this exercise, I am going to use a Kaggle notebook.
Step 1: Setting up
- By default, in Kaggle, the notebook you are working on is called __notebook__.ipyn
- Create two directories to store checkpoint and the best model:
Step 2: Importing libraries and creating helper functions
Importing libraries
Saving function
save_ckp is created to save checkpoint, the latest one and the best one. This creates flexibility: either you are interested in the state of the latest checkpoint or the best checkpoint.
In our case, we want to save a checkpoint that allows us to use this information to continue our model training. Here is the information needed:
- epoch: a measure of the number of times all the training vectors are used once to update the weights.
- valid_loss_min: the minimum validation loss, this is needed so that when we continue the training, we can start with this rather than np.Inf value.
- state_dict: model architecture information. It includes the parameter matrices for each of the layers.
- optimizer: You need to save optimizer parameters especially when you are using Adam as your optimizer. Adam is an adaptive learning rate method, which means, it computes individual learning rates for different parameters which we would need if we want to continue our training from where we left off [2].
Loading function
load_chkp is created for the loading model. It takes:
- location of the saved checkpoint
- the model instance that you want to load the state to
- the optimizer
Step 3: Importing dataset Fashion_MNIST_data and creating data loader
Step 4: Defining and creating a model
I am using a simple network from [1]
Output:
FashionClassifier( (fc1): Linear(in_features=784, out_features=512, bias=True) (fc2): Linear(in_features=512, out_features=256, bias=True) (fc3): Linear(in_features=256, out_features=128, bias=True) (fc4): Linear(in_features=128, out_features=64, bias=True) (fc5): Linear(in_features=64, out_features=10, bias=True) (dropout): Dropout(p=0.2) )
Step 5: Training the network and saving the model
The train function gives us the ability to set the number of epochs, the learning rate, and other parameters.
Define loss function and optimizer
Below, we are using an Adam optimizer and cross-entropy loss since we are looking at character class scores as output. We calculate the loss and perform back-propagation.
Define train method
Train the model
Output:
Epoch: 1 Training Loss: 0.000010 Validation Loss: 0.000044 Validation loss decreased (inf --> 0.000044). Saving model ...
Epoch: 2 Training Loss: 0.000007 Validation Loss: 0.000040 Validation loss decreased (0.000044 --> 0.000040). Saving model ...
Epoch: 3 Training Loss: 0.000007 Validation Loss: 0.000040 Validation loss decreased (0.000040 --> 0.000040). Saving model ...
Let’s focus on a few parameters we used above:
- start_epoch: value start of the epoch for the training
- n_epochs: value end of the epoch for the training
- valid_loss_min_input = np.Inf
- checkpoint_path: full path to save the state of the latest checkpoint of the training
- best_model_path: full path to the best state of the latest checkpoint of the training
Verify if the model is saved
- List down all files in the best_model directory
Output:
best_model.pt
- List down all files in the checkpoint directory
Output:
current_checkpoint.pt
Step 6: Loading the model
Reconstruct the model
Output:
FashionClassifier( (fc1): Linear(in_features=784, out_features=512, bias=True) (fc2): Linear(in_features=512, out_features=256, bias=True) (fc3): Linear(in_features=256, out_features=128, bias=True) (fc4): Linear(in_features=128, out_features=64, bias=True) (fc5): Linear(in_features=64, out_features=10, bias=True) (dropout): Dropout(p=0.2) )
Define the optimizer and checkpoint file path
Load the model using load_ckp function
I printed out the values that we get from load_ckp just to make sure everything is correct.
Output:
model = FashionClassifier( (fc1): Linear(in_features=784, out_features=512, bias=True) (fc2): Linear(in_features=512, out_features=256, bias=True) (fc3): Linear(in_features=256, out_features=128, bias=True) (fc4): Linear(in_features=128, out_features=64, bias=True) (fc5): Linear(in_features=64, out_features=10, bias=True) (dropout): Dropout(p=0.2) ) optimizer = Adam ( Parameter Group 0 amsgrad: False betas: (0.9, 0.999) eps: 1e-08 lr: 0.001 weight_decay: 0 ) start_epoch = 4 valid_loss_min = 3.952759288949892e-05 valid_loss_min = 0.000040
After we load all the information we need, we can continue training, start_epoch = 4. Previously, we train the model from 1 to 3
Step 7: Continue Training and/or Inference
Continue training
We can continue to train our model using the train function and provide the values of checkpoint we get from the load_ckp function above.
Output:
Epoch: 4 Training Loss: 0.000006 Validation Loss: 0.000040 Epoch: 5 Training Loss: 0.000006 Validation Loss: 0.000037 Validation loss decreased (0.000040 --> 0.000037). Saving model ... Epoch: 6 Training Loss: 0.000006 Validation Loss: 0.000036 Validation loss decreased (0.000037 --> 0.000036). Saving model ...
- Notice: epoch now start from 4 to 6. (start_epoch = 4)
- The validation loss continues from the last training checkpoint.
- at epoch 3, min validation loss is 0.000040
- here, minimum validation loss starts with 0.000040 and not INF
Inference
Remember that you must call model.eval() to set dropout and batch, normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results [3].
Output:
Accuracy of the network on 10000 test images: 86.58%
Here is my notebook in Kaggle:

Reference:
- [1] S. David, Saving and Loading Models in PyTorch (2019), https://www.kaggle.com/davidashraf/saving-and-loading-models-in-pytorch
- [2] J. Rachit, Saving and Loading Your Model to Resume Training in PyTorch (2019), https://medium.com/analytics-vidhya/saving-and-loading-your-model-to-resume-training-in-pytorch-cb687352fa61
- [3] I. Matthew, SAVING AND LOADING MODELS (2017), https://pytorch.org/tutorials/beginner/saving_loading_models.html