PyTorch Fluency: The User-Level Reference¶
Companion document. This chapter is the user-level counterpart to
AI_SYSTEMS_PLAN/DEEP_DIVES/04_PYTORCH_INTERNALS.md. That document explains how the dispatcher routes atorch.addcall, how the autograd engine builds and walks the backward graph, and howtorch.compilelowers Python into Inductor-compiled CUDA. This document is what you, the working AI engineer, must know to write training and inference code fluently-at the keyboard, with no time to spelunk source. If you ever ask "why did this fail / why is this slow / what's the right pattern?", the answer is here.All code targets PyTorch 2.4+. Every block is runnable as-is when collected into a single
.pyfile with the imports shown at the top. Length target met deliberately: this is a reference, not a tutorial-it is dense by design, and you will return to it.
0. Imports used throughout¶
import os
import math
import json
import random
import time
from pathlib import Path
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
These are the imports assumed by every snippet below. When a topic needs more (e.g. torch.distributed), the additional import appears in that section.
1. Tensors: the substrate¶
A tensor is a multidimensional array bundled with three pieces of metadata that decide whether your code runs at all and whether it runs fast: shape, dtype, device. Treat those three as a single triple-`(shape, dtype, device) - and ask it of every tensor you see.
1.1 Creation¶
# From Python data-slow path; only use for tiny configs / tests
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
# Pre-allocated constants-fast path
zeros = torch.zeros(4, 8) # default dtype float32, device cpu
ones = torch.ones(4, 8, dtype=torch.bfloat16, device="cuda")
empty = torch.empty(4, 8) # uninitialized-only use if you will overwrite
full = torch.full((4, 8), fill_value=-1.0)
# Random
torch.manual_seed(0)
g = torch.randn(4, 8) # standard normal, float32
u = torch.rand(4, 8) # uniform [0, 1)
i = torch.randint(low=0, high=10, size=(4, 8)) # int64 by default
# Sequences
ar = torch.arange(0, 10, step=2) # tensor([0,2,4,6,8])
ls = torch.linspace(0.0, 1.0, steps=5) # 5 evenly-spaced
# Like-shaped-inherit shape/dtype/device of an existing tensor
zl = torch.zeros_like(g)
rl = torch.randn_like(g)
# From NumPy-shares memory on CPU; mutating one mutates the other
arr = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
t = torch.from_numpy(arr)
arr[0, 0] = 99.0
assert t[0, 0].item() == 99.0 # shared storage
Pitfall. torch.tensor(data) always copies and infers dtype from Python (int → int64, float → float32). For NumPy inputs use torch.from_numpy if you want zero-copy, or torch.as_tensor if you don't care.
Pitfall. torch.empty is uninitialized, not zeros. Reading before you write yields garbage and, on CUDA, sometimes NaNs that propagate silently.
1.2 dtype¶
The dtypes you actually use in 2026:
| dtype | bits | use |
|---|---|---|
torch.float32 |
32 | "Default." Optimizer master weights, anything CPU, small models. |
torch.float16 |
16 | Inference on older GPUs (Volta/Turing). FP16 training only with a GradScaler to handle the narrow exponent range. |
torch.bfloat16 |
16 | Modern default for training compute on Ampere/Hopper/AMD MI. Same exponent as FP32, no scaler needed. |
torch.float64 |
64 | Scientific code. Almost never in deep learning. |
torch.int64 |
64 | Index tensors (token IDs, class labels, gather indices). |
torch.int32 |
32 | Sometimes for indices on memory-tight workloads. |
torch.bool |
8 | Masks. |
torch.uint8 |
8 | Raw image bytes pre-normalization. |
The key intuition: dtype is the lever between memory/throughput and numerical headroom. Halving precision halves memory and roughly doubles tensor-core throughput on supported GPUs. BF16 has FP32's dynamic range with 8 bits of mantissa, which is why it has displaced FP16 for training: you almost never blow up.
x = torch.randn(4, 8) # float32
y = x.to(torch.bfloat16) # cast (out-of-place)
z = x.float() # alias for .to(torch.float32)
b = x.bfloat16() # alias for .to(torch.bfloat16)
# Mixing dtypes in an op promotes; do it on purpose, not by accident
a = torch.ones(3, dtype=torch.float32)
b = torch.ones(3, dtype=torch.bfloat16)
(a + b).dtype # float32 (BF16 promotes up)
Pitfall. nn.CrossEntropyLoss requires target to be int64 (or floats for "soft" targets). Passing int32 raises a confusing error inside the loss; cast at dataset boundary: labels = labels.long().
1.3 device¶
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
x_cpu = torch.randn(4, 8)
x_gpu = x_cpu.to(device) # H2D copy
x_back = x_gpu.cpu() # D2H copy
# All operands of an op must be on the same device, period
torch.randn(3, device="cpu") + torch.randn(3, device="cuda") # RuntimeError
The cardinal rule of device discipline: move data to device exactly once per training step, at the boundary, and keep it there. Re-uploading per-op is the classic way to silently make your training 100x slower than necessary. We return to this in §15.
mps is Apple Silicon. It works for most ops but lags CUDA on coverage; assume you'll occasionally need .cpu() fallback for an unsupported op.
2. Shape ops: the daily grammar¶
Most "PyTorch programming" is shape choreography. Internalize these eight verbs.
x = torch.arange(24).reshape(2, 3, 4) # (B=2, S=3, H=4)
# reshape vs view
v = x.view(2, 12) # no copy-requires contiguous memory
r = x.reshape(2, 12) # may copy if needed
# transpose: swap exactly two dims
t = x.transpose(1, 2) # (2, 4, 3); NOT contiguous after this
# permute: arbitrary reordering
p = x.permute(2, 0, 1) # (4, 2, 3)
# squeeze / unsqueeze: drop / add a length-1 dim
y = torch.zeros(1, 3, 1, 5)
y.squeeze().shape # (3, 5)
y.squeeze(0).shape # (3, 1, 5) -only dim 0
x.unsqueeze(0).shape # (1, 2, 3, 4)
x.unsqueeze(-1).shape # (2, 3, 4, 1)
# expand vs repeat
a = torch.tensor([[1, 2, 3]]) # (1, 3)
a.expand(4, 3) # (4, 3) view, no memory copy-broadcasts
a.repeat(4, 1) # (4, 3) actual copy of data
2.1 view vs reshape-the rule¶
view requires the tensor to be contiguous in memory and compatible with the requested shape. If either fails, it raises. reshape is permissive: contiguous → identical to view (no copy); non-contiguous → calls .contiguous() internally and copies.
x = torch.randn(2, 3, 4)
x.is_contiguous() # True
x.view(6, 4) # works
xt = x.transpose(0, 1) # (3, 2, 4)-strides are now non-standard
xt.is_contiguous() # False
xt.view(6, 4) # RuntimeError: view size is not compatible
xt.reshape(6, 4) # works (silently copies)
xt.contiguous().view(6, 4) # works (explicit copy)
Why care? reshape's implicit copy can be a hidden allocation in a hot loop. view is loud about needing contiguity, which is usually what you want. Use view by default; reach for reshape only when you've decided a copy is acceptable.
Pitfall (silent slowness). A transpose followed by a series of elementwise ops on a non-contiguous tensor often runs much slower than the contiguous version because the kernel can't vectorize cleanly. If a tensor will be used many times after a transpose, call .contiguous() once.
2.2 expand vs repeat¶
expand returns a view with stride 0 along the broadcasted dim-zero memory cost, but writes to it are illegal because multiple "logical" elements share one storage cell. repeat actually copies, which costs memory but produces a writable tensor.
mask = torch.tensor([1, 0, 1]) # (3,)
mask.expand(4, 3) # (4, 3), view-zero copy
# mask.expand(4, 3)[0, 0] = 99 # would error-can't write to expanded view
mask.repeat(4, 1) # (4, 3), real tensor
99% of "broadcast-style" needs are satisfied by expand plus the broadcasting rules in §3.2; you almost never need repeat.
3. Indexing and broadcasting¶
3.1 Indexing¶
x = torch.arange(24).reshape(2, 3, 4)
# Integer indexing
x[0] # (3, 4)
x[0, 1] # (4,)
x[0, 1, 2] # scalar tensor
# Slicing
x[:, 1:, :] # (2, 2, 4)
x[..., -1] # (2, 3)-ellipsis fills all leading dims
# Boolean mask
mask = x > 10
x[mask] # 1-D tensor of all selected values
# Fancy / advanced indexing-produces a copy
idx = torch.tensor([0, 2])
x[:, idx, :] # (2, 2, 4)-selects rows 0 and 2 of dim 1
# gather: per-row picks along a dim
logits = torch.randn(4, 10) # (B=4, V=10)
targets = torch.tensor([3, 7, 0, 9]) # (B,)
picked = logits.gather(dim=1, index=targets.unsqueeze(1)).squeeze(1) # (4,)
# picked[i] == logits[i, targets[i]]
# scatter: write per-row values along a dim
one_hot = torch.zeros(4, 10).scatter_(dim=1, index=targets.unsqueeze(1), value=1.0)
gather and scatter are the workhorses for "indexed reads/writes" in vectorized code-top-k decoding, sparse updates, label smoothing. The dim argument says "along which axis are we indexing"; the index tensor must have the same shape as the output you want.
3.2 Broadcasting rules¶
Two tensors are broadcastable if, when their shapes are right-aligned, every dim is either equal, or one of them is 1, or missing. The output shape is the elementwise max.
# Right-align shapes:
# (B, 1, H)
# ( S, H)
# After right-align: (B, 1, H) and (1, S, H)
# Pairwise: B vs 1 -> B, 1 vs S -> S, H vs H -> H. Output: (B, S, H).
q = torch.randn(4, 1, 16) # (B=4, 1, H=16)
k = torch.randn( 8, 16) # ( S=8, H=16)
out = q + k # (4, 8, 16)
# A common attention scaffold: per-token bias added to a (B, S, H) hidden state
h = torch.randn(4, 8, 16)
b = torch.randn(16) # (H,)-broadcasts to (1, 1, 16)
h2 = h + b # (4, 8, 16)
# Failing case: shapes incompatible
torch.randn(3, 4) + torch.randn(2, 4) # RuntimeError
The mental drill: right-align, then check pairwise. If you can recite the output shape before you press enter, you'll never write a broadcasting bug.
Pitfall. Broadcasting silently allocates: (B, S, H) + (B, S, H) is an in-place fusable op; (B, S, H) + (H,) is too; but (B, 1, H) * (1, S, H) → (B, S, H) materializes the full Cartesian product. In tight loops this is the difference between fitting in cache and thrashing.
4. Math ops¶
4.1 Elementwise and reductions¶
x = torch.randn(4, 8)
# Elementwise-all return new tensors; in-place variants end with _
x.abs(); x.exp(); x.log(); x.sqrt(); x.sigmoid()
x.add_(1.0) # in-place, modifies x
# Reductions-the dim argument is everything
x.sum() # scalar
x.sum(dim=0) # (8,)-collapses dim 0
x.sum(dim=1) # (4,)
x.sum(dim=1, keepdim=True) # (4, 1)-preserves the dim, ready to broadcast back
x.mean(dim=-1) # last dim; idiomatic
x.max(dim=-1) # returns a (values, indices) named tuple
x.argmax(dim=-1) # just indices
The keepdim=True pattern is constantly used: reduce, then broadcast back.
# LayerNorm by hand-uses keepdim to broadcast mean/std back to original shape
def layer_norm(x, eps=1e-5):
mu = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
return (x - mu) / torch.sqrt(var + eps)
4.2 Matrix multiplication¶
A = torch.randn(4, 8)
B = torch.randn(8, 16)
A @ B # (4, 16)
torch.matmul(A, B) # same; supports batching
# Batched matmul-bmm is strict, matmul is permissive
Ab = torch.randn(32, 4, 8)
Bb = torch.randn(32, 8, 16)
torch.bmm(Ab, Bb) # (32, 4, 16); requires exactly 3-D
torch.matmul(Ab, Bb) # same; also accepts broadcasting on leading dims
# Matmul broadcasts leading dims-useful for multi-head attention
Q = torch.randn(2, 4, 8, 16) # (B=2, H=4, S=8, D=16)
K = torch.randn(2, 4, 8, 16)
scores = Q @ K.transpose(-2, -1) # (2, 4, 8, 8)
4.3 einsum-when index notation is the right tool¶
einsum lets you write the contraction in index notation, which is far more readable for anything beyond a 2-D matmul.
# Plain matmul: A_ij B_jk -> C_ik
torch.einsum("ij,jk->ik", A, B)
# Batched matmul: A_bij B_bjk -> C_bik
torch.einsum("bij,bjk->bik", Ab, Bb)
# Multi-head attention scores in one line
# Q: (b, h, s, d), K: (b, h, t, d) -> (b, h, s, t)
torch.einsum("bhsd,bhtd->bhst", Q, K)
# Outer product
u, v = torch.randn(4), torch.randn(5)
torch.einsum("i,j->ij", u, v) # (4, 5)
# Trace
M = torch.randn(8, 8)
torch.einsum("ii->", M) # scalar
# Diagonal
torch.einsum("ii->i", M) # (8,)
When to reach for which. Use @ / matmul for plain 2-D and 3-D matmuls; reach for einsum the moment you have more than three indices or you'd otherwise need a permute/transpose/reshape dance to set up a matmul. It's just as fast in modern PyTorch (it lowers to optimized BLAS), and the index notation reads like the math.
5. Autograd as a user¶
The ten-second mental model: every op on a tensor with requires_grad=True records a node in a dynamic graph whose leaves are the parameters. Calling .backward() on a scalar walks that graph in reverse, accumulating gradients into the leaves' .grad fields. The graph is rebuilt every forward pass-that's "dynamic" / "define-by-run." For the actual engine (function objects, the variable-version mechanism, hooks, the autograd graph traversal), see AI_SYSTEMS_PLAN/DEEP_DIVES/04_PYTORCH_INTERNALS.md §3.
5.1 Basic mechanics¶
x = torch.tensor([2.0, 3.0], requires_grad=True)
y = (x ** 2).sum()
y.backward()
x.grad # tensor([4., 6.]) -> dy/dx = 2x
# Gradients accumulate. You MUST zero them between steps.
y2 = (x ** 3).sum()
y2.backward()
x.grad # tensor([4.+12.=16., 6.+27.=33.])
x.grad.zero_() # or model.zero_grad() / optimizer.zero_grad()
5.2 detach, no_grad, inference_mode¶
Three different ways to "step out of" autograd. They are not interchangeable.
# detach-produce a tensor that shares storage but has no grad history
x = torch.randn(4, requires_grad=True)
y = x.detach() # y.requires_grad == False
# no_grad-context manager: ops inside don't record graph, but tensors created
# can still later have requires_grad set
with torch.no_grad():
z = x * 2 # z.requires_grad == False
# inference_mode-stricter and faster than no_grad; disables version counter
# bumping. Tensors created inside an inference_mode block are tagged "inference"
# and CANNOT be used in any later autograd computation.
with torch.inference_mode():
z = x * 2
# z + (something requiring grad) -> RuntimeError
Rule of thumb. inference_mode for serving / evaluation loops where you will never need grads on these tensors. no_grad for evaluation that lives inside a training script and may interact with grad-requiring tensors. detach for "give me this value as a constant from here on" (e.g., target networks in RL, EMA teachers in self-distillation).
5.3 The two patterns you actually use¶
# Forward+backward+step (training)
optimizer.zero_grad(set_to_none=True) # set_to_none=True is faster: frees grad memory
loss = loss_fn(model(x), y)
loss.backward()
optimizer.step()
# Evaluation (no graph at all)
model.eval()
with torch.inference_mode():
pred = model(x)
set_to_none=True (the default since 2.0) replaces grads with None instead of zeroing in place-first backward after each step re-allocates, but it skips a zeroing kernel and reduces optimizer memory pressure. Always use it.
6. The nn.Module pattern¶
Every model is a tree of nn.Modules. The contract is:
- Subclass
nn.Module. - In
__init__, callsuper().__init__()then create child modules and parameters as attributes (self.linear = nn.Linear(...)). Module discovery is by attribute assignment. - In
forward, write the computation. Don't callforwarddirectly; call the module instance-`model(x) - which runs hooks and tracks the graph.
class MLP(nn.Module):
def __init__(self, d_in: int, d_hidden: int, d_out: int, p_drop: float = 0.1):
super().__init__()
self.fc1 = nn.Linear(d_in, d_hidden)
self.act = nn.GELU()
self.drop = nn.Dropout(p_drop)
self.fc2 = nn.Linear(d_hidden, d_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.fc2(self.drop(self.act(self.fc1(x))))
model = MLP(64, 256, 10)
y = model(torch.randn(8, 64)) # (8, 10)
6.1 Parameter discovery¶
for name, p in model.named_parameters():
print(name, tuple(p.shape), p.requires_grad)
# fc1.weight (256, 64) True
# fc1.bias (256,) True
# fc2.weight (10, 256) True
# fc2.bias (10,) True
n_params = sum(p.numel() for p in model.parameters())
n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
Parameter is a special Tensor subclass: assigning one as an attribute of a module registers it as a parameter. Plain tensors don't get registered. If you need a tensor that should move with .to(device) and be saved by state_dict() but is not trained, use register_buffer:
class CausalSelfAttention(nn.Module):
def __init__(self, max_seq: int, d: int):
super().__init__()
self.qkv = nn.Linear(d, 3 * d)
# Causal mask is a buffer: moves with .to(device), saved/loaded, but not trained
mask = torch.tril(torch.ones(max_seq, max_seq, dtype=torch.bool))
self.register_buffer("causal_mask", mask, persistent=True)
6.2 state_dict / load_state_dict¶
state_dict() returns an OrderedDict[str, Tensor] containing every parameter and persistent buffer, keyed by dotted path. load_state_dict() consumes the same.
sd = model.state_dict() # in-memory snapshot
torch.save(sd, "mlp.pt")
# Reload
model2 = MLP(64, 256, 10)
model2.load_state_dict(torch.load("mlp.pt", map_location="cpu"))
# Strict vs lax loading
missing, unexpected = model2.load_state_dict(sd, strict=False)
map_location="cpu" on load is best practice: it deserializes onto CPU regardless of where the tensors lived when saved. Move to GPU after loading via .to(device). This avoids the "saved on GPU 7, loading machine has 4 GPUs" foot-gun.
6.3 Composition helpers¶
# Sequential-a hard-coded straight pipe
seq = nn.Sequential(
nn.Linear(64, 256),
nn.GELU(),
nn.Linear(256, 10),
)
# ModuleList-like Python list, but its modules are properly registered
class Stack(nn.Module):
def __init__(self, n_layers, d):
super().__init__()
self.layers = nn.ModuleList([nn.Linear(d, d) for _ in range(n_layers)])
def forward(self, x):
for layer in self.layers:
x = F.gelu(layer(x))
return x
# ModuleDict-same idea, dict-shaped, useful for branched architectures
class MultiHead(nn.Module):
def __init__(self, d):
super().__init__()
self.heads = nn.ModuleDict({
"classifier": nn.Linear(d, 10),
"regressor": nn.Linear(d, 1),
})
def forward(self, x, head: str):
return self.heads[head](x)
Pitfall. Storing modules in a plain list or dict (instead of ModuleList / ModuleDict) silently breaks parameter discovery-model.parameters() won't see them, the optimizer won't update them, state_dict() won't save them. Symptom: training "runs" but the model never improves.
7. Common layers, in the parameterization you'll actually use¶
# Linear: y = xW^T + b, weight shape (out, in), bias shape (out,)
fc = nn.Linear(in_features=128, out_features=512, bias=True)
# Embedding: lookup table, weight shape (num_embeddings, embedding_dim)
emb = nn.Embedding(num_embeddings=50_000, embedding_dim=768)
ids = torch.randint(0, 50_000, (4, 32)) # (B, S)
h = emb(ids) # (4, 32, 768)
# LayerNorm: normalizes over the last `normalized_shape` dims; learnable (gamma, beta)
ln = nn.LayerNorm(normalized_shape=768) # over last dim of size 768
# Dropout: zeros each elem with prob p AT TRAINING TIME ONLY (no-op in .eval())
drop = nn.Dropout(p=0.1)
# GELU: smooth ReLU; standard activation in transformers
act = nn.GELU()
# Multi-head self-attention. The `batch_first=True` form matches every modern
# codebase: input is (B, S, E). Without it the input is (S, B, E).
attn = nn.MultiheadAttention(embed_dim=768, num_heads=12, batch_first=True)
x = torch.randn(4, 32, 768)
out, weights = attn(query=x, key=x, value=x, need_weights=False) # (4, 32, 768)
LayerNorm parameter shape note. nn.LayerNorm(768) creates two learnables of shape (768,): gamma (weight) and beta (bias). The op normalizes the last dim and then applies gamma * x + beta elementwise-this is the standard transformer LayerNorm.
Dropout in eval. model.eval() flips a flag (self.training = False) that propagates to children. Dropout becomes identity; BatchNorm switches from batch stats to running stats. Forgetting .eval() for inference is a top-five real-world bug-your "validation accuracy" will be lower than the true number, and randomly different every run.
MultiheadAttention quirks. It bundles input and output projections, so embed_dim is the model dim, not per-head. num_heads must divide embed_dim. For causal language modeling, pass is_causal=True and a triangular attn_mask; PyTorch is conservative about applying causality without an explicit mask.
8. Loss functions¶
# Multiclass classification (logits in, class indices out). DO NOT softmax first.
logits = torch.randn(8, 10) # (B, V)
targets = torch.randint(0, 10, (8,)) # (B,) int64
loss_fn = nn.CrossEntropyLoss(ignore_index=-100) # -100 == "this position is padding"
loss = loss_fn(logits, targets) # scalar
# Sequence prediction: flatten (B, S, V) -> (B*S, V), and (B, S) -> (B*S,)
B, S, V = 4, 32, 50_000
seq_logits = torch.randn(B, S, V)
seq_targets = torch.randint(0, V, (B, S))
loss = loss_fn(seq_logits.reshape(-1, V), seq_targets.reshape(-1))
# Regression
mse = nn.MSELoss()
mse(torch.randn(8, 1), torch.randn(8, 1))
# Binary classification-BCE-with-logits is numerically stable; never use BCE+sigmoid
bce = nn.BCEWithLogitsLoss()
bce(torch.randn(8), torch.empty(8).uniform_().round())
Numerical stability-why "with logits" matters. BCEWithLogitsLoss and CrossEntropyLoss use the log-sum-exp trick internally:
This keeps the largest exponent at zero, so you never overflow exp(x) for large positive logits or underflow log(0) for large negatives. If you instead F.softmax then F.nll_loss, or sigmoid then binary_cross_entropy, you do the unstable thing and then take a log; on BF16 / FP16 you'll see NaNs in production. Always use the fused "with logits" loss.
ignore_index=-100 is the convention used by Hugging Face tokenizers: any position labeled - 100` is excluded from the loss (denominator and numerator). Use it for padding and for "instruction tokens we don't want to teach on" in supervised fine-tuning.
9. Optimizers¶
# SGD with momentum-still the right answer for many vision tasks
opt = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4, nesterov=True)
# Adam-adaptive, used to be default
opt = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-8)
# AdamW-Adam with **decoupled** weight decay; the modern transformer default
opt = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), weight_decay=0.1)
Adam vs AdamW is not cosmetic. In Adam, weight decay is folded into the gradient and then divided by the adaptive denominator, which couples it with the learning rate in unintended ways. AdamW applies weight decay as a separate, decoupled param -= lr * wd * param step. For transformers this is consistently better at the same hyperparameters; it is the default.
9.1 Parameter groups¶
You almost always want two groups: weights with decay, biases / norms / embeddings without.
def build_param_groups(model, weight_decay=0.1):
decay, no_decay = [], []
for name, p in model.named_parameters():
if not p.requires_grad:
continue
# Heuristic: 1-D params (biases, LayerNorm gamma/beta) -> no decay
if p.ndim < 2 or name.endswith(".bias") or "norm" in name.lower() or "embed" in name.lower():
no_decay.append(p)
else:
decay.append(p)
return [
{"params": decay, "weight_decay": weight_decay},
{"params": no_decay, "weight_decay": 0.0},
]
opt = torch.optim.AdamW(build_param_groups(model), lr=3e-4, betas=(0.9, 0.95))
This is the recipe used by GPT-2, LLaMA, every reputable training script. Without it you regularize biases and LayerNorm gains toward zero, which is at best wasteful and at worst destabilizing.
9.2 The closure pattern¶
A few optimizers (LBFGS) take a closure callable that recomputes the loss. You can ignore this for SGD/Adam/AdamW.
10. Learning-rate schedulers¶
# Linear warmup over the first 1000 steps
warmup = torch.optim.lr_scheduler.LinearLR(opt, start_factor=1e-6, end_factor=1.0, total_iters=1000)
# Cosine decay from peak to zero over total_steps - warmup_steps
total_steps = 100_000
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=total_steps - 1000)
# Compose: warmup first, then cosine
sched = torch.optim.lr_scheduler.SequentialLR(
opt, schedulers=[warmup, cosine], milestones=[1000]
)
# OneCycle-"warm up then decay" packaged into one scheduler; popular for training from scratch
sched_oc = torch.optim.lr_scheduler.OneCycleLR(
opt, max_lr=3e-3, total_steps=total_steps, pct_start=0.1
)
The standard transformer recipe is linear warmup → cosine decay: ramp from ~0 to peak LR over the first 1–10% of steps, then cosine-anneal to zero (or to peak/10) over the rest. The warmup tames the early "Adam wobble" where second-moment estimates are noisy and gradients are large.
Step the scheduler exactly once per optimizer step, after opt.step():
For epoch-based schedulers (the older default), you'd step once per epoch instead. Modern recipes are step-based.
11. Dataset and DataLoader¶
Dataset is "give me item i"; DataLoader is "stream batches with workers." Together they decouple data from training.
11.1 Subclassing Dataset¶
class JsonlDataset(Dataset):
"""Reads a JSONL file lazily; each line is one example."""
def __init__(self, path: str):
self.path = Path(path)
# Pre-index byte offsets for O(1) random access without loading file
self.offsets = []
with open(self.path, "rb") as f:
offset = 0
for line in f:
self.offsets.append(offset)
offset += len(line)
def __len__(self) -> int:
return len(self.offsets)
def __getitem__(self, idx: int) -> dict:
with open(self.path, "rb") as f:
f.seek(self.offsets[idx])
line = f.readline()
return json.loads(line)
This pattern-pre-index, lazy-read-handles JSONLs of any size without memory pressure. For tiny datasets just load into memory and index a list.
11.2 DataLoader, the fast data path¶
loader = DataLoader(
dataset,
batch_size=32,
shuffle=True,
num_workers=4, # parallel data-loading processes
pin_memory=True, # page-locked host memory -> async H2D copy
prefetch_factor=2, # each worker pre-fetches 2 batches in advance
persistent_workers=True, # don't kill workers between epochs
drop_last=True, # drop the last partial batch (training only)
collate_fn=None, # custom batcher if items are heterogeneous
)
Each knob earns its keep:
num_workers> 0 forks worker processes that call__getitem__in parallel. Set to 2–8; too high and the IPC overhead dominates.pin_memory=Trueallocates batches in non-pageable host memory, which letstensor.to(device, non_blocking=True)overlap with computation. Always on if you train on GPU.prefetch_factoris per-worker; total queued batches isnum_workers * prefetch_factor.persistent_workers=Trueavoids re-spawning workers each epoch-saves seconds on small epochs, hours over a long run.drop_last=Truefor training (a partial batch screws with batchnorm and statistics). For eval,drop_last=Falseand account for the smaller last batch when averaging.collate_fnturns a list of__getitem__results into a batch tensor. Default istorch.utils.data.default_collateand works for tensors and dicts of tensors. Override for variable-length sequences (padding) or images of mixed sizes.
# Custom collate for variable-length token sequences
def pad_collate(batch: list[dict], pad_id: int = 0) -> dict:
max_len = max(len(item["input_ids"]) for item in batch)
input_ids, labels, attn_mask = [], [], []
for item in batch:
ids = item["input_ids"]
n = len(ids)
pad = max_len - n
input_ids.append(ids + [pad_id] * pad)
labels.append(item["labels"] + [-100] * pad) # -100 = ignore in CE
attn_mask.append([1] * n + [0] * pad)
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
"attention_mask": torch.tensor(attn_mask, dtype=torch.long),
}
12. The honest training loop¶
Below is a complete, copy-pasteable, production-shaped training loop. Every line earns its place; nothing is illustrative-only. Annotate this in your head until you can write it from blank.
import torch, math, time
from pathlib import Path
def train(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
*,
device: str = "cuda",
epochs: int = 10,
lr: float = 3e-4,
weight_decay: float = 0.1,
grad_clip: float = 1.0,
warmup_steps: int = 1000,
total_steps: Optional[int] = None,
use_amp: bool = True,
amp_dtype: torch.dtype = torch.bfloat16,
ckpt_dir: str = "./checkpoints",
patience: int = 3,
log_every: int = 50,
):
Path(ckpt_dir).mkdir(parents=True, exist_ok=True)
model.to(device) # §15: move once
opt = torch.optim.AdamW(build_param_groups(model, weight_decay), lr=lr, betas=(0.9, 0.95))
if total_steps is None:
total_steps = len(train_loader) * epochs
warmup = torch.optim.lr_scheduler.LinearLR(opt, 1e-6, 1.0, total_iters=warmup_steps)
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max(1, total_steps - warmup_steps))
sched = torch.optim.lr_scheduler.SequentialLR(opt, [warmup, cosine], milestones=[warmup_steps])
# Mixed precision: BF16 needs no scaler; FP16 needs GradScaler
use_scaler = use_amp and amp_dtype == torch.float16 and device.startswith("cuda")
scaler = torch.cuda.amp.GradScaler(enabled=use_scaler)
best_val = float("inf")
epochs_since_best = 0
global_step = 0
for epoch in range(epochs):
model.train() # §21: dropout/BN ON
t0 = time.time()
for batch_idx, batch in enumerate(train_loader):
# ---- 1. Move to device with non_blocking for overlap with compute
inputs = batch["input_ids"].to(device, non_blocking=True)
labels = batch["labels"].to(device, non_blocking=True)
# ---- 2. Forward in autocast region
with torch.autocast(device_type=device.split(":")[0],
dtype=amp_dtype, enabled=use_amp):
logits = model(inputs) # (B, S, V) say
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
labels.reshape(-1),
ignore_index=-100,
)
# ---- 3. Backward (scaled if FP16)
opt.zero_grad(set_to_none=True) # §5: cheap reset
if use_scaler:
scaler.scale(loss).backward()
scaler.unscale_(opt) # unscale before clip
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
scaler.step(opt)
scaler.update()
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
opt.step()
sched.step()
global_step += 1
if batch_idx % log_every == 0:
lr_now = opt.param_groups[0]["lr"]
print(f"epoch {epoch} step {global_step} loss {loss.item():.4f} lr {lr_now:.2e}")
# ---- 4. Validation
val_loss = evaluate(model, val_loader, device, amp_dtype, use_amp)
print(f"epoch {epoch} val_loss {val_loss:.4f} time {time.time()-t0:.1f}s")
# ---- 5. Checkpoint best, early-stop on patience
if val_loss < best_val:
best_val = val_loss
epochs_since_best = 0
save_checkpoint(model, opt, sched, scaler, epoch, global_step,
best_val, Path(ckpt_dir) / "best.pt")
else:
epochs_since_best += 1
if epochs_since_best >= patience:
print(f"early stop at epoch {epoch}")
break
save_checkpoint(model, opt, sched, scaler, epoch, global_step,
best_val, Path(ckpt_dir) / "last.pt")
@torch.inference_mode()
def evaluate(model, loader, device, amp_dtype, use_amp):
model.eval()
total, n = 0.0, 0
for batch in loader:
inputs = batch["input_ids"].to(device, non_blocking=True)
labels = batch["labels"].to(device, non_blocking=True)
with torch.autocast(device_type=device.split(":")[0], dtype=amp_dtype, enabled=use_amp):
logits = model(inputs)
loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)),
labels.reshape(-1), ignore_index=-100)
total += loss.item() * inputs.size(0)
n += inputs.size(0)
return total / max(n, 1)
The list of features here is the list of features you have decided on purpose to leave out when you write a smaller loop, not a list of "advanced" things-nothing here is optional in a real training run:
- AdamW with parameter groups (§9.1)
- Linear warmup → cosine decay (§10)
- Mixed precision (§13)
- Gradient clipping by global norm (caps gradient explosions)
set_to_none=Truezeroing- Checkpoint best + last
- Validation in
inference_modewithmodel.eval() - Early stopping with patience
non_blocking=TrueH2D copies (works withpin_memory=Trueloader)
13. Mixed precision¶
Training in FP32 wastes memory and tensor-core throughput. The two production options:
13.1 BF16 (modern default)¶
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits = model(inputs)
loss = F.cross_entropy(logits, labels)
loss.backward()
optimizer.step()
That's it. No scaler. Inside the context, eligible ops (matmul, conv, attention, …) run in BF16; sensitive ops (reductions, softmax) stay in FP32. Master parameter weights remain FP32 for stable optimizer updates. BF16's exponent matches FP32's, so gradient magnitudes don't underflow.
13.2 FP16 (legacy / Volta / Turing)¶
scaler = torch.cuda.amp.GradScaler()
with torch.autocast(device_type="cuda", dtype=torch.float16):
logits = model(inputs)
loss = F.cross_entropy(logits, labels)
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
GradScaler multiplies the loss by a large factor before backward (so tiny gradients don't underflow to zero in FP16's narrow dynamic range), then unscales before the optimizer step. If a NaN/Inf is detected, it skips the step and halves the scale; otherwise it slowly grows the scale.
Decision rule. Ampere+ (RTX 30/40, A100, H100): use BF16. Volta/Turing (V100, T4, RTX 20): use FP16 + GradScaler. AMD MI200+: use BF16. CPU autocast exists but is rarely worth it.
Pitfall. Don't wrap the backward in autocast-backward inherits dtypes from the saved forward tensors. Just wrap the forward and the loss.
14. Checkpointing¶
A checkpoint that lets you resume bit-exactly must save more than weights:
def save_checkpoint(model, opt, sched, scaler, epoch, step, best_val, path):
torch.save({
"epoch": epoch,
"step": step,
"best_val": best_val,
"model": model.state_dict(),
"optimizer": opt.state_dict(),
"scheduler": sched.state_dict(),
"scaler": scaler.state_dict() if scaler is not None else None,
"rng_torch_cpu": torch.get_rng_state(),
"rng_torch_cuda": torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None,
"rng_numpy": np.random.get_state(),
"rng_python": random.getstate(),
}, path)
def load_checkpoint(path, model, opt, sched, scaler, device):
ck = torch.load(path, map_location="cpu")
model.load_state_dict(ck["model"])
model.to(device)
opt.load_state_dict(ck["optimizer"])
sched.load_state_dict(ck["scheduler"])
if scaler is not None and ck.get("scaler") is not None:
scaler.load_state_dict(ck["scaler"])
torch.set_rng_state(ck["rng_torch_cpu"])
if torch.cuda.is_available() and ck.get("rng_torch_cuda") is not None:
torch.cuda.set_rng_state_all(ck["rng_torch_cuda"])
np.random.set_state(ck["rng_numpy"])
random.setstate(ck["rng_python"])
return ck["epoch"], ck["step"], ck["best_val"]
Why everything? Optimizer state (Adam moments)-without it you reset to zero momentum and your loss curve has a visible jolt. Scheduler state-otherwise your LR resets to peak. Scaler state-preserves the loss-scale value. RNG states-without them, dropout masks and dataloader shuffles diverge and the loss curve does too. Epoch/step-for the LR scheduler and for bookkeeping.
After loading, the next batch's loss should match what the original run produced at that step (modulo CUDA non-determinism, see §19). If it doesn't, you missed something.
Pitfall. torch.save pickles. Saving a model directly (instead of model.state_dict()) couples the checkpoint to the exact class layout and import path; refactor the class and the checkpoint stops loading. Always save state_dict().
15. Device transfer discipline¶
The single most common source of "PyTorch is slow" complaints from beginners.
The rule. Move to device exactly once per tensor's life: at the dataloader-to-train-step boundary for inputs/labels, and at model.to(device) for parameters. Never inside the model's forward.
# WRONG-uploads constants every step
class Bad(nn.Module):
def forward(self, x):
scale = torch.tensor(2.0).to(x.device) # CPU->GPU every call
return x * scale
# RIGHT-buffer registered once, moves with the module
class Good(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("scale", torch.tensor(2.0))
def forward(self, x):
return x * self.scale
pin_memory + non_blocking. Together they enable async H2D copies that overlap with the previous step's compute:
# DataLoader(..., pin_memory=True)
inputs = batch["x"].to(device, non_blocking=True) # async; returns immediately
# ... GPU is busy with prior step's backward; copy happens in parallel ...
out = model(inputs) # synchronizes when needed
The "where does this live?" question. Whenever a RuntimeError: Expected all tensors to be on the same device fires, the fix is always: print the devices.
print({n: p.device for n, p in model.named_parameters() if p.device.type != device})
print(inputs.device, labels.device)
It's almost always a forgotten tensor literal in forward (a torch.zeros(...) you didn't register_buffer), or a tensor created in __getitem__ returning on CPU when you expected it elsewhere.
16. `torch.compile - the user-level view¶
What it does (briefly; see 04_PYTORCH_INTERNALS.md for the dispatcher / Dynamo / Inductor pipeline): traces your forward into an FX graph, fuses ops, and generates one or a few CUDA kernels per "graph region." Typical speedups: 1.3–2.5× on transformer training, more on inference.
16.1 Modes¶
model = torch.compile(model, mode="default") # balanced
model = torch.compile(model, mode="reduce-overhead") # cudagraphs, lower kernel-launch overhead-best for inference / small batches
model = torch.compile(model, mode="max-autotune") # spends time at first call to autotune; best steady-state perf
16.2 Graph breaks-what to avoid¶
A "graph break" is when the tracer hits Python that it can't capture and falls back to eager. Each break costs you fusion across the break boundary. Common offenders:
.item()/.cpu()/printin forward-pulls a value back to Python, forces synchronization.- Data-dependent control flow on tensor values:
if x.sum() > 0: .... Usetorch.whereor rewrite with masking. - Custom Python objects with non-tensor attributes used in shape arithmetic.
- Calling
.numpy()or third-party libs that aren't traceable.
Diagnose with `TORCH_LOGS="graph_breaks" python script.py - the output tells you the file/line of each break.
16.3 When to skip compile¶
- Tiny models (overhead exceeds gain).
- During first-pass debugging: compile errors are harder to read than eager stack traces.
- Highly dynamic shapes where recompilation thrashes (mitigate with
dynamic=True).
The pragmatic flow: write eager, get correct, then model = torch.compile(model) and measure.
17. Distributed Data Parallel (user-level)¶
DDP runs one process per GPU, each with a full model replica. Each step: independent forward + backward; gradients are all-reduced across processes; each replica runs its own optimizer step on the (now identical) gradients. (For the math and bandwidth analysis, see AI_SYSTEMS_PLAN/DEEP_DIVES/06_DISTRIBUTED_TRAINING.md.)
The minimum viable DDP script:
# train_ddp.py
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
def main():
# torchrun sets these env vars for us
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
model = MyModel().to(device)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
train_set = MyDataset(...)
sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True)
loader = DataLoader(train_set, batch_size=32, sampler=sampler,
num_workers=4, pin_memory=True, persistent_workers=True)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4)
for epoch in range(10):
sampler.set_epoch(epoch) # ensures different shuffling per epoch
for batch in loader:
inputs = batch["x"].to(device, non_blocking=True)
labels = batch["y"].to(device, non_blocking=True)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
loss = F.cross_entropy(model(inputs), labels)
opt.zero_grad(set_to_none=True)
loss.backward()
opt.step()
dist.destroy_process_group()
if __name__ == "__main__":
main()
Launch:
torchrun --nproc_per_node=4 --nnodes=1 train_ddp.py
# multi-node:
torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 \
--rdzv_backend=c10d --rdzv_endpoint=$MASTER_IP:29500 train_ddp.py
Five things that always matter in DDP:
- One process per GPU, set with
torch.cuda.set_device(local_rank). Mixing local ranks and devices is the most common DDP bug. DistributedSamplermust be on the train loader, and you must callsampler.set_epoch(epoch)each epoch-otherwise every epoch shuffles identically.- Save checkpoints from rank 0 only (
if rank == 0: torch.save(...)) and calldist.barrier()after to keep ranks in lockstep. - Don't
.to(device)after wrapping in DDP. Order ismodel.to(device)→DDP(model). - Effective batch size =
per_device_batch_size * world_size. Scale LR accordingly (linear scaling rule for SGD; for AdamW, sub-linear, oftensqrt(world_size)).
18. `torch.utils.checkpoint - gradient checkpointing¶
Activations are the dominant memory cost in deep transformer training: every forward saves intermediate tensors needed for backward. Gradient checkpointing trades compute for memory: you don't save activations, you re-run forward during backward.
from torch.utils.checkpoint import checkpoint
class TransformerBlock(nn.Module):
def forward(self, x):
# ... attention + MLP ...
return x
class GCStack(nn.Module):
def __init__(self, blocks: list[nn.Module], use_checkpoint: bool = True):
super().__init__()
self.blocks = nn.ModuleList(blocks)
self.use_checkpoint = use_checkpoint
def forward(self, x):
for block in self.blocks:
if self.use_checkpoint and self.training:
x = checkpoint(block, x, use_reentrant=False)
else:
x = block(x)
return x
The trade. Memory: O(sqrt(N)) activations instead of O(N) for an N-layer model with the standard "checkpoint every block" pattern-typically 30–50% less activation memory. Compute: one extra forward pass during backward, so roughly +33% wall time per step. You buy memory; you pay time.
When to use. When you can't fit the model + activations in GPU memory at your target batch size, and using a smaller batch or higher gradient accumulation isn't acceptable. Production LLM training almost always uses it for some layers.
use_reentrant=False is the modern (non-deprecated) implementation-use it.
Pitfall. Checkpointed regions can't contain ops that depend on RNG state in a way that varies between forward and "re-forward" unless you set preserve_rng_state=True (the default), which incurs more overhead. Plain dropout is handled.
19. Reproducibility-what you can and can't guarantee¶
def set_seed(seed: int = 42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# For CUBLAS determinism (must be set BEFORE first CUDA op)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True, warn_only=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False # disables autotuner
What you can guarantee, given the above:
- Bit-exact reruns on the same hardware, same PyTorch + CUDA version, same number of workers, same GPU count, same input order.
What you cannot guarantee:
- Bit-exactness across different GPU models (Ampere ≠ Hopper). Different SM counts → different reduction trees → different rounding.
- Bit-exactness across different worker counts in the DataLoader (data ordering differs).
- Bit-exactness with
cudnn.benchmark=True-it picks the fastest kernel per shape, which can vary across runs. - Bit-exactness across PyTorch versions-kernel implementations change.
In practice you set the seeds, accept "matches loss curve to 3 decimals" as reproducible, and only chase bit-exact when debugging. torch.use_deterministic_algorithms(True) will raise if you call an op without a deterministic implementation; warn_only=True softens this to a warning.
20. Hugging Face `transformers - the bridge¶
Most production models are loaded from the HF hub, not trained from scratch. The interface is small.
from transformers import AutoTokenizer, AutoModelForCausalLM
name = "meta-llama/Llama-3-8B" # or any hub id
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(
name,
torch_dtype=torch.bfloat16,
device_map="auto", # spreads across visible GPUs
)
model.eval()
prompt = "Write a haiku about Adam vs SGD:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=64,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.05,
pad_token_id=tokenizer.eos_token_id,
)
print(tokenizer.decode(out[0], skip_special_tokens=True))
Things to know:
AutoModel→ bare encoder;AutoModelForCausalLM→ adds an LM head;AutoModelForSequenceClassification→ adds a classification head. Pick the one matching your task.from_pretrainedis the integration point: it downloads, instantiates the right class, loads weights, and respectstorch_dtypeanddevice_map. The returned object is a plainnn.Modulesubclass-every PyTorch idiom in this chapter applies.model.generateis a sampling loop wrapper;do_sample=Falsegives greedy decoding,num_beams=Ngives beam search. For production serving, use vLLM or TGI rather than `generate - the HF generate is fine for prototypes and evaluation.- Tokenizers return
input_ids(token ids) andattention_mask(1 for real, 0 for pad). Always pass both to the model. - For training:
Traineris the high-level API; under the hood it's the samenn.Module+ AdamW + AMP loop we wrote in §12. WhenTrainerdoesn't fit, drop to a custom loop-the model is just ann.Module.
# Custom training, treating the HF model as a plain module
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
for batch in loader:
out = model(input_ids=batch["input_ids"].to(device),
attention_mask=batch["attention_mask"].to(device),
labels=batch["labels"].to(device))
out.loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
When labels is passed, HF causal-LM models internally shift and compute cross-entropy; the returned out.loss is a scalar tensor ready for backward.
21. Common pitfalls-the bug bestiary¶
A consolidated list of the bugs that cost real teams real days. Recognize them on sight.
- Forgetting
optimizer.zero_grad(). Gradients accumulate from prior step. Symptom: loss explodes by step 2. - Mixing CPU and CUDA tensors. Usually a tensor literal in
forward. Fix withregister_bufferor.to(x.device). - Forgetting
model.eval()for inference. Dropout still drops; BatchNorm still updates running stats. Validation accuracy is randomly worse than reality. - Forgetting
model.train()after eval. Symmetric of #3-your training "stops working" mid-run because you eval'd and never flipped back. - Stale
requires_gradafter a copy.t = old_param.detach().clone()produces a tensor withrequires_grad=False. If you wanted a learnable parameter, wrap innn.Parameter. - Non-contiguous tensors causing slowness. Persistent transpose without
.contiguous(). Symptom: a particular layer is 5× slower than it should be. - Using
viewaftertransposewithout.contiguous(). Crashes with "view size is not compatible." Either call.contiguous()or usereshape. .item()in a hot path. Forces a CUDA sync-the GPU has to finish all queued work for you to read one number. Loggingloss.item()once per N steps is fine; doing it every step throttles training measurably.- In-place op on a leaf tensor that requires grad.
param.data.add_(...)is fine;param.add_(...)outsideno_gradraises. Wrap parameter mutations inwith torch.no_grad():. - CrossEntropyLoss with float labels (when you wanted hard labels). Crashes about the dtype. Cast:
labels = labels.long(). - Softmax → CE. Numerically unstable. Use
nn.CrossEntropyLossdirectly on logits. - Saving full
modelinstead ofstate_dict(). Refactor breaks the checkpoint. - Forgetting
sampler.set_epoch(epoch)in DDP. Same shuffle every epoch → silently degraded training. num_workers=0in production. Single-threaded data loading; GPU starves.- Wrong dtype on indices.
torch.gatherand embedding lookups needint64.int32works on some ops, fails on others, with bad error messages. with torch.no_grad():around a backward. No graph was recorded → "element 0 of tensors does not require grad.".to(device)after wrapping in DDP instead of before. DDP wraps a CPU model, then ranks all see CPU params.nn.MultiheadAttentionwithoutbatch_first=True. Default is(S, B, E). If your(B, S, E)"works" without it, it's actually transposing your batch axis silently-the loss looks reasonable for a while because the model is symmetric in shape, then you debug for a day.- Validation loss computed under
train()mode. Same as #3. - Loss has a Python float in it.
loss = F.cross_entropy(logits, labels) + 0.01is fine;loss = F.cross_entropy(...) + my_python_varwheremy_python_varis a numpy scalar can break autograd on some versions. Keep losses as tensors throughout.
22. Practical exercises (with answer code)¶
These are intentionally compressed: read the prompt, attempt mentally, then read the answer.
Exercise 1-Implement a tiny transformer block¶
Build a pre-LN transformer block with multi-head self-attention and an MLP, suitable for (B, S, D) input. No external libs.
class TransformerBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int, p_drop: float = 0.1):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=p_drop, batch_first=True)
self.ln2 = nn.LayerNorm(d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Dropout(p_drop),
nn.Linear(d_ff, d_model),
nn.Dropout(p_drop),
)
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
# Pre-LN: normalize before each sublayer; residuals around each
h = self.ln1(x)
attn_out, _ = self.attn(h, h, h, attn_mask=attn_mask, need_weights=False, is_causal=attn_mask is not None)
x = x + attn_out
x = x + self.mlp(self.ln2(x))
return x
# Sanity check:
blk = TransformerBlock(d_model=64, n_heads=4, d_ff=256)
x = torch.randn(2, 16, 64)
mask = torch.triu(torch.ones(16, 16), diagonal=1).bool() # causal upper-triangular True = masked
y = blk(x, attn_mask=mask)
assert y.shape == (2, 16, 64)
Exercise 2-JSONL Dataset with tokenization¶
Each line is {"text": "..."}. Tokenize on the fly with a HF tokenizer; pad/truncate to a fixed length.
from transformers import AutoTokenizer
class TextJsonlDataset(Dataset):
def __init__(self, path: str, tokenizer_name: str = "gpt2", max_len: int = 256):
self.path = Path(path)
self.tok = AutoTokenizer.from_pretrained(tokenizer_name)
if self.tok.pad_token is None:
self.tok.pad_token = self.tok.eos_token
self.max_len = max_len
self.offsets = []
with open(self.path, "rb") as f:
offset = 0
for line in f:
self.offsets.append(offset)
offset += len(line)
def __len__(self):
return len(self.offsets)
def __getitem__(self, idx):
with open(self.path, "rb") as f:
f.seek(self.offsets[idx])
obj = json.loads(f.readline())
enc = self.tok(
obj["text"],
max_length=self.max_len,
truncation=True,
padding="max_length",
return_tensors="pt",
)
input_ids = enc["input_ids"].squeeze(0) # (max_len,)
attn = enc["attention_mask"].squeeze(0)
labels = input_ids.clone()
labels[attn == 0] = -100 # mask pads from loss
return {"input_ids": input_ids, "attention_mask": attn, "labels": labels}
Exercise 3-Convert an FP32 training step to BF16 AMP¶
Given a vanilla loop, add autocast cleanly.
# Before
def step_fp32(model, batch, opt, device):
x = batch["x"].to(device); y = batch["y"].to(device)
opt.zero_grad(set_to_none=True)
loss = F.cross_entropy(model(x), y)
loss.backward()
opt.step()
return loss.item()
# After
def step_bf16(model, batch, opt, device):
x = batch["x"].to(device, non_blocking=True)
y = batch["y"].to(device, non_blocking=True)
opt.zero_grad(set_to_none=True)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
loss = F.cross_entropy(model(x), y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
return loss.item()
Note: no GradScaler for BF16. Master parameters stay FP32 inside the optimizer, autocast handles the rest.
Exercise 4-Debug "expected CUDA tensor got CPU"¶
Given:
class Buggy(nn.Module):
def __init__(self, d):
super().__init__()
self.fc = nn.Linear(d, d)
self.scale = torch.tensor(0.1) # not a buffer-stays on CPU forever
def forward(self, x):
return self.fc(x) * self.scale
model = Buggy(64).cuda()
model(torch.randn(2, 64, device="cuda")) # RuntimeError: ... got Tensor on cpu
Fix. Register the constant as a buffer so .to(device) moves it:
class Fixed(nn.Module):
def __init__(self, d):
super().__init__()
self.fc = nn.Linear(d, d)
self.register_buffer("scale", torch.tensor(0.1))
def forward(self, x):
return self.fc(x) * self.scale
Diagnostic recipe applicable to any "device mismatch" bug:
for n, p in model.named_parameters():
print(n, p.device)
for n, b in model.named_buffers():
print(n, b.device)
Whichever printed CPU when everything else is CUDA is the culprit.
Exercise 5-Add gradient checkpointing to a stack¶
Take a stack that OOMs at seq_len=8192 and make it fit at +33% wall time.
from torch.utils.checkpoint import checkpoint
class CkptStack(nn.Module):
def __init__(self, blocks: list[nn.Module]):
super().__init__()
self.blocks = nn.ModuleList(blocks)
def forward(self, x):
for b in self.blocks:
if self.training:
x = checkpoint(b, x, use_reentrant=False)
else:
x = b(x)
return x
Notes: use_reentrant=False is required-modern. We only checkpoint when training (no need to save activations for backward at eval). This roughly halves activation memory in a long-context transformer.
Exercise 6-Distributed training in 30 lines¶
Write the absolute minimum DDP training script.
# train_min_ddp.py-torchrun --nproc_per_node=2 train_min_ddp.py
import os, torch, torch.nn as nn, torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler
def main():
rank = int(os.environ["RANK"]); local = int(os.environ["LOCAL_RANK"]); ws = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl", rank=rank, world_size=ws)
torch.cuda.set_device(local); dev = torch.device("cuda", local)
model = nn.Linear(64, 10).to(dev)
model = DDP(model, device_ids=[local])
opt = torch.optim.AdamW(model.parameters(), lr=1e-3)
ds = TensorDataset(torch.randn(10_000, 64), torch.randint(0, 10, (10_000,)))
sampler = DistributedSampler(ds, num_replicas=ws, rank=rank, shuffle=True)
loader = DataLoader(ds, batch_size=64, sampler=sampler, num_workers=2, pin_memory=True)
for epoch in range(3):
sampler.set_epoch(epoch)
for x, y in loader:
x, y = x.to(dev, non_blocking=True), y.to(dev, non_blocking=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
loss = F.cross_entropy(model(x), y)
opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
if rank == 0: print(f"epoch {epoch} loss {loss.item():.4f}")
dist.destroy_process_group()
if __name__ == "__main__":
main()
Run with torchrun --nproc_per_node=$N train_min_ddp.py. This is the smallest correct DDP training I can write-every removed line breaks something.
Cross-references¶
- PyTorch internals (dispatcher, autograd engine,
torch.compilepipeline):AI_SYSTEMS_PLAN/DEEP_DIVES/04_PYTORCH_INTERNALS.md. - Distributed training math (ring all-reduce, ZeRO, FSDP, parallelism strategies):
AI_SYSTEMS_PLAN/DEEP_DIVES/06_DISTRIBUTED_TRAINING.md. - CUDA / GPU programming model: see the GPU deep dive in the same directory.
Closing: the one-page mental model¶
If everything in this document collapsed into a single page, it would be this:
- Every tensor has
(shape, dtype, device). Know all three at every line. nn.Modulediscovers parameters by attribute assignment. UseModuleList/ModuleDict/register_buffercorrectly.- The training step is: zero grads → forward in autocast → backward → clip → optimizer step → scheduler step.
- Move data to device once per step at the boundary.
pin_memory=True+non_blocking=True. Never.to(device)insideforward. - BF16 is the default precision on modern GPUs. No
GradScalerneeded. - Save
state_dict()plus optimizer + scheduler + RNG. Load withmap_location="cpu"then.to(device). model.train()for training,model.eval()+torch.inference_mode()for evaluation.- AdamW with parameter groups (no decay on biases / LayerNorm / embeddings) + linear warmup + cosine decay is the universal recipe.
- DDP: one process per GPU,
DistributedSampler+set_epoch, save from rank 0 only. - Use
nn.CrossEntropyLoss/nn.BCEWithLogitsLosson raw logits-never softmax-then-CE.
Internalize those ten and ninety percent of the daily PyTorch code you'll ever write writes itself.