Save and Load Your PyTorch Models


A deep learning model serves as a mathematical representation of data, encompassing numerous parameters. Training these parameters can take considerable time, from hours to even weeks. Once trained, these models can be used for inference, which is the process of applying the model to new data. It’s crucial to understand how to save the trained model to disk and subsequently load it for inference. This post will guide you through saving and loading your PyTorch models, helping you achieve efficient model management. By the end, you will understand:

  • The difference between states and parameters in a PyTorch model.
  • How to save model states.
  • How to load model states.

Let’s Get Started

Overview

This article is organized into three main sections:

  1. Building an Example Model
  2. Understanding the Composition of a PyTorch Model
  3. Accessing the state_dict of a Model

Building an Example Model

We’ll start by creating a simple model using the Iris dataset. This dataset will be loaded via Scikit-learn, where the targets correspond to integer labels (0, 1, and 2). We will train a neural network to tackle this multiclass classification challenge. The model will utilize a log softmax activation at the output to work effectively with the negative log likelihood loss function, which is analogous to using cross-entropy loss without an output activation.

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Load data into NumPy arrays
data = load_iris()
X, y = data["data"], data["target"]

# Convert NumPy arrays into PyTorch tensors
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

# Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, shuffle=True)

# Define the PyTorch model
class Multiclass(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(4, 8)
        self.act = nn.ReLU()
        self.output = nn.Linear(8, 3)
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.act(self.hidden(x))
        x = self.logsoftmax(self.output(x))
        return x

model = Multiclass()

# Specify the loss function and optimizer
loss_fn = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Prepare for model training
n_epochs = 100
batch_size = 5
batch_start = torch.arange(0, len(X), batch_size)

# Training loop
for epoch in range(n_epochs):
    for start in batch_start:
        X_batch = X_train[start:start + batch_size]
        y_batch = y_train[start:start + batch_size]
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Evaluate the model on the test set
y_pred = model(X_test)
acc = (torch.argmax(y_pred, 1) == y_test).float().mean()
print("Accuracy: %.2f" % acc)

Understanding the Composition of a PyTorch Model

A PyTorch model is effectively an object in Python, consisting of various deep learning constructs such as layers and activation functions. It knows how to connect these components to produce outputs from input tensors. While the architecture and algorithms are fixed upon creation, the model has trainable parameters that are updated during training to enhance accuracy.

To access the model’s parameters, you can utilize model.parameters(), which yields a generator referencing each layer’s trainable parameters in the form of PyTorch tensors. This allows for copying or modifying them, as demonstrated below:

# Create a new model
newmodel = Multiclass()
# Copy parameters from the original model
with torch.no_grad():
    for new_tensor, old_tensor in zip(newmodel.parameters(), model.parameters()):
        new_tensor.copy_(old_tensor)

# Validate the new model's performance
y_pred = newmodel(X_test)
acc = (torch.argmax(y_pred, 1) == y_test).float().mean()
print("Accuracy: %.2f" % acc)

However, some models might include non-trainable parameters, such as those associated with batch normalization layers. These layers normalize outputs from previous layers and use learned statistics (mean and standard deviation) during training, but these parameters are not updated by the optimizer.

Accessing state_dict of a Model

To retrieve all parameters (both trainable and non-trainable) of a model, you can use the state_dict() function. Here’s how you can inspect the state of the model:

import pprint
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(model.state_dict())

The output will display an OrderedDict, enabling you to map variable names to their corresponding parameters accurately.

For saving and loading models, you can fetch these states, serialize them, and save to disk. During inference, you can create a model instance and load the saved states.

Saving and Loading Model States

import pickle

# Save model
with open("iris-model.pickle", "wb") as fp:
    pickle.dump(model.state_dict(), fp)

# Create new model and load states
newmodel = Multiclass()
with open("iris-model.pickle", "rb") as fp:
    newmodel.load_state_dict(pickle.load(fp))

# Validate the new model
y_pred = newmodel(X_test)
acc = (torch.argmax(y_pred, 1) == y_test).float().mean()
print("Accuracy: %.2f" % acc)

For simplicity and efficiency, the recommended method is to use the PyTorch API to save and load states:

# Save model
torch.save(model.state_dict(), "iris-model.pth")

# Create new model and load states
newmodel = Multiclass()
newmodel.load_state_dict(torch.load("iris-model.pth"))

# Validate the new model
y_pred = newmodel(X_test)
acc = (torch.argmax(y_pred, 1) == y_test).float().mean()
print("Accuracy: %.2f" % acc)

The .pth file is essentially a zip file containing pickled files created by PyTorch, allowing it to store additional relevant information.

If you prefer to save the entire model (though it’s not recommended due to potential code dependency issues), you can do so as follows:

# Save the entire model
torch.save(model, "iris-model-full.pth")

# Load the model
newmodel = torch.load("iris-model-full.pth")

# Validate the new model
y_pred = newmodel(X_test)
acc = (torch.argmax(y_pred, 1) == y_test).float().mean()
print("Accuracy: %.2f" % acc)

Remember, Python requires the model’s class definition to load it successfully, so always keep the model code available.

Complete Example

Here’s the code that combines all the steps of creating, training, and saving a model:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Load data into NumPy arrays
data = load_iris()
X, y = data["data"], data["target"]

# Convert NumPy array into PyTorch tensors
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

# Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, shuffle=True)

# Define the PyTorch model
class Multiclass(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(4, 8)
        self.act = nn.ReLU()
        self.output = nn.Linear(8, 3)
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.act(self.hidden(x))
        x = self.logsoftmax(self.output(x))
        return x

model = Multiclass()

# Specify the loss function and optimizer
loss_fn = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Prepare model and training parameters
n_epochs = 100
batch_size = 5
batch_start = torch.arange(0, len(X), batch_size)

# Training loop
for epoch in range(n_epochs):
    for start in batch_start:
        X_batch = X_train[start:start + batch_size]
        y_batch = y_train[start:start + batch_size]
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# Save model
torch.save(model.state_dict(), "iris-model.pth")

To load the model and run it for inference, use the following code:

# Load data into NumPy arrays
data = load_iris()
X, y = data["data"], data["target"]

# Convert NumPy array into PyTorch tensors
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

# Define the PyTorch model
class Multiclass(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(4, 8)
        self.act = nn.ReLU()
        self.output = nn.Linear(8, 3)
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.act(self.hidden(x))
        x = self.logsoftmax(self.output(x))
        return x

# Create new model and load states
model = Multiclass()
model.load_state_dict(torch.load("iris-model.pth"))

# Run model for inference
y_pred = model(X_test)
acc = (torch.argmax(y_pred, 1) == y_test).float().mean()
print("Accuracy: %.2f" % acc)

Further Reading

For more insights into saving and loading models in PyTorch, check out the official PyTorch Saving and Loading Models tutorial.

Summary

In this article, you learned how to save and load your trained PyTorch models effectively and the importance of distinguishing between parameters and states in a PyTorch model. You also discovered how to save all necessary model states to disk and rebuild a working model from those saved states.


This rewritten version captures the essence of the original article while improving clarity, coherence, and readability.

Leave a Comment