05 - Training Loop¶
What this session is¶
About an hour. Train the MLP from page 04 to recognize MNIST digits. By the end you'll have written a full training loop - the same shape as every PyTorch training loop in existence.
The pattern¶
Every training loop is:
For each epoch (pass over the data):
For each batch:
1. Forward pass - compute predictions
2. Compute loss - how wrong are we?
3. Backward pass - compute gradients
4. Optimizer step - adjust weights
That's it. The rest is bookkeeping.
Load MNIST¶
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.ToTensor(), # PIL image → tensor
transforms.Normalize((0.1307,), (0.3081,)), # standardize: (x - mean) / std
])
train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=512)
Three pieces: - Dataset - knows how to load and transform one example. - DataLoader - wraps the dataset, batches it, optionally shuffles. - Transform - preprocessing applied to each example.
torchvision provides MNIST out of the box. First run downloads ~10MB; subsequent runs use the cached copy.
The full training script¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Reproducibility
torch.manual_seed(42)
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")
# Data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=512)
# Model
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = x.view(x.size(0), -1) # flatten 28x28 → 784
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
model = MLP().to(device)
# Loss + optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Train
for epoch in range(3):
model.train()
total_loss = 0
correct = 0
n = 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
# 1. Forward
logits = model(x)
# 2. Loss
loss = criterion(logits, y)
# 3. Backward
optimizer.zero_grad() # clear gradients from last step
loss.backward() # compute gradients
# 4. Optimizer step
optimizer.step()
total_loss += loss.item() * x.size(0)
correct += (logits.argmax(dim=1) == y).sum().item()
n += x.size(0)
print(f"epoch {epoch}: train loss {total_loss/n:.4f}, acc {correct/n:.4f}")
# Test
model.eval()
correct = 0
n = 0
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
logits = model(x)
correct += (logits.argmax(dim=1) == y).sum().item()
n += x.size(0)
print(f"test accuracy: {correct/n:.4f}")
Run. After ~30 seconds (on CPU) or ~5 seconds (on GPU), you should see ~97% test accuracy. Your first trained model.
What each line is doing¶
x.view(x.size(0), -1) - flatten the 28x28 images into 784-length vectors. The -1 infers the dimension. x.size(0) is the batch dimension.
nn.CrossEntropyLoss - standard loss for classification. Internally: softmax + negative log-likelihood. Stable and standard.
optimizer = torch.optim.Adam(...) - the optimizer. Adam is the most-used optimizer for modern ML. Other options: SGD (stochastic gradient descent - classic but needs more tuning), AdamW (Adam with corrected weight decay).
optimizer.zero_grad() - clear gradients from the last batch. PyTorch accumulates gradients by default; you must clear them explicitly each step. Forget this and your gradients grow unbounded.
loss.backward() - autograd walks the computation graph backward, computing gradients for every parameter that participated in computing loss. Stores them in parameter.grad.
optimizer.step() - applies the gradient update. For Adam: complex math; for SGD: param = param - lr * param.grad.
model.train() vs model.eval() - switches the model's internal mode. Affects layers like dropout and batch norm that behave differently during training vs inference. Always set the right mode.
with torch.no_grad(): - disables autograd. Inference doesn't need gradients; this skips the bookkeeping and uses less memory.
What the loss tells you¶
Training loss going down = model is fitting the training data. Plateauing means we've hit the model's capacity (or the optimizer is stuck - try different hyperparams).
Test accuracy is what you actually care about - performance on data the model didn't see during training. If train loss keeps dropping but test accuracy stops improving, you're overfitting - memorizing the training data.
Mitigations: more data, regularization (dropout, weight decay), smaller model, early stopping.
Hyperparameters¶
The numbers you set that aren't learned:
- Learning rate (lr=1e-3) - how big a step the optimizer takes. Too high → unstable. Too low → slow. 1e-3 is a great starting point for Adam.
- Batch size (64) - larger = smoother gradients, more memory; smaller = noisier gradients, sometimes generalizes better.
- Epochs - how many passes over the data. More = better fit (until overfitting).
- Architecture - depth, width, normalizations.
These need tuning. For MNIST, defaults work. For real problems, expect to iterate.
Save the model¶
Load later:
What this scales to¶
The training loop pattern above is identical for huge models - the only thing that changes is the model definition. Add a few wrinkles for big-model training (mixed precision, gradient accumulation, distributed, checkpointing) and you have what a real LLM training script looks like.
Exercise¶
-
Run the script. Get ~97% test accuracy.
-
Tweak hyperparameters:
- Set
lr=1e-2. Does it train? (Probably loss goes NaN - too high.) - Set
lr=1e-5. (Trains but slowly.) -
Increase epochs to 10. Does test accuracy improve? Plateau? Degrade (overfit)?
-
Architecture changes:
- Add a third hidden layer of size 128.
- Increase hidden sizes to 256, 128.
-
Watch parameter count + accuracy change.
-
Visualize: plot the per-epoch training loss using matplotlib. (Or use TensorBoard -
pip install tensorboardthentensorboard --logdir runs/.) -
Stretch: instead of an MLP, try a small CNN. A 2-layer CNN beats this MLP at >98% accuracy. Look up
nn.Conv2d. Don't worry if it doesn't work first try - convolutions take some adjusting.
What you might wonder¶
"Why Adam over SGD?" Adam adapts the learning rate per-parameter; converges faster on most problems without tuning. SGD with momentum can outperform on specific architectures (vision CNNs) when carefully tuned. For getting started: Adam.
"How big should my batch be?" For GPU work, "as big as fits in memory" is a common heuristic. Common sizes: 32, 64, 128, 256. Larger batches give smoother gradients but you might need more epochs to converge.
"What does loss.item() do?"
Extracts a Python float from a 0-dim tensor. Detached from the graph (no gradient tracking). Use when you want a number for logging.
"Why is my loss not going down?" Common causes: - Learning rate too high (loss NaN) or too low (loss flat). - Wrong loss function for your task. - Bug in the model (wrong shapes - print them). - Data not normalized.
Print loss every step initially. If it's not going down within ~100 steps on MNIST, something's wrong.
Done¶
- The four-step training loop pattern.
- DataLoader for batched iteration.
nn.CrossEntropyLoss+ Adam optimizer.optimizer.zero_grad()/loss.backward()/optimizer.step().model.train()vsmodel.eval().- Save/load weights.