Skip to content

05 - Training Loop

What this session is

About an hour. Train the MLP from page 04 to recognize MNIST digits. By the end you'll have written a full training loop - the same shape as every PyTorch training loop in existence.

The pattern

Every training loop is:

For each epoch (pass over the data):
    For each batch:
        1. Forward pass - compute predictions
        2. Compute loss - how wrong are we?
        3. Backward pass - compute gradients
        4. Optimizer step - adjust weights

That's it. The rest is bookkeeping.

Load MNIST

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor(),                          # PIL image → tensor
    transforms.Normalize((0.1307,), (0.3081,)),     # standardize: (x - mean) / std
])

train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_ds  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=512)

Three pieces: - Dataset - knows how to load and transform one example. - DataLoader - wraps the dataset, batches it, optionally shuffles. - Transform - preprocessing applied to each example.

torchvision provides MNIST out of the box. First run downloads ~10MB; subsequent runs use the cached copy.

The full training script

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Reproducibility
torch.manual_seed(42)

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")

# Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])
train_ds = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_ds  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=512)

# Model
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)             # flatten 28x28 → 784
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

model = MLP().to(device)

# Loss + optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Train
for epoch in range(3):
    model.train()
    total_loss = 0
    correct = 0
    n = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)

        # 1. Forward
        logits = model(x)

        # 2. Loss
        loss = criterion(logits, y)

        # 3. Backward
        optimizer.zero_grad()                  # clear gradients from last step
        loss.backward()                        # compute gradients

        # 4. Optimizer step
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        correct += (logits.argmax(dim=1) == y).sum().item()
        n += x.size(0)

    print(f"epoch {epoch}: train loss {total_loss/n:.4f}, acc {correct/n:.4f}")

# Test
model.eval()
correct = 0
n = 0
with torch.no_grad():
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        correct += (logits.argmax(dim=1) == y).sum().item()
        n += x.size(0)

print(f"test accuracy: {correct/n:.4f}")

Run. After ~30 seconds (on CPU) or ~5 seconds (on GPU), you should see ~97% test accuracy. Your first trained model.

What each line is doing

x.view(x.size(0), -1) - flatten the 28x28 images into 784-length vectors. The -1 infers the dimension. x.size(0) is the batch dimension.

nn.CrossEntropyLoss - standard loss for classification. Internally: softmax + negative log-likelihood. Stable and standard.

optimizer = torch.optim.Adam(...) - the optimizer. Adam is the most-used optimizer for modern ML. Other options: SGD (stochastic gradient descent - classic but needs more tuning), AdamW (Adam with corrected weight decay).

optimizer.zero_grad() - clear gradients from the last batch. PyTorch accumulates gradients by default; you must clear them explicitly each step. Forget this and your gradients grow unbounded.

loss.backward() - autograd walks the computation graph backward, computing gradients for every parameter that participated in computing loss. Stores them in parameter.grad.

optimizer.step() - applies the gradient update. For Adam: complex math; for SGD: param = param - lr * param.grad.

model.train() vs model.eval() - switches the model's internal mode. Affects layers like dropout and batch norm that behave differently during training vs inference. Always set the right mode.

with torch.no_grad(): - disables autograd. Inference doesn't need gradients; this skips the bookkeeping and uses less memory.

What the loss tells you

Training loss going down = model is fitting the training data. Plateauing means we've hit the model's capacity (or the optimizer is stuck - try different hyperparams).

Test accuracy is what you actually care about - performance on data the model didn't see during training. If train loss keeps dropping but test accuracy stops improving, you're overfitting - memorizing the training data.

Mitigations: more data, regularization (dropout, weight decay), smaller model, early stopping.

Hyperparameters

The numbers you set that aren't learned: - Learning rate (lr=1e-3) - how big a step the optimizer takes. Too high → unstable. Too low → slow. 1e-3 is a great starting point for Adam. - Batch size (64) - larger = smoother gradients, more memory; smaller = noisier gradients, sometimes generalizes better. - Epochs - how many passes over the data. More = better fit (until overfitting). - Architecture - depth, width, normalizations.

These need tuning. For MNIST, defaults work. For real problems, expect to iterate.

Save the model

torch.save(model.state_dict(), "mnist_mlp.pt")

Load later:

model = MLP().to(device)
model.load_state_dict(torch.load("mnist_mlp.pt"))
model.eval()

What this scales to

The training loop pattern above is identical for huge models - the only thing that changes is the model definition. Add a few wrinkles for big-model training (mixed precision, gradient accumulation, distributed, checkpointing) and you have what a real LLM training script looks like.

Exercise

  1. Run the script. Get ~97% test accuracy.

  2. Tweak hyperparameters:

  3. Set lr=1e-2. Does it train? (Probably loss goes NaN - too high.)
  4. Set lr=1e-5. (Trains but slowly.)
  5. Increase epochs to 10. Does test accuracy improve? Plateau? Degrade (overfit)?

  6. Architecture changes:

  7. Add a third hidden layer of size 128.
  8. Increase hidden sizes to 256, 128.
  9. Watch parameter count + accuracy change.

  10. Visualize: plot the per-epoch training loss using matplotlib. (Or use TensorBoard - pip install tensorboard then tensorboard --logdir runs/.)

  11. Stretch: instead of an MLP, try a small CNN. A 2-layer CNN beats this MLP at >98% accuracy. Look up nn.Conv2d. Don't worry if it doesn't work first try - convolutions take some adjusting.

What you might wonder

"Why Adam over SGD?" Adam adapts the learning rate per-parameter; converges faster on most problems without tuning. SGD with momentum can outperform on specific architectures (vision CNNs) when carefully tuned. For getting started: Adam.

"How big should my batch be?" For GPU work, "as big as fits in memory" is a common heuristic. Common sizes: 32, 64, 128, 256. Larger batches give smoother gradients but you might need more epochs to converge.

"What does loss.item() do?" Extracts a Python float from a 0-dim tensor. Detached from the graph (no gradient tracking). Use when you want a number for logging.

"Why is my loss not going down?" Common causes: - Learning rate too high (loss NaN) or too low (loss flat). - Wrong loss function for your task. - Bug in the model (wrong shapes - print them). - Data not normalized.

Print loss every step initially. If it's not going down within ~100 steps on MNIST, something's wrong.

Done

  • The four-step training loop pattern.
  • DataLoader for batched iteration.
  • nn.CrossEntropyLoss + Adam optimizer.
  • optimizer.zero_grad() / loss.backward() / optimizer.step().
  • model.train() vs model.eval().
  • Save/load weights.

Next: Inference and saving →

Comments