Visualizing metrics during training offers valuable insights into neural networks and deep learning models. For example, if training accuracy worsens over time, it could indicate optimization issues like a high learning rate. This guide explains how to track and plot performance metrics in PyTorch, helping you understand your model’s behavior.
What You’ll Learn:
- Key metrics to track during training.
- How to plot training and validation metrics.
- How to interpret these plots to assess performance.
1. Collecting Metrics During Training
Training a model with gradient descent involves three steps:
- Forward pass: Compute the loss.
- Backward pass: Calculate gradients.
- Update: Adjust parameters using gradients.
Key Metrics to Track:
- Regression Problems: Metrics like Mean Squared Error (MSE), Root Mean Squared Error (RMSE), and Mean Absolute Error (MAE) are common.
- Classification Problems: Accuracy, precision, recall, F1 scores, and true positive rate provide interpretability beyond Cross-Entropy loss.
Example: Tracking MSE for Regression
Here’s an example using PyTorch and the California housing dataset:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
# Load and preprocess data
data = fetch_california_housing()
X, y = data.data, data.target
X_train_raw, X_test_raw, y_train, y_test = train_test_split(X, y, train_size=0.7, shuffle=True)
scaler = StandardScaler()
scaler.fit(X_train_raw)
X_train = scaler.transform(X_train_raw)
X_test = scaler.transform(X_test_raw)
# Convert data to PyTorch tensors
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32).reshape(-1, 1)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32).reshape(-1, 1)
# Define model, loss, and optimizer
model = nn.Sequential(
nn.Linear(8, 24), nn.ReLU(),
nn.Linear(24, 12), nn.ReLU(),
nn.Linear(12, 6), nn.ReLU(),
nn.Linear(6, 1)
)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Training loop
n_epochs = 100
batch_size = 32
batch_start = torch.arange(0, len(X_train), batch_size)
mse_history = []
for epoch in range(n_epochs):
for start in batch_start:
X_batch = X_train[start:start+batch_size]
y_batch = y_train[start:start+batch_size]
y_pred = model(X_batch)
loss = loss_fn(y_pred, y_batch)
mse_history.append(float(loss))
optimizer.zero_grad()
loss.backward()
optimizer.step()
Tracking Additional Metrics
Extend the loop to evaluate on the test set and track metrics like MAE:
mae_fn = nn.L1Loss()
train_mse_history, test_mse_history, test_mae_history = [], [], []
for epoch in range(n_epochs):
model.train()
epoch_mse = []
for start in batch_start:
X_batch = X_train[start:start+batch_size]
y_batch = y_train[start:start+batch_size]
y_pred = model(X_batch)
loss = loss_fn(y_pred, y_batch)
epoch_mse.append(float(loss))
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_mse_history.append(sum(epoch_mse) / len(epoch_mse))
# Validation step
model.eval()
with torch.no_grad():
y_pred = model(X_test)
mse = loss_fn(y_pred, y_test)
mae = mae_fn(y_pred, y_test)
test_mse_history.append(float(mse))
test_mae_history.append(float(mae))
2. Plotting the Training History
Visualize collected metrics using matplotlib:
import matplotlib.pyplot as plt
import numpy as np
plt.plot(np.sqrt(train_mse_history), label="Train RMSE")
plt.plot(np.sqrt(test_mse_history), label="Test RMSE")
plt.plot(test_mae_history, label="Test MAE")
plt.xlabel("Epochs")
plt.legend()
plt.show()
Interpreting the Plots
- Convergence Speed: A steeper slope indicates faster convergence.
- Overfitting: When training metrics improve but validation metrics worsen.
- Plateauing: Flat curves suggest the model has converged.
For regression, metrics like MSE and MAE should decrease over time. In classification, accuracy should increase, and loss should decrease.
Summary
In this guide, you learned how to:
- Track relevant metrics during training.
- Visualize training and validation performance.
- Interpret plots to diagnose issues like overfitting or slow convergence.
With these tools, you can optimize your training process and improve model performance effectively.