Skip to content

05 - Training Loop

What this session is

About an hour. Train the MLP from page 04 to recognize MNIST digits. By the end you'll have written a full training loop - the same shape as every PyTorch training loop in existence.

The pattern

Every training loop is:

For each epoch (pass over the data):
    For each batch:
        1. Forward pass - compute predictions
        2. Compute loss - how wrong are we?
        3. Backward pass - compute gradients
        4. Optimizer step - adjust weights

That's it. The rest is bookkeeping.

Load MNIST

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),                          # PIL image → tensor
    transforms.Normalize((0.1307,), (0.3081,)),     # standardize: (x - mean) / std
])

train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_ds  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=512)

Three pieces: - Dataset - knows how to load and transform one example. - DataLoader - wraps the dataset, batches it, optionally shuffles. - Transform - preprocessing applied to each example.

torchvision provides MNIST out of the box. First run downloads ~10MB; subsequent runs use the cached copy.

The full training script

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Reproducibility
torch.manual_seed(42)

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")

# Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])
train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_ds  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=512)

# Model
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)             # flatten 28x28 → 784
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

model = MLP().to(device)

# Loss + optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Train
for epoch in range(3):
    model.train()
    total_loss = 0
    correct = 0
    n = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        # 1. Forward
        logits = model(x)

        # 2. Loss
        loss = criterion(logits, y)

        # 3. Backward
        optimizer.zero_grad()                  # clear gradients from last step
        loss.backward()                        # compute gradients

        # 4. Optimizer step
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        correct += (logits.argmax(dim=1) == y).sum().item()
        n += x.size(0)

    print(f"epoch {epoch}: train loss {total_loss/n:.4f}, acc {correct/n:.4f}")

# Test
model.eval()
correct = 0
n = 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        correct += (logits.argmax(dim=1) == y).sum().item()
        n += x.size(0)

print(f"test accuracy: {correct/n:.4f}")

Run. After ~30 seconds (on CPU) or ~5 seconds (on GPU), you should see ~97% test accuracy. Your first trained model.

What each line is doing

x.view(x.size(0), -1) - flatten the 28x28 images into 784-length vectors. The -1 infers the dimension. x.size(0) is the batch dimension.

nn.CrossEntropyLoss - standard loss for classification. Internally: softmax + negative log-likelihood. Stable and standard.

optimizer = torch.optim.Adam(...) - the optimizer. Adam is the most-used optimizer for modern ML. Other options: SGD (stochastic gradient descent - classic but needs more tuning), AdamW (Adam with corrected weight decay).

optimizer.zero_grad() - clear gradients from the last batch. PyTorch accumulates gradients by default; you must clear them explicitly each step. Forget this and your gradients grow unbounded.

loss.backward() - autograd walks the computation graph backward, computing gradients for every parameter that participated in computing loss. Stores them in parameter.grad.

optimizer.step() - applies the gradient update. For Adam: complex math; for SGD: param = param - lr * param.grad.

model.train() vs model.eval() - switches the model's internal mode. Affects layers like dropout and batch norm that behave differently during training vs inference. Always set the right mode.

with torch.no_grad(): - disables autograd. Inference doesn't need gradients; this skips the bookkeeping and uses less memory.

What the loss tells you

Training loss going down = model is fitting the training data. Plateauing means we've hit the model's capacity (or the optimizer is stuck - try different hyperparams).

Test accuracy is what you actually care about - performance on data the model didn't see during training. If train loss keeps dropping but test accuracy stops improving, you're overfitting - memorizing the training data.

Mitigations: more data, regularization (dropout, weight decay), smaller model, early stopping.

Hyperparameters

The numbers you set that aren't learned: - Learning rate (lr=1e-3) - how big a step the optimizer takes. Too high → unstable. Too low → slow. 1e-3 is a great starting point for Adam. - Batch size (64) - larger = smoother gradients, more memory; smaller = noisier gradients, sometimes generalizes better. - Epochs - how many passes over the data. More = better fit (until overfitting). - Architecture - depth, width, normalizations.

These need tuning. For MNIST, defaults work. For real problems, expect to iterate.

Save the model

torch.save(model.state_dict(), "mnist_mlp.pt")

Load later:

model = MLP().to(device)
model.load_state_dict(torch.load("mnist_mlp.pt"))
model.eval()

What this scales to

The training loop pattern above is identical for huge models - the only thing that changes is the model definition. Add a few wrinkles for big-model training (mixed precision, gradient accumulation, distributed, checkpointing) and you have what a real LLM training script looks like.

Going deeper

You can run a training loop now. This section is the part nobody teaches and everybody needs: what to do when the loss won't go down. This is the daily reality of training - the loop runs, no error, but the model isn't learning. Here's the diagnosis tree, with what each failure looks like.

Read the loss curve - it tells you what's wrong

Before fixing anything, look at the loss over steps. The shape is the diagnosis:

HEALTHY:           loss falls fast, then slowly flattens
  3.2 -> 2.1 -> 1.4 -> 1.1 -> 0.95 -> 0.88 ...

FLAT (not learning): barely moves
  3.2 -> 3.19 -> 3.20 -> 3.18 ...        # learning rate too low, or a bug

EXPLODING (diverging): shoots up, often to NaN
  3.2 -> 5.1 -> 19.4 -> 412 -> nan       # learning rate too high

ZIGZAG (unstable): bounces around, no trend
  3.2 -> 1.1 -> 4.5 -> 0.9 -> 5.0 ...    # learning rate slightly too high, or bad data

OVERFITTING: train loss falls, val loss rises
  train 0.1, val 2.3 (and climbing)      # memorizing, not generalizing

