05 - Training Loop¶
What this session is¶
About an hour. Train the MLP from page 04 to recognize MNIST digits. By the end you'll have written a full training loop - the same shape as every PyTorch training loop in existence.
The pattern¶
Every training loop is:
For each epoch (pass over the data):
For each batch:
1. Forward pass - compute predictions
2. Compute loss - how wrong are we?
3. Backward pass - compute gradients
4. Optimizer step - adjust weights
That's it. The rest is bookkeeping.
Load MNIST¶
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.ToTensor(), # PIL image → tensor
transforms.Normalize((0.1307,), (0.3081,)), # standardize: (x - mean) / std
])
train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=512)
Three pieces: - Dataset - knows how to load and transform one example. - DataLoader - wraps the dataset, batches it, optionally shuffles. - Transform - preprocessing applied to each example.
torchvision provides MNIST out of the box. First run downloads ~10MB; subsequent runs use the cached copy.
The full training script¶
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# Reproducibility
torch.manual_seed(42)
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")
# Data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
])
train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_ds = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=512)
# Model
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = x.view(x.size(0), -1) # flatten 28x28 → 784
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
model = MLP().to(device)
# Loss + optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Train
for epoch in range(3):
model.train()
total_loss = 0
correct = 0
n = 0
for x, y in train_loader:
x, y = x.to(device), y.to(device)
# 1. Forward
logits = model(x)
# 2. Loss
loss = criterion(logits, y)
# 3. Backward
optimizer.zero_grad() # clear gradients from last step
loss.backward() # compute gradients
# 4. Optimizer step
optimizer.step()
total_loss += loss.item() * x.size(0)
correct += (logits.argmax(dim=1) == y).sum().item()
n += x.size(0)
print(f"epoch {epoch}: train loss {total_loss/n:.4f}, acc {correct/n:.4f}")
# Test
model.eval()
correct = 0
n = 0
with torch.no_grad():
for x, y in test_loader:
x, y = x.to(device), y.to(device)
logits = model(x)
correct += (logits.argmax(dim=1) == y).sum().item()
n += x.size(0)
print(f"test accuracy: {correct/n:.4f}")
Run. After ~30 seconds (on CPU) or ~5 seconds (on GPU), you should see ~97% test accuracy. Your first trained model.
What each line is doing¶
x.view(x.size(0), -1) - flatten the 28x28 images into 784-length vectors. The -1 infers the dimension. x.size(0) is the batch dimension.
nn.CrossEntropyLoss - standard loss for classification. Internally: softmax + negative log-likelihood. Stable and standard.
optimizer = torch.optim.Adam(...) - the optimizer. Adam is the most-used optimizer for modern ML. Other options: SGD (stochastic gradient descent - classic but needs more tuning), AdamW (Adam with corrected weight decay).
optimizer.zero_grad() - clear gradients from the last batch. PyTorch accumulates gradients by default; you must clear them explicitly each step. Forget this and your gradients grow unbounded.
loss.backward() - autograd walks the computation graph backward, computing gradients for every parameter that participated in computing loss. Stores them in parameter.grad.
optimizer.step() - applies the gradient update. For Adam: complex math; for SGD: param = param - lr * param.grad.
model.train() vs model.eval() - switches the model's internal mode. Affects layers like dropout and batch norm that behave differently during training vs inference. Always set the right mode.
with torch.no_grad(): - disables autograd. Inference doesn't need gradients; this skips the bookkeeping and uses less memory.
What the loss tells you¶
Training loss going down = model is fitting the training data. Plateauing means we've hit the model's capacity (or the optimizer is stuck - try different hyperparams).
Test accuracy is what you actually care about - performance on data the model didn't see during training. If train loss keeps dropping but test accuracy stops improving, you're overfitting - memorizing the training data.
Mitigations: more data, regularization (dropout, weight decay), smaller model, early stopping.
Hyperparameters¶
The numbers you set that aren't learned: - Learning rate (lr=1e-3) - how big a step the optimizer takes. Too high → unstable. Too low → slow. 1e-3 is a great starting point for Adam. - Batch size (64) - larger = smoother gradients, more memory; smaller = noisier gradients, sometimes generalizes better. - Epochs - how many passes over the data. More = better fit (until overfitting). - Architecture - depth, width, normalizations.
These need tuning. For MNIST, defaults work. For real problems, expect to iterate.
Save the model¶
Load later:
What this scales to¶
The training loop pattern above is identical for huge models - the only thing that changes is the model definition. Add a few wrinkles for big-model training (mixed precision, gradient accumulation, distributed, checkpointing) and you have what a real LLM training script looks like.
Going deeper¶
You can run a training loop now. This section is the part nobody teaches and everybody needs: what to do when the loss won't go down. This is the daily reality of training - the loop runs, no error, but the model isn't learning. Here's the diagnosis tree, with what each failure looks like.
Read the loss curve - it tells you what's wrong¶
Before fixing anything, look at the loss over steps. The shape is the diagnosis:
HEALTHY: loss falls fast, then slowly flattens
3.2 -> 2.1 -> 1.4 -> 1.1 -> 0.95 -> 0.88 ...
FLAT (not learning): barely moves
3.2 -> 3.19 -> 3.20 -> 3.18 ... # learning rate too low, or a bug
EXPLODING (diverging): shoots up, often to NaN
3.2 -> 5.1 -> 19.4 -> 412 -> nan # learning rate too high
ZIGZAG (unstable): bounces around, no trend
3.2 -> 1.1 -> 4.5 -> 0.9 -> 5.0 ... # learning rate slightly too high, or bad data
OVERFITTING: train loss falls, val loss rises
train 0.1, val 2.3 (and climbing) # memorizing, not generalizing
Always print and watch both training and validation loss. The single most informative thing you can do. The curve's shape points directly at the cause below.
The loss-won't-decrease checklist (in order)¶
When loss is flat, work this list top to bottom - the causes are roughly in order of likelihood:
- Did you forget
optimizer.zero_grad()? Gradients accumulate in PyTorch by default. Without zeroing them each step, they pile up and training goes haywire. Missingzero_grad()is the #1 silent training bug. - Learning rate wrong. Flat loss = LR too low (try 10x higher). Exploding/NaN loss = LR too high (try 10x lower). LR is the highest-leverage knob; if in doubt, sweep it (1e-2, 1e-3, 1e-4). Most "it won't learn" problems are LR.
- Did you call
loss.backward()andoptimizer.step()? Forgetting either means no gradients flow or no weights update - the model literally never changes. Print a weight before and after a step to confirm it moved. - Is the model in
train()mode?model.eval()left on disables dropout and freezes batchnorm;model.train()must be set for training. Forgetting to switch back after validation is a classic. - Data/label mismatch. Are inputs and labels actually aligned? A shuffling bug or off-by-one between
xandymakes the task impossible - loss stays at random-guess level. Sanity check: can the model overfit a single batch? (Next section - the best test there is.) - Wrong loss function.
CrossEntropyLossexpects raw logits (it applies softmax internally) - passing it already-softmaxed values, or using the wrong loss for the task, quietly breaks learning.
The single best debugging test: overfit one batch¶
Before training on the full dataset, prove the loop can learn at all by making it memorize one tiny batch:
x, y = next(iter(loader)) # grab ONE batch
for step in range(200): # train on JUST this batch, repeatedly
optimizer.zero_grad()
loss = criterion(model(x), y)
loss.backward()
optimizer.step()
if step % 20 == 0: print(step, loss.item())
If the loss drops to near zero, your training loop, model, loss function, and optimizer are all wired correctly - any remaining problem is data or hyperparameters. If it can't even memorize one batch, you have a fundamental bug (one of the checklist items) - and you've isolated it to the loop, not the data, in 10 seconds. Overfitting one batch is the first thing experienced people do when training misbehaves. It separates "the plumbing is broken" from "the plumbing works, tune it."
What you'll see: the loss is NaN¶
Loss suddenly becomes nan and stays there - training is dead:
Causes, in order: learning rate too high (most common - lower it 10x), exploding gradients (add gradient clipping: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)), a log(0) or divide-by-zero in a custom loss, or bad input data (a nan/inf in the dataset - check with torch.isnan(x).any()). Once loss is nan, every subsequent step is nan (nan propagates), so catch it early - the moment loss spikes before nan is your clue it's divergence, not data.
Try it (with what you'll see)¶
-
Take your working training loop. Delete
optimizer.zero_grad(). Watch the loss behave erratically or stall. Restore it. Feel the #1 bug. -
Run the overfit-one-batch test on your model. Confirm loss -> ~0 in a couple hundred steps. (If it doesn't, you've found a real bug - work the checklist.)
-
Crank the learning rate 100x. Watch the loss explode to
nan. Then drop it 100x below normal and watch the loss go flat. See both failure modes bracket the healthy range. -
Print a model weight (
next(model.parameters())[0,0].item()) before and after oneoptimizer.step(). Confirm it changed - proof the update is happening.
Exercise¶
-
Run the script. Get ~97% test accuracy.
-
Tweak hyperparameters:
- Set
lr=1e-2. Does it train? (Probably loss goes NaN - too high.) - Set
lr=1e-5. (Trains but slowly.) -
Increase epochs to 10. Does test accuracy improve? Plateau? Degrade (overfit)?
-
Architecture changes:
- Add a third hidden layer of size 128.
- Increase hidden sizes to 256, 128.
-
Watch parameter count + accuracy change.
-
Visualize: plot the per-epoch training loss using matplotlib. (Or use TensorBoard -
pip install tensorboardthentensorboard --logdir runs/.) -
Stretch: instead of an MLP, try a small CNN. A 2-layer CNN beats this MLP at >98% accuracy. Look up
nn.Conv2d. Don't worry if it doesn't work first try - convolutions take some adjusting.
What you might wonder¶
"Why Adam over SGD?" Adam adapts the learning rate per-parameter; converges faster on most problems without tuning. SGD with momentum can outperform on specific architectures (vision CNNs) when carefully tuned. For getting started: Adam.
"How big should my batch be?" For GPU work, "as big as fits in memory" is a common heuristic. Common sizes: 32, 64, 128, 256. Larger batches give smoother gradients but you might need more epochs to converge.
"What does loss.item() do?" Extracts a Python float from a 0-dim tensor. Detached from the graph (no gradient tracking). Use when you want a number for logging.
"Why is my loss not going down?" Common causes: - Learning rate too high (loss NaN) or too low (loss flat). - Wrong loss function for your task. - Bug in the model (wrong shapes - print them). - Data not normalized.
Print loss every step initially. If it's not going down within ~100 steps on MNIST, something's wrong.
Done¶
- The four-step training loop pattern.
- DataLoader for batched iteration.
nn.CrossEntropyLoss+ Adam optimizer.optimizer.zero_grad()/loss.backward()/optimizer.step().model.train()vsmodel.eval().- Save/load weights.