Managing a PyTorch Training Process with Checkpoints and Early Stopping

Training large deep learning models can be time-intensive. Unexpected interruptions during training can result in lost progress, while prolonged training beyond a certain point may yield diminishing returns. This guide demonstrates how to manage PyTorch training loops effectively by using checkpoints and early stopping.

Key Takeaways:

  • Understand the importance of checkpointing in neural network training.
  • Learn to save and restore model checkpoints in PyTorch.
  • Implement early stopping techniques for efficient training.

Checkpointing Neural Network Models

Checkpointing involves saving the state of a system during a process to resume it later if needed. In deep learning, checkpoints typically store model weights, which can be used to resume training or make predictions. While PyTorch does not have a built-in checkpointing utility, it provides functions to save and load model weights, enabling checkpointing.

Implementing Checkpointing

Here’s a simple implementation to save and load model weights:

import torch

def checkpoint(model, filename):
    torch.save(model.state_dict(), filename)

def resume(model, filename):
    model.load_state_dict(torch.load(filename))

Training Example

The following example demonstrates training a binary classification model on the “electricity” dataset. It includes dataset preparation, model definition, and training with checkpointing:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split, default_collate
from sklearn.datasets import fetch_openml
from sklearn.preprocessing import LabelEncoder

# Load and preprocess data
data = fetch_openml("electricity", version=1, parser="auto")
X = data['data'].astype('float').values
y = LabelEncoder().fit_transform(data['target'])
X, y = torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32).reshape(-1, 1)

# Split data into training and testing sets
trainset, testset = random_split(TensorDataset(X, y), [int(0.7 * len(X)), int(0.3 * len(X))])
loader = DataLoader(trainset, shuffle=True, batch_size=32)
X_test, y_test = default_collate(testset)

# Define the model
model = nn.Sequential(
    nn.Linear(8, 12),
    nn.ReLU(),
    nn.Linear(12, 12),
    nn.ReLU(),
    nn.Linear(12, 1),
    nn.Sigmoid(),
)

# Training loop with checkpointing
n_epochs = 100
loss_fn = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

for epoch in range(n_epochs):
    model.train()
    for X_batch, y_batch in loader:
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    y_pred = model(X_test)
    acc = (y_pred.round() == y_test).float().mean()
    print(f"Epoch {epoch}: Accuracy = {float(acc)*100:.2f}%")

    # Save checkpoint
    checkpoint(model, f"epoch-{epoch}.pth")

Restoring Checkpoints

To resume training from a specific checkpoint, you can load the saved weights before starting the loop:

start_epoch = 0
if start_epoch > 0:
    resume_epoch = start_epoch - 1
    resume(model, f"epoch-{resume_epoch}.pth")

for epoch in range(start_epoch, n_epochs):
    # Training logic
    ...

Ensure consistent dataset splits when resuming training by saving and reusing the split data.


Checkpointing with Optimizer States

In some cases, it is necessary to save the optimizer state (e.g., for optimizers like Adam). This can be achieved by extending the checkpoint function:

def checkpoint(model, optimizer, filename):
    torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename)

def resume(model, optimizer, filename):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])

Early Stopping with Checkpointing

Checkpointing can also identify and save the best model during training. Modify the training loop to track the best accuracy and save the corresponding model:

best_accuracy = -1
for epoch in range(n_epochs):
    model.train()
    for X_batch, y_batch in loader:
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    y_pred = model(X_test)
    acc = (y_pred.round() == y_test).float().mean() * 100
    print(f"Epoch {epoch}: Accuracy = {acc:.2f}%")

    if acc > best_accuracy:
        best_accuracy = acc
        checkpoint(model, optimizer, "best_model.pth")

After training, restore the best model:

resume(model, optimizer, "best_model.pth")

Implementing Early Stopping

Early stopping terminates training when further improvements are unlikely, preventing overfitting:

early_stop_thresh = 5
best_accuracy = -1
best_epoch = -1

for epoch in range(n_epochs):
    model.train()
    for X_batch, y_batch in loader:
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    y_pred = model(X_test)
    acc = (y_pred.round() == y_test).float().mean() * 100
    print(f"Epoch {epoch}: Accuracy = {acc:.2f}%")

    if acc > best_accuracy:
        best_accuracy = acc
        best_epoch = epoch
        checkpoint(model, optimizer, "best_model.pth")
    elif epoch - best_epoch > early_stop_thresh:
        print(f"Early stopping at epoch {epoch}")
        break

Summary

In this guide, we covered:

  • The concept and benefits of checkpointing in PyTorch.
  • How to implement and use checkpoints for resuming training.
  • Techniques to integrate early stopping for efficient model training.

By implementing these strategies, you can streamline your PyTorch training process, reduce training time, and ensure optimal model performance. Customize these techniques as needed to meet your project’s specific requirements.

Leave a Comment