Deep learning has proven to be an effective approach for object recognition in image data. One of the quintessential tasks in this domain is handwritten digit recognition, typically demonstrated through the MNIST dataset. In this article, you will learn how to develop a deep learning model that achieves near state-of-the-art performance on the MNIST digit classification task using the LeNet5 architecture in PyTorch.
By the end of this article, you will understand how to:
- Load the MNIST dataset using torchvision.
- Develop and evaluate a baseline neural network model for handwriting recognition.
- Implement a convolutional neural network (CNN) specifically for the MNIST dataset.
- Utilize the LeNet5 model for digit classification.
Let’s get started!
Overview
This tutorial is organized into five segments:
- The MNIST Handwritten Digit Recognition Problem
- Loading the MNIST Dataset in PyTorch
- Building a Baseline Model with Multilayer Perceptrons
- Creating a Simple Convolutional Neural Network for MNIST
- Implementing the LeNet5 Architecture for MNIST
The MNIST Handwritten Digit Recognition Problem
The MNIST challenge serves as a classic example for assessing the capabilities of convolutional neural networks (CNNs). Developed by Yann LeCun, Corinna Cortes, and Christopher Burges, the MNIST dataset is derived from scanned documents available from the National Institute of Standards and Technology (NIST). Each image in the dataset is a 28×28 pixel grayscale representation of a handwritten digit (0-9), with a structured split of 60,000 training and 10,000 testing samples.
The primary goal is to classify the digits accurately, with state-of-the-art models achieving accuracy levels of approximately 99.8%.
Loading the MNIST Dataset in PyTorch
PyTorch includes functions to easily load datasets like MNIST using the torchvision library. Here’s how you can load the training and testing data:
import torchvision.transforms as transforms
from torchvision import datasets
# Load the MNIST dataset
train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
This code snippet automatically downloads the dataset into the specified directory. For verification of the number of training and testing samples in the dataset, you can run:
print("Number of training samples: ", len(train_set))
print("Number of testing samples: ", len(test_set))
You’ll see:
Number of training samples: 60000
Number of testing samples: 10000
Each sample consists of an image and its corresponding label. To inspect the data type and size of the first training sample, use:
print("Data type of the first training sample: ", train_set[0][0].type())
print("Size of the first training sample: ", train_set[0][0].size())
The output will be:
Data type of the first training sample: torch.FloatTensor
Size of the first training sample: torch.Size([1, 28, 28])
This indicates that the first image is a grayscale image sized 28×28 pixels.
Building a Baseline Model with Multilayer Perceptrons
Before approaching convolutional networks, it’s essential to establish a baseline using a simple multilayer perceptron (MLP). This straightforward model provides a framework for understanding how more advanced architectures function.
You can create a basic MLP model as follows:
import torch.nn as nn
import torch.optim as optim
class Baseline(nn.Module):
def __init__(self):
super(Baseline, self).__init__()
self.layer1 = nn.Linear(784, 784) # Input to hidden layer
self.activation1 = nn.ReLU()
self.layer2 = nn.Linear(784, 10) # Hidden to output layer
def forward(self, x):
x = self.activation1(self.layer1(x))
x = self.layer2(x)
return x
Creating a Simple Convolutional Neural Network for MNIST
With a better understanding of the basics, let’s build a convolutional neural network (CNN) for the MNIST dataset. CNNs are designed to work effectively with image data, preserving spatial hierarchies.
Here’s an example of how to define a simple CNN architecture:
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=5) # First convolutional layer
self.pool = nn.MaxPool2d(kernel_size=2) # Pooling layer
self.fc1 = nn.Linear(32 * 12 * 12, 128) # Fully connected layer
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # Apply conv and activation
x = x.view(-1, 32 * 12 * 12) # Flatten
x = F.relu(self.fc1(x)) # Fully connected layer with activation
return x
Implementing the LeNet5 Architecture for MNIST
The LeNet5 model, developed specifically for digit recognition, uses multiple convolutional layers. Here’s how to implement it in PyTorch:
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2)
self.act1 = nn.Tanh()
self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0)
self.act2 = nn.Tanh()
self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(16, 120, kernel_size=5, stride=1, padding=0)
self.act3 = nn.Tanh()
self.flat = nn.Flatten()
self.fc1 = nn.Linear(1 * 1 * 120, 84)
self.act4 = nn.Tanh()
self.fc2 = nn.Linear(84, 10)
def forward(self, x):
x = self.act1(self.conv1(x))
x = self.pool1(x)
x = self.act2(self.conv2(x))
x = self.pool2(x)
x = self.act3(self.conv3(x))
x = self.act4(self.fc1(self.flat(x)))
x = self.fc2(x)
return x
Training the Model
Train the model using a DataLoader to efficiently manage batching. Set the loss function and optimizer before beginning the training loop:
# Set up DataLoaders for training and testing
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False)
model = LeNet5()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss() # Appropriate for multi-class classification
for epoch in range(10): # Train for 10 epochs
model.train()
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Validation
model.eval()
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
outputs = model(inputs)
predicted = torch.argmax(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
print(f'Epoch [{epoch + 1}/10]: Accuracy: {accuracy:.2f}%')
Summary
In this tutorial, you learned how to implement the LeNet5 model for handwritten digit recognition using the MNIST dataset in PyTorch. You covered:
- How to load the MNIST dataset using torchvision.
- How to develop and evaluate a baseline neural network model.
- How to implement and run a convolutional neural network for image classification.
Armed with this knowledge, you can now delve deeper into deep learning projects and explore the possibilities of building more advanced models using PyTorch.