Always print and watch both training and validation loss. The single most informative thing you can do. The curve's shape points directly at the cause below.

The loss-won't-decrease checklist (in order)

When loss is flat, work this list top to bottom - the causes are roughly in order of likelihood:

  1. Did you forget optimizer.zero_grad()? Gradients accumulate in PyTorch by default. Without zeroing them each step, they pile up and training goes haywire. Missing zero_grad() is the #1 silent training bug.
  2. Learning rate wrong. Flat loss = LR too low (try 10x higher). Exploding/NaN loss = LR too high (try 10x lower). LR is the highest-leverage knob; if in doubt, sweep it (1e-2, 1e-3, 1e-4). Most "it won't learn" problems are LR.
  3. Did you call loss.backward() and optimizer.step()? Forgetting either means no gradients flow or no weights update - the model literally never changes. Print a weight before and after a step to confirm it moved.
  4. Is the model in train() mode? model.eval() left on disables dropout and freezes batchnorm; model.train() must be set for training. Forgetting to switch back after validation is a classic.
  5. Data/label mismatch. Are inputs and labels actually aligned? A shuffling bug or off-by-one between x and y makes the task impossible - loss stays at random-guess level. Sanity check: can the model overfit a single batch? (Next section - the best test there is.)
  6. Wrong loss function. CrossEntropyLoss expects raw logits (it applies softmax internally) - passing it already-softmaxed values, or using the wrong loss for the task, quietly breaks learning.

The single best debugging test: overfit one batch

Before training on the full dataset, prove the loop can learn at all by making it memorize one tiny batch:

x, y = next(iter(loader))           # grab ONE batch
for step in range(200):             # train on JUST this batch, repeatedly
    optimizer.zero_grad()
    loss = criterion(model(x), y)
    loss.backward()
    optimizer.step()
    if step % 20 == 0: print(step, loss.item())
0   2.31
20  0.84
40  0.12
60  0.01        # loss -> ~0 = the loop WORKS, the model CAN learn

If the loss drops to near zero, your training loop, model, loss function, and optimizer are all wired correctly - any remaining problem is data or hyperparameters. If it can't even memorize one batch, you have a fundamental bug (one of the checklist items) - and you've isolated it to the loop, not the data, in 10 seconds. Overfitting one batch is the first thing experienced people do when training misbehaves. It separates "the plumbing is broken" from "the plumbing works, tune it."

What you'll see: the loss is NaN

Loss suddenly becomes nan and stays there - training is dead:

step 40  loss 1.82
step 41  loss 14.3
step 42  loss nan        # diverged - everything downstream is now nan

Causes, in order: learning rate too high (most common - lower it 10x), exploding gradients (add gradient clipping: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)), a log(0) or divide-by-zero in a custom loss, or bad input data (a nan/inf in the dataset - check with torch.isnan(x).any()). Once loss is nan, every subsequent step is nan (nan propagates), so catch it early - the moment loss spikes before nan is your clue it's divergence, not data.

Try it (with what you'll see)

  1. Take your working training loop. Delete optimizer.zero_grad(). Watch the loss behave erratically or stall. Restore it. Feel the #1 bug.

  2. Run the overfit-one-batch test on your model. Confirm loss -> ~0 in a couple hundred steps. (If it doesn't, you've found a real bug - work the checklist.)

  3. Crank the learning rate 100x. Watch the loss explode to nan. Then drop it 100x below normal and watch the loss go flat. See both failure modes bracket the healthy range.

  4. Print a model weight (next(model.parameters())[0,0].item()) before and after one optimizer.step(). Confirm it changed - proof the update is happening.

Exercise

  1. Run the script. Get ~97% test accuracy.

  2. Tweak hyperparameters:

  3. Set lr=1e-2. Does it train? (Probably loss goes NaN - too high.)
  4. Set lr=1e-5. (Trains but slowly.)
  5. Increase epochs to 10. Does test accuracy improve? Plateau? Degrade (overfit)?

  6. Architecture changes:

  7. Add a third hidden layer of size 128.
  8. Increase hidden sizes to 256, 128.
  9. Watch parameter count + accuracy change.

  10. Visualize: plot the per-epoch training loss using matplotlib. (Or use TensorBoard - pip install tensorboard then tensorboard --logdir runs/.)

  11. Stretch: instead of an MLP, try a small CNN. A 2-layer CNN beats this MLP at >98% accuracy. Look up nn.Conv2d. Don't worry if it doesn't work first try - convolutions take some adjusting.

What you might wonder

"Why Adam over SGD?" Adam adapts the learning rate per-parameter; converges faster on most problems without tuning. SGD with momentum can outperform on specific architectures (vision CNNs) when carefully tuned. For getting started: Adam.

"How big should my batch be?" For GPU work, "as big as fits in memory" is a common heuristic. Common sizes: 32, 64, 128, 256. Larger batches give smoother gradients but you might need more epochs to converge.

"What does loss.item() do?" Extracts a Python float from a 0-dim tensor. Detached from the graph (no gradient tracking). Use when you want a number for logging.

"Why is my loss not going down?" Common causes: - Learning rate too high (loss NaN) or too low (loss flat). - Wrong loss function for your task. - Bug in the model (wrong shapes - print them). - Data not normalized.

Print loss every step initially. If it's not going down within ~100 steps on MNIST, something's wrong.

Done

  • The four-step training loop pattern.
  • DataLoader for batched iteration.
  • nn.CrossEntropyLoss + Adam optimizer.
  • optimizer.zero_grad() / loss.backward() / optimizer.step().
  • model.train() vs model.eval().
  • Save/load weights.

Next: Inference and saving →

Comments