Saltar a contenido

06 - Inference and Saving

What this session is

About 30 minutes. The other half of training - using a trained model. Loading saved weights, running predictions, the eval-mode + no-grad pattern.

Inference vs training

Training: forward pass + loss + backward pass + optimizer step. Slow, memory-hungry, uses gradients.

Inference: forward pass only. Fast, cheap, no gradients.

The difference matters because most of a model's lifetime is inference - users sending requests; you predicting. Optimizing inference is its own discipline (page 12).

The basic pattern

import torch

model = MLP()                              # the same class you trained with
model.load_state_dict(torch.load("mnist_mlp.pt"))
model.eval()                               # IMPORTANT: switch to eval mode

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# An input
x = torch.randn(1, 1, 28, 28).to(device)   # one fake "image"

with torch.no_grad():                       # IMPORTANT: skip gradient tracking
    logits = model(x)
    probs = torch.softmax(logits, dim=1)
    pred = logits.argmax(dim=1)

print(f"predicted: {pred.item()}, confidence: {probs.max().item():.4f}")

Three things you must remember: 1. Load the weights - load_state_dict into a freshly-constructed model with the same architecture. 2. model.eval() - disables dropout, freezes batch norm running statistics. 3. with torch.no_grad(): - disables gradient tracking. Faster, uses less memory.

Forgetting any of these gives subtle bugs.

Predict on a real image

from PIL import Image
from torchvision import transforms

# The same transform you used during training
transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])

img = Image.open("my_digit.png")
x = transform(img).unsqueeze(0).to(device)    # add batch dim → (1, 1, 28, 28)

with torch.no_grad():
    logits = model(x)
    pred = logits.argmax(dim=1).item()

print(f"predicted digit: {pred}")

Key point: the inference preprocessing must match training. Same resize, same normalize, same color space. Mismatched preprocessing is the #1 silent-bug source in ML - the model still produces a prediction, just a bad one.

Batching for speed

If you have many inputs, predict on them in batches - much faster than one-at-a-time:

images = [transform(Image.open(p)) for p in paths]
batch = torch.stack(images).to(device)         # (N, 1, 28, 28)

with torch.no_grad():
    logits = model(batch)
    preds = logits.argmax(dim=1)

for path, p in zip(paths, preds):
    print(path, p.item())

Batches let the GPU keep busy. For latency-sensitive online inference, you might still process single inputs; for throughput-sensitive batch jobs, batch as much as memory allows.

Save more than weights

state_dict() is just the parameters. For a fully-recoverable training session, save more:

torch.save({
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "loss": loss,
}, "checkpoint.pt")

# Resume:
ckpt = torch.load("checkpoint.pt")
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
start_epoch = ckpt["epoch"] + 1

For production deployment, save just the model weights. For "I want to resume training tomorrow," save the full checkpoint.

TorchScript and ONNX (briefly)

For shipping models, two portability formats:

  • TorchScript - torch.jit.script(model) or torch.jit.trace(model, example_input) produces a deployment-friendly version. Can run without Python.
  • ONNX - open standard. torch.onnx.export(model, ...) produces a file readable by many runtimes (ONNX Runtime, TensorRT, browsers).

Beyond beginner scope; mentioned because deployment paths sometimes need them.

For most cases, deploying a PyTorch model directly (page 12) is fine.

Inference performance: the gotchas

A few things that catch people:

  • First inference is slow. PyTorch JIT-compiles kernels on first use. Warm up with a dummy forward pass before timing.
  • .cuda()/.to(device) is async. GPU operations are queued. To time them, call torch.cuda.synchronize() first.
  • with torch.no_grad(): matters even for small inferences. Saves memory; can be 2x faster.
  • torch.set_num_threads(1) for CPU inference can speed up small models by avoiding thread overhead.

A complete example

"""
Load a trained MNIST MLP and predict on a single image.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms


class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)


def main(image_path: str):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = MLP().to(device)
    model.load_state_dict(torch.load("mnist_mlp.pt", map_location=device))
    model.eval()

    transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ])

    img = Image.open(image_path)
    x = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits, dim=1)
        pred = logits.argmax(dim=1).item()
        confidence = probs.max().item()

    print(f"predicted: {pred} (confidence: {confidence:.4f})")


if __name__ == "__main__":
    import sys
    main(sys.argv[1])

Run: python infer.py my_digit.png.

Exercise

  1. Train a model from page 05. Save its weights as mnist_mlp.pt.

  2. Write infer.py above. Load the weights. Get a digit image (download one from Google or draw one in Paint, save as PNG). Run prediction.

  3. Measure speed:

    import time
    for _ in range(3):                # warm up
        with torch.no_grad(): model(x)
    torch.cuda.synchronize() if device == "cuda" else None
    t0 = time.time()
    for _ in range(1000):
        with torch.no_grad(): _ = model(x)
    torch.cuda.synchronize() if device == "cuda" else None
    print(f"{(time.time() - t0) * 1000 / 1000:.3f} ms / inference")
    

  4. Stretch: load multiple images at once into one batch. Time forward on the batch vs looping single-images. The batch is much faster per-image - that's why batching matters.

What you might wonder

"Why map_location=device in torch.load?" Loads tensors directly to the target device. Without it, PyTorch tries to load to the device they were saved on, which fails if you trained on GPU but are inferring on CPU.

"What's torch.compile?" PyTorch 2's JIT compiler. model = torch.compile(model) can give 1.5x-3x speedup. Sometimes flaky; experiment carefully. Mentioned for awareness.

"Should I use .half() or .bfloat16() for inference?" For modern GPUs (Volta and newer), yes - half-precision inference is ~2x faster with negligible quality drop for most models. model = model.half() then x = x.half(). Test accuracy afterward; some models tolerate half-precision better than others.

"What about quantization (INT8, INT4)?" Even more aggressive than half-precision. Used heavily for LLM inference (page 12). Beyond beginner scope on this page.

Done

  • Load weights into a model architecture.
  • Use model.eval() + with torch.no_grad():.
  • Preprocess inference inputs the same way as training inputs.
  • Batch inference for throughput.
  • Save/load full training checkpoints.

Next: Transformers and tokenization →

Comments