Skip to content

AI Systems

GPU programming, framework internals, distributed training, inference.

Printing this page

Use your browser's PrintSave as PDF. The print stylesheet hides navigation, comments, and other site chrome; pages break cleanly at section boundaries; advanced content stays included regardless of beginner-mode state.


AI Systems Engineering-A 24-Week Beginner-to-Advanced Mastery Roadmap

Authoring lens: Principal AI Systems Engineer / ML Performance Architect. Target outcome: A graduate of this curriculum is capable of (a) writing competitive GPU kernels in CUDA or Triton, (b) implementing distributed training (FSDP, tensor parallelism) on multi-node clusters, (c) building production-grade inference servers with paged KV-cache and continuous batching, and (d) operating GPU fleets at scale with cost, observability, and safety controls in place.

Crucially: this curriculum sits underneath ML research. It is not "learn ML in 24 weeks." It teaches the systems infrastructure that makes modern ML possible: the GPUs, kernels, frameworks, schedulers, and serving stacks. A graduate cannot necessarily train a frontier model from scratch, but they can make one run 3× faster and serve a million users without melting.


Why This Curriculum Exists

The companion curricula in this repository (RUST_TUTORIAL_PLAN, GO_LEARNIN_PLAN, LINUX, CONTAINER_INTERNALS_PLAN, KUBERNETES_PLAN) build the systems-engineering foundation. They are necessary but not sufficient for an AI-engineer career.

The gap they leave: the AI-systems-specific layer. GPU programming, accelerator runtimes, distributed training internals, inference serving, ML-on-Kubernetes patterns. This curriculum closes that gap, and is intentionally readable by a working backend/SRE engineer who has never written a CUDA kernel.

It is also designed to age gracefully. Specific tools (vLLM v0.x, PyTorch 2.x APIs) will shift in 2–4 years. The concepts-memory hierarchy on GPUs, attention math, parallelism patterns, scheduling theory for inference-are durable. Each module marks which is which, so refreshes target the ephemeral and not the spine.


Repository Layout

File Purpose
00_PRELUDE_AND_PHILOSOPHY.md The shape of "AI systems" as a discipline; cost model; reading list; what's durable vs ephemeral.
01_MONTH_FOUNDATIONS.md Weeks 1–4. Compute hierarchy, linear algebra, tensors, autograd, training loops. Beginner ramp.
02_MONTH_GPU_PROGRAMMING.md Weeks 5–8. GPU architecture, CUDA, memory optimization, Triton.
03_MONTH_FRAMEWORK_INTERNALS.md Weeks 9–12. PyTorch dispatcher, torch.compile, JAX/XLA, custom ops.
04_MONTH_DISTRIBUTED_TRAINING.md Weeks 13–16. NCCL, DDP/FSDP, tensor + pipeline parallelism, FP8.
05_MONTH_INFERENCE_SYSTEMS.md Weeks 17–20. KV-cache, paged attention, continuous batching, quantization, speculative decoding.
06_MONTH_INFRASTRUCTURE_CAPSTONE.md Weeks 21–24. ML-on-K8s, observability, safety/eval infra, capstone defense.
APPENDIX_A_HARDENING_AND_OBSERVABILITY.md GPU profiling, cost dashboards, fleet ops, model monitoring.
APPENDIX_B_BUILD_FROM_SCRATCH.md Reference implementations: attention, layer-norm, optimizer, dataloader, paged-cache.
APPENDIX_C_CONTRIBUTING.md Contribution paths to PyTorch, JAX, Triton, vLLM, Hugging Face.
CAPSTONE_PROJECTS.md Three tracks: mini-vLLM, FSDP-from-scratch, fused attention kernel.
DEEP_DIVES/ Eleven self-contained reference chapters (~96K words total). Authored to let the reader master each topic without the underlying papers. See DEEP_DIVES/README.md for the index.

The DEEP_DIVES/ Companion

The monthly modules are survey + lab. The deep dives are reference text. Eleven chapters totaling ~96,000 words, each authored to be a self-contained mastery resource for one major topic-derive everything, no external paper required:

  1. `01_GPU_ARCHITECTURE.md - pair with Month 2 §5.
  2. `02_CUDA_PROGRAMMING.md - pair with Month 2 §6–7.
  3. `03_TRITON.md - pair with Month 2 §8.
  4. `04_PYTORCH_INTERNALS.md - pair with Month 3 §9–10.
  5. `05_JAX_XLA.md - pair with Month 3 §11.
  6. `06_DISTRIBUTED_TRAINING.md - pair with Month 4 (all weeks).
  7. `07_ATTENTION_TRANSFORMER.md - pair with Month 5 §17.
  8. `08_INFERENCE_SERVING.md - pair with Month 5 §18.
  9. `09_QUANTIZATION.md - pair with Month 5 §19.
  10. `10_SPECULATIVE_DISAGGREGATION.md - pair with Month 5 §20.
  11. `11_NUMERICS_AND_MIXED_PRECISION.md - orthogonal; reference everywhere, anchor in Month 4 §16.

Each chapter is layered: intuition → mechanism → math → numbers → diagrams → code → pitfalls → six worked exercises.


How Each Week Is Structured

  1. Conceptual Core-the why, with a mental model.
  2. Mechanical Detail-the how, down to source files in PyTorch/JAX/vLLM/Triton/CUTLASS where relevant.
  3. Lab-a hands-on exercise that cannot be completed without internalizing the concept.
  4. Idiomatic & Diagnostic Drill-nsys, ncu, torch.profiler, dcgm, plus one shape-of-good-code review.
  5. Production Slice-a cost, observability, or reliability micro-task that compounds into a publishable template.

Each week is sized for ~12–16 focused hours.


Progression Strategy

Foundations (beginner) ──► GPU Programming ──► Framework Internals
        │                        │                       │
        └────────────┬───────────┴───────────────────────┘
            Distributed Training
             Inference Systems
       Infrastructure & Capstone

The first month is the beginner ramp. From week 5 onward the difficulty climbs steeply. By week 12 you should be reading framework source comfortably; by week 16 distributed-training papers; by week 20 OSDI/SOSP-tier inference papers.


Prerequisites

Hard prerequisites (without these, the curriculum will not stick): - Python fluency: classes, decorators, context managers, async basics. - Linux fluency: command line, basic systemd, nvidia-smi, cgroup awareness (Linux curriculum weeks 1–10 minimum). - C familiarity: pointers, memory layout, make/cmake. You don't need to be an expert; you need to read it. - Linear algebra: matrix multiplication, dot product, gradient. The first week refreshes these to the level the curriculum needs.

Soft prerequisites (helpful, not required): - Container fluency (CONTAINER_INTERNALS_PLAN weeks 1–3). - Kubernetes basics (KUBERNETES_PLAN Month 1). - Working knowledge of one ML framework (PyTorch preferred). If you have none, plan to spend 2–3 extra weeks before week 1 doing fast.ai's Practical Deep Learning part 1.

Hardware: - Weeks 1–4: any laptop. - Weeks 5–8: access to at least one NVIDIA GPU (RTX 30/40 series fine; cloud T4/L4/A10G fine). RunPod, Lambda Labs, Vast.ai are budget-friendly. - Weeks 13–16: access to at least 2 GPUs on one node, and ideally 4–8 across two nodes for one week of distributed-training labs. Cloud-rented; ~$200–500 budget. - Weeks 17–20: a single A100 or H100 (cloud-rented, ~$2/hour) for two of the four labs. Smaller GPUs work for the others. - Weeks 21–24: depends on capstone choice.

If hardware budget is tight: do everything you can on Colab + a single rented L4 ($0.50–0.80/hour). The curriculum's designs are sized to fit.


Capstone Tracks (pick one in Month 6)

  1. Inference Engine Track-Build a mini-vLLM in Python+CUDA: paged KV-cache, continuous batching, FP8 weight quantization. Benchmark within 2× of production vLLM on a 7B model.
  2. Training Systems Track-Implement FSDP-style sharded data parallelism from scratch using NCCL collectives. Train a small transformer on 4–8 GPUs across 2 nodes; demonstrate scaling efficiency.
  3. GPU Kernel Track-Author a fused attention kernel in Triton (and optionally CUTLASS) competitive with FlashAttention-2 for one shape regime. Profile, document, contribute upstream.

Details in CAPSTONE_PROJECTS.md.


What This Curriculum Does Not Cover

To set expectations honestly:

  • Foundational ML theory-backprop derivation, optimization theory, generalization bounds. Use a separate course (Bishop, Goodfellow, fast.ai).
  • Model architecture research-designing new attention variants, scaling laws investigation. This is the research scientist track; this curriculum is the systems engineer track.
  • Computer vision / RL / classical ML pipelines-the focus is on the transformer-LLM stack because that's where the systems-engineering pressure is in 2026 and likely 2030. CV and RL infra share many primitives; the deltas are well-trodden.
  • Prompt engineering, agents, RAG architecture-application-layer concerns. Important, well-covered elsewhere.
  • AI ethics, governance, policy-flagged in week 23 but not deep-dived.

A Note on the Field's Velocity

This is the fastest-moving area in software in 2026. The curriculum copes by: 1. Anchoring each week to at least one peer-reviewed paper that is unlikely to be invalidated. 2. Distinguishing algorithmic content (durable: attention math, ZeRO, paged attention) from API content (ephemeral: vLLM v0.x flags, PyTorch 2.x compile modes). 3. Pointing at source repositories that are likely to remain canonical (PyTorch core, JAX, Triton, vLLM, FlashAttention).

When the curriculum says "in 2026 this is true," it is dated for a reason. Re-evaluate yearly.

Prelude-The Shape of the Discipline

Sit with this document for an evening before week 1. It is the only place in the curriculum where we step back from mechanics and define what "AI systems engineering" actually is.


1. Two Disciplines, One Field

Modern AI is built by two cooperating disciplines that often share a name:

ML Research AI Systems Engineering
Question What should we compute? How do we compute it efficiently?
Output New architectures, training recipes, benchmark wins. Faster kernels, larger models possible, lower-cost inference, higher reliability.
Optimizes Loss curves. Tokens/sec/dollar.
Reads NeurIPS, ICML, ICLR. OSDI, SOSP, MLSys, ASPLOS, ISCA.
Writes in Python + PyTorch high-level. CUDA, Triton, C++, Rust, the framework's internals.

This curriculum trains the second discipline. You finish able to take any model the research team hands you and: make it train on your hardware, make it run in production, make it cheap, make it observable.

The economic pressure in 2026 is overwhelmingly on the second discipline. Every frontier-lab paper costs millions in compute; every percent of efficiency saves real money; every inference architecture decision compounds across millions of users. The half-life of a research idea is a year; the half-life of the systems infrastructure that serves it is a decade.


2. The Five-Axis Cost Model

A working AI systems engineer reasons along five axes simultaneously:

Axis Question to ask
FLOPs How many floating-point ops does this op cost? Is it compute-bound?
Bytes How much data moves between memory tiers (HBM ↔ SRAM, host ↔ device, node ↔ node)? Is it memory-bound?
Arithmetic intensity FLOPs / Bytes. The ratio that determines whether tensor cores or HBM bandwidth is the limit.
Parallelism Across what axis are we splitting work-batch, sequence, head, layer? What synchronization cost?
Failure What happens on OOM, NaN, NCCL timeout, preemption, datacenter outage?

Beginner ML courses teach axis 1 only ("this model is N billion parameters"). The cost-model in production is dominated by axes 2 and 3-which is exactly why FlashAttention exists, why paged attention exists, why mixed precision exists.

The single most important number in modern AI systems engineering is the arithmetic intensity at which a hardware platform crosses from memory-bound to compute-bound. For an H100, that crossover is roughly 295 FLOP/byte (BF16). Below it, you're starving the tensor cores; above it, HBM doesn't matter. Memorize this. Most performance work is moving operations across that crossover.


3. The Roofline Model

Every performance discussion in this curriculum will use the roofline model (Williams, Waterman, Patterson, 2009-the most useful single paper in computer architecture):

Performance (FLOP/s) = min(
    Peak compute,
    Peak bandwidth × Arithmetic intensity
)

Plot this on log-log axes: a horizontal line at peak compute (the "roof") and a slanted line at peak bandwidth (the "ramp"). Every kernel is a point. If you're under the ramp, you're bandwidth-limited-buy bandwidth (or recompute to reduce bytes moved). If you're under the roof but past the ramp's knee, you're compute-limited-get faster math (tensor cores, lower precision).

By week 5, you will sketch this from memory.


4. The Reading List

Foundational papers (read in order, ideally during weeks 1–4):

  1. Williams et al., Roofline: An Insightful Visual Performance Model (2009). The lens.
  2. Vaswani et al., Attention Is All You Need (2017). The architecture that defines this era.
  3. Dao et al., FlashAttention (2022) and FlashAttention-2 (2023). The most important systems paper of the LLM era.
  4. Kwon et al., Efficient Memory Management for Large Language Model Serving with PagedAttention (SOSP 2023). The vLLM paper.
  5. Rajbhandari et al., ZeRO: Memory Optimizations Toward Training Trillion Parameter Models (SC 2020).
  6. Shoeybi et al., Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism (2019).
  7. Yu et al., Orca: A Distributed Serving System for Transformer-Based Generative Models (OSDI 2022). Continuous batching's origin.
  8. Leviathan, Kalman, Matias, Fast Inference from Transformers via Speculative Decoding (ICML 2023).
  9. Frantar, Ashkboos, Hoefler, Alistarh, GPTQ: Accurate Post-Training Quantization (ICLR 2023). And Lin et al., AWQ (MLSys 2024).

Books: - Programming Massively Parallel Processors (Hwu, Kirk, Hajj, 4e). The CUDA textbook. - Computer Architecture: A Quantitative Approach (Hennessy & Patterson, 6e). Chapters 4–5 on data-level parallelism. Required if you want to be more than a recipe-follower. - Deep Learning (Goodfellow, Bengio, Courville). Chapters 6–8 for the math vocabulary. - Designing Machine Learning Systems (Chip Huyen). For the production framing.

Source repositories to bookmark: - pytorch/pytorch - the framework. Particularlyaten/,torch/csrc/,torch/_inductor/. -openai/triton - the GPU DSL. - NVIDIA/cutlass - high-performance GEMM templates. -vllm-project/vllm - the canonical inference server. - Dao-AILab/flash-attention - the kernel. -pytorch/FBGEMM,NVIDIA/TransformerEngine - quantization and FP8. - google/jax, openxla/xla - the JAX/XLA stack. -microsoft/DeepSpeed,NVIDIA/Megatron-LM - training stacks.

Adjacent canon (you must know): - The roofline paper, mentioned above. - The Linux + Container + Kubernetes curricula for the substrate this all runs on.


5. What's Durable, What's Ephemeral

A 2030 reader will need most of the conceptual content. They will need to refresh much of the API content. The curriculum flags this on a per-week basis. The general pattern:

Durable (10+ year half-life) Ephemeral (2–4 year half-life)
Roofline model Specific GPU's peak FLOPS
Memory hierarchy on GPUs Ada/Hopper/Blackwell-specific instructions
Attention math The nth FlashAttention variant
Parallelism patterns (DP/TP/PP) Specific FSDP / DeepSpeed APIs
Continuous batching theory vLLM's specific scheduler
Quantization theory (INT8, FP8 representability) AWQ vs GPTQ vs SmoothQuant winner
Roofline-aware kernel design Triton vs CUTLASS vs Mojo trajectories
The dispatcher pattern PyTorch 2.x's exact dispatcher API

Bias your study toward the left column. The right column is where you'll cite "as of 2026" timestamps in your code comments.


6. Curriculum Philosophy

  1. Paper first, framework docs second. When the curriculum says "study paged attention," it means open the SOSP 2023 paper. Then read the vLLM source. The framework docs are a tertiary source.
  2. Profile before you optimize. Every performance lab requires a nsys or ncu capture before any change. The most common beginner failure mode is "optimizing" code that wasn't actually slow.
  3. One artifact per phase. End of each month produces a benchmarked, profile-attached, reviewable artifact. The capstone is not a surprise; it's the natural assembly of the monthly artifacts.

7. What AI Systems Are Not For

A graduate of this curriculum should be able to argue these points without sounding evangelical:

  • Small datasets, simple problems. Linear/tree models on tabular data still win much of the time. Don't reach for a transformer because it's the new tool.
  • Hard real-time inference. Modern LLMs have unpredictable latency tails. For sub-millisecond hard deadlines, use a small distilled model (or no LLM).
  • Privacy-critical workloads on shared GPUs. Multi-tenancy on a single GPU has unsolved isolation problems (timing channels, memory residue). Use dedicated hardware or a confidential-computing GPU.
  • Greenfield projects with no production traffic. If you're not yet load-bound, the ML systems infrastructure is overkill. Use a managed inference API (Anthropic, OpenAI, Bedrock, Vertex) until you can justify the lift.

The signal that AI systems engineering is the right tool: you have a cost, latency, sovereignty, or scale constraint that ranks above iteration speed.


8. AI-Assisted Workflows (in an AI Systems curriculum)

The recursive irony is unavoidable: you will use LLMs to learn how to build systems for LLMs. Three rules:

  1. Never accept generated CUDA without ncu profiling. Models hallucinate index math. The kernel may compile and produce mostly-correct outputs while having a 10× perf cliff or a subtle race.
  2. Never accept generated NCCL / distributed code without - race - equivalent. PyTorch's DDP has its own timing assumptions; NCCL has its own deadlock modes. AI-generated all-reduce patterns are the single most common source of "works on 2 GPUs, hangs on 8."
  3. Always read the underlying paper. When the model summarizes paged attention or speculative decoding, the summary is plausibly wrong in ways that matter for implementation. The paper is short. Read it.

You are now ready for Week 1. Open 01_MONTH_FOUNDATIONS.md.

Month 1-Foundations: Compute Hierarchy, Tensors, Autograd, Training Loops

Goal: by the end of week 4 you can (a) sketch the memory and compute hierarchy from CPU register to multi-node cluster and put numerical bandwidth/latency on each step, (b) implement matrix multiplication three ways with measured performance differences, (c) implement reverse-mode automatic differentiation from scratch, and (d) write an honest training loop that handles checkpointing, mixed precision, and metrics.

This is the beginner ramp. If you already do all four, skim and proceed to Month 2. If you don't, this is the hardest month-concepts here are referenced everywhere else.


Weeks

Week 1 - The Compute Hierarchy and the Cost Model

1.1 Conceptual Core

  • Modern AI computation crosses seven memory tiers, each ~10× slower and ~10× larger than the one above. Performance is determined by which tier you're operating in:
  • Registers (~1 KB per core, 0-cycle latency, ~10 TB/s).
  • L1 / shared memory / SRAM (~64 KB per SM/core, ~5-cycle, ~5 TB/s).
  • L2 cache (~50 MB on H100, ~50-cycle).
  • HBM / VRAM (~80 GB on H100, ~500-cycle, ~3 TB/s).
  • Host DRAM (~TB, ~PCIe-cycle, ~64 GB/s over PCIe Gen5).
  • Local NVMe (~TB+, ms latency, ~10 GB/s).
  • Network / cluster (TB to PB, microseconds–milliseconds, depends on fabric).
  • The same model trained across 1,024 GPUs is the same algorithm applied at every layer of this hierarchy: keep the hot data near the compute.
  • The defining trick of modern ML systems: most operations are memory-bound. The ALUs are starved. Performance work is moving data better, not computing faster.

1.2 Mechanical Detail

  • Floating-point formats in 2026:
  • FP32 (single, 32-bit): 8-bit exponent, 23-bit mantissa. Default training precision historically.
  • FP16 (half, 16-bit): 5-bit exponent, 10-bit mantissa. Limited dynamic range-overflows training.
  • BF16 (bfloat16, 16-bit): 8-bit exponent (matches FP32), 7-bit mantissa. The standard training format on modern GPUs/TPUs.
  • FP8 (8-bit, two variants): E4M3 (4 exp, 3 mantissa) and E5M2 (5 exp, 2 mantissa). H100/H200 native; the future of training.
  • INT8 / INT4: integer quantization formats for inference.
  • Hardware peak numbers worth memorizing (2026, single GPU):
  • H100 SXM: ~989 TFLOPS BF16 dense, ~1979 TFLOPS FP8 dense, ~3.35 TB/s HBM.
  • H200: same compute, ~4.8 TB/s HBM.
  • B200 (Blackwell): roughly 2.5× H100 BF16, 5× FP8 with sparsity, ~8 TB/s HBM.
  • A100: ~312 TFLOPS BF16, ~2 TB/s HBM. Still the workhorse.
  • Arithmetic intensity = FLOPs per byte loaded. The crossover point between memory-bound and compute-bound is hardware-specific. For H100 BF16: ~295 FLOP/byte. Below it: memory-bound. Above: compute-bound. This is the most important number in your career.

1.3 Lab-"Roofline Sketch"

Write a small program (Python+NumPy, or any) that: 1. Performs C = A @ B for square matrices N=64, 256, 1024, 4096. 2. Times each. Computes achieved FLOPS (= 2·N³ / time). 3. Computes the bytes moved (= 3·N²·sizeof(dtype)). 4. Plots achieved FLOPS vs arithmetic intensity on log-log axes. 5. Overlays the theoretical roofline of your laptop CPU (look up its peak FLOPS and DRAM bandwidth).

You should see the small N points sit under the bandwidth ramp and the large N points approach the compute roof. Keep the plot-every subsequent lab will produce another.

1.4 Idiomatic & Diagnostic Drill

  • Install htop, numactl, perf. Pin the matmul to a single CPU core (taskset -c 0); observe perf change. Pin to a NUMA node (numactl --cpunodebind=0 --membind=0); observe again.

1.5 Production Slice

  • Most production AI workloads run on cloud GPUs metered per second. Build a one-page "GPU economics cheat sheet" with cost per hour for: A100 (40GB and 80GB), H100, H200, B200 across at least three providers (AWS, GCP, Lambda, RunPod). Update yearly.

Week 2 - Linear Algebra Refresh, BLAS, NumPy

2.1 Conceptual Core

  • The only linear algebra you need for systems work, for now: matrix-vector multiply, matrix-matrix multiply (GEMM), elementwise ops, reductions. Almost every neural-network primitive decomposes into these.
  • GEMM (C = αAB + βC) is the most-studied operation in scientific computing. BLAS Level 3 routines (sgemm, dgemm, hgemm) are heavily optimized. Modern hardware defines its peak FLOPS by sgemm performance.
  • The naive triple-loop matmul achieves <5% of peak. Tiled, blocked, vectorized matmuls achieve >90%. Understanding the gap is week 2's whole content.

2.2 Mechanical Detail

  • NumPy uses BLAS underneath. numpy.dot(A, B) with appropriately-built NumPy hits the BLAS path; a Python triple-loop is ~1000× slower.
  • OpenBLAS / Intel MKL / Apple Accelerate are the dominant CPU BLAS implementations. Intel oneMKL is fastest on Intel; OpenBLAS is portable.
  • GPU BLAS: NVIDIA cuBLAS, AMD rocBLAS. Wrapped by every framework.
  • Einsum notation (numpy.einsum("ij,jk->ik", A, B)) is the lingua franca of multi-dimensional tensor ops. It generalizes matmul, batch matmul, transpose, sum, contraction. Learn it.

2.3 Lab-"Three Matmuls"

Implement 1024×1024 matmul three ways: 1. Naive triple-loop in Python (will take ~minutes; that's the point). 2. Naive in NumPy with explicit loops-only marginal speedup. 3. numpy.dot-measure speedup over (1).

You should see ~10,000× speedup between (1) and (3). Internalize why. Read Goto and van de Geijn's "Anatomy of a High-Performance Matrix Multiplication" if you want the deep version (recommended).

2.4 Idiomatic & Diagnostic Drill

  • python -c 'import numpy; numpy.show_config()' - see which BLAS your NumPy is linked against. Reinstall withconda install numpy` (which pulls MKL on Linux/Windows) and re-benchmark; observe.

2.5 Production Slice

  • Add a requirements.txt to your project with versions pinned. NumPy/BLAS bugs in version drift have cost real money-pin everything.

Week 3 - Tensors, Autograd, the Gradient Tape

3.1 Conceptual Core

  • A tensor is an N-dimensional array with a dtype, a shape, a device, and a computation graph attached (if it requires grad).
  • Automatic differentiation has two modes:
  • Forward-mode (efficient when outputs ≫ inputs): propagate derivatives alongside values.
  • Reverse-mode / backpropagation (efficient when inputs ≫ outputs, the ML case): build a graph during forward, traverse it backward.
  • PyTorch implements dynamic (define-by-run) reverse-mode AD via a graph built from Function nodes. JAX implements functional AD via tracing.
  • The single most useful thing about reverse-mode is the VJP (vector-Jacobian product): given output gradients, propagate to input gradients without ever materializing the Jacobian matrix.

3.2 Mechanical Detail

  • A PyTorch tensor with requires_grad=True records every op into a graph. loss.backward() traverses the graph, calling each op's backward function.
  • The graph is built per-iteration (this is what "dynamic" means). At backward(), the graph is consumed and discarded (unless retain_graph=True).
  • torch.no_grad() disables graph building-used during inference and during certain training tricks (target networks in RL, EMA updates).
  • detach() creates a tensor that shares storage but is severed from the graph.
  • Custom autograd: torch.autograd.Function lets you define forward/backward pairs. The escape hatch when you need a custom op (week 11–12).

3.3 Lab-"Autograd From Scratch"

Implement reverse-mode AD in ~100 lines of pure Python (no PyTorch). Support: - A Tensor class wrapping a NumPy array with a grad field. - __add__, __mul__, __matmul__, relu, sum. Each records its inputs and a backward function. - A backward() method that topologically sorts and traverses the graph. - Test on a tiny MLP: define f = x @ W1 + b1; g = relu(f); h = g @ W2 + b2; loss = h.sum(). Verify the gradients match a torch.autograd reference within float-precision.

This is Andrej Karpathy's micrograd exercise. Do it before reading his code; then read his code and compare.

3.4 Idiomatic & Diagnostic Drill

  • Run a real training step and inspect tensor.grad_fn. Walk the graph manually: loss.grad_fn.next_functions[0][0].next_functions....

3.5 Production Slice

  • The most common beginner bug: forgetting optimizer.zero_grad(), accumulating gradients across iterations. Add a unit test to your training loop scaffolding that asserts gradients are zeroed at the start of every step.

Week 4 - The Honest Training Loop

4.1 Conceptual Core

  • A "real" training loop has ~15 concerns most tutorials skip:
  • Data loading (parallel, prefetched, shuffled, sharded).
  • Forward + backward + optimizer step.
  • Mixed-precision casting (AMP).
  • Gradient accumulation (effective batch > per-step batch).
  • Gradient clipping (NaN-prevention).
  • LR schedule.
  • Checkpointing (model, optimizer, scheduler, RNG state).
  • Resumption (idempotent, exact reproduction).
  • Validation loop (no grad, EMA where applicable).
  • Metrics logging (loss, throughput, GPU util, memory).
  • Determinism flags (torch.use_deterministic_algorithms for debugging).
  • Error handling (NaN detection, OOM recovery).
  • Multi-GPU coordination (preview; full in Month 4).
  • Early stopping and best-checkpoint tracking.
  • Run metadata (commit, config, hardware fingerprint).

4.2 Mechanical Detail

  • torch.utils.data.Dataset + DataLoader: the canonical data path. num_workers > 0 for parallel loading; pin_memory=True for faster H2D copies; prefetch_factor for tuning lookahead.
  • torch.cuda.amp.autocast + GradScaler for mixed-precision FP16 training. For BF16, autocast alone is sufficient (no scaling needed; the dynamic range is enough).
  • torch.utils.checkpoint.checkpoint for gradient checkpointing-trade compute (extra forward pass) for memory (don't store activations). Essential for fitting large models.
  • Determinism: setting torch.manual_seed, numpy.random.seed, random.seed, torch.use_deterministic_algorithms(True), and CUBLAS_WORKSPACE_CONFIG=:4096:8 is not always enough-some kernels are inherently non-deterministic. Document what you achieve.

4.3 Lab-"Train Something Small, Right"

Train a 1-layer transformer (or 2-layer MLP if transformer is too far) on TinyShakespeare or MNIST. Required: - Dataset class + DataLoader with num_workers=4, pin_memory=True. - AMP autocast (BF16 on Ampere+, FP16 with GradScaler on older). - LR schedule (warmup + cosine). - Checkpoint every N steps; able to resume from any checkpoint and produce identical loss thereafter (within 1e-5). - Per-step metrics: loss, tokens/sec, GPU memory, GPU util%. - Final report: train + val loss curves, throughput, peak memory, total cost in $ (compute hours × $/hr).

4.4 Idiomatic & Diagnostic Drill

  • torch.profiler.profile for one training step. Read the output. Identify what fraction of time is spent in: forward, backward, optimizer step, data loading, idle.

4.5 Production Slice

  • Ship the training script as a reproducible Docker image. Pin Python, PyTorch, CUDA, NumPy versions. Embed git rev-parse HEAD into the run metadata. This is the foundation for every later lab.

Month 1 Capstone Deliverable

A foundations/ directory containing: 1. roofline-sketch/ (week 1)-measured roofline plot for your laptop and one GPU. 2. three-matmuls/ (week 2)-naive vs NumPy timings, with a markdown writeup. 3. micrograd/ (week 3)-your AD from scratch. 4. honest-training-loop/ (week 4)-the full training scaffold, reproducible.

By the end of Month 1 you should be comfortable in PyTorch-not a wizard. Comfort is what Month 2's GPU work assumes.


  • The roofline paper.
  • Karpathy's micrograd and nanoGPT (the latter is the right level of implementation depth for the year ahead).
  • The PyTorch autograd documentation, end to end.
  • Goodfellow et al., Deep Learning, chapters 6 and 8.

Month 2-GPU Programming: Architecture, CUDA, Memory, Triton

Goal: by the end of week 8 you can (a) describe the GPU's hierarchical execution model from grid down to warp lane, (b) write CUDA kernels that achieve >70% of peak BW or compute on memory-bound and compute-bound problems respectively, (c) use shared memory and tensor cores correctly, and (d) write equivalent kernels in Triton with within-2× performance and 5× less code.

Deep-dive companions (read in tandem): - Week 5 → DEEP_DIVES/01_GPU_ARCHITECTURE.md - full SM/memory/tensor-core derivation, occupancy theory, NVLink topology. - Week 6–7 →DEEP_DIVES/02_CUDA_PROGRAMMING.md - six-stage tiled GEMM with code, mma.sync PTX, complete buildable BF16 GEMM at 60–70% cuBLAS. - Week 8 → `DEEP_DIVES/03_TRITON.md - block-level model, autotune, six annotated kernels including a simplified flash-attention.


Weeks

Week 5 - GPU Hardware Architecture

5.1 Conceptual Core

  • A modern NVIDIA GPU (Hopper H100 used here as the canonical example) consists of:
  • 132 SMs (Streaming Multiprocessors), each functionally an independent processor.
  • 80 GB HBM3 at ~3 TB/s.
  • 50 MB L2 shared across SMs.
  • Per-SM resources: 256 KB register file, 256 KB combined L1/shared, ~64 FP32 cores, 4 tensor cores, 1 ray-tracing core.
  • The GPU runs threads in groups of 32 called warps. Threads in a warp execute the same instruction in lockstep (SIMT-single-instruction multi-thread). Branching within a warp causes divergence-warp executes both branches serially, masking off the inactive lanes.
  • A block is a group of warps (up to 32 warps = 1024 threads) that share an SM, share L1/shared memory, and can synchronize via __syncthreads().
  • A grid is the set of blocks for a kernel launch. Blocks within a grid do not share state (other than HBM) and do not synchronize.

5.2 Mechanical Detail

  • Warp scheduling: each SM has 4 warp schedulers; each cycle, each scheduler picks one ready warp from up to 16 resident warps and issues an instruction. Latency is hidden by switching warps-not by speculation as on CPUs. This is why you want many resident warps ("occupancy").
  • Memory tiers per-SM:
  • Registers: 256 per thread × 1024 threads = 256K total; spill to local memory (slow).
  • Shared memory: 228 KB per SM on H100 (configurable split with L1).
  • L1 cache: shares budget with shared memory.
  • Tensor cores: specialized matmul-accumulate units. On H100 each TC does an 8x4 BF16 GEMM per cycle; aggregated across the chip ~989 TFLOPS BF16. They demand specific data layouts (16x16 tiles in registers); the wmma and mma.sync PTX instructions expose them.
  • Async copy (cp.async): copy from HBM to shared memory in the background; thread continues other work, syncs later. Foundation for software pipelining (used in FlashAttention).

5.3 Lab-"Inspect Your Hardware"

  1. Run nvidia-smi and nvidia-smi -q. Read every line.
  2. Compile and run NVIDIA's deviceQuery sample. It prints all the numbers above for your specific GPU.
  3. Compile and run bandwidthTest (CUDA samples). Compare measured PCIe and HBM bandwidth to spec.
  4. Compute: at the measured HBM BW and compute peak of your GPU, what is the arithmetic intensity break-even? Sketch the roofline.

5.4 Idiomatic & Diagnostic Drill

  • Install Nsight Systems (nsys) and Nsight Compute (ncu). They are the tools. Familiarize with the GUI; learn nsys profile -o trace.qdrep ./prog and ncu --set full ./prog.

5.5 Production Slice

  • Document your GPU fleet's exact SKUs in a HARDWARE.md: model, HBM size, peak BW, peak FLOPS BF16/FP16/FP8, TDP. Cluster ops decisions hinge on this.

Week 6 - Your First CUDA Kernels

6.1 Conceptual Core

  • A CUDA program has host code (runs on CPU, in C++) and device code (runs on GPU, in C++ with CUDA extensions). They share a binary, but compile via different toolchains (nvcc orchestrates).
  • A kernel is a __global__ function called from the host with a launch configuration: kernel<<<gridDim, blockDim, sharedBytes, stream>>>(args).
  • Within a kernel, each thread sees threadIdx, blockIdx, blockDim, gridDim to compute its own work.

6.2 Mechanical Detail

  • Vector add is the canonical first kernel:
    __global__ void vadd(const float* a, const float* b, float* c, int n) {
        int i = blockIdx.x * blockDim.x + threadIdx.x;
        if (i < n) c[i] = a[i] + b[i];
    }
    // launch: vadd<<<(n+255)/256, 256>>>(a, b, c, n);
    
  • Memory transfer:
  • cudaMalloc, cudaMemcpy, `cudaFree - the basics.
  • Pinned memory (cudaMallocHost): host memory that can DMA directly. Faster H2D/D2H copies.
  • Unified memory (cudaMallocManaged): single pointer accessible from both; the runtime migrates pages on demand. Easy but unpredictable; avoid for performance code.
  • Async (cudaMemcpyAsync + streams): overlap copy with compute. Essential.
  • Streams: queues of kernels and copies. Operations within a stream are sequential; across streams, concurrent. Default stream is special (synchronizes with all others); use explicit non-default streams in production.
  • Error handling: every CUDA call returns an error code. Wrap with a macro (CUDA_CHECK(...)) that aborts on error. The most-skipped step in beginner code; the source of every "silent corrupt output" bug.

6.3 Lab-"Kernel Speedrun"

Write three kernels in CUDA C++: 1. Vector add: SAXPY (y = a*x + y). Time vs cuBLAS axpy. 2. Reduction: sum a million floats. Compare your naive version (one global atomic) with a hierarchical version (block-level reduction in shared memory, then global). Expect ~100× difference. 3. Naive matmul: 1024×1024 BF16. Compare to cuBLAS-expect to be 50-100× slower. Don't get discouraged; you'll close most of the gap in week 7.

For each: measure runtime with cudaEvent_t timing; compute achieved throughput; mark on the roofline.

6.4 Idiomatic & Diagnostic Drill

  • Run each kernel under ncu --set full. Read the "GPU speed of light" section: it tells you the % of peak compute and BW you achieved. Memorize the report layout.

6.5 Production Slice

  • Wrap CUDA error checking and timing in a small C++ header you'll reuse all month. This is your cuda_utils.cuh.

Week 7 - Memory Optimization: Coalescing, Shared Memory, Tensor Cores

7.1 Conceptual Core

  • The naive matmul of week 6 is slow because:
  • Uncoalesced memory access: adjacent threads read non-adjacent addresses. Each warp issues many memory transactions instead of one.
  • No data reuse: each element of A is loaded N times from HBM.
  • No tensor cores: scalar FP32 ops, not 16×16 BF16 GEMM blocks.
  • Three optimizations, each ~5-10×:
  • Coalesce-make threads in a warp read adjacent addresses.
  • Tile in shared memory-a block cooperatively loads a 32×32 tile of A and B into shared memory; each thread computes its output using shared data. Each element of A loaded once per block.
  • Tensor cores-use wmma (or nvcuda::wmma) intrinsics to issue 16×16 GEMM blocks. ~10× over CUDA cores at BF16.

7.2 Mechanical Detail

  • Coalescing rule: a warp's 32 memory accesses to a 128-byte aligned, contiguous range = 1 transaction. Strided or scattered = up to 32 transactions.
  • Shared memory bank conflicts: shared memory is divided into 32 banks (4-byte stride). If two threads in a warp access the same bank but different addresses, conflict-serialized. Common with column-major access. Fix: pad arrays (shared_mem[32][33] not [32][32]).
  • Double buffering: while the SM computes on tile N, asynchronously load tile N+1 with cp.async. The compute hides the load latency. This is software pipelining.
  • Tensor core usage (CUDA C++):
    using namespace nvcuda::wmma;
    fragment<matrix_a, 16, 16, 16, half, row_major> a_frag;
    fragment<matrix_b, 16, 16, 16, half, col_major> b_frag;
    fragment<accumulator, 16, 16, 16, float> c_frag;
    load_matrix_sync(a_frag, A_smem, 16);
    load_matrix_sync(b_frag, B_smem, 16);
    mma_sync(c_frag, a_frag, b_frag, c_frag);
    store_matrix_sync(C_gmem, c_frag, N, mem_row_major);
    

7.3 Lab-"Climb the Roofline"

Take your week 6 naive matmul and progressively optimize: 1. Coalesce loads (transpose access pattern). Re-time. 2. Tile in shared memory with 32×32 blocks. Re-time. 3. Double-buffer with cp.async. Re-time. 4. Use tensor cores with BF16. Re-time.

You should reach 30–60% of cuBLAS perf. Document each step's improvement and the residual gap. Read NVIDIA's cutlass examples for the production-grade version.

7.4 Idiomatic & Diagnostic Drill

  • For each version, capture an ncu report. The metrics that matter: sm__cycles_active.avg.pct_of_peak_sustained_elapsed (tensor core utilization), dram__bytes_read.sum (HBM traffic), l1tex__t_sectors_pipe_lsu_mem_global_op_ld.sum (load coalescing).

7.5 Production Slice

  • Production CUDA kernels go through cuBLAS, cuDNN, or CUTLASS, not from-scratch CUDA, in 99% of cases. Read CUTLASS's examples/ directory; understand its template-based GEMM and how it's tuned per architecture.

Week 8 - Triton: GPU Kernels From Python

8.1 Conceptual Core

  • Triton (Tillet et al., 2019; OpenAI's project) is a Python-embedded DSL that compiles to PTX. It targets the same execution model as CUDA but at a block (not thread) abstraction: you write the work for a block of threads, the compiler handles intra-block parallelism.
  • The promise: 80% of CUTLASS performance with 20% of the code. The reality (as of 2026): often achieves it for memory-bound kernels (attention, layer norm); sometimes leaves 2-3× on the table for compute-bound GEMMs vs hand-tuned CUTLASS.
  • Triton is the dominant kernel authoring DSL in the open-source LLM ecosystem (FlashAttention's reference impl, vLLM's custom kernels, Liger Kernel, Unsloth-all Triton).

8.2 Mechanical Detail

  • A Triton kernel is a @triton.jit Python function. Inside, you operate on blocks (vectors / tiles of values) rather than individual scalars:
    @triton.jit
    def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr):
        pid = tl.program_id(0)
        offsets = pid * BLOCK + tl.arange(0, BLOCK)
        mask = offsets < n
        x = tl.load(x_ptr + offsets, mask=mask)
        y = tl.load(y_ptr + offsets, mask=mask)
        tl.store(out_ptr + offsets, x + y, mask=mask)
    
  • Memory ops: tl.load(ptr, mask), tl.store(ptr, val, mask).
  • Math ops: elementwise, dot (`tl.dot - uses tensor cores!), reductions.
  • Autotuning: @triton.autotune(configs=[...]) sweeps block sizes / num_warps / num_stages; picks fastest at runtime.
  • The compiler handles: vectorization, software pipelining, register allocation, cp.async insertion, tensor-core mapping.

8.3 Lab-"Three Triton Kernels"

  1. Elementwise add (the Hello World).
  2. Softmax with online maximum subtraction (numerical stability). Compare to torch.softmax perf.
  3. Naive matmul in Triton with autotuning. Compare to cuBLAS-you should reach 70-90% of peak for square BF16 matmul on common shapes.

8.4 Idiomatic & Diagnostic Drill

  • TRITON_PRINT_AUTOTUNING=1 to see autotune traces. triton.compiler.compile(...).asm['ptx'] to inspect generated PTX. The PTX-level view becomes useful in week 16+ when you debug kernel choice.

8.5 Production Slice

  • Build a small library of "kernels you'll need later": fused softmax, layer norm, RMSNorm, fused dropout. Each ≤100 lines, autotuned, benchmarked vs PyTorch reference.

Month 2 Capstone Deliverable

A gpu-programming/ directory: 1. hardware-survey/ - your hardware tier. 2.cuda-kernels/ - vector add, reduction, naive matmul, optimized matmul. 3. triton-kernels/ - three kernels with autotune, benchmark plots vs PyTorch baseline. 4. AKERNEL_LOG.mddocumenting each optimization step'sncu` deltas.

This is the artifact that will impress GPU-engineer interviewers.


  • Programming Massively Parallel Processors, Hwu/Kirk/Hajj, chapters 1–6.
  • The CUDA C++ Programming Guide, sections 1–5.
  • The Triton paper (MAPL 2019).
  • NVIDIA's "GPU Performance Background" technical blog posts.
  • The CUTLASS README and the gemm_universal example.

Month 3-Framework Internals: PyTorch, torch.compile, JAX/XLA, Custom Ops

Goal: by the end of week 12 you can (a) read PyTorch's dispatcher source and trace an op from Python through ATen to a CUDA kernel, (b) explain torch.compile's graph capture and Inductor backend, (c) read JAX/XLA HLO and reason about XLA optimizations, and (d) ship a custom CUDA kernel as a PyTorch extension callable from Python.

Deep-dive companions (read in tandem): - Weeks 9–10, 12 → DEEP_DIVES/04_PYTORCH_INTERNALS.md - full layered architecture trace, dispatcher mechanics, autograd engine, completetorch.compilepipeline (Dynamo + AOTAutograd + Inductor), modern custom-op path with Triton, CUDA caching allocator algorithm. - Week 11 →DEEP_DIVES/05_JAX_XLA.md - pure-functional model, jaxpr tracing with annotated examples, full XLA pipeline, GSPMD with Megatron-MLP propagation walkthrough.


Weeks

Week 9 - PyTorch Internals: Tensor, Dispatcher, ATen

9.1 Conceptual Core

  • PyTorch is a layered system:
  • Python frontend-torch.* namespace, what users write.
  • Dispatcher-routes ops to backend implementations based on device, dtype, layout, autograd state, and other "keys."
  • ATen-the C++ tensor library. Each op (add, matmul, softmax) has device-specific implementations (CPU, CUDA, MPS, XPU).
  • Backends-cuBLAS, cuDNN, OneDNN, custom kernels.
  • Every Python tensor op is, fundamentally, a dispatcher call. a + btorch.add(a, b)aten::add → CPU/CUDA add kernel. Understanding this is the foundation for the rest of the month.

9.2 Mechanical Detail

  • Read aten/src/ATen/core/dispatch/Dispatcher.h and DispatchKey.h. The DispatchKey enum names every backend, every layer (autograd, autocast, named tensors, vmap, ...).
  • Dispatch keys stack: a tensor's "key set" determines which dispatcher entries fire and in what order. AutogradCUDA → AutocastCUDA → CUDA, for example.
  • torch::Library macro registers ops:
    TORCH_LIBRARY_IMPL(aten, CUDA, m) {
        m.impl("add.Tensor", &my_add_cuda);
    }
    
  • The Python tensor object is a thin wrapper around at::Tensor, which is a thin wrapper around c10::TensorImpl, which holds a c10::Storage and view metadata (sizes, strides, offset, dtype, device).
  • Strides are critical. A "tensor view" (transpose, slice, narrow) shares storage but rewrites strides. The dispatcher and most ops handle strided tensors transparently; some kernels require contiguous (tensor.contiguous()).

9.3 Lab-"Trace an Op"

  1. From Python, run a + b for two CUDA tensors. Use TORCH_SHOW_DISPATCH_TRACE=1 (or torch._C._dispatch_print_registrations()) to see the dispatcher's path.
  2. Read `aten/src/ATen/native/cuda/BinaryOps.cu - find the actual CUDA kernel for add.
  3. Trace torch.matmul(a, b) similarly. Note that for BF16 it routes to cuBLAS.
  4. Document the call chain in TRACE.md.

9.4 Idiomatic & Diagnostic Drill

  • torch.profiler.profile(activities=[CPU, CUDA]) with record_shapes=True and with_stack=True. Read the table; identify any op spending more than 5% of total time.

9.5 Production Slice

  • Add torch.cuda.synchronize() discipline: every benchmark must sync before timing. CUDA is asynchronous; without sync, you'll measure queue insertion, not execution.

Week 10 - torch.compile, TorchDynamo, Inductor

10.1 Conceptual Core

  • torch.compile (PyTorch 2.0+) is a JIT compiler that captures Python+PyTorch into a graph and compiles it to optimized kernels. The pipeline:
  • TorchDynamo-Python frame evaluation hook; captures bytecode into FX graphs, handles graph breaks for unsupported ops.
  • AOTAutograd-runs both forward and backward through Dynamo, partitions into a joint graph, decomposes high-level ops into a small "core ATen" set.
  • Inductor-the default backend. Lowers the FX graph to Triton kernels (for CUDA) or C++/OpenMP (for CPU). Schedules with kernel fusion.
  • The user-visible promise: ~30-50% speedup on training, more for inference, with one decorator. The reality: graph breaks and silent fallbacks make this a discipline, not a free lunch.

10.2 Mechanical Detail

  • Graph breaks: any operation Dynamo can't trace falls back to eager. Common causes: data-dependent control flow on tensor values, print, custom Python objects, certain if patterns.
  • `torch._dynamo.explain(model)(input) - shows graph breaks with reasons.
  • `TORCH_COMPILE_DEBUG=1 - dumps every stage of compilation. Massive output; useful when debugging perf regressions.
  • Inductor codegen: TORCH_LOGS=output_code shows the generated Triton kernels. Read these-they're surprisingly readable and often reveal optimization opportunities you can replicate by hand.
  • Modes: mode="reduce-overhead" (CUDA graphs), mode="max-autotune" (heavy autotuning), default. Choose for the workload.
  • Caching: compiled artifacts cached in ~/.cache/torch_inductor. First run is slow; subsequent calls are fast.

10.3 Lab-"Compile and Compare"

Take your honest-training-loop from Month 1. Add model = torch.compile(model). Measure: 1. First-step time (compilation cost). 2. Steady-state step time vs uncompiled. 3. With TORCH_LOGS="recompiles": how many recompilations occurred? Why? 4. With mode="max-autotune": extra speed vs default? Worth the compile time?

Triage any graph breaks; report in COMPILE_LOG.md.

10.4 Idiomatic & Diagnostic Drill

  • The "guard" system: every compiled artifact carries assumptions about input shapes, dtypes, requires_grad. A mismatched call recompiles. Dynamic shapes are a special hell-investigate dynamic=True for serving workloads.

10.5 Production Slice

  • For inference, torch.compile + CUDA graphs (mode="reduce-overhead") is the production path. Document the compile-warmup procedure for your serving stack.

Week 11 - JAX, XLA, HLO

11.1 Conceptual Core

  • JAX is a library for composable function transformations: jit, grad, vmap, pmap, shard_map. Underneath, every transformed function is traced into a jaxpr and compiled by XLA.
  • XLA is a domain-specific compiler for linear algebra. Its IR is HLO (High Level Operations)-a small, well-defined op set with parametric shapes. XLA optimizes via fusion, layout assignment, sharding propagation.
  • JAX is favored for: research clarity (functional purity), TPU support (XLA is the TPU compiler), large-scale training (sharding via pjit/shard_map is more elegant than PyTorch's parallelism stack).
  • PyTorch is favored for: ecosystem maturity, eager-mode debugging, NVIDIA hardware-tuning depth.
  • Both are correct answers; the curriculum requires fluency in both because real production stacks use both.

11.2 Mechanical Detail

  • jax.jit traces the function with abstract Tracer arguments, builds a jaxpr, compiles with XLA. The compiled artifact is cached by input shapes/dtypes/static args.
  • jax.grad is reverse-mode AD that operates on jaxprs-purely functional. Closures and side effects don't survive grad.
  • jax.vmap vectorizes a function across a new axis. The classic example: a function that operates on one example becomes a batched function.
  • pjit / shard_map (modern unified jit with sharding annotations): distribute computation across devices.
  • HLO inspection: jax.jit(f).lower(args).compiler_ir(dialect="hlo") returns the HLO text. Read this-same value as reading Inductor's Triton.

11.3 Lab-"JAX Equivalent"

Re-implement your Month 1 training loop in JAX: - Pure-functional model (no nn.Module mutation). - optax for the optimizer. - jax.jit the train step. - Add jax.vmap somewhere meaningfully (e.g., per-example metric computation). - Compare end-to-end throughput with the PyTorch baseline.

11.4 Idiomatic & Diagnostic Drill

  • Inspect the HLO. Identify a fused op produced by XLA that PyTorch+Inductor produced (or didn't produce) for the equivalent code.

11.5 Production Slice

  • For TPU-bound work, JAX is the right tool. Document the JAX install path on Cloud TPU; rent a v4-8 for two hours; run a small training job; capture the HLO and the XLA cost analysis.

Week 12 - Custom Operators: From CUDA Kernel to torch.ops

12.1 Conceptual Core

  • When PyTorch / JAX don't have a fast-enough op for your needs, you write one. The standard path:
  • Implement the kernel (CUDA, Triton, or C++).
  • Wrap with the framework's extension API.
  • Register with the dispatcher.
  • Define the autograd backward (forward + backward = autograd.Function).
  • Optionally support torch.compile via abstract-shape registration.

12.2 Mechanical Detail

  • PyTorch C++ extension (the recommended modern path):
  • setup.py with torch.utils.cpp_extension.CUDAExtension.
  • C++/CUDA source with pybind11 bindings.
  • Built at install time; loadable as import myop.
  • torch.library API (PyTorch 2.x) for dispatcher integration without C++:
    @torch.library.custom_op("myns::myop", mutates_args=())
    def myop(x: torch.Tensor) -> torch.Tensor:
        return _my_triton_kernel(x)
    
    @myop.register_fake
    def _(x):
        return torch.empty_like(x)  # for compile/dynamo
    
  • Backward registration: torch.library.register_autograd("myns::myop", backward_fn).
  • Triton-as-custom-op: torch.compile recognizes Triton kernels and integrates them into the compiled graph without a graph break-the modern preferred path.

12.3 Lab-"RMSNorm From Scratch"

RMSNorm is used in modern LLMs (Llama, Qwen). Implement it three ways: 1. PyTorch: pure tensor ops. 2. Triton custom op: a fused kernel that reads input, computes RMS, normalizes, scales-all in one pass over HBM. 3. CUDA C++ extension: same kernel in CUDA C++ with a pybind11 binding.

For each: forward + backward, autograd-correct (numerical-grad test), benchmarked vs the others on (B, S, H) = (8, 4096, 4096) BF16. Your fused Triton version should beat PyTorch by 3-5×.

12.4 Idiomatic & Diagnostic Drill

  • Test your custom op under torch.compile. Verify it doesn't break the graph (check torch._dynamo.explain).

12.5 Production Slice

  • Custom ops in production must ship binary artifacts compatible with the user's PyTorch version. Use torch.ops.load_library for shared-library loading; pin PyTorch ABI.

Month 3 Capstone Deliverable

A framework-internals/ directory: 1. dispatcher-trace/ (week 9)-the annotated walk through ATen. 2. compile-bench/ (week 10)-torch.compile measurements + graph-break triage. 3. jax-baseline/ (week 11)-JAX training loop matching the PyTorch baseline; HLO analysis. 4. rmsnorm-fused/ (week 12)-three implementations, benchmark plot, autograd tests.

By end of month you should be comfortable reading framework source-the literacy that distinguishes systems engineers from framework users.


  • The torch.compile design doc on pytorch.org/docs/.
  • The Inductor design doc.
  • The JAX "How JAX primitives work" guide.
  • The XLA HLO operation semantics page.
  • The PyTorch dispatcher tutorial in pytorch/pytorch/wiki.

Month 4-Distributed Training: NCCL, DDP, FSDP, Tensor & Pipeline Parallelism, FP8

Goal: by the end of week 16 you can (a) explain the ring-allreduce algorithm and predict its bandwidth, (b) train a model on 8 GPUs across 2 nodes with FSDP achieving >85% scaling efficiency, (c) implement tensor-parallel attention by hand, and (d) reason about 3D parallelism schedules and FP8 training stability.

Deep-dive companions (read in tandem): - Weeks 13–15 (all) → DEEP_DIVES/06_DISTRIBUTED_TRAINING.md - derivation of all 5 all-reduce algorithms (with ring-allreduce bandwidth-optimality proof), full ZeRO-1/2/3 memory-math table, Megatron column→row partition derivations, ASCII pipeline schedules (GPipe, 1F1B, Interleaved 1F1B, Zero Bubble) with bubble formulas, 3D parallelism worked examples for 8B/70B/405B. - Week 16 →DEEP_DIVES/11_NUMERICS_AND_MIXED_PRECISION.md - IEEE-754 derivation, FP16/BF16/FP8 layouts, full loss-scaling derivation including dynamic GradScaler, FP8 with delayed scaling, Adam-low-precision pitfall, catastrophic cancellation, transformer numerical-stability tricks.


Weeks

Week 13 - Communication Primitives: NCCL, Allreduce, Topology

13.1 Conceptual Core

  • Distributed training reduces to collective communication: at certain points, every GPU's tensor must combine with every other GPU's. The fundamental collectives:
  • Allreduce (sum tensors across all ranks; result on all ranks). The workhorse-used for gradient sync in data parallelism.
  • Allgather (concatenate per-rank tensors; result on all ranks). Used for sharded ops.
  • Reduce-scatter (sum then shard; each rank gets a piece). Used in ZeRO/FSDP.
  • Broadcast (one rank → all). Initialization, parameter sync.
  • All-to-all (every rank exchanges with every other). Used in MoE expert routing.
  • NCCL (NVIDIA Collective Communication Library) is the canonical implementation on NVIDIA GPUs. AMD/RCCL is the equivalent. Both implement the same API.
  • The algorithms matter:
  • Ring-allreduce: the gold standard. Each rank sends its chunk to the next, accumulating. After 2(N-1) steps (N = ranks), all ranks have the sum. Bandwidth-optimal at scale.
  • Tree-allreduce: lower latency at small messages, scales poorly.
  • Hierarchical: ring within a node (NVLink), ring across nodes (InfiniBand). NCCL chooses automatically.

13.2 Mechanical Detail

  • NVLink: GPU-to-GPU interconnect within a node. ~900 GB/s on H100 (NVLink 4). Matters because intra-node allreduce moves at NVLink speed; inter-node moves at NIC speed (typically 200-400 Gbps InfiniBand or 800 Gbps for newer fabrics-i.e., 25-100 GB/s, 10-30× slower).
  • Topology: 8-GPU H100 nodes typically use NVSwitch-full bisection bandwidth between any pair. Cross-node uses NDR/HDR InfiniBand or RoCE. The nvidia-smi topo -m command shows the matrix.
  • torch.distributed wraps NCCL. init_process_group(backend='nccl'), dist.all_reduce(tensor, op=ReduceOp.SUM). Synchronous by default; async via async_op=True returns a Work handle.
  • Ring-allreduce bandwidth analysis: for tensor of size B bytes across N ranks, each rank sends 2(N-1)/N · B bytes. Time ≈ 2(N-1)/N · B / link_bandwidth. The 2× hides the gradient sync's true cost.

13.3 Lab-"Allreduce Bench"

On at least 2 GPUs (single node fine), run an allreduce benchmark: 1. torch.distributed.all_reduce on tensors from 1 KB to 1 GB. 2. Compute achieved bandwidth (= 2(N-1)/N · message_size / time). 3. Plot bandwidth vs message size; identify the message size at which BW saturates (the "knee"). 4. If you have access: run on 8 GPUs via single node (NVLink) and compare to 8 GPUs across 2 nodes (InfiniBand). Document the gap.

13.4 Idiomatic & Diagnostic Drill

  • NCCL_DEBUG=INFO produces verbose NCCL output. Read one full session's output; identify the chosen algorithm (ring/tree/CollNet) and the topology NCCL inferred.

13.5 Production Slice

  • Document your cluster's NIC count, IB rail, GPU-NIC affinity. NCCL's perf hinges on NCCL_IB_HCA, NCCL_SOCKET_IFNAME, NCCL_TOPO_FILE settings-wrong defaults can halve throughput.

Week 14 - Data Parallelism: DDP, ZeRO, FSDP

14.1 Conceptual Core

  • DDP (Distributed Data Parallel): every rank holds the full model. Each step, ranks process different micro-batches; gradients are allreduced before optimizer step. Simple, memory-inefficient (model + optimizer states replicated N times).
  • ZeRO (Zero Redundancy Optimizer; Rajbhandari et al., SC 2020): observation that DDP wastes memory by replicating optimizer states (8x model size for Adam in FP32 momentum + variance). Three stages:
  • ZeRO-1: shard optimizer states across ranks. Saves N× on optimizer memory.
  • ZeRO-2: also shard gradients.
  • ZeRO-3: also shard parameters. Fetches them just-in-time per layer via allgather; reduces gradients via reduce-scatter.
  • FSDP (Fully Sharded Data Parallel; PyTorch's implementation of ZeRO-3): the modern standard for training models that don't fit in single-GPU memory under DDP.

14.2 Mechanical Detail

  • FSDP wraps modules. Each FSDP unit (typically a transformer layer) shards parameters across all ranks. Forward:
  • Allgather parameters of unit i (just before its forward).
  • Compute forward of unit i.
  • Free non-local parameters.
  • Move to unit i+1.
  • Backward is symmetric, with reduce-scatter for gradients.
  • Mixed precision in FSDP: parameters in BF16 for compute, gradients in FP32 for stability, optimizer states in FP32. Configurable via MixedPrecision.
  • Activation checkpointing: instead of storing all activations for backward, store only at unit boundaries; recompute the forward when needed. Trades ~30% extra compute for ~3× less activation memory. Essential for long-context LLM training.
  • CPU offload: optimizer states or even parameters can spill to CPU RAM (slower but allows larger models). Used in low-memory regimes; avoid in production HPC.

14.3 Lab-"FSDP a Small Model"

On 4-8 GPUs (single node fine): 1. Train a 1B-parameter transformer in FSDP. Use transformer_auto_wrap_policy. 2. Compare memory and throughput: DDP-OOM-baseline (small model) vs FSDP small vs FSDP same-model-larger. 3. Add activation checkpointing; re-measure. 4. Add CPU offload; observe the speed cost. 5. Compute scaling efficiency (throughput_8gpu / (8 × throughput_1gpu)).

14.4 Idiomatic & Diagnostic Drill

  • torch.profiler with with_stack=True on an FSDP step. Identify the allgather and reduce-scatter calls; measure their fraction of step time.

14.5 Production Slice

  • FSDP's BackwardPrefetch.BACKWARD_PRE overlaps backward compute with next-layer's allgather. Verify it's enabled; without it, large models leave 20-30% perf on the table.

Week 15 - Tensor Parallelism and Pipeline Parallelism

15.1 Conceptual Core

  • Tensor parallelism (TP): shard individual layers across ranks. The classic Megatron-LM approach (Shoeybi et al., 2019):
  • Column-parallel linear: shard weight matrix by output dim. Each rank computes a slice of output; allgather to combine if needed.
  • Row-parallel linear: shard by input dim. Each rank computes a partial output; allreduce to sum.
  • Attention: shard heads across ranks. Each rank computes its assigned heads; results allreduced before output projection.
  • TP requires fast (intra-node, NVLink) communication. Across nodes, the latency dominates.
  • Pipeline parallelism (PP): split layers across ranks. Each rank holds layers L_i to L_j. Activations flow forward; gradients flow backward.
  • Naive PP wastes most ranks' time (bubble). GPipe (Huang et al., 2018) microbatches to fill the pipeline. 1F1B schedule (interleaved forward and backward) further reduces bubble.
  • 3D parallelism: combine TP (intra-node) + PP (inter-node) + DP (across pipeline stages). Each model parameter is sharded along three axes. The GPT-3, PaLM, Llama-3 405B training recipes all use 3D parallelism.

15.2 Mechanical Detail

  • TP communication: per-layer allreduce in row-parallel; allgather in column-parallel. Latency-sensitive; do TP within a node where NVLink bandwidth dominates.
  • PP communication: only at stage boundaries; activations forward + gradients backward. Bandwidth-friendly; do PP across nodes.
  • DP: gradient sync once per step; very bandwidth-friendly; do DP across pipeline stages.
  • The choice of which parallelism along which axis is governed by:
  • TP degree ≤ GPUs-per-node (NVLink scope).
  • PP degree determined by memory needs (each stage holds layers + activations).
  • DP fills the rest.
  • Megatron-LM and DeepSpeed are the two open-source 3D-parallelism stacks. Modern PyTorch's Pipeline Parallelism API (torch.distributed.pipelining, stable as of 2.4) and DTensor (torch.distributed.tensor) are converging on a unified path.

15.3 Lab-"Implement Tensor-Parallel Attention"

By hand, in pure PyTorch + torch.distributed: 1. Implement the Megatron-style tensor-parallel multi-head attention: column-parallel QKV projection, sharded heads, row-parallel output projection. 2. Verify numerically against a single-GPU reference for correctness (allclose to atol=1e-3). 3. Benchmark on 4 GPUs vs 1-GPU baseline. Compute scaling efficiency.

15.4 Idiomatic & Diagnostic Drill

  • Use nsys profile to see the timeline of allreduce vs compute. Tensor parallelism's signature is short, frequent allreduces.

15.5 Production Slice

  • Capture the topology decision: for your hardware (e.g., 8x H100 per node, 16 nodes), what TP/PP/DP degrees do you choose for a 70B model? Document the math.

Week 16 - Mixed Precision, FP8, Numerical Stability at Scale

16.1 Conceptual Core

  • Modern training uses multiple precisions simultaneously:
  • Compute (matmul) in BF16 / FP8.
  • Master weights in FP32.
  • Optimizer states in FP32 (Adam: momentum + variance).
  • Gradients accumulated in FP32 to avoid loss-scale issues.
  • FP8 training (Hopper / Blackwell): two formats-E4M3 (more mantissa, less range, used for activations/weights) and E5M2 (more range, used for gradients). NVIDIA TransformerEngine library handles the casting.
  • The challenges of low-precision training:
  • Loss scaling (FP16): scale loss by a power of 2 before backward to prevent gradient underflow; unscale before optimizer. Done automatically by GradScaler.
  • Per-tensor scaling (FP8): each tensor needs its own scale factor (the max abs value); recomputed every step. This is delicate; TransformerEngine handles it.
  • Numerical stability: occasional NaNs from low-precision overflow. Detect and handle (skip step or reduce LR).

16.2 Mechanical Detail

  • NVIDIA TransformerEngine (transformer-engine Python package): drop-in replacement layers (te.Linear, te.LayerNorm, te.TransformerLayer) that use FP8 internally with auto-scaling.
  • Activation memory dominates training memory at long contexts. With BF16 activations, a 32K-context sequence at 8B params can need ~80 GB activations alone. FP8 halves this; FlashAttention reduces it further.
  • Gradient accumulation steps × per-step batch = effective batch size. Memory is per-step batch; convergence is governed by effective batch.
  • Communication / compute overlap: with FSDP, you want allgather of layer N+1 to overlap with compute of layer N. With TP, you want the backward allreduce to overlap with the next layer's forward of the next microbatch. Profile to verify; absent overlap, you're leaving 20-50%.

16.3 Lab-"FP8 Train a Small Model"

On at least one H100/H200/B200 (you may need to rent for a day): 1. Take your week 14 FSDP setup. Replace all linear layers with te.Linear. Wrap blocks with te.fp8_autocast. 2. Train the same model in BF16 vs FP8. Compare: - Throughput. - Memory. - Loss curve (the test of stability-FP8 should match BF16 within noise). 3. Document any NaN events and recovery actions.

If H100+ is unavailable, do this lab in BF16 + torch.cuda.amp, comparing against FP32. The instability dynamics are similar at lower stakes.

16.4 Idiomatic & Diagnostic Drill

  • Track per-tensor scale factors across training steps. Sudden scale-factor drops indicate impending NaN; sudden rises indicate underflow risk on the next step.

16.5 Production Slice

  • Build a "training health" dashboard: loss, grad norm, parameter norm, scale factors, throughput, memory. Alert on grad-norm spikes (data quality issue or instability) and on plateaued loss without grad-norm decrease (saturated learning).

Month 4 Capstone Deliverable

A distributed-training/ directory: 1. allreduce-bench/ (week 13)-bandwidth measurements + topology doc. 2. fsdp-scaling/ (week 14)-scaling efficiency study. 3. tp-attention/ (week 15)-hand-rolled TP attention + benchmark. 4. fp8-train/ (week 16)-FP8 vs BF16 comparison.

A PARALLELISM_GUIDE.md decision matrix: given (model size, GPU count, NVLink topology, IB bandwidth), what 3D-parallelism degrees do you pick?


  • The ZeRO paper (SC 2020).
  • The Megatron-LM paper.
  • GPipe (NeurIPS 2019) and PipeDream (SOSP 2019) for pipeline parallelism.
  • The NCCL design doc.
  • Narayanan et al., Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM (SC 2021)-the canonical 3D-parallelism paper.
  • Micikevicius et al., FP8 Formats for Deep Learning (2022).

Month 5-Inference Systems: KV-Cache, Paged Attention, Continuous Batching, Quantization, Speculative Decoding

Goal: by the end of week 20 you can (a) explain why LLM inference is bandwidth-bound and how the KV-cache changes everything, (b) implement a paged KV-cache from scratch, (c) reason about quantization (INT8/INT4/FP8) tradeoffs and pick a scheme for a workload, and (d) implement a basic speculative-decoding loop.

This month is the commercial heart of AI systems engineering. Frontier-lab inference economics are dominated by these techniques.

Deep-dive companions (read in tandem): - Week 17 → DEEP_DIVES/07_ATTENTION_TRANSFORMER.md - attention math from first principles, RoPE complex-number derivation, KV-cache memory math with worked Llama-3-70B example, full FlashAttention online-softmax derivation with inductive proof. - Week 18 →DEEP_DIVES/08_INFERENCE_SERVING.md - cost-model derivation, PagedAttention block-pool algorithm, Orca scheduler pseudocode, vLLM architecture, chunked prefill, prefix caching, disaggregation. - Week 19 → DEEP_DIVES/09_QUANTIZATION.md - number-format derivations, AWQ identity proof with numerical example, GPTQ derived from Optimal Brain Surgeon with Cholesky efficiency, SmoothQuant α derivation, FP8 with delayed scaling, Marlin kernel. - Week 20 →DEEP_DIVES/10_SPECULATIVE_DISAGGREGATION.md - speculative decoding rejection-sampling proof, speedup formula, geometric acceptance model, tree speculation, DistServe/Mooncake/Splitwise architectures, full production-stack composition.


Weeks

Week 17 - LLM Inference, the KV-Cache, Attention Math

17.1 Conceptual Core

  • LLM inference has two distinct phases:
  • Prefill: process all input tokens in parallel. Compute-bound (a big matmul). Latency dominated by sequence length × hidden dim.
  • Decode: generate tokens one at a time, each requiring a forward pass attending to all previous tokens. Severely memory-bound. Latency dominated by reading model weights + KV-cache from HBM.
  • The KV-cache: at decode step t, the model needs all previous keys and values to attend to. Recomputing them every step is O(t²). Caching them is O(t). The cache is large: for a 70B model with 8K context, ~10 GB per request.
  • Why decode is memory-bound: each generated token reads ~70 GB of weights and ~10 GB of KV-cache → ~80 GB from HBM. At ~3 TB/s HBM, that's ~25 ms minimum, regardless of compute. This is why decode rarely uses tensor cores efficiently.

17.2 Mechanical Detail

  • Standard attention: O = softmax(QK^T / √d_k) V. For a single decode step on the t-th token, Q is (1, d), K and V are (t, d).
  • The KV-cache layout matters:
  • `(num_layers, 2, batch, num_heads, max_seq_len, head_dim) - naive contiguous. Wasteful: pre-allocates max_seq_len for every request even short ones.
  • Paged (vLLM): (num_layers, 2, num_blocks, num_heads, block_size, head_dim) plus per-request page tables. Fragmentation gone; sharing across requests possible (prefix caching).
  • FlashAttention (Dao et al., 2022 → v3 by 2024): tiles attention so the K/V never materialize a full t×t score matrix. Streaming: process Q in chunks, online softmax. Reduces HBM access from O(t²) to O(t·d). Critical for long contexts.
  • Multi-Query Attention (MQA) and Grouped-Query Attention (GQA): share K/V heads across query heads. MQA: 1 KV head per layer (8× shrinkage on Llama-2-70B style). GQA: groups (e.g., 8 KV heads shared across 64 query heads). Modern open models (Llama-3, Qwen) all use GQA.

17.3 Lab-"Decode From Scratch"

  1. Implement greedy decoding for a small Hugging Face model (Llama-3-8B works on a single A100; smaller for L4):
  2. Prefill once, capture KV-cache.
  3. Decode loop: forward(token, kv_cache) → next_token.
  4. Append next_token to KV-cache.
  5. Measure tokens/sec. Compute the achieved HBM BW (model weights × tokens / time).
  6. Replace standard attention with flash_attn_with_kvcache. Re-measure.
  7. Document the decode-vs-prefill latency split for a 1K-prefill, 512-decode request.

17.4 Idiomatic & Diagnostic Drill

  • nsys profile a decode step. The expected pattern: long matmuls (model weights load) interspersed with small attention kernels. Tensor cores ~10-20% utilized.

17.5 Production Slice

  • Inference cost is dominated by decode tokens × hardware-hour. Build a one-page estimator: given (model size, GPU type, batch size, concurrent requests), what is your tokens-per-dollar?

Week 18 - Paged Attention, Continuous Batching, vLLM

18.1 Conceptual Core

  • The two ideas that made open-source LLM serving competitive with closed APIs:
  • Paged attention (Kwon et al., SOSP 2023): manage the KV-cache like virtual memory. Fixed-size blocks (e.g., 16 tokens each) allocated from a pool; per-request page tables map logical token positions to physical blocks. Eliminates fragmentation; enables prefix sharing.
  • Continuous batching (Yu et al., OSDI 2022 / Orca): instead of batching requests at start, dynamically schedule per-step. Finished requests leave the batch immediately; new requests join. Decode batches stay full.
  • vLLM combines both. Result: ~5-20× higher throughput vs naive HuggingFace generation.

18.2 Mechanical Detail

  • Paged attention kernel: takes Q (current step), K/V cache pool, and per-request block tables. Each query attends to its blocks, gathered via the block table. The kernel handles non-contiguous K/V-modest perf cost (~10%) for huge memory savings.
  • Continuous batching scheduler (Orca's "iteration-level scheduling"):
  • Maintain a queue of pending requests.
  • Each iteration, pick a batch satisfying memory budget (sum of KV-cache sizes ≤ available pool).
  • Execute one decode step; emit any finished tokens; release any finished requests' blocks.
  • Loop.
  • Prefill / decode disaggregation (an active research area; production-deployed in 2024-2026 by major labs):
  • Prefill is compute-bound; benefits from large batches and TP.
  • Decode is memory-bound; benefits from many concurrent requests sharing weights.
  • Run them on different hardware tiers. The token-streaming protocol between prefill and decode workers is non-trivial.
  • Prefix caching / speculative prefix caching: identical prompt prefixes share blocks. Critical for common system-prompt patterns and multi-turn dialogs.

18.3 Lab-"vLLM Internals"

  1. Install vLLM. Serve a 7B model. Run a load test (benchmark_serving.py) at various concurrency levels.
  2. Read vllm/core/scheduler.py and vllm/attention/backends/flash_attn.py end-to-end. Annotate the scheduler's iteration loop.
  3. Build a mini-scheduler in Python (not for prod; for understanding): manages a fixed pool of KV blocks, schedules decode steps, evicts on memory pressure. Use real model forward via vLLM's lower-level APIs or HuggingFace.
  4. Compare throughput of your mini-scheduler vs vLLM proper. The gap is likely 5-20×-that gap is your education.

18.4 Idiomatic & Diagnostic Drill

  • vLLM exposes Prometheus metrics. Capture: requests-running, requests-waiting, GPU-cache-usage, time-to-first-token (TTFT), time-per-output-token (TPOT). These are the SLOs for LLM serving.

18.5 Production Slice

  • Tune vLLM for your workload: gpu-memory-utilization, max-num-batched-tokens, enable-prefix-caching, swap-space. Each is a real lever. Document chosen values + rationale in a SERVING.md.

Week 19 - Quantization: INT8, INT4, FP8, AWQ, GPTQ, SmoothQuant

19.1 Conceptual Core

  • Quantization reduces precision of weights (and sometimes activations) to shrink memory and increase throughput.
  • The categories:
  • Weight-only quantization (W8A16, W4A16): weights INT8 or INT4; activations BF16. Cuts decode HBM traffic by 2-4×-huge for memory-bound decode.
  • Weight + activation quantization (W8A8, W4A8): both quantized. Tensor cores can be invoked at lower precision; throughput wins but stability harder.
  • Per-tensor / per-channel / per-group: granularity of the scale factor. Per-group (e.g., group size 128) is the modern standard-balances size and accuracy.
  • The methods:
  • Post-Training Quantization (PTQ): quantize a trained model in minutes-to-hours. Methods: AWQ, GPTQ, SmoothQuant. The 2026 standard for production.
  • Quantization-Aware Training (QAT): train with simulated quantization. More accurate, much more expensive. Used for highest-accuracy scenarios.
  • AWQ (Lin et al., MLSys 2024): observation that activation outlier channels matter more; scale them up before quantization to preserve accuracy. Standard for INT4 weight-only.
  • GPTQ (Frantar et al., ICLR 2023): optimal-brain-surgeon-style quantization; one layer at a time, calibration-data driven.
  • FP8 for inference (Hopper+): native hardware support. E4M3 for weights/activations, E5M2 for gradients (training only). Production-deployed at major labs.

19.2 Mechanical Detail

  • Storage format:
  • INT8: 1 byte per weight + 1 scale per group.
  • INT4 (packed): 2 weights per byte + scale. Need a "dequant" kernel that unpacks.
  • Compute:
  • W4A16: dequantize to BF16 just before matmul (the gemv in decode is memory-bound anyway, so dequant doesn't hurt). The matmul itself runs in BF16 on tensor cores.
  • W8A8: matmul runs in INT8 tensor cores (mma.s8). Higher throughput, requires careful scale handling.
  • Library landscape: bitsandbytes (W8/W4 with LoRA), AutoAWQ, AutoGPTQ, llama.cpp's GGUF formats (k-quants), TensorRT-LLM (production NVIDIA path), Marlin (W4A16 fast kernels).

19.3 Lab-"Quantize and Compare"

On a 7B-13B model: 1. Run baseline BF16 inference. Capture TTFT, TPOT, model size, throughput. 2. Quantize with AWQ (W4A16). Re-measure. Eval on a small held-out set (e.g., MMLU 200-question subset, or perplexity on Wikitext) for accuracy. 3. Quantize with FP8 (if on Hopper+). Re-measure. 4. Optionally: GPTQ comparison, AWQ INT8 comparison. 5. Build a tradeoff matrix: throughput, memory, perplexity / accuracy.

19.4 Idiomatic & Diagnostic Drill

  • A quantized kernel must handle calibration: gather activation statistics on representative inputs. Document your calibration set and how you chose it (random docs ≠ production traffic).

19.5 Production Slice

  • Quantization in production must be reproducible. The quantized weights are a new artifact that must be versioned, signed, and SBOM'd just like training artifacts. Treat as such.

Week 20 - Speculative Decoding, Disaggregation, Inference Frontiers

20.1 Conceptual Core

  • Speculative decoding (Leviathan et al., ICML 2023): use a small "draft" model to generate K candidate tokens; verify with the large "target" model in parallel; accept the longest accepted prefix. Gains ~2-3× tokens/sec when the draft model agrees often.
  • Why it works: target model verification is one prefill of K tokens, not K decode steps. Prefill is compute-bound; can use tensor cores efficiently. Decode steps are memory-bound; saving them is gold.
  • Variants:
  • Vanilla speculative: separate draft model.
  • Self-speculative / Medusa: multiple decoding heads on the same model.
  • EAGLE / EAGLE-2: train auxiliary heads to predict multiple tokens.
  • Lookahead decoding: no draft model; uses n-gram patterns from the model's own generation.
  • Prefill/decode disaggregation: covered in week 18. Scaling-out the prefill workers and decode workers separately, with KV-cache transfer between them via fast networking (RDMA).

20.2 Mechanical Detail

  • Speculative loop:
    while not done:
        drafts = draft_model.generate(K)      # K candidate tokens
        logits = target_model.forward(drafts) # parallel verify
        accepted = longest_prefix_where(target_logits agrees with drafts)
        emit accepted tokens
    
  • Acceptance rate depends on draft-target agreement. Llama-3-8B drafting Llama-3-70B: ~70-80% agreement on typical prompts. Acceptance rate × draft length = expected gain.
  • Disaggregation requires KV-cache transfer. Approaches: RDMA (Mooncake), shared-memory pools (ZSpread), distributed object stores. State-of-the-art papers from 2024-2025: Mooncake, DistServe, Splitwise.

20.3 Lab-"Speculative Decoding"

  1. Pair a small model (1B) drafting a larger model (7-13B).
  2. Implement vanilla speculative decoding: draft-then-verify.
  3. Measure: acceptance rate, tokens/sec gain, vs baseline single-model decoding.
  4. Tune K (draft length); sweep; identify the sweet spot for your workload.

20.4 Idiomatic & Diagnostic Drill

  • Speculative decoding's wins disappear under high concurrency (the target model's batch is already big and fills tensor cores). Profile under varying concurrency; document the regime where speculation helps.

20.5 Production Slice

  • The current 2026 inference frontier is disaggregation + speculation + paged caching + prefix sharing + quantization, all simultaneously. Document a hypothetical architecture combining all five, with the quantitative contribution of each. This is the design exercise for any senior inference role.

Month 5 Capstone Deliverable

A inference-systems/ directory: 1. decode-from-scratch/ (week 17)-KV-cache + FlashAttention. 2. mini-vllm/ (week 18)-your mini paged scheduler vs real vLLM. 3. quantization-bench/ (week 19)-AWQ / FP8 / BF16 tradeoff matrix. 4. speculative/ (week 20)-speculative decoding harness with sweep.

A LLM_SERVING.md documenting: the scheduling model, the cost-per-token calculation, and the roadmap to incorporate disaggregation.


  • PagedAttention / vLLM (SOSP 2023). Read twice.
  • Orca (OSDI 2022).
  • FlashAttention v1 → v2 → v3 papers. v3 (Hopper FP8) is the 2024 paper.
  • AWQ (MLSys 2024) and GPTQ (ICLR 2023).
  • Speculative Decoding (Leviathan et al., ICML 2023).
  • Mooncake (FAST 2025), DistServe (OSDI 2024) for disaggregation state-of-the-art.

Month 6-ML Infrastructure, Safety & Eval Infra, Capstone Defense

Goal: by the end of week 24 you have integrated the prior five months into one coherent capstone artifact, you can operate ML workloads on Kubernetes idiomatically, you understand the eval/safety infrastructure that gates production model deploys, and you can defend every design decision in a senior-level interview.


Weeks

Week 21 - ML on Kubernetes: KServe, KubeRay, Volcano, GPU Operators

21.1 Conceptual Core

  • Kubernetes is the dominant control plane for ML workloads in production. Three dimensions:
  • Training orchestration-schedule large multi-GPU jobs, handle preemption, gang-schedule (all-or-nothing). Tools: Volcano, KubeRay, Kueue, JobSet.
  • Inference serving-model deployment, autoscaling, traffic routing, A/B. Tools: KServe, Seldon Core, vLLM-on-K8s, Ray Serve, Triton Inference Server.
  • GPU resource management-driver installation, device plugins, MIG partitioning, time-slicing. Tools: NVIDIA GPU Operator, AMD GPU Operator.

21.2 Mechanical Detail

  • NVIDIA GPU Operator: deploys driver, container toolkit, device plugin, DCGM metrics exporter, MIG manager, all as DaemonSets. Pod requests nvidia.com/gpu: 1 and the device plugin allocates.
  • MIG (Multi-Instance GPU): A100/H100 hardware partitioning. One A100 → up to 7 isolated GPU slices. Useful for many small workloads on big GPUs; not for training. Configured via the GPU Operator's MIG manager.
  • Volcano / Kueue: gang scheduling-a 64-GPU job won't start until 64 GPUs are simultaneously available. Naive K8s default scheduler will partial-schedule and deadlock.
  • KubeRay: operator for Ray clusters. Ray is the de-facto distributed Python compute (used heavily by AI labs for data preprocessing, RLHF rollouts, hyperparameter sweeps).
  • KServe + vLLM: the canonical inference stack. KServe InferenceService CRD wraps vLLM (or Triton, or TGI) with autoscaling, canary, transformer pre/post-processing.

21.3 Lab-"Train and Serve on K8s"

  1. Bring up a small GPU-enabled cluster (kind+nvidia, or a 2-node cloud cluster with 1-2 GPUs each).
  2. Install GPU Operator. Verify kubectl describe node shows nvidia.com/gpu: N.
  3. Install Volcano. Submit a 4-GPU gang-scheduled training job (a small FSDP run from week 14).
  4. Install KServe + vLLM runtime. Deploy a 7B model. Hit it with a load test. Demonstrate autoscaling.
  5. Document the YAML for each in a deployable repo.

21.4 Idiomatic & Diagnostic Drill

  • DCGM metrics (DCGM_FI_DEV_GPU_UTIL, DCGM_FI_DEV_FB_USED, DCGM_FI_PROF_PIPE_TENSOR_ACTIVE) exported to Prometheus. Read them.

21.5 Production Slice

  • A real production GPU fleet has cost, capacity, utilization, and reliability dashboards. Build a Grafana dashboard with the four. Bookmark.

Week 22 - Observability, Cost, Eval Pipelines, MLOps

22.1 Conceptual Core

  • ML observability has two layers above the system observability you already learned:
  • Model observability: prediction distribution drift, input feature drift, output quality (toxicity, refusals, hallucinations). Tools: Arize, Fiddler, WhyLabs, Langfuse.
  • Eval pipelines: continuous evaluation on benchmarks (MMLU, HumanEval, internal eval sets). Tools: lm-evaluation-harness, OpenAI Evals, Inspect AI, internal bespoke harnesses.
  • Cost observability: per-team / per-product / per-feature attribution. GPU-hours × $/hour, plus fixed-cost amortization. OpenCost (week 22 of K8s curriculum) plus model-aware tagging.

22.2 Mechanical Detail

  • Eval-as-CI: every model checkpoint runs through a fixed eval suite. Regressions block promotion. The "tests" of ML.
  • Tracing for LLM applications (vs traditional traces): a single user request fans out to multiple LLM calls, embeddings, retrievals, tool uses. OTel + Langfuse / LangSmith capture the call tree with prompts, responses, latencies, costs.
  • Drift detection:
  • Input drift: PSI (Population Stability Index), KS test on feature distributions.
  • Output drift: change in label/output distribution. For LLMs: monitor refusal rate, response length, toxicity scores.
  • Concept drift: relationship between input and label changes. Hardest to detect.
  • A/B and canary: traffic-split with measured metrics. KServe's canary support handles the routing; the metric aggregation is yours to build.

22.3 Lab-"Eval and Drift Pipeline"

  1. Build a CI pipeline: on every model push, run lm-evaluation-harness on a fixed subset (MMLU 500-question, HumanEval pass@1).
  2. Compare against a baseline; fail the pipeline on >2% regression.
  3. Wire production traffic samples into a drift dashboard: input length distribution, output length distribution, refusal rate, fraction of failed JSON-mode outputs.
  4. Synthetic drift: shift the input distribution (longer prompts) and verify the dashboard catches it.

22.4 Idiomatic & Diagnostic Drill

  • Cost/quality Pareto: every eval run captures both quality scores and inference cost. The dashboard is cost vs quality per model-the unit of decision-making for production model selection.

22.5 Production Slice

  • Document an "incident response for model regressions" runbook: detection → roll-back via traffic-split → investigate → fix → re-promote. The same shape as software incidents, with model-specific specifics.

Week 23 - Safety, Red-Teaming, Alignment Infrastructure

23.1 Conceptual Core

  • Production AI systems require a safety layer above the model: input filters, output filters, refusal handling, abuse detection. The model is one component of a safety-enforcing system.
  • The dominant patterns:
  • Input classification: detect prompt-injection, jailbreak attempts, content-policy violations before invoking the model. Cheap classifier or small LLM.
  • Output classification: detect policy-violating output before returning. Same pattern.
  • Constrained decoding: structural constraints during generation (JSON schema, regex, grammar). Reduces "the model said something invalid" failure modes.
  • Refusal handling: tasteful refusals. The hardest part-bad refusals (over-refuses) are nearly as bad as harmful outputs (under-refuses).
  • Red-teaming infrastructure: continuous adversarial probing of deployed models. Scaled red-teaming is itself an LLM workload.

23.2 Mechanical Detail

  • Constrained decoding tools: outlines (CFG-based), guidance, jsonformer, vLLM's guided_decoding (uses xgrammar / outlines under the hood). Performance: constraints add ~10-30% latency overhead; usually worth it.
  • Safety classifiers: Llama Guard, ShieldGemma, NVIDIA NeMo Guardrails, Anthropic's content moderation API patterns. Latency is the constraint; usually deploy as a small parallel model.
  • Audit logging for AI: every inference request logged with input, output, classifier decisions, model version, request ID. Required for:
  • Compliance (regulators increasingly want this).
  • Debugging.
  • Eval-from-production (resampling production traffic for offline eval).
  • Red-teaming harnesses: PyRIT (Microsoft), Garak, internal bespoke. Run nightly; failures are P1 issues.

23.3 Lab-"A Safety Layer"

Take your week 21 vLLM deployment. Add: 1. Input classifier (Llama Guard or a small custom classifier)-block obvious prompt injections. 2. Output classifier-block policy-violating outputs. 3. Constrained-decoding mode for any structured-output endpoint. 4. Audit logging to a separate, append-only store. 5. A nightly red-teaming job that fires 1000 adversarial prompts; measures failure rate; alerts on regression.

23.4 Idiomatic & Diagnostic Drill

  • The cost of safety: measure latency overhead and quality impact (does the safety layer cause false-positive refusals on benign prompts?). Track both.

23.5 Production Slice

  • Safety infrastructure is itself a software system. It needs versioning, eval, regression testing, on-call. Treat the safety classifier with the same MLOps rigor as the main model.

Week 24 - Capstone Integration & Defense

24.1 Conceptual Core

The final week is integration, not new content. Bring your chosen capstone (see CAPSTONE_PROJECTS.md) to defensible quality.

24.2 Final Hardening Checklist

  • Reproducible training/inference runs: pinned PyTorch/CUDA/driver versions, seed everything, document determinism guarantees.
  • Benchmarks: throughput, latency, cost, scaling efficiency. All committed.
  • Profiles: nsys and ncu reports for at least one kernel hot path; flame graphs for at least one end-to-end run.
  • Observability: GPU util, memory, communication overhead, request-level metrics-all in Prometheus or equivalent.
  • Cost: a documented cost-per-token (inference) or cost-per-step (training).
  • Safety (if inference): input/output classification, constrained decoding, audit logging.
  • Eval: regression suite that gates merges; baseline + thresholds documented.
  • Repro environment: a Dockerfile + a make demo target that brings up the artifact end-to-end on a fresh machine.
  • Defensible decisions: ADRs (≥3) for the non-obvious choices.
  • Threat model: input attacks, infrastructure attacks, supply-chain risk.

24.3 Lab-"Defend the Design"

Schedule a 60-minute mock review (peer or recorded). Walk through: 1. The architecture diagram. 2. The roofline analysis: where does your system sit on the roofline? What's bound by what? 3. One slide per non-obvious decision (e.g., "why FSDP-2 over DeepSpeed Stage-3", "why AWQ over GPTQ", "why your batching policy"). 4. A live demo of the end-to-end artifact. 5. A live demo of one production-quality concern: cost, observability, safety, or fault tolerance.

The deliverable is the defense, not the slides. If you cannot answer: - "What is your worst-case tail latency under 10× concurrent load?" - "What happens when a GPU fails mid-training?" - "What is your cost per million output tokens?" - "How would you scale this to 10× the model size?" ...you have not yet finished the curriculum.

24.4 Production Slice

  • Tag the capstone repo v1.0.0. Write a CHANGELOG. Write a README aimed at the next engineer who picks it up. Write a blog post (publish or shelve) explaining the most interesting technical decision. That blog post is the artifact recruiters and hiring managers actually read.

Month 6 Deliverable

The chosen capstone (per CAPSTONE_PROJECTS.md), running, defensible, observable, cost-attributed.

You are done. The next steps are no longer pedagogical; they are professional.


  • Designing Machine Learning Systems, Chip Huyen.
  • Building Machine Learning Powered Applications, Emmanuel Ameisen-for the production framing.
  • The KServe and KubeRay docs.
  • The OpenAI Evals and Anthropic published evals documentation.
  • Anthropic's Constitutional AI paper (durable framing for safety design).
  • The Llama Guard / ShieldGemma papers.

Appendix A-Hardening, Observability, and Fleet Operations

Cumulative reference for the production-readiness work distributed through the curriculum.


A.1 GPU Profiling Toolkit

Tool Use case
nvidia-smi Quick health check, util, memory, temperature, ECC errors.
nvidia-smi dmon -s pucvmet Streaming per-second telemetry.
dcgmi NVIDIA Data Center GPU Manager-production fleet metrics.
nsys profile System-level timeline: CUDA kernels, NCCL calls, CPU, OS, NVTX ranges.
ncu --set full Per-kernel deep dive: SM occupancy, memory throughput, tensor-core util.
nvprof Legacy. Avoid; use nsys + ncu instead.
torch.profiler Framework-aware: per-op timing, memory allocations, stack traces, CUDA + CPU.
torch.cuda.memory._record_memory_history Memory allocation timeline. Indispensable for OOM debugging.
py-spy / austin Sampling profiler for the Python side.

A complete perf debugging session uses three tools: torch.profiler for the framework view, nsys for the timeline, ncu for the kernel deep-dive.


A.2 GPU Fleet Observability

Required Prometheus metrics for any production GPU fleet:

DCGM_FI_DEV_GPU_UTIL                  # 0-100, time GPU was busy
DCGM_FI_DEV_FB_USED                   # framebuffer (HBM) bytes used
DCGM_FI_DEV_FB_FREE
DCGM_FI_DEV_GPU_TEMP                  # alert > 85°C sustained
DCGM_FI_DEV_POWER_USAGE               # watts
DCGM_FI_DEV_PCIE_LINK_GEN_CURRENT     # surprise downgrades happen
DCGM_FI_PROF_PIPE_TENSOR_ACTIVE       # tensor-core active fraction
DCGM_FI_PROF_DRAM_ACTIVE              # HBM active fraction
DCGM_FI_PROF_NVLINK_RX_BYTES          # interconnect traffic
DCGM_FI_PROF_PCIE_RX_BYTES
DCGM_FI_DEV_XID_ERRORS                # any non-zero is bad
DCGM_FI_DEV_ECC_DBE_VOL_TOTAL         # double-bit ECC = retire the GPU

Plus framework-level: - pytorch_distributed_* for NCCL collectives (use NCCL's NVTX ranges). - vLLM / serving metrics: TTFT, TPOT, requests-running, requests-waiting, kv-cache-usage-perc, num-preemptions.

The four golden signals for an inference fleet: 1. Throughput (tokens/sec/GPU). 2. Tail latency (p99 TTFT, p99 TPOT). 3. Cost (inferred from utilization × $/hour). 4. Quality (eval scores from canary traffic).


A.3 Training Job Hardening

For multi-day training runs:

  • Checkpointing: every N minutes (not just every N steps); resumable from any checkpoint with bit-exact loss continuation.
  • Failure handling: NCCL timeouts retry with exponential backoff before failing the job. NaN detection skips (or fails) the step. Persistent NaNs trigger checkpoint reload.
  • Preemption: cloud spot instances may preempt with seconds of notice. SIGTERM handler triggers checkpoint + clean exit.
  • Health probes: fail-fast on ECC errors, NVLink degradation, throttling. Tools: NVIDIA's GPU diagnostic (dcgmi diag).
  • Run state: persist the run's metadata (commit, config, hardware fingerprint) alongside checkpoints.
  • Cost control: hard timeouts, budget caps, on-call alerting on overruns.
  • Reproducibility: pinned versions, seeded RNGs, documented determinism guarantees.

A.4 Inference Fleet Hardening

  • Request validation: prompt length caps, content-type checks, rate limiting per user/tenant. Reject obvious abuse before invoking the model.
  • Timeouts at every layer: client→gateway, gateway→model, model→GPU. Cascading timeout discipline: each layer's timeout < parent's timeout − queue time budget.
  • Backpressure: when GPU memory or queue depth saturates, return 429 (rate-limited) immediately rather than queueing forever. The model can not catch up; let the client retry with backoff.
  • Graceful degradation: when the primary model fails, fall back to a smaller, cheaper model (with a quality marker in the response). Still-broken: return 503.
  • Multi-region: GPU fleets often consolidate in one region. Plan for regional outages; document the RTO/RPO for inference.
  • Model-version pinning: production traffic targets a specific model digest, not a tag. Promote via canary (small % traffic) → roll forward or rollback.

A.5 Cost Control

Inference cost ≈ tokens × time-per-token × $/GPU-hour / batch-size.

Levers: - Batch size: bigger batch → lower $/token, higher latency. Workload-dependent sweet spot. - Model size: 7B → 13B → 70B is roughly 5× compute per inference. Quality may not justify; eval rigorously. - Quantization: INT4 W4A16 cuts decode HBM traffic ~4×. Roughly proportional throughput gain. - GPU class: H100 vs A100-H100 is ~2× faster for ~1.5× cost. Often cheaper per token. - Region/spot: spot inference is risky; spot training is normal. Spot decisions are cluster-wide. - Caching: prefix caching (system prompts) and full-response caching (FAQ-style) save dramatically. Build hit-rate dashboards.

Budget: $100/month per developer for hands-on learning; $1000-5000 for the capstone month, depending on track.


A.6 Model Observability (Beyond System Metrics)

  • Output distribution drift: monitor response length, refusal rate, fraction of failed JSON-mode outputs, sentiment, toxicity scores.
  • Input distribution drift: prompt length, language detection, topic distribution.
  • Eval-from-production: weekly resample N production requests, send to an eval harness (LLM-as-judge or human eval), surface quality trends.
  • Per-feature observability: tag every request with the calling product/feature; surface metrics per-feature so a single product's regression is detectable.
  • A/B telemetry: when running a canary (5% traffic to model B, 95% to model A), every metric needs to support per-variant breakdowns.

A.7 The ai-systems-baseline/ Template

ai-systems-baseline/
  Dockerfile.cuda                     # pinned CUDA + cuDNN + Python + PyTorch
  pyproject.toml / requirements.txt   # pinned deps, including torch
  scripts/
    bench-throughput.sh
    bench-latency.sh
    profile-nsys.sh
    profile-ncu.sh
    eval-regression.py
  configs/
    nccl.env                          # NCCL_*  tuning
    pytorch.env                       # TORCH_* tuning
    vllm.yaml                         # for serving track
  observability/
    dcgm-exporter-config.yaml
    grafana/
      gpu-fleet.json
      training-health.json
      inference-fleet.json
  ci/
    test.yml
    bench.yml                         # regression-tracked perf
    eval.yml                          # quality regression
    profile-on-pr.yml                 # generate ncu reports on PR
  runbooks/
    nan-during-training.md
    ecc-error.md
    nccl-timeout.md
    inference-latency-spike.md
    oom-on-load.md
  RELEASE_CHECKLIST.md
  THREAT_MODEL.md
  COST_MODEL.md

Every AI workload you ship after week 24 should be built from this template.

Appendix B-Build-From-Scratch Reference

A working AI systems engineer should have implemented each of the following at least once. These are the building blocks from which every modern transformer is assembled.


B.1 Tokenizer (BPE)

When: foundational; everything downstream assumes one.

Design: - Train a byte-pair-encoding (BPE) tokenizer on a small corpus (TinyShakespeare or a Wikipedia subset). - Implement train, encode, decode. Save/load merges. - Test round-trip on held-out text.

Reference: Karpathy's minbpe repo. ~300 lines of Python, well-commented.

Lab outcome: internalize the input layer of every modern LLM. Understand why context-length and tokenization are coupled.


B.2 Multi-Head Attention

When: the heart of the transformer.

Design: - Standard Q, K, V = x @ W_qkv; O = softmax(Q K^T / √d) V; out = O @ W_out. - With causal mask for autoregressive use. - With GQA (group K/V heads). - With KV-cache for decode.

Reference: nanoGPT's model.py. PyTorch reference; ~50 lines.

Lab outcome: understand the operation that defines this era. Read every variant (sliding window, ALiBi, RoPE) against this baseline.


B.3 RMSNorm

When: replaces LayerNorm in modern LLMs (Llama, Qwen, etc.).

Design:

def rmsnorm(x, weight, eps=1e-6):
    rms = (x.pow(2).mean(-1, keepdim=True) + eps).rsqrt()
    return x * rms * weight
Triton-fused version: Month 3, week 12. Be able to write both.


B.4 Rotary Position Embedding (RoPE)

When: every modern open LLM since 2022.

Design: - Precompute cos/sin tables for all positions × half head dim. - Apply to Q and K after the projection, before attention. - Position-shift extension: linear interpolation, NTK-aware, YaRN-variants for context extension.

Reference: Su et al., RoFormer: Enhanced Transformer with Rotary Position Embedding (2021).


B.5 Mixture-of-Experts (MoE) Layer

When: increasingly used (Mixtral, DeepSeek-V3, GPT-MoE).

Design: - Top-k gating: gate_logits = x @ W_gate; topk_weights, topk_idx = topk(gate_logits). - Dispatch each token to its top-k experts (typically k=2 of 8 experts). - Combine outputs weighted by gate. - Load balancing loss to prevent expert collapse.

Reference: Shazeer et al., Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer (2017). Fedus et al., Switch Transformers (2021).

Lab outcome: MoE is the primary architectural lever for parameter scaling without proportional FLOP scaling.


B.6 FlashAttention (Tiled Attention)

When: any context-length-bound workload.

Design (sketch): - Tile Q in blocks of B_q. - For each Q-tile, stream K and V in blocks of B_k. - Maintain running max, running sum, running output (online softmax). - Output block when done.

Implementation: in Triton, ~150 lines. Beating FA-2 / FA-3 is hard; matching them on common shapes is achievable.

Reference: Dao et al., FlashAttention (2022) and FlashAttention-2 (2023).


B.7 Optimizer (AdamW)

When: foundational.

Design:

m = β1*m + (1-β1)*g
v = β2*v + (1-β2)*g²
m_hat = m / (1 - β1^t)
v_hat = v / (1 - β2^t)
p = p - lr * (m_hat / (sqrt(v_hat) + eps) + weight_decay * p)
- Per-parameter state (m, v)-2x model size in FP32. - foreach and fused variants in modern PyTorch fuse the per-parameter loops.

Variants worth understanding: Lion, AdEMAMix, Sophia, Muon (the Distributed Shampoo lineage). Most aim to halve optimizer-state memory or improve convergence.


B.8 Data Loader

When: every training run.

Design: - Memory-mapped binary token files (one giant np.uint16 or np.uint32 array). - Random sampling: pick a position uniformly, slice context_length tokens. - Cross-worker / cross-rank determinism: identical seed → identical sample order. - Multi-process via torch.utils.data.DataLoader(num_workers=N).

Reference: nanoGPT's data prep + loader. Production scale: Mosaic StreamingDataset, FFCV.


B.9 Paged KV-Cache (Mini)

When: serving track capstone.

Design: - Block pool: [num_blocks, num_heads, block_size, head_dim] for K and same for V. - Free-block stack. - Per-request page table: list of physical block indices. - On decode step, gather indexed blocks for attention. - On request completion, return blocks to free pool.

Reference: vLLM's core/block_manager.py. ~400 lines.

Lab outcome: this is the mini-vLLM capstone's foundation.


B.10 Continuous Batching Scheduler (Mini)

When: serving track.

Design: - Iteration-level scheduling loop:

while True:
    batch = pick_runnable_requests()    # respect KV-budget
    logits = model.forward(batch)        # one decode step
    tokens = sample(logits)
    for req, tok in zip(batch, tokens):
        req.append(tok)
        if req.is_done(): finalize(req); free_blocks(req)
- Admission policy when memory tight: preempt the longest-running request, or evict + restart.

Reference: Yu et al., Orca (OSDI 2022).


B.11 Speculative Decoding Loop

When: high-priority inference latency wins.

Design:

while not done:
    drafts = draft_model.generate_k(state, k=K)
    target_logits = target_model.parallel_verify(state, drafts)  # one forward of K tokens
    accepted = longest_match(target_logits, drafts)
    state.append(accepted)
    if len(accepted) < K:
        state.append(sample(target_logits[len(accepted)]))  # fallback
- Tune K per workload.

Reference: Leviathan et al., Speculative Decoding (ICML 2023).


B.12 NCCL-Collective From Scratch (Bonus)

When: training-systems track capstone.

Design: - Implement ring-allreduce yourself using point-to-point send/recv (NCCL or MPI). - Compare bandwidth to dist.all_reduce.

Lab outcome: internalize why ring-allreduce achieves bandwidth-optimal scaling.


Difficulty Ranking

Tier Builds
Warmup Tokenizer (BPE), AdamW, RMSNorm
Intermediate Multi-head attention with KV-cache, RoPE, Data loader, MoE
Advanced FlashAttention in Triton, Paged KV-cache, Continuous batching scheduler
Expert NCCL ring-allreduce, mini-vLLM end-to-end, FP8 training stable for >100k steps

Pick at least one from each tier. Ship with benchmarks, profiling artifacts, and a writeup.

Appendix C-Contributing to the AI Systems Ecosystem

The AI systems ecosystem is largely on GitHub, with friendly maintainers, fast review cycles, and high impact per merged PR. The path from "user" to "contributor" is shorter here than almost anywhere else in software.


C.1 The Project Map (2026)

Project Bar Scope Notes
pytorch/pytorch High The framework Big org; per-subsystem reviewers; specific contribution guides per area.
pytorch/ao Medium Quantization + sparsity Fast-moving; welcoming.
huggingface/transformers Low–Medium Model implementations Highest velocity; small fixes merged in days.
huggingface/accelerate Medium Training launcher Welcoming; growing.
huggingface/text-generation-inference Medium Production inference (Rust+Python) Welcoming.
vllm-project/vllm Medium Inference server High velocity, friendly maintainers. The single most strategic project on this list.
Dao-AILab/flash-attention High The attention kernel Tight ownership; deep expertise required.
openai/triton Medium–High The DSL Compiler-shaped contributions; high learning curve.
NVIDIA/cutlass High GEMM templates Deep CUDA + template metaprogramming.
NVIDIA/TransformerEngine Medium FP8 + transformer ops Active; contributions welcome.
google/jax Medium The functional framework Smaller team; high standards.
openxla/xla High The compiler Compiler expertise required.
microsoft/DeepSpeed Medium Training stack Active.
NVIDIA/Megatron-LM Medium Training stack Reference impl for many parallelism patterns.
ray-project/ray Medium Distributed Python Big org; many subteams.
pytorch/torchtitan Medium Reference training New, growing, well-curated.
lm-evaluation-harness (EleutherAI) Low Evaluation New benchmarks always welcome.
EleutherAI/gpt-neox Medium Training stack Stable, smaller community.

C.2 First-Issue On-Ramps

Easy

  • huggingface/transformers: docstring fixes, model-config consistency, new model contributions (the model-add guide is excellent).
  • vllm-project/vllm: bug reports with minimal repro; small kernel optimizations; new model architecture additions.
  • lm-evaluation-harness: add a new task; fix a metric.

Medium

  • pytorch/ao: a new quantization recipe, a kernel optimization, integration with a new model.
  • vllm-project/vllm: support a new attention backend or scheduling policy. Specific issues labeled good first issue are tractable.
  • openai/triton: fix a specific autotuning regression; add a new tutorial.
  • huggingface/accelerate: a new launcher integration, FSDP-2 support edge cases.

Hard

  • pytorch/pytorch core (especially aten/, dispatcher, autograd): high stakes; deep familiarity required; cycle time is weeks.
  • Dao-AILab/flash-attention: kernel-level changes; deep CUDA expertise.
  • NVIDIA/cutlass: template-heavy C++; deep architecture expertise.
  • openxla/xla: compiler internals.

C.3 The Workflow (typical)

  1. Find an issue: filter by good first issue / help wanted labels.
  2. Comment to claim, ideally with a one-paragraph plan.
  3. Discuss design: for non-trivial changes, the maintainers will steer you. Listen.
  4. Implement, write tests (every project has its own conventions; mimic existing tests).
  5. Open the PR: small, focused, with a clear description and reproduction case.
  6. Address review: usually 1-3 cycles. Merge.

Cycle time: HF Transformers / vLLM / lm-eval-days. PyTorch core / FlashAttention-weeks.


C.4 The Highest-Leverage Contributions

In the AI systems ecosystem, the contributions that earn outsized recognition tend to be:

  1. A new fast kernel (Triton, CUTLASS, or hand-CUDA) for a common operation. Examples: rotary embedding, RMSNorm fused with subsequent matmul, GQA attention for a specific head shape. Liger Kernel, Unsloth are open-source examples.
  2. An integration: support a new model architecture in vLLM, a new optimizer in DeepSpeed, a new quantization scheme in pytorch/ao.
  3. A measured perf win: identify a slow path with nsys / ncu evidence, propose a fix, ship the patch with before/after numbers.
  4. A new benchmark or eval: meaningful evals are scarce and load-bearing for the field.
  5. A reproduction + study: take a published-but-not-quite-reproducible technique, ship a clean reference implementation. (Mistral / DeepSeek architecture studies have been notable examples.)

C.5 The Indirect Path: Open-Source Repos as Portfolio

If contributing-to-the-frameworks isn't yielding fast review, build your own open-source repo and ship it well. Specific patterns that have worked:

  • A clear, single-purpose tool: e.g., a small inference server with a particular optimization (paged + speculative + 4-bit), benchmarks, blog post.
  • An educational reference impl: nanoGPT-class-minimal, readable, MIT-licensed.
  • A reproduction: of a recent paper, with clean code and notes on what worked and didn't.

Each of the above can demonstrate AI-systems fluency to a hiring manager more efficiently than chasing a single merged PR in PyTorch.


C.6 Calibration

A reasonable goal for a curriculum graduate:

  • By end of week 23: a PR open against vLLM, transformers, accelerate, lm-eval, or pytorch/ao.
  • By end of capstone: that PR merged, or a public-facing capstone with measurable performance.
  • 6 months post-curriculum: a substantive contribution-a new kernel, a new model integration, a measured perf win, or an established open-source artifact.

The ecosystem moves fast. Patient, persistent contributors become trusted; trusted contributors become reviewers; reviewers become maintainers. The path is shorter here than in any adjacent ecosystem-and the ratio of "interesting work to do" to "qualified people doing it" is the highest in software in 2026.

Capstone Projects-Three Tracks, One Choice

Pick one. The work performed here is what you describe in interviews and link from a portfolio.


Track 1-Inference Engine: A Mini-vLLM

Outcome: an LLM inference server you wrote, with paged KV-cache, continuous batching, and at least one of (FP8 weights, AWQ INT4, speculative decoding). Benchmarked within 2× of production vLLM on a 7B model.

Functional spec

  • HTTP API: POST /v1/completions and POST /v1/chat/completions (subset of OpenAI's API).
  • Server-sent-events streaming output.
  • Continuous batching with paged KV-cache.
  • One quantization scheme (your choice: AWQ W4A16 with Marlin kernel, or FP8 weights via TransformerEngine).
  • Optionally: prefix caching, speculative decoding.
  • Health, readiness, metrics endpoints.

Non-functional spec

  • Throughput within 2× of production vLLM for a 7B model on the same hardware. (vLLM is the bar; matching it is implausible in 24 weeks. Within 2× is achievable and impressive.)
  • TTFT p99 < 1s for 1K-token prompts under steady-state load.
  • TPOT p99 < 30 ms after first token.
  • Memory: stable under 8-hour load; no leaks.

Architecture sketch

HTTP server (FastAPI/Axum)
Request queue ──► Scheduler (Python or Rust)
                  Block manager (page table, free list)
              Model runner (Python/Triton/CUDA)
                  Token streamer

Test rigor

  • Unit tests for the block manager (allocate/free/leak detection).
  • Integration: warmup load, sustained load, mixed-prompt-length load.
  • Correctness: outputs match a reference HF implementation for greedy decoding.
  • Stress: kill-9 the process under load; restart; verify recovery.

Hardening pass

  • pprof - style metrics;nsys` profile of one full request lifecycle committed to the repo.
  • ncu profile of the attention kernel.
  • Cost/quality matrix.

Acceptance criteria

  • Public repo with build + run + benchmark scripts.
  • A README with: architecture diagram, benchmark table, profiling artifacts, "what's next" section.
  • A blog post explaining one non-obvious decision (e.g., your block size choice, your eviction policy).

Skills exercised

  • All months. Heaviest on Months 2 (kernels), 3 (framework integration), 5 (serving).

Track 2-Training Systems: FSDP From Scratch

Outcome: a working sharded data-parallel training implementation, written from scratch on top of torch.distributed primitives. Trains a small transformer on 4–8 GPUs across 1–2 nodes with documented scaling efficiency.

Functional spec

  • A MyFSDP wrapper that:
  • Shards parameters across ranks.
  • Allgathers parameters before forward; frees after.
  • Reduces gradients via reduce-scatter.
  • Supports activation checkpointing.
  • Mixed-precision (BF16 compute, FP32 master).
  • Gradient accumulation.
  • Resumable checkpoints.
  • A reference training script that uses it to train a ~500M parameter transformer on a tokenized corpus.

Non-functional spec

  • Scaling efficiency ≥85% on 8 GPUs (single node) vs single-GPU baseline.
  • Scaling efficiency ≥75% on 16 GPUs across 2 nodes.
  • Resumed run produces identical loss (within 1e-4) compared to a continuous run.
  • Throughput within 30% of PyTorch's native FSDP-2 on the same workload.

Test rigor

  • Numerical correctness: 4-rank MyFSDP matches single-rank reference for one full training step (allclose at 1e-3 in BF16).
  • Memory measurement: peak HBM matches model_size / num_ranks + activation_overhead.
  • Failure injection: kill one rank mid-epoch; observe NCCL timeout; document the recovery path.

Hardening pass

  • NCCL tuning (NCCL_IB_HCA, NCCL_SOCKET_IFNAME, NCCL_TOPO_FILE) documented.
  • nsys profile showing allgather/reduce-scatter overlap with compute.
  • Cost calculation (GPU-hours × $/hr × scaling-efficiency).

Acceptance criteria

  • Public repo with infra-as-code (Terraform/Ansible) for bringing up a 2-node cluster, plus the FSDP code, plus the training script.
  • A SCALING_REPORT.md with the efficiency numbers, the optimization journey (each tuning step's effect), and one comparison against native FSDP-2.

Skills exercised

  • All months. Heaviest on Months 3 (framework), 4 (distributed).

Track 3-GPU Kernel Track: A Competitive Fused Attention

Outcome: a fused attention kernel in Triton (and optionally CUTLASS), competitive with FlashAttention-2 for at least one common shape regime, complete with profiling, autograd, and a tested PyTorch integration.

Functional spec

  • A Triton kernel implementing causal flash-attention (forward + backward).
  • Configurable for: BF16 / FP16, head dim 64/128, GQA support.
  • Drop-in replacement for F.scaled_dot_product_attention for the supported shape range.
  • Numerically equivalent to the reference (allclose at 1e-3 BF16).

Non-functional spec

  • Within 1.5× of FlashAttention-2 forward+backward time at one chosen shape (e.g., B=4, H=32, S=4096, D=128).
  • Validated on at least two GPU classes (e.g., A100 + H100 if both accessible; A100 + RTX 4090 acceptable).
  • Compiles via torch.compile without graph breaks when used in a small transformer.

Test rigor

  • Correctness: random input testing against reference attention; gradient testing against torch.autograd.gradcheck (FP32 reference).
  • Perf: ncu reports for forward and backward; attached to the repo.
  • Edge cases: short sequences, variable-length (padding-aware), large batch.

Hardening pass

  • Autotune configs documented with rationale.
  • Performance-regression CI (lock benchmark numbers; alert on >5% regression).

Acceptance criteria

  • Public repo with the kernel, tests, benchmarks, and a clear README.
  • One submitted PR (even if not merged) to a real project: vLLM (as a backend), Liger Kernel, or pytorch/ao.
  • A blog post analyzing one design choice in the kernel-block sizes, software pipelining stages, register-pressure tradeoffs.

Skills exercised

  • All months, but most concentrated on Months 2 (GPU programming), 3 (framework integration), and the inference math from Month 5.

Cross-Track Requirements

Regardless of track:

  • ai-systems-baseline/ template (Appendix A) integrated.
  • ADRs: ≥3 for non-obvious decisions.
  • Threat model: at minimum, one page covering input attacks, supply-chain risk, and infrastructure failure modes.
  • Cost model: per the workload, what's the steady-state $/hour and $/output-unit?
  • Defense readiness: a 60-minute walkthrough you can deliver to a peer or hiring manager.

The track choice signals career direction: - Track 1 (Inference) → inference-engineer roles at frontier labs, latency-sensitive serving teams, model-serving startups. - Track 2 (Training) → training-infra roles at frontier labs, large-scale-training teams, framework engineering teams. - Track 3 (Kernels) → GPU performance engineering, compiler/runtime teams (NVIDIA, OpenAI, Meta), specialized inference accelerator teams.

Pick based on where you want the next interview loop, not on what looks easiest.

Deep Dive 01-NVIDIA GPU Architecture and Memory Hierarchy

A self-contained reference chapter. Reader prerequisites: Python, basic C, basic linear algebra. Everything else is built up here. Canonical chip: NVIDIA H100 (Hopper, SM_90a). Other generations are referenced with explicit version tags.


Table of Contents

  1. Why GPUs Exist: Throughput Machines vs Latency Machines
  2. Execution Models: SIMD, SIMT, MIMD
  3. Why Deep Learning Maps onto Throughput Hardware
  4. The Streaming Multiprocessor (SM)-Anatomy of the Compute Unit
  5. Warp Scheduling, Dual-Issue, and Scoreboarding
  6. Branch Divergence and Independent Thread Scheduling
  7. The Memory Hierarchy: Registers to NVMe (with H100 numbers)
  8. Tensor Cores: WMMA, mma.sync, Fragments, Precisions
  9. 2:4 Structured Sparsity and Async Tensor Cores
  10. Async Copy and the Tensor Memory Accelerator (TMA)
  11. Occupancy Theory-Derivation From First Principles
  12. NVLink, NVSwitch, and Multi-GPU Topologies
  13. Ada and Blackwell Deltas (with explicit uncertainty)
  14. AMD CDNA / MI300X Contrast
  15. Five Worked Practical Exercises

1. Why GPUs Exist: Throughput Machines vs Latency Machines

A CPU is a latency machine. Its design goal is: take one sequential thread of instructions and complete it as fast as possible. To do that, a modern CPU core spends most of its silicon area on machinery that has nothing to do with arithmetic:

  • Out-of-order execution (reorder buffer, register renaming, ~hundreds of in-flight instructions).
  • Aggressive branch prediction (TAGE-style predictors with multi-KB history tables).
  • Several MB of private and shared cache.
  • Speculative loads, memory disambiguation, prefetchers.

On a recent server CPU, perhaps 5–10% of the die is the ALUs that actually do arithmetic. The other 90% exists to hide latency for one thread.

A GPU is a throughput machine. Its design goal is: given many independent units of work, finish all of them as fast as possible. Per-task latency is irrelevant; what matters is operations completed per second across the whole chip. To do that, a GPU does the inverse: it spends silicon on ALUs and, instead of hiding latency with prediction and out-of-order, hides latency with parallelism. When one warp stalls on memory, the SM switches to another warp. As long as you have enough independent warps in flight, every cycle is doing useful arithmetic.

Concretely on H100:

              Latency machine (CPU core)        Throughput machine (H100 SM)
              -----------------------------     -----------------------------
ALUs/unit     ~few wide vector lanes            128 FP32 + 64 FP64 + 64 INT32
                                                 + 4 Tensor Cores per SM
Threads in    1 (SMT: 2)                        Up to 64 warps = 2048 threads
flight                                           per SM, x132 SMs = 270k threads
Branch pred   Massive                            Minimal
OoO window    ~512 micro-ops                    None (in-order issue per warp)
Cache/core    ~1–4 MB private                   256 KB register file +
                                                 228 KB combined L1/smem

The two designs are both correct-just for different problem shapes. CPU wins when you have one critical-path thread. GPU wins when you have a problem expressible as "do this same operation to a million data items, mostly independently."

Roofline picture. Any program is bounded either by compute (FLOPs/s) or by memory bandwidth (bytes/s). The crossover point is arithmetic intensity-FLOPs per byte loaded. H100 has roughly:

  • BF16 dense compute: ~989 TFLOPS
  • HBM3 bandwidth: ~3 TB/s = ~3 × 10^12 B/s

Crossover intensity = 989e12 / 3e12 ≈ 330 FLOPs/byte (BF16 case). Below that, you are memory-bound; above, compute-bound. A GEMM of large enough size sits well above. Element-wise activations sit far below. This single inequality predicts most of GPU programming pain.


2. Execution Models: SIMD, SIMT, MIMD

Three classical taxonomies for parallel hardware:

  • SIMD (Single Instruction, Multiple Data). One instruction fetch operates on a vector register of N lanes. Examples: x86 AVX-512 (16 FP32 lanes), ARM SVE. The programmer writes vector code explicitly; lanes are not addressable as separate threads. Branches require predication or scalar fallback.

  • MIMD (Multiple Instruction, Multiple Data). Each core runs its own independent instruction stream. Examples: any multi-core CPU. Maximum flexibility, maximum hardware overhead per core.

  • SIMT (Single Instruction, Multiple Threads-NVIDIA's term). 32 lanes execute the same instruction in lockstep, but each lane is programmed as if it were a thread: it has its own registers, its own program counter (logically), its own stack. The compiler/programmer writes scalar code; the hardware groups 32 of them into a warp that issues together.

SIMT is essentially "SIMD with a thread illusion." It buys you:

  1. Programmer ergonomics. You write if (tid % 2 == 0) ... else ... and it works. Under the hood the warp executes both sides with masks (this is divergence; see §6).
  2. Per-lane addressing. A warp can load 32 different addresses in one instruction-a gather. SIMD can too, but only via dedicated gather instructions; in SIMT it is the natural mode.
  3. Dynamic parallelism scaling. You launch a grid of millions of threads and the hardware schedules warps onto SMs. The same source compiles for a 24-SM laptop GPU and a 132-SM H100.

The cost: lanes within a warp must execute the same instruction each cycle. If they branch differently, only the lanes on the active path do useful work. This is the fundamental SIMT tradeoff and the source of much GPU performance lore.

SIMD:     [op] -> [lane0 lane1 ... lane15]   one PC, one mask register
SIMT:     [op] -> [thread0 thread1 ... thread31]   per-lane PC since Volta
                                                    plus an "active mask"
MIMD:     [op0]->core0   [op1]->core1   ...        N independent PCs

3. Why Deep Learning Maps onto Throughput Hardware

Take a single forward pass through a transformer block. The expensive ops are:

  • Linear layers (y = x W): a GEMM. For batch B, sequence S, model dim D, this is (B·S × D) × (D × D) = O(B·S·D²) FLOPs against O(B·S·D + D²) bytes of weights+activations. Arithmetic intensity grows with D.
  • Attention (Q Kᵀ, softmax, ·V): two GEMMs and a softmax. The GEMMs are again high-intensity for non-tiny shapes.
  • LayerNorm / RMSNorm / activations: element-wise; arithmetic intensity ~1. Memory-bound.

GEMMs are the canonical GPU-friendly workload because:

  1. Massive independent parallelism. Each output element is an independent dot product. A 4096×4096 output has 16.7M independent dot products-enough to fill an H100's 270k threads many times over.
  2. High arithmetic intensity at scale. For (M×K)·(K×N), FLOPs ≈ 2·M·N·K and bytes ≈ 2·(M·K + K·N + M·N) (in BF16). Intensity scales with min(M,N,K), so big matmuls are compute-bound.
  3. Regular memory access. Tiles map cleanly onto shared memory. Vectorized 128-bit loads stay coalesced.
  4. Exploitable structure. Tensor Cores accept fixed-shape tiles (e.g., 16×16×16 BF16) and produce one output in a few cycles. The math literally is "outer product and accumulate," which is what a transformer wants.

Training adds backward and optimizer steps but the dominant cost is still GEMM. So the entire ML training stack-PyTorch, cuBLAS, cuDNN, FlashAttention, Triton, CUTLASS-is essentially elaborate machinery for feeding GEMMs to Tensor Cores efficiently.


4. The Streaming Multiprocessor (SM)-Anatomy of the Compute Unit

The SM is the smallest unit of "GPU." An H100 has 132 SMs (H100 SXM5; the PCIe variant has 114). All your CUDA threads run inside SMs. A thread block is assigned to exactly one SM and stays there for its lifetime.

4.1 H100 SM block diagram (canonical)

+--------------------------------------------------------------+
|                       Streaming Multiprocessor (H100)        |
|                                                              |
|  +-------------------+   +-------------------+               |
|  |  Sub-partition 0  |   |  Sub-partition 1  |  ... x4       |
|  |                   |   |                   |               |
|  |  Warp Scheduler   |   |  Warp Scheduler   |               |
|  |  Dispatch Unit    |   |  Dispatch Unit    |               |
|  |                   |   |                   |               |
|  |  Register File    |   |  Register File    |               |
|  |   16384 x 32-bit  |   |   16384 x 32-bit  |               |
|  |   = 64 KB         |   |   = 64 KB         |               |
|  |                   |   |                   |               |
|  |  32 FP32 ALUs     |   |  32 FP32 ALUs     |               |
|  |  16 FP64 ALUs     |   |  16 FP64 ALUs     |               |
|  |  16 INT32 ALUs    |   |  16 INT32 ALUs    |               |
|  |  8  LD/ST units   |   |  8  LD/ST units   |               |
|  |  4  SFUs          |   |  4  SFUs          |               |
|  |  1  Tensor Core   |   |  1  Tensor Core   |               |
|  +-------------------+   +-------------------+               |
|                                                              |
|  Total per SM: 128 FP32, 64 FP64, 64 INT32, 4 Tensor Cores   |
|                                                              |
|  +--------------------------------------------------------+  |
|  |  Combined L1 Data Cache + Shared Memory  (228 KB)      |  |
|  |  (configurable: up to 228 KB shared, rest as L1)       |  |
|  +--------------------------------------------------------+  |
|                                                              |
|  Tensor Memory Accelerator (TMA)   |  Async barrier engines  |
+--------------------------------------------------------------+

The SM is divided into four sub-partitions, also called processing blocks. Each has its own:

  • Warp scheduler-picks one ready warp per cycle.
  • Dispatch unit-issues the picked warp's instruction to the right execution unit.
  • Register file slice-16384 × 32-bit registers = 64 KB. (4 × 64 KB = 256 KB per SM total.)
  • A 1/4 share of the FP32/FP64/INT/SFU/LD-ST/Tensor-Core resources.

Threads within a warp always live in one sub-partition. That is why the warp is exactly 32 threads: it matches the sub-partition's lane count for the most common ops.

4.2 The execution units

  • FP32 cores ("CUDA cores"): 128/SM on H100. Do scalar FP32 add/multiply/FMA each cycle.
  • FP64 cores: 64/SM. Hopper's FP64 ratio is much higher than gaming GPUs (Ada has only 4 FP64 per SM as a throttled path).
  • INT32 cores: 64/SM. Address arithmetic and integer kernels.
  • SFU (Special Function Unit): transcendentals (rsqrt, exp2, log2, sin, cos) at reduced throughput (typically 1/4 of FP32). PyTorch GELU/SiLU eventually lower to SFU instructions.
  • LD/ST units: issue loads/stores to L1, shared, L2, HBM.
  • Tensor Cores: 4/SM on Hopper, 4th-generation. Each does a small matrix multiply per cycle on tiles.

4.3 Register file-the fastest, most precious resource

  • Capacity: 256 KB per SM = 65536 × 32-bit registers.
  • Latency: effectively 0 cycles (read-after-write hazards are tracked but throughput is one operand bundle/cycle).
  • Bandwidth: enormous-every FMA reads 3 operands and writes 1, on every functional unit, every cycle.

Each thread can use up to 255 32-bit registers (a hard CUDA limit). The compiler (ptxas) decides how many registers a kernel actually uses, controllable via __launch_bounds__ or - maxrregcount`. Higher register count per thread → fewer threads can be resident on the SM (because the file is fixed size) → lower occupancy. This is the central tension we'll formalize in §11.

A warp's "register footprint" is regs_per_thread × 32. For a warp using 64 regs/thread, that is 2048 registers = 8 KB. The 64 KB sub-partition register file therefore holds at most 64 KB / 8 KB = 8 warps' worth, which caps that sub-partition's warp residency.

4.4 Generation deltas (SM-level)

Feature A100 (Ampere, SM_80) H100 (Hopper, SM_90) RTX 4090 (Ada, SM_89) B100/B200 (Blackwell, SM_100)
FP32/SM 64 128 128 128 (approximate)
Tensor Cores/SM 4 (3rd gen) 4 (4th gen) 4 (4th gen) 4 (5th gen, FP4-capable)
Register file/SM 256 KB 256 KB 256 KB 256 KB (publicly stated)
L1+smem/SM 192 KB 228 KB 128 KB ~256 KB (approximate)
FP8 tensor cores no yes (E4M3 / E5M2) yes yes (also FP4 / FP6)
TMA hardware no yes no yes (enhanced)
Thread block clusters no yes (SM-to-SM smem) no yes
2nd-gen Transformer Engine no no (1st gen) no yes

Numbers marked "approximate" for Blackwell because as of authoring NVIDIA had not published every microarchitectural detail with full precision-verify with the latest H100/B200 whitepapers when relying on exact figures.


5. Warp Scheduling, Dual-Issue, and Scoreboarding

Now the dynamic picture: how an SM picks what to execute each cycle.

5.1 The basic loop

Each cycle, in each sub-partition:

  1. The warp scheduler scans all resident warps in this sub-partition (up to 16; 64 total per SM).
  2. It selects warps whose next instruction has all source operands ready (no outstanding dependencies). This is scoreboarding.
  3. It issues one (Hopper: sometimes two) instruction(s) to the appropriate functional unit.

Key fact: the SM does no out-of-order execution within a warp. Instructions from a single warp are issued in program order. Latency is hidden by switching among warps, not by reordering one warp.

5.2 Scoreboarding in detail

Each register has a "scoreboard bit" tracking whether a long-latency operation (memory load, transcendental, tensor MMA) is still writing it. When the scheduler considers issuing instruction I, it checks all of I's source registers' scoreboard bits. If any are set, the warp is not ready; the scheduler picks a different warp. When the long op completes, the bit clears.

This is why a kernel with many independent warps hides memory latency for free. The math:

  • Suppose every load takes 400 cycles to HBM.
  • A warp that just issued a load will be unable to issue its next dependent instruction for ~400 cycles.
  • If the sub-partition has 8 resident warps and each issues a load every ~50 cycles, then on average there is always at least one warp ready to issue.
  • Functional units stay busy; no stall is observed.

The exact threshold is Little's Law: parallelism_required = latency × throughput. For an SM that issues 1 instruction/cycle/sub-partition with 400-cycle memory latency, you need ~400 in-flight instructions per sub-partition to hide the latency. Each warp can have a few in flight at once (independent loads), so a handful of warps suffice.

5.3 Dual-issue on Hopper

In some cases a Hopper sub-partition can issue two instructions from the same warp in one cycle, provided they target different functional units and have no dependency. Example: an FP32 FMA and an INT32 address calculation can co-issue. This is not superscalar OoO-it is constrained dual-issue from the same in-order warp.

5.4 Warp scheduling timeline (ASCII)

cycle:        0   1   2   3   4   5   6   7   8   9  ...
warp 0:       I0          (LOAD pending..............)
warp 1:           I0  I1  I2
warp 2:                       I0          (LOAD pending...
warp 3:                           I0  I1
warp 0:                                      [load done] I1
issue slot:   W0  W1  W1  W1  W2  W3  W3  W2'??  W0  ...
                                       ^waiting on its load

The scheduler's job is to keep that bottom row never empty. If it goes empty, the SM is stalled and you are leaving FLOPs on the table.

5.5 Resident warp limits (H100)

  • 64 warps maximum per SM (16 per sub-partition × 4 sub-partitions).
  • 32 thread blocks maximum per SM.
  • 2048 threads maximum per SM.
  • Constrained additionally by registers (256 KB/SM) and shared memory (≤228 KB/SM).

The actual resident count is min of all these constraints. §11 walks the math.


6. Branch Divergence and Independent Thread Scheduling

6.1 The classical (pre-Volta) story

A warp has one program counter. If lanes diverge:

if (tid % 2 == 0) {
    A();  // even lanes
} else {
    B();  // odd lanes
}

The hardware executes A() with mask 0xAAAAAAAA, then B() with 0x55555555, then reconverges. Both branches run sequentially; throughput is halved. Worst case (32 different paths) is a 32× slowdown-the warp serializes.

6.2 Independent Thread Scheduling (Volta and later)

Since Volta (SM_70), each lane has its own program counter and call stack. The hardware can interleave divergent paths and even let lanes from the same warp synchronize among themselves. This enables fine-grained algorithms (per-lane locks, producer/consumer within a warp) that were impossible before.

The performance picture is unchanged: at any one cycle the sub-partition can only issue one path. The benefit is correctness/expressiveness, not raw throughput. You still want lanes within a warp to follow the same path most of the time.

6.3 Practical rules

  • Branch on warp-aligned quantities when possible (e.g., on warp_id, not tid). All 32 lanes go the same way; no divergence.
  • Hoist invariant work out of divergent branches.
  • Use __ballot_sync, __any_sync, __all_sync for warp-wide voting instead of explicit branch+reduce.
  • Predication (compiler-generated) handles short branches with no actual divergence cost.

7. The Memory Hierarchy: Registers to NVMe (with H100 Numbers)

This section is the core of practical GPU programming. Every performance decision is a memory decision.

7.1 The hierarchy (H100 specific)

Tier Capacity Latency Bandwidth (peak) Scope
Registers 256 KB / SM ~1 cycle enormous (per-lane) per-thread
L1 / Shared mem 228 KB / SM ~20–30 cycles ~10s of TB/s aggreg per-block
L2 cache 50 MB ~150–250 cycles ~5–7 TB/s aggreg device-wide
HBM3 80 GB ~400–600 cycles ~3 TB/s device-wide
Host DRAM system-dep. ~µs (many k cycles) ~50 GB/s (PCIe 5) host
NVMe SSD TB-scale ~10s of µs ~5–14 GB/s (PCIe 5) host (block dev)

Caveats: cycle counts are nominal-they vary with bank conflicts, DRAM row state, and contention. Treat them as orders of magnitude. Bandwidth numbers are peak achievable on H100 SXM5 in well-tuned kernels; everyday kernels see less.

Aggregate L1/smem bandwidth: each SM can issue ~128 B/cycle of shared loads, at ~1.6 GHz, × 132 SMs ≈ 27 TB/s peak in well-balanced cases. Treat as approximate-verify with NVIDIA H100 datasheet.

7.2 ASCII picture

   per-thread               per-block               device-wide        host
  +-----------+         +----------------+        +------------+    +-------+
  | Registers |         |  Shared Memory |        |     L2     |    | DRAM  |
  |  ~1 cyc   |  -->    |  (config'd     |  -->   |   ~150 cyc |--> | µs    |
  | 256 KB/SM |         |   from L1)     |        |   50 MB    |    | (PCIe |
  +-----------+         |  ~25 cyc       |        +------------+    |  5)   |
        ^               |  228 KB/SM     |              |           +-------+
        |               +----------------+              |               |
        |                       ^                       v               v
        |                       |                  +---------+      +------+
        +----stmts-------------(L1 hit ~30 cyc)<---|  HBM3   |      | NVMe |
                                                   |  ~500   |      |  10s |
                                                   |  cyc    |      |  µs  |
                                                   |  3 TB/s |      | GB/s |
                                                   |  80 GB  |      +------+
                                                   +---------+

7.3 Cascade of misses

A single load instruction in a CUDA kernel does this:

  1. Coalesce check. The 32 lanes' addresses are inspected. If they fall within a small number of 128-byte sectors, the request becomes 1–4 memory transactions. If they scatter, it becomes up to 32. This is the single biggest determinant of memory performance.

  2. L1 lookup. Each transaction consults L1. If it hits, ~30 cycles, done.

  3. L2 lookup on L1 miss. Sent to L2 (which is shared device-wide and partitioned across the chip). Hit: ~150–250 cycles round trip.

  4. HBM on L2 miss. Goes to HBM3. ~400–600 cycles. Possibly more if DRAM page must be opened.

  5. (Unified memory only) host fault. If using managed memory and the page lives in host DRAM, a page fault crosses PCIe-microseconds, i.e. thousands of cycles. Catastrophic for kernel throughput.

  6. NVMe. Only via explicit pinning + cudaMemcpy or via GPUDirect Storage. Tens of microseconds or worse.

The practical implication: once you miss to HBM, you have ~500 cycles to hide. Once you miss to host, you have ~10000+ cycles to hide and almost certainly cannot. Keep working sets in registers and shared memory.

7.4 Coalescing: a worked picture

Suppose 32 threads each load float a = arr[tid] where arr is 4-byte aligned. The 32 addresses are `arr[0], arr[1], ..., arr[31] - contiguous, 128 bytes total. The hardware bundles this into one 128-byte transaction, perfectly coalesced.

Now suppose float a = arr[tid * 32]. The addresses are arr[0], arr[32], arr[64], ..., each 128 bytes apart. That is 32 transactions of 128 bytes-32× the memory traffic for the same useful data.

COALESCED:    [t0 t1 t2 ... t31]  -> one 128B sector
              ^^^^^^^^^^^^^^^^^^

UNCOALESCED:  t0 ........... t1 ............ t2 ............ ...
              ^^^^                                                 -> 32 sectors
                              ^^^^
                                                ^^^^

7.5 Shared memory: the workhorse

Shared memory is software-managed, on-SM SRAM. Latency is ~25 cycles, comparable to a register hit, and all 32 threads in a warp can read independent addresses simultaneously. This is what makes it the staging area for tile-based algorithms.

It is organized into 32 banks, each 4 bytes wide. A warp's 32 accesses are conflict-free if every lane hits a different bank. If two lanes hit the same bank (different rows), the access serializes-a bank conflict. Common pitfall: 2D tiles of size 32×32 with a stride of 32 produce systematic conflicts; the fix is padding to stride 33.

banks:    0   1   2   3  ...  31
           |   |   |   |       |
words:   [B0][B1][B2][B3] ... [B31]   <- a warp reading these is conflict-free
         [B0][B1][B2][B3] ... [B31]   <- next row, same banks

7.6 L2 cache (50 MB)

On H100 the L2 is split into two partitions joined by a high-bandwidth crossbar. It serves all SMs. Useful properties:

  • Persistent access policies (cudaAccessPolicyWindow) let you mark a buffer as "keep this in L2 with high priority"-relevant for KV-caches that fit in 50 MB.
  • L2 hit rate matters: a 50 MB working set that fits in L2 effectively turns HBM bandwidth into L2 bandwidth (~5–7 TB/s) for that data.

7.7 HBM3 (80 GB, ~3 TB/s)

HBM is stacked DRAM connected by a wide silicon interposer. Peak bandwidth is ~3 TB/s (H100 SXM5; H100 PCIe is lower). This is the single number that governs memory-bound kernel performance.

Two consequences:

  • Largest model that fits: 80 GB / (params × bytes-per-param). At BF16 (2 B/param), that's ~40B parameters of weights alone, before activations and gradients.
  • Memory-bound kernel ceiling: an elementwise op on a tensor of size N (BF16, so 2N bytes read + 2N bytes written = 4N bytes) on H100 takes at minimum 4N / 3e12 seconds. For N = 1e9, that's ~1.3 ms-and you cannot go faster regardless of compute.

7.8 Host DRAM and NVMe

Crossing PCIe 5 x16 to host DRAM is ~50 GB/s-60× slower than HBM. This is why model loading is slow, why pinned memory matters (avoids an extra copy through pageable DRAM), and why model.to(device) for a 70 GB model takes ~1.5 s of pure transfer at best.

NVMe via GPUDirect Storage can hit 5–14 GB/s on PCIe 5 SSDs, bypassing the CPU. It is the right tool for streaming datasets but not for hot tensors.


8. Tensor Cores: WMMA, mma.sync, Fragments, Precisions

8.1 Why they exist

A standard FP32 FMA does 2 FLOPs (multiply + add) per cycle per lane. An H100 SM at ~1.8 GHz with 128 FP32 lanes does ~460 GFLOPS of FP32-and 132 SMs gives ~60 TFLOPS FP32. Respectable.

The same SM has 4 Tensor Cores. Each Tensor Core, at BF16, executes a small dense matmul per cycle-far more FLOPs per cycle than the FP32 path because the operation is fused. At full chip: ~989 TFLOPS BF16, ~17× faster than the FP32 path. At FP8: ~1979 TFLOPS dense, ~34× faster.

Ignoring Tensor Cores leaves 90%+ of the chip's arithmetic on the floor. This is non-negotiable for ML.

8.2 What a Tensor Core operation actually is

Conceptually: D = A · B + C where A, B, C, D are small fixed-shape tiles. On Hopper at BF16, a common shape is 16×16×16: A is 16×16, B is 16×16, C and D are 16×16. So each Tensor Core MMA does:

  • 16×16×16 = 4096 multiply-accumulates per instruction
  • = 8192 FLOPs per instruction

In one cycle, across 4 Tensor Cores per SM × 132 SMs × 1.8 GHz, we get ~7800 GFLOPs/cycle × 1.8e9 cycles/s ≈ ~14 PFLOPS of theoretical peak-consistent with the published ~989 TFLOPS BF16 once you account for the actual mma throughput per Tensor Core (not every cycle issues a full 16×16×16; the published number bakes in real issue rates).

8.3 The PTX mma.sync family

PTX is NVIDIA's virtual ISA. The instructions you actually emit (or that ptxas emits) for tensor cores are of the form:

mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 d, a, b, c;

Decoded:

  • `m16n8k16 - D is 16×8, A is 16×16, B is 16×8 (the K dim is shared, =16). MMA shapes vary: m16n8k8, m16n8k16, m16n8k32 (for 8-bit), etc.
  • `row.col - A is row-major in the fragment, B is column-major. Layout matters.
  • `.f32.bf16.bf16.f32 - D type, A type, B type, C type. Here: BF16 inputs, FP32 accumulator. This is the standard mixed-precision recipe.

Each thread in the warp owns a fragment of the tile-a few registers' worth of A, B, C, D each. The 32 threads collaboratively hold the entire 16×16 tile distributed across their register files. The hardware knows the distribution; you must load fragments using matching ldmatrix instructions or via the WMMA API.

8.4 The WMMA C++ API

wmma::fragment<...> is the high-level handle:

#include <mma.h>
using namespace nvcuda::wmma;

fragment<matrix_a, 16, 16, 16, __nv_bfloat16, row_major> a_frag;
fragment<matrix_b, 16, 16, 16, __nv_bfloat16, col_major> b_frag;
fragment<accumulator, 16, 16, 16, float> c_frag;

fill_fragment(c_frag, 0.0f);
load_matrix_sync(a_frag, A_smem_ptr, 16);
load_matrix_sync(b_frag, B_smem_ptr, 16);
mma_sync(c_frag, a_frag, b_frag, c_frag);
store_matrix_sync(D_global_ptr, c_frag, 16, mem_row_major);

Each *_sync call is warp-collective-all 32 lanes must participate. The fragment objects live in registers.

8.5 Supported precisions and what they buy

For Hopper Tensor Cores (4th gen), peak dense throughput approximately scales as:

Type Bits H100 dense TFLOPS (approx) Notes
FP64 64 ~67 Dedicated FP64 Tensor Core path
TF32 19 ~495 NVIDIA's drop-in for FP32 GEMM
BF16 16 ~989 Standard ML training dtype
FP16 16 ~989 Older default; less dynamic range
FP8 8 ~1979 E4M3 (forward), E5M2 (gradients)
INT8 8 ~1979 Quantized inference

(Numbers are H100 SXM5 dense, no sparsity. Verify exact figures with the H100 datasheet.) Each precision halving roughly doubles throughput because each Tensor Core can pack twice as many ops in the same die area-that is the raison d'être of FP8 for inference and FP4 for Blackwell.

FP8 detail. Two formats exist:

  • E4M3: 1 sign + 4 exponent + 3 mantissa. Range ~±448. Used for forward activations and weights-needs precision more than range.
  • E5M2: 1 sign + 5 exponent + 2 mantissa. Range ~±57344. Used for gradients in training-gradients have huge dynamic range.

A "Transformer Engine" library (NVIDIA's TE, plus cuDNN/cuBLAS support) automatically picks per-tensor scaling factors to keep values in range, accumulates in FP32, and chooses E4M3 vs E5M2 based on tensor role.

TF32. A funny format: 19 bits total (1 sign + 8 exponent + 10 mantissa), padded to look like FP32 in registers. The Tensor Core silently truncates the mantissa. Net effect: code that reads "FP32 GEMM" runs at TF32 speed (~495 TFLOPS) on Ampere/Hopper unless you pass torch.backends.cuda.matmul.allow_tf32 = False.

8.6 Data movement is everything

Tensor Cores are so fast that they will starve unless data flows from HBM → L2 → smem → registers fast enough. A typical tile-based GEMM kernel structure:

Persistent loop over output tiles (each block owns one output tile):
  Loop over K dimension in tile-sized chunks:
    1. cp.async (or TMA) loads A-tile and B-tile from HBM into shared memory.
    2. ldmatrix loads fragments from shared memory into registers.
    3. mma.sync accumulates into the output fragment (in registers).
    4. Overlap: while step 3 runs, step 1 of the next K-chunk is already
       in flight via async copy.
  Write output tile from registers to HBM.

The whole CUTLASS/cuBLAS/Triton design space is variations on this skeleton: tile sizes, number of K-stages buffered, smem layout (swizzled to avoid bank conflicts), warp specialization (some warps load, others compute).


9. 2:4 Structured Sparsity and Async Tensor Cores

9.1 The 2:4 sparsity hardware path

Since Ampere (3rd-gen Tensor Cores) and continued in Hopper/Blackwell, Tensor Cores can natively skip multiplications when the weight matrix is 2:4 structured-sparse: in every contiguous group of 4 elements along the K dimension, exactly 2 are zero.

This is a hardware-enforced pattern. The benefit:

  • The compressed weight stores only the 2 nonzero values + a 4-bit metadata mask per group of 4 (so 50% memory).
  • The Tensor Core fetches the 2 nonzeros + the matching 2 lanes of the activation, multiplies, accumulates. The two skipped multiplications are physically not done.
  • Effective throughput doubles: H100 BF16 dense is ~989 TFLOPS, BF16 with 2:4 sparsity is ~1979 TFLOPS.

Caveat: the model must actually have 2:4 sparsity. Tools (NVIDIA ASP) prune dense models to this pattern with QAT-style fine-tuning. Quality recovery is usually possible but not free.

4 contiguous K-elements:    [w0  w1  w2  w3]
2:4 enforced:               [w0   0   0  w3]  <- two zeros in known positions
compressed in memory:       [w0  w3] + 4-bit metadata "10..01"

9.2 Async (warpgroup) Tensor Core operations on Hopper

Hopper introduced the warpgroup MMA (wgmma.mma_async): a Tensor Core operation that operates on a warpgroup of 4 warps (128 threads) and runs asynchronously with respect to the issuing thread. Properties:

  • Inputs A, B can come directly from shared memory (not just registers)-saves register pressure.
  • The accumulator C lives in registers.
  • The instruction returns immediately; you wgmma.wait_group to synchronize before reading C.
  • The Tensor Core can be working while the warp issues additional instructions (e.g., the next cp.async to fetch the next K-tile).

This is the missing piece that lets a Hopper kernel truly overlap data movement with tensor compute. On Ampere, the MMA was synchronous: while it ran, the warp was blocked. On Hopper, the warp can do other work, including issuing more MMAs and queueing more loads. This is why Hopper kernels often use warp specialization: dedicate some warps to loading (issuing TMAs), others to computing (issuing wgmmas), and let them communicate via shared memory and barriers.


10. Async Copy and the Tensor Memory Accelerator (TMA)

10.1 cp.async (Ampere onward)

The instruction cp.async.cg.shared.global copies bytes from global memory to shared memory without going through registers and without blocking the issuing thread.

cp.async.cg.shared.global   [smem_ptr], [global_ptr], 16;
cp.async.commit_group;          // package outstanding cp.async into group
... do other work ...
cp.async.wait_group 0;          // wait for all groups to finish
__syncthreads();                // ensure smem is visible to all threads

Key semantics:

  • commit_group bundles all outstanding cp.async instructions issued by this thread since the last commit into a numbered group.
  • wait_group N waits until at most N groups remain in flight.
  • This pattern enables pipelined / multi-stage smem buffers: while compute uses stage k, stage k+1 is being filled.

A typical 3-stage pipeline:

stage k:    [load HBM->smem]   [load HBM->smem]   [load HBM->smem]
stage k-1:                     [compute on smem]  [compute on smem]
stage k-2:                                        [(done)]
                ^                  ^                  ^
              cycle T            cycle T+1          cycle T+2

10.2 The Tensor Memory Accelerator (Hopper)

cp.async is per-thread. A 128×128 BF16 tile is 32 KB; loading it takes thousands of cp.async instructions across a warp, each doing address arithmetic. That is a lot of instruction-issue overhead.

The TMA is dedicated hardware that takes a single descriptor describing a multi-dimensional tensor (base pointer, dimensions, strides, element type, swizzle pattern) and a tile coordinate, and asynchronously moves the entire tile between HBM and shared memory. One instruction triggers a multi-KB transfer.

// pseudo-PTX
cp.async.bulk.tensor.2d.shared::cluster.global   [smem_ptr], [tma_descr, {x, y}], [mbarrier];
mbarrier.try_wait(mbarrier);

Properties:

  • One thread issues the TMA on behalf of the whole block.
  • An mbarrier (memory barrier object in shared memory) is signaled when bytes arrive.
  • Multi-dimensional indexing is in hardware: the TMA computes addresses for a 2D, 3D, 4D, or 5D tile correctly, including out-of-bounds zero-fill.
  • Swizzle patterns (interleavings of elements within smem) are applied for free, eliminating bank conflicts in the subsequent ldmatrix.
  • TMA also supports HBM→HBM and multicast within a thread block cluster.

10.3 Thread block clusters

Hopper added a new level above thread block: the cluster (up to 16 blocks). Blocks in the same cluster can directly access each other's shared memory (Distributed Shared Memory, DSM) and synchronize. The TMA can multicast a tile to all blocks in the cluster, broadcasting input data with a single HBM read. This is how very large GEMMs amortize input bandwidth across many SMs.

Thread hierarchy (Hopper):
  thread  (1)
    -> warp (32 threads)
      -> warpgroup (4 warps = 128 threads)
        -> block (1..1024 threads, all on one SM)
          -> cluster (up to 16 blocks, on adjacent SMs)
            -> grid (entire kernel launch)

11. Occupancy Theory-Derivation From First Principles

Occupancy = (resident warps per SM) / (maximum warps per SM). On H100 the denominator is 64.

11.1 What constrains residency

Each thread block, once placed on an SM, consumes:

  • R = registers per thread × threads per block (in 32-bit registers)
  • S = shared memory per block (bytes)
  • T = threads per block

The SM has hard limits:

  • R_max = 65536 registers (256 KB / 4 B)
  • S_max = 228 KB (configurable; often slightly less is usable)
  • T_max = 2048 threads
  • B_max = 32 blocks
  • W_max = 64 warps

The number of resident blocks is

B_resident = min( floor(R_max / R),
                  floor(S_max / S),
                  floor(T_max / T),
                  B_max )

Then resident warps = B_resident × ceil(T / 32), and occupancy = resident_warps / W_max.

11.2 Worked example 1-register-limited

A kernel with:

  • 256 threads per block
  • 96 registers per thread
  • 16 KB shared memory per block

Compute each constraint:

  • R = 96 × 256 = 24576 registers/block. floor(65536 / 24576) = 2 blocks.
  • S = 16384 bytes/block. floor(228·1024 / 16384) = floor(14.25) = 14 blocks.
  • T = 256. floor(2048 / 256) = 8 blocks.
  • B_max = 32 blocks.

B_resident = min(2, 14, 8, 32) = 2 blocks. Warps = 2 × (256/32) = 16. Occupancy = 16/64 = 25%. The bottleneck is registers.

To raise occupancy you'd reduce regs/thread (via __launch_bounds__(256, 4), which tells ptxas "I want at least 4 blocks of 256 threads resident, please compile within that register budget"). The compiler will spill some registers to local memory to comply-possibly hurting performance more than the occupancy gain helps.

11.3 Worked example 2-shared-memory-limited

Same kernel but bumped to 80 KB shared memory per block (e.g., a big tile):

  • S = 80 KB/block. floor(228 / 80) = 2 blocks.
  • R = 96 × 256 = 24576. floor(65536 / 24576) = 2 blocks.
  • T limit: 8 blocks.

B_resident = 2 blocks, occupancy = 25%, but now both registers and shared memory are at the 2-block limit. Reducing shared memory wouldn't help unless you also reduce registers.

11.4 Worked example 3-block-count-limited

256 threads/block, 32 regs/thread, 4 KB smem/block:

  • R = 8192. floor(65536/8192) = 8 blocks.
  • S = 4096. floor(228·1024/4096) = 57 blocks.
  • T: 8 blocks.
  • B_max: 32 blocks.

B_resident = min(8, 57, 8, 32) = 8 blocks. Warps = 8 × 8 = 64. Occupancy = 100%.

11.5 Reverse engineering-"given X warps, what register budget?"

If you want at least 8 warps resident in one sub-partition (= 8 blocks of 32 threads each, or equivalently 1 block of 256 threads spreading 2 warps across each sub-partition-be careful with the per-sub-partition accounting), and the sub-partition has 16K registers:

regs/thread × warps × 32 ≤ 16384regs/thread ≤ 16384 / (8 × 32) = 64.

So a 64-reg/thread budget gives exactly 8 warps per sub-partition = 32 warps per SM = 50% occupancy.

11.6 When low occupancy is fine

Occupancy is only a means to an end (latency hiding). It is not the same as performance. Kernels can run at full HBM bandwidth or full Tensor Core throughput at 25% occupancy if:

  • Each warp has lots of independent instructions (high ILP) → less reliance on warp-level parallelism for latency hiding.
  • The bottleneck is HBM bandwidth, not arithmetic, and the few warps already saturate it.
  • The kernel uses async copies / TMA so memory operations don't tie up warps.

CUTLASS and FlashAttention typically run at modest occupancy (30–50%) because their warps are very busy doing useful Tensor Core work. Chasing occupancy by spilling registers usually loses performance. Always profile.

11.7 When low occupancy hurts

  • Kernels with serialized memory waits and no async copy: only way to hide HBM latency is to have many warps.
  • Kernels with frequent __syncthreads() and short between-sync work: a single block doesn't hide much; you need many concurrent blocks.
  • Memory-bound kernels with low ILP per warp.

12.1 Why we need it

A single H100 has 80 GB. A 70B-parameter LLM at BF16 is 140 GB of weights, plus optimizer states (3–5× weights for AdamW), activations, gradients. Training requires multi-GPU. Inference of a 405B model also requires multi-GPU. The interconnect determines how close the multi-GPU system is to a single big GPU.

  • PCIe 5 x16: ~64 GB/s per direction (~128 GB/s bidirectional). Your CPU↔GPU path. Also the only path to non-NVLink machines.
  • NVLink 4 (Hopper): 18 links per H100, each 50 GB/s bidirectional, = ~900 GB/s aggregate per GPU (sum of all 18 links, both directions). About 7× a single PCIe 5 x16.
  • NVLink 3 (Ampere): 12 links × 50 GB/s = 600 GB/s aggregate per A100.
  • NVLink 5 (Blackwell): roughly doubled bandwidth vs NVLink 4 (publicly stated as ~1.8 TB/s aggregate per B200; verify with NVIDIA datasheet).

The 900 GB/s is aggregate to all NVLink peers; if you have 8 peers, each link to a single peer is 900/8 ≈ 112 GB/s, still ~2× PCIe.

12.3 NVSwitch

In a DGX H100 (8 GPUs), you can wire each GPU's 18 NVLinks point-to-point-but with 8 GPUs that's only ~2 links per pair, asymmetric. NVSwitch is a chip that acts as a non-blocking crossbar: every GPU has all 18 links going into NVSwitches, and the switches route any-to-any at full bandwidth. A DGX H100 has 4 NVSwitches.

    GPU0 -+        +- GPU4
    GPU1 -+--NVSw-+- GPU5
    GPU2 -+        +- GPU6
    GPU3 -+        +- GPU7

    Each GPU sees a flat 900 GB/s aggregate to *any* combination of peers.

12.4 NVL8, NVL36, NVL72

NVIDIA's GB200 ("Grace + Blackwell") rack systems chain many GPUs into one NVLink domain via external NVLink switches:

  • NVL8-8 GPUs in one server (DGX/HGX). The classic configuration.
  • NVL36-36 GPUs (typically 18 GB200 superchips × 2 Blackwells each). Single NVLink domain across one rack half.
  • NVL72-72 GPUs in one NVLink domain (GB200 NVL72 rack). Total NVLink bandwidth is staggering (~130 TB/s aggregate). All 72 GPUs see each other as if they were on the same node.

The point of large NVLink domains is to make tensor parallelism and expert parallelism tractable across more GPUs without falling off a bandwidth cliff onto InfiniBand (which is ~50 GB/s per port, ~20× slower than NVLink).

12.5 Reading nvidia-smi topo -m

The topology matrix shows how every GPU pair is connected. Cell legend:

  • X -self
  • NV# - NVLink, where#` is the number of links between the pair (more = higher bw)
  • `PIX - same PCIe switch (no host bridge between)
  • `PXB - multiple PCIe switches, no CPU
  • `PHB - through a host bridge (CPU root complex)
  • NODE— traverses NUMA node
  • `SYS - traverses CPU socket (worst PCIe path)

Example partial output:

        GPU0   GPU1   GPU2   GPU3   GPU4   GPU5   GPU6   GPU7
GPU0     X     NV18   NV18   NV18   NV18   NV18   NV18   NV18
GPU1   NV18     X     NV18   NV18   NV18   NV18   NV18   NV18
...

NV18 means 18 NVLinks between every pair-i.e., a fully-connected NVSwitch fabric. Conversely if you saw SYS between two GPUs, you would know that GPU↔GPU traffic goes over PCIe and across CPU sockets, the slowest possible path.

12.6 Collective communication primitives

NCCL sits on top of NVLink/NVSwitch/IB and provides AllReduce, AllGather, ReduceScatter, Broadcast, AlltoAll. The two performance numbers to know:

  • AllReduce bandwidth ≈ NVLink bw × (n-1)/n × something close to 1 for ring or tree algorithms within an NVLink domain.
  • Cross-domain (over IB) AllReduce is bottlenecked by per-node IB bandwidth-typically 4× 400 Gb/s = ~200 GB/s in modern clusters, ~5× slower than NVLink.

This mismatch is why hierarchical algorithms (intra-node first, then inter-node) dominate.


13. Ada and Blackwell Deltas (with explicit uncertainty)

13.1 Ada Lovelace (RTX 40 series, SM_89)

Ada is the consumer/workstation generation contemporary with Hopper. Per-SM compute and Tensor Cores look similar to Hopper on the surface but Ada is missing key datacenter features:

  • No TMA hardware-you can use cp.async (Ampere-style) but not the multi-dim tensor descriptors.
  • No thread block clusters-no distributed shared memory.
  • No async warpgroup MMA (wgmma)-MMAs are synchronous as on Ampere.
  • No HBM-uses GDDR6X (~1 TB/s on 4090 vs 3 TB/s HBM3 on H100).
  • Throttled FP64-Ada has 2 FP64 cores per SM vs Hopper's 64. Crippling for HPC, irrelevant for ML.
  • Has FP8 Tensor Cores-same E4M3/E5M2 as Hopper.
  • Has 2:4 sparsity.

So Ada is a fine inference card but is a different architecture from Hopper for kernel programming. A FlashAttention kernel written for Hopper (using TMA + wgmma + clusters) needs significant fallback code on Ada.

13.2 Blackwell (B100 / B200 / GB200, SM_100)

Blackwell is the post-Hopper datacenter generation. Public commitments include:

  • 5th-generation Tensor Cores with native FP4 (E2M1) and FP6 support. FP4 enables ~2× throughput vs FP8-quoted as approximately 10–20 PFLOPS dense FP4 on B200 (verify with NVIDIA Blackwell whitepaper; numbers vary by SKU and dense-vs-sparse).
  • 2nd-generation Transformer Engine-extended micro-scaling formats (MXFP8, MXFP6, MXFP4) where each small block of values shares an exponent; enables FP4 inference and FP4-ish training without catastrophic accuracy loss.
  • NVLink 5-~1.8 TB/s aggregate per GPU, ~2× NVLink 4.
  • HBM3e-higher capacity (192 GB on B200) and bandwidth (~8 TB/s; verify) vs Hopper's 80 GB / 3 TB/s.
  • GB200 superchip-1 Grace CPU + 2 Blackwells on one board with NVLink-C2C between Grace and the GPUs (much higher than PCIe).
  • Two-die package-B200 is a multi-die GPU (two compute dies joined by a high-bandwidth on-package interconnect, presented to software as a single GPU).
  • NVL72 racks-72 Blackwells in one NVLink domain.

Uncertainty disclosure. As of authoring, NVIDIA had not published every Blackwell microarchitectural detail with the same fidelity as Hopper's whitepaper. Treat exact TFLOPS numbers, register-file sizes, and shared-memory capacities for Blackwell as approximate-verify with the latest official Blackwell datasheet/whitepaper before relying on them in code. The shape of the architecture (more dies, FP4, NVLink 5, 2nd-gen TE) is committed; precise numbers may shift between announcement and shipping silicon.

13.3 What this means practically

Concern A100 H100 Ada (4090) B100/B200
Best ML training option yes (legacy) yes no (consumer) yes (current)
FP8 inference no yes yes yes
FP4 inference no no no yes
TMA / wgmma kernels no yes no yes (extended)
HBM capacity 40/80 GB 80 GB 24 GB GDDR6X ~192 GB HBM3e
NVLink bw aggregate 600 GB/s 900 GB/s none (or limited) ~1.8 TB/s

14. AMD CDNA / MI300X Contrast

A short, accurate comparison so you know what's the same and what isn't.

14.1 Vocabulary

NVIDIA term AMD CDNA equivalent
Streaming Multiprocessor (SM) Compute Unit (CU)
CUDA core (FP32) SIMD lane (CU has 4 SIMD16 vector units)
Warp (32 threads) Wavefront (64 threads on CDNA-twice as wide)
Tensor Core Matrix Core (MFMA)
NVLink Infinity Fabric (xGMI)
L2 cache L2 (per-XCD on MI300)
HBM HBM (same physical tech)

14.2 MI300X (CDNA 3) specifics

  • 8 XCDs (Accelerator Compute Dies) per package, 304 CUs total-vs H100's 132 SMs. The CUs are smaller individually; aggregate FP64 and matrix throughput are competitive.
  • 192 GB HBM3 (vs H100's 80 GB). Single biggest practical advantage for large-model inference: a 70B BF16 model fits on one MI300X with room for KV cache; on H100 you need 2 GPUs.
  • 5.3 TB/s HBM3 bandwidth (vs H100's ~3 TB/s).
  • Matrix Cores support FP64, FP32, TF32-equivalent, BF16, FP16, INT8, FP8. FP4 is not supported on MI300X (it is reportedly added in MI355).
  • Wavefront = 64 lanes, so divergence dynamics are coarser; tiles and vectorization need adjustment.
  • Infinity Fabric between MI300Xs is roughly comparable to NVLink 4 in per-link bandwidth, but the topology is different (e.g., 8-GPU all-to-all in MI300X servers).

14.3 Software story

ROCm (HIP, rocBLAS, MIOpen, RCCL) targets CDNA. HIP is a near-source-compatible CUDA dialect: cudaMallochipMalloc, with a translation tool (hipify). Kernels written for plain CUDA usually port; kernels written for Hopper-specific features (TMA, wgmma, clusters) do not-those are NVIDIA-exclusive.

Practical rule: MI300X is excellent silicon, often the best choice for very large model inference where 192 GB > 80 GB matters; ecosystem maturity for training and exotic kernels still trails CUDA's. Triton and PyTorch both work on ROCm but with thinner kernel coverage than on CUDA.


15. Five Worked Practical Exercises

Exercise 1-Occupancy on H100

A kernel uses 64 registers per thread, 24 KB of shared memory per block, and launches 256 threads per block. Compute occupancy on H100.

Solution.

Per-block resource use:

  • Registers: 64 × 256 = 16,384 registers.
  • Shared memory: 24 × 1024 = 24,576 bytes.
  • Threads: 256, i.e. 8 warps.

SM limits (H100): 65,536 registers, 228 KB smem (= 233,472 B), 2048 threads, 32 blocks, 64 warps.

Per-resource block caps:

  • Registers: 65,536 / 16,384 = 4 blocks.
  • Shared memory: 233,472 / 24,576 ≈ 9.5 → 9 blocks.
  • Threads: 2048 / 256 = 8 blocks.
  • Block hard cap: 32.

B_resident = min(4, 9, 8, 32) = 4 blocks.

Warps resident: 4 × 8 = 32 warps. Occupancy = 32 / 64 = 50%.

The bottleneck is registers. To raise occupancy you'd compile with __launch_bounds__(256, 6) and accept the spills, or refactor the kernel to use fewer registers (e.g., smaller register-resident tile).


Exercise 2-Roofline on a memory-bound kernel

You write an elementwise BF16 kernel: y = silu(x). Tensor x has 1 billion elements. On H100 (3 TB/s HBM, ~989 TFLOPS BF16), what is the minimum runtime, and what limits it?

Solution.

Bytes moved: read x (2 B/elem) + write y (2 B/elem) = 4 B/elem × 1e9 elem = 4e9 B = 4 GB.

Min time (HBM-bound) = 4e9 / 3e12 ≈ 1.33 ms.

FLOPs done: SiLU ≈ a multiply, an exp, an add, a divide ≈ ~4 FLOPs/elem (the exp via SFU is more expensive). Call it 5 FLOPs × 1e9 = 5e9 FLOPs. At 989 TFLOPS, that's 5e9 / 9.89e14 = ~5 µs of compute.

Compute is negligible (5 µs) vs HBM (1.33 ms). The kernel is memory-bound. Arithmetic intensity = 5 FLOPs / 4 B = 1.25 FLOPs/B, well below the H100 BF16 ridge of ~330 FLOPs/B. No amount of Tensor Core wizardry helps; the ceiling is HBM. Fusing this op into a neighboring GEMM (so x and y are read/written through registers in a fused kernel) is the only way past 1.33 ms.


Exercise 3-Tile size for a tensor-core GEMM

You write a CUDA kernel that loads a 128×64 BF16 tile of A and a 64×128 BF16 tile of B per K-step. How much shared memory does double-buffering require, and is that compatible with 50% occupancy at 256 threads/block?

Solution.

Single tile: A is 128 × 64 × 2 B = 16,384 B = 16 KB. B is 64 × 128 × 2 B = 16 KB. Total per stage: 32 KB.

Double-buffered (2 stages): 64 KB per block.

H100 has 228 KB smem/SM. Blocks per SM by smem: floor(228 / 64) = 3 blocks.

For 50% occupancy = 32 warps = 8 blocks of 4 warps each = 8 blocks of 128 threads each, or 4 blocks of 8 warps each = 4 blocks of 256 threads each.

At 256 threads/block (8 warps), 50% needs 4 blocks resident. Smem allows only 3. So double-buffered 64 KB/block is incompatible with 50% occupancy at 256 threads/block-you get at most 3 × 8 = 24 warps = 37.5%.

Options: (a) shrink the tile (e.g., 128×32 + 32×128 = 16 KB/stage, 32 KB double-buffered → 7 blocks possible, way over). (b) Use a 3-stage pipeline that pays smem to keep Tensor Cores fed more and accept lower occupancy-often the right call on Hopper. (c) Use thread block clusters to share input tiles across multiple blocks via TMA multicast, reducing per-block smem.


Exercise 4-Coalescing analysis

A kernel does y[tid] = x[tid * stride] with stride = 8, BF16. How many HBM transactions per warp, and what's the effective bandwidth utilization?

Solution.

Each lane reads x[tid*8]. Lane addresses (in bytes from base): 0, 16, 32, 48, ..., 16·31 = 496. So the warp's 32 reads span 0..496 + 1 element = 498 B.

A 128-byte sector covers 64 BF16 elements. Lane 0 reads byte 0 (sector 0). Lane 8 reads byte 128 (sector 1). Lane 16 reads byte 256 (sector 2). Lane 24 reads byte 384 (sector 3).

So the warp's loads span 4 sectors, but uses only 32 elements × 2 B = 64 B of the 4 × 128 B = 512 B fetched. Useful data ratio = 64/512 = 12.5%.

Effective bandwidth utilization = 12.5% of HBM peak. To fix: reorganize x so accesses are contiguous (stride 1), or use a transposed kernel that achieves contiguous access via shared-memory staging.


Exercise 5-Why is FlashAttention faster than naive attention?

Naive attention computes Q Kᵀ → softmax → · V by materializing the full N×N attention matrix in HBM. For N = 8192, BF16, what is the HBM traffic, and how does FlashAttention's tile-based approach reduce it?

Solution.

Naive (per attention head):

  • Q is N×d (d=128). Read Q: N·d·2 = 2 MB.
  • K is N×d. Read K: 2 MB.
  • Compute Q Kᵀ → S of shape N×N. N² × 2 B = 8192² × 2 = 128 MB. Write S to HBM.
  • Read S (128 MB). Softmax. Write back P (128 MB). Together: 256 MB.
  • Read P (128 MB), read V (N×d = 2 MB), write O (N×d = 2 MB).

Total HBM bytes ≈ 2 + 2 + 128 + 128 + 128 + 128 + 2 + 2 ≈ 520 MB per head, dominated by the N×N matrix shuffling. For 32 heads: ~16 GB.

FlashAttention:

  • Tile Q in row-blocks of size Br, tile K and V in column-blocks of size Bc, where Br·d + 2·Bc·d fits in shared memory.
  • For each Q tile, stream all K, V tiles through it, computing partial softmax statistics (rowmax, rowsum) online and accumulating O directly.
  • The N×N matrix S never materializes in HBM-it lives only as Br×Bc tiles in shared memory.

HBM bytes for FlashAttention:

  • Read Q once: 2 MB.
  • Read K and V once across all Q tiles via clever reordering: 2 + 2 = 4 MB. (Strictly, K and V are reread per Q tile; with Q outer-loop tiling and recomputed-softmax tricks they are read O(N²·d / (Br·smem)) bytes, but with reasonable Br the multiplier is small-roughly O(N·d) total per K and V given enough smem.)
  • Write O: 2 MB.

Total ≈ ~10 MB per head, ~50× less HBM traffic than naive. Since attention at N=8192 is HBM-bound, this directly translates to ~50× wall-clock speedup at this scale.

The lesson: the GPU memory hierarchy is the algorithm. FlashAttention's mathematical content (online softmax) exists because keeping intermediate state in shared memory-instead of HBM-is the only way to make attention fast at long context. Tile size, smem budget, and the SM architecture chose the algorithm.


Closing Notes

What you should now hold without external references:

  1. The GPU is a throughput machine that hides latency by parallelism, not prediction.
  2. The SM has 4 sub-partitions; each holds up to 16 warps; 64-warp / 2048-thread / 256 KB-register / 228 KB-smem caps on H100.
  3. The memory hierarchy is registers → smem/L1 (~25 cycles) → L2 (~150 cyc) → HBM (~500 cyc) → host (~µs) → NVMe (~10s of µs), and every algorithmic decision is a memory decision.
  4. Tensor Cores do D = A·B + C on small fixed tiles; 4 per SM; precisions BF16, FP16, FP8, INT8 (Hopper), FP4 (Blackwell); structured 2:4 sparsity doubles throughput.
  5. TMA + warpgroup MMA + clusters are Hopper's mechanism for decoupling HBM motion from Tensor Core compute-kernels overlap them via warp specialization.
  6. Occupancy = resident_warps / 64 on H100, derived from min(register, smem, thread, block) caps. Higher is not always faster.
  7. NVLink 4 = 900 GB/s aggregate per H100; NVSwitch makes 8-GPU domains flat; NVL72 makes 72-GPU domains flat for Blackwell.
  8. Ada lacks TMA / wgmma / clusters / HBM. Blackwell adds FP4, NVLink 5, HBM3e, multi-die.
  9. AMD MI300X has 192 GB HBM3 / 304 CUs / 64-lane wavefronts-different shape, similar physics.
  10. Numbers I marked approximate (especially Blackwell figures): always verify with the current NVIDIA datasheet/whitepaper before relying on them in code or capacity planning.

Deep Dive 02: CUDA Programming-From First Kernel to Optimized GEMM

A self-contained reference chapter for the AI Systems curriculum (Month 2). Read this end-to-end and you will be able to write, optimize, and profile CUDA kernels at a level sufficient to read CUTLASS, FlashAttention, and Triton-generated PTX with comprehension. No external dependency required for the core material.

Anti-fabrication note: numbers marked with ~ are realistic ballparks, not invented exact figures. Anything stronger than that is either a hard architectural fact (e.g., warp size = 32) or you should verify against deviceQuery / Nsight on your specific GPU.


Table of Contents

  1. The CUDA programming model
  2. Indexing, launch configuration, and the SM/warp execution model
  3. Memory transfer: host ↔ device, pinned, managed, zero-copy
  4. Streams, events, synchronization
  5. Error handling discipline
  6. Memory access patterns: coalescing, derived from first principles
  7. Shared memory and bank conflicts
  8. Building blocks: vector add, reduction, prefix sum, histogram
  9. Tiled GEMM walkthrough-naive → tensor-core → double-buffered
  10. Tensor cores via nvcuda::wmma
  11. mma.sync PTX inline assembly
  12. Cooperative groups and thread-block clusters
  13. Profiling discipline
  14. A complete, build-and-run-ready BF16 GEMM at 2048×2048
  15. Practical exercises with answer sketches

1. The CUDA Programming Model

1.1 Intuition

A CUDA program runs on two machines at once: a host (CPU) and a device (GPU). The host owns control flow and orchestrates the device the way a conductor cues a section of an orchestra: it allocates device memory, copies data in, launches a function (a kernel) that runs on tens of thousands of GPU threads in parallel, then copies results back. There is no shared address space by default; the two memories are physically distinct (host DRAM and device HBM/GDDR). Modern Unified Memory blurs this, but the mental model of "two machines" is still the right starting point.

A kernel is a function executed N times in parallel by N CUDA threads. The threads are organized into a grid of blocks, and each block is a 1D, 2D, or 3D array of threads. Threads within a block can cooperate (shared memory, barriers); threads in different blocks generally cannot-they may not even be co-resident on hardware at the same time.

1.2 Function-space qualifiers

Three qualifiers tell the compiler where a function runs and where it can be called from:

Qualifier Runs on Callable from Notes
__global__ Device Host (and device on CC ≥ 3.5 via dynamic parallelism) This is a kernel. Must return void.
__device__ Device Device Inlined aggressively.
__host__ Host Host Default if unmarked.
__host__ __device__ Both Both Same source compiled twice.
__device__ float square(float x) { return x * x; }     // GPU helper
__global__ void apply(float* y, const float* x, int n) // kernel
{
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < n) y[i] = square(x[i]);
}
int main() { /* host code */ }                         // implicit __host__

1.3 Launch syntax

A kernel is launched with the chevron syntax:

kernel<<<gridDim, blockDim, dynamicSharedBytes, stream>>>(args...);
  • gridDim: a dim3 giving the number of blocks along x, y, z.
  • blockDim: a dim3 giving the number of threads per block along x, y, z.
  • dynamicSharedBytes: bytes of extern __shared__ memory to allocate per block. Defaults to 0.
  • stream: a cudaStream_t to enqueue into. Defaults to the default stream (stream 0), which has special legacy semantics described in §4.

Total threads launched = gridDim.x * gridDim.y * gridDim.z * blockDim.x * blockDim.y * blockDim.z. Hard limits (verify with deviceQuery for your GPU): blockDim.x*y*z ≤ 1024, gridDim.x ≤ 2^31-1.

1.4 The SM/warp execution model

A GPU is a collection of Streaming Multiprocessors (SMs). On H100 there are 132 SMs and 4 warp schedulers per SM; on A100 there are 108 SMs and also 4 schedulers per SM. Each SM has a register file (~256 KB on Ampere/Hopper), an L1/shared scratchpad (~228 KB on H100, configurable), warp schedulers, INT/FP32/FP64 pipes, and tensor cores.

When you launch a kernel, the GPU's hardware work distributor assigns whole blocks to SMs. Each SM can hold multiple blocks resident simultaneously, up to limits dictated by registers per thread, shared memory per block, and hardware caps (max 2048 resident threads per SM on most architectures, ~32–64 resident warps depending on SM).

Once on an SM, a block is sliced into warps of 32 threads each. The warp is the unit of scheduling and instruction issue. All 32 threads in a warp execute the same instruction at the same time-this is SIMT (Single Instruction Multiple Threads). When threads diverge (different sides of a branch), the warp executes both paths serially, masking inactive lanes. From Volta onward, threads have independent program counters (Independent Thread Scheduling), so the model is more nuanced, but warp-level ops still operate on full warps with explicit masks.

The warp size of 32 is hard-baked into hardware, PTX, and tooling. Treat it as a constant of the universe.

1.5 Why this matters

Three rules fall directly out of the model and you will rederive them constantly:

  1. Block size should be a multiple of 32. Any other choice wastes lanes in the trailing warp. 128 or 256 are good defaults.
  2. Memory accesses should be warp-coalesced (§6). The 32 threads of a warp issue one load instruction; if their addresses fall in a single 128-byte aligned segment, the memory subsystem services it in one transaction.
  3. Shared memory has 32 banks (§7). Patterns that map 32 threads of a warp to 32 distinct banks run at full speed; collisions serialize.

Micro-exercise 1.1

Why must a kernel return void? Answer sketch: the chevron launch is asynchronous-by the time the host instruction stream proceeds past the launch, the kernel has likely not run. There is no synchronous return value to capture. Outputs go through pointers to device memory.


2. Indexing & Launch Configuration

2.1 Built-in variables

Inside a kernel you have:

  • gridDim -dimensions of the grid (in blocks).
  • `blockDim - dimensions of a block (in threads).
  • `blockIdx - this block's coordinate within the grid.
  • threadIdx— this thread's coordinate within its block.
  • `warpSize - always 32.

2.2 Computing a global thread ID

For a 1D grid of 1D blocks:

int gtid = blockIdx.x * blockDim.x + threadIdx.x;

For a 2D grid of 2D blocks (e.g., per-pixel image kernel):

int x = blockIdx.x * blockDim.x + threadIdx.x;
int y = blockIdx.y * blockDim.y + threadIdx.y;
if (x < width && y < height) { /* ... */ }

The bounds check matters because grid sizes are typically rounded up:

dim3 block(16, 16);                                              // 256 threads
dim3 grid((width + 15) / 16, (height + 15) / 16);                // ceil-div
saxpy2d<<<grid, block>>>(img, width, height);

2.3 Choosing block size-heuristics

  • Multiple of 32 (warp size). Otherwise lanes are wasted.
  • 128 or 256 is the sweet spot for memory-bound kernels.
  • 256–1024 is typical for compute-bound kernels-but bigger blocks reduce the number of resident blocks per SM, which reduces the scheduler's ability to hide latency by switching warps. Run the occupancy calculator or use cudaOccupancyMaxPotentialBlockSize to pick an optimal pair given per-thread register and shared-memory usage.
int minGrid, blockSize;
cudaOccupancyMaxPotentialBlockSize(&minGrid, &blockSize,
                                   (void*)myKernel, /*smem*/0, /*N*/0);
int grid = (n + blockSize - 1) / blockSize;
myKernel<<<grid, blockSize>>>(...);

2.4 Grid-stride loops

Decoupling problem size from grid size is the right idiom: launch a fixed grid (typically 2 * SMcount * blocksPerSM), let each thread loop:

__global__ void saxpy_grid_stride(int n, float a, const float* x, float* y) {
    int stride = gridDim.x * blockDim.x;
    for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += stride) {
        y[i] = a * x[i] + y[i];
    }
}

This kernel handles any n without recompiling and tends to be faster on arrays whose size doesn't divide nicely into the grid.

Micro-exercise 2.1

A kernel accesses a 1D array of 1,000,003 floats (a prime). What block and grid sizes do you launch? Answer sketch: block=256, grid=ceil(1000003/256)=3907. The last warp has 3 active threads × 0 trailing-actually 1000003 mod 256 = 3, so 3 active lanes in the last warp. Add if (i < n) guard.

nvcc invocation

nvcc -O3 -arch=sm_90 -lineinfo saxpy.cu -o saxpy   # Hopper (H100)
nvcc -O3 -arch=sm_80 saxpy.cu -o saxpy             # Ampere (A100)
- lineinfo` keeps source-line mappings without enabling debug optimizations.


3. Memory Transfer

3.1 The four ways to get bytes onto the GPU

API Allocator Locality Transfer cost Use case
cudaMalloc + cudaMemcpy Device DRAM Device PCIe / NVLink Default for hot data.
cudaMallocHost Pinned host RAM Host (page-locked) Faster cudaMemcpy Staging buffer for async copies.
cudaMallocManaged (UVM) Either Migrates on access Page-fault driven Prototyping, pointer sharing.
Zero-copy (cudaHostAllocMapped) Pinned host Device sees host directly PCIe per access Tiny, rare GPU reads of host data.

3.2 The classic pattern

const int N = 1 << 20;
size_t bytes = N * sizeof(float);

float *h_x = (float*)malloc(bytes);
float *d_x;  cudaMalloc(&d_x, bytes);

cudaMemcpy(d_x, h_x, bytes, cudaMemcpyHostToDevice);
kernel<<<grid, block>>>(d_x, N);
cudaMemcpy(h_x, d_x, bytes, cudaMemcpyDeviceToHost);

cudaFree(d_x);  free(h_x);

cudaMemcpy on the default stream is host-blocking for D2H and H2D when the host pointer is pageable; the host thread does not return until the copy is enqueued and (for D2H) complete. This is the most common source of "my kernel is slow"-the kernel is fine, the copy is dominating.

3.3 Pinned (page-locked) host memory

The OS may swap pageable memory; the GPU's DMA engine therefore must stage through a driver-pinned bounce buffer. Allocating already-pinned memory cuts that out:

float* h_x;
cudaMallocHost(&h_x, bytes);   // pinned
// ... fill h_x ...
cudaMemcpyAsync(d_x, h_x, bytes, cudaMemcpyHostToDevice, stream);

Practical effect: PCIe Gen4 x16 has a theoretical peak of ~32 GB/s unidirectional. With pageable memory you typically see ~6–12 GB/s; with pinned you see ~22–28 GB/s. Pinning is mandatory for cudaMemcpyAsync to actually overlap with kernels-otherwise the driver silently falls back to sync.

Caveat: pinned memory is a scarce OS resource. Don't pin gigabytes unnecessarily.

3.4 Unified Memory (cudaMallocManaged)

float* x;
cudaMallocManaged(&x, bytes);
for (int i = 0; i < N; ++i) x[i] = i;        // host writes
kernel<<<grid, block>>>(x, N);                // device reads-pages migrate
cudaDeviceSynchronize();
printf("%f\n", x[0]);                          // host reads-pages migrate back
cudaFree(x);

The driver fault-handles page migrations on demand. On Pascal+ this is fully demand-paged at 4 KB granularity; on pre-Pascal the entire allocation migrated on touch. UVM is great for prototyping but page faults are expensive (~tens of μs each), so production code typically uses explicit cudaMemPrefetchAsync to push data ahead of time, or just falls back to explicit cudaMemcpy.

3.5 Performance ballparks (H100 + PCIe5; verify on your hardware)

  • HBM3 device-to-device (cudaMemcpy(D2D)): ~2.5–3 TB/s.
  • PCIe Gen5 x16 H↔D pinned: ~50–55 GB/s.
  • Pageable H↔D: ~half that, plus driver overhead.
  • Zero-copy reads from device: bound by PCIe per access, latency ~1 μs.

Micro-exercise 3.1

You repeatedly run a kernel on the same input. Where should the input live? Answer: in device memory (cudaMalloc), copied once. If the input is streamed from disk each iteration, use cudaMemcpyAsync from a pinned staging buffer overlapped with compute (§4).


4. Streams, Events, Synchronization

4.1 Streams as queues

A stream is a FIFO queue of GPU work. Operations enqueued into the same stream execute in order; operations in different streams may execute concurrently (subject to hardware resources).

cudaStream_t s1, s2;
cudaStreamCreate(&s1);
cudaStreamCreate(&s2);

cudaMemcpyAsync(dA, hA, bytes, cudaMemcpyHostToDevice, s1);
kernelA<<<g, b, 0, s1>>>(dA);

cudaMemcpyAsync(dB, hB, bytes, cudaMemcpyHostToDevice, s2);
kernelB<<<g, b, 0, s2>>>(dB);
// kernelA and kernelB may run concurrently if the SMs have room

4.2 The default stream's special semantics

Stream 0 (the default / NULL / legacy stream) is implicitly synchronizing: a launch into the default stream waits for all prior work in every stream of the same device, and all subsequent stream work waits for it. This is hostile to concurrency.

Two ways out:

  1. Compile with - -default-stream per-thread` so each host thread gets its own non-blocking default stream.
  2. Always use explicit non-default streams created with cudaStreamCreate (or cudaStreamCreateWithFlags(&s, cudaStreamNonBlocking) to make a stream not synchronize against the default stream).

4.3 Events for timing and dependencies

cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);

cudaEventRecord(start, stream);
kernel<<<g, b, 0, stream>>>(...);
cudaEventRecord(stop, stream);

cudaEventSynchronize(stop);             // host waits for stop
float ms;
cudaEventElapsedTime(&ms, start, stop); // resolution ~0.5 μs

Events also express cross-stream dependencies:

cudaEventRecord(eA, sA);                // record after kernelA
cudaStreamWaitEvent(sB, eA, 0);         // sB cannot proceed until eA
kernelB<<<g, b, 0, sB>>>(...);          // depends on kernelA

4.4 Overlapping copy and compute

The classic three-stream pipeline for a chunked workload:

const int CHUNK = 1 << 20;
const int NSTREAMS = 3;
cudaStream_t s[NSTREAMS];
for (int i = 0; i < NSTREAMS; ++i) cudaStreamCreate(&s[i]);

for (int chunk = 0; chunk < nChunks; ++chunk) {
    int k = chunk % NSTREAMS;
    cudaMemcpyAsync(d_in + chunk*CHUNK, h_in + chunk*CHUNK,
                    CHUNK*sizeof(float), cudaMemcpyHostToDevice, s[k]);
    process<<<grid, block, 0, s[k]>>>(d_in + chunk*CHUNK,
                                       d_out + chunk*CHUNK, CHUNK);
    cudaMemcpyAsync(h_out + chunk*CHUNK, d_out + chunk*CHUNK,
                    CHUNK*sizeof(float), cudaMemcpyDeviceToHost, s[k]);
}
for (int i = 0; i < NSTREAMS; ++i) cudaStreamSynchronize(s[i]);

Why three streams: a GPU typically has two DMA engines (one each direction) plus the compute engine. Three streams let H2D, kernel, and D2H run concurrently across chunks. With pinned host buffers, total wall time collapses from T_h2d + T_compute + T_d2h toward max(T_h2d, T_compute, T_d2h).

Micro-exercise 4.1

You time a kernel as cudaEventRecord(start), kernel launch, cudaEventRecord(stop), cudaEventElapsedTime(&ms, start, stop), and get a suspiciously small number. What did you forget? Answer: cudaEventSynchronize(stop) before reading the elapsed time. Without it, the event may not have completed.


5. Error Handling Discipline

CUDA reports errors in two channels:

  1. The return code of every API call.
  2. A sticky per-thread error state for asynchronous failures, queryable via cudaGetLastError() (which clears it) or cudaPeekAtLastError() (which doesn't).

Kernel launches return success if the launch configuration is valid; an in-kernel out-of-bounds access surfaces only on the next sync point. So:

#define CUDA_CHECK(call) do {                                                 \
    cudaError_t _e = (call);                                                  \
    if (_e != cudaSuccess) {                                                  \
        fprintf(stderr, "CUDA error %s:%d: %s\n",                             \
                __FILE__, __LINE__, cudaGetErrorString(_e));                  \
        std::abort();                                                         \
    }                                                                         \
} while (0)

#define CUDA_LAUNCH_CHECK() do {                                              \
    CUDA_CHECK(cudaPeekAtLastError());                                        \
    CUDA_CHECK(cudaDeviceSynchronize());                                      \
} while (0)

Use:

CUDA_CHECK(cudaMalloc(&d_x, bytes));
CUDA_CHECK(cudaMemcpy(d_x, h_x, bytes, cudaMemcpyHostToDevice));
mykernel<<<g, b>>>(d_x);
CUDA_LAUNCH_CHECK();    // catches both launch and runtime kernel errors

In release builds you may want to skip cudaDeviceSynchronize (it kills async pipelining); make CUDA_LAUNCH_CHECK peek-only and rely on compute-sanitizer during development:

compute-sanitizer --tool memcheck ./myprog
compute-sanitizer --tool racecheck ./myprog

Common pitfalls

  • Sticky errors: if any earlier call failed, every subsequent call returns the old error until you clear it with cudaGetLastError.
  • Asynchronous errors: a kernel writing OOB will not surface until you next sync. Always sync in tests.
  • Multi-GPU: CUDA errors are per-thread per-context. Setting the wrong device with cudaSetDevice produces silent wrongness, not an error.

6. Memory Access Patterns: Coalescing

6.1 The coalescing rule, derived

Global memory is served by the L2 cache and HBM in 32-byte and 128-byte sectors. When a warp issues a load instruction, the hardware computes the union of the 32 thread addresses and issues the minimum number of sector fetches to cover them.

The ideal case: 32 threads of a warp load 32 consecutive 4-byte floats starting at a 128-byte aligned address. One 128-byte transaction services the entire warp. Achieved bandwidth = peak.

The worst case (for 4-byte loads): 32 threads load 32 floats, each in a different 128-byte sector. The hardware issues 32 sector loads to fetch 128 bytes of data-a 32× waste. Achieved bandwidth ≈ peak / 32.

6.2 Concrete: coalesced vs strided

__global__ void coalesced(const float* x, float* y, int n) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < n) y[i] = x[i] * 2.0f;
}

__global__ void strided(const float* x, float* y, int n, int stride) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    int j = i * stride;          // stride 32 → each thread in a different sector
    if (j < n) y[j] = x[j] * 2.0f;
}

For stride=1, achieved ~80–95% of HBM peak (verify; H100 HBM3 peak ~3 TB/s, so realistic ~2.4–2.8 TB/s). For stride=32, achieved ~3–10% of peak because each warp issues 32 sector loads where the coalesced version issues 1. Concretely, on a 1 GiB array:

  • Coalesced runtime ≈ 1 GiB / 2.5 TB/s ≈ 0.4 ms.
  • Strided runtime ≈ 32× that ≈ 13 ms. (Approximate-L2 hits dampen the worst case.)

6.3 Alignment

Device pointers from cudaMalloc are aligned to at least 256 bytes. But if you offset a pointer by, say, 1 element, the warp loads cross a sector boundary and now need 2 transactions instead of 1. Achieved BW falls by ~50%. This is why CUTLASS aligns leading dimensions of GEMM tiles.

6.4 Vectorized loads

float4 (16 bytes) and int4 lets a single thread issue a 16-byte load, so a warp issues 32 × 16 = 512 bytes in 4 transactions of 128 B. This is identical bandwidth to scalar coalesced loads but reduces instruction count, often improving compute-bound kernels.

__global__ void copy_vec(const float4* x, float4* y, int n4) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < n4) y[i] = x[i];      // 16 B per thread, 512 B per warp
}

Caveat: n must be divisible by 4 and pointers 16-byte aligned (true for cudaMalloc).

Micro-exercise 6.1

A 2D row-major matrix A[M][N] is processed with threadIdx.x indexing columns. Is A[row][threadIdx.x] coalesced? Answer: yes-consecutive threads access consecutive columns, which are consecutive in memory in row-major. If threadIdx.x indexed rows instead (A[threadIdx.x][col]) that's a stride-N access and bad.


7. Shared Memory and Bank Conflicts

7.1 The bank model

Shared memory is divided into 32 banks. Each bank serves one 4-byte word per cycle. A warp accessing 32 distinct banks completes in one cycle. Two threads hitting the same bank with different addresses serialize — that's a 2-way bank conflict, and the warp issue takes 2 cycles. Up to 32-way is possible (worst case: every thread the same bank).

Bank index = (address / 4) mod 32. So:

__shared__ float s[32];
float v = s[threadIdx.x];   // thread t reads bank t-no conflict
__shared__ float s[32 * 32];
float v = s[threadIdx.x * 32];  // all threads bank 0-32-way conflict

7.2 The transpose problem

A naive 32×32 shared-memory tile transpose:

__shared__ float tile[32][32];
tile[threadIdx.y][threadIdx.x] = in[...];   // store: coalesced, no conflict
__syncthreads();
out[...] = tile[threadIdx.x][threadIdx.y];  // load:  stride-32, 32-way conflict

The store has threadIdx.x varying along the inner dimension (banks 0..31, no conflict). The load reads tile[threadIdx.x][threadIdx.y] - fixingthreadIdx.yand varyingthreadIdx.xreadstile[0][y], tile[1][y], ...` which sit at addresses differing by 32 floats = 128 bytes = same bank. 32-way conflict.

7.3 The padding fix

Add a phantom column:

__shared__ float tile[32][33];   // 33, not 32

Now tile[k][y] is at k*33 + y floats from base. Bank index = (k*33 + y) mod 32 = (k + y) mod 32 since 33 mod 32 = 1. As k varies 0..31 with y fixed, the bank stride is 1-every thread hits a different bank. Conflict gone, at the cost of 32 floats × 4 bytes = 128 B of waste per tile.

__global__ void transpose(const float* in, float* out, int N) {
    __shared__ float tile[32][33];
    int x = blockIdx.x * 32 + threadIdx.x;
    int y = blockIdx.y * 32 + threadIdx.y;
    if (x < N && y < N) tile[threadIdx.y][threadIdx.x] = in[y*N + x];
    __syncthreads();
    int xt = blockIdx.y * 32 + threadIdx.x;   // swap block coords
    int yt = blockIdx.x * 32 + threadIdx.y;
    if (xt < N && yt < N) out[yt*N + xt] = tile[threadIdx.x][threadIdx.y];
}

7.4 Dynamic shared memory

If the size is known only at launch, use extern __shared__:

__global__ void k(int n) {
    extern __shared__ float buf[];
    // ... use buf[0..n-1]
}
k<<<grid, block, n*sizeof(float)>>>(n);

The third chevron argument supplies the byte count.

7.5 Capacity

H100 has up to ~228 KB of combined L1+shared per SM, configurable via cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, …). A100 has ~164 KB. Default per-block limit is 48 KB unless opted up.

Micro-exercise 7.1

A warp does s[threadIdx.x * 2] on a __shared__ float s[64]. Conflict? Answer: threads 0..15 hit banks 0,2,4,...,30 (no conflicts within), threads 16..31 hit banks 0,2,4,...,30 again-but in shared-memory bank-conflict arithmetic, threads 16..31 collide with threads 0..15 pairwise. 2-way conflict.


8. Building Blocks

8.1 Vector add

__global__ void vadd(const float* a, const float* b, float* c, int n) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < n) c[i] = a[i] + b[i];
}

This is purely bandwidth-bound: 3 × 4 = 12 bytes touched per output, ~1 FLOP per output → arithmetic intensity 1/12 FLOP/byte. On H100 (~3 TB/s, ~67 TFLOP/s FP32), the roofline says ~250 GFLOP/s-this kernel achieves near-peak HBM bandwidth, not near-peak FLOPs.

8.2 Reduction-naive

Sum N floats. Naive in-block reduction:

__global__ void reduce_naive(const float* x, float* out, int n) {
    extern __shared__ float sdata[];
    int tid = threadIdx.x;
    int i = blockIdx.x * blockDim.x + tid;
    sdata[tid] = (i < n) ? x[i] : 0.0f;
    __syncthreads();
    for (int s = 1; s < blockDim.x; s *= 2) {
        if (tid % (2*s) == 0) sdata[tid] += sdata[tid + s];
        __syncthreads();
    }
    if (tid == 0) atomicAdd(out, sdata[0]);
}

Problems: divergence inside the warp (tid % (2*s)) wastes lanes; bank conflicts on sdata accesses.

8.3 Reduction-sequential addressing

for (int s = blockDim.x / 2; s > 0; s >>= 1) {
    if (tid < s) sdata[tid] += sdata[tid + s];
    __syncthreads();
}

Now active threads are 0..s-1, contiguous, no divergence within active warps. Bank conflicts gone (stride-1).

8.4 Reduction-warp shuffle

The last 6 levels of reduction (s = 32 down to 1) execute within a single warp. Warp shuffle (__shfl_down_sync) lets a thread read another lane's register without going through shared memory:

__device__ float warp_reduce_sum(float v) {
    for (int off = 16; off > 0; off >>= 1)
        v += __shfl_down_sync(0xffffffff, v, off);
    return v;     // lane 0 holds the full warp sum
}

__global__ void reduce_shfl(const float* x, float* out, int n) {
    __shared__ float sm[32];                       // one slot per warp in block
    int tid = threadIdx.x;
    int lane = tid & 31;
    int wid  = tid >> 5;
    int i = blockIdx.x * blockDim.x + tid;
    float v = (i < n) ? x[i] : 0.0f;
    v = warp_reduce_sum(v);
    if (lane == 0) sm[wid] = v;
    __syncthreads();
    v = (tid < blockDim.x / 32) ? sm[lane] : 0.0f;
    if (wid == 0) v = warp_reduce_sum(v);
    if (tid == 0) atomicAdd(out, v);
}

Performance evolution (approximate, 256-thread block, 1 GiB input, H100):

Variant Time Notes
Naive divergence ~3.0 ms warp divergence + bank conflicts
Sequential addr. ~1.2 ms clean shared-mem reduction
Warp shuffle ~0.55 ms near HBM-bound
+ grid-stride loop ~0.40 ms 1 launch, fewer atomics

8.5 Cooperative groups (optional cleaner API)

#include <cooperative_groups.h>
namespace cg = cooperative_groups;

__global__ void reduce_cg(const float* x, float* out, int n) {
    auto block = cg::this_thread_block();
    auto warp  = cg::tiled_partition<32>(block);
    int i = block.group_index().x * block.size() + block.thread_rank();
    float v = (i < n) ? x[i] : 0.0f;
    v = cg::reduce(warp, v, cg::plus<float>());
    if (warp.thread_rank() == 0) atomicAdd(out, v);
}

8.6 Inclusive prefix sum (Hillis–Steele within a block)

__global__ void prefix_sum(float* x, int n) {
    extern __shared__ float s[];
    int tid = threadIdx.x;
    s[tid] = (tid < n) ? x[tid] : 0.0f;
    __syncthreads();
    for (int off = 1; off < blockDim.x; off <<= 1) {
        float v = (tid >= off) ? s[tid - off] : 0.0f;
        __syncthreads();
        s[tid] += v;
        __syncthreads();
    }
    if (tid < n) x[tid] = s[tid];
}

For full-array scan you do block scans, scan the per-block totals, then add back. Or use thrust::inclusive_scan.

8.7 Histogram with privatization

The naive histogram has every thread atomicAdd(&hist[bin], 1) on global — a contention disaster. Privatize per-block in shared memory, merge at end:

__global__ void hist_priv(const int* data, int* hist, int n, int B) {
    extern __shared__ int sh[];
    for (int i = threadIdx.x; i < B; i += blockDim.x) sh[i] = 0;
    __syncthreads();
    int tid = blockIdx.x * blockDim.x + threadIdx.x;
    int stride = gridDim.x * blockDim.x;
    for (int i = tid; i < n; i += stride) atomicAdd(&sh[data[i]], 1);
    __syncthreads();
    for (int i = threadIdx.x; i < B; i += blockDim.x)
        atomicAdd(&hist[i], sh[i]);
}

Speedup over naive global atomics: typically 10–50×, larger as bin counts shrink.


9. Tiled GEMM Walkthrough

We compute C = A · B with A: MxK, B: KxN, C: MxN, all row-major. We'll evolve the kernel through six stages, each with its expected speedup over the naive baseline on a square problem like 4096×4096.

9.1 Stage 0-Naive

Each thread computes one output element:

__global__ void gemm_naive(const float* A, const float* B, float* C,
                           int M, int N, int K) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    if (row < M && col < N) {
        float acc = 0.f;
        for (int k = 0; k < K; ++k) acc += A[row*K + k] * B[k*N + col];
        C[row*N + col] = acc;
    }
}

Bandwidth analysis: each thread reads 2K floats and writes 1, performs 2K FLOPs → arithmetic intensity 2K / (8K) = 0.25 FLOP/byte. Bandwidth bound at ~750 GFLOP/s on H100 (HBM ~3 TB/s × 0.25). Reality is worse because the access pattern through A is row-major coalesced, but the access through B is strided (column traversal of a row-major matrix).

Baseline: ~1× (by definition). Achieves on the order of ~1–3% of cuBLAS.

9.2 Stage 1-Coalesced

Swap thread mapping so consecutive threadIdx.x walks consecutive columns of B (which are not contiguous), but at least the writes to C are coalesced. The naive code above is already laid out so that threadIdx.xcol and C[row*N + col] is coalesced. The remaining problem is B[k*N + col] - for fixedk, varyingcol` is contiguous, so it's already coalesced. The real fix is reusing each loaded byte across threads, which only shared memory delivers.

Speedup over naive: small (~1.2–1.5×) just from cache effects.

9.3 Stage 2-Shared-memory tiled

Each block computes a BM × BN tile of C. It cooperatively loads tiles of A and B into shared memory, then iterates. With BM = BN = BK = 32:

template<int BM, int BN, int BK>
__global__ void gemm_smem(const float* A, const float* B, float* C,
                          int M, int N, int K) {
    __shared__ float sA[BM][BK];
    __shared__ float sB[BK][BN];
    int row = blockIdx.y * BM + threadIdx.y;
    int col = blockIdx.x * BN + threadIdx.x;
    float acc = 0.f;
    for (int kt = 0; kt < K; kt += BK) {
        sA[threadIdx.y][threadIdx.x] =
            (row < M && kt + threadIdx.x < K) ? A[row*K + kt + threadIdx.x] : 0.f;
        sB[threadIdx.y][threadIdx.x] =
            (kt + threadIdx.y < K && col < N) ? B[(kt + threadIdx.y)*N + col] : 0.f;
        __syncthreads();
        #pragma unroll
        for (int k = 0; k < BK; ++k) acc += sA[threadIdx.y][k] * sB[k][threadIdx.x];
        __syncthreads();
    }
    if (row < M && col < N) C[row*N + col] = acc;
}

Each tile of A and B is loaded once from HBM and reused 32× from shared memory. Arithmetic intensity rises to ~16 FLOP/byte. On H100 FP32 (~67 TFLOP/s peak), this can reach ~30–40% of cuBLAS FP32. Speedup over naive: ~10–15×.

Note: the K-loop accumulator sA[ty][k] * sB[k][tx] has bank conflicts on sA[ty][k] (all threads in a row hit the same bank). Padding sA[BM][BK+1] fixes it for some tile shapes.

9.4 Stage 3-Register tiling (1D and 2D)

Have each thread compute multiple outputs, holding partial sums in registers. With TM = TN = 8, each thread now produces an 8×8 micro-tile, so a BM=128, BN=128 block uses 128*128/(8*8) = 256 threads-convenient.

Sketch (key inner loop):

constexpr int BM=128, BN=128, BK=8, TM=8, TN=8;
__shared__ float sA[BM][BK];
__shared__ float sB[BK][BN];
float regA[TM], regB[TN], acc[TM][TN] = {0};
// ... load sA, sB cooperatively ...
__syncthreads();
#pragma unroll
for (int k = 0; k < BK; ++k) {
    #pragma unroll
    for (int i = 0; i < TM; ++i) regA[i] = sA[ty*TM + i][k];
    #pragma unroll
    for (int j = 0; j < TN; ++j) regB[j] = sB[k][tx*TN + j];
    #pragma unroll
    for (int i = 0; i < TM; ++i)
        #pragma unroll
        for (int j = 0; j < TN; ++j) acc[i][j] += regA[i] * regB[j];
}

Each register-resident regA[i] * regB[j] is one FMA. With TM=TN=8 we generate 64 FMAs per k step from 8+8=16 register loads → 4 FMAs/load, matching SM dispatch ratios. Arithmetic intensity now in the hundreds of FLOP/byte. Expected: ~50–70% of cuBLAS FP32. Speedup over naive: ~30–40×.

9.5 Stage 4-Tensor cores (wmma, BF16 in / FP32 out)

On Volta+ each SM has tensor cores that compute matrix-multiply-accumulate on small fragments per warp. The fundamental fragment shape supported across generations for BF16 is m=16, n=16, k=16 with FP32 accumulator (we'll call this the "16x16x16" shape). Each warp does one tensor-core op per mma_sync.

#include <mma.h>
using namespace nvcuda::wmma;

__global__ void gemm_wmma(const __nv_bfloat16* A, const __nv_bfloat16* B,
                          float* C, int M, int N, int K) {
    int warpM = (blockIdx.y * blockDim.y + threadIdx.y);
    int warpN = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
    fragment<matrix_a, 16,16,16, __nv_bfloat16, row_major> a_frag;
    fragment<matrix_b, 16,16,16, __nv_bfloat16, row_major> b_frag;
    fragment<accumulator, 16,16,16, float> c_frag;
    fill_fragment(c_frag, 0.0f);
    for (int k = 0; k < K; k += 16) {
        load_matrix_sync(a_frag, A + warpM*16*K + k, K);
        load_matrix_sync(b_frag, B + k*N + warpN*16, N);
        mma_sync(c_frag, a_frag, b_frag, c_frag);
    }
    store_matrix_sync(C + warpM*16*N + warpN*16, c_frag, N, mem_row_major);
}

This is the simplest possible tensor-core GEMM-one warp produces one 16×16 output tile and walks the K-dimension with no shared-memory tiling. Useful as a teaching example, but it misses the data-reuse win. To get serious throughput, combine wmma with shared-memory tiling (stage 5+): load tiles of A and B into shared memory, iterate wmma fragments out of those tiles, accumulate per-warp.

Expected speedup once integrated with shared-memory tiling and BF16 inputs: ~5–10× over the FP32 register-tiled stage 3, simply because BF16 tensor-core peak on H100 (~990 TFLOP/s sparse, ~495 TFLOP/s dense) dwarfs FP32 (~67 TFLOP/s).

9.6 Stage 5-Double-buffered with cp.async (Ampere+)

cp.async (PTX: cp.async.cg.shared.global) initiates a global → shared copy that doesn't block the issuing thread; it commits asynchronously. You overlap the next tile's load with the current tile's compute.

#include <cuda/pipeline>
#include <cooperative_groups/memcpy_async.h>
namespace cg = cooperative_groups;

extern __shared__ char smem_raw[];
auto sA = reinterpret_cast<__nv_bfloat16(*)[BK]>(smem_raw);
auto sB = reinterpret_cast<__nv_bfloat16(*)[BN]>(smem_raw + sizeof(*sA)*BM*2);

__pipeline_memcpy_async(&sA[buf][...], &A[...], bytes);
__pipeline_memcpy_async(&sB[buf][...], &B[...], bytes);
__pipeline_commit();
__pipeline_wait_prior(0);     // wait for the buffer we'll consume next
__syncthreads();
// ... wmma on sA[buf], sB[buf], while next iter prefetches into sA[1-buf] ...

By keeping two shared-memory buffers and pre-issuing the next load while the math fragment runs, the kernel overlaps HBM latency with tensor-core compute. Final speedup on H100 with proper pipelining and Hopper-specific TMA/wgmma: cuBLAS-class performance (~70–95% of cuBLAS for square problems > 1024).

9.7 Cumulative speedup table (4096³ BF16, H100-approximate)

Stage Approx. % of cuBLAS Notes
0. Naive FP32 ~1–3% bandwidth-bound, strided B
1. Coalesced FP32 ~3–5% minor
2. Shared-mem tiled FP32 ~25–40% classic GEMM tiling
3. Register tiled FP32 ~50–70% OK for FP32
4. wmma BF16 (no tiling) N/A demo only
4'. wmma BF16 + tiling ~40–60% of BF16 cuBLAS
5. + cp.async doublebuf ~70–90% of BF16 cuBLAS
6. Hopper wgmma + TMA ~90–98% what CUTLASS does

10. Tensor Cores via nvcuda::wmma

10.1 Fragment types

fragment<USE, M, N, K, T, LAYOUT> frag;
  • USE: matrix_a, matrix_b, or accumulator.
  • M, N, K: shape; the canonical, supported-everywhere combination is 16,16,16 for FP16/BF16 → FP32. Other shapes (8,32,16; 32,8,16) exist on some archs.
  • T: element type. __half, __nv_bfloat16 for inputs; float, __half, int32_t for accumulators.
  • LAYOUT: row_major or col_major (only for A and B fragments).

A fragment is opaque across threads of a warp-internally it's a strided slice of registers. The element-by-element layout is unspecified; you only interact with it through the API.

10.2 The four operations

fill_fragment(frag, value);
load_matrix_sync(frag, ptr, leadingDim);
mma_sync(d, a, b, c);                       // d = a*b + c
store_matrix_sync(ptr, frag, leadingDim, layout);

load_matrix_sync requires the pointer to be 16-byte aligned and the leading dim to be a multiple of 8 (FP16/BF16). The "sync" suffix is a warp-wide collective; all 32 threads must reach it with the same args.

10.3 Supported precisions (canonical)

Inputs A,B Accum Notes
__half __half/float Volta+
__nv_bfloat16 float Ampere+
TF32 (FP32 fed in) float Ampere+, via wmma_precision_tf32
int8 int32 Turing+
FP8 (E4M3, E5M2) float Hopper+ (use wgmma/CUTLASS in practice)

10.4 Layout interplay

For canonical 16x16x16 BF16: a warp's 32 threads cooperatively hold a 16×16 fragment (256 elements / 32 lanes = 8 elements/lane). The layout is opaque, but load_matrix_sync knows how to materialize it from row- or col-major memory with the leading-dim arg. Always pass the true leading dim of the parent matrix-passing the tile width is a common bug.

Micro-exercise 10.1

You have an FP32 matrix and want to use BF16 tensor cores. What do you do? Answer: either pre-convert with a kernel that emits BF16 (__float2bfloat16), or use TF32 fragments (wmma::precision::tf32) which take FP32 input and truncate the mantissa internally.


11. mma.sync PTX Inline Assembly

wmma is a high-level wrapper. CUTLASS and FlashAttention drop down to mma.sync PTX for tile sizes the wrapper doesn't expose, layout flexibility, or to interleave multiple MMAs to hide latency.

A canonical Ampere BF16 16×8×16 MMA (note: 16×8×16, not 16×16×16-the PTX shape is finer):

__device__ inline void mma_m16n8k16_bf16(
    float d[4],
    const uint32_t a[4],     // packed BF16, 8 elements per uint32 pair
    const uint32_t b[2],
    const float c[4])
{
    asm volatile(
        "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
        "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
        : "=f"(d[0]), "=f"(d[1]), "=f"(d[2]), "=f"(d[3])
        : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
          "r"(b[0]), "r"(b[1]),
          "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
}

Reading the mnemonic:

  • `mma.sync - synchronous warp-collective MMA.
  • `aligned - all threads must execute together with consistent operands.
  • `m16n8k16 - shape: 16×8 output tile, K-dim 16.
  • `row.col - A is row-major, B is col-major (in fragment layout terms).
  • `f32.bf16.bf16.f32 - D, A, B, C types.

Each lane carries part of the fragment; the layout is documented in PTX ISA. CUTLASS abstracts this with cutlass::arch::Mma templates so you don't write the asm by hand for production code, but it's important to recognize it when reading library code.

Hopper introduces wgmma.mma_async (warp-group MMA), which operates over 128 threads (4 warps) and runs asynchronously, freeing the warp group to issue more work while the tensor core churns. CUTLASS 3.x and cuBLAS use this internally for peak Hopper performance.


12. Cooperative Groups & Thread-Block Clusters

12.1 Cooperative groups

The cooperative_groups namespace provides typed handles for thread collectives:

#include <cooperative_groups.h>
namespace cg = cooperative_groups;

__global__ void k() {
    auto block = cg::this_thread_block();
    auto warp  = cg::tiled_partition<32>(block);
    auto half  = cg::tiled_partition<16>(warp);
    // collective ops:
    int v = cg::reduce(warp, lane_value, cg::plus<int>());
    block.sync();      // == __syncthreads()
}

Useful when you want to write code that doesn't hard-code warp size 32 (in case of future architectures or simulating smaller tiles).

12.2 Grid-wide synchronization

If you launch with cudaLaunchCooperativeKernel, you can call cg::this_grid().sync() for a barrier across all blocks of the grid. The constraint: the entire grid must fit on the GPU concurrently. Useful for multi-pass algorithms that previously required separate kernel launches.

12.3 Thread-block clusters (Hopper)

Hopper introduces a new level above the block: the cluster, a group of nearby blocks that share Distributed Shared Memory (DSMEM) and can synchronize collectively.

#include <cooperative_groups.h>
namespace cg = cooperative_groups;

__global__ void __cluster_dims__(2, 2, 1) cluster_k() {
    auto cluster = cg::this_cluster();
    cluster.sync();             // sync all blocks in the cluster
    int rank = cluster.block_rank();
    // distributed shared memory:
    __shared__ float buf[1024];
    auto neighbor = cluster.map_shared_rank(buf, /*rank*/0);
    // now `neighbor` points into block-0's shared memory
}

This unlocks new patterns like cross-block tile sharing in GEMM without going through HBM, which is part of how Hopper hits its tensor-core peak.


13. Profiling Discipline

13.1 Timing with events (the right way)

cudaEvent_t s, e;
cudaEventCreate(&s); cudaEventCreate(&e);
// warmup
for (int i = 0; i < 3; ++i) kernel<<<g,b>>>(...);
cudaDeviceSynchronize();
cudaEventRecord(s);
for (int i = 0; i < N_ITERS; ++i) kernel<<<g,b>>>(...);
cudaEventRecord(e);
cudaEventSynchronize(e);
float ms;  cudaEventElapsedTime(&ms, s, e);
double avg_ms = ms / N_ITERS;

Always: - Warm up (first launch is JIT-y and miss-y). - Loop and average. - Sync before reading. - Match the stream events are recorded in to the kernel's stream.

13.2 Common pitfalls

  • Forgetting to sync before stopping the host clock. Kernel launches return immediately; CPU chrono::now() measures launch overhead, not kernel time. Use events.
  • Default-stream interactions make cross-stream timing measure serialized work even when you wrote concurrent code. Check with Nsight Systems' timeline view.
  • JIT compile on first launch can add seconds. Pre-build for your arch ( - arch=sm_90`) and ignore the first iteration.
  • Power state: GPUs throttle. Pin clocks (nvidia-smi -pm 1 -lgc …) for reproducible benchmarks.

13.3 nsys (Nsight Systems)-system-wide timeline

Captures a Gantt chart of CUDA API calls, kernels, and memcopies across all streams. Best first stop:

nsys profile --stats=true -o report ./myprog
nsys-ui report.nsys-rep        # GUI

Use to answer: are H2D, kernel, D2H actually overlapped? Is the default stream serializing things?

13.4 ncu (Nsight Compute)-per-kernel deep dive

ncu --set full -k mygemm -o profile ./myprog
ncu-ui profile.ncu-rep

Tells you achieved occupancy, achieved DRAM throughput, L1/L2 hit rates, shared-memory bank conflicts, warp stall reasons. The "Speed of Light" section tells you, for each kernel, how close to peak compute and peak memory you are-if both are far from 100%, you have latency-hiding problems (too few resident warps, dependencies).

13.5 Reading roofline

For any kernel, compute: - Arithmetic intensity = FLOPs / bytes touched. - Achievable = min(peak_compute, AI × peak_BW).

If AI < peak_compute / peak_BW (the "ridge point"), the kernel is memory-bound and you should think about reuse, not FLOP throughput. For H100 BF16 tensor cores: ridge ≈ 990 TFLOP/s ÷ 3 TB/s ≈ 330 FLOP/byte. Almost nothing reaches that without aggressive shared-memory tiling.


14. A Complete BF16 GEMM at 2048×2048

Below is a single self-contained file targeting Ampere+ (sm_80). It includes shared-memory tiling, wmma tensor cores, and a simple double-buffered cp.async pipeline. On A100 we observe ~60–70% of cuBLAS BF16 for M=N=K=2048. (Actual depends heavily on tile choice and hardware-verify with ncu.)

gemm_bf16.cu:

// Build:
//   nvcc -O3 -arch=sm_80 -lineinfo -lcublas gemm_bf16.cu -o gemm_bf16
// Run:
//   ./gemm_bf16

#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <mma.h>
#include <cublas_v2.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <vector>

#define CUDA_CHECK(call) do {                                    \
    cudaError_t e_ = (call);                                     \
    if (e_ != cudaSuccess) {                                     \
        fprintf(stderr, "CUDA %s:%d: %s\n", __FILE__, __LINE__,  \
                cudaGetErrorString(e_));                         \
        std::exit(1);                                            \
    }                                                            \
} while (0)

using nvcuda::wmma::fragment;
using nvcuda::wmma::matrix_a;
using nvcuda::wmma::matrix_b;
using nvcuda::wmma::accumulator;
using nvcuda::wmma::row_major;
using nvcuda::wmma::col_major;
using nvcuda::wmma::load_matrix_sync;
using nvcuda::wmma::mma_sync;
using nvcuda::wmma::store_matrix_sync;
using nvcuda::wmma::fill_fragment;
using nvcuda::wmma::mem_row_major;

// Tile config: BM × BN output tile per block; BK along K.
// One warp computes a WM × WN sub-tile.
constexpr int BM = 128, BN = 128, BK = 32;
constexpr int WM = 64,  WN = 64;
constexpr int WMMA_M = 16, WMMA_N = 16, WMMA_K = 16;
constexpr int WARPS_PER_BLOCK = (BM/WM) * (BN/WN);   // 2*2 = 4
constexpr int THREADS_PER_BLOCK = WARPS_PER_BLOCK * 32; // 128

// `cp.async` 16-byte copy, 4-byte aligned source/destination.
__device__ __forceinline__ void cp_async_16(void* smem_ptr, const void* gmem_ptr) {
#if __CUDA_ARCH__ >= 800
    unsigned int smem_int = __cvta_generic_to_shared(smem_ptr);
    asm volatile("cp.async.cg.shared.global [%0], [%1], 16;\n"
                 :: "r"(smem_int), "l"(gmem_ptr));
#else
    *reinterpret_cast<int4*>(smem_ptr) = *reinterpret_cast<const int4*>(gmem_ptr);
#endif
}
__device__ __forceinline__ void cp_async_commit() {
#if __CUDA_ARCH__ >= 800
    asm volatile("cp.async.commit_group;\n" ::);
#endif
}
__device__ __forceinline__ void cp_async_wait_all() {
#if __CUDA_ARCH__ >= 800
    asm volatile("cp.async.wait_all;\n" ::);
#endif
}

__global__ void gemm_bf16_kernel(const __nv_bfloat16* __restrict__ A,
                                 const __nv_bfloat16* __restrict__ B,
                                 float* __restrict__ C,
                                 int M, int N, int K)
{
    // Two shared-mem buffers per matrix for double-buffering.
    __shared__ __nv_bfloat16 sA[2][BM][BK];
    __shared__ __nv_bfloat16 sB[2][BK][BN];

    int tid       = threadIdx.x;
    int warp_id   = tid / 32;
    int lane_id   = tid % 32;
    int warp_row  = warp_id / (BN / WN);   // 0 or 1
    int warp_col  = warp_id % (BN / WN);

    int block_row = blockIdx.y * BM;
    int block_col = blockIdx.x * BN;

    // Each thread brings 16B = 8 BF16 from A and 8 from B per K-tile.
    // 128 threads × 8 = 1024 elements per tile; BM*BK = 128*32 = 4096 → 4 iters,
    // BK*BN = 32*128 = 4096 → 4 iters. We unroll inline below.

    auto load_tile = [&](int buf, int kt) {
        // Load sA[buf][BM][BK] from A[block_row..+BM][kt..+BK]
        constexpr int A_TILE = BM * BK;       // 4096
        constexpr int A_PER_THREAD = A_TILE / THREADS_PER_BLOCK / 8;  // 4
        #pragma unroll
        for (int i = 0; i < A_PER_THREAD; ++i) {
            int idx = (i * THREADS_PER_BLOCK + tid) * 8;
            int r = idx / BK;
            int c = idx % BK;
            const void* gptr = &A[(block_row + r) * K + (kt + c)];
            void* sptr = &sA[buf][r][c];
            cp_async_16(sptr, gptr);
        }
        constexpr int B_TILE = BK * BN;       // 4096
        constexpr int B_PER_THREAD = B_TILE / THREADS_PER_BLOCK / 8;  // 4
        #pragma unroll
        for (int i = 0; i < B_PER_THREAD; ++i) {
            int idx = (i * THREADS_PER_BLOCK + tid) * 8;
            int r = idx / BN;
            int c = idx % BN;
            const void* gptr = &B[(kt + r) * N + (block_col + c)];
            void* sptr = &sB[buf][r][c];
            cp_async_16(sptr, gptr);
        }
        cp_async_commit();
    };

    // Per-warp output fragments.
    constexpr int FRAG_M = WM / WMMA_M;   // 4
    constexpr int FRAG_N = WN / WMMA_N;   // 4
    fragment<accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag[FRAG_M][FRAG_N];
    #pragma unroll
    for (int i = 0; i < FRAG_M; ++i)
        #pragma unroll
        for (int j = 0; j < FRAG_N; ++j) fill_fragment(c_frag[i][j], 0.0f);

    // Prefetch first tile.
    int kt = 0;
    int buf = 0;
    load_tile(buf, kt);
    cp_async_wait_all();
    __syncthreads();

    for (kt = BK; kt < K; kt += BK) {
        // Issue next tile prefetch into the other buffer.
        int next_buf = buf ^ 1;
        load_tile(next_buf, kt);

        // Compute on current buffer.
        #pragma unroll
        for (int kk = 0; kk < BK; kk += WMMA_K) {
            fragment<matrix_a, WMMA_M, WMMA_N, WMMA_K, __nv_bfloat16, row_major> a_frag[FRAG_M];
            fragment<matrix_b, WMMA_M, WMMA_N, WMMA_K, __nv_bfloat16, row_major> b_frag[FRAG_N];
            #pragma unroll
            for (int i = 0; i < FRAG_M; ++i) {
                int row = warp_row * WM + i * WMMA_M;
                load_matrix_sync(a_frag[i], &sA[buf][row][kk], BK);
            }
            #pragma unroll
            for (int j = 0; j < FRAG_N; ++j) {
                int col = warp_col * WN + j * WMMA_N;
                load_matrix_sync(b_frag[j], &sB[buf][kk][col], BN);
            }
            #pragma unroll
            for (int i = 0; i < FRAG_M; ++i)
                #pragma unroll
                for (int j = 0; j < FRAG_N; ++j)
                    mma_sync(c_frag[i][j], a_frag[i], b_frag[j], c_frag[i][j]);
        }

        cp_async_wait_all();
        __syncthreads();
        buf = next_buf;
    }

    // Tail compute on last loaded buffer.
    #pragma unroll
    for (int kk = 0; kk < BK; kk += WMMA_K) {
        fragment<matrix_a, WMMA_M, WMMA_N, WMMA_K, __nv_bfloat16, row_major> a_frag[FRAG_M];
        fragment<matrix_b, WMMA_M, WMMA_N, WMMA_K, __nv_bfloat16, row_major> b_frag[FRAG_N];
        #pragma unroll
        for (int i = 0; i < FRAG_M; ++i)
            load_matrix_sync(a_frag[i], &sA[buf][warp_row*WM + i*WMMA_M][kk], BK);
        #pragma unroll
        for (int j = 0; j < FRAG_N; ++j)
            load_matrix_sync(b_frag[j], &sB[buf][kk][warp_col*WN + j*WMMA_N], BN);
        #pragma unroll
        for (int i = 0; i < FRAG_M; ++i)
            #pragma unroll
            for (int j = 0; j < FRAG_N; ++j)
                mma_sync(c_frag[i][j], a_frag[i], b_frag[j], c_frag[i][j]);
    }

    // Store back.
    #pragma unroll
    for (int i = 0; i < FRAG_M; ++i) {
        #pragma unroll
        for (int j = 0; j < FRAG_N; ++j) {
            int row = block_row + warp_row * WM + i * WMMA_M;
            int col = block_col + warp_col * WN + j * WMMA_N;
            store_matrix_sync(&C[row*N + col], c_frag[i][j], N, mem_row_major);
        }
    }
    (void)lane_id;
}

// Reference using cuBLAS for correctness + perf comparison.
static void cublas_gemm(cublasHandle_t h,
                        const __nv_bfloat16* dA, const __nv_bfloat16* dB,
                        float* dC, int M, int N, int K) {
    float alpha = 1.f, beta = 0.f;
    // cuBLAS is column-major. To compute C = A*B with row-major inputs,
    // call cublasGemmEx with op_N, op_N and swap the operand order:
    //   C^T = B^T * A^T, treated as col-major.
    cublasGemmEx(h, CUBLAS_OP_N, CUBLAS_OP_N,
                 N, M, K,
                 &alpha,
                 dB, CUDA_R_16BF, N,
                 dA, CUDA_R_16BF, K,
                 &beta,
                 dC, CUDA_R_32F, N,
                 CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}

int main() {
    const int M = 2048, N = 2048, K = 2048;
    size_t aBytes = (size_t)M*K*sizeof(__nv_bfloat16);
    size_t bBytes = (size_t)K*N*sizeof(__nv_bfloat16);
    size_t cBytes = (size_t)M*N*sizeof(float);

    std::vector<__nv_bfloat16> hA(M*K), hB(K*N);
    std::vector<float> hC(M*N), hRef(M*N);
    for (int i = 0; i < M*K; ++i) hA[i] = __float2bfloat16((float)((i*7) % 13) * 0.01f);
    for (int i = 0; i < K*N; ++i) hB[i] = __float2bfloat16((float)((i*5) % 11) * 0.01f);

    __nv_bfloat16 *dA, *dB; float *dC, *dRef;
    CUDA_CHECK(cudaMalloc(&dA, aBytes));
    CUDA_CHECK(cudaMalloc(&dB, bBytes));
    CUDA_CHECK(cudaMalloc(&dC, cBytes));
    CUDA_CHECK(cudaMalloc(&dRef, cBytes));
    CUDA_CHECK(cudaMemcpy(dA, hA.data(), aBytes, cudaMemcpyHostToDevice));
    CUDA_CHECK(cudaMemcpy(dB, hB.data(), bBytes, cudaMemcpyHostToDevice));

    dim3 grid(N/BN, M/BM);
    dim3 block(THREADS_PER_BLOCK);

    // Warmup
    for (int i = 0; i < 3; ++i)
        gemm_bf16_kernel<<<grid, block>>>(dA, dB, dC, M, N, K);
    CUDA_CHECK(cudaDeviceSynchronize());

    cudaEvent_t s, e;
    cudaEventCreate(&s); cudaEventCreate(&e);
    const int ITERS = 50;
    cudaEventRecord(s);
    for (int i = 0; i < ITERS; ++i)
        gemm_bf16_kernel<<<grid, block>>>(dA, dB, dC, M, N, K);
    cudaEventRecord(e);
    cudaEventSynchronize(e);
    float ms; cudaEventElapsedTime(&ms, s, e);
    double avg_ms = ms / ITERS;
    double tflops = 2.0 * M * N * K / (avg_ms * 1e-3) / 1e12;
    printf("Custom: %.3f ms,  %.2f TFLOP/s\n", avg_ms, tflops);

    cublasHandle_t h; cublasCreate(&h);
    for (int i = 0; i < 3; ++i) cublas_gemm(h, dA, dB, dRef, M, N, K);
    CUDA_CHECK(cudaDeviceSynchronize());
    cudaEventRecord(s);
    for (int i = 0; i < ITERS; ++i) cublas_gemm(h, dA, dB, dRef, M, N, K);
    cudaEventRecord(e);
    cudaEventSynchronize(e);
    cudaEventElapsedTime(&ms, s, e);
    double cb_ms = ms / ITERS;
    double cb_tflops = 2.0 * M * N * K / (cb_ms * 1e-3) / 1e12;
    printf("cuBLAS: %.3f ms,  %.2f TFLOP/s\n", cb_ms, cb_tflops);
    printf("Ratio:  %.1f%% of cuBLAS\n", 100.0 * cb_tflops / tflops > 100 ? 100.0 : 100.0 * tflops / cb_tflops);

    // Correctness spot-check (sample 16 elements).
    CUDA_CHECK(cudaMemcpy(hC.data(),  dC,   cBytes, cudaMemcpyDeviceToHost));
    CUDA_CHECK(cudaMemcpy(hRef.data(),dRef, cBytes, cudaMemcpyDeviceToHost));
    double max_rel = 0;
    for (int i = 0; i < M*N; i += (M*N/16)) {
        double r = std::abs(hC[i]-hRef[i]) / (std::abs(hRef[i]) + 1e-6);
        if (r > max_rel) max_rel = r;
    }
    printf("Max sampled rel error vs cuBLAS: %.3e\n", max_rel);

    cublasDestroy(h);
    cudaFree(dA); cudaFree(dB); cudaFree(dC); cudaFree(dRef);
    return 0;
}

Build:

nvcc -O3 -arch=sm_80 -lineinfo -lcublas gemm_bf16.cu -o gemm_bf16
./gemm_bf16

Annotations on the design:

  • Tile: 128×128 output × 32 K-block per iteration. Each block has 4 warps, each warp owns a 64×64 output sub-tile, materialized as a 4×4 grid of 16×16 tensor-core fragments.
  • Threads per block: 128 = 4 warps × 32. Modest-leaves room for 4–8 blocks per SM (occupancy-bounded by shared-memory: each block uses 2 × (128×32 + 32×128) × 2 B = 32 KB of smem).
  • cp.async double-buffering: while the tensor cores compute on sA[buf] / sB[buf], the next K-tile is fetched into the other buffer with no thread blocking. cp.async.wait_all + __syncthreads() flips the buffer when both the compute and the prefetch are done.
  • No bank-conflict fix: at 32-wide BF16 rows the natural layout already hits 32 different banks per warp load through load_matrix_sync. If you change BK you may need to add padding.
  • cuBLAS comparison: cuBLAS is column-major, so we swap the operand order to compute the row-major equivalent, a standard trick.

What this kernel doesn't do (that CUTLASS / cuBLAS do): swizzled shared-mem layout to fully avoid bank conflicts at all tile shapes; multi-stage cp.async pipeline (3+ buffers) for deeper latency hiding; wgmma on Hopper; epilogue fusion (bias, ReLU, scale); split-K for skinny problems. Implementing each of those is what closes the remaining 30–40% gap.


15. Practical Exercises

Exercise 1-Saturating PCIe

Write a benchmark that measures cudaMemcpy H→D bandwidth for transfer sizes 1 KB to 1 GB (powers of 2), with and without pinned host memory, and plots the curve. Answer sketch: allocate once outside the timing loop; cudaMallocHost for pinned; warm up; time with events; expect pinned to plateau near PCIe peak (~25 GB/s Gen4 / ~50 GB/s Gen5) starting ~1 MB, while pageable plateaus lower.

Exercise 2-Fix a coalescing bug

The kernel below is 8× slower than expected. Why?

__global__ void scale(float* x, int N) {
    int i = threadIdx.x * gridDim.x * blockDim.x / blockDim.x + blockIdx.x;
    if (i < N) x[i] *= 2.f;
}
Answer: i = threadIdx.x * gridDim.x + blockIdx.x - consecutive threads stride bygridDim.x, so loads fromxgo through different sectors. Fix:i = blockIdx.x * blockDim.x + threadIdx.x`.

Exercise 3-Eliminate bank conflicts

Profile the transpose kernel (§7.2) without padding using ncu and verify shared_load_transactions_per_request ≈ 32. Add the [32][33] padding and verify it drops to 1. Answer sketch: ncu --set full -k transpose ./prog; look at the "Memory Workload Analysis → Shared Memory" pane.

Exercise 4-Reduction with fewer atomics

Modify reduce_shfl (§8.4) to do a two-stage reduction: per-block partial into an array partials[gridDim.x], then a second kernel reduces partials. Compare against the single-kernel atomicAdd version. Answer sketch: saves the contended atomics; for n = 2^28 the two-kernel version is often ~10–20% faster on H100 because atomics hit L2 contention.

Exercise 5-Naive vs tiled GEMM

Write FP32 gemm_naive (§9.1) and gemm_smem (§9.3) for M=N=K=2048. Time both. Compare to cuBLAS. Answer sketch: expect naive ~1–2% of cuBLAS, tiled ~25–40%, with cuBLAS BF16 itself being ~10× cuBLAS FP32 on tensor-core-capable hardware. Use the tiled FP32 number as the FP32 roofline reference.

Exercise 6-wmma correctness

Implement a wmma - based 16×16×16 GEMM (§9.5) forM=N=K=64and verify against a CPU reference. Catch: BF16 has 7 mantissa bits-accept1e-2relative error, not1e-6. **Answer sketch:** allocate BF16 inputs / FP32 output; CPU computes in float; compare element-wise; rel error budget5e-2` per element is realistic for random inputs in [0,1].


Closing thoughts

You've now seen, end-to-end, how a CUDA kernel becomes a tensor-core GEMM: the launch model and execution machinery, the memory hierarchy and the access patterns it punishes or rewards, the shared-memory bank discipline, the building-block kernels every CUDA author reimplements once, and a six-stage GEMM evolution that mirrors how CUTLASS got built. From here:

  • Read CUTLASS's cute layouts-you now have the model to understand tensor algebra over fragments.
  • Read FlashAttention's CUDA-block-tiled GEMM + softmax fused, with exactly the cp.async pipeline pattern you implemented in §14.
  • Read PyTorch `aten/src/ATen/native/cuda/*.cu - production CUDA written for correctness first, perf second.
  • For Hopper-class peak, learn wgmma, TMA, and Distributed Shared Memory; CUTLASS 3.x is the canonical reference.

When in doubt: profile with ncu, look at the Speed-of-Light, and let the achieved-vs-peak gap tell you whether your kernel is bandwidth-bound, compute-bound, or latency-bound. Optimization without measurement is just guessing; with measurement, every step here is mechanical.

— end of deep dive 02 —

Triton: A Deep Dive into the Block-Level GPU DSL

A self-contained reference. After working through this chapter you should be able to read, write, debug, autotune, and integrate Triton kernels into a PyTorch project without consulting any other source. We assume you know Python, basic linear algebra, what a GPU is, and roughly what CUDA does (threads, blocks, shared memory, warps). We do not assume you have ever written a CUDA kernel by hand.


Table of contents

  1. Why Triton exists
  2. The programming model
  3. Memory operations: load, store, mask, broadcast
  4. Math operations: elementwise, tl.dot, reductions, transcendentals
  5. Compile-time constants and specialization
  6. Autotuning
  7. Numerical stability patterns (online softmax derivation, Welford, log-sum-exp)
  8. Six fully-annotated real kernels
  9. The compilation pipeline (Python -> MLIR -> PTX)
  10. Integration with PyTorch (torch.library, torch.compile, autograd)
  11. Common pitfalls
  12. Triton vs CUTLASS vs hand-rolled CUDA
  13. Six exercises with worked answer sketches
  14. Closing notes

1. Why Triton exists

1.1 The CUDA per-thread model

CUDA exposes the GPU as a hierarchy of threads organised into warps (32 threads), then blocks (one or more warps), then a grid of blocks. The programmer writes code from the perspective of a single thread:

// CUDA: each thread handles a single output element of c = a + b
__global__ void vec_add(const float* a, const float* b, float* c, int n) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < n) c[i] = a[i] + b[i];
}

This per-thread view is faithful to the hardware (an SM really does dispatch warps, each a SIMD-32 unit), but for a kernel author it forces a constant mental gear shift. When you think "tile of A times tile of B equals tile of C", you are reasoning block-wise. CUDA makes you re-derive that block-level operation as a sequence of per-thread loads, per-thread __shared__ writes, __syncthreads() barriers, per-thread MMA inputs, accumulator registers, mask handling at the trailing edge of the matrix, etc. Every one of those steps is a place to make a memory-coalescing, bank-conflict, or register-pressure mistake. Worse: changing the tile size invalidates almost all of that hand-crafted indexing.

1.2 The Triton thesis (MAPL 2019)

Tillet, Kung, and Cox (MAPL 2019, "Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations") observe that almost every high-performance GPU kernel for ML is naturally written as a loop over blocks of the output, where each iteration computes one output tile from input tiles. They propose: let the programmer write that loop directly, with block-shaped tensors as first-class values, and let a compiler lower the block program to a per-thread CUDA-like form -- choosing vector widths, shared-memory layouts, double-buffered pipelining, swizzling, and tensor core instructions automatically.

Concretely, the human writes (paraphrased):

Program instance pid is responsible for output rows pid*BM ... pid*BM+BM. Allocate a block-tensor acc of shape [BM, BN]. Loop over the K dimension in blocks of BK: load an [BM, BK] tile of A, an [BK, BN] tile of B, do acc += dot(a, b). After the loop, store acc to C.

The compiler turns that into hundreds of lines of well-vectorised PTX with correct shared-memory layout, asynchronous copies, and mma instructions.

1.3 What Triton gives up, and why that is fine

Triton intentionally cannot express:

  • arbitrary intra-block thread-divergent control flow,
  • explicit __syncthreads(),
  • user-controlled shared memory,
  • warp-level shuffles as a primary interface (some primitives exist, but they are not the model).

For ML kernels this is rarely a loss: 95% of useful kernels are tiled elementwise / reduce / matmul / softmax compositions, and for those the compiler's choices are at least as good as a moderately experienced CUDA author and far less error-prone.

1.4 Where the productivity comes from

Three places:

  1. Indexing is automatic. You write tl.arange(0, BLOCK) and broadcast; the compiler generates the per-thread index arithmetic.
  2. Shared memory is automatic. tl.load of a 2D block lowers to coalesced global -> shared copies with the right swizzle for downstream tl.dot.
  3. Tile-size sweeps are cheap. Because block sizes are constexpr, the compiler specialises and you can autotune over a config grid in one decorator.

The empirical claim of the original paper, repeated through every release: hand-written Triton matmul reaches ~80% of cuBLAS on common square shapes and frequently beats cuBLAS on awkward shapes (small K, non-multiple-of-16 dims) where the vendor library has no specialised kernel.


2. The programming model

2.1 @triton.jit and the program-instance abstraction

A Triton kernel is a Python function decorated with @triton.jit:

import triton
import triton.language as tl

@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n_elements,
               BLOCK: tl.constexpr):
    pid = tl.program_id(axis=0)                    # 0..num_programs-1
    offsets = pid * BLOCK + tl.arange(0, BLOCK)    # vector of indices
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    tl.store(out_ptr + offsets, x + y, mask=mask)

The decorator does not run the function on the host. It captures the Python AST, lowers it to Triton IR (an MLIR dialect), then to PTX, on first call. The body of the function executes on the GPU.

Inside the body, you are one program instance. There is no threadIdx. The total number of program instances is set when you launch:

n = x.numel()
grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),)   # 1D grid
add_kernel[grid](x, y, out, n, BLOCK=1024)

grid is either a tuple (1D/2D/3D) or a callable that receives the compile-time meta-parameters and returns a tuple. triton.cdiv(a, b) is ceil-division.

A program instance is roughly equivalent to a CUDA block, not a thread. Inside it, the compiler will allocate num_warps * 32 actual threads and distribute the block-level work across them. You influence that by passing num_warps= at launch (or via autotune).

2.2 Block-tensors as first-class values

tl.arange(0, N) and any expression built from it -- arithmetic, broadcasting, tl.load, tl.dot, reductions -- produces a block tensor whose shape is known at compile time. These tensors live in registers (with the compiler spilling to shared memory or even global as a last resort).

offs_m = tl.arange(0, BM)                 # shape [BM]
offs_n = tl.arange(0, BN)                 # shape [BN]
offs_2d = offs_m[:, None] * BN + offs_n[None, :]   # shape [BM, BN]

x[:, None] and x[None, :] are NumPy-style broadcasts; the compiler generates the per-thread index arithmetic that makes a 2D logical tile out of registers held by 32-wide warps.

2.3 Side-by-side: vector add

Same kernel in CUDA:

__global__ void add_kernel(const float* x, const float* y, float* out, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) out[idx] = x[idx] + y[idx];
}
// host:
add_kernel<<<(n + 255) / 256, 256>>>(x, y, out, n);

The Triton version is structurally similar but works at the vector level: one program does 1024 elements at once; the compiler decides how to assign those 1024 lanes across 32 threads (probably 4 elements per thread, with 128-bit loads). You did not have to choose. The CUDA author who wants the same vectorisation has to use float4 casts and do four-way unrolling by hand.

Exercise (warm-up). Modify add_kernel so it computes out = a*x + y (axpy). What changes? Answer: add an a scalar argument and replace the store value with a*x + y. Nothing else.


3. Memory operations

3.1 tl.load and tl.store

Signatures (simplified, current as of 2026 Triton; verify if you target an older release):

tl.load(pointer, mask=None, other=0.0, cache=None, eviction_policy=None,
        volatile=False)
tl.store(pointer, value, mask=None, cache=None, eviction_policy=None)

Key semantics:

  • pointer is either a scalar pointer (rare) or a block tensor of pointers formed by base_ptr + offsets. The compiler infers the block shape from offsets.
  • mask, when provided, is a same-shape boolean tensor. Lanes where mask=False skip the memory access entirely (no fault, no traffic).
  • other is the value substituted for masked-off lanes on load. Default 0.0. Pick this carefully: if you mask in a tl.dot reduction, the masked lanes still participate arithmetically, so other=0.0 makes them contribute zero -- which is what you want.

3.2 Boundary masking: why it always matters

Take a vector add over n=1000 elements with BLOCK=128. You will launch ceil(1000/128) = 8 programs covering 1024 elements. The last program must mask off the trailing 24 lanes or it will read past the end of the tensor and either segfault or corrupt data:

offsets = pid * BLOCK + tl.arange(0, BLOCK)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)

In CUDA the equivalent is the if (idx < n) guard around a single thread. The Triton version is more efficient because the entire warp issues one masked load instruction; in CUDA each thread independently decides to load or not.

For 2D blocks, you need a 2D mask:

offs_m = pid_m * BM + tl.arange(0, BM)             # [BM]
offs_n = pid_n * BN + tl.arange(0, BN)             # [BN]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
ptrs = X + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
x = tl.load(ptrs, mask=mask, other=0.0)

This is the canonical pattern. Get the strides right and the mask right and you have a correct boundary-aware tile load.

3.3 Broadcasting

tl.arange returns a 1D tensor; expressions broadcast like NumPy:

A = tl.zeros([BM, BN], dtype=tl.float32)           # [BM, BN] of zero
v = tl.arange(0, BN)                               # [BN]
A2 = A + v[None, :]                                # broadcast across rows

The compiler chooses the warp/lane layout so the broadcast costs nothing (it is a register-relabelling, not a memory op). This is one of the places Triton silently saves you a lot of CUDA boilerplate.

3.4 Cache and eviction hints

cache="ca" (cache all), cache="cg" (cache global, skip L1), eviction_policy="evict_last" | "evict_first" map to PTX cache-modifier suffixes. Useful for reused tiles (the K-tile of A re-read by every column block) vs. one-shot tiles. As a first pass, leave them at default.

Exercise. Why does tl.load(ptr, mask=mask, other=0.0) in a tl.dot reduction not introduce numerical error from the masked lanes? Answer: because 0.0 is the additive identity, and tl.dot is a sum-of-products, masked-off lanes contribute exactly zero to the accumulator.


4. Math operations

4.1 Elementwise

All standard arithmetic (+ - * /), comparison (< > == !=), and tl.exp, tl.log, tl.sqrt, tl.rsqrt, tl.sin, tl.cos, tl.tanh, tl.sigmoid, tl.where(cond, a, b) operate elementwise on block tensors and follow the broadcasting rules of section 3.

tl.where is your branchless conditional. There is no per-lane if/else in Triton; you use where for value selection and mask for memory access gating.

4.2 tl.dot and tensor cores

acc = tl.dot(a, b, acc=acc, allow_tf32=True)
# a: [M, K], b: [K, N], acc: [M, N]; result shape [M, N]

tl.dot is the only operation in Triton that lowers to the GPU's tensor-core MMA instructions on Volta+. It requires:

  • M, N, K dimensions to be at least 16 (often 16/16/16 minimum; on Hopper larger MMAs are used). Sub-16 dims lower to FMA loops, which is fine but slow.
  • dtypes that the hardware supports: fp16, bf16, tf32 (the fp32 path goes through tf32 on A100 if allow_tf32=True, which is the default in PyTorch ML), fp8 on Hopper.
  • acc to be fp32 (recommended) for numerical safety. The accumulator stays in registers across iterations of the K loop.

Without tl.dot you cannot use tensor cores from Triton. Memorise this.

4.3 Reductions

m = tl.max(x, axis=1)        # reduce along axis 1
s = tl.sum(x, axis=0)
mn = tl.min(x, axis=0)
am = tl.argmax(x, axis=1)    # newer Triton versions

Reductions return a tensor with the reduced axis removed. They lower to warp-level shuffle reductions where possible and shared-memory reductions across warps. You do not write either by hand.

4.4 Transcendentals

tl.exp, tl.log, tl.sqrt, tl.rsqrt, tl.sin, tl.cos, tl.erf, tl.tanh, tl.sigmoid exist. Two notes:

  • On NVIDIA they typically lower to the fast PTX intrinsics (exp.approx.f32 etc.) for fp32, which are ~1 ulp accurate, fine for ML.
  • tl.exp2 is often slightly faster than tl.exp because the hardware natively computes base-2 exponential. You can rewrite softmax using exp2(x * log2_e) if you really need the cycles; usually not worth it.

Exercise. Implement a fused GELU: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))). Sketch: one elementwise kernel, mask on boundaries; everything is a tl.tanh/*/+. About 12 lines.


5. Compile-time constants: tl.constexpr

5.1 Why constexpr exists

Block sizes (BM, BN, BK, BLOCK) influence the shape of every block tensor in the kernel. Shape determines:

  • register allocation,
  • shared-memory tile size,
  • MMA instruction shape selection,
  • loop unrolling.

The compiler must know these at codegen time. So they are passed as tl.constexpr parameters:

@triton.jit
def kernel(..., BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr): ...

When you launch with kernel[grid](..., BM=128, BN=128, BK=32), Triton hashes the constexpr values and the dtypes/strides/etc. into a cache key. First call with a new key triggers compilation; subsequent calls reuse the cached PTX.

5.2 What else can be constexpr

Anything that should be specialised: a bool selecting causal vs. non-causal attention, an int choosing reduction axis, an enum - likedtype`. Use constexpr aggressively for branches that you want the compiler to delete in the specialised version.

@triton.jit
def attn(..., CAUSAL: tl.constexpr):
    if CAUSAL:
        # this branch is compiled away when CAUSAL=False
        mask = offs_m[:, None] >= offs_n[None, :]
        scores = tl.where(mask, scores, float("-inf"))

That if CAUSAL is not a runtime branch -- it is Python-level dead-code elimination at JIT time.

5.3 The cost: recompilation

Each unique (constexpr-value, dtype, contiguity) tuple is a separate kernel. Cache lives in ~/.triton/cache. If your constexpr space is unbounded (e.g. you set BLOCK=n per call) you will recompile constantly. Pick a finite menu and stick to it.


6. Autotuning

6.1 The decorator

@triton.autotune(
    configs=[
        triton.Config({'BM': 128, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
        triton.Config({'BM': 128, 'BN': 64,  'BK': 32}, num_warps=4, num_stages=4),
        triton.Config({'BM': 64,  'BN': 128, 'BK': 32}, num_warps=4, num_stages=4),
        triton.Config({'BM': 64,  'BN': 64,  'BK': 32}, num_warps=2, num_stages=5),
        triton.Config({'BM': 256, 'BN': 64,  'BK': 32}, num_warps=8, num_stages=3),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(...): ...

key= lists the runtime arguments whose values define a tuning bucket. On first call with a particular (M, N, K) triple, Triton runs every config, measures wall time, picks the winner, caches it. Subsequent calls with the same (M, N, K) skip the search.

6.2 Knobs

  • Block sizes (BM, BN, BK for matmul; BLOCK_SIZE for elementwise) control register pressure and shared-memory tile size. Bigger usually means more reuse but more spills.
  • num_warps controls how many 32-thread warps form the program instance. More warps means more parallelism per SM but each warp gets fewer registers.
  • num_stages controls software pipelining depth: how many K-iterations worth of tiles are simultaneously in flight (cp.async on Ampere+, TMA on Hopper). 2-5 is typical. Too many overflows shared memory; too few starves the MMA units.

6.3 Practical autotune

  • Limit configs to a dozen or so. The first call already pays the worst-case cost (you run every config), so 30 configs with 200ms each is six seconds of warmup.
  • key=['M', 'N', 'K'] is right for matmul; for elementwise, key=['n'] or no key (single bucket) is fine.
  • Use prune_configs_by={'early_config_prune': fn, 'perf_model': fn} if the search is huge. Usually unnecessary.
  • Set TRITON_PRINT_AUTOTUNING=1 to see the table of configs and their timings. This is invaluable for sanity-checking.

6.4 Caching

Compiled kernels are cached on disk under ~/.triton/cache/<hash>/. The hash is over: function source, argument dtypes, constexpr values, GPU arch. Move to a different GPU and you recompile. This is fine.


7. Numerical stability patterns

Three patterns dominate ML kernels: online softmax, Welford for running mean/variance, log-sum-exp.

7.1 The naive softmax problem

softmax(x)_i = exp(x_i) / sum_j exp(x_j). If any x_i > 88.7 (fp32) or >11.1 (fp16), exp(x_i) overflows. The standard fix:

softmax(x)_i = exp(x_i - max(x)) / sum_j exp(x_j - max(x))

This requires two passes over x: one to find max(x), one to compute the numerator/denominator. For attention, where x is the score row of length seq_len, two passes mean two reads of a tensor that does not fit in registers if seq_len is large. Bandwidth-bound.

7.2 Online softmax: the derivation

We want a one-pass algorithm. Suppose we have processed a prefix of length k and we know:

  • m_k = max(x_1..x_k),
  • l_k = sum_{j=1..k} exp(x_j - m_k).

The softmax over the prefix is exp(x_j - m_k) / l_k.

Now we see x_{k+1}. The new max is m_{k+1} = max(m_k, x_{k+1}). We need l_{k+1} = sum_{j=1..k+1} exp(x_j - m_{k+1}).

Split the sum:

l_{k+1} = sum_{j=1..k} exp(x_j - m_{k+1})  +  exp(x_{k+1} - m_{k+1})
        = sum_{j=1..k} exp(x_j - m_k) * exp(m_k - m_{k+1})  +  exp(x_{k+1} - m_{k+1})
        = l_k * exp(m_k - m_{k+1})  +  exp(x_{k+1} - m_{k+1})

That is the recurrence. The factor exp(m_k - m_{k+1}) is <= 1 and rescales the running denominator whenever the max grows. The recurrence generalises trivially to blocks of new elements rather than single elements: you compute the local max and local sum of the new block, then combine with the running (m, l) exactly as above. This is the heart of flash-attention.

In code (block version):

m_block = tl.max(x_block, axis=0)            # local max
p_block = tl.exp(x_block - m_block)
l_block = tl.sum(p_block, axis=0)            # local sum
m_new = tl.maximum(m, m_block)
alpha = tl.exp(m - m_new)
beta  = tl.exp(m_block - m_new)
l_new = alpha * l + beta * l_block
m, l = m_new, l_new
# ...and rescale any running output by alpha

7.3 Welford for variance

For RMSNorm forward you only need mean(x^2), which is a single pass. For LayerNorm forward you need both mean(x) and var(x). The naive "sum of squares minus square of sum" formula is catastrophic in fp32 for large vectors. Welford is the one-pass numerically stable form. For block combination:

n_combined = n_a + n_b
delta      = mean_b - mean_a
mean_combined = mean_a + delta * n_b / n_combined
M2_combined   = M2_a + M2_b + delta^2 * n_a * n_b / n_combined
var = M2_combined / n_combined

You will see this exact pattern in the LayerNorm tutorial kernels.

7.4 Log-sum-exp

logsumexp(x) = log(sum exp(x)) is m + log(sum exp(x - m)) for m = max(x). Same trick as softmax. Useful inside cross-entropy loss kernels where you want - x_target + logsumexp(x)` directly.


8. Six annotated kernels

We progress from the trivial to flash-attention. Every kernel below is a runnable example. Imports assumed:

import torch
import triton
import triton.language as tl

8.1 Vector add

@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n,
               BLOCK: tl.constexpr):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < n
    x = tl.load(x_ptr + offs, mask=mask)
    y = tl.load(y_ptr + offs, mask=mask)
    tl.store(out_ptr + offs, x + y, mask=mask)

def add(x, y):
    out = torch.empty_like(x)
    n = x.numel()
    grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),)
    add_kernel[grid](x, y, out, n, BLOCK=1024)
    return out

This kernel is bandwidth-bound; it should approach peak HBM bandwidth on any reasonable GPU. If you measure 80% of peak you are doing fine; the remaining 20% is launch overhead and tail effects.

8.2 Naive matmul

A naive matmul has each program compute one output element via an inner loop over K. Useful as a teaching baseline; do not ship it.

@triton.jit
def matmul_naive(A, B, C, M, N, K,
                 sa_m, sa_k, sb_k, sb_n, sc_m, sc_n,
                 BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    offs_m = pid_m * BM + tl.arange(0, BM)
    offs_n = pid_n * BN + tl.arange(0, BN)
    offs_k = tl.arange(0, BK)
    a_ptrs = A + offs_m[:, None] * sa_m + offs_k[None, :] * sa_k
    b_ptrs = B + offs_k[:, None] * sb_k + offs_n[None, :] * sb_n
    acc = tl.zeros([BM, BN], dtype=tl.float32)
    for k in range(0, K, BK):
        a_mask = (offs_m[:, None] < M) & ((k + offs_k)[None, :] < K)
        b_mask = ((k + offs_k)[:, None] < K) & (offs_n[None, :] < N)
        a = tl.load(a_ptrs, mask=a_mask, other=0.0)
        b = tl.load(b_ptrs, mask=b_mask, other=0.0)
        acc += tl.dot(a, b)
        a_ptrs += BK * sa_k
        b_ptrs += BK * sb_k
    c_ptrs = C + offs_m[:, None] * sc_m + offs_n[None, :] * sc_n
    c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tl.store(c_ptrs, acc.to(C.dtype.element_ty), mask=c_mask)

Already this is much simpler than the equivalent CUDA, which would need explicit shared-memory tiles, __syncthreads(), and per-thread MMA inputs.

8.3 Tiled matmul with autotune

The canonical Triton matmul. The two changes from 8.2 are:

  1. L2-cache-friendly program-id swizzle (group along M to reuse B tiles).
  2. @triton.autotune over a config grid.
def get_configs():
    return [
        triton.Config({'BM':128,'BN':256,'BK':32,'GROUP_M':8}, num_warps=8, num_stages=3),
        triton.Config({'BM':128,'BN':128,'BK':32,'GROUP_M':8}, num_warps=4, num_stages=4),
        triton.Config({'BM':128,'BN':64, 'BK':32,'GROUP_M':8}, num_warps=4, num_stages=4),
        triton.Config({'BM':64, 'BN':128,'BK':32,'GROUP_M':8}, num_warps=4, num_stages=4),
        triton.Config({'BM':64, 'BN':64, 'BK':32,'GROUP_M':8}, num_warps=2, num_stages=5),
        triton.Config({'BM':128,'BN':128,'BK':64,'GROUP_M':8}, num_warps=8, num_stages=3),
    ]

@triton.autotune(configs=get_configs(), key=['M', 'N', 'K'])
@triton.jit
def matmul_tiled(A, B, C, M, N, K,
                 sa_m, sa_k, sb_k, sb_n, sc_m, sc_n,
                 BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr,
                 GROUP_M: tl.constexpr):
    pid = tl.program_id(0)
    num_pid_m = tl.cdiv(M, BM)
    num_pid_n = tl.cdiv(N, BN)
    # L2-friendly swizzle: walk in (GROUP_M, num_pid_n) super-blocks
    num_pid_in_group = GROUP_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = (pid_m * BM + tl.arange(0, BM)) % M
    offs_bn = (pid_n * BN + tl.arange(0, BN)) % N
    offs_k  = tl.arange(0, BK)
    a_ptrs = A + offs_am[:, None] * sa_m + offs_k[None, :] * sa_k
    b_ptrs = B + offs_k[:, None] * sb_k + offs_bn[None, :] * sb_n

    acc = tl.zeros([BM, BN], dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BK)):
        k_remaining = K - k * BK
        a_mask = offs_k[None, :] < k_remaining
        b_mask = offs_k[:, None] < k_remaining
        a = tl.load(a_ptrs, mask=a_mask, other=0.0)
        b = tl.load(b_ptrs, mask=b_mask, other=0.0)
        acc = tl.dot(a, b, acc=acc)
        a_ptrs += BK * sa_k
        b_ptrs += BK * sb_k

    offs_cm = pid_m * BM + tl.arange(0, BM)
    offs_cn = pid_n * BN + tl.arange(0, BN)
    c_ptrs = C + offs_cm[:, None] * sc_m + offs_cn[None, :] * sc_n
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, acc.to(C.dtype.element_ty), mask=c_mask)

def matmul(a, b):
    assert a.shape[1] == b.shape[0]
    M, K = a.shape
    _, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=a.dtype)
    grid = lambda meta: (triton.cdiv(M, meta['BM']) * triton.cdiv(N, meta['BN']),)
    matmul_tiled[grid](
        a, b, c, M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    )
    return c

Why the swizzle? Without it, neighbouring program ids cover neighbouring N columns of C with the same row of A. With it, neighbours share both A and B reuse, which keeps the L2 hit rate high on big problems. This trick alone is often worth 20-40% on large square matmul.

On A100/H100 this kernel reaches roughly 80% of cuBLAS for square matmul with M,N,K multiples of 128, and it frequently beats cuBLAS on irregular shapes (e.g. K=80, common in attention with head_dim=80). Do not take exact numbers on faith; benchmark on your hardware.

8.4 Fused softmax (rowwise, single pass-per-row)

For rows that fit entirely in one block (N <= BLOCK_N):

@triton.jit
def softmax_kernel(X, Y, sx, sy, N,
                   BLOCK: tl.constexpr):
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK)
    mask = cols < N
    x = tl.load(X + row * sx + cols, mask=mask, other=-float("inf"))
    x = x - tl.max(x, axis=0)
    num = tl.exp(x)
    den = tl.sum(num, axis=0)
    y = num / den
    tl.store(Y + row * sy + cols, y, mask=mask)

def softmax(x):
    M, N = x.shape
    BLOCK = triton.next_power_of_2(N)
    y = torch.empty_like(x)
    softmax_kernel[(M,)](x, y, x.stride(0), y.stride(0), N, BLOCK=BLOCK,
                         num_warps=4 if BLOCK <= 1024 else 8)
    return y

Notes:

  • other=-inf for masked-off lanes ensures they cannot become the max and do not contribute to the sum (exp(-inf) = 0).
  • This whole row lives in registers, so we have a true one-pass softmax; no online-softmax recurrence needed.
  • For very wide rows (N > 64K), you would use a tiled version with the online-softmax recurrence from section 7.

8.5 RMSNorm forward + backward

RMSNorm: y_i = x_i / sqrt(mean(x^2) + eps) * w_i.

Forward.

@triton.jit
def rmsnorm_fwd(X, W, Y, RSTD, sx, sy, N, eps,
                BLOCK: tl.constexpr):
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK)
    mask = cols < N
    x = tl.load(X + row * sx + cols, mask=mask, other=0.0).to(tl.float32)
    var = tl.sum(x * x, axis=0) / N
    rstd = 1.0 / tl.sqrt(var + eps)
    tl.store(RSTD + row, rstd)             # save for backward
    w = tl.load(W + cols, mask=mask, other=0.0)
    y = x * rstd * w
    tl.store(Y + row * sy + cols, y.to(Y.dtype.element_ty), mask=mask)

Backward. Given dy, compute dx, dw. The math:

y_i = x_i * w_i * r,   where r = (mean(x^2) + eps)^{-1/2}
dr/dx_j = -(1/N) * x_j * r^3                          (since d/dx_j var = 2 x_j / N)
dy_i/dx_j = w_i * r * delta_{ij} + x_i * w_i * dr/dx_j
         = w_i * r * delta_{ij} - (1/N) * x_i * w_i * x_j * r^3

dx_j = sum_i dy_i * dy_i/dx_j
     = w_j * r * dy_j - (r^3 / N) * x_j * sum_i (x_i * w_i * dy_i)
     = r * (w_j * dy_j - x_j * (r^2 / N) * S),   where S = sum_i x_i * w_i * dy_i

So we need a single reduction S per row, then a fused elementwise pass:

@triton.jit
def rmsnorm_bwd_dx(X, W, DY, RSTD, DX, sx, sdy, sdx, N,
                   BLOCK: tl.constexpr):
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK)
    mask = cols < N
    x  = tl.load(X  + row * sx  + cols, mask=mask, other=0.0).to(tl.float32)
    w  = tl.load(W                  + cols, mask=mask, other=0.0).to(tl.float32)
    dy = tl.load(DY + row * sdy + cols, mask=mask, other=0.0).to(tl.float32)
    r  = tl.load(RSTD + row).to(tl.float32)
    S  = tl.sum(x * w * dy, axis=0)
    dx = r * (w * dy - x * (r * r / N) * S)
    tl.store(DX + row * sdx + cols, dx, mask=mask)

dW is reduced across rows; do it in a second small kernel (or fold it into the same kernel with atomics if you have many rows; pure reduce-then- store is usually faster).

8.6 Causal flash-attention forward (simplified)

We compute O = softmax(Q K^T / sqrt(d)) V, with a causal mask, in tiles. The trick is to never materialise the seq_len x seq_len score matrix -- we keep an online (m, l) per output row block and stream over key/value blocks. Backward is more complex; we omit it.

Shapes: Q, K, V are [B, H, S, D]; we run one program per (batch, head, M-block of Q rows). For brevity we drop the batch/head dims; assume Q, K, V are 2D [S, D] and that the launcher iterates over BH.

@triton.jit
def flash_attn_fwd(Q, K, V, O, L, M,            # M, L for backward
                   sq_s, sq_d, sk_s, sk_d, sv_s, sv_d, so_s, so_d,
                   S, D: tl.constexpr,
                   BM: tl.constexpr, BN: tl.constexpr,
                   CAUSAL: tl.constexpr):
    pid_m = tl.program_id(0)
    offs_m = pid_m * BM + tl.arange(0, BM)
    offs_d = tl.arange(0, D)
    offs_n = tl.arange(0, BN)

    # Load Q-tile once and keep it in registers
    q_ptrs = Q + offs_m[:, None] * sq_s + offs_d[None, :] * sq_d
    q_mask = offs_m[:, None] < S
    q = tl.load(q_ptrs, mask=q_mask, other=0.0)

    # Online softmax state
    m_i = tl.full([BM], -float("inf"), dtype=tl.float32)
    l_i = tl.zeros([BM], dtype=tl.float32)
    acc = tl.zeros([BM, D], dtype=tl.float32)

    scale = 1.0 / tl.sqrt(tl.full([], D, dtype=tl.float32))

    # Causal: only iterate up to the last K-block that can contribute.
    n_end = (pid_m + 1) * BM if CAUSAL else S
    for start_n in range(0, n_end, BN):
        cur_n = start_n + offs_n
        k_ptrs = K + cur_n[None, :] * sk_s + offs_d[:, None] * sk_d
        v_ptrs = V + cur_n[:, None] * sv_s + offs_d[None, :] * sv_d
        k_mask = cur_n[None, :] < S
        v_mask = cur_n[:, None] < S
        k = tl.load(k_ptrs, mask=k_mask, other=0.0)
        v = tl.load(v_ptrs, mask=v_mask, other=0.0)

        # Scores [BM, BN]
        s = tl.dot(q, k) * scale
        if CAUSAL:
            causal_mask = offs_m[:, None] >= cur_n[None, :]
            s = tl.where(causal_mask, s, float("-inf"))
        # also mask off out-of-range cur_n
        s = tl.where(cur_n[None, :] < S, s, float("-inf"))

        # Online softmax update
        m_new = tl.maximum(m_i, tl.max(s, axis=1))
        alpha = tl.exp(m_i - m_new)
        p = tl.exp(s - m_new[:, None])
        l_i = alpha * l_i + tl.sum(p, axis=1)
        acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)
        m_i = m_new

    o = acc / l_i[:, None]
    o_ptrs = O + offs_m[:, None] * so_s + offs_d[None, :] * so_d
    tl.store(o_ptrs, o.to(O.dtype.element_ty), mask=q_mask)
    # Save log-sum-exp = m_i + log(l_i) for backward
    tl.store(L + offs_m, m_i + tl.log(l_i), mask=offs_m < S)

This is not the fastest flash-attention -- the production versions split the K loop differently, fuse the dropout, use TMA on Hopper, and have a dedicated backward kernel. But the structure -- Q-tile in registers, streaming K/V tiles, online (m, l), accumulator rescaled by alpha -- is the same and is the form you should be able to derive from scratch.

Two subtle points the code makes explicit:

  1. The accumulator update is acc = acc * alpha + dot(p, v). The alpha scaling corrects every previous K-block's contribution to the new denominator. Without it the answer is wrong.
  2. The causal mask is applied before the softmax exp, by setting out-of-bounds positions to - inf. Afterexp, those become0, which is the additive identity for the running sum and for thedot(p, v)` accumulation.

9. The compilation pipeline

9.1 The path

Python @triton.jit function
   |  AST capture
   v
Triton IR  (an MLIR dialect)
   |  high-level optimisation:
   |    layout selection, broadcasting elimination,
   |    reduction lowering, masked-load elision
   v
TritonGPU IR  (an MLIR dialect, HW-aware)
   |  GPU-specific:
   |    shared-memory allocation, pipelining (cp.async / TMA),
   |    swizzling, MMA selection
   v
LLVM IR
   |
   v
PTX (NVIDIA) or AMDGCN (AMD) or other backend
   |
   v
SASS (assembled by ptxas at load time)

You almost never need to look at the intermediate IRs. You frequently want to look at the PTX.

9.2 Inspecting compiled artefacts

kernel = matmul_tiled[grid](...)        # first call, compiles
# Get the compiled handle:
fn = matmul_tiled.warmup(*args, grid=grid)   # one approach
# Or, after a real call:
print(matmul_tiled.cache.values())
# For a specific compiled binary:
binary = next(iter(matmul_tiled.cache.values()))
print(binary.asm['ttir'])    # Triton IR
print(binary.asm['ttgir'])   # TritonGPU IR
print(binary.asm['llir'])    # LLVM IR
print(binary.asm['ptx'])     # PTX
print(binary.asm['cubin'])   # binary cubin (bytes)

(API has shifted across releases; the exact attribute names may differ in your version. Verify with dir(binary.asm).)

9.3 What to look for in PTX

  • ld.global.v4.f32 / ld.global.v8.f16 -- vectorised loads. If your kernel emits scalar ld.global.f32, your indexing is not coalesced.
  • mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 -- tensor-core MMA. No such instruction means you are not using tensor cores.
  • cp.async.cg.shared.global (Ampere) or cp.async.bulk.tensor (Hopper) -- asynchronous global-to-shared copies, used by software pipelining.
  • bar.sync -- barriers between pipeline stages. Too many usually means num_stages is too low and you are stalling.

9.4 Useful environment variables

  • TRITON_PRINT_AUTOTUNING=1 -- prints the timing table for autotune.
  • TRITON_CACHE_DIR=/path -- override the default cache location.
  • TRITON_DEBUG=1 -- prints extra diagnostics.
  • TRITON_INTERPRET=1 -- runs the kernel on the CPU in pure Python emulation. Excruciatingly slow but the only way to single-step a Triton kernel under pdb. Worth the cost when you have a correctness bug you cannot localise.

10. Integration with PyTorch

10.1 Calling a Triton kernel from PyTorch

The simplest path: a Python wrapper around the kernel, callable from ordinary user code. The wrappers in section 8 (add, matmul, softmax) are exactly this. PyTorch tensors are passed directly; their .data_ptr() becomes the kernel's pointer argument, dtype and strides are read from the tensor metadata.

10.2 Autograd

Subclass torch.autograd.Function:

class RMSNormFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, w, eps):
        M, N = x.shape
        y = torch.empty_like(x)
        rstd = torch.empty(M, device=x.device, dtype=torch.float32)
        BLOCK = triton.next_power_of_2(N)
        rmsnorm_fwd[(M,)](x, w, y, rstd, x.stride(0), y.stride(0), N, eps,
                          BLOCK=BLOCK, num_warps=4)
        ctx.save_for_backward(x, w, rstd)
        ctx.N, ctx.BLOCK = N, BLOCK
        return y

    @staticmethod
    def backward(ctx, dy):
        x, w, rstd = ctx.saved_tensors
        dx = torch.empty_like(x)
        M = x.shape[0]
        rmsnorm_bwd_dx[(M,)](x, w, dy, rstd, dx,
                             x.stride(0), dy.stride(0), dx.stride(0),
                             ctx.N, BLOCK=ctx.BLOCK, num_warps=4)
        # dW: reduce x*rstd*dy over the row dim
        dw = (x * rstd[:, None] * dy).sum(dim=0)   # let PyTorch handle it
        return dx, dw, None

def rmsnorm(x, w, eps=1e-6):
    return RMSNormFn.apply(x, w, eps)

The dW reduction is small enough that letting PyTorch handle it is fine; fusing it into a Triton kernel is an optimisation, not a correctness step.

10.3 torch.library ops

For interaction with torch.compile, register your Triton kernel as a custom op so the compiler treats it as opaque and does not try to trace through @triton.jit:

import torch
from torch.library import custom_op, register_fake

@custom_op("mylib::rmsnorm", mutates_args=())
def rmsnorm_op(x: torch.Tensor, w: torch.Tensor, eps: float) -> torch.Tensor:
    return RMSNormFn.apply(x, w, eps)

@register_fake("mylib::rmsnorm")
def _(x, w, eps):
    return torch.empty_like(x)        # shape/stride meta, no compute

The register_fake (formerly register_meta) provides shape inference for torch.compile's symbolic shape tracing. Without it torch.compile will fail to trace through your op.

10.4 torch.compile

torch.compile already emits Triton kernels for many fusions (it has a backend called Inductor). Your handwritten Triton kernels coexist: Inductor will compile around them and treat them as black boxes (provided you registered them via torch.library). If you write a really good kernel that beats Inductor on a hot path -- common for attention variants and custom norms -- this is the way to ship it.

Two cautions:

  • torch.compile expects shape-stable callees. If your custom op recompiles on every shape it slows the graph; pre-warm with the shapes you care about.
  • mutates_args=() is critical for op safety. If you do mutate arguments (e.g. an in-place kernel) declare them or you will silently corrupt the graph cache.

11. Common pitfalls

Shape-dependent recompilation. Every unique constexpr value triggers a new compile. Pin block sizes to a finite menu. Autotune key buckets your shapes; that is fine. What is not fine is BLOCK = triton.next_power_of_2(N) when N varies continuously -- you will have a compile per N. Solution: quantise (BLOCK = max(64, triton.next_power_of_2(N)) -- ok if N has a small set of values; not ok in general).

Mask correctness. Two failure modes:

  • Forgetting to mask. Boundary out-of-bounds reads on Ampere often appear to "work" because they read zeros from a happy memory region; on Hopper with TMA they hard-fault. Always mask.
  • other= mismatch. For a softmax row, other=0.0 makes masked lanes the max if all real values are negative -- wrong. Use other=-inf for max-reductions, other=0.0 for sum-reductions and dot accumulators.

num_stages too high. Each pipeline stage holds a full [BM, BK] and [BK, BN] tile in shared memory. Running out of shared memory causes silent fallback to fewer stages or a hard compile error, depending on version. If autotune timing for a config is much worse than a config with smaller BM*BK + BK*BN, you are spilling.

Register-pressure spills. Same root cause: tiles too big for registers. Symptoms: ptxas reports spills (ptxas info: ... 2048 bytes stack frame, ... bytes spill stores), kernel runs at fraction of expected speed. Reduce BM, BN, or split the K loop.

Forgetting .to(dtype) on store. acc is fp32, your output buffer is fp16. tl.store(C_ptr, acc, ...) will implicitly cast in modern Triton but the cast may not be the rounding mode you expect (RTNE vs. RTZ). Be explicit: acc.to(C.dtype.element_ty).

Confusing tl.program_id axes with grid axes. tl.program_id(0) is the first grid dim, not the X coordinate of anything physical. With a 2D grid (grid_m, grid_n), tl.program_id(0) ranges 0..grid_m-1 and tl.program_id(1) 0..grid_n-1. Be consistent.

Autotune triggering on every call. Happens if your key= does not include all shape-dependent args (e.g. you forgot K). The cache key collides across distinct shapes and the autotuner re-runs. Always include every shape-relevant runtime arg in key=.

Stride bugs. Always pass strides explicitly; never assume contiguous. A view of a transposed tensor will have non-trivial strides; using T without checking is a frequent source of garbage outputs.

Thinking tl.dot is matrix multiplication. It is block matrix multiplication -- one tile from the operand-A block and one tile from operand-B block. The K-loop is yours.


12. Triton vs CUTLASS vs hand-rolled CUDA

Rough rules of thumb. None are universal; benchmark your case.

Use case Pick Why
Standard square matmul, fp16/bf16, big shapes cuBLAS Vendor-tuned per arch; near-peak.
GEMM with epilogue fusion (bias + GeLU + dropout) CUTLASS or Triton cuBLAS does no fusion.
Custom fused norm / softmax / attention Triton Productivity is decisive; perf is good.
Bleeding-edge attention (FA-3 style) Hand CUDA / CUTLASS TMA, warp specialisation, async fence patterns -- still ahead of Triton in 2025/2026.
Quirky shapes (small K, irregular masking) Triton Beats vendor libs frequently.
Sparse / structured-sparsity kernels CUTLASS Has sparse MMA primitives.
Research prototype, weeks-not-months Triton Always.

CUTLASS is a C++ template library from NVIDIA: more verbose than Triton, more flexible than cuBLAS, exposes the actual MMA shapes and async pipeline primitives. Hand-rolled CUDA gets you the last 5-15% on niche ops; the engineering cost is large.

A reasonable rule: write everything in Triton first, profile, replace the top-1 hot kernel with CUTLASS or hand CUDA only if numbers force you to.


13. Exercises

Six problems, ordered by difficulty. Try each before reading the sketch.

13.1 Fused bias + ReLU + scale

Write a kernel y = relu(x + b) * s where x is [M, N], b is [N] (broadcast across rows), s is a scalar. One launch, one pass.

Sketch. 1D grid of M*ceil(N/BN) programs, or 2D grid of (M, ceil(N/BN)). Load x tile [1, BN], load b [BN] once per program, broadcast add, tl.where(z > 0, z, 0.0) * s, store. ~25 lines. Time should be HBM-bandwidth-bound.

13.2 Rowwise top-k mask (k=1)

Given x [M, N], output y of same shape with y[i, j] = x[i, j] if x[i, j] is the rowwise max, else 0. Single pass per row.

Sketch. Like softmax: load row into registers, m = tl.max(x, axis=0), y = tl.where(x == m, x, 0.0), store. Care with ties (you will mark all of them; usually fine). Width must fit in BLOCK.

13.3 LayerNorm forward with Welford

Write a LayerNorm forward that computes mean and variance with one pass using Welford's online algorithm. Compare numerical accuracy to the naive formula on a [1, 65536] input of magnitude 1e6.

Sketch. Tile the row into chunks of BLOCK, accumulate (n, mean, M2) across chunks with the combine formulas in section 7.3. After the loop, var = M2 / N, rstd = rsqrt(var + eps), second pass to compute and store y. Or, since the row fits in registers when N <= 64K, use the naive single-block reduction path -- in fp32 that is also stable for any realistic magnitude. Welford pays off when you must tile.

13.4 Causal-attention forward, drop the causal mask

Modify the flash-attention kernel in 8.6 so it accepts a CAUSAL: tl.constexpr parameter and, when False, iterates over the full K range. Verify: when CAUSAL=False, output equals softmax(QK^T/sqrt(d)) V to within fp16 precision.

Sketch. The kernel as written already has CAUSAL. The change is the n_end = (pid_m + 1) * BM if CAUSAL else S line and the tl.where causal mask gated on if CAUSAL:. Confirm that the if is on CAUSAL (a constexpr) so it is compiled away when False.

13.5 Matmul autotune key bug

A user complains: "every call recompiles, even at the same shape." Their decorator is @triton.autotune(configs=[...], key=['M', 'N']) for an M x K by K x N matmul. What is wrong?

Sketch. K is missing from the key. The autotuner caches only by (M, N), so (M, N, K1) and (M, N, K2) collide and the autotuner reruns benchmarks each time K changes. Add K to key.

13.6 Spotting tensor-core usage in PTX

Take the matmul kernel from 8.3, compile it for fp16 inputs, dump the PTX, and find the MMA instructions. What is the M-N-K shape of each mma.sync? What does it imply about the tile sizes Triton chose internally?

Sketch. Look for mma.sync.aligned.m16n8k16 (Ampere fp16) or wgmma.mma_async (Hopper). On Ampere fp16, the per-warp MMA shape is 16x8x16, so a [128, 128] accumulator across 4 warps is built from many of these MMAs in a tile schedule chosen by the compiler. The takeaway: the fundamental hardware tile is small; Triton's BM, BN, BK are block-level tiles assembled out of many MMAs.


14. Closing notes

You now have:

  • the conceptual model (block-level program, per-instance abstraction),
  • the API surface (tl.load/store/dot/sum/max, masks, constexpr, autotune),
  • the numerical-stability primitives (online softmax, Welford, log-sum-exp),
  • six worked kernels covering the 90% case (elementwise, matmul, softmax, norm, attention),
  • the integration story for PyTorch (autograd.Function, torch.library, torch.compile),
  • a debugging path (PTX inspection, TRITON_INTERPRET, autotune logs),
  • a comparative map against CUTLASS and hand CUDA.

Two final discipline rules. Always benchmark with torch.cuda.synchronize inside triton.testing.do_bench -- forgetting do_bench's warmup discards the autotune-search time and gives you misleading numbers. Always check correctness against a PyTorch reference first, with torch.testing.assert_close(out_triton, out_torch, rtol=1e-2, atol=1e-2) for fp16, before you start chasing performance. A faster wrong kernel is not a kernel.

The single strongest piece of advice anyone has given about Triton: write the dumb version first, autotune it, look at the PTX once to make sure tensor cores fired, then move on. The framework is designed to reward that workflow. The hours you would have spent on per-thread indexing in CUDA become hours you spend on actual algorithms.

PyTorch Internals: A Deep Dive Reference

A self-contained chapter for the AI Systems curriculum, Month 3. Target reader: a backend/SRE engineer who already writes PyTorch model code, wants to understand what happens beneath tensor.add_() and model.compile(), and is willing to read C++-flavoured pseudocode. Goal: after this chapter you should be able to read the PyTorch source tree, debug a dispatch-related bug, write a custom op that survives torch.compile, and reason about performance from first principles.


0. How To Read This Chapter

PyTorch is a stack of layers. Understanding it means understanding which layer owns which decision. The chapter walks down the stack on the way in (Python -> ATen -> dispatcher -> backend), then back up (autograd, AMP, compile) because higher layers are easier to follow once you know the substrate.

For each topic you will see four passes:

  1. Intuition -- what mental model is correct.
  2. Mechanism -- the actual data structures and control flow.
  3. Minimal code -- the smallest example that exercises the mechanism.
  4. Dispatch trace -- "if I were the dispatcher, what would I do step by step." This is the most underrated reasoning tool in PyTorch -- once you can simulate the dispatcher in your head, almost every weird bug becomes obvious.

All code examples target PyTorch 2.4+ on Linux/CUDA. Where source paths are referenced, they are relative to the pytorch/pytorch repo root.


1. The Layered Architecture

1.1 Intuition

When you write c = a + b in Python you are at the top of a five-layer cake. The layers exist because each one solves a different problem:

Layer Language Job
Python frontend (torch.*) Python Ergonomics, autograd surfaces, nn.Module
Pybind shim (torch._C) C++/pybind11 Convert PyObject -> C++ args, hold the GIL boundary
ATen (at::Tensor, ops) C++ The op API. Defines what add means type-erased over backends
Dispatcher (c10) C++ Pick which kernel to run (Autograd? CUDA? Autocast?)
Backend kernels C++/CUDA/Triton/MPS/etc. Actually compute the bytes

The dispatcher is the keystone. ATen does not call kernels directly. ATen says "here is add(Tensor, Tensor) -> Tensor, dispatcher, please find the right implementation given these tensor properties." This indirection is what lets autograd, autocast, vmap, FakeTensor, meta tensors, and quantization all hook in at the same place.

1.2 Mechanism

Top-to-bottom for c = a + b where a, b are CUDA float32, requires_grad=True:

Python:           c = a + b
                    -> Tensor.__add__(self, other)
                    -> torch._C._TensorBase.add(self, other)   # via pybind
C++ (ATen):       at::add(self, other)
                    -> at::_ops::add_Tensor::call(self, other)  # codegen'd
Dispatcher:       Dispatcher::singleton().call(op_handle, stack)
                    -> picks kernel by computed DispatchKeySet
                    1. Autograd kernel (records grad_fn, then redispatches)
                    2. AMP/Autocast kernel (maybe casts, then redispatches)
                    3. CUDA kernel (the real one)
CUDA:             vectorized elementwise add launches; returns Tensor

The order of layers is not arbitrary. Autograd wraps autocast wraps the backend, because: - Autograd needs to see the original op so it can record the right grad_fn. - Autocast needs to decide casts before the backend sees the dtypes. - The backend just computes.

You will see the same pattern reappear: any new cross-cutting concern (functionalization, batching for vmap, fake-tensor tracing) becomes a new DispatchKey somewhere in the stack.

1.3 ASCII trace of a + b

            +---------------------------+
 Python --> | Tensor.__add__            |
            +-------------+-------------+
                          |
                          v   torch._C (pybind)
            +---------------------------+
            | at::add(Tensor, Tensor)   |   ATen
            +-------------+-------------+
                          |
                          v
            +---------------------------+
            | OperatorHandle::call      |   Dispatcher
            |  computes DispatchKeySet  |
            +-------------+-------------+
                          |
              redispatches through keys:
                          |
                 +--------+--------+
                 | Autograd kernel |  records AddBackward; redispatch w/o Autograd key
                 +--------+--------+
                          |
                 +--------+--------+
                 | Autocast kernel |  (if active) cast inputs; redispatch w/o Autocast key
                 +--------+--------+
                          |
                 +--------+--------+
                 | CUDA kernel     |  vectorized_elementwise_kernel<<<...>>>(...)
                 +-----------------+

Memorise that picture. Almost every "why is this slow / wrong / weird" question maps to a layer in it.


2. Tensor Representation

2.1 The four-level wrapping

A torch.Tensor you hold in Python is a thin handle. Underneath:

torch.Tensor                 # Python object, subclass of torch._C._TensorBase
    -> at::Tensor            # C++ value type, ~one pointer wide
        -> c10::TensorImpl   # the heap object: dtype, sizes, strides, key set
            -> c10::Storage  # owns the bytes (or shares them with views)
                -> c10::DataPtr -> raw void* + Device + Allocator

at::Tensor is essentially intrusive_ptr<TensorImpl>. Copying a tensor in C++ bumps a refcount; it does not copy bytes. That is why Tensor a = b; is cheap -- they share the same TensorImpl (and hence the same Storage).

2.2 Fields you must know

TensorImpl (in c10/core/TensorImpl.h) carries roughly:

Field Type Meaning
storage_ c10::Storage The byte buffer. Shared between views.
sizes_ SmallVector<int64_t> Shape.
strides_ SmallVector<int64_t> How many elements (not bytes) to step per dim.
storage_offset_ int64_t Where this tensor starts inside the storage, in elements.
dtype_ caffe2::TypeMeta float32, bfloat16, int64, ...
device_ c10::Device (DeviceType::CUDA, index=0), (CPU, -1), etc.
layout_ c10::Layout Strided, Sparse, SparseCsr, Mkldnn.
key_set_ DispatchKeySet Bitset of dispatch keys (Autograd, CUDA, ...).
requires_grad_ bool Lives via AutogradMeta, not directly here.
autograd_meta_ unique_ptr<AutogradMetaInterface> grad, grad_fn, version counter.

Storage (in c10/core/Storage.h) carries:

Field Type Meaning
data_ptr_ c10::DataPtr Owning pointer + device + allocator deleter.
size_bytes_ size_t Capacity in bytes.
allocator_ c10::Allocator* Where to ask for memory. On CUDA this is the caching allocator (Section 12).
resizable_ bool Can resize_ grow the buffer?

Note: Storage does not know dtype or shape. It is just bytes. Two views of the same Storage can in principle even disagree on dtype (e.g., a.view(torch.int32) reinterprets bits).

2.3 The single most important invariant

For a strided tensor:

address_of_element(i0, i1, ..., in) =
    storage.data_ptr
    + dtype_size * (storage_offset + sum_k(ik * stride_k))

That's it. Sizes, strides, storage_offset, dtype, base pointer. Five things determine where every element lives. Views are just other tensors that share storage but have different (sizes, strides, storage_offset). Contiguous is a property of the strides relative to the sizes, not of memory itself.

2.4 Why decouple view from storage

If shape lived inside storage you would copy bytes for transpose, narrow, unsqueeze. Decoupling lets these be O(1) metadata-only ops. The cost: kernels must respect arbitrary strides, or you must contiguous() first. PyTorch favours the first for "shape ops" and the second for "compute ops" -- compute kernels typically demand contiguous (or one of a few canonical memory formats) input.


3. Strides and Views

3.1 Stride arithmetic

For a contiguous (2, 3, 4) float32 tensor:

sizes   = [2, 3, 4]
strides = [12, 4, 1]            # elements (not bytes)
element (i, j, k) -> offset = 12*i + 4*j + 1*k

Strides for contiguous (row-major) tensors are: stride[i] = prod(sizes[i+1:]).

3.2 Three view ops with no copy

import torch
a = torch.arange(24, dtype=torch.float32).view(2, 3, 4)   # contiguous
print(a.stride())           # (12, 4, 1)

b = a.transpose(0, 2)        # swap dim 0 and dim 2
print(b.shape, b.stride())   # torch.Size([4, 3, 2]), (1, 4, 12)
print(b.is_contiguous())     # False
print(b.data_ptr() == a.data_ptr())   # True -- same storage

c = a.narrow(1, 1, 2)        # along dim 1, start=1, length=2
print(c.shape, c.stride(), c.storage_offset())
# torch.Size([2, 2, 4]), (12, 4, 1), 4

d = a.unsqueeze(0)            # add a leading dim of size 1, stride 0 (or any)
print(d.shape, d.stride())   # torch.Size([1, 2, 3, 4]), (24, 12, 4, 1) or similar

What changed in each:

Op sizes strides storage_offset storage
transpose(0,2) (4,3,2) (1,4,12) 0 shared
narrow(1,1,2) (2,2,4) (12,4,1) 1*4 = 4 shared
unsqueeze(0) (1,2,3,4) (24,12,4,1) 0 shared

The transposed tensor's storage looks the same byte-for-byte; we just re-described how to walk it. That is the entire trick.

3.3 is_contiguous and contiguous()

is_contiguous() returns true iff strides equal the canonical strides for the sizes:

strides[i] == prod(sizes[i+1:])    for all i (and ==1 for the last)

Many fused/elementwise CUDA kernels assume contiguous input so they can do unit-stride vectorised loads. If you pass them a transposed tensor they would either: - error out with a stride check, or - silently fall back to a strided kernel that is 5-10x slower.

So shape-bending operations are often followed by .contiguous() before a heavy op:

y = x.transpose(1, 2).contiguous()   # make a real copy with canonical strides
z = some_fused_kernel(y)

contiguous() allocates new storage and copies. Skipping it when needed wastes throughput; calling it when not needed wastes memory.

3.4 Memory format: channels_last

For 4D image tensors, two memory layouts are common:

  • torch.contiguous_format (NCHW): strides (C*H*W, H*W, W, 1).
  • torch.channels_last (NHWC): strides (C*H*W, 1, W*C, C).

Both have shape (N, C, H, W) -- only strides differ. CUDNN and many fused conv kernels are faster on channels_last for FP16/BF16 because it matches tensor-core memory access patterns. You opt in with:

x = x.to(memory_format=torch.channels_last)
model = model.to(memory_format=torch.channels_last)

Now x.is_contiguous() is false but x.is_contiguous(memory_format=torch.channels_last) is true. Kernels that understand the format will not insert a copy; kernels that do not will materialize a contiguous tensor on the way in.

3.5 The dispatcher's view of strides

The dispatcher does not know or care about strides. Strides live in TensorImpl. The kernel underneath cares. This is why a stride bug is almost always a kernel bug, never a dispatcher bug.


4. The Dispatcher

4.1 Intuition

The dispatcher is a polymorphism mechanism more powerful than virtual functions. A virtual call dispatches on one type. The PyTorch dispatcher dispatches on a set of features that come from all tensor inputs combined: the union of their dispatch keys plus thread-local state (autocast active? grad enabled?).

Think of it as: every op is a function pointer table indexed by DispatchKey. Every call computes a key set, picks the highest priority key in the set, looks up the kernel, runs it. The kernel may "redispatch" -- remove its own key from the set and ask the dispatcher to do it again -- to chain effects.

4.2 The DispatchKey enum

Defined in c10/core/DispatchKey.h. The keys form a priority order. Roughly (highest first):

... functorch / vmap / FuncTorchBatched ...
PythonTLSSnapshot
Python                    # __torch_dispatch__
Functionalize             # for export / aot
... per-backend autograd ...
AutogradOther / AutogradCPU / AutogradCUDA / AutogradXPU / AutogradMeta
... AMP / autocast keys ...
AutocastCPU
AutocastCUDA
AutocastXPU
... tracing / profiling ...
... backend keys (lowest, where real work happens) ...
CPU
CUDA
MPS
XPU
Meta                      # shape-only "fake" tensors
SparseCPU / SparseCUDA
QuantizedCPU / QuantizedCUDA

Each tensor has a DispatchKeySet -- a 64-bit bitset over these keys. A typical CUDA tensor with requires_grad=True has {AutogradCUDA, CUDA}. If you enter a with torch.autocast("cuda"): block, the thread-local "included" set adds AutocastCUDA, so calls inside that block effectively dispatch on {AutogradCUDA, AutocastCUDA, CUDA}.

4.3 Computing the key for a call

Per call:

ks = empty
for each tensor argument t:
    ks |= t.key_set()
ks |= local_include_set()      # e.g. AutocastCUDA inside autocast()
ks &= ~local_exclude_set()     # e.g. inference_mode excludes Autograd
top = ks.highest_priority_key()
kernel = op_table[op_id][top]
kernel(args)

This is conceptually ~50 lines of C++ in aten/src/ATen/core/dispatch/Dispatcher.{h,cpp}. Real implementation has fast paths and caching but the model is exact.

4.4 Redispatch

A kernel can run, then ask the dispatcher to re-run the same op but with its own key removed. That is how chaining works. In pseudocode for the autograd kernel for add:

Tensor add_autograd(const Tensor& a, const Tensor& b) {
    // 1. Make output by redispatching to lower keys (skip Autograd).
    auto out = at::redispatch::add(
        c10::DispatchKeySet(c10::DispatchKey::AutogradCUDA).remove_from(ks),
        a, b);
    // 2. If grad mode is on and any input requires grad, build the graph node.
    if (compute_requires_grad(a, b)) {
        auto node = std::make_shared<AddBackward0>();
        node->set_next_edges(collect_next_edges(a, b));
        set_history(out, node);
    }
    return out;
}

So autograd doesn't do the math. It records bookkeeping, then asks the layer below to do the math.

4.5 TORCH_LIBRARY and TORCH_LIBRARY_IMPL

Two macros in C++. The first declares ops in a namespace; the second registers a backend implementation for some keys.

#include <torch/library.h>

// Declare the op. Namespace "myops". Schema is C++-typed Python.
TORCH_LIBRARY(myops, m) {
    m.def("triple(Tensor x) -> Tensor");
}

// Implement for CPU.
Tensor triple_cpu(const Tensor& x) {
    return x * 3;
}
TORCH_LIBRARY_IMPL(myops, CPU, m) {
    m.impl("triple", triple_cpu);
}

// Implement for CUDA.
Tensor triple_cuda(const Tensor& x) {
    return x * 3;
}
TORCH_LIBRARY_IMPL(myops, CUDA, m) {
    m.impl("triple", triple_cuda);
}

You can also write m.impl("aten::add.Tensor", &my_add) to override a built-in op for your library / dispatch key. This is how custom backends, quantized ops, and out-of-tree devices plug in without touching ATen.

4.6 Layered keys: why autograd > autocast > backend

Imagine the opposite order. If autocast were above autograd, then when autograd records grad_fn, the recorded op would already have the cast applied -- so the backward would see the cast and might run in low precision unintentionally. Putting autograd on top means it sees the user-level op and the user-level dtypes. Conversely autocast sits above the backend so the cast happens before the kernel chooses its codepath.

The general rule: cross-cutting concerns that transform the call (cast, batch, fake-tensor-ify) sit above the backend; concerns that observe and record (autograd) sit above the transformers so their recording is faithful to user intent.

4.7 __torch_dispatch__ and __torch_function__

Two extension points worth knowing:

  • __torch_function__: a Python-level override. If you define a subclass of Tensor with this method, torch.add(my_tensor, ...) will route through your method before doing anything else. Used by libraries like torch.compile's subclass tracing, torch.func, and pretty-printing-only wrappers.
  • __torch_dispatch__: a post-dispatcher override. Called from inside the dispatcher at the Python key. You see the canonical op (aten.add.Tensor) with already-resolved overloads. Used for FakeTensor, FunctionalTensor, LoggingTensor, and is the right hook for "I want to intercept everything below the API level."

If you only ever debug models you may never write either, but you will see them mentioned in stack traces and dynamo logs.


5. ATen Op Registration: native_functions.yaml

ATen's op surface is defined declaratively in aten/src/ATen/native/native_functions.yaml. Each entry looks roughly like:

- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
  variants: function, method
  dispatch:
    CPU: add_cpu
    CUDA: add_cuda
    SparseCPU, SparseCUDA: add_sparse
    MkldnnCPU: mkldnn_add
  autogen: add.out
  tags: pointwise, canonical

This declares:

  • Schema: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor. This is the canonical Python-typed signature. The .Tensor suffix is the overload name (vs add.Scalar).
  • Variants: generate at::add(...) (function form) and Tensor::add(...) (method form).
  • Dispatch table: which C++ function implements which key.
  • Autogen: also generate the add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) variant from the out= pattern.

A codegen tool (driven from torchgen/, output mostly under build/aten/src/ATen/) reads this YAML and emits:

  1. Op symbol headers: at::_ops::add_Tensor callables.
  2. Function variants: free functions in at:: and methods on at::Tensor.
  3. Default implementations: `add_(...) {return at::add(...).copy_(...)} - style helpers.
  4. Autograd derivative bindings (combined with derivatives.yaml): AddBackward0::apply etc.
  5. Python bindings: THPVariable_add etc. via tools/autograd/.

When you read PyTorch source and cannot find at::add's definition: it's generated. Look at aten/src/ATen/native/BinaryOps.cpp for the kernels and at the YAML for the contract.

derivatives.yaml (in tools/autograd/) is the sibling file. Each entry binds an op to its VJP:

- name: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
  self: grad
  other: maybe_multiply(grad, alpha)

That tiny snippet is the autograd of add. The codegen turns it into an AddBackward0 Function class that returns (grad, alpha*grad).


6. The Autograd Engine

6.1 Intuition

Autograd in PyTorch is a dynamic tape: the graph is built every forward pass, executed once in reverse, and discarded. There is no "compile autograd". This buys flexibility (control flow, dynamic shapes) at the cost of allocating one graph node per differentiable op per forward.

6.2 Mechanism

Each differentiable op produces an output Tensor whose grad_fn is a Node (subclass of torch::autograd::Node, formerly Function). The Node holds:

  • next_edges_: list of (Node*, input_nr) pairs pointing at the Nodes that produced this op's inputs.
  • saved tensors / scalars needed for the backward (e.g., for mul, both inputs).
  • apply(grads_out) -> grads_in: the VJP.

grad_fn is null on leaves; leaves with requires_grad=True instead have an AccumulateGrad Node which writes into tensor.grad.

When you call loss.backward():

1. Engine seeds the gradient for `loss` (default: 1.0 if scalar).
2. It performs a reverse topological traversal starting at loss.grad_fn.
3. For each Node in topo order, call Node.apply(grad_outputs).
   Result: grad_inputs, one per next_edge.
4. Send each grad_input to the corresponding next_edge's Node, accumulating.
5. When a Node's incoming grad count is satisfied, schedule it.

The engine is multi-threaded across devices: there is one worker per device that owns Nodes for that device (torch/csrc/autograd/engine.cpp). CPU work runs on the calling thread.

6.3 Trace of a tiny graph

import torch
a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)
c = a * b
d = c + a
d.backward()
print(a.grad, b.grad)   # tensor(4.) tensor(2.)

Forward graph (as the dispatcher's autograd kernels build it):

        AccumulateGrad(a)            AccumulateGrad(b)
              ^                            ^
              |                            |
              +-----------+   +------------+
                          |   |
                       MulBackward0      <- saves a, b
                          |
                          v
                          c
                          |
              +-----------+
              |
        AddBackward0      <- needs alpha=1 only
              |
              v
              d

Backward execution starting from d with seed 1:

AddBackward0(grad=1) -> grad_c = 1, grad_a_partial = 1
MulBackward0(grad=1)  -> grad_a_partial2 = b = 3, grad_b = a = 2
AccumulateGrad(a): a.grad = 1 + 3 = 4
AccumulateGrad(b): b.grad = 2

Note how a had two paths to it (through c and direct) and the engine summed contributions at AccumulateGrad. That summing is what next_edges plus accumulation buys you for free.

6.4 VJP definition per op

For each forward op y = f(x1, ..., xn), the VJP is grads_in = J^T @ grad_y. PyTorch implements VJPs op-by-op so it never materialises a Jacobian. For mul(a, b):

y = a*b
dy/da = b
dy/db = a
VJP given grad_y:
    grad_a = grad_y * b
    grad_b = grad_y * a

These are themselves PyTorch ops that go through the dispatcher. Crucially, by default they go through the dispatcher with the Autograd key still on -- enabling double backward (computing gradients of gradients). To turn that off you would compute backward in no_grad.

6.5 Custom torch.autograd.Function

You write one when no built-in op covers your forward, or when you have a faster handwritten backward. The shape:

import torch

class FusedScaleClamp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, scale, lo, hi):
        ctx.save_for_backward(x)
        ctx.scale = scale
        ctx.lo, ctx.hi = lo, hi
        y = torch.clamp(x * scale, min=lo, max=hi)
        return y

    @staticmethod
    def backward(ctx, grad_y):
        (x,) = ctx.saved_tensors
        s, lo, hi = ctx.scale, ctx.lo, ctx.hi
        # grad flows only where the clamp didn't saturate
        scaled = x * s
        mask = (scaled > lo) & (scaled < hi)
        grad_x = grad_y * s * mask.to(grad_y.dtype)
        # No grads w.r.t. python scalars
        return grad_x, None, None, None

# Use it
x = torch.randn(8, requires_grad=True)
y = FusedScaleClamp.apply(x, 2.0, -1.0, 1.0)
y.sum().backward()
print(x.grad)

Three rules to keep in mind:

  1. The number of returned grads in backward must equal the number of inputs to forward. Use None for non-differentiable inputs (Python scalars, ints).
  2. Anything you save_for_backward must be a tensor; non-tensor context goes on ctx.<attr>.
  3. If your forward calls only differentiable PyTorch ops, you usually don't need a custom Function -- just write the function. Custom Functions are for when you sidestep autograd (e.g., calling a Triton kernel, or wanting a fused/cheaper backward).

6.6 Versioning and inplace

Each Storage carries a version counter. Inplace ops bump it. Saved tensors record the version they were saved at. On backward, the engine checks: if the version is now higher, you mutated a tensor that was needed for grad and the engine raises RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation. This is the famous error. The fix is almost always: don't += into something whose value the backward needs.


7. requires_grad, no_grad, inference_mode

7.1 requires_grad

A per-tensor flag (lives in AutogradMeta). When true, ops involving the tensor produce outputs whose grad_fn is set, and the tensor itself appears in the graph (via AccumulateGrad if it is a leaf).

Default: false for plain tensors, true for nn.Parameter.

7.2 no_grad

A thread-local override. Inside with torch.no_grad(): (or @torch.no_grad()), GradMode::is_enabled() returns false. The autograd kernel for each op checks this flag and, if grad mode is off, skips recording -- it just redispatches to the layer below. The op still runs through the autograd dispatch key (because the inputs still have AutogradCUDA in their key set), it just doesn't build graph nodes.

Use case: evaluation. You still get inference correctness; you save the bookkeeping cost of building backward graphs.

7.3 inference_mode

A stronger thread-local mode introduced for inference workloads. Inside with torch.inference_mode()::

  1. The Autograd dispatch keys are excluded from the key set entirely. The dispatcher no longer enters the autograd kernel at all -- it goes straight to the backend kernel.
  2. Outputs are marked as inference tensors. Their version counter is disabled. They cannot later be used in autograd.

Why is it faster than no_grad?

Cost no_grad inference_mode
Enter autograd dispatch kernel yes no
Allocate AutogradMeta for outputs yes (cheap but nonzero) no
Bump version counter on inplace yes no
Lookup kernel via dispatcher once per op once per op

You skip a whole layer of dispatch and a pile of small allocations. Benchmarks show ~5-15% overhead reduction on small ops where dispatch cost dominates.

The price: outputs cannot be used in autograd later. If you accidentally pass an inference tensor into a training graph, you get RuntimeError: Inference tensors cannot be saved for backward.

7.4 When to use which

  • Training loop, ever: nothing.
  • Eval loop you might re-enter training from: torch.no_grad().
  • Pure inference server: torch.inference_mode(). Wrap once at the top of the request handler.

8. AMP / Autocast

8.1 Intuition

Mixed precision exists because matmul on tensor cores is much faster in FP16/BF16 than FP32, but reductions (loss, softmax denominator, layer norm stats) want FP32 to avoid catastrophic cancellation. Autocast classifies ops:

  • lower-precision allowed (matmul, conv, linear): cast inputs down before running.
  • must stay in FP32 (loss functions, softmax, layer norm in some cases): leave inputs alone, or cast up.
  • promote (add of mixed dtypes): cast all inputs to the highest dtype present.

8.2 Mechanism: it's just another dispatch key

When you enter with torch.autocast("cuda", dtype=torch.float16):, the thread-local include set adds AutocastCUDA. Now any op called on a CUDA tensor inside the block has AutocastCUDA in its key set. The autocast kernel for that op runs before the backend kernel. It looks up the op's autocast policy:

// pseudocode for an autocast-lower op like matmul
Tensor matmul_autocastCUDA(const Tensor& a, const Tensor& b) {
    auto target = at::autocast::current_dtype(c10::DeviceType::CUDA);  // e.g. half
    auto a_cast = cached_cast(target, a);
    auto b_cast = cached_cast(target, b);
    return at::redispatch::matmul(/*remove AutocastCUDA*/, a_cast, b_cast);
}

For a "must stay FP32" op, the autocast kernel casts up instead. cached_cast keeps a small thread-local cache so if you matmul the same weight twice in a forward you don't re-cast it.

The lists live in aten/src/ATen/autocast_mode.cpp (and FP16/BF16 specific lists). They are explicit: you can read which ops are "lower", "fp32", "promote".

8.3 GradScaler vs autocast-only

The danger of FP16 is gradient underflow: tiny gradients become 0. The fix is loss scaling: multiply the loss by a big number S before backward, divide grads by S before the optimizer step. If any grad becomes Inf/NaN, skip the step and shrink S; otherwise gradually grow S.

scaler = torch.cuda.amp.GradScaler()
for x, y in loader:
    opt.zero_grad(set_to_none=True)
    with torch.autocast("cuda", dtype=torch.float16):
        out = model(x)
        loss = loss_fn(out, y)
    scaler.scale(loss).backward()
    scaler.step(opt)
    scaler.update()

BF16 has the same exponent range as FP32, so underflow is not an issue and you do not need GradScaler:

for x, y in loader:
    opt.zero_grad(set_to_none=True)
    with torch.autocast("cuda", dtype=torch.bfloat16):
        out = model(x)
        loss = loss_fn(out, y)
    loss.backward()
    opt.step()

Rule of thumb on Ampere or later: BF16 unless you have a specific reason. Older GPUs (Volta, Turing) lack good BF16, so FP16 + GradScaler.


9. torch.compile Pipeline

torch.compile(model) returns a wrapped callable. Under the hood it composes three pieces: TorchDynamo (capture), AOTAutograd (joint forward+backward capture), Inductor (codegen).

+------------------+      +----------------+      +------------+      +------------+
|  Python module   | ---> |   TorchDynamo  | ---> | AOTAutograd | ---> | Inductor   |
+------------------+      +----------------+      +------------+      +------------+
                            captures FX graph     joint fwd/bwd       Triton/C++
                            + guards              traced into          kernels
                                                  core ATen            + scheduling

9.1 TorchDynamo

Source: torch/_dynamo/.

Dynamo is a Python-level tracer that hooks into CPython's frame evaluation API (PEP 523). It registers an alternative frame evaluator. When Python is about to execute a function decorated by torch.compile, Dynamo intercepts the bytecode, symbolically executes it, and produces:

  1. An FX graph of tensor ops (torch.fx.GraphModule).
  2. A set of guards -- runtime predicates over inputs that, if true, mean it's safe to reuse this graph (e.g., "input 0 is torch.float32", "input 0's shape is (B, 1024) for some int B", "this Python int equals 7").
  3. A residual bytecode that calls the compiled graph and does anything Dynamo couldn't handle.

The key idea: instead of tracing a Tensor program, Dynamo traces Python bytecode, with FakeTensors standing in for real tensors. Every PyTorch op invocation gets recorded into the FX graph. Every Python-level operation that isn't a tensor op (a list append, an if x.shape[0] > 16:) is either:

  • specialised into a guard ("we recorded this branch when the shape was 32; if at runtime it's not, recompile"), or
  • becomes a graph break.
Graph breaks

A graph break happens when Dynamo encounters something it cannot symbolically execute. Examples:

  • print(x) -- has a side effect Dynamo doesn't model.
  • Calling into a third-party C extension that isn't a torch op.
  • A try/except whose handling Dynamo isn't sure how to capture.
  • Data-dependent control flow on a tensor without using torch.cond.
  • Mutating a global Python object.

When a break happens, Dynamo:

  1. Compiles the graph it has so far.
  2. Falls back to the Python interpreter for the offending statement.
  3. Resumes tracing from the next statement -- producing a second graph after the break.

Each graph compiles separately and you pay the call/launch overhead between them. One graph break = lost optimization opportunity. Many = torch.compile may even be slower than eager.

Inspecting
import torch

@torch.compile
def f(x):
    y = x.sin()
    print("hi")          # graph break
    return y.cos()

torch._dynamo.explain(f)(torch.randn(4))
# prints: number of graphs, number of breaks, reasons, locations

torch._dynamo.explain is your first debugging tool. If it says "1 graph, 0 breaks", you're golden. If it says "5 graphs, 4 breaks", read each reason and fix.

Other useful env knobs:

TORCH_LOGS="dynamo"             # what dynamo is tracing
TORCH_LOGS="graph_breaks"       # only the breaks
TORCH_LOGS="recompiles"         # why a graph re-compiled at runtime
TORCH_LOGS="output_code"        # the generated kernels (see Inductor)

9.2 AOTAutograd

Source: torch/_functorch/aot_autograd.py and friends.

Once Dynamo hands an FX graph to the compiler backend, AOTAutograd does two things:

  1. Joint trace of forward + backward. It runs the forward FX graph through make_fx with grad enabled, then calls .backward() to also trace the backward. The output is one FX graph containing both, plus a partition that assigns nodes to "forward" or "backward" subgraphs (so they can run separately at runtime).
  2. Decomposition to core ATen. Higher-level ops (e.g., torch.nn.functional.layer_norm) are decomposed into their constituent core ops (mean, var, mul, add, ...). This shrinks the op surface Inductor must handle from thousands to ~250 core ops.

The decomposition table lives in torch/_decomp/. You can see what an op decomposes to:

from torch._decomp import core_aten_decompositions
table = core_aten_decompositions()
for k, v in list(table.items())[:5]:
    print(k)

After AOTAutograd you have two FX graphs in core ATen: forward_graph and backward_graph. They are pure functions of inputs (and saved-for-backward tensors). Now Inductor compiles each.

9.3 Inductor

Source: torch/_inductor/.

Inductor is a lowering compiler. It takes an FX graph in core ATen, builds an internal IR (Inductor IR) that represents loops over tensors, fuses adjacent pointwise ops into bigger loops, schedules reductions, and emits target code.

Two backends:

  • CUDA / ROCm: emits Triton kernels. Triton handles the GPU-specific tiling and memory hierarchy; Inductor decides what to fuse and what shapes to specialise on.
  • CPU: emits C++ with OpenMP pragmas, optionally with vector ISA intrinsics (AVX2/AVX-512). Compiles via the system compiler at runtime.

Pipeline inside Inductor:

core ATen FX graph
    -> lowering        (each op -> Inductor IR ops; eg. add -> Pointwise(...))
    -> scheduler       (group nodes that can fuse; assign to kernels)
    -> codegen         (emit Triton or C++)
    -> compile         (call Triton's autotuner / call cc -O3)
    -> wrapper         (Python wrapper that calls each kernel in order)

Fusion is the big win. A handwritten layer norm in eager is 5+ kernels (mean, var, sub, mul, add). Inductor often emits one fused Triton kernel. Same for activation+linear-bias, embedding+layernorm, etc.

Inspecting generated code
TORCH_LOGS="output_code" python my_script.py

You get the Triton (or C++) source dumped. It's worth reading at least once -- you'll see something like:

@triton.jit
def triton_poi_fused_add_mul_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    a = tl.load(in_ptr0 + xindex, xmask)
    b = tl.load(in_ptr1 + xindex, xmask)
    c = a * 2.0 + b
    tl.store(out_ptr0 + xindex, c, xmask)

That's a fused a*2 + b. Two separate eager ops collapsed into one launch.

You can also dump the FX graph after AOTAutograd:

TORCH_LOGS="aot_graphs"

9.4 Guards and recompilation

A compiled artifact is keyed by (op graph, guards). At each call, Dynamo evaluates the guards over the actual inputs. If all hold, run the cached compiled artifact. If any fail, recompile.

Common guards:

  • Type guards: input is a Tensor, dtype is float32, device is cuda:0.
  • Shape guards: rank is 3, sizes are (2, ?, 1024). The ? may be symbolic (a free variable) or static (a specific int) depending on the dynamic-shape mode.
  • requires_grad guards: input had requires_grad=True. (Recompile if you switch eval/train without telling it.)
  • Python guards: a constant Python int equals N, a list has length M, a particular nn.Module instance is the same object.

Dynamic vs static shapes:

  • torch.compile(model) (default in 2.4+): tries dynamic shapes when it sees the same shape vary; specialises when it doesn't.
  • torch.compile(model, dynamic=False): always specialise on shapes. Faster code, more recompiles.
  • torch.compile(model, dynamic=True): assume dynamic from the start. Fewer recompiles, sometimes slower per-iter.

If you see frequent recompiles (TORCH_LOGS="recompiles"), the usual culprits are:

  1. Variable batch size with dynamic=False.
  2. Variable sequence length without marking it dynamic.
  3. Calling with different requires_grad settings (eval vs train without .eval()/.train()).
  4. Using Python lists/tuples whose length varies.

9.5 Modes

torch.compile(model, mode="default")
torch.compile(model, mode="reduce-overhead")
torch.compile(model, mode="max-autotune")
  • default: balanced. Compile time low-ish, runtime good.
  • reduce-overhead: enables CUDA graphs around the compiled region. CUDA graphs eliminate per-op CUDA launch overhead by recording a sequence and replaying it as one submission. Big win for small batches and lots of small ops. Constraints: shapes must be static across replays, and tensors must live at the same addresses (Inductor handles this with persistent input buffers; you may need to .clone() inputs or warm up).
  • max-autotune: Triton autotunes block sizes per kernel, multiple template variants for matmul, longest compile time, often best runtime.

9.6 End-to-end example

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(1024, 4096)
        self.l2 = nn.Linear(4096, 1024)
    def forward(self, x):
        return self.l2(torch.relu(self.l1(x)))

model = Net().cuda().to(torch.bfloat16)
model = torch.compile(model, mode="reduce-overhead")

x = torch.randn(32, 1024, device="cuda", dtype=torch.bfloat16)
# Warm-up: compiles + records cuda graph
for _ in range(3):
    y = model(x)

# Steady-state: every call is a single CUDA graph replay
for _ in range(1000):
    y = model(x)

What happened on first call:

  1. Dynamo hooks forward, traces it into FX (3 nodes: linear, relu, linear).
  2. AOTAutograd skips backward (no grad needed), decomposes linear -> matmul + add.
  3. Inductor lowers, fuses relu with the second matmul's bias-add prologue if possible, generates two Triton matmul kernels and a fused activation/bias kernel.
  4. CUDA graph captures the launch sequence on call 2.
  5. Calls 3+ replay the graph.

10. Custom Op Registration (Modern Path)

You want to register a new op so that:

  • It has an autograd rule.
  • It survives torch.compile (Dynamo and Inductor know what to do).
  • It works under FakeTensor / meta tracing.

The modern API is torch.library, available in 2.4+. Avoid the older torch.autograd.Function - only path when going throughtorch.compile`; Dynamo will graph-break on it.

10.1 Skeleton

import torch

# 1. Declare and implement the op for real backends
@torch.library.custom_op("mylib::myadd", mutates_args=())
def myadd(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x + y       # or call into a Triton kernel here

# 2. Tell the compiler/dynamo how shapes propagate (the "fake" / meta impl)
@myadd.register_fake
def _(x, y):
    # Must match real op's output shape/dtype/device, with no real compute
    return torch.empty_like(x)

# 3. Register an autograd rule
def myadd_setup_context(ctx, inputs, output):
    # Save what backward needs
    pass  # nothing for plain add

def myadd_backward(ctx, grad):
    return grad, grad   # dx, dy

myadd.register_autograd(myadd_backward, setup_context=myadd_setup_context)

Now mylib::myadd is a first-class op. You can call torch.ops.mylib.myadd(x, y) and it goes through the dispatcher like any built-in.

10.2 What each piece is for

  • custom_op: the user-facing implementation. Runs in eager mode.
  • register_fake: Dynamo / FakeTensor / torch.export use this to symbolically execute your op. It must allocate output tensors with correct shape/dtype/device but no real values. Without this, Dynamo will graph-break at your op.
  • register_autograd: the VJP. Mirrors torch.autograd.Function.backward. Setup context can save tensors via ctx.save_for_backward(...).
  • mutates_args: tuple of arg names that are mutated in place. The compiler needs to know this for correctness when reordering / re-using buffers.

10.3 A Triton kernel as a custom op (worked example)

import torch
import triton
import triton.language as tl

@triton.jit
def _add_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK: tl.constexpr):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < N
    a = tl.load(x_ptr + offs, mask=mask)
    b = tl.load(y_ptr + offs, mask=mask)
    tl.store(out_ptr + offs, a + b, mask=mask)

def _add_launch(x, y):
    out = torch.empty_like(x)
    N = x.numel()
    BLOCK = 1024
    grid = ((N + BLOCK - 1) // BLOCK,)
    _add_kernel[grid](x, y, out, N, BLOCK=BLOCK)
    return out

@torch.library.custom_op("mylib::triton_add", mutates_args=())
def triton_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    assert x.is_cuda and y.is_cuda and x.shape == y.shape
    return _add_launch(x.contiguous(), y.contiguous())

@triton_add.register_fake
def _(x, y):
    return torch.empty_like(x)

def _bwd(ctx, g):
    return g, g

triton_add.register_autograd(_bwd)

# Use it
x = torch.randn(4096, device="cuda", requires_grad=True)
y = torch.randn(4096, device="cuda", requires_grad=True)
z = torch.ops.mylib.triton_add(x, y)
z.sum().backward()
print(x.grad.shape, y.grad.shape)

# It also works under torch.compile thanks to register_fake
def f(a, b):
    return torch.ops.mylib.triton_add(a, b).relu()
g = torch.compile(f)
g(x, y)

The key win: Dynamo treats triton_add as an opaque op (it doesn't try to look inside the Triton kernel). It uses register_fake to know the shape, and register_autograd to know how grads flow. Inductor will not fuse with surrounding ops -- but it also won't graph-break.

If you do want Inductor to fuse with surrounding ops, write the kernel in core ATen (let Inductor codegen its own Triton). Custom Triton ops are for cases where you have a hand-tuned kernel that beats codegen.


11. C++ Extension Path

When you need raw C++/CUDA, the supported flow is torch.utils.cpp_extension. There are two flavours:

  • JIT (load, load_inline): compiles on first import. Great for development.
  • AOT (setup.py with CUDAExtension / CppExtension): produces a wheel.

11.1 Minimal setup.py skeleton

my_ext/
  setup.py
  src/
    binding.cpp
    kernel.cu

src/kernel.cu:

#include <torch/extension.h>

__global__ void scale_kernel(const float* in, float* out, float s, int N) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < N) out[i] = in[i] * s;
}

torch::Tensor scale_cuda(torch::Tensor x, double s) {
    TORCH_CHECK(x.is_cuda(), "x must be cuda");
    TORCH_CHECK(x.scalar_type() == torch::kFloat32, "x must be float32");
    auto y = torch::empty_like(x);
    int N = x.numel();
    int block = 256;
    int grid = (N + block - 1) / block;
    scale_kernel<<<grid, block>>>(x.data_ptr<float>(), y.data_ptr<float>(),
                                  static_cast<float>(s), N);
    return y;
}

src/binding.cpp:

#include <torch/extension.h>

torch::Tensor scale_cuda(torch::Tensor x, double s);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("scale", &scale_cuda, "scale(x, s) = x * s on CUDA");
}

setup.py:

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name="my_ext",
    ext_modules=[
        CUDAExtension(
            name="my_ext",
            sources=["src/binding.cpp", "src/kernel.cu"],
            extra_compile_args={"cxx": ["-O3"], "nvcc": ["-O3"]},
        ),
    ],
    cmdclass={"build_ext": BuildExtension},
)

Build and use:

pip install -e .
import torch, my_ext
x = torch.randn(1024, device="cuda")
y = my_ext.scale(x, 2.5)

11.2 ABI considerations

PyTorch is built with a specific C++ ABI (the C++11 GCC ABI on Linux). Your extension must be compiled against the same PyTorch headers and the same compiler ABI flag. Practical rules:

  • Always build the extension on the machine where it will run, against the installed PyTorch wheel, or publish per-PyTorch-version wheels.
  • Match the CUDA toolkit major version to PyTorch's CUDA major (e.g. CUDA 12.1 PyTorch -> CUDA 12.x toolkit).
  • Avoid passing C++ exceptions across the pybind boundary. Use TORCH_CHECK for user errors -- it raises Python RuntimeError.
  • Don't statically link C++ standard library; let the system one be used.
  • For wheels, use manylinux2014 or newer base images; build separate wheels per (PyTorch version, CUDA version, Python version) tuple.

11.3 Registering the C++ op into the dispatcher

If you want it to be a real dispatcher op (so autograd, autocast, etc. integrate), use TORCH_LIBRARY (Section 4.5) instead of (or in addition to) the pybind binding. That gives you torch.ops.myext.scale(x, s) and full dispatch behavior.


12. The CUDA Caching Allocator

12.1 Why caching

cudaMalloc and cudaFree are slow (often 100us+ each) and synchronous on the default stream. A naive implementation would call them per tensor. PyTorch instead routes all CUDA tensor allocations through a caching allocator (c10/cuda/CUDACachingAllocator.cpp).

12.2 Mechanism

Conceptually:

CachingAllocator state per device:
  large_blocks:   sorted-by-size list of free blocks >= 1MB
  small_blocks:   sorted-by-size list of free blocks <  1MB
  active_blocks:  {ptr -> Block}     # currently held by a Tensor

allocate(size, stream):
    round size up (small to nearest 512B; large to nearest 2MB)
    search the appropriate free list for a block of >= size on a compatible stream
    if found: split off the suffix as a free block, return prefix
    else:
        cudaMalloc a fresh segment (geometric growth: 2MB, 20MB, 200MB, ...)
        carve a block out of it, return it

free(ptr):
    mark Block free
    record the stream we last used it on; only reusable on that stream
        unless cuda events confirm cross-stream safety
    coalesce with neighbors in the same segment if both free

Two keys to internalize:

  1. free() does not call cudaFree. It returns the block to the pool. From the driver's perspective the memory is still allocated.
  2. Stream-aware reuse. Memory freed on stream A cannot be reused on stream B until events confirm the prior work has finished. This is why multi-stream code can OOM where single-stream code wouldn't: the allocator is conservatively keeping blocks pinned to the original stream.

12.3 empty_cache

torch.cuda.empty_cache()

Walks the free lists and actually cudaFrees segments that contain only free blocks. Returns memory to the driver -- visible to other processes (e.g., another container sharing the GPU). Does not shrink anything currently in use, and does not improve performance for your own process (the caching allocator was already going to reuse those blocks). Use it when you need to release memory across processes; do not sprinkle it through your training loop.

12.4 Reading memory_summary

print(torch.cuda.memory_summary(device=0, abbreviated=False))

You get a table like:

|---|------------|-----------|-----------|-----------|
|   |   Cur Usage|  Peak Usage| Tot Alloc | Tot Freed |
|---|------------|-----------|-----------|-----------|
|Allocated memory   |  ...      |  ...      |  ...      |  ...
|Active memory      |  ...      |  ...      |  ...      |  ...
|GPU reserved memory|  ...      |  ...      |  ...      |  ...
|Non-releasable mem |  ...      |  ...      |  ...      |  ...

Definitions:

  • Allocated: bytes currently held by Tensors.
  • Active: allocated + still-pending-on-stream-thus-uncoalescable.
  • Reserved: total cudaMalloc'd. Reserved - Allocated = sitting in the cache.
  • Non-releasable: free-but-cannot-be-given-back-to-driver because the segment still has at least one in-use block.

12.5 Fragmentation

The classic failure mode: you have 3GB free in the cache, but it's split into 300 blocks of ~10MB each, none big enough to satisfy a 100MB request. The allocator does coalesce neighbors, but only within the same segment. Mitigations:

  1. PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True (PyTorch 2.0+). The allocator uses CUDA virtual memory APIs (cuMemMap/cuMemAddressReserve) to grow a single backing segment instead of allocating many. Drastically reduces fragmentation for variable-shape workloads.
  2. PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:N: don't split blocks larger than N MB, reducing tiny suffix blocks scattered around.
  3. Avoid pathologic patterns: alternating very large and very small allocations on the same stream.

If you OOM at "tried to allocate 1GiB but only 500MiB free although 4GiB reserved", that is fragmentation. Check memory_summary; consider expandable segments.


13. Profiling Internals

torch.profiler.profile(...) (in torch/profiler/) records per-op entry/exit by registering callbacks at the dispatcher. Each time the dispatcher enters or exits an op, it calls every registered observer. The profiler is one such observer; so are autograd hooks and record_function.

13.1 Minimal use

import torch
from torch.profiler import profile, record_function, ProfilerActivity

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    with_stack=False,
) as prof:
    with record_function("forward"):
        y = model(x)
    with record_function("loss"):
        loss = (y - target).pow(2).mean()
    with record_function("backward"):
        loss.backward()

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
prof.export_chrome_trace("trace.json")    # open with chrome://tracing or perfetto

Open the trace and you get a flame-graph-like view: CPU dispatcher events on top, CUDA kernel events on the bottom with launch arrows. The gaps between kernels are launch overhead.

13.2 What each column means

  • Self CPU: time the op itself spent in CPU (dispatcher + Python -> C++ + kernel launch).
  • CPU total: includes children. A linear op's total includes its matmul and add children.
  • Self CUDA / CUDA total: same on GPU, measured via CUDA events.
  • # of Calls: how many times this op key (with these shapes) was hit.

13.3 Reading patterns

  • "Self CPU >> Self CUDA, kernels short": you are launch-overhead bound. Try torch.compile(mode="reduce-overhead") for CUDA graphs, or batch up small ops.
  • "Self CUDA dominates, one kernel is 80% of it": profile that kernel; consider a different algorithm (FlashAttention, a fused MoE), or see if Inductor will generate a better one with max-autotune.
  • "CUDA gaps with no work": host is too slow producing input. DataLoader is the usual suspect; bump num_workers, prefetch, pin memory.
  • "memcpy dominating": you're moving data CPU<->GPU per step. Pin host memory, pre-load to GPU, or use non_blocking=True with pinned source.

13.4 Autograd hooks for finer questions

hooks = []
for name, p in model.named_parameters():
    h = p.register_hook(lambda g, n=name: print(n, g.norm()))
    hooks.append(h)
loss.backward()
for h in hooks: h.remove()

Hook fires from the engine's worker thread when the grad is computed. Useful for "which parameter's grad is NaN".


14. Practical Exercises (with answers)

Exercise 1: stride sleuthing

You have:

import torch
a = torch.arange(60).reshape(3, 4, 5)
b = a.permute(2, 0, 1)

Without running this, what are b.shape, b.stride(), b.storage_offset()? Is b.is_contiguous()? Why?

Answer. a.shape=(3,4,5), a.stride()=(20,5,1). Permute remaps dims by index: new dim 0 = old dim 2, new dim 1 = old dim 0, new dim 2 = old dim 1. So b.shape=(5,3,4), b.stride()=(1,20,5), b.storage_offset()=0. Not contiguous because canonical strides for (5,3,4) would be (12,4,1) and ours are (1,20,5).

Exercise 2: dispatch trace

You write:

with torch.autocast("cuda", dtype=torch.bfloat16):
    with torch.no_grad():
        y = a @ b

where a, b are CUDA fp32 tensors with requires_grad=True. List the keys in the dispatch set for @, the order of kernels invoked, and the dtype of y.

Answer. - Per-tensor key set: {AutogradCUDA, CUDA}. - Local include from autocast: adds AutocastCUDA. - Local exclude from no_grad: does not exclude Autograd keys (that's inference_mode's job). However, the autograd kernel checks GradMode::is_enabled() and, finding it false, just redispatches without recording.

So the order is: AutogradCUDA kernel (skips recording, redispatches) -> AutocastCUDA kernel (casts a,b to bf16, redispatches) -> CUDA kernel (runs bf16 matmul). y.dtype is bfloat16. y.requires_grad is False.

Exercise 3: detect a graph break

@torch.compile
def f(x, n):
    if n.item() > 0:
        return x.sin()
    else:
        return x.cos()

Why is this slow / breaky, and how do you fix?

Answer. n.item() materialises a tensor value to a Python int. Dynamo cannot symbolically execute that branch -- it triggers a graph break (or specialisation / recompile per value of n). The fix: use torch.where (data-dependent on tensor) or torch.cond (if you really need control flow):

@torch.compile
def f(x, n):
    return torch.where(n > 0, x.sin(), x.cos())

Exercise 4: a custom op that survives compile

Write a custom op clip_norm(x, max_norm) that scales x so its L2 norm is at most max_norm. Make sure it works under torch.compile.

Answer.

import torch

@torch.library.custom_op("mylib::clip_norm", mutates_args=())
def clip_norm(x: torch.Tensor, max_norm: float) -> torch.Tensor:
    n = x.norm()
    scale = (max_norm / (n + 1e-12)).clamp(max=1.0)
    return x * scale

@clip_norm.register_fake
def _(x, max_norm):
    return torch.empty_like(x)

# autograd: easiest is to leave it to the implementation
# since it uses only autograd-aware ops; but custom_op disables
# autograd-through-implementation by default. So:
def _bwd(ctx, g):
    x, scale = ctx.saved
    # d(x*scale)/dx = scale (treating scale as constant for simplicity)
    return g * scale, None

def _setup(ctx, inputs, output):
    x, max_norm = inputs
    n = x.norm()
    scale = (max_norm / (n + 1e-12)).clamp(max=1.0)
    ctx.saved = (x, scale)

clip_norm.register_autograd(_bwd, setup_context=_setup)

Note we approximate the gradient by treating scale as a constant -- that's the standard / desired behaviour for gradient clipping.

Exercise 5: why does this OOM?

Training fine yesterday, today OOMs at the same batch size. Memory summary shows Reserved=22GiB, Allocated=8GiB. Tried empty_cache, no change. What are two most likely causes and one mitigation?

Answer. Causes: 1. Fragmentation: the 14 GiB cache is split into too many small free blocks for a big request to find a contiguous free run. Mitigation: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True (PyTorch 2.0+). 2. Stream-pinned blocks: a multi-stream codepath freed memory on stream A; the allocator can't yet hand it to stream B. Mitigation: synchronize, or unify streams.

empty_cache does nothing here because all 14GiB of free cache is non-releasable (segments still have allocated blocks).

Exercise 6: inference_mode vs no_grad latency

You have a hot inference path that runs ~50 small ops per request (lots of LayerNorm, GELU, small matmul). You measure 8% latency drop switching no_grad -> inference_mode. Why? Where would the gain be much smaller?

Answer. inference_mode excludes Autograd dispatch keys entirely, so each op skips the autograd kernel layer (a function call, a check, and an output AutogradMeta allocation). Saving ~hundreds of nanoseconds per op times ~50 ops times batch is a measurable percentage when individual ops are short.

The gain shrinks toward zero as ops get bigger: a single 4096x4096 fp16 matmul takes milliseconds, dwarfing the dispatch cost. The win is in launch-overhead-bound regimes. For one big op per request, prefer profiling to see if it's worth bothering.


15. A Coherent Mental Model To Keep

If you remember nothing else, remember these seven sentences:

  1. A Tensor is a (storage, sizes, strides, storage_offset, dtype, device) tuple.
  2. Views share storage; contiguous() materialises.
  3. The dispatcher picks a kernel from a key set assembled from inputs and thread-local mode; layers redispatch by removing their own key.
  4. Autograd is a layer in the dispatcher that, when grad mode is on, builds a tape; loss.backward() traverses it.
  5. inference_mode is faster than no_grad because it removes the Autograd layer entirely.
  6. Autocast is just a dispatcher layer that casts inputs before the backend runs.
  7. torch.compile is Dynamo (capture Python -> FX) + AOTAutograd (joint forward/backward in core ATen) + Inductor (Triton/C++ codegen with fusion); guards govern when the compiled artifact is reused.

Everything else -- caching allocator, profiler, custom ops -- hangs off these. Once you can simulate the dispatcher in your head and reason about the compile pipeline, PyTorch internals stop feeling like a foreign country and become a place you live.


Appendix A: Source Tree Map

A high-confidence cheat sheet for navigating the repo:

Path Contents
c10/core/ TensorImpl, Storage, DispatchKey, Device, Layout. The smallest, most stable foundation.
c10/cuda/ CUDACachingAllocator, CUDA stream/event wrappers.
aten/src/ATen/core/ The dispatcher (Dispatcher.cpp), op registration (library.cpp).
aten/src/ATen/native/ Op kernels (CPU / generic). BinaryOps.cpp, ReduceOps.cpp, etc.
aten/src/ATen/native/cuda/ CUDA kernels.
aten/src/ATen/native/native_functions.yaml The op-schema source of truth.
aten/src/ATen/autocast_mode.cpp Autocast policies (which ops are FP16/BF16/FP32/promote).
tools/autograd/derivatives.yaml Op-by-op VJPs for codegen.
torch/csrc/autograd/ Engine, Function (Node), saved variables.
torch/_dynamo/ TorchDynamo (frame eval, symbolic execution, guards).
torch/_functorch/aot_autograd.py AOTAutograd.
torch/_decomp/ Decompositions to core ATen.
torch/_inductor/ Inductor lowering, scheduler, codegen (codegen/triton.py, codegen/cpp.py).
torch/library.py Modern custom-op API (custom_op, register_fake, register_autograd).
torch/utils/cpp_extension.py JIT and AOT C++/CUDA extension build helpers.
torch/profiler/ Profiler frontend; backend in torch/csrc/profiler/.

When in doubt, git grep for the op name in aten/src/ATen/native/ -- the kernel is almost always there.

Appendix B: Useful Environment Variables

TORCH_LOGS=dynamo,graph_breaks,recompiles,aot_graphs,output_code
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb:128
TORCH_SHOW_DISPATCH_TRACE=1            # prints kernel choice per op (verbose)
TORCH_USE_CUDA_DSA=1                   # device-side assertions for shape/index errors
CUDA_LAUNCH_BLOCKING=1                  # serialises kernel launches; better stack traces
TORCHINDUCTOR_CACHE_DIR=/tmp/inductor   # control compile cache location
TORCHINDUCTOR_MAX_AUTOTUNE=1            # equivalent to mode="max-autotune"

End of chapter.

Deep Dive 05-JAX and XLA

Reading contract. This chapter is a self-contained reference. After reading and working the exercises you should be able to (a) read and write idiomatic JAX, (b) reason about what happens when @jax.jit is applied to a function, (c) inspect jaxprs and HLO, (d) shard a computation across a multi-host TPU/GPU cluster using Mesh + PartitionSpec, and (e) pick between jit - with-sharding,shard_map, and (legacy)pmap` for a given workload. We do not punt to the JAX docs.


Table of contents

  1. Why JAX exists
  2. Functional purity: the unit of compilation
  3. PyTrees and jax.tree_util
  4. Stateless PRNGs (PRNGKey)
  5. Tracing and jaxprs
  6. jax.jit: caching, recompilation, static args
  7. jax.grad, value_and_grad, jvp, vjp
  8. jax.vmap and per-example gradients
  9. Device parallelism: pmap (legacy) vs jit + sharding (modern)
  10. jax.shard_map: when you want manual control
  11. Structured loops: lax.scan, lax.fori_loop, lax.while_loop
  12. XLA: HLO IR, compilation pipeline, fusion, layout, GSPMD
  13. TPU vs GPU under XLA
  14. Module systems on top: Equinox and Flax
  15. jax.experimental.pallas (Triton-like kernel DSL)
  16. Practical exercises (with worked answers)
  17. Cheat-sheet appendix

1. Why JAX exists

JAX is a library for numerical computing whose central thesis is:

Numerical programs are functions of arrays. If you keep them pure, you can compose program transformations on them-autodiff, vectorization, parallelization, just-in-time compilation-and you can lower the result through a single optimizing array compiler (XLA) to CPU, GPU, or TPU.

That sentence has every load-bearing word in JAX's design. Let us unpack it against the contrast that most readers carry: PyTorch.

1.1 PyTorch's design (so we have a foil)

PyTorch is eager and imperative: a tensor operation runs immediately on the host's accelerator queue. The autograd graph is built dynamically as a side effect of forward computation-every tensor with requires_grad=True allocates a node in a graph stored on the tensor itself. Modules are stateful objects (nn.Module) that own their parameters, buffers, and (transitively) the optimizer state. To go fast you typically (a) call into eager kernels, (b) use torch.compile (TorchDynamo + Inductor) which traces Python bytecode and lowers to fused kernels, or (c) drop into custom CUDA / Triton.

PyTorch optimizes for debuggability and programmer ergonomics in idiomatic Python. You can print a tensor, set a Python breakpoint inside a forward pass, mutate a list of layers conditionally, and it all works.

1.2 JAX's design choices

JAX takes an almost orthogonal set of choices:

  1. Pure functions are the unit. A JAX-compilable function takes arrays in, returns arrays out, and has no side effects: no mutation of outer Python state, no in-place tensor edits, no random state hidden in a global. Everything that would be state is threaded explicitly through arguments and return values.

  2. Composable transformations. Once a function is pure, JAX can give you several function-to-function transformations:

  3. `jax.jit - trace and compile via XLA.
  4. `jax.grad - return a function computing the gradient.
  5. `jax.vmap - return a function that runs the original over a new batch axis.
  6. jax.pmap / jit with sharding-return a function that runs across devices. These compose: jit(vmap(grad(f))) is meaningful and well-defined. This composability is the soul of JAX.

  7. XLA as the default backend. Where PyTorch's "real" backend is a constellation of cuDNN/cuBLAS calls in eager mode and Inductor-generated Triton at compile time, JAX always lowers to HLO (XLA's IR) and lets the XLA compiler emit device code. This means TPU and GPU share most of the toolchain. (XLA was originally a TPU compiler at Google, and that lineage shows.)

  8. TPU first-class. Unlike PyTorch where TPU support is delivered through torch_xla as an extra layer, JAX speaks XLA natively. A JAX program written for one TPU pod core scales to thousands of cores essentially by adding sharding annotations.

1.3 Trade-offs

The JAX bargain:

  • You give up in-place mutation, easy printing inside compiled code, dynamic shapes that change every call, and Python-level control flow that depends on traced values.
  • You get a uniform compilation pipeline, world-class autodiff that composes with everything, free vectorization (vmap), free distribution (jit + sharding), and predictable performance because compilation is explicit.

For research workloads with stable shapes and heavy linear algebra (transformers, diffusion, scientific computing), this trade is excellent. For workloads with ragged shapes, dynamic graphs of varying topology, or heavy host-side branching (some RL systems, classical NLP pipelines), it can be painful.


2. Functional purity: the unit of compilation

A pure function in JAX's sense:

  • Outputs depend only on inputs. No global reads (other than constants closed over at trace time).
  • Has no observable side effects. No global writes, no I/O, no mutation of arguments.
  • Same input → same output, every call.
import jax
import jax.numpy as jnp

# Pure
def loss(params, x, y):
    pred = x @ params["W"] + params["b"]
    return jnp.mean((pred - y) ** 2)

# Impure-mutates a global counter
counter = 0
def bad(x):
    global counter
    counter += 1
    return x * 2

jax.jit(bad) will appear to work but counter will be incremented exactly once per trace, not once per call. The bug surfaces silently. This is a recurring pattern: JAX does not police purity at runtime; it assumes it. If you violate it, you get correct-looking output and incorrect semantics.

The discipline imposed by purity buys two enormous things:

  1. Trivial reverse-mode AD. With no side effects, the chain rule is just structural induction over the jaxpr. There is nothing to "undo."
  2. Trivial parallelization. Pure functions are referentially transparent; you can run them on any device, in any order, multiple times, without changing program meaning.

2.1 Where state lives

If the model has weights, optimizer moments, RNG state, batch-norm running stats, those are values that the user threads through arguments:

def train_step(params, opt_state, rng, batch):
    rng, sub = jax.random.split(rng)
    grads = jax.grad(loss)(params, batch, sub)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, rng

Compare with PyTorch where optimizer.step() mutates param.data and param.grad in place. JAX surfaces the mutation as new return values. The train_step itself remains pure.


3. PyTrees and jax.tree_util

A neural network has hundreds or thousands of parameters. Threading them as positional arguments is unworkable. JAX solves this by making arbitrary nested Python containers first-class.

A PyTree is, recursively, either: - a leaf (an array, a scalar, anything not a registered container), or - a container node-by default tuple, list, dict, None, `namedtuple - with PyTree children.

Custom dataclasses can be registered with jax.tree_util.register_pytree_node (or, for dataclasses, @jax.tree_util.register_dataclass / Equinox's automatic registration).

3.1 Why this matters

Every JAX transformation that takes a function f(x) -> y and produces f'(x) -> y' operates over PyTrees: x and y may be arbitrarily nested. The transformations preserve structure.

Example: jax.grad applied to a function whose first argument is params = {"layer1": {"W": ..., "b": ...}, "layer2": {...}} returns a gradient PyTree with the same shape as params. You never write flatten_params glue code by hand.

3.2 The core API

from jax import tree_util as tu

leaves, treedef = tu.tree_flatten(params)   # list of arrays + structure spec
params2 = tu.tree_unflatten(treedef, leaves) # rebuild

# Map a function over every leaf:
doubled = jax.tree.map(lambda x: 2 * x, params)

# Map across two PyTrees that share structure:
sum_tree = jax.tree.map(lambda a, b: a + b, params, grads)

In recent JAX, the public surface is jax.tree.map, jax.tree.leaves, jax.tree.structure, etc. The underlying jax.tree_util module remains.

3.3 PyTree as the universal interface

Internally, every JAX transformation: 1. Flattens its inputs to a flat list of leaves + a tree-structure description. 2. Operates on the flat list (where everything is just an array). 3. Reconstructs the output structure on the way out.

This is why you can pass a dict of arrays to jit, grad, vmap, pmap and they all "do the right thing"-they each call tree_flatten and treat leaves uniformly.

3.4 Custom PyTree

import dataclasses
@jax.tree_util.register_dataclass
@dataclasses.dataclass
class GRUCell:
    W: jax.Array
    U: jax.Array
    b: jax.Array

Now GRUCell instances are valid PyTrees. jax.grad will return a GRUCell of gradients.

Contrast with PyTorch. PyTorch's nn.Module is a class with a state_dict() method. JAX's analog is "any PyTree." The "module" abstraction is built on top (Flax, Equinox), not into the core.

3.5 Exercise

Given params = {"a": jnp.zeros((3,)), "b": [jnp.ones((2,2)), jnp.ones((2,))]}, write the call that returns {"a": shape (3,), "b": [shape (2,2), shape (2,)]}.

jax.tree.map(lambda x: x.shape, params)

4. Stateless PRNGs (PRNGKey)

Random numbers are state. State is impure. JAX therefore cannot have a global RNG (well, it could, but it would break composition with jit/vmap/pmap).

Solution: an explicit, immutable key, threaded by the user.

key = jax.random.PRNGKey(42)         # an array of shape (2,) uint32 (historically)
key, subkey = jax.random.split(key)  # 2 new keys; old key conceptually consumed
x = jax.random.normal(subkey, (1024,))

Three rules:

  1. Never reuse a key. random.normal(key, ...) is a pure function of key; passing the same key gives identical samples.
  2. Always split before consuming. split(key, n) returns n fresh keys.
  3. Threading is the user's job. Every function that consumes randomness takes a key argument.

4.1 Why this design

  • Reproducibility. Two runs with the same starting key produce bit-identical results, even across vmap/pmap/sharding.
  • Composability. vmap(f) over a batched-keys argument gives per-example randomness with no surprise. With a global RNG, vmap could not say what the per-example samples should be.
  • Determinism under parallelism. Each device gets its own key derived deterministically from the master key. No race conditions, no per-device RNG state to seed.

4.2 Idiom: per-step splitting

def train_step(params, rng, batch):
    rng, dropout_key = jax.random.split(rng)
    logits = model(params, batch, dropout_key)
    ...
    return params, rng  # return the *new* rng for the next step

You can split into many keys at once: keys = jax.random.split(rng, num=8) and vmap over them.

4.3 Key types

Modern JAX has typed keys (e.g., the threefry and rbg algorithms). For day-to-day work PRNGKey(seed) is fine; for cryptographic-grade or platform-specific needs see the jax.random.key API. The conceptual model-explicit, splittable, stateless-is unchanged.


5. Tracing and jaxprs

This is the conceptual heart of JAX. Internalize it and the rest follows.

5.1 What @jax.jit actually does

When you call a `jit - decorated function for the first time with concrete arguments:

  1. JAX inspects each argument's shape and dtype (and static-argnum python values).
  2. It calls your Python function with abstract Tracer objects in place of those arguments-objects that record every operation performed on them but do not compute values.
  3. The resulting trace is a jaxpr (JAX expression): a small typed IR of primitive operations.
  4. The jaxpr is lowered to HLO and compiled by XLA for the target device.
  5. The compiled executable is cached, keyed by (function identity, abstract input signature, static-arg values).
  6. The actual concrete arguments are run through the executable.

Subsequent calls with the same abstract signature skip steps 2–5 and just run the cached executable.

5.2 A worked jaxpr

import jax
import jax.numpy as jnp

def f(x, y):
    a = x * y
    b = jnp.sin(a)
    return jnp.sum(b)

print(jax.make_jaxpr(f)(jnp.ones((3,)), jnp.ones((3,))))

You will see something like:

{ lambda ; a:f32[3] b:f32[3]. let
    c:f32[3] = mul a b
    d:f32[3] = sin c
    e:f32[] = reduce_sum[axes=(0,)] d
  in (e,) }

Reading it:

  • a:f32[3] b:f32[3] - twofloat32inputs of shape(3,)`.
  • The let block names intermediate values.
  • mul, sin, reduce_sum are JAX primitives-the leaves of the jaxpr.
  • `in (e,) - the output tuple.

A jaxpr is pure, typed, and closed: every variable is bound, every operation is a primitive, every shape and dtype is known. This is the input XLA receives.

5.3 Concrete vs Abstract vs Traced

Three kinds of values flow through JAX code:

  • Concrete arrays (jax.Array): real data on a device. Default outside jit.
  • Abstract arrays (ShapedArray, ConcreteArray): metadata only-shape, dtype, optional weak type. Used for tracing.
  • Tracers (Tracer): wrappers presented to the user's Python function during tracing. They look like arrays (they have .shape, .dtype, support +, *, jnp.sin, etc.) but every operation on them appends a node to the jaxpr.

A common pitfall: writing Python control flow on a tracer.

@jax.jit
def f(x):
    if x > 0:           # ConcretizationError: Tracer cannot be branched on
        return x
    else:
        return -x

The condition x > 0 is a tracer because x is. Python's if requires a concrete bool. You must use jax.lax.cond(x > 0, ..., ...) (or jnp.where for elementwise selection), which becomes part of the jaxpr.

5.4 Static arguments

If x is a Python scalar that determines the shape of arrays-e.g. number of layers-make it static:

from functools import partial
@partial(jax.jit, static_argnums=(1,))
def make_zeros(rng, n):     # n is static
    return jax.random.normal(rng, (n,))

n is now treated as part of the cache key, not as a tracer. Different n values trigger different compilations.

5.5 Print-debugging

print(x) inside a jitted function prints the tracer (useful for inspection), not the value. For per-call debug prints use jax.debug.print("x = {}", x) which compiles into a host callback.


6. jax.jit: cache, recompilation, costs

6.1 The cache key

The compilation cache is keyed (essentially) by:

Component What changes the key
Function identity The Python function object
Abstract input signature shape and dtype of every leaf in the input PyTree
Static argument values python == equality of static_argnums / static_argnames values
PyTree structure the treedef of inputs
Device / sharding context target backend & sharding spec

So:

  • Calling f(x_f32_3x4) then f(x_f32_3x4) → 1 compile, 2 calls.
  • Calling f(x_f32_3x4) then f(x_f32_3x5)2 compiles (different shape).
  • Calling f(x_f32_3x4) then f(x_f64_3x4) → 2 compiles (different dtype).
  • Calling f({'a': x, 'b': y}) then f({'b': y, 'a': x}) → 1 compile (dicts have stable PyTree order).
  • Changing the value of a non-static argument → 0 compiles.
  • Changing the value of a static argument → 1 compile per distinct value.

6.2 Recompilation costs

Compilation is not free: HLO optimization plus device codegen can cost hundreds of milliseconds to tens of seconds for transformer-scale models. If your training loop accidentally retraces every step, the program runs but at compile-bound throughput. Symptoms:

  • First step: 5 s. Second step: 5 s. Third step: 5 s. (Should be: 5 s, 50 ms, 50 ms.)
  • Memory growth in the HLO module cache.

Detection:

jax.config.update("jax_log_compiles", True)  # logs every compile

or set JAX_LOG_COMPILES=1. You should see one line per training-loop function, not one per step.

Common causes of accidental retraces:

  1. Padding-by-actual-length: each batch has a slightly different sequence length, so shapes vary. Fix: pad to a small set of bucket lengths.
  2. Passing Python ints that the function uses to construct shapes-make them static_argnums.
  3. Passing different PyTree structures (e.g., a dict with an optional key).

6.3 jit is lazy, dispatch is async

jit - compiled calls return immediately with futures (jax.Arraybacked by a pending computation). Usejax.block_until_ready(x)orx.block_until_ready()` when timing.

6.4 Ahead-of-time lowering

lowered = jax.jit(f).lower(jnp.ones((4,)), jnp.ones((4,)))
print(lowered.as_text())            # StableHLO MLIR text
print(lowered.compiler_ir(dialect="hlo"))  # HLO module
compiled = lowered.compile()
print(compiled.cost_analysis())     # FLOPs, bytes, etc., for benchmarking

This is the inspection toolchain you will use repeatedly.


7. jax.grad and friends

7.1 Reverse-mode AD on a jaxpr

jax.grad(f) returns a function g such that g(x) equals df/dx evaluated at x. Mechanically:

  1. Trace f to a jaxpr (the primal jaxpr).
  2. Walk the jaxpr forward, recording residuals where needed.
  3. Construct a transposed / reverse jaxpr that computes the cotangent: for each primitive, JAX has a registered VJP rule (primitive.def_vjp(...)).
  4. Return that as a callable jaxpr (typically jitted).

Functional purity is what makes step 3 trivial: each primitive has a local linearization, and the chain rule is just composition because there are no hidden state edges to break it.

By convention, grad(f) differentiates with respect to the first argument and expects a scalar output.

def loss(params, x, y):
    return jnp.mean((x @ params["W"] - y) ** 2)

grad_fn = jax.grad(loss)
g = grad_fn(params, x, y)   # g has the same PyTree shape as params

For multiple argnums: jax.grad(loss, argnums=(0, 1)) returns a tuple of grads.

7.2 value_and_grad

You usually want both the loss value and the gradient. jax.value_and_grad(loss) returns (loss, grads) from a single trace-no double work.

(loss_val, grads) = jax.value_and_grad(loss)(params, x, y)

7.3 Higher-order

jax.grad(jax.grad(f)) is well-defined (it traces grad(f) and differentiates that trace). For Hessians, use jax.hessian(f) (which is jacfwd(jacrev(f)) under the hood for a typical scalar function).

7.4 Forward and reverse mode primitives

Two lower-level operators:

  • jax.jvp(f, primals, tangents) - forward-mode: computesf(primals)and the directional derivativeJ · tangents` in one go. Good when output dim ≪ input dim.
  • jax.vjp(f, *primals) - reverse-mode: returns(f(primals), vjp_fn)wherevjp_fn(cotangent)computescotangent · J. This is whatgrad` is built on.

Rule of thumb:

  • Few outputs, many inputs (training loss → loss is scalar): reverse mode (grad/vjp).
  • Few inputs, many outputs (sensitivity of a vector-valued function to a small parameter): forward mode (jvp).
  • For Jacobians: jax.jacrev (reverse) for tall Jacobians, jax.jacfwd (forward) for wide Jacobians.

7.5 Custom derivatives

@jax.custom_vjp
def stable_softmax(x):
    z = x - jnp.max(x)
    return jnp.exp(z) / jnp.sum(jnp.exp(z))

def fwd(x): ...
def bwd(res, g): ...
stable_softmax.defvjp(fwd, bwd)

Use this when (a) you have a hand-derived gradient that is more numerically stable than the autodiff one, or (b) you want to break a gradient (stop_gradient is a single-line alternative for that).

7.6 Contrast with PyTorch

In PyTorch, loss.backward() walks the dynamically-built graph attached to the tensor, populates .grad fields by mutation, and frees the graph. In JAX, grad(loss) builds a new function that returns a new PyTree of gradients. There is no .grad attribute, no graph to free, no optimizer.zero_grad() needed-purity makes "zero out before backward" unnecessary because nothing is mutated.


8. jax.vmap: vectorization is a transformation

vmap(f) returns a function that runs f over an extra leading axis without you writing the batched code. It is not a Python for loop. It rewrites the jaxpr to push the batch dimension through every primitive.

8.1 The basic interface

def dot(a, b):                     # a:(d,) b:(d,) -> ()
    return jnp.sum(a * b)

batched_dot = jax.vmap(dot)        # (B,d), (B,d) -> (B,)
batched_dot(jnp.ones((32, 5)), jnp.ones((32, 5)))

8.2 in_axes / out_axes

Specify which axis of each argument is the batched axis:

# Batch over a's first axis, broadcast b:
jax.vmap(dot, in_axes=(0, None))   # (B,d), (d,) -> (B,)

# Batch over a's last axis, b's first axis:
jax.vmap(dot, in_axes=(-1, 0))

None means "this argument is not batched-broadcast it." out_axes controls where the batch axis appears in the output (default 0).

For PyTree arguments, in_axes is itself a PyTree (or a single int, applied uniformly).

8.3 How vmap works under the hood

For each primitive p, JAX has registered a batching rule: given how each input is batched, what is the batched output and along which axis? vmap walks the jaxpr applying these rules. Most primitives have rules that turn into a single fatter primitive call-e.g., vmap(dot) becomes a matmul, not a Python loop. That is why vmap is fast: it produces good HLO.

8.4 Composes with grad

The canonical example: per-example gradients.

In standard training, grad(loss) gives the gradient of the mean loss-a single PyTree summed across the batch. Sometimes you want one gradient per example (for influence functions, differential privacy, gradient clipping per example, etc.).

def per_example_loss(params, x, y):     # x:(d,), y:()-single example
    pred = x @ params["W"] + params["b"]
    return (pred - y) ** 2

per_example_grad = jax.vmap(jax.grad(per_example_loss), in_axes=(None, 0, 0))
# per_example_grad(params, X, Y) returns grads with leading axis B

Read it carefully: - grad(per_example_loss) is a function that, given a single example, returns a single gradient PyTree. - vmap(...) lifts that over a batch axis on (x, y) while sharing params. - The result has the same PyTree structure as params but with an added leading batch dimension on every leaf.

This is two lines of JAX. The PyTorch equivalent typically requires functorch.vmap (now folded into torch.func) or torch.func.vmap(grad(...)).

8.5 vmap for the inference case

A model that runs on a single example can be batched by writing the model for one example and vmap - ing it. This is sometimes cleaner than worrying about broadcasting in the model definition. In practice, most JAX models are written batched (becausematmulalready does it for free), butvmap` is invaluable for non-trivial axes (e.g., MoE expert routing, beam search, per-head attention).


9. Device parallelism: pmap (legacy) vs jit + sharding (modern)

JAX has been through a small evolution here. Understand both because real codebases mix them.

9.1 pmap (the original)

pmap is "parallel map": it `jit - compiles a function and runs it on multiple devices, with one shard of the leading axis per device.

@jax.pmap                          # 8 devices: leading axis must be 8
def step(params, batch):
    ...
    return new_params

# params must be replicated across devices:
params = jax.tree.map(lambda x: jnp.broadcast_to(x, (8,) + x.shape), params)
batch = ... shape (8, B/8, ...)
new_params = step(params, batch)

Cross-device communication is via collectives inside the function: jax.lax.psum(x, axis_name="i"), pmean, all_gather, etc., where axis_name is set by pmap(..., axis_name="i").

pmap is single-program multiple-data with explicit sharding by the user, restricted to one batch dimension and one mesh axis. It composes-`pmap(vmap(grad(f))) - but it is awkward when you need 2D meshes, host coordination across many TPU hosts, or per-tensor sharding choices.

9.2 pjit (intermediate) and the unified jit (modern)

Around 2022 JAX introduced pjit - ajitthat took aMeshandPartitionSpecs and let XLA's GSPMD partitioner shard arbitrary tensors across an arbitrary mesh. This was the right abstraction. By 2024,pjitandjitwere unified: today **jax.jitnatively understands sharding**, andpjit` is an alias.

The modern path:

import jax
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding

devices = jax.devices()                     # e.g., 8 GPUs or a 2x4 TPU slice
mesh = Mesh(np.array(devices).reshape(2, 4), axis_names=("data", "model"))

# Place an array sharded:
def shard(x, spec):
    return jax.device_put(x, NamedSharding(mesh, spec))

x   = shard(x_host,    P("data", None))     # batch sharded over 2 devices
W   = shard(W_host,    P(None, "model"))    # output dim sharded over 4 devices
b   = shard(b_host,    P("model"))

@jax.jit
def forward(x, W, b):
    y = x @ W + b
    # Optionally constrain intermediate shardings:
    y = jax.lax.with_sharding_constraint(y, NamedSharding(mesh, P("data", "model")))
    return y

What happens under the hood:

  1. jit traces forward with abstract sharded inputs.
  2. The jaxpr lowers to HLO with sharding annotations on the inputs, outputs, and any with_sharding_constraint points.
  3. GSPMD (the partitioner inside XLA) propagates sharding through the whole HLO module, deciding per-op how each tensor is laid out.
  4. GSPMD inserts collectives (all-reduce, all-gather, reduce-scatter, all-to-all) where needed.
  5. XLA emits one program per device; each device runs only its slice.

This is GSPMD: General SPMD partitioner. The user writes a single-device-shaped program with annotations on a few key tensors; the compiler figures out the rest.

9.3 Mesh, PartitionSpec, NamedSharding

  • Mesh: a logical N-dimensional grid of devices with named axes. Common patterns:
  • 1D: `("data",) - pure data parallelism.
  • 2D: `("data", "model") - DP × tensor-parallel (Megatron-style).
  • 3D: ("data", "fsdp", "model") or `("pp", "data", "model") - pipeline + DP + TP.
  • PartitionSpec (alias P): for an array with shape (d0, d1, …, dn), a P(spec0, spec1, …) says how each axis is partitioned over the mesh. Each speci is either:
  • None: replicated along this array axis.
  • "name": sharded along the mesh axis "name".
  • A tuple ("a", "b"): sharded along the product of mesh axes a and b.
  • NamedSharding(mesh, P(...)): binds a PartitionSpec to a concrete mesh.

Examples: - P("data", None) on (B, D): sharded batch, replicated features (FSDP-like for activations). - P(None, "model") on (D_in, D_out): replicated D_in, sharded D_out (tensor-parallel weight). - P(("data", "fsdp")) on (D,) with mesh ("data", "fsdp"): sharded over both axes simultaneously.

9.4 with_sharding_constraint

Inside a `jit - ed function, you can pin an intermediate sharding:

y = jax.lax.with_sharding_constraint(y, NamedSharding(mesh, P("data", "model")))

This is a hint to GSPMD: "after this point, y must be sharded this way." Use it to: - Break ambiguity when GSPMD picks a sharding you don't want. - Force a re-shard at a known boundary (e.g., between tensor-parallel attention and tensor-parallel MLP).

9.5 Output sharding

jit with sharding can take in_shardings and out_shardings arguments to specify how inputs and outputs are sharded. Default: inferred from the actual input shardings and propagated by GSPMD.

forward_p = jax.jit(
    forward,
    in_shardings=(NamedSharding(mesh, P("data", None)),
                  NamedSharding(mesh, P(None, "model")),
                  NamedSharding(mesh, P("model"))),
    out_shardings=NamedSharding(mesh, P("data", "model")),
)

9.6 When pmap is still useful

pmap survives for: - Quick single-axis SPMD where setting up a Mesh feels heavy. - Legacy code. - A few research patterns that depend on pmap's tight coupling between Python-side leading axis and device axis.

For new code at scale, prefer jit + sharding.


10. jax.shard_map: when you want manual control

jit + sharding is implicit: you annotate, GSPMD figures out collectives. shard_map is explicit: you write what each device sees, and you call collectives by hand.

from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P

@partial(shard_map, mesh=mesh, in_specs=(P("data", None), P(None, "model")),
                            out_specs=P("data", "model"))
def matmul(x_local, W_local):
    # x_local has the *local* shape on each device.
    # We must call collectives explicitly.
    y_partial = x_local @ W_local
    return y_partial   # already correctly sharded

Inside shard_map you operate on the local shard. Collectives (jax.lax.psum, jax.lax.all_gather, jax.lax.all_to_all, jax.lax.ppermute) reference the mesh axis names and run across the corresponding device subset.

When to use shard_map over jit+sharding: - You need an algorithm that GSPMD does not synthesize well (custom ring all-reduce, expert-parallel routing, sequence parallelism with overlap). - You want predictable collective placement for performance debugging. - You are implementing a low-level kernel (e.g., a custom attention with sequence sharding and explicit all_to_all).

When to stick with jit+sharding: - Standard transformer training. GSPMD does an excellent job. - You want one piece of code that retargets between mesh shapes without rewriting.

The mental model: jit+sharding is "declare the partitioning, let the compiler handle parallelism"; shard_map is "I am writing SPMD by hand, shoulder-to-shoulder with the hardware."


11. Structured loops: lax.scan, lax.fori_loop, lax.while_loop

A Python for loop inside a `jit - traced function unrolls into the jaxpr. For 4 iterations that is fine. For 1024 iterations the jaxpr (and the resulting HLO and the compile time) explode. Use structured control flow.

11.1 `jax.lax.scan - the workhorse

scan is a stateful map-reduce. Signature (informal):

def f(carry, x):                # carry: state, x: per-step input
    new_carry = ...
    y = ...
    return new_carry, y

final_carry, ys = jax.lax.scan(f, init_carry, xs)

It is O(T) in compile size regardless of T, because the loop body is traced once and reused.

Use for: - RNN forward passes. - Sampling loops where each step depends on the last. - Any reduction with an explicit state that you want to materialize per step (ys).

scan differentiates correctly through the loop in O(T) memory if you use unroll= carefully or rely on rematerialization (jax.checkpoint). The default reverse-mode AD over scan keeps activations for every step; for long sequences combine with jax.checkpoint to trade compute for memory.

11.2 jax.lax.fori_loop

def body(i, state): ...
final = jax.lax.fori_loop(0, N, body, init)

Like scan but does not stack per-step outputs (no ys). Slightly cheaper, less expressive. AD support is limited if N is dynamic-for differentiable loops prefer scan.

11.3 jax.lax.while_loop

def cond(state): ...
def body(state): ...
final = jax.lax.while_loop(cond, body, init)

Truly dynamic iteration count. Cannot be reverse-mode differentiated in the general case (the number of iterations is data-dependent). Use for inference-only loops with data-dependent termination (e.g., autoregressive sampling until EOS).

11.4 jax.lax.cond and jax.lax.switch

For data-dependent branching:

y = jax.lax.cond(pred, true_fn, false_fn, x)
y = jax.lax.switch(idx, [fn0, fn1, fn2], x)

Both branches are traced and compiled; only one runs at execution. This means both branches must return the same PyTree structure with matching shapes/dtypes.


12. XLA: the compiler under the hood

XLA (Accelerated Linear Algebra) is the compiler. JAX is one front-end; TensorFlow and PyTorch/XLA are others. We focus on what JAX programmers need to know.

12.1 HLO: the IR

HLO ("High Level Optimizer") is XLA's intermediate representation. An HLO module is a collection of computations; each computation is a list of instructions; each instruction has a name, an opcode, operand references, and a typed shape.

Modern XLA actually uses StableHLO (an MLIR dialect) at the boundary, then lowers internally to HLO. The two are largely isomorphic for our purposes.

A small HLO snippet for jnp.sum(jnp.sin(x * y)):

HloModule jit_f, entry_computation_layout={(f32[3]{0}, f32[3]{0})->(f32[])}

ENTRY main {
  x = f32[3]{0} parameter(0)
  y = f32[3]{0} parameter(1)
  prod = f32[3]{0} multiply(x, y)
  s    = f32[3]{0} sine(prod)
  zero = f32[] constant(0)
  ROOT sum = f32[] reduce(s, zero), dimensions={0}, to_apply=add_f32
}

You can dump HLO from JAX:

hlo = jax.jit(f).lower(jnp.ones((3,)), jnp.ones((3,))).compiler_ir(dialect="hlo")
print(hlo.to_string())

or get the post-optimization HLO via compiled.as_text() / compiled.runtime_executable(). Setting the env var XLA_FLAGS=--xla_dump_to=/tmp/xla_dump --xla_dump_hlo_as_text will dump every compiled module.

12.2 The HLO op set you should recognize

Op Meaning
parameter(i) The i-th input
constant(...) A literal
add, multiply, subtract, divide Elementwise
sine, exponential, log, tanh, ... Elementwise unary
compare Elementwise comparison
convert Dtype cast
broadcast(..., dimensions=...) Reshape/expand to a larger shape
reshape Same data, new shape
transpose(..., dimensions=...) Permute axes
slice, dynamic-slice Static / dynamic slicing
concatenate Concatenate along a dim
reduce(operand, init, dimensions=, to_apply=) Reduction with a reducer computation
dot(a, b, lhs_contracting_dims=..., rhs_contracting_dims=..., lhs_batch_dims=..., rhs_batch_dims=...) Generalized matmul
convolution Generalized conv
gather, scatter Indirect read / write
select Elementwise where
tuple, get-tuple-element Tuple constructors / accessors
while, conditional Structured control flow
all-reduce, all-gather, reduce-scatter, all-to-all, collective-permute Cross-device collectives
custom-call Escape hatch to a hand-written kernel (cuDNN, custom CUDA, Pallas)

dot is the linchpin. It is fully general: contracting dims are reduced, batch dims are kept, the rest are output. A standard (M, K) × (K, N) → (M, N) matmul has lhs_contracting=[1], rhs_contracting=[0], no batch dims. Attention's (B, H, S, D) × (B, H, T, D) has lhs_contracting=[3], rhs_contracting=[3], lhs_batch=[0,1], rhs_batch=[0,1].

12.3 The compilation pipeline

Roughly:

  jaxpr
    │  (lowered by JAX)
  StableHLO (MLIR)
    │  (XLA front-end)
  HLO
    │  (XLA HLO passes)
  Optimized HLO
    │  (Backend: GPU / TPU / CPU)
  Device code (PTX/SASS via LLVM-NVPTX,  TPU machine code,  x86 LLVM)

The HLO passes do the heavy lifting:

  1. Algebraic simplification. x * 1 → x, concat(slice, slice) → original, fold constants.
  2. Layout assignment. Pick physical memory layouts (which dim is fastest-varying, tile shapes, padding) per buffer to match the target hardware.
  3. Sharding propagation (GSPMD). From annotated tensors, infer shardings everywhere; insert collectives.
  4. Operator fusion. Combine adjacent elementwise ops, plus a producer reduction or matmul, into a single kernel-this is the single biggest performance win on GPU. JAX programs that look like 100 small numpy ops often compile to a handful of fused kernels.
  5. Memory scheduling. Order ops to minimize peak memory; insert rematerialization if needed.
  6. Lowering. GPU: emit LLVM IR (or call into cuBLAS/cuDNN for some patterns), then to PTX. TPU: emit TPU-specific IR for the matrix unit and vector unit, schedule across HBM/VMEM.

12.4 Fusion in detail

Fusion is the reason JAX feels fast. Consider:

def f(x, y, z):
    return jnp.tanh(x * y + z)

In eager numpy this is three kernel launches: multiply, add, tanh-each reads from and writes to HBM. XLA fuses them into one kernel: read x, y, z once, do all the elementwise math in registers, write the result once. On memory-bound workloads (most elementwise + small reductions) this is a 3–10× speedup.

XLA's GPU backend also fuses producer reductions and consumer elementwise (and vice versa) and a matmul with epilogue elementwise (and prologue elementwise on its inputs) where profitable. The resulting fused kernel is what you see as a single HLO fusion instruction in the post-optimization dump.

Caveat: fusion can hide bugs. If you jax.debug.print inside a function and it disappears, fusion has eliminated it; use jax.disable_jit() for debugging.

12.5 Layout assignment

Layout = how a logical shape (N, H, W, C) maps onto physical memory. On TPUs, tiling matters enormously: the matrix unit prefers (128, 128) tiles aligned in particular ways; XLA inserts paddings and chooses layouts so dot products hit fast paths. On GPUs, the choice is less critical (memory is more uniform) but still matters for shared memory and tensor cores.

You can sometimes coax XLA with shape choices (sizes that are multiples of 128 on TPU, multiples of 8 with bf16 for tensor cores on GPU).

12.6 GSPMD: how a single jit targets thousands of devices

GSPMD's input: an HLO module with sharding annotations on some subset of values (parameters, outputs, with_sharding_constraint points). Its output: an HLO module where every value has a sharding, with collectives inserted at boundaries.

The propagation is bidirectional and uses cost models. Key rules:

  • A dot between A: P("data", None) (sharded on batch) and B: P(None, "model") (sharded on output dim) produces a result P("data", "model") with no collective-purely local matmul.
  • A dot between A: P("data", "k") and B: P("k", "model") (sharded on the contraction dim) requires an all-reduce over the k axis after the local matmul.
  • An elementwise op with mismatched shardings inserts an all-gather or a reduce-scatter to align them.
  • A sequence-axis split followed by self-attention typically needs an all-to-all to switch from sequence-sharded to head-sharded.

GSPMD picks the lowest-cost combination. For most transformer workloads with a sensible mesh and a sensible PartitionSpec, the result is close to what an expert hand-writer would do. Where it isn't, you reach for with_sharding_constraint or shard_map.

12.7 Sharding propagation example

mesh = Mesh(devices.reshape(2, 4), ("data", "model"))
W1 = shard(W1_host, P(None, "model"))   # (D, 4D), sharded out dim
W2 = shard(W2_host, P("model", None))   # (4D, D), sharded in dim
x  = shard(x_host,  P("data", None))    # (B, D), sharded batch

@jax.jit
def mlp(x, W1, W2):
    h = jax.nn.gelu(x @ W1)
    y = h @ W2
    return y

GSPMD will: 1. Compute x @ W1: x is P("data", None), W1 is P(None, "model"), so result is P("data", "model") (no collective). 2. gelu: elementwise, sharding unchanged. 3. h @ W2: h is P("data", "model"), W2 is P("model", None). The contraction is on the "model" axis, so an all-reduce across "model" is inserted after the local matmul, yielding P("data", None).

That is the standard Megatron tensor-parallel MLP, derived automatically from three PartitionSpecs.


13. TPU vs GPU under JAX/XLA

XLA was conceived at Google with TPUs as its primary target, then extended to GPU and CPU. That history influences the runtime behavior.

Mark the TPU specifics here as "broadly true, version-dependent." TPU generations differ (v3/v4/v5p/v5e/Trillium…), and details have shifted.

13.1 Architectural differences (high level)

Aspect GPU (NVIDIA) TPU
Compute units SMs with FP32/FP16 cores + tensor cores Matrix multiply unit (systolic array, 128×128 typical) + vector unit
Memory hierarchy Registers → shared mem → L2 → HBM Registers → VMEM (per core) → HBM
Numeric formats FP32, FP16, BF16, FP8, INT8 Primarily BF16 / FP32 / INT8; FP8 in newer gens
Interconnect NVLink within a node, IB / Ethernet across nodes Native ICI (inter-chip interconnect) in pod topology-2D/3D torus
Sweet spot Heterogeneous workloads, irregular ops, custom CUDA Big regular dense matmuls, large pods

13.2 Compiler differences

  • Fusion granularity. TPU XLA tends to produce very large fused regions-sometimes the entire transformer block becomes one or two ops. GPU XLA fuses aggressively but is bounded by SM resource limits.
  • Layout. TPUs are pickier about layout (the matrix unit's tile size is fixed). XLA pads aggressively to align-small tensors can have substantial padding overhead. Sizes that are multiples of 128 (sometimes 256) are friendly.
  • Collectives. TPU pods have a 2D or 3D torus; XLA's collective scheduler is tightly tuned for it. GPU collectives go through NCCL (or XLA's own GPU collectives), and topology is less regular.
  • Async dispatch. Both backends launch async; on TPU the compiler often overlaps collectives with compute aggressively (because the cost model is well-understood).

13.3 Why XLA was TPU-first

TPUs are useless without a compiler-there is no eager kernel library equivalent to cuDNN. Every TPU program goes through XLA. So XLA had to be excellent at TPU codegen, sharding, and pod-scale collectives from day one. JAX inherited all of that. Running JAX on TPU is "the original path"; running JAX on GPU shares the front end but uses a different (also mature) backend.

13.4 Practical implications

  • A JAX program that runs well on 1 GPU often runs well on 1 TPU core with no changes.
  • A JAX program that scales to 8 GPUs via jit+sharding scales to a 2048-core TPU pod by enlarging the `Mesh - same code.
  • TPU memory per core is typically smaller than a top-end GPU's HBM. You will lean more on FSDP-style sharding and rematerialization on TPU.

14. Module systems: Equinox and Flax

JAX core has no nn.Module. Two libraries dominate.

14.1 Flax (flax.linen and the newer flax.nnx)

flax.linen (the long-standing API):

from flax import linen as nn
class MLP(nn.Module):
    hidden: int
    out: int
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden)(x)
        x = nn.relu(x)
        return nn.Dense(self.out)(x)

model = MLP(64, 10)
params = model.init(rng, dummy_input)        # returns a PyTree of params
y = model.apply(params, x)

Idioms: - Modules are dataclasses; calling them constructs thunks, not param-owning objects. - model.init(rng, x) traces the module to create a parameter PyTree. - model.apply(params, x) runs the forward pass-params are an argument, never owned.

This is JAX-functional through and through. Optimizer state is also a PyTree, threaded through train_step.

flax.nnx is a newer API that gives you a more PyTorch-like stateful feel while preserving JAX semantics under the hood. Pick linen for stable production code and the largest ecosystem; pick nnx if you prefer the ergonomics.

14.2 Equinox

Equinox treats modules as PyTrees of dataclasses directly:

import equinox as eqx
class MLP(eqx.Module):
    l1: eqx.nn.Linear
    l2: eqx.nn.Linear
    def __init__(self, key):
        k1, k2 = jax.random.split(key)
        self.l1 = eqx.nn.Linear(784, 128, key=k1)
        self.l2 = eqx.nn.Linear(128, 10,  key=k2)
    def __call__(self, x):
        return self.l2(jax.nn.relu(self.l1(x)))

model = MLP(jax.random.PRNGKey(0))
y = model(x)                                  # call the model directly
grads = jax.grad(loss)(model, x, y_true)      # model itself is the param tree

Idioms: - A module is its parameters. There is no separate params PyTree. - eqx.partition(model, eqx.is_array) separates trainable arrays from non-array fields when needed (e.g., for optax).

Pick Equinox if you like the "model is data" mental model and small dependency. Pick Flax for the bigger ecosystem (pretrained checkpoints, integrations, MaxText/Pax).

14.3 Optimizer: optax

Both work with optax, the standard JAX optimizer library. optax exposes optimizers as (init, update) function pairs that operate on PyTrees:

import optax
opt = optax.adamw(1e-3)
opt_state = opt.init(params)
grads = jax.grad(loss)(params, ...)
updates, opt_state = opt.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)

Note: pure functions, no mutation. optax chains transformations (optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(...))) which makes complex schedules trivially composable.


15. jax.experimental.pallas: Triton-like kernels in JAX

When the compiler's automatic codegen is not enough-typically for fused attention, custom flash-attention variants, or quantized kernels-Pallas lets you write kernels in a Python DSL that lowers to:

  • Triton on GPU, and
  • Mosaic on TPU.

Sketch:

from jax.experimental import pallas as pl

def add_kernel(x_ref, y_ref, o_ref):
    o_ref[...] = x_ref[...] + y_ref[...]

@jax.jit
def add(x, y):
    return pl.pallas_call(
        add_kernel,
        grid=(x.shape[0] // 128,),
        in_specs=[pl.BlockSpec((128,), lambda i: (i,)),
                  pl.BlockSpec((128,), lambda i: (i,))],
        out_specs=pl.BlockSpec((128,), lambda i: (i,)),
        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
    )(x, y)

Mental model: Pallas kernels look like CUDA/Triton kernels (you reason about blocks and refs/pointers), but they integrate as a single HLO custom-call in your JAX program, with full vmap and jit composition. Use it sparingly-only when the surrounding XLA-generated code leaves serious performance on the table-and benchmark against the un-Pallas baseline.

This is the JAX answer to PyTorch's "drop into Triton". It is a young area; expect API motion. The big production win so far is fused-attention kernels.


16. Practical exercises (with worked answers)

How to use these. Read the question. Pause. Predict. Then read the answer. If you predicted wrong, re-read the relevant section.

Exercise 1-What gets recompiled?

You have:

@jax.jit
def f(x, y):
    return x @ y + 1.0

f(jnp.ones((4, 8)),  jnp.ones((8, 16)))     # call A
f(jnp.ones((4, 8)),  jnp.ones((8, 16)))     # call B
f(jnp.ones((4, 8)),  jnp.ones((8, 32)))     # call C
f(jnp.ones((4, 8)).astype(jnp.bfloat16),
  jnp.ones((8, 16)).astype(jnp.bfloat16))   # call D

Which calls trigger compilation?

Answer. A compiles. B reuses A's cache (identical abstract signature). C compiles a new entry (different shape on y). D compiles a new entry (different dtype on both). Total: 3 compilations, 4 calls.

Exercise 2-Why did my training loop slow to a crawl?

Symptom: every step takes ~2 s, no GPU utilization between steps. JAX_LOG_COMPILES=1 shows a compile log every step.

Likely cause. Variable-shape inputs. Most often: padding sequences to their actual lengths inside the data loader, so each batch has shape (B, L) with a different L.

Fixes (any of): 1. Pad to a fixed maximum length (cheapest, slight wasted compute). 2. Pad to a small set of bucket lengths (e.g., 128/256/512/1024)-at most that many compilations. 3. Mark sequence length as static_argnums only if you genuinely have a tiny number of distinct values.

Exercise 3-Per-example gradient norms

You want, for each example in a batch, the L2 norm of its per-example gradient. Write it.

def per_example_loss(params, x, y):
    return ((x @ params["W"] - y) ** 2).mean()    # scalar per example

per_grad = jax.vmap(jax.grad(per_example_loss),
                    in_axes=(None, 0, 0))         # share params, batch x and y

def per_example_norm(params, X, Y):
    grads = per_grad(params, X, Y)                # PyTree, leaves shape (B, ...)
    flat  = jax.tree.leaves(grads)
    sq    = sum(jnp.sum(g.reshape(g.shape[0], -1) ** 2, axis=1) for g in flat)
    return jnp.sqrt(sq)                           # shape (B,)

Reading: vmap(grad(...)) produces a function that returns gradient PyTrees with a leading batch axis; we then compute the per-row norm. No Python loop.

Exercise 4-A Megatron-style MLP, by PartitionSpec

You have an MLP Linear(D, 4D) → gelu → Linear(4D, D). You have an 8-device mesh ("data", "model") with shape (2, 4). Specify shardings such that: - The batch is data-parallel. - Both Linear layers are tensor-parallel along the hidden axis. - The output of the MLP is data-sharded, model-replicated (so a downstream layer-norm sees a complete vector per example).

Answer.

mesh = Mesh(devices.reshape(2, 4), ("data", "model"))
x_spec   = P("data", None)          # (B, D)
W1_spec  = P(None, "model")         # (D, 4D)
W2_spec  = P("model", None)         # (4D, D)
y_spec   = P("data", None)          # (B, D)

GSPMD will compute x @ W1 locally (no collective), gelu locally, then h @ W2 with an all-reduce over "model" on the contraction axis to produce y_spec. This is exactly Megatron-LM's "column-parallel then row-parallel" pattern, derived from PartitionSpecs.

Exercise 5-Why won't this differentiate?

@jax.jit
def f(x):
    n = 0
    while jnp.linalg.norm(x) > 1.0:
        x = x / 2
        n += 1
    return x, n

You wrap it in jax.grad(lambda x: f(x)[0].sum()) and get an error. Why?

Answer. The Python while runs at trace time and depends on a traced value (jnp.linalg.norm(x) > 1.0 is a tracer). You will hit a ConcretizationTypeError even before differentiation. The fix is jax.lax.while_loop, but: while_loop does not support reverse-mode AD because the iteration count is data-dependent. If you need a differentiable variant, use jax.lax.scan with a known maximum number of iterations, masking the unused steps.

Exercise 6-Reading a jaxpr

What does this code do?

def g(x):
    return jnp.where(x > 0, x, 0.5 * x)

print(jax.make_jaxpr(g)(jnp.array([-1.0, 2.0, -3.0])))

Answer. Approximately:

{ lambda ; a:f32[3]. let
    b:bool[3] = gt a 0.0
    c:f32[3]  = mul a 0.5
    d:f32[3]  = select_n b c a
  in (d,) }

It is a leaky-ReLU-ish function (slope 0.5 on the negative side). The jaxpr makes the elementwise nature explicit and shows that where becomes a select_n (3-way select primitive) rather than a Python branch. It will compile to a single fused HLO kernel (compare, multiply, select).


17. Cheat sheet

17.1 The four core transformations

Transformation Maps f to Mental model
jax.jit(f) A function that traces, lowers to HLO, compiles, caches, runs "Make it fast"
jax.grad(f) A function returning df/dx "Differentiate"
jax.vmap(f) A function with an extra batch axis "Vectorize"
jax.pmap(f) / jit(..., in_shardings=...) A function running across devices "Parallelize"

They compose. jit(vmap(grad(f))) is well-defined and idiomatic.

17.2 Common errors and their meaning

Error Cause Fix
ConcretizationTypeError Python control flow on a tracer Use jax.lax.cond/where, or make the variable static
TracerArrayConversionError Tried to convert a tracer to numpy or to use it as a Python int Push the work into JAX-land or restructure
Repeated compiles Variable shapes / dtypes / static args / pytree structures Stabilize shapes, use bucketing, audit static_argnums
OOM during compile Long Python for loop being unrolled Use lax.scan
Silent wrong answer Side effect (mutation, global) inside jitted function Make it pure

17.3 Inspection toolbox

jax.make_jaxpr(f)(x)                                 # show jaxpr
jax.jit(f).lower(x).as_text()                        # StableHLO MLIR
jax.jit(f).lower(x).compiler_ir(dialect="hlo")       # HLO
jax.jit(f).lower(x).compile().as_text()              # post-opt HLO
jax.jit(f).lower(x).compile().cost_analysis()        # FLOPs, bytes
jax.config.update("jax_log_compiles", True)          # see every compile
JAX_TRACEBACK_FILTERING=off                          # (env) full Python tracebacks
XLA_FLAGS=--xla_dump_to=/tmp/xla --xla_dump_hlo_as_text   # (env) dump every module

17.4 When to reach for which loop primitive

Need Primitive
Static iteration count, want per-step outputs, want AD jax.lax.scan
Static iteration count, no per-step outputs, no AD on the loop count jax.lax.fori_loop
Data-dependent iteration count, no AD over the loop jax.lax.while_loop
Data-dependent branch, both branches valid jax.lax.cond
Choose among N branches by index jax.lax.switch

17.5 Sharding spec recipes (mesh ("data", "model"))

Goal Inputs Weights Note
Pure DP (replicated weights) P("data", ...) P(None, ...) Standard data parallel
FSDP-style P("data", ...) P("data", ...) (gathered before use) Combine with with_sharding_constraint
Tensor parallel (Megatron MLP) P("data", None) W1: P(None, "model"), W2: P("model", None) All-reduce after W2
2D parallelism P("data", None) TP weights as above ("data", "model") mesh

17.6 Mental discipline

  • State-as-argument. Anything that "carries over"-params, opt state, RNG, batchnorm running stats-is an argument and a return value, never a global.
  • Trace once, run many. Every `jit - decorated function should be traced a small number of times across the entire program lifetime.
  • Annotate sharding sparsely. Annotate the inputs and key intermediate shardings; let GSPMD figure out the rest.
  • Profile with the post-opt HLO. What you wrote in Python and what runs on the device can diverge dramatically due to fusion. Read the post-optimization HLO before optimizing.
  • Read the jaxpr when surprised. It is short and exact.

Closing

JAX is a small core (pure functions over PyTrees, traced into jaxprs, lowered to XLA HLO) with a large amount of leverage on top (composable transformations, GSPMD, Pallas, the Flax/Equinox ecosystem). Its design exacts a discipline-purity, explicit state, structured control flow, deliberate sharding-and pays back with a programming model that scales seamlessly from a single laptop GPU to a 4096-chip TPU pod with the same source code.

If you internalize four ideas from this chapter, make them: (1) purity makes transformations compose, (2) jit is trace-then-cache, and the cache key is the abstract signature, (3) HLO is the IR; fusion and GSPMD are where the magic lives, (4) Mesh + PartitionSpec is how you tell the compiler about your hardware, and the rest is propagation. Everything else in JAX is a refinement of those.

Distributed Training: Mathematics and Engineering

A self-contained reference for understanding how modern foundation models are trained across hundreds-to-thousands of GPUs. By the end, you should be able to read ZeRO, Megatron-LM, GPipe, and PipeDream from summary alone, derive the memory and communication costs of any sharding scheme, and design a 3D-parallel configuration for a given model and cluster.


1. Why Distributed Training: The Memory Wall

1.1 Intuition

A single H100 80GB GPU has 80 GB of HBM. A single A100 has 40 or 80 GB. Modern foundation models exceed this by orders of magnitude:

  • GPT-3 175B in FP16: 350 GB just for parameters.
  • Llama-3 405B in BF16: 810 GB just for parameters.

But parameters are only one of four memory bills. Training requires:

  1. Parameters (W)-the weights themselves.
  2. Gradients (∇W)-same shape as parameters.
  3. Optimizer states (O)-Adam keeps m (first moment) and v (second moment), both same shape as parameters.
  4. Activations (A)-intermediate forward outputs needed for backward; scales with batch × sequence × hidden, not with parameters.

Distributed training is, fundamentally, the engineering discipline of partitioning these four buckets across many devices while keeping the arithmetic correct and the communication cheap.

1.2 Memory Decomposition (the Standard Accounting)

Let Φ be the number of parameters in the model (e.g., Φ = 7 × 10⁹ for a 7B). Let:

  • b_p = bytes per parameter in storage precision (FP16/BF16 = 2, FP32 = 4).
  • b_g = bytes per gradient (typically same as b_p).
  • b_o = bytes per optimizer state slot (Adam in FP32 keeps m, v, plus a master FP32 copy of weights → 12 bytes per param under standard mixed precision).

Standard mixed-precision Adam memory per replica, in bytes:

M_param  = 2Φ        (BF16 parameters)
M_grad   = 2Φ        (BF16 gradients, often kept FP32 → 4Φ)
M_opt    = 12Φ       (FP32 master weights 4Φ + FP32 m 4Φ + FP32 v 4Φ)
M_static = 16Φ        ← without gradients FP32; with FP32 grads it's 18Φ or 20Φ

The "16Φ" figure is what ZeRO calls K=12 + 4 = 16 (12 bytes optimizer + 4 for FP16 params and grads). Variations exist; the principle is what matters.

For Φ = 70B with M_static ≈ 16Φ: 1.12 TB of static memory. No single 80GB GPU comes close. Hence: shard, replicate, or pipeline.

1.3 Activation Memory

Activations dominate at long sequence lengths. For a Transformer with L layers, hidden size h, batch B, sequence S, the activation memory per layer without recomputation is approximately:

A_per_layer ≈ S · B · h · (10 + 24/t + 5·a·S/h)    bytes

where t is the tensor-parallel degree and a is the number of attention heads (Korthikanti et al. 2022 give a precise formula). The 5·a·S²/(h·t) term grows quadratically in sequence length and is what motivates sequence parallelism (Section 10).

With selective activation checkpointing (recomputing attention but storing the rest), this drops to roughly S · B · h · (10 + 24/t).

1.4 The Three Axes of Parallelism

Each of the four memory buckets can be partitioned along three orthogonal axes:

Axis What it splits What it costs
Data parallel (DP) Batch All-reduce of gradients per step
Tensor parallel (TP) Hidden dim of each layer All-reduce inside each layer (×4 per Transformer block)
Pipeline parallel (PP) Layer depth Point-to-point sends between stages + bubble

ZeRO/FSDP is a refinement of DP that also shards parameters/gradients/optimizer states. Sequence parallel splits along the sequence dim. The full design space is rich; the rest of this chapter walks it.


2. Communication Primitives

Distributed training is impossible without a vocabulary of collective operations. We define each precisely, in terms of input shape, output shape, and the data each rank ends with.

2.1 Point-to-Point

send(tensor, dst) / recv(tensor, src): rank src sends a buffer of shape T to rank dst. After the call, both ranks hold T (typically). Used in pipeline parallelism between adjacent stages.

Cost model: α + M/β where α is link latency and β is link bandwidth, M is message size in bytes.

2.2 Collectives

Let there be N ranks 0, …, N-1. Each rank holds a tensor x_i.

broadcast(root): after the call, every rank holds x_root. Input shape on root: T; output shape on every rank: T.

reduce(op, root): after the call, root holds op(x_0, x_1, …, x_{N-1}) elementwise (for op = sum, min, max). Other ranks: undefined. Input/output shape: T.

all_reduce(op): every rank ends with op(x_0, …, x_{N-1}). Equivalent to reduce followed by broadcast. Input shape per rank: T; output shape per rank: T. Total data on each rank stays T.

all_gather: each rank i holds a chunk x_i of shape T. After the call, every rank holds the concatenation [x_0 ‖ x_1 ‖ … ‖ x_{N-1}] of shape N · T. Total bytes received per rank: (N-1)·T.

reduce_scatter(op): each rank i holds a tensor of shape N · T partitioned into N chunks. After the call, rank i holds the elementwise reduction of the i - th chunks across all ranks, shapeT. Output shape per rank:T`.

all_to_all: each rank i holds N chunks of shape T, where chunk j is destined for rank j. After the call, rank i holds chunks i from each rank 0..N-1, concatenated, shape N · T. Used in expert-parallel MoE for the dispatch/combine step.

A clean identity: all_reduce = reduce_scatter + all_gather. This decomposition is the basis for ring all-reduce and for ZeRO-2/FSDP.

2.3 Algorithm Trees for All-Reduce

We model communication time as T = α · L + M / β where L is the number of sequential link traversals (latency-limited) and M is the bytes that traverse the slowest link (bandwidth-limited). We want both small.

2.3.1 Naive (gather-to-root, scatter-from-root)

Every rank sends its M bytes to rank 0 sequentially; rank 0 reduces and broadcasts. Time:

T_naive = (N-1) · (α + M/β)   (gather, sequential through root's NIC)
        + (N-1) · (α + M/β)   (broadcast, same)
        ≈ 2(N-1)·α + 2(N-1)·M/β

Latency: O(N). Bandwidth: O(N·M). Root's link is a bottleneck; worst-case algorithm. Useful only for tiny N and tiny M.

2.3.2 Binary Tree

Build a binary tree over ranks. Reduce phase: leaves send to parents, parents reduce, propagate to root. Broadcast phase: reverse. Each rank participates in log₂(N) rounds of reduce + log₂(N) of broadcast.

Per-round, every active rank sends M bytes:

T_tree ≈ 2·log₂(N)·α + 2·log₂(N)·M/β

Latency: O(log N). Bandwidth: O(M·log N). Better latency, worse bandwidth than naive for large M. Used for small-message all-reduces (e.g., scalar metrics).

2.3.3 Recursive Doubling

For all-gather (dual of all-reduce): pair up ranks at distances 1, 2, 4, … N/2, each pair exchanges its current data. After log₂(N) rounds, each rank holds all N chunks. Time per round: α + M_round/β where M_round doubles.

Total bandwidth cost: M + 2M + 4M + … + (N/2)M = (N-1)·M, but spread across log₂(N) rounds. For all-reduce via this method: same as tree.

2.3.4 Recursive Halving + Recursive Doubling (Rabenseifner)

For large messages: do reduce-scatter in log₂(N) recursive-halving rounds, then all-gather in log₂(N) recursive-doubling rounds.

  • Reduce-scatter: in round k, each rank exchanges M / 2^k bytes with its partner at distance 2^(k-1).
  • After log₂(N) rounds: each rank holds M/N bytes that are the reduction of the corresponding slice.
  • All-gather: reverse, doubling the chunk size each round.

Total bytes sent per rank:

M/2 + M/4 + … + M/N  +  M/N + 2M/N + … + M/2  =  2(N-1)/N · M

Latency: 2·log₂(N)·α. Bandwidth-optimal in M, latency-logarithmic. This is what NCCL uses for inter-node all-reduce when the message and topology are right.

2.3.5 Ring All-Reduce (Patarasuk & Yuan, 2009)

Arrange the N ranks in a logical ring. Split the buffer into N chunks of size M/N.

Phase A: reduce-scatter (N-1 steps). In step k: - Rank i sends chunk (i - k) mod N to rank (i+1) mod N. - Rank i receives chunk (i - k - 1) mod N from rank (i-1) mod N, adds it to its local copy.

After N-1 steps, rank i owns the fully reduced chunk indexed (i+1) mod N (or some rotation; conventions vary).

Phase B: all-gather (N-1 steps). Same ring traversal: each rank passes its fully reduced chunk forward, eventually every rank holds all N reduced chunks.

Bandwidth analysis. In each step, every rank sends and receives one chunk of M/N bytes. Total steps: 2(N-1). Total bytes per rank (send or receive): 2(N-1) · M/N = 2(N-1)/N · M.

T_ring = 2(N-1)·α + 2(N-1)/N · M / β

Latency: O(N). Bandwidth: O(2(N-1)/N · M) → O(2M) as N→∞. This is bandwidth-optimal: you cannot all-reduce M bytes for fewer than 2(N-1)/N · M bytes per link without violating information-theoretic bounds. The latency is the price; for large messages on high-bandwidth links, latency vanishes relatively, and ring is the right choice.

Why bandwidth-optimal? Each rank's data must influence every other rank, and each rank must end up with the same M bytes. The min cuts of the ring topology force at least 2(N-1)/N · M bytes through each link.

Engineering note. Ring is the default intra-node algorithm in NCCL because NVLink's full-duplex topology fits the ring perfectly. For inter-node, NCCL may switch to tree (for small messages) or "double-binary tree" (for medium messages).

2.4 Summary Table

Algorithm Latency Bandwidth (per rank) Best regime
Naive 2(N-1)·α 2(N-1)·M Never
Binary tree 2 log₂N · α 2 log₂N · M Tiny messages
Recursive doubling log₂N · α log₂N · M All-gather, small-mid
Rabenseifner 2 log₂N · α 2(N-1)/N · M Large M, low N
Ring all-reduce 2(N-1)·α 2(N-1)/N · M Large M, intra-node

3. NCCL Specifics

NCCL (NVIDIA Collective Communications Library) is what every PyTorch/JAX distributed run uses on NVIDIA GPUs. Understanding its choices is the difference between 30% and 70% of peak.

3.1 Algorithm Selection

NCCL maintains a cost model and picks among:

  • Ring: bandwidth-optimal; default for large messages, intra-node, NVLink-rich.
  • Tree (double-binary tree): better latency for small messages, especially inter-node.
  • CollNet: uses InfiniBand SHARP (Mellanox switch-based aggregation) for in-network reduction. Subtracts host-side bandwidth.
  • NVLS (NVLink SHARP, on H100/Hopper systems with NVSwitch v3): in-NVSwitch reduction, like CollNet but for NVLink.

The choice is made per-collective based on (count, datatype, topology) using internal tuning tables.

3.2 Reading NCCL_DEBUG=INFO

A typical line:

NCCL INFO Channel 00/02 : 0[1c000] -> 1[1d000] via P2P/IPC
NCCL INFO Channel 00 : 0[1c000] -> 1[1d000] via NET/IB/0/GDRDMA

What to extract:

  • Channel count (02 here): NCCL uses multiple parallel rings; more channels = more concurrent SM use.
  • Transport (P2P/IPC, NET/IB, SHM): tells you whether traffic is staying on NVLink, going through shared memory, or hitting the network.
  • GDRDMA: GPUDirect RDMA, meaning NIC reads directly from GPU memory-desirable.
  • Ring 00 : 0 1 2 3 …: the ring order. For two-node 8-GPU jobs, NCCL builds rings that traverse intra-node NVLink before crossing IB.

If you see Tree lines for a large all-reduce, NCCL has decided the message is small enough or topology is sparse enough; sometimes wrong-you can override.

3.3 Tuning Knobs

Variable Effect
NCCL_DEBUG=INFO Verbose init logs, ring topology
NCCL_DEBUG_SUBSYS=ALL Even more (COLL, INIT, NET)
NCCL_IB_HCA=mlx5_0,mlx5_1 Which IB cards to use (rail affinity)
NCCL_SOCKET_IFNAME=eth0 Which Ethernet interface for bootstrap
NCCL_TOPO_FILE=topo.xml Override auto-detected topology
NCCL_ALGO=Ring (or Tree, CollNet, NVLS) Force algorithm
NCCL_PROTO=LL,LL128,Simple Force protocol; LL = low-latency (small msgs), Simple = bulk
NCCL_NTHREADS, NCCL_MAX_NCHANNELS Channel parallelism
NCCL_P2P_DISABLE=1 Disable peer-to-peer (debug)
NCCL_IB_GID_INDEX RoCE-specific; pick the right GID
NCCL_BUFFSIZE Per-channel buffer size; affects pipelining

3.4 Rail Affinity

Modern multi-NIC nodes use rails: GPU 0 talks to NIC 0, GPU 1 to NIC 1, etc. NCCL needs to know this; otherwise GPU 0 might end up sending through NIC 3 over PCIe, halving inter-node bandwidth.

Verify: nvidia-smi topo -m shows the GPU↔NIC connectivity matrix (PXB, PHB, NV2, etc.). Make sure NCCL's chosen NCCL_IB_HCA list aligns with the rails.


4. Data Parallelism (DDP)

4.1 The Pattern

Every rank holds the full model. Each step:

  1. Local forward on local micro-batch.
  2. Local backward.
  3. All-reduce of gradients across all ranks (sum, then divide by N).
  4. Local optimizer step (identical on every rank → models stay in sync).

This requires that initial weights are identical (broadcast at init) and that gradient all-reduce is exact. Both are easy to guarantee.

4.2 Memory Cost

Per rank (mixed precision, Adam), with Φ parameters:

  • Params (BF16):
  • Grads (FP32 for stability, often): 4Φ - or BF16 →2Φ`
  • Optimizer master weights (FP32):
  • Optimizer m (FP32):
  • Optimizer v (FP32):

Total static: ~16-18Φ per rank. DDP does not reduce model state memory: every rank pays full price. With Φ = 7B, that's ~112 GB on each rank. Already too big for one H100 80GB.

4.3 PyTorch DDP Implementation Details

The all-reduce isn't done as one giant call. PyTorch DDP uses gradient bucketing:

  • Gradients are grouped into buckets of ~25MB (configurable via bucket_cap_mb).
  • As each parameter's gradient is computed during backward, it's added to its bucket.
  • When a bucket is full, NCCL's allReduce is launched asynchronously, overlapping with continuing backward computation on other parameters.
  • Sync point at end of backward: wait for all bucket all-reduces to finish.

This overlap is critical: without it, you'd have backward → idle → all-reduce → idle. With it, all-reduce hides under backward.

Pseudocode of the hook:

on_grad_ready(param):
    bucket = bucket_for(param)
    bucket.add(param.grad)
    if bucket.full() or last_param_in_bucket(param):
        bucket.handle = ncclAllReduce(bucket.buffer, async=True)

4.4 Scaling Efficiency

Throughput per GPU vs. one GPU:

efficiency = T_compute / (T_compute + T_overhead)
T_overhead ≈ T_allreduce_unhidden

If T_allreduce is fully hidden under backward, efficiency → 100%. In practice, the last bucket can't be hidden (it's the final gradients). And small messages near the start have high latency relative to size.

Gotcha 1: small models, fast nodes. If T_compute per step is small (say 50ms) and the all-reduce of the final bucket is 20ms, you're at 70% efficiency before any other loss.

Gotcha 2: stragglers. All-reduce is barrier-like; one slow rank slows everyone. NCCL's default 30-min timeout will kill the job.

Gotcha 3: find_unused_parameters=True. Adds an extra pass and an extra all-reduce of a bitmask; avoid if possible.

4.5 Micro-example

Llama-2 7B, 8×H100, batch 8M tokens. Φ = 7e9, BF16 grads = 14 GB. Ring all-reduce on 8 GPUs at 600 GB/s NVLink:

T = 2(8-1)/8 · 14 GB / 600 GB/s = 1.75 · 23.3 ms ≈ 40.8 ms

If a forward+backward step takes 400 ms, the 40.8 ms hides ~90%; effective overhead ~4-8 ms. Efficiency: 98-99%. This is what makes 7B "easy" on 8 GPUs.


5. ZeRO (Rajbhandari et al., SC 2020)

The Zero Redundancy Optimizer observes that DDP duplicates the four memory buckets N times across N ranks. ZeRO shards them.

5.1 The Insight

Recall standard mixed-precision Adam memory: ~16Φ per rank. Of that, 12Φ is optimizer states (FP32 master, m, v). DDP keeps a full copy on every rank. Why? Because the optimizer step happens locally, and you need the states to take a step.

But: you only need the optimizer state of parameter p on the rank that owns p. If we shard:

  • Stage 1: shard optimizer states.
  • Stage 2: also shard gradients.
  • Stage 3: also shard parameters.

5.2 ZeRO-1: Optimizer State Sharding

Partition the parameter index space into N shards: rank i owns shard i's optimizer state (master weights + m + v, all FP32).

Per-step protocol: 1. Forward: full BF16 weights on every rank (still replicated). 2. Backward: full BF16 gradients on every rank. 3. All-reduce gradients (same as DDP). 4. Each rank applies optimizer to its shard of params (using its shard of m, v, master weights). 5. All-gather updated BF16 weights so every rank has the full updated model.

Memory per rank:

Params (BF16):    2Φ
Grads (BF16):     2Φ          (allocated full, but could shard after all-reduce)
Optimizer (FP32): 12Φ / N
Total ≈ 4Φ + 12Φ/N

For Φ=70B, N=64: 4·70 + 12·70/64 = 280 + 13.1 = 293 GB. Still bad. ZeRO-1 alone isn't enough for big models.

Communication. ZeRO-1 adds an all_gather of params after the optimizer step. Cost: (N-1)/N · 2Φ bytes (BF16). This is in addition to the gradient all-reduce. Net: roughly 1.5× DDP communication.

5.3 ZeRO-2: Optimizer + Gradient Sharding

Replace gradient all-reduce with reduce-scatter: each rank ends up with the reduced gradient of its shard only. Now only sharded gradients need to be stored.

Per-step protocol: 1. Forward: full params (replicated). 2. Backward: full grads computed locally. 3. Reduce-scatter gradients → each rank has reduced grads for its shard only. 4. Local optimizer step on shard. 5. All-gather updated params.

Memory per rank:

Params (BF16):    2Φ
Grads (BF16):     2Φ / N      ← sharded after reduce-scatter
Optimizer (FP32): 12Φ / N
Total ≈ 2Φ + 14Φ/N

For Φ=70B, N=64: 140 + 14·70/64 = 140 + 15.3 = 155 GB. Still too much.

Communication. reduce_scatter + all_gather = all_reduce algorithmically, so total bytes are the same as DDP. ZeRO-2 is essentially free in communication vs DDP-same bytes, less memory.

5.4 ZeRO-3: Full Sharding

Also shard parameters. Now no rank holds the full model. Forward must gather the params for the current layer, do the matmul, then free them.

Per-step protocol (per layer): 1. Forward layer l: all_gather the BF16 params of layer l. Do matmul. Free the gathered params (keep only your shard). 2. Backward layer l: all_gather BF16 params again (or kept from forward). Compute grads. reduce_scatter grads → each rank holds only its grad shard. Free grad buffer. 3. After full backward: each rank applies optimizer to its param shard. No final all-gather needed; next forward will gather as needed.

Memory per rank (everything sharded):

Params (BF16):    2Φ / N
Grads (BF16):     2Φ / N
Optimizer (FP32): 12Φ / N
Total ≈ 16Φ / N

For Φ=70B, N=64: 16·70/64 = 17.5 GB of static state. Now there's room for activations!

Communication cost. ZeRO-3 adds, per step:

  • One all_gather of params per layer (forward) ≈ (N-1)/N · 2Φ total over the model.
  • One all_gather of params per layer (backward, can be elided with smart reuse).
  • One reduce_scatter of grads per layer ≈ (N-1)/N · 2Φ.

Roughly 1.5× DDP communication. The tradeoff: you save N× memory at the cost of 1.5× bytes on the wire. Worth it when you can't fit otherwise.

5.5 Memory Math Summary

For one rank, parameters Φ, mixed-precision Adam (K=12 for opt states):

Scheme Params Grads Opt states Total
DDP 12Φ 16Φ
ZeRO-1 12Φ/N 4Φ + 12Φ/N
ZeRO-2 2Φ/N 12Φ/N 2Φ + 14Φ/N
ZeRO-3 2Φ/N 2Φ/N 12Φ/N 16Φ/N

(Note: some implementations keep grads in FP32 → 4Φ, shifting all rows. The structure is identical.)

5.6 ZeRO-Offload, ZeRO-Infinity

  • ZeRO-Offload: move optimizer states (and optionally gradients) to CPU memory and run the optimizer step on CPU. Trades GPU↔CPU PCIe bandwidth (~32 GB/s) for GPU memory.
  • ZeRO-Infinity: extends to NVMe; uses overlap so prefetched param shards arrive in time for the layer that needs them. Enables training models that exceed total GPU memory of the cluster.

These are useful when you can't add GPUs but have RAM and SSDs.


6. FSDP-PyTorch's ZeRO-3

PyTorch's Fully Sharded Data Parallel is the in-tree implementation of ZeRO-3, with a few production refinements.

6.1 Wrapping Discipline

You don't shard parameter-by-parameter-too fine; the all-gather overhead per matmul would crush you. You shard at the unit level, where a "unit" is a meaningful submodule like a Transformer block.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={LlamaDecoderLayer},
)
model = FSDP(model, auto_wrap_policy=policy, ...)

This wraps each LlamaDecoderLayer as an FSDP unit. The whole model is also wrapped (the outer FSDP). Each unit holds one all-gather buffer and one reduce-scatter buffer.

6.2 The Per-Unit Dance

For each unit during forward:

  1. Pre-forward hook: all_gather params of this unit (BF16) into a flat buffer.
  2. Forward compute: standard matmul/attention.
  3. Post-forward hook: free the all-gathered params; only the local shard remains.

For each unit during backward:

  1. Pre-backward hook: all_gather params again (since they were freed).
  2. Backward compute: produces gradients in the all-gathered shape.
  3. Post-backward hook: reduce_scatter gradients across ranks; each rank ends with its shard's gradients. Free the all-gathered params.

After full backward, the optimizer steps on local shards. No final all-gather (next forward does it lazily).

6.3 Prefetching

The naive schedule has bubbles: you can't compute layer l+1 while waiting for its all-gather. Fix: prefetch.

FSDP(..., backward_prefetch=BackwardPrefetch.BACKWARD_PRE)

BACKWARD_PRE issues the all-gather of the previous unit (in backward order, so the next-to-be-computed unit) before the current unit's backward runs. This overlaps:

backward(l)         |==compute==|  
                              \
all_gather(l-1)              |==comm==|
backward(l-1)                          |==compute==|

BACKWARD_POST is the conservative variant; BACKWARD_PRE is what you want when memory allows (it temporarily holds two all-gather buffers).

There's also forward_prefetch=True for forward, similar idea.

6.4 Mixed Precision

mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.bfloat16,
)
FSDP(model, mixed_precision=mp_policy, ...)
  • param_dtype=BF16: all-gathered params are BF16; matmuls run in BF16.
  • reduce_dtype=FP32: gradients are reduce-scattered in FP32 to avoid catastrophic cancellation (small grads summed across many ranks). This costs 2× bandwidth on grads but stabilizes training.
  • Optimizer in FP32: ZeRO-style master weights, m, v all FP32 on the local shard.

6.5 Activation Checkpointing

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper, CheckpointImpl
)
apply_activation_checkpointing(
    model,
    checkpoint_wrapper_fn=functools.partial(
        checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT
    ),
    auto_wrap_policy=policy,
)

For each wrapped block, only the input is saved; intermediate activations are recomputed during backward. Memory: ~O(L) instead of O(L · h · S · B) for the per-layer activations. Compute: ~+33% (one extra forward).

NO_REENTRANT is the modern impl; supports arbitrary autograd graphs.

6.6 CPU Offload

FSDP(..., cpu_offload=CPUOffload(offload_params=True))

Param shards live on CPU; before all-gather, they're brought back to GPU. Use only when out of GPU memory; PCIe is slow.

6.7 The FlatParameter

Internally, FSDP flattens all params in a unit into a single 1D tensor. Reasons:

  • One all-gather per unit (not one per param tensor).
  • Predictable memory layout.
  • Easier sharding math: just split the flat tensor into N equal slices (with padding).

Optimizer sees this flat parameter; optim.step operates on the local slice. State_dict round-tripping un-flattens.

6.8 FSDP2 (per-parameter sharding)

The successor (in-tree as of PyTorch 2.4+) abandons FlatParameter in favor of per-parameter sharding via DTensor. Cleaner semantics, better composition with TP/PP, slightly higher dispatch cost. The collective behavior is the same.


7. Tensor Parallelism (Megatron-LM)

When a single layer's weight matrix is too big for one GPU, or when matmul throughput is the limit, split the weight matrix itself.

7.1 The Two Splits

A linear layer computes Y = X · W where X ∈ ℝ^{B×K}, W ∈ ℝ^{K×N}, Y ∈ ℝ^{B×N}.

Column-parallel. Split W along the output (column) dimension:

W = [W_1 | W_2 | … | W_t]    where W_i ∈ ℝ^{K × N/t}

Each rank i holds W_i and a copy of X. Computes Y_i = X · W_i locally, shape B × N/t. Output is sharded along columns.

  • Forward output: sharded-Y_i on rank i. To get full Y, concatenate: all_gather along last dim. But if the next op is row-parallel, no gather needed (see below).
  • Backward of input: ∂L/∂X = ∂L/∂Y · W^T. Each rank computes ∂L/∂X_i = ∂L/∂Y_i · W_i^T. To get full ∂L/∂X (for the previous layer that has full X), need all-reduce of ∂L/∂X.

Row-parallel. Split W along the input (row) dimension:

W = [W_1; W_2; …; W_t]^T   where W_i ∈ ℝ^{K/t × N}

Each rank holds W_i and a corresponding shard X_i ∈ ℝ^{B × K/t} of the input. Computes Y_i = X_i · W_i ∈ ℝ^{B × N} (full output dim, but a partial sum).

  • Forward output: full shape but partial sum. Need all-reduce to get the true Y = ∑_i Y_i.
  • Backward of input: ∂L/∂X_i = ∂L/∂Y · W_i^T. Each rank already has ∂L/∂Y (because forward all-reduced it). Output ∂L/∂X_i is sharded-exactly the input format the previous (column-parallel) layer expects. No comm.

7.2 The Megatron Pattern

The genius of Megatron-LM: chain a column-parallel followed by a row-parallel so that the sharded output of the first is exactly the sharded input of the second. The all-gather between them is avoided.

MLP block (h → 4h → h):

Y = GeLU(X · W_1) · W_2

W_1: column-parallel (split 4h along output)
   → Each rank computes X · W_1_i, shape B×4h/t. NO all-gather.
GeLU: elementwise, sharded fine.
W_2: row-parallel (split 4h along input)
   → Each rank has its shard of GeLU output, computes (GeLU_i) · W_2_i.
   → All-reduce the result to get full Y.

Forward comm: one all-reduce at the end of MLP. Backward comm: one all-reduce at the start of MLP (input gradient).

Self-attention block. With a heads, head dim d_h, hidden h = a · d_h:

Q = X · W_Q,  K = X · W_K,  V = X · W_V       (h → h)
A = softmax(Q K^T / √d_h) · V                  (per-head)
Y = A · W_O                                    (h → h)

W_Q, W_K, W_V: column-parallel. Each rank gets a/t heads.
   → Q_i, K_i, V_i shape B×S×(h/t).
   → Per-head attention is local-no comm.
W_O: row-parallel.
   → All-reduce at end.

Forward comm: one all-reduce at end of attention. Backward comm: one all-reduce at start of attention.

7.3 Total Comm per Transformer Block

Per block (attention + MLP), TP communication is:

  • 2 all-reduces in forward (one at end of attention, one at end of MLP).
  • 2 all-reduces in backward (one at start of MLP's backward = ∂L/∂X for X feeding MLP, one at start of attention's backward).

4 all-reduces per block per step.

Each all-reduce is on activations of size B · S · h BF16 = 2BSh bytes. With ring all-reduce on t ranks: 2(t-1)/t · 2BSh bytes per all-reduce.

For a Transformer with L blocks, total TP comm per step:

T_TP_comm ≈ 4L · 2(t-1)/t · 2BSh / β

For Llama-3 70B (h=8192, L=80), B=4, S=8192, t=8: 4·80 · 1.75 · 2·4·8192·8192 ≈ 4·80·1.75·512MB = 280 GB per step crossing the TP comm. On 600 GB/s NVLink: ~470 ms per step. This is why TP wants intra-node only: cross-node TP at 50 GB/s (IB) would be 10× slower.

In summary: TP's comm is on the critical path of every layer. It cannot overlap with compute (the next layer can't start until this all-reduce finishes-well, except with sequence parallelism tricks). So bandwidth matters, immensely.

Rule of thumb: TP degree ≤ GPUs per node. On an 8×H100 DGX, TP=8 fits one node. TP=16 across two nodes through IB → ruinous.

7.5 Embedding TP

Vocabulary is large (e.g., 128K tokens × 8K hidden = 1B params just for embedding). TP splits along vocab:

  • Each rank holds V/t rows of the embedding table.
  • Embedding lookup: each rank looks up the rows for tokens in its slice, zeros elsewhere; all-reduce.
  • Output projection (the LM head, often tied to embedding): row-parallel; same structure.

The softmax over vocab requires care: subtract per-row max before exp; the per-row max is computed via all-reduce(max).

7.6 Micro-example

A row-parallel W: K=8192 → N=8192, t=8. Each rank holds W_i: 1024 × 8192. Forward: - Input shard X_i: B × 1024 (output of previous column-parallel). - Local matmul: Y_i = X_i · W_i, shape B × 8192. - All-reduce: Y = ∑_i Y_i, shape B × 8192.

Each rank's matmul uses 1024 · 8192 · B · 2 / 8 FLOPs (BF16). Each rank's all-reduce sends 2(t-1)/t · 2 · B · 8192 bytes. Compare compute time vs comm time to assess overlap potential.


8. Pipeline Parallelism

Splits the model along the depth axis. Layers 1..L/4 on stage 0, L/4+1..L/2 on stage 1, etc.

8.1 Naive Pipeline

Stage 0: forward layers 1..L/P → send activations to stage 1 → wait → receive grads from stage 1 → backward → done. Each stage waits for the others; only one stage active at a time. Utilization: 1/P. Useless without microbatching.

8.2 GPipe (Huang et al., 2018)

Split each minibatch into M microbatches. Pipeline them: stage 0 starts microbatch 1, then microbatch 2, then 3, etc. After all forwards complete, run backwards in reverse.

ASCII schedule for P=4, M=4:

time →
S0: F1 F2 F3 F4 .  .  .  B4 B3 B2 B1
S1: .  F1 F2 F3 F4 .  .  .  B4 B3 B2 B1
S2: .  .  F1 F2 F3 F4 .  .  .  B4 B3 B2 B1
S3: .  .  .  F1 F2 F3 F4 B4 B3 B2 B1 .  .  .

The leading and trailing dots are bubbles: idle time. With P stages and M microbatches:

  • Total useful time per stage: M · t_F + M · t_BM · (t_F + t_B).
  • Total wall time: (M + P - 1) · (t_F + t_B).
  • Bubble fraction: (P-1) / (M + P - 1).

Bubble = (P-1)/(M+P-1). For P=4, M=4: 3/7 ≈ 43%. For M=16: 3/19 ≈ 16%. For M=64: 3/67 ≈ 4.5%. Make M large.

GPipe memory issue: stage 0 must keep activations from M microbatches alive (it does forward 1..M before any backward). Activation memory: O(M · per_microbatch_act). Use activation checkpointing to bring it down.

8.3 1F1B (PipeDream)

Interleave forwards and backwards: as soon as a microbatch reaches the last stage, start its backward immediately. Backward propagates back; meanwhile, new forwards continue feeding in. Each stage maintains the invariant that it does one forward then one backward (1F1B).

ASCII for P=4, M=8:

S0: F1 F2 F3 F4 F5 B1 F6 B2 F7 B3 F8 B4 .. B5 B6 B7 B8
S1: .  F1 F2 F3 F4 F5 B1 F6 B2 F7 B3 F8 B4 B5 B6 B7 B8
S2: .  .  F1 F2 F3 F4 F5 B1 F6 B2 F7 B3 F8 B4 B5 B6 B7 B8
S3: .  .  .  F1 F2 B1 F3 B2 F4 B3 F5 B4 F6 B5 F7 B6 F8 B7 B8

Bubble fraction is the same: (P-1)/(M+P-1). But peak activation memory per stage drops dramatically. Stage 0 only holds activations for in-flight microbatches at any moment ≈ P (warm-up forwards before first backward). Memory: O(P) not O(M).

This is why every modern pipeline uses 1F1B-style scheduling.

8.4 Interleaved 1F1B (Megatron-LM)

Idea: each stage owns multiple non-contiguous chunks of layers. With v "virtual stages" per physical stage:

  • Layers 1..L/(P·v) on stage 0's chunk 1.
  • Layers L/(P·v)+1..2L/(P·v) on stage 1's chunk 1.
  • ...
  • Layers L/v + 1..L/v + L/(P·v) on stage 0's chunk 2.
  • ...

Each microbatch traverses P · v chunks (with comm between every chunk).

The bubble shrinks: (P-1)/(v · M + P - 1). With v=4, M=8, P=4: 3/(32+3) = 8.6% vs 3/11 = 27% for vanilla 1F1B. Cost: more comm ( more sends/recvs), more scheduling complexity.

8.5 Zero Bubble Pipeline (Qi et al., 2023)

Split the backward into two parts:

  • B-input (B_X): compute ∂L/∂X (gradient with respect to layer input)-needed by previous stage.
  • B-weight (B_W): compute ∂L/∂W (gradient with respect to weights)-only needed locally for optimizer.

B_W doesn't block the previous stage. Schedule B_X as soon as possible; defer B_W to fill bubbles. With careful scheduling, achievable bubble approaches zero.

Tradeoff: scheduling complexity increases; some implementations require equal t_F = t_{B_X} and divisible structure.

8.6 Memory and Throughput Comparison

Per-stage activation memory (with M microbatches, P stages):

Schedule Bubble Activation memory per stage
Naive (P-1)/P 1 microbatch
GPipe (P-1)/(M+P-1) M microbatches
1F1B (P-1)/(M+P-1) ~P microbatches
Interleaved 1F1B (v) (P-1)/(vM+P-1) ~P/v microbatches
Zero Bubble ~0 similar to 1F1B

8.7 PP Communication

Between adjacent stages: send/recv of activations (forward) and gradients (backward). Per microbatch, per stage boundary: 2 · B_micro · S · h bytes (forward + backward). Unlike TP, this is a point-to-point call, not a collective. Scales fine across IB. Pipeline parallelism is the right axis for crossing nodes.

8.8 PP Caveats

  • LayerNorm/RMSNorm statistics are local-fine.
  • Loss computation happens on the last stage; it must send the loss scalar back if you want it everywhere.
  • Optimizer state is local to each stage-no comm.
  • Pipeline imbalance: if stages have unequal compute (e.g., embeddings on stage 0 are cheap, last stage has LM head + loss), the slowest stage sets pace. Solution: assign more layers to faster stages.

9. 3D Parallelism: DP × TP × PP

9.1 The Decision Matrix

Given a model with parameters Φ, layers L, hidden h, and a cluster of G GPUs in nodes of g_n GPUs each, choose (t, p, d) such that t · p · d = G:

  • t = TP degree.
  • p = PP degree.
  • d = DP (or FSDP) degree.

Constraints:

  1. t ≤ g_n-TP must stay intra-node (NVLink/NVSwitch).
  2. p divides L-pipeline stages need integer layer counts.
  3. Per-GPU memory must fit: parameters + activations + optimizer state.
  4. Global batch size constraints: d × micro_batch × M_microbatches = total_batch. Global batch determines convergence; can't grow without bound.

9.2 Decision Heuristic

  1. Compute params per stage: Φ / (t · p) (TP shards within stage; PP shards across stages).
  2. Compute static memory per GPU: 16 · Φ / (t · p) for DP-replicated optimizer, or 16 · Φ / (t · p · d) for FSDP within DP (i.e., shard within the DP dim).
  3. Add activation memory per GPU. With activation checkpointing and sequence parallelism, this is much smaller.
  4. Pick smallest t that fits, then smallest p, then d fills the rest.

Why prefer small t and p? - TP comm is on the critical path of every layer. - PP creates bubbles unless you can make M huge (which inflates batch). - DP comm overlaps cleanly with backward.

So the priority is: maximize DP, use TP only as needed for memory/compute, use PP only when a single layer is too big for a node-level group, OR when TP+DP can't fit the model.

9.3 Worked Examples

Example A: 8B model on 64 H100 (8 nodes × 8 GPUs)

Φ = 8B. Static memory at 16Φ = 128 GB.

Try t=1, p=1, d=64 with FSDP-3: - Per-GPU static: 16Φ/64 = 2 GB. Easy fit. - Activations: with B_micro=4, S=8192, BF16, ~10 GB. Plenty of room. - TP comm: none. PP comm: none. Just FSDP all-gather/reduce-scatter. - Pick this. Simple is fast.

Example B: 70B model on 64 H100

Φ = 70B. Static at 16Φ = 1.12 TB.

Try t=1, p=1, d=64 FSDP-3: per-GPU = 17.5 GB. Fits, but barely (activations + workspace might push it over).

Try t=8, p=1, d=8 (TP within node, DP across nodes): - Per-GPU params (shared by TP shard, replicated across DP) = 2Φ/t = 2·70/8 = 17.5 GB BF16. - With FSDP across DP: divide by 8 more → 2.2 GB. - TP comm hot inside each node, DP comm reasonable. - Likely faster than pure FSDP-3 because activations are sharded too (with sequence parallelism) and matmuls are bigger per rank.

Try t=8, p=2, d=4: - Stage holds 35B. Per-GPU params = 2·35/8 = 8.75 GB. - Activations on stage 0 have to live for ~p=2 microbatches → smaller pressure. - Bubble (P-1)/(M+P-1) = 1/(M+1); with M=8, ~11% bubble. Not great. - Probably not worth it for 70B; pick the previous.

Example C: 405B model on 1024 H100 (128 nodes × 8 GPUs)

Φ = 405B. Static at 16Φ = 6.48 TB. Need ~6.5 TB / 80 GB = ~81 GPUs minimum just for state, but activation+workspace really demands more.

Try t=8, p=8, d=16: - Stage holds Φ/p = 50.6B params. - Per-GPU params = 2·50.6 / 8 = 12.6 GB. - With FSDP across the d=16 dim: 12.6 / 16 = 0.79 GB. Good. - Activation memory per stage with checkpointing + SP: typically 5-15 GB per GPU. - TP comm intra-node. PP comm node-to-node (reasonable, ~few GB per microbatch). DP/FSDP comm across the d=16 dim (across more nodes). - 8 stages → bubble (P-1)/(M+P-1). With M=64, bubble = 7/71 ≈ 10%. Acceptable.

This is roughly the published configuration shape for Llama 3.1 405B.

9.4 The Order of Ranks

A subtle but critical point: rank assignment order matters. Convention (Megatron):

rank = (pp_rank · tp_size · dp_size) + (dp_rank · tp_size) + tp_rank

So the innermost dimension is TP. This places TP groups within consecutive ranks, which on a node maps to neighbors connected by NVLink. DP comes next, then PP outermost. Get this wrong and you get TP all-reduces over IB → catastrophe.

torch.distributed.new_group lets you create the three groups (TP, DP, PP) with explicit rank lists.


10. Sequence Parallelism (Korthikanti et al., 2022)

10.1 The Problem

Even with TP, certain ops are not parallelized: LayerNorm, dropout, residual adds. These keep the full activation tensor of shape B × S × h on every TP rank. As S grows (long context training), this dominates memory.

10.2 The Idea

Split these ops along the sequence dimension. Each TP rank holds B × S/t × h for LayerNorm/Dropout. The activation memory drops by t.

But TP linears need the full sequence (matmul is in (B·S) × h). So at the boundary: all-gather to convert from sequence-parallel (B × S/t × h) to TP-parallel (B × S × h) on each rank, do the column-parallel linear, etc.

The all-gather replaces the all-reduce that was at the start of the column-parallel block. Comm volume: same total (all-reduce = reduce-scatter + all-gather; SP just moves the boundary). But now the LayerNorm/dropout/residual sit in the cheap SP region.

Net: same comm, much less activation memory. That's the deal.

10.3 Where Each Lives

[SP region]       LayerNorm    | activations B × S/t × h (sharded along S)
   ↓ all-gather to TP region
[TP region]       Linear (column-parallel) → Linear (row-parallel)
   ↓ reduce-scatter back to SP region
[SP region]       Residual + LayerNorm

reduce_scatter at the TP→SP boundary does double duty: aggregates the row-parallel partial sum and shards along sequence in one op.

10.4 Synergy with Long Context

Without SP, the 5 a S² / h term (attention activation memory) grows unboundedly. With SP and activation checkpointing, S=128K becomes feasible on commodity-cluster sizes.

Modern LLM training stacks (Megatron-LM, NeMo, Pax, MaxText) ship SP on by default for any TP > 1.


11. Mixed Precision and FP8

11.1 Why Lower Precision

Compute throughput on H100: - FP32 (TF32): 989 TFLOPS - BF16/FP16: 1979 TFLOPS - FP8: 3958 TFLOPS

Memory and bandwidth scale similarly. The training imperative: keep math in low precision; only escalate to FP32 where needed for stability.

11.2 FP16 with Loss Scaling

FP16 dynamic range is small: ~6e-5 to 6.5e4. Many gradients are below 6e-5 → underflow to zero → no learning.

Trick: loss scaling. Multiply the loss by s before backward; gradients are s × larger and fit in FP16 range. Before the optimizer step, divide gradients by s.

Dynamic loss scaling (the production version):

s = initial_scale (e.g., 2^16)
on each step:
    grads = backward(s · loss)
    if any(isnan or isinf in grads):
        s = s / 2          ← overflow: halve the scale, skip step
    else:
        grads = grads / s
        optimizer.step()
        if no overflow for `growth_interval` steps (e.g., 2000):
            s = s · 2       ← grow when stable

This keeps s near the maximum that doesn't overflow. PyTorch's GradScaler implements exactly this.

11.3 BF16

BF16 has FP32's dynamic range (8 exponent bits) but only 7 mantissa bits. For gradient magnitude purposes, equivalent to FP32. No loss scaling needed.

The penalty is reduced precision in the matmul accumulator, but H100 accumulates BF16 matmuls in FP32 internally. Net effect: free dynamic range.

This is why every modern training pipeline (Llama, Mistral, GPT-4 leaks, etc.) uses BF16 not FP16.

11.4 FP8

Two FP8 formats:

  • E4M3 (4 exponent bits, 3 mantissa bits): range ~±448. Higher precision, lower range. Use for forward activations and weights.
  • E5M2 (5 exponent bits, 2 mantissa bits): range ~±57344. Lower precision, higher range. Use for backward gradients (which need range).

Standard recipe (NVIDIA TransformerEngine):

forward: weights and activations cast to E4M3 (per-tensor scaled).
matmul: BF16 accumulator, output dequantized to BF16.
backward: incoming gradient cast to E5M2.
optimizer: FP32 master weights, FP32 m, v.

11.5 Per-Tensor Scaling

Each tensor T gets a scale s_T. The FP8 representation is T_fp8 = round(T / s_T). To dequantize: T_bf16 = T_fp8 · s_T. The scale is updated each step using the amax (maximum absolute value) of the previous N steps.

s = amax / fp8_max_representable, with a safety margin.

TransformerEngine maintains an amax history per tensor (forward weights, forward activations, backward grads, etc.); the scale is the EMA or window-max over the history.

11.6 Engineering Notes

  • FP8 wins compute-bound layers (large matmuls). Layers that are bandwidth-bound or have small compute don't benefit.
  • Loss curves with FP8 should track BF16 to within noise. If they diverge, scaling is wrong.
  • delayed scaling (using last step's amax) is faster than current scaling (compute amax during this step's pass) but slightly less stable.

12. Verifying Communication Overlap

A correctly implemented DDP/FSDP overlaps gradient comm with backward compute. Verifying this is a profiling exercise.

12.1 What "Overlap" Looks Like in a Trace

In torch.profiler or NSight, you'll see two streams active simultaneously:

  • Compute stream: matmuls, attention.
  • NCCL stream: all-reduce/all-gather/reduce-scatter kernels.

A healthy backward looks like:

Compute: |MM L80|MM L79|MM L78|MM L77|...
NCCL:           |AR  L80   |AR L79   |...

Each all-reduce kernel runs on the NCCL stream while the next matmul runs on compute. Total wall time ≈ max(compute, comm), not sum.

12.2 Failure Modes

No overlap. Compute and NCCL are serialized. Causes: - Bucket too small → all-reduce of a tiny bucket is latency-dominated, no compute to hide it. - find_unused_parameters=True forcing extra sync. - The all-gather in FSDP-3 must complete before forward proceeds (this is intentional-there's no way to hide it without prefetch).

Negative overlap. Comm kernel slows compute kernel because they share SMs (NCCL kernels use SMs to copy data). On H100, NCCL has dedicated copy engines for some collectives → less interference.

12.3 FSDP Prefetch in the Trace

With BACKWARD_PRE, you should see:

backward(L80):  |compute|
all_gather(L79):           |comm   |    ← issued before backward(L79)
backward(L79):                       |compute|
all_gather(L78):                              |comm |

The all-gather of unit l-1 runs while unit l's backward runs. Without BACKWARD_PRE, the all-gather of l-1 only starts after unit l's backward finishes → critical path doubled.

12.4 Measuring with Numbers

Run two configs:

  1. Real training: measure step time T_step.
  2. Training with comm replaced by no-op (or - -bench-no-comm): measure step timeT_compute`.

Overlap quality: T_step / T_compute. Ideal: 1.0. Practical: 1.05–1.15 for a healthy DDP/FSDP setup.


13. Fault Tolerance

A 1024-GPU job over a week has many opportunities to fail. Hardware errors, NCCL hangs, ECC on a HBM page, OOM from a long sequence-any one kills the job.

13.1 NCCL Timeouts

NCCL has a watchdog. If a collective hasn't completed in NCCL_TIMEOUT (default 30 min) or PyTorch's dist.init_process_group(timeout=...), the rank aborts.

Causes: - One rank deadlocked (e.g., infinite loop in dataloader). - One rank slower (bad GPU; watch for GPU temperature throttling). - Barriers mismatched: rank 0 calling all_reduce(x) while rank 1 calls all_reduce(y).

Detection: - Set TORCH_NCCL_ASYNC_ERROR_HANDLING=1: throws a Python exception on timeout instead of segfaulting. - TORCH_NCCL_DESYNC_DEBUG=1: dumps per-rank state when something hangs. - Monitor with dcgmi for GPU faults.

13.2 Checkpointing

Full checkpoint. Gather the full model and optimizer state to rank 0; save one big file. Simple, but: huge memory spike (rank 0 needs RAM for the full model), slow at scale.

Sharded checkpoint (FSDP, DCP). Each rank saves its shard; metadata describes how to reassemble. Modern PyTorch: torch.distributed.checkpoint.save and load. Allows resharding (load a 64-rank checkpoint onto 128 ranks; reshape happens automatically).

Cadence. Every N steps; N chosen so that wasted-compute-on-failure is acceptable. With ~5% failure rate per day, every 30-60 min is reasonable. Async checkpointing (writing in background while training continues) helps.

Atomic writes. Write to ckpt.tmp then rename. If a node crashes mid-write, no partial corruption.

13.3 Elastic Training

Cluster topology changes (nodes go away). Solutions:

  • torchrun --nnodes=2:8: rendezvous backend allows joining/leaving.
  • On rank loss: surviving ranks re-form the process group. Resume from latest checkpoint with the new world size (sharded checkpoints make this resharding transparent).

In practice: most teams just use static configs and re-launch on failure. Elastic is mostly for ML-as-a-service platforms.

13.4 Determinism Considerations

NCCL is deterministic by default (same input → same output) but all-reduce is not associative in floating point. Different ring traversal orders yield slightly different sums. If your random seed plus determinism is reproducible to the bit on the same world size-but probably not on a different world size.

For deterministic training: fix world size, use deterministic ops, accept that scaling is the cost of bit-determinism.


14. Cluster Topology

Performance depends on physical wiring. Here's what each layer does and why it matters.

14.1 Inside the Node

NVLink (intra-GPU links). On H100: 18 NVLink-4 links per GPU at 25 GB/s each → 450 GB/s per direction, 900 GB/s bidirectional aggregate (for the NVLink fabric, though per-pair bandwidth is lower).

NVSwitch (the chip that connects NVLinks). DGX H100: 4× NVSwitch chips form a non-blocking fabric: any GPU can talk to any other at full NVLink speed. Without NVSwitch (e.g., HGX boards with direct NVLink), GPU 0 ↔ GPU 1 might be 600 GB/s but GPU 0 ↔ GPU 5 might be 300 GB/s (one hop).

NVLink SHARP (NVLS). NVSwitch v3 (H100) does in-network reduction: ranks send data to switch, switch reduces, sends result back. Halves bandwidth requirement on the link (instead of 2(N-1)/N · M, more like (N-1)/N · M).

PCIe. CPU↔GPU and GPU↔NIC paths. Gen5 x16 = 64 GB/s. Bottleneck for CPU offload (ZeRO-Offload).

14.2 Between Nodes

InfiniBand (or RoCE). H100 generation: NDR (400 Gb/s = 50 GB/s) per port. DGX H100 has 8× NDR cards: one per GPU (rail-aligned).

Rail topology. Each GPU has its own NIC. The cluster fabric is structured as R rails, where rail r connects only NIC r of each node to a top-of-rack switch. Cross-rail traffic must go through a higher-tier switch. NCCL, properly configured, keeps each rank's traffic on its own rail.

Fat tree / spine-leaf. Typical large-cluster topology: leaf switches connect ~32 nodes (= 256 GPUs); spine switches connect leaves. Bisection bandwidth determined by spine fan-out.

Hop count. Same-node: 0 hops (NVSwitch). Same-rack: 1 hop (leaf). Cross-rack: 3 hops (leaf-spine-leaf). Each hop adds ~1µs of latency and shares bandwidth with other tenants.

14.3 GPU↔NIC Affinity

nvidia-smi topo -m output:

        GPU0   GPU1   GPU2   GPU3   GPU4   GPU5   GPU6   GPU7   NIC0  NIC1  NIC2  NIC3  CPU
GPU0     X    NV12   NV12   NV12   NV12   NV12   NV12   NV12   PXB   SYS   SYS   SYS   0
NIC0    PXB   SYS    SYS    SYS    SYS    SYS    SYS    SYS    X     ...

Legend: NV12 = 12 NVLinks. PXB = PCIe through bridge (good). SYS = through CPU socket (bad).

For NCCL: ensure NCCL_IB_HCA lists, in order, the NICs aligned with each GPU rank (mlx5_0 for GPU 0, etc.).

14.4 Why Topology Is in This Chapter

Algorithm choice depends on topology. Ring all-reduce on NVLink: optimal. Ring all-reduce that crosses spine switches every step: catastrophic (latency dominated, contention with other tenants).

Decision: TP within node, DP/PP across nodes. NCCL's auto-topology gets this right most of the time, but verify.


15. Practical Exercises

Six worked problems. Cover them yourself before reading the answers.

Exercise 1: All-reduce Bandwidth

You have 8 GPUs in a ring at 600 GB/s NVLink each. You all-reduce a tensor of 14 GB (BF16 gradients of a 7B model). How long does the all-reduce take, ignoring latency?

Answer. Ring all-reduce sends 2(N-1)/N · M = 2·7/8 · 14 = 24.5 GB per rank. At 600 GB/s: 24.5 / 600 = 40.8 ms.

If, instead of ring, NCCL chose tree (suboptimal here): 2 log₂(8) · M = 6 · 14 = 84 GB per rank → 140 ms. Big difference; this is why NCCL picks ring on intra-node.

Exercise 2: Peak Memory for FSDP + Activation Checkpointing

Compute peak per-GPU memory for a 70B model under FSDP-3 + activation checkpointing on 64 H100 80GB GPUs. Assume BF16 params/grads, FP32 optimizer (Adam, K=12), batch 4 per GPU, sequence 8192, hidden 8192, 80 layers, BF16 activations. Use the simplified bound: with checkpointing, per-layer activation ~ B · S · h · 12 bytes (rough).

Answer. - Static (FSDP-3): 16Φ / N = 16 · 70 / 64 = 17.5 GB. - Activation per layer (post-checkpoint): 4 · 8192 · 8192 · 12 ≈ 3.2 GB. With one stored input per layer × 80 layers: ~256 GB. This is way too much if each layer's input is stored.

Recheck: with selective activation checkpointing, you store only the input of each transformer block (one tensor per block). That's 4 · 8192 · 8192 · 2 = 537 MB per layer × 80 = 43 GB. Plus the recompute buffer for the layer currently being recomputed (transient).

So total: 17.5 (state) + 43 (stored block inputs) + ~10 (workspace, kv cache, recompute buffer) ≈ 70 GB. Tight on 80 GB. Solutions: increase N, decrease batch, add TP=2.

Exercise 3: Pipeline Bubble

You run GPipe with P=8 stages. What's the minimum number of microbatches to keep bubble below 5%?

Answer. Bubble = (P-1)/(M+P-1) ≤ 0.05. Solve: (M + 7) ≥ 7/0.05 = 140. So M ≥ 133. With per-microbatch batch of 1 and 133 microbatches, you need batch ≥ 133 per DP replica. May not be feasible if your global batch budget is small.

With 1F1B, same bubble. With interleaved 1F1B (v=4): 7/(4M+7) ≤ 0.05M ≥ 33. Much better.

Exercise 4: TP Communication Cost

A Transformer block with h=8192, B=4, S=8192, runs with t=8 TP. Compute total bytes communicated per block per step (forward + backward).

Answer. 4 all-reduces per block × 2(t-1)/t · 2BSh bytes per all-reduce.

2BSh = 2 · 4 · 8192 · 8192 = 512 MB.

Per all-reduce: 1.75 · 512 = 896 MB. Times 4: 3.58 GB per block per step.

With L=80 blocks: ~287 GB per step of TP comm. On 600 GB/s NVLink: ~480 ms. This is why TP comm dominates and why you want overlap (sequence parallelism + compute).

Exercise 5: 3D Parallel Sizing

Llama-3 405B on 1024 H100 80GB. Choose (t, p, d). Show your reasoning.

Answer. - t = 8 (max, intra-node). - Try p = 8. Stage params = 405/8 = 50.6B. Per-GPU after TP: 2 · 50.6 / 8 = 12.6 GB (BF16). - d = 1024 / (8·8) = 16. With ZeRO/FSDP across DP: per-GPU state = (16Φ)/(t·p·d) = 16 · 405 / 1024 ≈ 6.3 GB. Fits comfortably; activations + workspace ~30 GB; total ~36 GB; within 80 GB. - With PP=8 and global batch ~1024 sequences, microbatch=8 per stage × 16 DP replicas = 128 microbatches. Bubble = 7/(128+7) = 5.2%. With interleaved 1F1B v=4: 7/(512+7) ≈ 1.3%. Excellent.

(t=8, p=8, d=16) or thereabouts. Real systems may vary p based on per-stage compute and act memory.

Exercise 6: ZeRO Stage Tradeoff

Your model is 13B; you have 16 A100 40GB GPUs. Standard mixed-precision Adam. Which ZeRO stage is the minimum that fits?

Answer. - DDP: 16Φ = 16·13 = 208 GB per rank-won't fit. - ZeRO-1: 4Φ + 12Φ/N = 52 + 9.75 = 61.75 GB - won't fit on 40 GB. - ZeRO-2:2Φ + 14Φ/N = 26 + 11.4 = 37.4 GB - borderline; with activations and workspace, will OOM. - ZeRO-3: `16Φ/N = 13 GB - fits with room for activations.

ZeRO-3 / FSDP-3 is required. With activation checkpointing and a moderate batch, you'll stay around 30-35 GB peak.


Concluding Synthesis

The whole apparatus is one idea repeated at three resolutions:

  1. Memory is the binding constraint. Single-GPU memory grows ~2× per generation; model size grows ~10× per year. The gap is what distributed training closes.

  2. Communication is the price. Every form of partitioning (DP grad sync, ZeRO param gather, TP activation all-reduce, PP activation send) trades memory for bandwidth.

  3. Topology determines algorithm. Ring is bandwidth-optimal but latency-linear. Tree is latency-logarithmic but bandwidth-suboptimal. Inside a node (NVLink), bandwidth is plentiful → ring; across nodes (IB), pick based on message size and tier.

The mature configuration for a frontier model: - TP = node size (e.g., 8), so TP all-reduces are NVLink-confined. - PP across nodes, so cross-node traffic is point-to-point (cheaper than all-reduce). - DP fills the rest, with FSDP/ZeRO-3 to shard state within each DP rank. - Sequence parallel on by default for any TP > 1, to keep activation memory in check. - Activation checkpointing to make long context tractable. - BF16 + FP8 in the matmul-heavy layers; FP32 master weights.

When you can derive each line of this recipe from first principles-why TP wants NVLink, why FSDP-3 needs prefetch, why PP wants 1F1B, why ZeRO trades 1.5× comm for N× memory-you have mastered distributed training.


Appendix: Notation

  • `Φ - number of model parameters.
  • `N - DP world size (or generic rank count for collectives).
  • `t, p, d - TP, PP, DP degrees.
  • `L - number of Transformer layers.
  • `h - hidden dimension.
  • `S - sequence length.
  • `B - batch size.
  • `M - number of microbatches (PP), or message size in bytes (collectives).
  • `α - link latency (s).
  • `β - link bandwidth (B/s).
  • `K - Adam optimizer state multiplier (= 12 for FP32 m, v, master weights).

Appendix: Paper References (for further reading; the math above is self-contained)

  • Patarasuk & Yuan, "Bandwidth Optimal All-reduce Algorithms for Clusters of Workstations," 2009.
  • Rajbhandari et al., "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models," SC 2020.
  • Shoeybi et al., "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism," 2019.
  • Narayanan et al., "Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM," SC 2021.
  • Huang et al., "GPipe: Efficient Training of Giant Neural Networks Using Pipeline Parallelism," NeurIPS 2019.
  • Narayanan et al., "PipeDream: Generalized Pipeline Parallelism for DNN Training," SOSP 2019.
  • Korthikanti et al., "Reducing Activation Recomputation in Large Transformer Models," 2022.
  • Zhao et al., "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel," 2023.
  • Qi et al., "Zero Bubble Pipeline Parallelism," 2023.
  • Micikevicius et al., "Mixed Precision Training," ICLR 2018.
  • NVIDIA, TransformerEngine documentation (FP8 training recipes).

These give you the published numerical results and ablations; the algorithms themselves you now own.

Deep Dive 07: Attention, the Transformer, and FlashAttention

A self-contained reference. By the end of this chapter you should be able to: derive scaled dot-product attention from first principles, implement causal multi-head attention from scratch matching F.scaled_dot_product_attention, reason about KV-cache memory for any decoder-only LLM, derive the online softmax that powers FlashAttention, and explain why FA-2 and FA-3 each roughly doubled throughput.


Table of contents

  1. From "predict the next token" to a transformer
  2. Scaled dot-product attention-full derivation
  3. Multi-head attention
  4. Causal masking
  5. MQA and GQA-why decode is bandwidth-bound
  6. Position encodings: Sinusoidal, Learned, RoPE, ALiBi, Sliding window, YaRN/PI/NTK
  7. The transformer block: pre-norm, residuals, FFN
  8. LayerNorm vs RMSNorm
  9. Activations: GeLU, SwiGLU
  10. KV-cache: math, layouts, paged attention
  11. Attention complexity: O(S^2) is the enemy
  12. FlashAttention-derivation of the online softmax and tiled algorithm
  13. FlashAttention-2 deltas
  14. FlashAttention-3 deltas
  15. Decode-time variant: flash_attn_with_kvcache
  16. Practical exercises

Notation used throughout:

B = batch size S = sequence length (often L_q or L_k for query/key length separately) H = number of attention heads d = model hidden size (sometimes d_model) d_h = per-head dimension; usually d / H d_k = key dimension per head (= d_h in standard attention) d_v = value dimension per head (= d_h in standard attention) H_q = number of query heads, H_kv = number of K/V heads (for GQA/MQA) V = vocab size N = batch * sequence dimension when flattened


1. From "predict the next token" to a transformer

1.1 The autoregressive language modeling setup

A language model is a probability distribution over token sequences. Given a vocabulary of size V and a sequence x_1, x_2, ..., x_T of token IDs in {0, 1, ..., V-1}, the model factorizes the joint probability via the chain rule:

P(x_1, x_2, ..., x_T) = prod_{t=1..T} P(x_t | x_1, ..., x_{t-1})

We call P(x_t | x_<t) the next-token distribution. Training a decoder-only transformer is exactly fitting a parametric model p_theta(x_t | x_<t) by minimizing the negative log likelihood:

L(theta) = -(1/T) sum_{t=1..T} log p_theta(x_t | x_<t)

For a batch this is just averaged. The gradient signal at every position t trains the model to predict its own next token from a left context.

The structure we need:

  • Input: a sequence of T tokens, embedded as vectors of dimension d. So x in R^{T x d}.
  • Output: a sequence of T vectors in R^{T x d} (one per position). These get projected to V-dim logits and softmaxed to give the next-token distribution at each position.
  • Constraint: the output at position t must depend only on inputs x_<= t (causal); otherwise the loss leaks the answer.
  • Inductive bias: every output position should be a function of all previous positions, not just the immediately preceding one. We do not want the gradient to have to traverse hundreds of recurrent steps.

The transformer's answer is: at every position, aggregate information from all earlier positions in parallel via attention, then mix it position-wise with an MLP, then repeat L layers deep.

1.2 Why attention: the "gather" intuition

Imagine T = 1024 tokens and we are computing the new representation for position t. Some earlier positions are highly relevant ("the antecedent of the pronoun two paragraphs back"); most are not. Conceptually we want a soft-lookup:

  • Each position t emits a query q_t-what it is looking for.
  • Each position s emits a key k_s-what it advertises.
  • Each position s emits a value v_s-the payload it would contribute.
  • The new representation at t is a weighted sum of v_s, where the weight is high when q_t and k_s match.

A natural similarity is the dot product q_t . k_s. To turn unbounded scores into a convex combination we softmax across s. That gives:

a_{t,s} = softmax_s( q_t . k_s )         # attention weights at row t
h_t = sum_s a_{t,s} v_s                  # output at position t

This is parallel across all (t, s), so we batch it as one matrix multiply. This is the entire idea. Everything below is making it numerically stable, making it scale to many heads, making it causal, encoding position, and making it fit in memory at long context.


2. Scaled dot-product attention-full derivation

2.1 The formula

Given matrices Q in R^{S x d_k}, K in R^{S x d_k}, V in R^{S x d_v}:

Attention(Q, K, V) = softmax( Q K^T / sqrt(d_k) ) V

Step by step:

  1. S = Q K^T # raw scores, shape (S, S)
  2. S_scaled = S / sqrt(d_k) # divide by sqrt(d_k)
  3. P = softmax(S_scaled, axis=-1) # row-wise softmax, shape (S, S)
  4. O = P V # output, shape (S, d_v)

Each row P[t, :] is a probability distribution over key positions. Each row O[t, :] = sum_s P[t, s] * V[s, :]. So O[t, :] is the convex combination of value rows weighted by how well key s matches query t.

2.2 Why divide by sqrt(d_k): the variance derivation

This is not aesthetic. It is required to keep the softmax in its useful regime as d_k grows.

Assume each component of q in R^{d_k} and each component of k in R^{d_k} are independent random variables with mean 0 and variance 1. (This is the post-normalization, post-init regime that pre-norm transformers operate in.)

The dot product is

q . k = sum_{i=1..d_k} q_i k_i

The expected value is E[q.k] = sum E[q_i] E[k_i] = 0 (independence + zero mean). The variance:

Var(q.k) = sum_{i=1..d_k} Var(q_i k_i)
         = sum_{i=1..d_k} ( E[q_i^2] E[k_i^2] - 0 )
         = sum_{i=1..d_k} (1 * 1)
         = d_k

So q.k has standard deviation sqrt(d_k). Without scaling, individual scores have magnitudes that grow as sqrt(d_k). For modern d_k (64 or 128), this puts the softmax inputs in the regime where one entry dominates and the rest are crushed to zero. Two consequences:

  1. The softmax becomes nearly one-hot. Attention degenerates to a hard argmax look-up, which is hard to train (the gradient through softmax is approximately p_i (delta_{ij} - p_j); when p is one-hot, almost all entries are zero or saturated).
  2. The forward and backward become numerically fragile. Tiny perturbations to a single near-max score flip which key wins.

Dividing by sqrt(d_k):

Var(q.k / sqrt(d_k)) = Var(q.k) / d_k = 1

so the scaled scores have standard deviation 1 regardless of d_k. The softmax stays in a regime with meaningful gradients. That is the entire argument.

A common confusion: why sqrt(d_k) and not d_k? Because we want standard deviation to be 1, not variance. Variance scales linearly with d_k, so the standard deviation scales as sqrt(d_k).

2.3 Numerically stable softmax

Real implementations never compute softmax(z) as exp(z) / sum(exp(z)) naively because exp can overflow. The standard trick:

softmax(z)_i = exp(z_i - max(z)) / sum_j exp(z_j - max(z))

Subtracting the max is mathematically a no-op (top and bottom both get multiplied by exp(-max(z))) but keeps every exponent <= 0, so all values are in (0, 1]. Hold this idea-the same identity is what makes FlashAttention's online softmax possible (Section 12).

2.4 Tensor shapes

Walk through one attention layer.

Input: x in R^{B x S x d} Project: Q = x W_Q, K = x W_K, V = x W_V (each W is d x d) so Q, K, V in R^{B x S x d} Reshape: view as (B, S, H, d_h) where d_h = d / H transpose to (B, H, S, d_h) Scores: QK^T over last two dims gives (B, H, S, S) Softmax: along last axis, gives (B, H, S, S) Apply V: (B, H, S, S) x (B, H, S, d_h) -> (B, H, S, d_h) Reshape: transpose back to (B, S, H, d_h), view as (B, S, d) Project: O = h W_O, W_O in R^{d x d}, output (B, S, d)


3. Multi-head attention

3.1 Why multi-head

A single attention head computes one scoring function q.k. There are many useful relations between tokens-syntactic dependency, coreference, positional adjacency, semantic similarity-and forcing one head to encode all of them in a single d-dim space is a bottleneck. Multi-head says: split the d-dim space into H subspaces, each of dim d_h = d / H, and run an independent attention per subspace. Concatenate the H outputs, project.

The total parameter count and FLOPs do not change: one big d x d projection is identical to H independent (d x d_h) projections concatenated. The computational difference is in the QK^T step: instead of one (S, d)x(d, S) matmul of cost O(S^2 d), you do H independent (S, d_h)x(d_h, S) matmuls of cost O(S^2 d_h) each, totalling O(H * S^2 * d_h) = O(S^2 d). Same FLOPs.

What changes is the expressive structure. Each head's scores are softmax-normalized independently, so head i can spend all its probability mass on syntactic neighbors while head j attends globally to topical anchors.

3.2 Fused QKV projection

Implementations almost always fuse the three projections into one matmul:

qkv = x @ W_qkv        # W_qkv in R^{d x 3d}
Q, K, V = qkv.split(d, dim=-1)

This is one big GEMM (general matrix multiply) instead of three small ones. On GPUs, one large matmul beats three small ones because of launch overhead and tensor-core utilization. The math is identical: the columns of W_qkv are just [W_Q | W_K | W_V] concatenated.

For GQA (Section 5) the fusion is asymmetric: W_qkv has shape d x (H_q + 2*H_kv) * d_h, so the K and V slices are narrower than Q.

3.3 Pseudocode for multi-head attention

def mha(x, W_qkv, W_o, H):
    B, S, d = x.shape
    d_h = d // H
    qkv = x @ W_qkv                                 # (B, S, 3d)
    q, k, v = qkv.split(d, dim=-1)                  # each (B, S, d)
    q = q.view(B, S, H, d_h).transpose(1, 2)        # (B, H, S, d_h)
    k = k.view(B, S, H, d_h).transpose(1, 2)
    v = v.view(B, S, H, d_h).transpose(1, 2)
    scores = q @ k.transpose(-1, -2) / sqrt(d_h)    # (B, H, S, S)
    scores = scores.masked_fill(causal_mask, -inf)  # (S, S) lower tri
    p = softmax(scores, dim=-1)                     # (B, H, S, S)
    out = p @ v                                     # (B, H, S, d_h)
    out = out.transpose(1, 2).reshape(B, S, d)      # (B, S, d)
    return out @ W_o                                # (B, S, d)

4. Causal masking

4.1 Why we need it

In autoregressive training, we feed the model the entire sequence x_1..T once, compute the output at every position in parallel, and demand that position t predict x_{t+1}. If position t's output is allowed to depend on x_>t, the loss is zero by trivial copying-the model has not learned anything.

We therefore need: output at position t is a function of x_1..t only. Concretely, in the (S, S) attention matrix, row t may have nonzero entries only in columns 1..t. Columns t+1..S must be zeroed.

4.2 Implementation: -inf before softmax

You cannot zero out the probabilities after softmax-softmax normalizes, so zeroing some entries breaks the convex-combination invariant. Instead you set the scores (pre-softmax) at masked positions to -inf:

M[t, s] = 0          if s <= t
M[t, s] = -inf       if s > t

scores += M
P = softmax(scores, dim=-1)

Because exp(-inf) = 0, those positions contribute nothing to the normalization sum or the output. The remaining positions still form a valid probability distribution.

In code:

mask = torch.full((S, S), float('-inf'))
mask = torch.triu(mask, diagonal=1)   # zeros on/below diagonal, -inf above
scores = scores + mask                # broadcast over (B, H)

torch.triu(..., diagonal=1) sets the strictly upper triangle to whatever the original matrix had and zero elsewhere-but here we want -inf in the upper tri and 0 in the lower, which is what the snippet above produces because the source matrix is all -inf and triu keeps the strict upper.

4.3 The mask in matrix form

For S = 4:

M = [ [  0  -inf -inf -inf ]
      [  0    0  -inf -inf ]
      [  0    0    0  -inf ]
      [  0    0    0    0  ] ]

After softmax, the attention probability matrix is lower triangular. Each row sums to 1. Row t spreads probability over columns 1..t.


5. MQA and GQA-why decode is bandwidth-bound

5.1 Standard MHA

Standard multi-head attention has H query heads and H K/V heads-every query head has its own private K and V. KV-cache size scales with H.

5.2 Multi-Query Attention (MQA)

Shazeer 2019. All H query heads share a single K head and a single V head. So while Q has shape (B, H, S, d_h), K and V have shape (B, 1, S, d_h). The attention scores Q @ K^T broadcast K across the H query heads: (B, H, S, d_h) @ (B, 1, d_h, S) -> (B, H, S, S).

KV-cache shrinks by a factor of H. For Llama-3-70B (H=64) this would be a 64x reduction. The trade-off: model quality drops measurably because all heads must agree on what to store. MQA was the right answer for PaLM and early-era inference engines but has been largely superseded.

5.3 Grouped-Query Attention (GQA)

Ainslie et al. 2023. Compromise between MHA (H KV heads) and MQA (1 KV head). Pick H_kv such that H_q is a multiple of H_kv. Group every G = H_q / H_kv query heads to share one K/V head.

  • Llama-3-70B: H_q = 64, H_kv = 8, so G = 8.
  • Llama-3-8B: H_q = 32, H_kv = 8, so G = 4.

KV-cache shrinks by G compared to MHA, while quality matches MHA almost exactly in published benchmarks. This is now the default for serious LLMs (Llama 2/3, Mistral, Qwen, DeepSeek, ...).

5.4 Why this matters for inference

During decode (generating one token at a time after the prompt is processed), each step does:

  1. Read Q for the current token (1 token, fast).
  2. Read K and V for all S tokens in the cache (the entire history).
  3. Compute attention.
  4. Append new K and V to the cache.

Step 2 reads the full KV-cache from HBM every single decode step. With S = 8192 tokens, MHA, BF16, 80 layers, 64 heads, d_h = 128:

KV bytes per token per layer per head = 2 (K and V) * 128 * 2 (BF16) = 512
Total KV bytes = 80 * 64 * 8192 * 512 = 21.5 GB

Reading 21.5 GB at H100 HBM bandwidth (~3.35 TB/s) takes ~6.4 ms per token just for KV. Decode is dominated by this read, not by the FLOPs of the matmul. Decode is memory-bandwidth-bound.

GQA cuts the KV by 8x. MQA cuts it by 64x. Now you understand why every serious LLM ships GQA: it directly multiplies decode throughput by ~8.

The exact KV-cache memory formula (for one request):

bytes = 2 * num_layers * H_kv * seq_len * d_h * dtype_bytes

6. Position encodings

The attention operation softmax(QK^T / sqrt(d_k)) V is permutation-equivariant in the sequence dimension: if you shuffle the rows of Q, K, V identically, the output rows shuffle the same way. Without a position signal the model cannot tell "the cat sat" from "sat cat the". We need to inject position.

6.1 Sinusoidal (Vaswani 2017)

For position p (0-indexed) and embedding dimension i, define

PE[p, 2i]   = sin(p / 10000^{2i/d})
PE[p, 2i+1] = cos(p / 10000^{2i/d})

Then x_p_with_pos = embed(x_p) + PE[p].

Why this form? Two properties:

  1. Each dimension is a sinusoid with a different frequency, ranging from wavelength 2pi (i=0) up to wavelength 10000 * 2pi (i = d/2).
  2. PE[p + k] is a linear function of PE[p] (for fixed k), because sin and cos satisfy sin(a + b) = sin(a) cos(b) + cos(a) sin(b) cos(a + b) = cos(a) cos(b) - sin(a) sin(b) so the model can learn linear projections that compute relative offsets.

Why it generalizes mediocrely: the model still has to learn to use the relative-position structure, and the additive coupling means position info gets mixed with content via the projection matrices. In practice sinusoidal extrapolation to 2x the trained context degrades quickly.

6.2 Learned absolute position embeddings

Instead of the closed-form PE, allocate a learnable matrix P in R^{S_max x d} and add P[p] to embed(x_p). Used by GPT-2, BERT, OPT.

Pros: simple, often matches or beats sinusoidal at trained lengths. Cons: cannot extrapolate at all-there are no learned vectors for positions beyond S_max. Hard limit on context length.

6.3 RoPE-Rotary Position Embedding (Su et al. 2021)

The most-used position encoding in modern LLMs (Llama, Mistral, Qwen, DeepSeek, Gemma, ...). Worth deriving in full.

6.3.1 Goal

We want the dot product q_m . k_n to depend only on the relative offset n - m (and on the contents of the tokens), not on absolute positions m and n separately. Concretely, we want a function f such that the modified query at position m, q'_m = R_m q_m, and modified key at position n, k'_n = R_n k_n, satisfy

q'_m . k'_n = g(q_m, k_n, n - m)

for some function g. Note: rotation R_m means apply some linear map that depends on position m.

6.3.2 Complex-number formulation

Pair up the d_h components of q in R^{d_h} into d_h/2 pairs. Treat each pair (q_{2i}, q_{2i+1}) as a complex number z_i = q_{2i} + i * q_{2i+1}. Same for k.

For a fixed angular frequency theta_i, define the rotation by position m as

z_i^{(m)} = z_i * exp(i * m * theta_i)

i.e. multiply the complex number by a unit-magnitude complex of angle m * theta_i. Equivalently, in R^2,

[ q'_{2i}   ]   [ cos(m theta_i)  -sin(m theta_i) ] [ q_{2i}   ]
[ q'_{2i+1} ] = [ sin(m theta_i)   cos(m theta_i) ] [ q_{2i+1} ]

This is a 2D rotation by angle m * theta_i in the (2i, 2i+1) plane.

The key calculation: the inner product of two rotated complex numbers z_a^{(m)} and z_b^{(n)} is

< z_a * exp(i m theta) , z_b * exp(i n theta) >
    = Re( (z_a exp(i m theta)) * conj(z_b exp(i n theta)) )
    = Re( z_a conj(z_b) * exp(i (m - n) theta) )

Crucially this depends on m - n only, not on m and n separately. Summed across all pairs i (each with its own theta_i),

q'_m . k'_n = sum_i Re( z_{q,i} conj(z_{k,i}) * exp(i (m-n) theta_i) )
            = g(q_m, k_n, m - n)

This is exactly the relative-position-only similarity we wanted.

6.3.3 Choice of frequencies

Following sinusoidal precedent,

theta_i = base^{-2i / d_h}     for i = 0, 1, ..., d_h/2 - 1

with base = 10000 typically. So low-i pairs rotate fast (tracking local relative offsets) and high-i pairs rotate slow (tracking global offsets).

6.3.4 Implementation

In practice you precompute two tables of shape (S_max, d_h/2):

cos_table[m, i] = cos(m * theta_i)
sin_table[m, i] = sin(m * theta_i)

Then the rotation at position m is applied componentwise to a query/key of shape (..., d_h):

def apply_rope(x, cos_table, sin_table, positions):
    # x: (B, H, S, d_h)
    # positions: (S,) integer positions
    cos = cos_table[positions]          # (S, d_h/2)
    sin = sin_table[positions]          # (S, d_h/2)
    x1 = x[..., 0::2]                   # even-indexed (B, H, S, d_h/2)
    x2 = x[..., 1::2]                   # odd-indexed
    rot1 = x1 * cos - x2 * sin
    rot2 = x1 * sin + x2 * cos
    out = stack([rot1, rot2], dim=-1).flatten(-2)
    return out

Two important rules:

  • RoPE is applied to Q and K, not to V.
  • RoPE is applied after the linear projections W_Q, W_K, before the attention scores are computed.
  • RoPE is applied per head, not per model-dim.

Inference with RoPE on a KV-cache: the K stored in the cache is already rotated. When you append a new token, you rotate its K with the current position and append. No re-rotation of past K is needed.

6.3.5 Why RoPE became dominant
  • No additional parameters (just trig tables).
  • Naturally relative-generalizes (with the extension techniques in 6.6) to longer contexts than trained.
  • Composes cleanly with FlashAttention because it is applied before attention, not as an additive bias inside the softmax.

6.4 ALiBi-Attention with Linear Biases (Press et al. 2022)

Even simpler: modify the score matrix directly with a position-dependent bias.

scores[t, s] = q_t . k_s / sqrt(d_k) + slope_h * (s - t)

Here slope_h is a head-specific negative slope (precomputed). The penalty grows linearly with how far back the key is. Different heads have different slopes, so some attend close, some attend far.

Slopes for H heads are typically chosen as a geometric sequence:

slope_h = 2^{-8 h / H}   for h = 1, ..., H

Pros: zero-shot extrapolation-performance degrades smoothly when going to contexts longer than trained. No extra parameters.

Cons: less expressive than RoPE in practice. Largely supplanted by RoPE in flagship models, but BLOOM and a few others used ALiBi.

6.5 Sliding window attention (Mistral)

Restricts each token to attend only to the previous W tokens (e.g. W = 4096). This is a modification of the mask, not a position encoding, but it is closely related because it is how you get "positional" locality.

Mask:

M[t, s] = -inf  if s > t                    (causal)
M[t, s] = -inf  if t - s >= W               (out of window)
M[t, s] = 0     otherwise

In stacked layers, the receptive field grows: token t at layer L can indirectly see tokens up to t - L*W away (information propagates through the stack like a CNN). Mistral 7B uses W = 4096 with 32 layers, giving an effective receptive field of ~131K tokens.

KV-cache benefit: only the most recent W K/V need to be stored, capping KV memory at O(W) instead of O(S).

6.6 Context extension: YaRN, NTK-aware, Position Interpolation

A model trained at context length L_train often needs to be extended to 4L_train or 8L_train at deploy time. Three families, all working in the RoPE frequency domain:

Position Interpolation (PI) (Chen et al. 2023): scale all positions by L_train / L_target. Geometrically, every token's RoPE rotation is slowed by the scale factor, so the maximum rotation angle the model sees is unchanged. Works but degrades quality-high-frequency dimensions lose discriminative power.

NTK-aware scaling (bloc97 / community 2023): instead of scaling all frequencies uniformly, change the RoPE base such that high-frequency dimensions are barely affected and only low-frequency dimensions are interpolated. Concretely

base' = base * (L_target / L_train)^(d_h / (d_h - 2))

Better preservation of local relative-position info than PI.

YaRN (Peng et al. 2023): a more careful per-frequency-band schedule. Bands are split into "extrapolation" (high freq, kept as-is), "interpolation" (low freq, scaled), and a transition region. Adds a small temperature correction to the attention softmax. Empirically the strongest of the three with comparable fine-tuning.

In all three, you typically fine-tune for a small number of steps on long-context data after applying the schedule.


7. The transformer block

7.1 The structure

A decoder-only transformer layer has two sub-blocks: attention and FFN (also called MLP). Each sub-block has a residual connection and a normalization. The two competing wirings are post-norm (Vaswani 2017) and pre-norm (used by every modern LLM).

Post-norm (original):

h = LayerNorm( x + Attention(x) )
y = LayerNorm( h + FFN(h) )

Pre-norm:

h = x + Attention( LayerNorm(x) )
y = h + FFN( LayerNorm(h) )

The difference matters. In pre-norm, the residual stream x flows from input to output unmodified by any normalization-the layer adds a correction computed from a normalized view of x. This means deep pre-norm transformers are easier to train: gradients flow through the residuals without being scaled by repeated normalizers. Post-norm transformers required learning-rate warmup to train at depth and were fragile beyond ~12 layers without tricks.

Modern stacks (Llama, Mistral, GPT-NeoX) all use pre-norm + RMSNorm.

7.2 ASCII diagram of a pre-norm block

     x  -----------------------------------+----------+
     |                                     |          |
     v                                     |          |
  RMSNorm                                  |          |
     |                                     |          |
     v                                     |          |
  Attention(Q=K=V=norm(x), causal,         |          |
            with RoPE on Q and K)          |          |
     |                                     |          |
     v                                     |          |
     +-------------(add residual)----------+          |
     |                                                |
     v                                                |
     h                                                |
     |                                                |
     v                                                |
  RMSNorm                                             |
     |                                                |
     v                                                |
  FFN  (e.g. SwiGLU: down(silu(gate(x)) * up(x)))     |
     |                                                |
     v                                                |
     +---------(add residual)-------------------------+
     |
     v
     y

7.3 Pseudocode

def block(x, params):
    h = x + attention(rmsnorm(x, params.norm1), params.attn)
    y = h + ffn(rmsnorm(h, params.norm2), params.ffn)
    return y

def transformer(tokens, params):
    x = embed(tokens, params.embed)
    for layer_params in params.layers:
        x = block(x, layer_params)
    x = rmsnorm(x, params.final_norm)
    logits = x @ params.embed.weight.T   # tied embeddings, often
    return logits

8. LayerNorm vs RMSNorm

8.1 LayerNorm

Ba et al. 2016. For a vector x in R^d:

mean = (1/d) sum_i x_i
var  = (1/d) sum_i (x_i - mean)^2
y    = (x - mean) / sqrt(var + eps)
out  = gamma * y + beta

Two reductions across the feature axis (mean and variance), two learnable parameters per dim (gamma is scale, beta is shift), one elementwise subtract, one elementwise divide.

8.2 RMSNorm

Zhang & Sennrich 2019. Drops the mean centering:

rms = sqrt( (1/d) sum_i x_i^2 + eps )
out = (x / rms) * weight

One reduction across the feature axis (sum of squares), one learnable parameter per dim (weight, the scale), no shift, no mean subtraction.

8.3 Why RMSNorm is enough

Empirical observation (Zhang & Sennrich, then countless replications): in pre-norm transformers, the mean centering of LayerNorm contributes little to model quality. The crucial operation is the variance normalization-bounding the magnitude of x so that the subsequent linear layer sees inputs of controlled scale. Mean subtraction is redundant given the high dimensionality and the fact that the gamma/beta parameters can absorb shifts.

Computational benefits:

  • One reduction instead of two-about 30-40% faster.
  • Half the parameters (no beta).
  • Slightly better numerics (the mean subtraction can subtract two similar values, losing precision).

Every Llama-family model and most modern open-weights LLMs use RMSNorm.


9. Activation functions in the FFN

9.1 The plain FFN

Vaswani 2017 used a two-layer MLP per position:

ffn(x) = W_2 ( gelu( W_1 x ) )

Sizes: W_1 is d x d_ff, W_2 is d_ff x d, with d_ff = 4d typically. Two matmuls. Non-linearity is GeLU (or ReLU originally).

9.2 GeLU

Gaussian Error Linear Unit, Hendrycks & Gimpel 2016:

gelu(x) = x * Phi(x)

where Phi is the standard normal CDF. Approximated as:

gelu(x) ~= 0.5 x (1 + tanh( sqrt(2/pi) (x + 0.044715 x^3) ))

Smoother than ReLU near 0, approximately ReLU for large |x|. Empirically better than ReLU for transformers.

9.3 SwiGLU

Shazeer 2020. A gated FFN: instead of one input matmul + nonlinearity, project the input through two matrices, multiply them elementwise after nonlinearity-ing one of them, then project down.

gate(x) = silu( W_gate x )      # silu(x) = x * sigmoid(x)
up(x)   = W_up   x
ffn(x)  = W_down ( gate(x) * up(x) )

Three weight matrices instead of two. To keep parameter count approximately equal to a plain GeLU FFN with d_ff = 4d, SwiGLU implementations use d_ff = (8/3) d ~= 2.67 d. Llama-3 uses d_ff ~= (8/3) d rounded to a friendly multiple.

The cost: 3 matmuls instead of 2 (about +50% FFN compute). The benefit: empirically better quality at fixed parameter count, and better still at fixed compute when tuned. It is now the default-Llama, Mistral, Qwen all use SwiGLU.

The "GLU family" is parameterized by which nonlinearity wraps the gate: GLU (sigmoid), ReGLU (relu), GeGLU (gelu), SwiGLU (silu). SwiGLU won.


10. KV-cache

10.1 Why it exists

In autoregressive decode, you generate one new token at a time. Naively each step would be:

for t in 1..T_gen:
    full_input = prompt + generated_so_far     # length t
    run full transformer forward on full_input
    sample the next token

The forward pass on length-S input does O(S^2 d) work in attention. So generating T tokens is O(T^3 d) total. This is catastrophic.

Observation: at decode step t+1, the K and V matrices for positions 1..t are exactly the same as they were at step t. Only one new K/V pair is added at position t+1. So we can cache K and V across steps.

10.2 The decode loop with KV-cache

K_cache, V_cache = empty
# Prefill: process the prompt of length P in one big forward
Q, K, V = project(prompt)
K_cache, V_cache = K, V
output = attention(Q, K_cache, V_cache, causal=True)
# take the last position's logits, sample x_{P+1}

# Decode: one token at a time
for t in P+1..P+T_gen:
    x_t = embedded(generated[t])               # 1 token
    q, k, v = project(x_t)                     # each (B, H, 1, d_h)
    K_cache = cat(K_cache, k, dim=seq)         # grow by 1
    V_cache = cat(V_cache, v, dim=seq)
    out = attention(q, K_cache, V_cache, causal=True)  # implicit: q sees all K
    logits = unembed(out)
    sample the next token

Per-step work in decode: project 1 token (O(d^2)), do attention with Q of length 1 against K/V of length t (O(t * d)), FFN on 1 token (O(d^2)). Linear in t, not quadratic. Generating T tokens is O(T^2 d) total-quadratic in T, not cubic.

10.3 KV-cache memory

For a single request:

bytes = 2 * num_layers * H_kv * seq_len * d_h * dtype_bytes

Where: - 2 covers K and V. - num_layers: number of transformer blocks. - H_kv: K/V heads (= H_q in MHA, smaller in GQA, 1 in MQA). - seq_len: how many tokens are in the cache. - d_h: per-head dim. - dtype_bytes: 2 for FP16/BF16, 1 for FP8/INT8.

10.4 Worked example: Llama-3-70B at 8K context, BF16

num_layers = 80 H_kv = 8 (GQA, H_q = 64, group size 8) seq_len = 8192 d_h = 128 dtype_bytes = 2

bytes = 2 * 80 * 8 * 8192 * 128 * 2 = 2,684,354,560 ~= 2.5 GiB per request

(If it were MHA with H = 64 instead of GQA H_kv = 8, this would be ~20 GiB per request, which is why GQA exists.)

Per token added to the cache:

bytes_per_token = 2 * 80 * 8 * 128 * 2 = 327,680 bytes ~= 320 KiB per token

So generating 8K new tokens adds ~2.5 GiB to the cache. At HBM bandwidth 3.35 TB/s on H100, just reading that 2.5 GiB once costs ~0.75 ms.

For a 32K context the KV is 10 GiB per request; for 128K it is 40 GiB per request. KV-cache, not weights, becomes the dominant memory consumer at long context.

Same model, MQA hypothetical (H_kv = 1):

bytes = 2 * 80 * 1 * 8192 * 128 * 2 = 320 MiB per request

Same model, MHA hypothetical (H_kv = 64):

bytes = 2 * 80 * 64 * 8192 * 128 * 2 = 20 GiB per request

The 8x reduction MHA -> GQA is exactly what makes long-context inference practical.

10.5 Layout: contiguous vs paged

Contiguous layout. Allocate one big tensor per request of shape (num_layers, 2, H_kv, max_seq_len, d_h). Simple, but you must reserve max_seq_len up front. If max_seq_len = 8K but a request stops at 200 tokens, you wasted (1 - 200/8192) ~= 97% of that allocation.

In a serving system with many concurrent requests of varying lengths, contiguous layout means you must either: - Pre-size for the worst case and accept massive waste, or - Refuse to add a new request unless you have full max-len space free.

This caps concurrency catastrophically-a 40 GiB GPU might be unable to host more than 4 requests at 8K despite using <10% of the cache.

Paged attention (vLLM, Kwon et al. 2023). Treat the KV-cache as a virtual memory: divide it into fixed-size blocks (e.g., 16 tokens per block). Each request keeps a block table-a list of physical block IDs it owns. When a request grows, it allocates a new block from a free pool. When it finishes, blocks return to the pool.

This is exactly the OS virtual-memory abstraction applied to attention. Benefits: - Internal fragmentation is bounded by one block per request (~16 tokens of waste, instead of 8K). - Many more concurrent requests fit in the same HBM. - Copy-on-write block sharing for prefix caching: if requests share a system prompt, they can share its KV blocks.

The cost: the attention kernel must be paged-aware-it indirects through the block table instead of a contiguous slab. PagedAttention kernels (vLLM, FA-3, SGLang) are non-trivial; the attention loop has to gather K/V from non-contiguous memory. Modern serving engines all use paged KV.


11. Attention complexity-why long context is hard

For a sequence of length S with model dim d (single head):

QK^T:     (S x d) @ (d x S)  -> (S x S),   FLOPs ~ S^2 * d
softmax:  on S x S,                         FLOPs ~ S^2
@ V:      (S x S) @ (S x d)  -> (S x d),   FLOPs ~ S^2 * d
Total compute:          O(S^2 * d)
Score-matrix memory:    O(S^2)

For S = 32K, the score matrix per head is 32K * 32K = 1 G entries. In BF16 that is 2 GB. Per head. Per layer. For one request. Materializing the score matrix in HBM is the binding constraint at long context-long before you run out of FLOPs, you run out of memory to hold the intermediate softmax matrix.

This is what FlashAttention solves.


12. FlashAttention

Dao et al. 2022. Two ideas working together: (a) tile the attention computation so that only small blocks of Q, K, V are in fast memory at any time, never the full S x S score matrix; (b) compute softmax online over the K tiles so each output row's normalization stays correct without seeing all scores at once.

12.1 GPU memory hierarchy

  • HBM (high-bandwidth memory): 40-141 GB on H100, ~3 TB/s. Big, slow by GPU standards.
  • SRAM / shared memory / register file: ~256 KB per SM, ~20 TB/s. Tiny, fast.

Standard attention reads/writes the S x S score matrix to HBM at every step (because it doesn't fit in SRAM). The S x S matrix dominates HBM traffic. FlashAttention's goal: do all attention work for a Q tile inside SRAM, never spilling the score matrix.

12.2 The online softmax-full derivation

The non-trivial part is doing softmax incrementally across K tiles. Suppose K is split into tiles K_1, K_2, ..., K_T and we want to compute, for a fixed Q tile (call its rows q):

P = softmax( [s_1 ; s_2 ; ... ; s_T] )    where s_j = q K_j^T / sqrt(d_k)
O = P [V_1 ; V_2 ; ... ; V_T]

We process one tile (K_j, V_j) at a time, maintaining for each row of Q three running statistics:

m   = current running max over scores seen so far
l   = current running sum of exp(score - m) over scores seen so far
O_  = current running unnormalized output (sum of exp(score - m) * v)

After processing all T tiles, the final output is O_ / l. Per row.

Now the recurrence. Suppose we have processed tiles 1..j-1 and have state (m, l, O_). We process tile j with scores s_j (a block of K_BLOCK columns) and values V_j.

Step 1: compute the local max of the new tile.

m_local = max( s_j )                               # scalar per row of q

Step 2: update the running max.

m_new = max( m, m_local )                          # scalar per row

Step 3: rescale the old running sum and output to be relative to m_new. The reason: the old l was computed as sum exp(score - m). To put it on the m_new scale, we multiply by exp(m - m_new):

correction = exp( m - m_new )                      # scalar per row, in (0, 1]
l_new_partial = l * correction
O_new_partial = O_ * correction                    # vector per row

Step 4: compute the new tile's contribution on the m_new scale.

p_j = exp( s_j - m_new )                           # block of weights, in (0, 1]
l_new = l_new_partial + sum( p_j )                 # add new tile's mass
O_new = O_new_partial + p_j @ V_j                  # add new tile's contribution

Step 5: store (m_new, l_new, O_new) as the new state.

After all tiles are processed, divide:

O_final = O_ / l                                   # the actual softmaxed output
12.2.1 Why this is mathematically equal to the all-at-once softmax

Let s_1, ..., s_T be the per-tile score blocks and V_1, ..., V_T the per-tile value blocks. Let m_global = max over all entries of all s_j. The all-at-once softmax gives:

p_global_j = exp(s_j - m_global) / Z,    Z = sum_j sum_entries exp(s_j - m_global)
O = sum_j p_global_j @ V_j
  = (1/Z) sum_j exp(s_j - m_global) @ V_j

We need to show that the online algorithm produces exactly this.

Inductive claim: after processing tiles 1..j with running state (m, l, O_),

O_ = sum_{k=1..j} exp(s_k - m) @ V_k
l  = sum_{k=1..j} sum_entries exp(s_k - m)

Base case j = 0: m = -inf, l = 0, O_ = 0. The empty sums are zero. The exp(s - m) is technically 0/0 with m = -inf, but we never evaluate it on zero entries because the tile 1 update sets m = m_1.

Inductive step: assume the claim holds for j-1 with running max m. We process tile j with new tile max m_j = max(s_j), new running max m' = max(m, m_j). The correction factor is c = exp(m - m').

After the update, claim m and the new running quantities are:

O_'  = O_ * c + exp(s_j - m') @ V_j
     = c * sum_{k=1..j-1} exp(s_k - m) @ V_k + exp(s_j - m') @ V_j
     = sum_{k=1..j-1} exp(m - m') exp(s_k - m) @ V_k + exp(s_j - m') @ V_j
     = sum_{k=1..j-1} exp(s_k - m') @ V_k + exp(s_j - m') @ V_j
     = sum_{k=1..j} exp(s_k - m') @ V_k

Identical algebra for l. So the inductive claim holds with m replaced by m'.

By induction, after T tiles, m equals the global max m_global, and:

O_ = sum_{k=1..T} exp(s_k - m_global) @ V_k
l  = sum_{k=1..T} sum_entries exp(s_k - m_global) = Z

so O_ / l = O. Online softmax = batched softmax exactly. No approximation. The numerical-stability trick (subtract the running max) is the same trick as ordinary stable softmax, just done incrementally.

12.3 The tiled algorithm

# Inputs: Q (S_q x d), K (S_k x d), V (S_k x d).
# Tile sizes: B_q (Q rows per tile), B_k (K rows per tile).
# Output: O (S_q x d).

for q_tile in range(0, S_q, B_q):                      # outer: over Q
    Q_tile = Q[q_tile : q_tile + B_q]                  # (B_q, d), load to SRAM
    m = full((B_q,), -inf)                             # running max, in SRAM
    l = zeros((B_q,))                                  # running sum
    O_ = zeros((B_q, d))                               # running output

    for k_tile in range(0, S_k, B_k):                  # inner: over K, V
        K_tile = K[k_tile : k_tile + B_k]              # (B_k, d), load
        V_tile = V[k_tile : k_tile + B_k]              # (B_k, d), load
        S = Q_tile @ K_tile.T / sqrt(d)                # (B_q, B_k), in SRAM
        if causal: S = mask(S, q_tile, k_tile)
        m_local = rowmax(S)                            # (B_q,)
        m_new = maximum(m, m_local)                    # (B_q,)
        correction = exp(m - m_new)                    # (B_q,)
        P = exp(S - m_new[:, None])                    # (B_q, B_k)
        l = l * correction + rowsum(P)                 # (B_q,)
        O_ = O_ * correction[:, None] + P @ V_tile     # (B_q, d)
        m = m_new

    O[q_tile : q_tile + B_q] = O_ / l[:, None]         # final normalize, store to HBM

ASCII picture:

Q (S_q x d)               K (S_k x d)           V (S_k x d)        O (S_q x d)
+-----+                  +---+---+---+        +---+---+---+        +-----+
| Q_1 |  outer loop -->  |K_1|K_2|K_3|  ...   |V_1|V_2|V_3|        | O_1 |
+-----+                  +---+---+---+        +---+---+---+        +-----+
| Q_2 |                                                            | O_2 |
+-----+                                                            +-----+
| ... |                                                            | ... |
+-----+                                                            +-----+

For each Q_i:
  for k_tile = 1..T:
    score block S_ik = Q_i K_k^T / sqrt(d)             [stays in SRAM]
    update (m, l, O_) with online softmax
  emit O_i = O_ / l                                    [single write to HBM]

12.4 Why this saves memory and bandwidth

Memory: the only on-the-fly working set is one Q tile (B_q x d), one K tile (B_k x d), one V tile (B_k x d), one score block (B_q x B_k), and the running state (B_q x (d + 2)). All small. The full S x S score matrix is never materialized anywhere.

Total memory for activations is O(B_q * d + B_k * d + B_q * B_k), with B_q and B_k chosen to fit in SRAM (e.g., 64 or 128). The dominant memory is the inputs and the output, both O(S * d). So overall O(S * d) memory instead of O(S^2).

Bandwidth: standard attention's HBM traffic is dominated by reading and writing the S x S score matrix-O(S^2) reads + O(S^2) writes. FA reads each of Q, K, V from HBM once (O(S * d)) and writes O once (O(S * d)). Score blocks live in SRAM and never see HBM. Effective HBM bandwidth drops from O(S^2) per pass to O(S * d) per pass, an S/d-fold reduction.

For S = 8192, d = 128 per head, that is a 64x reduction in HBM traffic. Attention is bandwidth-bound on real GPUs, so this translates almost directly into ~5-10x wall-clock speedup at long context.

12.5 Backward pass

For the backward, you also recompute the attention block-by-block (you don't store the S x S matrix during the forward, so you can't read it back). The trick: store only m and l per row from the forward-they are O(S) total. In the backward, recompute scores tile-by-tile using the stored m and l, derive dQ, dK, dV. Total backward FLOPs are roughly 2-3x forward FLOPs, dominated by the recomputation. But memory stays O(S) vs O(S^2) for naive backward.


13. FlashAttention-2 deltas

Dao 2023. FA-1 was already great but had several inefficiencies that FA-2 addressed. FA-2 is roughly 2x faster than FA-1 on A100, getting close to GEMM efficiency.

The key ideas:

Better work partition. FA-1 parallelized over (batch * heads), with the seq-length loop being sequential within each (batch, head) thread block. For long sequences with few heads (large S, small BH), GPU utilization was poor-too few thread blocks. FA-2 parallelizes over (batch * heads * seq_q), so the Q-tile* outer loop is also parallel. This makes long-context attention scale to all SMs.

Reduced non-matmul work. GPU tensor cores execute matmul-shaped work at peak throughput; everything else (exp, max, divide) runs on much slower CUDA cores. FA-1 spent more time on these "non-matmul" operations than necessary, particularly on the per-block rescaling. FA-2 reorders the algorithm so the running statistics are updated less frequently and the final divide-by-l is deferred to a single pass at the end. The overall ratio of matmul to non-matmul FLOPs improves from ~70% to ~95%.

Causal masking optimization. The lower-triangular mask means the upper-right triangle of the attention matrix is zero. FA-2 skips entire K tiles that are guaranteed to be fully masked (where k_tile_start > q_tile_end), saving roughly half the work on causal attention.

Forward and backward both improved. The new partition and reduced non-matmul applies to both passes; backward gets ~2x as well.

Net: ~2x throughput vs FA-1, getting attention up to 50-70% of GPU theoretical peak, vs ~25-40% for FA-1.


14. FlashAttention-3 deltas (Hopper-targeted, 2024)

Shah, Bikshandi, Dao et al. 2024. Targets H100 specifically. Not just an algorithmic improvement; it leverages new hardware features:

TMA (Tensor Memory Accelerator). H100 has dedicated hardware for async bulk memory copies between HBM and SRAM. FA-3 uses TMA to overlap loading the next K/V tile with computing on the current one. This hides HBM latency in the inner loop.

Asynchronous WGMMA tensor cores. H100's WGMMA instruction issues matmul work to the tensor cores asynchronously-the warp can keep computing other things (softmax, normalization) while a previous matmul is still finishing. FA-3 schedules QK^T, softmax, and PV simultaneously on different warps of the same warp group. This is "warp specialization": some warps fetch, some compute matmul, some compute softmax, all running concurrently. The pipeline is fully filled.

FP8 support. H100 has FP8 tensor cores at ~2x the throughput of BF16. FA-3 supports FP8 for Q, K, V with online quantization scales. Critical for serving models like Llama-3 in FP8 at maximum throughput.

Net: ~2x over FA-2 on H100 for BF16, ~5x for FP8. FA-3 gets attention to 75% of peak BF16 theoretical and 80%+ of peak FP8 theoretical.

FA-3 is Hopper-only because the techniques rely on H100-specific hardware. On A100 you still want FA-2.


15. The decode-time variant: flash_attn_with_kvcache

Decode is a different beast from prefill. In decode:

  • Q has length 1 (one new token).
  • K and V are read from the existing KV-cache.
  • The new K and V need to be appended to the cache.
  • Causal masking is implicit: the new Q sees all cached K (which are all earlier positions by construction).

A naive decode would be three separate kernels: append K, append V, attention. Three HBM round-trips. The fused flash_attn_with_kvcache:

  1. Take new Q (B, H_q, 1, d_h), new K and V (B, H_kv, 1, d_h), the existing KV-cache, and a per-batch sequence-length array.
  2. Inside one kernel, write the new K, V into the cache at the correct slots (using the seq-len array).
  3. Run flash attention with Q against the now-updated K, V cache.
  4. Return output.

The win: only one HBM read of the cache, only one write of the new K/V slot, no separate launch overhead. On long contexts this is the difference between 200 and 350 tokens/sec/request on a single H100.

For paged KV layouts, there is a paged variant that takes a block table and gathers K/V from non-contiguous physical blocks. This is the kernel that production engines (vLLM, SGLang, TRT-LLM) actually call during decode.


16. Practical exercises

Six problems. Solve all six with pencil and paper, then in code.

Exercise 1: Derive the online softmax for 3 blocks

Given scores split into three blocks s = [s^(1), s^(2), s^(3)] with corresponding values V = [V^(1), V^(2), V^(3)]. Set initial state m = -inf, l = 0, O_ = 0. Step through the algorithm:

  1. After processing s^(1): write m_1, l_1, O_1 in terms of s^(1), V^(1).
  2. After processing s^(2): write m_2, l_2, O_2 using m_1, l_1, O_1 and s^(2), V^(2). Show the correction factor exp(m_1 - m_2).
  3. After processing s^(3): write m_3, l_3, O_3.
  4. Show that O_3 / l_3 equals the all-at-once softmax(s) @ V.

This exercise tests that you can keep the inductive invariant straight across multiple correction steps. The algebra is identical to the proof in Section 12.2.1 but explicit for T = 3.

Exercise 2: KV-cache size for various models

Compute the KV-cache size in bytes per request for the following configurations at seq_len = 4096, BF16 (2 bytes/element). Use the formula: bytes = 2 * num_layers * H_kv * seq_len * d_h * dtype_bytes.

Model num_layers H_q H_kv d_h KV cache (MiB)
Llama-3-8B 32 32 8 128 ?
Llama-3-70B 80 64 8 128 ?
Mistral 7B 32 32 8 128 ?
GPT-3 (MHA) 96 96 96 128 ?

(Spoiler: 1024, 2560, 1024, 36864 MiB respectively.) Notice the GPT-3 result: 36 GiB of KV per request at 4K context with MHA. This is why GQA is non-negotiable for serving.

Repeat at seq_len = 32768 and explain in one sentence why GPT-3-style MHA is infeasible at long context for serving.

Exercise 3: Implement causal-masked MHA in PyTorch

Implement the function

def my_attention(Q, K, V, causal=True):
    # Q, K, V: (B, H, S, d_h)
    # returns: (B, H, S, d_h)
    ...

matching torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=causal) to within 1e-5 in BF16.

Edge cases to test: - S = 1 (the decode-step shape). - S = 1 query against a length-N cache. - Different lengths for query and key (use the lower-triangular mask appropriately).

Exercise 4: Implement RoPE and verify the relative-position property

Implement apply_rope(x, positions) with d_h = 64 and base = 10000. Then numerically verify:

  1. For random q, k and various m, n, the value of apply_rope(q, m) . apply_rope(k, n) depends only on (m - n), not on m and n separately.
  2. Specifically, vary m, n simultaneously by the same shift and show the dot product is invariant to within float precision.

Exercise 5: GQA broadcasting

In GQA, K and V have shape (B, H_kv, S, d_h) but Q has shape (B, H_q, S, d_h). Implement the GQA attention so that each group of G = H_q / H_kv query heads shares one K/V head. Verify that GQA with H_kv = H_q reduces exactly to standard MHA.

Hint: one approach is to repeat-interleave K and V along the head axis by G; another is to reshape Q to (B, H_kv, G, S, d_h) and broadcast. The reshape approach saves memory.

Exercise 6: Decode-time KV growth and HBM bandwidth bound

Take Llama-3-70B (GQA, 80 layers, H_kv = 8, d_h = 128, BF16).

  1. Compute the KV-cache size at S = 0, 1024, 4096, 16384, 65536.
  2. Assuming HBM bandwidth = 3.35 TB/s, compute the minimum time per decode step due to reading the entire cache.
  3. Plot tokens/sec vs context length implied by this lower bound.
  4. Compare to actual measured decode throughput from a real engine (e.g., vLLM benchmarks). Where does the gap come from?

Expected answers (rough):

  • At S = 4096: KV is ~1.25 GiB; bandwidth-bound time per step is ~0.4 ms; that is a ceiling of ~2700 tok/s/request. Real systems hit 50-150 tok/s at this point because they are not running solo decode at full bandwidth-they batch multiple requests.
  • At S = 65536: KV is ~20 GiB; ceiling drops to ~170 tok/s/request.

The exercise drives home that decode tok/s drops linearly with context length, and that the dominant cost is HBM bandwidth on KV, not FLOPs.


Closing notes

What you should now be able to derive without notes:

  1. Why divide by sqrt(d_k) (variance argument, Section 2.2).
  2. The exact KV-cache memory formula for any decoder model (Section 10.3).
  3. Why GQA exists and what factor it saves (Section 5).
  4. RoPE's relative-position property by complex number rotation (Section 6.3).
  5. The online softmax algorithm and its proof of equivalence (Section 12.2).
  6. The asymptotic argument: standard attention is O(S^2 d) compute and O(S^2) memory; FlashAttention is O(S^2 d) compute and O(S * d) memory with O(S * d) effective HBM traffic.

What you should know from memory but cannot derive:

  • Concrete per-version deltas of FA-1 → FA-2 → FA-3 (Sections 13, 14).
  • Specific architecture choices of Llama-3 / Mistral (GQA group sizes, SwiGLU dim ratio of 8/3, RoPE base 10000 or extended bases).

Cross-reference the inference-side material in Month 5 of the AI Systems Plan: /home/voseghale/projects/self_dev/AI_SYSTEMS_PLAN/. KV-cache memory formulas, paged attention, and FlashAttention's wall-time implications are the single biggest driver of inference engineering decisions, and they all flow from the math in this chapter.

Deep Dive 08-LLM Inference Serving Systems

Paged attention, continuous batching, vLLM architecture, prefill/decode disaggregation, and the algebra that makes them necessary.

Self-contained. Reading the SOSP and OSDI papers afterward should feel like consolidation, not first contact.


0. Why this chapter exists

A trained LLM is, fundamentally, a function: given a prompt, produce a continuation. In a research notebook, this is one line. In production, it is a distributed system with throughput, latency, fairness, and memory-management problems that look more like an operating system than like deep learning.

This chapter rebuilds that system from first principles. We start by deriving-from arithmetic intensity alone-why serving an LLM is fundamentally a memory-bandwidth problem, why batching is the universal lever, and why the naive batching strategies fail. From there the architecture of vLLM (paged attention + continuous batching) is forced rather than chosen, and so are the more recent designs (chunked prefill, prefix caching, prefill/decode disaggregation).

A reader who finishes this chapter should be able to:

  1. Predict, within a factor of two, the throughput of a given model on a given GPU at a given batch size.
  2. Explain why doubling batch size at decode is nearly free in compute but expensive in KV memory.
  3. Sketch a paged-attention kernel and a continuous-batching scheduler from memory.
  4. Tune vLLM (or any descendant) without poking at flags blindly.
  5. Recognize when the right answer is "split prefill and decode onto different machines."

We will derive every claim. References to SOSP'23 (vLLM / PagedAttention, Kwon et al.), OSDI'22 (Orca, Yu et al.), OSDI'24 (DistServe, Zhong et al.), and the Sarathi-Serve work (Agrawal et al.) are pointers, not load-bearing-the math here stands alone.


1. The two phases of LLM inference

A decoder-only Transformer (GPT-style) has two operational regimes that look like the same forward pass but have very different performance characteristics. Understanding this is the foundation of everything else.

1.1 Prefill

When a request arrives with prompt of length P, the model must compute hidden states (and KV-cache entries) for all P positions before any new token can be sampled. Crucially, all P positions can be processed in parallel in a single forward pass:

  • The attention mask is causal but the matmuls are still over P × d and P × P matrices.
  • For each layer, you do roughly O(P · d²) for the QKV projection, O(P² · d) for attention scores, O(P · d²) for the FFN.
  • Per-token work is large; per-request work is P × per-token work.

Prefill is compute-bound on modern hardware. The GPU can saturate its tensor cores because the matmuls are tall (lots of rows)-there is plenty of arithmetic per byte of weight read.

1.2 Decode

After prefill, generation is autoregressive: produce token P+1, append to context, produce P+2, etc. Each step does a forward pass over a single new token (the previously sampled one), reusing the KV-cache for all earlier positions.

Per-step:

  • QKV projection on a 1 × d vector: O(d²) FLOPs, but reads the full weight matrix parameters.
  • Attention: query is 1 × d, keys/values are S × d_kv (where S is the current context length). Compute O(S · d), memory read `O(S · d_kv) - proportional to context length.
  • FFN: 1 × d against d × 4d and 4d × d weights: O(d²) FLOPs, O(d²) memory.

Decode is memory-bandwidth-bound. We are reading huge weight tensors from HBM and doing a tiny dot-product against a single token vector.

1.3 The cost model (the equation that runs the chapter)

Let us define, for a model:

  • W = total parameter count in bytes (bytes-already accounts for dtype).
  • K(s) = KV-cache size in bytes for a sequence of length s.
  • B_HBM = HBM bandwidth, e.g., 3.35 TB/s for an H100 SXM.
  • F_peak = peak tensor-core throughput, e.g., ~990 TFLOP/s BF16 dense on H100.
  • b = batch size (number of sequences in this iteration).

Decode step time, when memory-bound (we will validate this assumption shortly), is:

$$ T_\text{decode-step}(b) \;\approx\; \frac{W + b \cdot K(s)}{B_\text{HBM}} $$

Read this carefully. The model weights W are read once per step-every sequence in the batch sees the same weights. But each sequence has its own KV-cache, and those do scale with b.

When compute-bound (prefill, or very large batch decode), the time becomes (b · per-token-FLOPs) / F_peak instead. The crossover batch size between regimes is what determines whether batching helps.

1.4 A worked example: Llama-3-70B on H100

  • Parameters: 70B × 2 bytes (BF16) = 140 GB. Doesn't fit on one H100 (80 GB)-assume 2× H100 with tensor parallelism, so each GPU holds ~70 GB.
  • KV-cache per token (Llama-3-70B): 80 layers × 8 KV heads × 128 head dim × 2 (K and V) × 2 bytes = 327,680 bytes/token ≈ 320 KB/token.
  • HBM bandwidth per GPU: 3.35 TB/s.
  • For a sequence at length 2048: KV size ≈ 2048 × 320 KB ≈ 640 MB.

At batch=1, decode step time ≈ (70 GB + 0.64 GB) / 3.35 TB/s ≈ 21 ms per layer-shard, per token. (Real implementations achieve roughly this.)

At batch=32, weights are still 70 GB but KV totals 32 × 0.64 GB = 20.5 GB. Step time ≈ (70 + 20.5) / 3.35 ≈ 27 ms. We've gone from 1 token / 21 ms to 32 tokens / 27 ms-almost a 24× throughput improvement for a 30% latency hit. That is why batching is the central lever.

Hold this picture: weight bytes amortize across the batch, KV bytes do not.


2. Why decode is memory-bound (rigorous derivation)

Every performance optimization in this chapter falls out of a single number: arithmetic intensity, the ratio of FLOPs done per byte read.

2.1 The roofline crossover

GPUs have a roofline: at low arithmetic intensity, you are bandwidth-limited (achieved_FLOPs = intensity × B_HBM). At high intensity, you are compute-limited (achieved_FLOPs = F_peak). The crossover happens at:

$$ I_\text{crossover} = \frac{F_\text{peak}}{B_\text{HBM}} $$

For H100 SXM BF16: 990 TFLOP/s ÷ 3.35 TB/s ≈ 295 FLOP/byte. (Numbers vary; 280–310 is the right ballpark.)

Below ~295 FLOP/byte, you are memory-bound. Above, compute-bound.

2.2 Decode arithmetic intensity at batch=1

Take the FFN block-typically the dominant compute. Llama-3-70B FFN is d=8192, intermediate d_ff ≈ 28672 (SwiGLU). Per token:

  • FLOPs: 2 · d · d_ff · 3 (the 3 is for SwiGLU's three matmuls) ≈ 2 · 8192 · 28672 · 3 ≈ 1.4 GFLOP.
  • Bytes read (weights): d · d_ff · 3 · 2 (BF16) ≈ 1.4 GB of weights for this one block per layer.

Intensity = 1.4 GFLOP / 1.4 GB = 1 FLOP/byte. That is 295× below the crossover. We are deep in the memory-bound regime.

Adding batch b keeps the bytes constant (still one weight read per step) and multiplies the FLOPs by b. So intensity scales linearly with batch:

$$ I_\text{decode}(b) \approx b \cdot 1\,\text{FLOP/byte} $$

We need b ≈ 295 to hit the compute roof. In practice nobody runs decode at batch=295 because (a) KV memory runs out long before that, and (b) you hit other limits-kernel launch overhead, attention scaling, etc. But the direction is right: bigger batch ⇒ closer to compute-bound ⇒ better hardware utilization.

2.3 Attention is the wrinkle

The FFN is straightforwardly batch-friendly because all tokens share weights. Attention is not. Each sequence reads its own KV-cache. So in the attention block:

  • FLOPs: O(b · S · d) (for each of the b sequences, dot-product query against S KV entries).
  • Bytes: O(b · S · d_kv) (each sequence loads its own KV).

Intensity is independent of batch size for attention: ≈ d / d_kv, which is roughly the head-replication factor in GQA. For Llama-3-70B with 64 query heads and 8 KV heads, that's 8 FLOP/byte. Still way below 295.

This is why long contexts hurt so much-attention's bandwidth cost grows with sequence length but its arithmetic intensity does not. Even with infinite batching, attention stays memory-bound. (Flash-attention helps with on-chip SRAM tiling but cannot change the off-chip KV reads.)

2.4 Takeaway equation

$$ \boxed{\;T_\text{step}(b, S) \approx \underbrace{\frac{W}{B_\text{HBM}}}{\text{weight read, amortized}} + \underbrace{\frac{b \cdot K\text{per-token} \cdot S}{B_\text{HBM}}}_{\text{KV read, scales with } b\cdot S}\;} $$

Two terms: weights (great news for batching) and KV (bad news for batching). The whole architecture of vLLM is about managing the second term as aggressively as possible.


3. The big idea: batching saves bandwidth

Let's stare at the equation again with the question: "what does batching buy?"

  • Per-token throughput at batch b: b / T_step(b, S).
  • At small batch, T_step ≈ W/B_HBM (weights dominate). Throughput grows linearly with b.
  • At large batch, T_step ≈ b · K_per-token · S / B_HBM (KV dominates). Throughput plateaus.

The plateau happens when the KV term equals the weight term:

$$ b^* \approx \frac{W}{K_\text{per-token} \cdot S} $$

For Llama-3-70B at 2K context, b* ≈ 70 GB / 640 MB ≈ 110. In principle we want batch ~100. In practice we are limited by total KV memory.

KV budget: each GPU has 80 GB. Weights take 70 GB. That leaves ~10 GB for KV-cache. At 320 KB/token, that is 30K tokens total across the batch. At 2K context per sequence, that's at most 15 concurrent sequences.

So the physics says we want batch=100, but the capacity says batch=15. The gap between these is exactly the design pressure that produced paged attention. Every byte of KV memory wasted is a sequence we can't serve.

3.1 Why weight-amortization scales the way it does

Subtle but important: the weight read is once per iteration, not once per request. So as long as you can pack b sequences into one forward pass, the weight bandwidth cost is fixed. This is what makes "continuous batching" so powerful-every iteration we re-fill the batch up to capacity, so the GPU never sees a half-empty batch.


4. KV-cache management-the central problem

A sequence of length S needs S · K_per-token bytes of KV-cache. Two facts make this hard:

  1. S is unknown at admission time. Generation continues until an EOS token or max_tokens. We may stop after 5 tokens or 5000.
  2. S grows during the sequence's lifetime. The KV-cache for a sequence is mutated in place at every decode step.

4.1 The naive approach: contiguous per-request buffer

For each request, allocate a contiguous buffer of size max_seq_len · K_per-token upfront. Fill it as tokens arrive.

Problems:

  • Internal fragmentation. A 100-token output in a 4096-token buffer wastes 97% of the allocation.
  • External fragmentation. Buffers of varying sizes leave holes the allocator can't fill.
  • No prefix sharing. Two requests with the same system prompt store identical KV bytes twice.
  • Kills batch size. With 10 GB of KV budget and 4K-token reservations, you fit 10 GB / (4096 · 320 KB) ≈ 8 sequences.

Numerically: in a 70B-on-80GB setup, naive contiguous allocation caps batch at ~8. The physics said 100. We are leaving an order of magnitude of throughput on the floor.

4.2 What we want from a KV allocator

  • Fine-grained allocation-give a sequence one chunk at a time, as it grows.
  • No fragmentation-chunks are uniform size, so any free chunk fits any need.
  • Cheap relocation-but actually, no relocation: chunks are virtually addressed.
  • Sharing-distinct sequences with identical prefix-content share the same physical chunks.
  • Cheap GC-reference counting is enough.

This is, almost word for word, the description of a virtual-memory system. Which is exactly the analogy Kwon et al. drew.


5. Paged attention (the heart of vLLM)

PagedAttention treats KV-cache like an OS treats process memory: it is virtually contiguous but physically paged.

5.1 The block pool

HBM (minus weights and activations) is carved up into uniform physical blocks. Each block holds the KV-cache for B_block consecutive token positions in one sequence:

block_size_bytes = B_block * num_kv_heads * head_dim * 2 (K and V) * dtype_bytes

For Llama-3-70B with B_block = 16, BF16: 16 · 8 · 128 · 2 · 2 = 65,536 bytes = 64 KB per block per layer. Across 80 layers: 5.12 MB per block (this is the per-sequence cost of one block of KV across the whole stack). With 10 GB of KV budget that's ~2000 blocks total-enough for 2000 · 16 = 32K tokens distributed however we like.

B_block = 16 is the typical choice-large enough to amortize indirection overhead, small enough to keep wasted memory under one-block-per-sequence.

The block pool is just an array P[0..N-1] of these fixed-size buffers, plus a free-list F of unused indices.

5.2 The per-request page table

Each in-flight request has a block table-a list of physical block indices, in logical order:

request 7's KV at logical position 145
  → block 145 // 16 = 9 in the table
  → physical block index = block_table[7][9]
  → byte offset within block = (145 % 16) * per_token_kv_bytes

This is byte-for-byte the same as a page table mapping virtual addresses to physical frames. The block table is small (one int per 16 tokens-8K context = 512 entries = 2 KB).

5.3 The block manager

The block manager owns the free list and exposes:

class BlockManager:
    def can_allocate(self, request) -> bool:
        # check if free_blocks >= ceil(request.context_len / block_size)
        ...

    def allocate(self, request):
        # pop blocks from free_list, build initial block_table
        ...

    def append_slot(self, request):
        # decode added a token; if last block is full, allocate a new one
        if request.last_block_full():
            block = self.free_list.pop()
            request.block_table.append(block)
            self.refcount[block] = 1

    def free(self, request):
        for block in request.block_table:
            self.refcount[block] -= 1
            if self.refcount[block] == 0:
                self.free_list.push(block)

    def fork(self, parent_request, child_request):
        # for prefix sharing or beam search:
        # child reuses parent's blocks; bump refcount
        for block in parent_request.block_table:
            self.refcount[block] += 1
        child_request.block_table = list(parent_request.block_table)

Reference counting is doing the heavy lifting for prefix sharing-see §10.

5.4 The paged attention kernel

The naive attention kernel reads K and V from contiguous tensors. The paged kernel takes the block table as an extra input and does an indirect read:

for each query token q in this iteration:
    for each block_idx in block_table[q.request_id]:
        load K_block, V_block from P[block_idx]   # gather from non-contiguous memory
        accumulate q · K_block^T into scores
    softmax(scores)
    accumulate scores · V_block into output

The gather is the cost of indirection. Measured overhead in the original paper: ~10% slowdown vs a contiguous-tensor kernel. This is the price of fragmentation-freedom, and it's a steal: you typically triple your effective batch size, so you net 2.5× throughput after eating the 10% kernel overhead.

Implementation tricks that keep the cost down:

  • Block size is a compile-time constant (e.g., 16) so loops unroll.
  • Within a block, KV is contiguous → vectorized load.
  • Block table is fetched once per query, kept in registers/shared memory.
  • Flash-attention-style on-chip tiling still applies; the page table just controls where the K/V tiles come from.

5.5 Memory waste analysis

The only wasted memory is in the last partial block of each sequence. In the worst case, a sequence of length 17 with block_size=16 occupies 2 blocks but uses only 17/32 of them. Across b sequences, total waste is at most b · block_size · K_per-token bytes-bounded, small.

Compare to naive: waste is `b · (max_seq_len - actual_len) · K_per-token - unbounded relative to actual usage. The difference is 1-2 orders of magnitude in real workloads.

5.6 Diagram: block table

Logical KV view (request 7, 50 tokens so far, block_size=16):

  positions: [0..15]   [16..31]  [32..47]  [48..49]
  block #:     0          1         2         3 (partial)

block_table[7] = [42, 17, 99, 23]
                  │   │   │   │
                  ▼   ▼   ▼   ▼
Physical block pool:
  [ ... | block 17: tokens 16-31 | ... | block 23: tokens 48-49 (partial) | ... | block 42: tokens 0-15 | ... | block 99: tokens 32-47 | ... ]

Physically scattered, logically contiguous. Just like virtual memory.

5.7 Exercise (do it now, mentally)

How many blocks does an 8K-context sequence need at block_size=16? 512. How much KV memory is that for Llama-3-70B BF16? 512 · 80 layers · 64 KB/layer/block = 2.6 GB. Sanity-check against §1.4: 8192 · 320 KB/token = 2.6 GB. Match.


6. Continuous batching (Orca's iteration-level scheduling)

Paged attention solves the spatial problem (memory). Continuous batching solves the temporal problem (when do we run which sequence?).

6.1 The naive (request-level) batching strategy

The dumb scheduler:

batch = first N requests in queue
run them all in parallel until ALL are done
return results

This wastes massive amounts of GPU time. Suppose 8 requests batch together, 7 finish at token 50 but one runs to token 2000. For 1950 decode steps, the GPU runs a batch of one. We just gave back 7/8 of our throughput.

6.2 The Orca insight: schedule per iteration, not per request

What if the scheduler runs between every decode step?

loop forever:
    1. select a batch from runnable requests, subject to memory + compute budget
    2. run one forward pass over the batch
    3. for each sequence in batch:
         - sample next token
         - if EOS or max_tokens: mark finished, free its blocks
         - else: append token, request stays runnable
    4. admit new requests if there's spare capacity

Finished requests leave immediately. Newly-arrived requests join immediately. The batch is continuously refilled-hence "continuous batching."

6.3 The scheduler in pseudocode

def schedule_iteration(scheduler):
    # 1. Take all currently-running requests (those with prefilled KV)
    running = scheduler.running_queue

    # 2. Drop any that finished last step
    running = [r for r in running if not r.finished]

    # 3. Try to grow the batch by admitting waiting requests
    waiting = scheduler.waiting_queue
    while waiting and scheduler.can_admit(waiting[0]):
        req = waiting.popleft()
        scheduler.block_manager.allocate(req)  # for prefill
        running.append(req)

    # 4. Each running req consumes one decode step (or one prefill chunk - see §8)
    for r in running:
        scheduler.block_manager.append_slot(r)  # may allocate a block

    # 5. If memory pressure, evict some running reqs (preempt - see §9)
    while scheduler.over_budget():
        victim = scheduler.choose_victim()
        scheduler.preempt(victim)
        running.remove(victim)
        scheduler.waiting_queue.appendleft(victim)  # priority for re-admission

    return running

The interesting part: step 3 (admission) and step 5 (preemption) are policy. FCFS by default. You can plug in priority, fairness, deadline-aware, etc.-Orca's contribution was the mechanism; the policy is yours.

6.4 Why this is so much better

  • No long-tail blocking. A 5-token request and a 5000-token request that arrived together don't share fate.
  • Constant-batch GPU utilization. As long as the queue is non-empty, the scheduler keeps the running batch near capacity.
  • Bounded latency overhead. The scheduler runs in microseconds; the forward pass takes tens of milliseconds. Scheduler overhead is <1%.

In production, continuous batching alone (without paged attention) typically gives 2-4× throughput vs request-level batching. Combined with paged attention: 5-10×.

6.5 The "iteration" subtlety

In Orca's paper an "iteration" is a single forward pass that processes one new token per running sequence, plus prefill on any newly-admitted sequences. In modern vLLM, prefill and decode mix more flexibly-see §8.


7. Putting it together: vLLM architecture

We now have all the parts. Here is the system, in dependency order.

7.1 Components

                      ┌─────────────────┐
   HTTP/OpenAI API ─→ │   API Server    │
                      └────────┬────────┘
                               │ enqueue request
                      ┌─────────────────┐
                      │     Engine      │  (main loop, blocking)
                      │  ┌───────────┐  │
                      │  │ Scheduler │  │  picks batch each iter
                      │  └─────┬─────┘  │
                      │        │        │
                      │  ┌─────▼─────┐  │
                      │  │  Block    │  │  paged-KV alloc
                      │  │  Manager  │  │
                      │  └─────┬─────┘  │
                      │        │        │
                      │  ┌─────▼─────┐  │
                      │  │   Model   │  │  forward pass
                      │  │ Executor  │  │  (paged attn kernel)
                      │  └─────┬─────┘  │
                      │        │        │
                      │  ┌─────▼─────┐  │
                      │  │  Sampler  │  │  next token per seq
                      │  └─────┬─────┘  │
                      └────────┼────────┘
                               │ output tokens
                      ┌─────────────────┐
                      │ Output Processor│  detokenize, stream
                      └─────────────────┘

API server. HTTP front-end, often OpenAI-compatible. Translates requests into the engine's internal format, exposes streaming responses.

Engine. The orchestrator. Owns the scheduler, block manager, model executor, sampler, output processor. Single-threaded main loop.

Scheduler. Decides which requests run this iteration. Maintains running and waiting queues. Enforces memory budget, scheduling policy, max-batch-size, max-batched-tokens.

Block manager. Owns the physical block pool, free list, per-request block tables, refcounts. Exposes allocate, append_slot, free, fork, swap_out, swap_in.

Model executor. Runs the forward pass on whatever batch the scheduler hands it. Internally calls paged-attention kernels with block tables. May span multiple GPUs (tensor or pipeline parallel).

Sampler. Given final-layer logits, samples the next token per sequence. Supports temperature, top-k, top-p, repetition penalties, beam, etc.

Output processor. Detokenizes, accumulates output, streams partial completions back via the API server.

7.2 The engine main loop

class Engine:
    def __init__(self, model, scheduler, executor, sampler, out_proc):
        self.model = model
        self.scheduler = scheduler
        self.executor = executor
        self.sampler = sampler
        self.out_proc = out_proc

    def step(self):
        # 1. Build the next batch
        batch = self.scheduler.schedule_iteration()
        if not batch:
            return  # nothing to do, idle

        # 2. Forward pass: prefill chunks + decode steps mixed
        logits = self.executor.forward(
            batch,
            block_tables=self.scheduler.block_manager.tables_for(batch),
        )

        # 3. Sample next tokens
        next_tokens = self.sampler.sample(logits, batch.sampling_params)

        # 4. For each request, append token, check stop conditions
        for req, tok in zip(batch, next_tokens):
            req.append_token(tok)
            if self.is_done(req, tok):
                req.finish()
                self.scheduler.block_manager.free(req)
            self.out_proc.emit(req, tok)

    def run(self):
        while True:
            self.handle_new_requests()
            self.step()

The whole system is a tight loop around step(). Iteration time is typically 10-50 ms. The scheduler runs at the same cadence-a key reason latency is bounded.

7.3 Where everything plugs in

  • Tensor parallelism lives inside `executor.forward - each GPU runs a shard of the layers, NCCL all-reduces between shards. Transparent to the scheduler.
  • Pipeline parallelism turns the loop into a pipelined sequence of micro-batches. More involved.
  • Quantization (INT8, INT4, FP8) changes weight bytes but not the architecture.
  • LoRA / multi-tenant adapters plug into the executor: it composes the base weights with the active adapter at forward time.

8. Mixing prefill and decode in continuous batching

A subtle but very important point. Prefill and decode have different compute profiles:

  • A prefill of 1000 tokens: ~1000× the FLOPs of a single decode step on the same model.
  • A decode step: tiny FLOPs but reads all weights.

If you run a "pure prefill" iteration, decode requests stall (TPOT spikes). If you run a "pure decode" iteration, prefill requests stall (TTFT spikes). Naively alternating wastes hardware.

8.1 Mixed batches

Modern engines run mixed iterations: a single forward pass that includes prefill tokens for newly-admitted sequences and decode tokens for already-running ones. The kernel is careful about the per-token positions and the per-sequence KV layouts, but the high-level idea is:

  • The batch has N_decode sequences each contributing 1 token, plus N_prefill chunks contributing many tokens.
  • Total tokens in this iteration: T = N_decode + sum(prefill_chunk_lengths).
  • Forward pass is one giant matmul over T tokens.

8.2 Chunked prefill (Sarathi-Serve)

Problem: a 4000-token prefill in one chunk takes much longer than a 1-token decode. Mixing them naively means the iteration is gated by the prefill length-TPOT spikes for everyone.

Sarathi-Serve's fix: bound each iteration's total token budget. Split long prefills into chunks of, say, 512 tokens. Each iteration processes min(remaining_prefill, chunk_size) prefill tokens for that request, alongside 1 decode token per running request. Total token count per iteration is roughly constant ⇒ wall-clock per iteration is roughly constant ⇒ TPOT is bounded.

This is a knob: max_num_batched_tokens in vLLM. Setting it to, say, 2048 says "every iteration's prefill+decode work, in token units, sums to at most 2048." Lowers TPOT variance dramatically; slightly hurts throughput on pure-prefill workloads.

8.3 Pseudocode for chunked prefill

def build_iteration_batch(scheduler, max_batched_tokens):
    batch = []
    budget = max_batched_tokens

    # 1. Decode steps for all running sequences (1 token each)
    for r in scheduler.running:
        batch.append(DecodeStep(r))
        budget -= 1

    # 2. Fill the rest with prefill chunks
    for r in scheduler.waiting + scheduler.partially_prefilled:
        if budget <= 0: break
        chunk_len = min(r.remaining_prefill, budget)
        if chunk_len < MIN_CHUNK and budget < r.remaining_prefill:
            continue  # don't bother with tiny chunks unless we'll finish
        batch.append(PrefillChunk(r, chunk_len))
        budget -= chunk_len

    return batch

The scheduler is balancing two competing pressures: keep TPOT low (cap the iteration size) and keep TTFT low (admit prefills aggressively). Both pressures are knobs the operator sets per-deployment.

8.4 Why mixed batches are bandwidth-good

The weight read is still once per iteration. Whether the iteration has 32 decode tokens or 32 prefill tokens or 16+16, the weight bytes are constant. So mixing prefill into spare decode capacity is free in bandwidth. It's only paid in extra FLOPs-which we have plenty of when memory-bound.

This is why chunked prefill is such a clean throughput win: it converts "spare bandwidth" (idle decode iteration) into "useful prefill progress."


9. Eviction and preemption

What happens when a new admission would push KV-cache over budget, and we have no spare blocks? The scheduler evicts a victim sequence.

9.1 Two strategies

Swap (CPU offload). Move the victim's KV-cache blocks from HBM to CPU pinned memory. When the sequence is rescheduled, swap them back. Cost: PCIe bandwidth one-way. For Llama-3-70B, swapping a 2K-context sequence is 2K · 320 KB = 640 MB × PCIe ≈ 25 ms one-way, 50 ms round-trip. Real, but bounded.

Recompute. Drop the KV-cache. When the sequence is rescheduled, re-run prefill on its prompt + already-generated tokens. Cost: a full prefill, which can be cheap (chunked prefill is fast) for short contexts but expensive for long ones.

9.2 Choosing the right strategy

  • Swap wins when (a) PCIe bandwidth is plentiful and (b) prefill is expensive (long context).
  • Recompute wins when (a) generation is short (tiny KV but full prompt prefill is short) and (b) you don't want to manage a CPU-side buffer.

vLLM exposes swap_space (GiB of CPU RAM reserved)-set to nonzero to enable swap. Otherwise it recomputes.

9.3 Victim selection

Default policy: evict the most recently admitted (LIFO)-protects long-running sequences that have already invested compute. Other policies: lowest priority, longest-remaining, fairness-based.

def choose_victim(running):
    # vLLM default: most recently scheduled
    return running[-1]

The choice matters for tail latency but not for throughput.

9.4 Preemption pseudocode

def preempt(req, mode):
    if mode == "swap":
        for block in req.block_table:
            cpu_block = self.swap_pool.allocate()
            copy_hbm_to_cpu(block, cpu_block)
            req.swapped_blocks.append(cpu_block)
        self.block_manager.free(req)  # return HBM blocks
        req.state = "swapped"
    elif mode == "recompute":
        self.block_manager.free(req)
        req.state = "waiting"   # will re-prefill on resume
        req.kv_dropped = True

Preemption is rare in well-tuned systems. If it happens often, you've over-admitted; lower max_num_seqs.


10. Prefix caching

Most production LLM workloads share prefixes:

  • System prompts: every chat call from an app starts with the same 500-token system message.
  • Multi-turn conversations: turn N+1 has all of turn N's prefix.
  • RAG: long retrieved-document prompts repeat across queries about the same doc.
  • Few-shot prompts: shared examples across requests.

If two sequences share a prefix of length L, they have identical KV-cache for those L positions (because attention is causal-KV at position i only depends on tokens 0..i). Storing them twice is waste.

10.1 The mechanism

Block-level content hashing. When a prefill produces a full block (16 tokens):

block_hash = hash(parent_block_hash, tokens[block_start:block_end])

This is a cumulative hash-block at position k*16 depends on the hashes of all earlier blocks. So two requests with identical first-`L - token prefix produce identical block hashes for all blocks fully inside that prefix.

The block manager keeps a hash → block_idx map. On a hit:

def maybe_share_block(req, block_idx_in_seq, block_hash):
    if block_hash in self.hash_table:
        existing_block = self.hash_table[block_hash]
        self.refcount[existing_block] += 1
        req.block_table[block_idx_in_seq] = existing_block
        return True   # reused
    return False

Reference counting handles cleanup: when a block's refcount drops to zero (all referring sequences finished), it goes back on the free list.

10.2 Copy-on-write

Subtle: a shared block is read-only-the sampled tokens diverge after the prefix. But the prefix blocks are never written to during decode (KV at those positions is fixed). So no CoW is needed for prefix sharing per se.

For beam search or speculative decoding, where two children of one parent diverge mid-block, CoW is needed: when child writes, allocate a new block, copy the parent's prefix portion, then write. Same block manager primitive-just a different caller.

10.3 Hit rates in production

This is one of those features where the empirical numbers are larger than you'd expect. In real chat workloads (multiple tenants, shared system prompts, multi-turn):

  • System prompts shared across all users: hit rates above 50% on prefix tokens are routine.
  • Multi-turn conversations: each turn re-uses N-1 turns' worth of KV. Effective compute reduction at decode is small (KV is already there) but prefill cost goes to near zero-you only prefill the new user message.
  • Mass RAG over a fixed corpus: prompts share the retrieved-doc tokens, hit rates can exceed 70%.

Because prefill is the expensive part, prefix caching often dominates the throughput win for chat-style workloads. Enable it (enable_prefix_caching=True) by default.

10.4 Eviction policy for prefix blocks

A prefix block with refcount=0 doesn't have to be freed-it might be useful later. vLLM (and similar systems) keep evicted-but-cached blocks on a separate "evictor" list (LRU). When new allocation needs space, evict from there first. This way the prefix cache is bounded by KV memory, not by active sequences.


11. Speculative decoding (preview-full treatment in deep dive 10)

Decoding one token per step is wasteful when the model is memory-bound: we read all weights to do a tiny matmul. Idea: have a small "draft" model produce K candidate tokens cheaply, then verify them in one big-model forward pass.

draft_tokens = draft_model.generate(K)            # cheap, K small-model steps
target_logits = target_model.forward(draft_tokens) # ONE big-model step
accept tokens that match target's argmax (or pass rejection sampling)
keep prefix of accepted tokens; resume from first reject

Why it works: the big model's forward pass over K tokens is roughly the same wall-clock as over 1 token (memory-bound-weights dominate). So we get up to K tokens per big-model step, at the cost of one cheap draft pass and one big verification pass.

11.1 Integration with paged attention

The verification pass is a "K-token decode"-we extend each sequence's KV-cache by up to K positions in one go. Block manager must allocate up to K new slots upfront (or fewer if blocks have room). After verification, if only k < K are accepted, the manager must truncate the block table and reset positions for the rejected suffix.

def speculative_step(req, draft_tokens, target_logits):
    # Append slots for K candidate tokens
    for _ in range(K):
        block_manager.append_slot(req)

    # Verify
    accepted = []
    for i, (draft, target) in enumerate(zip(draft_tokens, target_logits)):
        if accept(draft, target):
            accepted.append(draft)
        else:
            break

    # Truncate KV to len(accepted) + 1 (the corrected token)
    block_manager.truncate(req, prev_len + len(accepted) + 1)
    return accepted + [target_argmax_at_first_reject]

Realistic acceptance rates with a good draft model: 60-80%, giving ~2-3× decode throughput.

This is a preview; deep dive 10 derives the rejection sampler, draft-model selection, and tree-based variants (Medusa, EAGLE).


12. Prefill/decode disaggregation

The most recent (and perhaps most consequential) shift in inference architecture: stop running prefill and decode on the same workers.

12.1 The problem with co-location

Recall §1: prefill is compute-bound, decode is memory-bound. They want different things from the hardware:

  • Prefill is happy with low memory bandwidth, high FLOPs. Likes large batches of tokens (which it has-many prompt tokens). Hits the FLOP roof easily.
  • Decode is starved for bandwidth, indifferent to FLOPs. Wants large batches of sequences but each is just 1 token.

When you mix them on the same GPU:

  • The mixed batch cannot be optimal for either. Sarathi-Serve's chunked prefill (§8) is the best you can do with co-location, but you're still leaving throughput on the table.
  • A long prefill iteration spikes TPOT for all the decode-only requests in that batch.
  • You can't pick different hardware for the two phases (e.g., use H100 for prefill, A100 for decode).

12.2 The disaggregation idea

Run two pools:

  • Prefill workers: optimized for compute. Run pure prefill, no decode. Smaller KV-cache footprint (sequences leave after prefill).
  • Decode workers: optimized for memory bandwidth and capacity. Run pure decode batches.

Workflow:

  1. Request arrives at router.
  2. Router sends it to a prefill worker.
  3. Prefill worker runs prefill, produces final-layer hidden states + KV-cache for the prompt.
  4. KV-cache is transferred to a decode worker.
  5. Decode worker generates tokens until completion.

12.3 The KV-cache transfer

This is the load-bearing engineering piece. KV-cache for a 2K-token Llama-3-70B prompt is 640 MB (§1.4). Moving that over PCIe takes ~25 ms; over NVLink, single-digit ms; over RDMA (200 Gbps Infiniband), ~25 ms.

Key tricks:

  • Layer-by-layer streaming: send layer L's KV while prefill computes layer L+1. Hides most of the transfer behind computation.
  • Zero-copy + RDMA: GPUDirect-RDMA writes the KV-cache straight from prefill GPU's HBM to decode GPU's HBM, without CPU involvement.
  • Block-aligned transfers: send paged-attention blocks as units; integrate with block manager on the receiving side.

The transfer cost is amortized over all the decode tokens: if the prompt generates 200 output tokens, 25 ms of transfer is 25 / (25 + 200·30ms) ≈ 0.4% overhead. Cheap.

12.4 Why it helps

DistServe (Zhong et al., OSDI'24) and Mooncake show:

  • Prefill workers run at much higher GPU utilization (80%+ of FLOP roof) than co-located workers.
  • Decode workers run at higher batch sizes (no prefill stealing KV memory) → better bandwidth amortization.
  • The two pools can be sized independently to match workload mix.
  • Different SLO classes can be honored independently-TTFT comes from prefill pool, TPOT from decode pool.

End-to-end throughput improvements: 1.5-3× over co-located continuous batching, depending on workload.

12.5 When not to disaggregate

  • Tiny deployments (1-2 GPUs): the network overhead and operational complexity dominate.
  • Workloads where prefill is small (chat with short messages, mostly decode)-co-located handles fine.
  • Workloads where prefill dominates (long prompts, short outputs)-you may want all workers to be prefill-shaped.

Disaggregation pays off when you have enough scale to dedicate GPUs and a workload mix that genuinely has both regimes active.

12.6 Splitwise/Mooncake variants

  • Splitwise: similar idea, distinguishes "prompt" and "token" phases. Microsoft prod.
  • Mooncake (Moonshot AI's Kimi): pushes further, treats KV-cache as a first-class distributed object with its own caching tier (HBM → DRAM → SSD), and routes requests to maximize prefix-cache hits at the global level.

13. Performance metrics and SLOs

You cannot tune what you cannot measure. The standard metrics:

13.1 TTFT-Time to First Token

End-to-end time from request arrival to first generated token. Components:

  • Queue wait: how long the request sits in waiting_queue before scheduling.
  • Prefill latency: time to process the prompt. Roughly P · per-token-prefill-time for prompt length P.
  • First-decode latency: one decode step time.

For a chat UI, TTFT determines "feels responsive." Typical SLO: p99 < 1 second for short prompts.

13.2 TPOT-Time Per Output Token

Once decoding starts, average decode step time. Determines streaming output speed. Typical SLO: p99 < 100ms per token (10 tok/s minimum) for chat. For high-quality experiences: 30-50ms (20-30 tok/s).

TPOT is dominated by per-iteration time, which is dominated by (W + b·KV) / B_HBM. Tuning levers: batch size (lower → lower TPOT, lower throughput) and chunked prefill chunk size (lower → less interference).

13.3 Throughput

Total output tokens generated per second across all concurrent requests. The headline number on benchmarks. Limited by hardware: at saturation, throughput ≈ b_max · B_HBM / (W + b_max · KV_max).

These three metrics trade off:

  • Higher batch → higher throughput, higher TPOT.
  • Lower batch → lower TPOT, lower throughput.
  • More aggressive prefill admission → lower TTFT, higher TPOT spikes.

The SLO formulation looks like:

maximize throughput
s.t. TTFT_p99 ≤ X ms
     TPOT_p99 ≤ Y ms

And the operator's job is to pick the levers that satisfy the SLO at maximum throughput.

13.4 Other useful metrics

  • Goodput: throughput counting only requests that met SLO. Better than raw throughput for capacity planning.
  • Queue length over time: leading indicator of saturation.
  • KV utilization: used_blocks / total_blocks. Should be high (80-95%) at saturation.
  • Preemption rate: ideally <1% of requests.
  • Prefix cache hit rate: workload-dependent; higher is free throughput.

14. Tuning vLLM (and descendants)

The names are vLLM-specific but the concepts transfer. (Field names may have evolved between versions; if a flag isn't where I describe, the equivalent exists.)

14.1 gpu_memory_utilization (e.g., 0.9)

Fraction of GPU memory the engine may use. The rest is reserved for system overhead, CUDA workspace, NCCL buffers. Default 0.9.

  • Higher (0.95): more KV blocks → larger batch → higher throughput. Risk of OOM under transient spikes.
  • Lower (0.85): more headroom, fewer KV blocks → smaller batch.

Rule of thumb: start at 0.9. If you see OOMs in production, drop to 0.85. If KV utilization is consistently < 70%, raise to 0.93.

14.2 max_num_batched_tokens (e.g., 4096)

Per-iteration token budget (sum over decode tokens + prefill chunk lengths). The chunked-prefill cap.

  • Higher: longer prefill chunks → faster TTFT but spikier TPOT.
  • Lower: smoother TPOT but TTFT suffers on long prompts.

Rule of thumb: set so one iteration's wall-clock at this token count is roughly your target TPOT. For Llama-3-70B on H100, ~2048-4096 keeps iterations under 50 ms.

14.3 max_num_seqs (e.g., 256)

Hard cap on concurrent in-flight sequences (decode batch size cap).

  • Higher: more concurrent users, higher throughput at saturation.
  • Lower: lower TPOT, fewer preemptions, less scheduling overhead.

Rule of thumb: set just above the highest batch size you actually achieve given KV budget. If KV budget supports batch=64, set `max_num_seqs=80 - caps absurd over-admission, doesn't constrain real load.

14.4 block_size (e.g., 16)

Tokens per paged-attention block.

  • Smaller (8): less last-block waste, more page-table indirection overhead.
  • Larger (32): less indirection, more waste.

Rule of thumb: 16 is the universal sweet spot. Don't change unless you have unusual workloads (very short or very long sequences) and you've measured.

14.5 enable_prefix_caching (bool)

Turn block-level content hashing on. Default: on, in modern vLLM.

  • Cost: small overhead per block to compute and look up hashes; some KV memory used to retain evicted-but-cached blocks.
  • Benefit: massive speedup on shared-prefix workloads (chat, RAG, multi-turn).

Rule of thumb: leave on. Only disable if you have measured that it hurts your workload (rare-pathological case is purely random prompts).

14.6 swap_space (GiB)

CPU RAM reserved for swap-out of evicted KV-cache. Default: 4 GiB or so.

  • Nonzero: enables swap-mode preemption.
  • Zero: forces recompute-mode preemption.

Rule of thumb: set to 4-8 GiB if you have generous host memory and long contexts; 0 otherwise. Recompute is fine for most workloads.

14.7 tensor_parallel_size, pipeline_parallel_size

How to shard the model across GPUs.

  • TP (intra-layer): splits each layer across N GPUs, all-reduce after each. Low latency, NVLink-bandwidth-hungry.
  • PP (inter-layer): splits the layer stack across stages, micro-batching. Higher latency, PCIe-tolerant.

Rule of thumb: TP first, up to NVLink island size (usually 8). Add PP only when you're past one island and need to scale further. For Llama-3-70B on 2× H100: TP=2.

14.8 max_model_len (e.g., 8192)

Maximum context length. Caps the longest sequence you'll serve. KV memory budget per sequence is implied by this.

  • Higher: more flexible but reserves more pessimistic block budget.
  • Lower: tighter packing, more concurrent users.

Rule of thumb: set to your real max prompt+output length, not the model's architectural max. If you only ever serve 4K-context chats, set max_model_len=4096 and reap the batch-size benefits.


15. Practical exercises

These are designed to be done on paper. Solutions follow the spirit of §1-they're all derivable from the cost model.

Exercise 1-KV-cache pool size

You're running Llama-3-70B BF16 on a single 80 GB H100 (assume 1-GPU deployment is somehow possible-say with FP8 quantization at 70 GB → ~35 GB weights). After weights and overhead, 40 GB is left for KV-cache. With block_size=16 and the model's KV-per-token of 320 KB across all layers:

  • How many blocks are in the pool?
  • How many tokens of total KV-cache?
  • At max_model_len=4096, how many concurrent sequences (worst case, fully-grown)?

Answer sketch: Block size in bytes (across layers) = 16 × 320 KB = 5 MB. Pool size = 40 GB / 5 MB = 8000 blocks. Tokens = 8000 × 16 = 128K tokens. Concurrent fully-grown = 128K / 4096 = 32 sequences.

Exercise 2-Decode step time

Llama-3-70B BF16 on 2× H100 (TP=2, 35 GB weights per GPU, 1.675 TB/s effective bandwidth per GPU due to TP overhead):

  • TPOT at batch=1, S=2K?
  • TPOT at batch=32, S=2K?
  • At what batch does TPOT double vs batch=1?

Answer sketch: KV per sequence = 640 MB. Per-GPU work ≈ (35 GB + b × 320 MB)/1.675 TB/s (KV halved by TP). - b=1: (35 + 0.32)/1.675 ≈ 21 ms. - b=32: (35 + 10.24)/1.675 ≈ 27 ms. - Doubles when b × 320 MB ≈ 35 GB → b ≈ 110.

Exercise 3-Throughput wall

Same setup. What batch size maximizes decode throughput, given a 80 GB GPU?

Answer sketch: KV memory budget per GPU = 80 - 35 - overhead ≈ 40 GB. KV per sequence at S=2K = 320 MB (post-TP). Max batch = 40 / 0.32 = 125. Throughput at b=125: 125 / 64 ms ≈ 1950 tokens/sec/GPU. Beyond that you OOM.

Exercise 4-Naive vs paged

Same model, max_model_len=4096. Compare concurrent-sequence capacity:

  • Naive contiguous allocation per request, each reserving 4K KV.
  • Paged with block_size=16, sequences average actual length 800.

Answer sketch: Naive KV per seq = 4096 × 320 KB = 1.28 GB. Capacity = 40 GB / 1.28 GB ≈ 31. Paged: each seq actually uses ceil(800/16) = 50 blocks = 250 MB plus at most 16-token waste. Capacity = 40 GB / 250 MB ≈ 160. 5× more concurrent users.

Exercise 5-TTFT under load

Llama-3-70B prefill rate: 5000 tokens/sec/GPU at batch=1 (compute-bound, hits ~50% of FLOP roof). Prompts arrive Poisson at rate λ = 10/sec, mean prompt length 1000 tokens.

  • Prefill server load? Under what λ does the queue blow up?
  • What's expected TTFT (M/M/1 approximation) at λ=10, λ=15, λ=18?

Answer sketch: Service rate = 5000 tok/s ÷ 1000 tok/req = 5 req/s. Wait-that's already overloaded at λ=10. We need TP=2 or 2 prefill servers. With service rate 10 req/s, ρ = λ/μ. At λ=10: ρ=1, queue diverges (M/M/1). At λ=8: ρ=0.8, expected wait = ρ/(μ(1-ρ)) = 0.8/(10×0.2) = 0.4 s. Lesson: prefill capacity must be over-provisioned, especially when tail-aware.

Exercise 6-Disaggregation worth it?

Workload: 50% requests have 4K prompts and 100 output tokens (prefill-heavy); 50% have 100-token prompts and 1000 output tokens (decode-heavy). Two GPUs.

  • In co-located continuous batching, what's the bottleneck regime?
  • In disaggregated (1 prefill, 1 decode), what bottleneck shifts?

Answer sketch: Total prefill work per request avg = 0.5 × 4000 + 0.5 × 100 = 2050 tokens. Total decode tokens per request avg = 0.5 × 100 + 0.5 × 1000 = 550 tokens. In co-located, both compete for KV. The 4K-prompt requests admit slowly under tight KV, increasing TTFT. Disaggregated: prefill GPU runs at 5000 tok/s, handles 5000/2050 ≈ 2.4 req/s. Decode GPU runs at maybe 2000 tok/s, handles 2000/550 ≈ 3.6 req/s. Bottleneck is prefill-add a second prefill GPU. Co-located would have over-provisioned both regimes simultaneously.


16. The shape of the field

The vLLM + Orca synthesis (paged attention + continuous batching) is now table stakes-every modern inference engine (TGI, TensorRT-LLM, SGLang, LMDeploy) has both. The frontier is now:

  • Better scheduling: SLO-aware, multi-tenant, fairness, deadline scheduling. SGLang's RadixAttention and constrained decoding fit here.
  • Better KV management: hierarchical KV (HBM → DRAM → SSD) à la Mooncake. Distributed prefix caches.
  • Disaggregation everywhere: prefill/decode is just the start. Some systems are exploring further splits-embedding compute, sampling, even per-layer routing.
  • Speculative & parallel decoding: Medusa, EAGLE, lookahead decoding-all integrate with paged attention via the same block-manager primitives.
  • Quantization at every layer: FP8 weights, FP8 KV-cache, INT4 weight-only-all expand the practical batch envelope by shrinking the bandwidth footprint.
  • MoE serving: routes the bandwidth equation through expert-selection, with its own batching considerations (every token activates a different expert subset → batching gains weaker).

Each of these slots into the same architecture: scheduler → block manager → executor → sampler. The mechanism is stable; the policies and the kernels evolve.


17. Cheat sheet

The two phases:

Prefill  : compute-bound.  P tokens in parallel.  T ≈ P · per-tok-prefill ≈ P · 200 µs (70B/H100).
Decode   : bandwidth-bound. 1 token per step.   T ≈ (W + b·KV)/B_HBM.

The cost model:

T_step(b, S) ≈ (W + b · K_per_tok · S) / B_HBM        [memory-bound regime]
Throughput(b) ≈ b / T_step(b, S)
                      ↑ peaks where KV-term ≈ weight-term
                      ↑ in practice capped by KV memory

The architecture:

API → Engine.step():
  1. Scheduler picks batch (running ∪ admit_some_waiting)
  2. BlockManager allocates / appends slots
  3. Executor.forward(batch, block_tables)
  4. Sampler.sample(logits)
  5. For each: append token, finish-or-continue, free blocks
  → loop

The big wins:

Continuous batching     : 2-4×    (vs request-level)
+ Paged attention       : 5-10×   (vs naive contiguous)
+ Prefix caching        : 2-5×    (workload-dependent, chat/RAG)
+ Chunked prefill       : smoother TPOT, modest throughput
+ Speculative decoding  : 2-3×    (decode only)
+ Disaggregation        : 1.5-3×  (mixed workloads, scale)

The tuning levers:

gpu_memory_utilization        ← KV pool size
max_num_batched_tokens        ← TPOT smoothness
max_num_seqs                  ← concurrent batch cap
block_size                    ← leave at 16
enable_prefix_caching         ← leave on
swap_space                    ← if long ctx + spare host RAM
tensor_parallel_size          ← within-NVLink-island
pipeline_parallel_size        ← across-island
max_model_len                 ← match real workload, not arch max

The metrics:

TTFT  = queue + prefill + first_decode
TPOT  ≈ T_step
Throughput = sum of output tokens / wall time
Goodput = throughput counting only SLO-meeting requests

If you want to consolidate (not learn from scratch-that's what this chapter is for):

  • Kwon et al., "Efficient Memory Management for Large Language Model Serving with PagedAttention," SOSP 2023. The vLLM paper. Read for the kernel details and original block-manager design.
  • Yu et al., "Orca: A Distributed Serving System for Transformer-Based Generative Models," OSDI 2022. Iteration-level scheduling. Predates vLLM, still the cleanest exposition of the scheduling idea.
  • Agrawal et al., "Taming Throughput-Latency Tradeoff in LLM Inference with Sarathi-Serve," OSDI 2024. Chunked prefill and the prefill/decode interference analysis.
  • Zhong et al., "DistServe: Disaggregating Prefill and Decoding for Goodput-optimized LLM Serving," OSDI 2024. The disaggregation argument and KV-transfer engineering.
  • Mooncake (Moonshot AI, 2024-2025). KV-as-distributed-object, hierarchical KV cache.
  • The FlashAttention series (Dao et al., 2022, 2023, 2024)-orthogonal but always relevant for understanding the attention kernel; covered in deep dive 06.

The papers are dense in evaluation (benchmark suites, microbenchmarks). With this chapter as scaffolding, you should be able to skim those evaluations as confirmation rather than as primary learning.


19. Closing

Inference serving is not a deep-learning problem dressed up as a systems problem. It is a systems problem, full stop, with a deep-learning function as the workload. The hardness comes from:

  • Variable-length, growing, unpredictable workloads.
  • Bandwidth-bound primary cost model.
  • Multi-tenant SLOs across heterogeneous request shapes.
  • Memory hierarchies that span HBM, DRAM, and PCIe.

PagedAttention and continuous batching solved the spatial and temporal problems. Chunked prefill smoothed the regime transitions. Prefix caching exploited the workload's redundancy. Disaggregation made the regimes physical. Speculative decoding cracked the per-step bandwidth ceiling.

Each of these is one good idea, and the system that implements them is more or less forced once you accept the cost model in §1-3. That is the point of the chapter: when you understand why the equation looks the way it does, every architectural choice in vLLM and its descendants reads as inevitable.

The next deep dive (09) covers training systems-the same kind of cost-model-first analysis, applied to gradient computation, all-reduce, and pipeline schedules. After that, 10 covers speculative decoding properly, 11 covers MoE, and 12 covers the multi-modal serving extensions. They will all build on the same foundation: a clear-eyed accounting of bytes, FLOPs, and time.

Deep Dive 09: Quantization Theory and Practice for LLMs

"FP16 is a courtesy. INT4 is the contract."-Anonymous inference engineer, c. 2024.

This chapter is the self-contained reference for everything an AI systems engineer needs to know about quantization for large-language-model inference (and a sketch of training). It is written so that, after reading it carefully, you should be able to:

  1. Re-derive every algorithm presented (AWQ, GPTQ, SmoothQuant, FP8 scaling) without consulting the original papers.
  2. Estimate the on-device memory footprint of any model under any precision scheme to within ~1%.
  3. Reason about why a given scheme is fast (or not) on a given hardware target.
  4. Design and run a defensible quantization evaluation for a production deployment.

We assume the reader has internalized the previous deep dives, especially:

  • DD 03 (the GPU memory hierarchy and the HBM/L2/SMEM/register pyramid),
  • DD 06 (KV-cache management and the prefill/decode split),
  • DD 07 (attention kernels: FlashAttention and Marlin's structural cousins).

If you have not, the single most important fact to anchor on is this:

Decode is memory-bandwidth-bound. Every output token requires reading every weight from HBM. Halving the bytes per weight roughly doubles tokens-per-second.

Quantization is therefore the highest-leverage inference optimization. Nothing else-not better attention kernels, not speculative decoding, not better schedulers-buys as much throughput per engineering hour as moving from FP16 weights to INT4. This chapter is about how that miracle is implemented without destroying model quality.


Table of Contents

  1. Why quantize at all
  2. Number-format theory: floats, integers, and the bits in between
  3. Quantization fundamentals: affine maps, symmetric vs. asymmetric
  4. Granularity: per-tensor, per-channel, per-group
  5. Round-to-nearest (RTN) and why INT8 just works
  6. Why INT4 RTN fails: outliers and heavy tails
  7. AWQ-Activation-aware Weight Quantization, derived
  8. GPTQ-Hessian-aware column-wise quantization, derived
  9. SmoothQuant-redistributing difficulty for W8A8
  10. Activation quantization: static vs. dynamic
  11. FP8 inference on H100
  12. FP8 training (brief)
  13. On-the-fly dequantization and the Marlin kernel
  14. Mixed-precision inference
  15. Calibration set design
  16. Evaluation discipline
  17. Practical exercises
  18. Cheat sheet and further reading

1. Why quantize at all

1.1 The arithmetic intensity argument

For a transformer decoder generating one token at a time with batch size 1, the compute required is roughly 2 × P FLOPs (one multiply and one add per parameter), where P is the parameter count. The memory traffic required is B × P bytes, where B is bytes per parameter. The arithmetic intensity is therefore:

AI = (2 × P) FLOPs / (B × P) bytes = 2/B FLOPs per byte

Some concrete numbers:

Precision B (bytes/param) AI (FLOP/byte)
FP32 4 0.5
BF16/FP16 2 1.0
FP8 1 2.0
INT4 0.5 4.0

Compare these to the roofline arithmetic intensity of an H100, which is (989 TFLOPS BF16) / (3.35 TB/s HBM3) ≈ 295 FLOP/byte. Decode at every precision listed above is at least two orders of magnitude below the roofline-squarely memory-bound.

When you are memory-bound, throughput scales as 1/B. Cutting weight bytes in half doubles your tokens-per-second. Cutting them by 4× quadruples it. Quantization is leverage you don't get anywhere else in the stack.

1.2 The capacity argument

A 70B-parameter model needs, in raw weights:

  • FP32: 70 × 10^9 × 4 = 280 GB
  • FP16: 140 GB
  • INT8: 70 GB
  • INT4 (group=128 with FP16 scales): ~36 GB

The H100 80 GB SXM has 80 GB of HBM, of which ~10 GB is reserved for KV cache, activations, and the CUDA runtime. INT4 is the only way to fit a 70B model on a single H100. This is not a footnote-it is the dominant operational reason quantization is mandatory in 2024–2026 production deployments.

1.3 The energy argument

For server-class inference, the number that matters is tokens per joule. Memory traffic dominates that as well: an HBM3 read costs roughly 50–100× more energy than an FMA on tensor cores. Quantization lowers tokens-per-joule by directly cutting the dominant cost term.

1.4 The cost argument

If you serve 1B tokens/day from a fleet of H100s, doubling tokens-per-GPU halves your fleet size and roughly halves your inference COGS. There is no other lever in the stack with this multiplier.


2. Number-format theory

A floating-point number is (-1)^s × m × 2^e with three fields packed into a fixed bit width: a 1-bit sign s, an exponent field, and a mantissa (significand) field. IEEE 754 introduces:

  • An exponent bias so the exponent field is unsigned.
  • Subnormals (denormals) for graceful underflow.
  • Special values ±∞ and NaN.

For all formats below let Eb be exponent bits and Mb be mantissa bits, exponent bias bias = 2^(Eb-1) - 1. A normal number is (-1)^s × (1.m_2) × 2^(E - bias).

2.1 FP32 (binary32)

Field Bits
Sign 1
Exponent 8
Mantissa 23
Total 32
  • Exponent bias: 127.
  • Smallest positive normal: 2^-126 ≈ 1.18 × 10^-38.
  • Largest finite: (2 - 2^-23) × 2^127 ≈ 3.40 × 10^38.
  • Decimal precision: log10(2^24) ≈ 7.22 digits.

This is the historical default for ML training. It is wasteful for both training and inference because neural networks empirically tolerate aggressive precision loss.

2.2 FP16 (binary16, IEEE half-precision)

Field Bits
Sign 1
Exponent 5
Mantissa 10
  • Bias: 15.
  • Smallest normal: 2^-14 ≈ 6.10 × 10^-5.
  • Largest: (2 - 2^-10) × 2^15 ≈ 65 504.
  • Decimal precision: ~3.3 digits.

The killer problem for FP16 in training is the exponent range: gradients can underflow to zero. Loss-scaling (multiplying loss by 2^k before backprop, dividing gradients by 2^k after) was the workaround until BF16 displaced it.

2.3 BF16 (bfloat16, Google Brain)

Field Bits
Sign 1
Exponent 8
Mantissa 7
  • Bias: 127 (same as FP32).
  • Range matches FP32 (~1.18e-38 to ~3.39e38).
  • Decimal precision: log10(2^8) ≈ 2.4 digits.

BF16 was specifically designed to be a drop-in replacement for FP32 in deep learning: the upper 16 bits of an FP32 number, full-stop. No loss-scaling needed. Almost all modern training (since A100 / TPU v3) uses BF16. It is the de-facto FP16 of 2024+.

2.4 FP8 E4M3

Field Bits
Sign 1
Exponent 4
Mantissa 3
  • Bias: 7 (per the OFP8/NVIDIA spec; some specs use 8).
  • Range: ~2^-9 (smallest subnormal, ~1.95e-3) up to ~448 (the spec replaces Inf with the largest finite value to extend range).
  • Used for forward activations and weights in inference and training.

The mantissa is wider than E5M2 for better precision in the heart of the distribution, at the cost of a smaller exponent range.

2.5 FP8 E5M2

Field Bits
Sign 1
Exponent 5
Mantissa 2
  • Bias: 15 (matches FP16).
  • Range matches FP16 (~6e-5 to ~65 504).
  • Used for gradients in FP8 training.

The wider exponent range trades two mantissa bits for the dynamic range that backprop requires.

2.6 INT8

  • Signed two's-complement 8-bit integer.
  • Range: -128 to 127.
  • No exponent. Dynamic range is fixed at log2(256) ≈ 8 bits.

To represent a real number x you need an external scale (typically FP32 or FP16): x ≈ scale × q.

2.7 INT4

  • Signed two's-complement 4-bit integer.
  • Range: -8 to 7 (or sometimes -7 to 7 with one redundant code).
  • Always packed: two INT4 values per byte. Hardware tensor cores load them packed and unpack on the fly.

INT4 is not natively addressable-you cannot *ptr an INT4. Software must pack/unpack at storage boundaries, and kernels must implement dequant logic in registers or shared memory. This is the central engineering challenge that Marlin (§13) solves.

2.8 Comparison table

Format Bits Sign Exp Mantissa Range (approx) ~Decimal precision Typical use
FP32 32 1 8 23 1.2e-38 .. 3.4e38 7.2 Reference training, master weights
BF16 16 1 8 7 matches FP32 2.4 Training, activations
FP16 16 1 5 10 6e-5 .. 6.5e4 3.3 Legacy training, inference
FP8 E4M3 8 1 4 3 2e-3 .. 448 ~1 FP8 inference fwd, training fwd
FP8 E5M2 8 1 5 2 6e-5 .. 6.5e4 <1 FP8 training gradients
INT8 8 1 - - -128 .. 127 (× scale) n/a INT8 PTQ, W8A8
INT4 4 1 - - -8 .. 7 (× scale) n/a W4A16 inference

The thing to internalize: floats trade range for precision via the exponent. Integers have no exponent, so they need an external scale to be useful. That external scale is the entire conceptual core of integer quantization.


3. Quantization fundamentals

3.1 Affine quantization

The most general scheme:

q     = round(x / scale) + zero_point          # quantize
x_hat = scale × (q - zero_point)               # dequantize

Here: - x ∈ ℝ is the original real-valued tensor element. - q ∈ ℤ is the integer code. - scale ∈ ℝ_{>0} (typically stored FP16 or FP32). - zero_point ∈ ℤ (typically same width as q, so it fits in the same dtype).

The quantization error per element is e = x - x_hat, bounded by |e| ≤ scale / 2 if rounding is correct (round-to-nearest, ties-to-even).

3.2 Symmetric quantization

Set zero_point = 0. Then:

q     = round(x / scale)
x_hat = scale × q

The integer range [Q_min, Q_max] should be symmetric around 0. For INT8 we typically use Q_max = 127 and Q_min = -127 (forfeiting one code at -128 for symmetry; some implementations use the full -128..127). Then:

scale = max(|x|) / Q_max

Symmetric is the right default for weights, which are empirically near-zero-centered after standard initialization and training.

3.3 Asymmetric quantization

Allow nonzero zero_point. The mapping that uses the full integer range is:

scale       = (x_max - x_min) / (Q_max - Q_min)
zero_point  = round(Q_min - x_min / scale)

Then for any input x:

q = clip( round(x / scale) + zero_point , Q_min , Q_max )

Asymmetric is the right default for activations, especially post-ReLU/post-GELU activations that have a hard one-sided floor at 0. Forcing symmetric quantization on a [0, x_max] activation throws away half your codes.

3.4 Derivation of scale and zero_point (asymmetric)

We want a linear (affine) map from [x_min, x_max] → [Q_min, Q_max]:

q(x) = a × x + b

with constraints q(x_min) = Q_min, q(x_max) = Q_max. Two equations, two unknowns:

a × x_min + b = Q_min
a × x_max + b = Q_max

Subtract:

a × (x_max - x_min) = Q_max - Q_min
a = (Q_max - Q_min) / (x_max - x_min) = 1 / scale

so scale = (x_max - x_min) / (Q_max - Q_min). Substitute back:

b = Q_min - a × x_min = Q_min - x_min / scale

i.e. zero_point = round(Q_min - x_min / scale).

3.5 Symmetric as a special case

Set x_min = -x_max (after taking x_max ← max(|x|)). Then Q_min = -Q_max gives b = 0, i.e. zero_point vanishes, and scale = 2 × x_max / (2 × Q_max) = x_max / Q_max = max(|x|) / Q_max. Consistent.

3.6 The MAD-vs-MSE choice

When you compute max(|x|) you are picking an `L_inf - optimal scale. You can also pick a scale that minimizes MSE under a Gaussian/Laplacian assumption-this is the basis of percentile clipping (e.g., set scale based on the 99.9th percentile of |x|, clipping outliers). For weights this rarely matters; for activations it can matter a lot, and is one knob that AWQ/SmoothQuant indirectly tune.


4. Granularity

A single scale per tensor is the cheapest. But it is also the most error-prone, because a single outlier element forces the scale large, wasting precision on the rest.

4.1 Per-tensor

One scale (and optionally one zero_point) for the entire tensor. Storage cost is negligible. Quality is poor for INT4 weights because the dynamic range of weights varies wildly across rows of W.

4.2 Per-channel (per-row, per-output-channel)

For a weight matrix W ∈ ℝ^{out × in} representing y = W x, each output channel (each row of W) gets its own scale. This is the standard for weight quantization because:

  1. Each output is a linear combination of all inputs; the scale of row i only affects output i.
  2. There is no cross-channel arithmetic that would cause scale mismatches in the matmul itself.
  3. Dequantization at output time is y_i ≈ scale_i × (W_q · x)_i, a cheap final multiply per output element.

For activations, per-channel-per-token is also possible but expensive to apply in the matmul.

4.3 Per-group (block-wise)

A compromise: split each row of W into contiguous groups of size G (typically G = 128 or 64). Each group has its own scale and (optionally) zero_point.

Why 128? It matches the K-dimension tile size of standard tensor-core GEMMs. A column-major dequantization can dequantize one 128-wide tile, multiply, accumulate, then move to the next tile-the scale is constant within the inner loop.

Per-group is the dominant scheme for INT4 weight quantization in 2024+. It is what AWQ, GPTQ, and Marlin all target.

4.4 Effective bits for INT4 group=128

Each element costs 4 bits. Each group of 128 elements has one FP16 scale = 16 bits (and possibly one FP16 or INT4 zero_point-let's count both).

bits_per_element = 4 + 16/128 + 16/128
                 = 4 + 0.125 + 0.125
                 = 4.25 bits/element     (with FP16 zp)

# or, symmetric (no zp):
                 = 4 + 16/128
                 = 4.125 bits/element

This is why INT4 group=128 is sometimes quoted as "~4.13 bits/element". The exact number depends on whether scales are FP16/BF16/FP32 and whether zero_points are FP16 or packed integers. In practice, 4.25 bits/element is a safe planning number.

4.5 Granularity comparison

Granularity Quality Storage cost Kernel cost Typical use
Per-tensor Worst Negligible Cheapest INT8 W8A8 (with smoothing)
Per-channel (row) Good 1 scale per row 1 mul per output INT8 weights
Per-group Best 1 scale per G elems 1 mul per group inside the inner loop INT4 weights

5. Round-to-nearest (RTN)

The simplest possible PTQ algorithm:

def rtn_quantize_per_group(W, group_size=128, bits=4):
    Q_max = 2**(bits-1) - 1                   # 7 for INT4
    W_q = empty_int(W.shape)
    scales = empty_fp(W.shape[0], W.shape[1] // group_size)
    for i in range(W.shape[0]):               # rows
        for g in range(W.shape[1] // group_size):
            block = W[i, g*group_size:(g+1)*group_size]
            s = block.abs().max() / Q_max
            scales[i, g] = s
            W_q[i, g*group_size:(g+1)*group_size] = round(block / s).clamp(-Q_max, Q_max)
    return W_q, scales

For INT8 per-channel symmetric RTN on weights, this is essentially good enough for most dense LLMs: degradation is typically <0.1 PPL on well-calibrated benchmarks for 7B+ models. The reason is that the dynamic range of any single output channel of a trained W_proj rarely exceeds ~2^7, so 8 bits + a per-channel scale captures it.

For INT4 per-group RTN, the story is more painful. We will see why next.


6. Why INT4 RTN fails

6.1 The error model

Per-element rounding error is uniform on [-scale/2, scale/2], so its variance is scale^2 / 12. The matmul output is a sum of such errors:

y_i = Σ_j W_ij × x_j ≈ Σ_j (W_ij + e_ij) × x_j
e_y = Σ_j e_ij × x_j

If e_ij are independent zero-mean with variance σ_w^2 = scale^2 / 12, then:

Var(e_y) = Σ_j σ_w^2 × x_j^2 = σ_w^2 × ||x||^2

So the output error scales with the squared norm of the activation. A single outlier x_j with large magnitude dominates ||x||^2, and therefore dominates the output error of every output channel.

6.2 Heavy-tailed activations in LLMs

It is an empirical, well-replicated fact that LLM activations-specifically the inputs to the down-projection of MLP blocks and the inputs to attention output projections-have heavy-tailed per-channel distributions. A handful of channels (sometimes called systematic outliers or emergent features) carry magnitudes 10× to 100× larger than the median channel. These are not bugs; they appear during training and are load-bearing for the network's function.

6.3 Why outliers break INT4 weight quantization

Even though we are quantizing weights, the quality metric we ultimately care about is output error. By the equation above, the impact of weight error on output is multiplied by the corresponding input (activation) channel. If weight column j is multiplied by an input channel with 100× the typical magnitude, then any error in that column is amplified 100× at the output.

INT4 weight RTN treats all columns equally-it applies the same per-row, per-group scale logic regardless of which input channel a weight column will be multiplied by. The columns paired with outlier activation channels get just as much rounding noise as the rest, and that noise blows up into a 100×-larger output error.

Two responses are possible:

  1. AWQ: protect the weights paired with outlier activation channels by giving them more precision (effectively scaling them up before quantization, then compensating).
  2. GPTQ: don't try to protect anything-instead, after quantizing each weight, update the remaining weights to compensate for the error you just introduced.

These are the two great pillars of modern weight-only INT4 PTQ.

A third response, addressing W8A8 (where activations are also quantized), is SmoothQuant: shift the magnitude out of the activations and into the weights, smoothing the activation distribution to make it INT8-friendly.


7. AWQ

Lin, Tang, Tang, Yang, Xiao, Dang, Han, "AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration", MLSys 2024.

7.1 Insight

AWQ rests on three observations:

  1. Not all weights matter equally. Profiling a calibrated network shows that ~1% of weight channels carry most of the importance, measured by the magnitude of activations they multiply.
  2. Keeping just 1% of channels in FP16 (mixed-precision) recovers most of the accuracy of full FP16. This is the empirical proof that the rest of the channels are quantization-tolerant.
  3. Mixed-precision is a kernel headache. You don't actually want to ship an INT4-mostly-with-1%-FP16 weight matrix because the GEMM kernel would have to handle two layouts.

AWQ's innovation: instead of keeping salient channels in FP16, scale them up before quantization so they get more INT4 precision, and scale the matching activation channels down to compensate. The matmul output is mathematically unchanged, but the salient weights now occupy a more quantization-friendly part of the INT4 grid.

7.2 The math

For a linear layer with output y = W x where W ∈ ℝ^{m × n} and x ∈ ℝ^n, introduce a per-input-channel diagonal scaling matrix S = diag(s) with s ∈ ℝ^n_{>0}:

y = W x
  = W (S S^{-1}) x
  = (W S) (S^{-1} x)
  = W' x'

where W' = W S (each column of W is scaled by the corresponding s_j) and x' = S^{-1} x (each input element is divided by s_j). The product is exactly unchanged; we have only shifted magnitude between the two operands.

Now apply INT4 quantization to W' instead of W:

W_q = Q(W S)
y_hat = (1/scale_row) × dequant(W_q) × (S^{-1} x)
      ≈ W' × x'
      = W x

The key question: how should we choose s?

7.3 Choosing s: the AWQ heuristic

For weight column j, its contribution to the output is (column_j of W) × x_j. If x_j is large in magnitude (a salient activation channel), then errors in column_j of W get amplified at the output.

If we set s_j to be large for salient channels, then W'[:, j] = W[:, j] × s_j is also large. The per-row, per-group scale used by INT4 quantization is determined by max |W'[:, j']| over j' in the group. By inflating salient columns, we ensure that even the small entries of W in salient columns are quantized at a finer absolute resolution (relative to their original magnitude).

Why does this work? Imagine a group of 128 weights where one column is salient (paired with a 100× outlier activation) and the rest are normal. Without AWQ, the scale of this group is set by the largest absolute weight in it, which might be a non-salient one-and our salient column gets only INT4 resolution for what really should be INT5 or INT6 precision. With AWQ, the salient column has been pre-scaled by, say, 2× or 4×, so its weights now dominate the group's scale-setting max|W'|. Implicitly, the salient column gets 2 or 4 effective levels of additional precision.

The cost is borne by `S^{-1} x - but the activation multiplier compensates exactly, if it's done in higher precision. AWQ keeps activations in FP16, so the compensation is essentially free.

7.4 The AWQ algorithm, formal

Input:  W ∈ ℝ^{m × n}, calibration activations X ∈ ℝ^{n × N}
        (collected across N tokens from a small calibration set)
Output: W_q (INT4 group=128), per-group scales

Step 1. Compute per-input-channel activation magnitude:
        a_j = (1/N) Σ_t |X[j, t]|         for j = 1..n
        (mean absolute value per channel, optionally restricted to top-k tokens)

Step 2. Choose per-channel scaling vector s:
        s_j = a_j^α
        where α ∈ [0, 1] is a hyperparameter (typical: 0.5–0.7)

Step 3. Optionally normalize s so that geometric mean is 1:
        s_j ← s_j / geomean(s)
        (purely numerical hygiene; the math is invariant)

Step 4. Form W' = W · diag(s)

Step 5. Quantize W' with RTN per-group, group_size = 128:
        compute (W'_q, scales_per_group)

Step 6. Store W'_q, scales, and s.
        At inference: y = (W'_q dequantized) × (x / s)
        (the divide-by-s is fused into the previous layer's output, so it is free at runtime)

Step 7. (Optional) Search α: for each candidate α ∈ {0.0, 0.1, ..., 1.0},
        run forward pass on calibration set, compute output MSE, pick the α with
        minimum MSE. This is a 1-D grid search per layer.

The grid search in Step 7 is what makes AWQ "activation-aware" rather than just "activation-magnitude-scaled". It picks the α that empirically minimizes per-layer reconstruction error on real data.

7.5 Worked tiny example

Let W ∈ ℝ^{2 × 4} and consider one row of the weight matrix:

W[0, :] = [ 0.10,  0.05,  0.02,  0.08 ]

and a calibration mean-abs activation vector:

a = [ 1.0,  1.0,  10.0,  1.0 ]

Channel 2 is a 10× outlier.

Without AWQ, RTN INT4 with Q_max = 7 and one group covering all 4 elements:

max|W[0,:]| = 0.10
scale = 0.10 / 7 ≈ 0.01429
W_q = round(W / scale) = round([7.0, 3.5, 1.4, 5.6]) = [7, 4, 1, 6]    # ties-to-even may differ
W_dequant = scale × W_q = [0.1000, 0.0571, 0.0143, 0.0857]
err = W - W_dequant ≈ [0.000, -0.007, +0.006, -0.006]

The error in channel 2 is +0.006. The output contribution of this column to the output is:

err_y_from_col2 = err_col2 × a_col2 = 0.006 × 10.0 = 0.060

versus columns 0,1,3 each contributing roughly 0.007 × 1.0 = 0.007. Channel 2 dominates output error by ~10×.

With AWQ, choose α = 0.5:

s = a^0.5 = [1.0, 1.0, 3.162, 1.0]                              # before normalization
geomean = (1 × 1 × 3.162 × 1)^0.25 = 3.162^0.25 ≈ 1.333
s ← s / 1.333 = [0.750, 0.750, 2.372, 0.750]                    # after normalization
W' = W × s = [0.075, 0.0375, 0.0474, 0.060]                     # channel 2 is now near the top
max|W'[0,:]| = 0.075
scale' = 0.075 / 7 ≈ 0.01071
W'_q = round(W'/scale') = round([7.0, 3.5, 4.43, 5.6]) ≈ [7, 4, 4, 6]
W'_dequant = scale' × W'_q = [0.0750, 0.0429, 0.0429, 0.0643]
err' = W' - W'_dequant = [0, -0.0054, +0.0045, -0.0043]
err in original W space: err_W = err' / s = [0, -0.0072, +0.0019, -0.0057]
err_y_from_col2 = err_W[2] × a[2] = 0.0019 × 10.0 = 0.019

Output error from the outlier column dropped from 0.060 to `0.019 - roughly 3×. Errors in non-outlier channels grew slightly (because we shrank their effective precision), but they were 10× smaller to begin with, so total output MSE drops substantially.

The general principle: AWQ trades a little precision in non-outlier columns for a lot of precision where it counts.

7.6 Practical notes on AWQ

  • The scales s are absorbed into the previous layer's output. For a transformer block, the per-channel scaling of the MLP down-projection's input is folded into the up-projection's output. There is no runtime divide.
  • AWQ is weight-only-activations stay BF16. So there's no activation quantization error, only weight quantization error.
  • Typical degradation: less than 0.5 PPL on standard perplexity benchmarks for 7B+ dense models. Smaller models (1-3B) are more sensitive and can lose 1-2 PPL.
  • The AWQ kernel + Marlin (§13) is the highest-throughput W4A16 kernel as of writing.

8. GPTQ

Frantar, Ashkboos, Hoefler, Alistarh, "GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers", ICLR 2023.

8.1 Insight

Where AWQ asks "which weights should I protect?", GPTQ asks "after I round one weight, how should I update the rest to undo the damage?". It descends from Hassibi & Stork's Optimal Brain Surgeon (1993) for neural-network pruning.

8.2 The Optimal Brain Surgeon background

Suppose you have a trained network with loss L(w) minimized at w*. You want to perturb w (e.g., set one weight to a specific quantized value) while minimizing the increase in L. Locally, expand L to second order around w*:

L(w* + δ) ≈ L(w*) + g^T δ + (1/2) δ^T H δ

At the minimum g = 0, so δL ≈ (1/2) δ^T H δ.

If we constrain the perturbation to satisfy e_q^T δ = -w*_q (i.e., we are forcing the q-th weight to become 0, equivalently δ_q = -w*_q), then the optimal δ minimizing the quadratic subject to this linear constraint is, by Lagrange multipliers:

δ* = - (w*_q / [H^{-1}]_{qq}) × H^{-1} e_q

Equivalently, after we set weight q to its target value w_q^new, every other weight w_i should be updated by:

δw_i = -(w_q* - w_q^new) × [H^{-1}]_{iq} / [H^{-1}]_{qq}

This is the OBS update rule. The increase in loss it causes is:

δL = (1/2) × (w_q* - w_q^new)^2 / [H^{-1}]_{qq}

Notice that this δL is the minimum possible loss increase given that you're forced to change w_q to w_q^new.

8.3 Layer-wise reformulation for quantization

GPTQ does not run on the global loss L. Instead, it considers each linear layer in isolation and minimizes the layer-wise reconstruction error:

E(W_q) = ‖W X - W_q X‖_F^2

where X ∈ ℝ^{n × N} is a batch of N calibration activations going into this layer. Expanding:

E(W_q) = trace[(W - W_q) X X^T (W - W_q)^T]
       = Σ_i (Δw_i)^T H (Δw_i)

where Δw_i = (W - W_q)[i, :] is the error in row i, and:

H = 2 × X X^T   ∈ ℝ^{n × n}

is the Hessian of the layer-wise reconstruction loss with respect to a single row of W. The factor of 2 from differentiating the squared norm is conventional. Crucially:

  • H only depends on the input activations X, not the weights.
  • H is the same for every row of W. So we precompute it once per layer.
  • The problem decouples by row: each row of W is quantized independently.

8.4 The greedy column-by-column algorithm

For one row w ∈ ℝ^n of W, we want to choose integer q ∈ ℤ^n minimizing (w - dequant(q))^T H (w - dequant(q)). This is an integer-quadratic-programming problem (NP-hard in general).

GPTQ's approximation: quantize one column at a time, in order, and compensate the rest using OBS.

Pseudocode for a single row, ignoring grouping for clarity:

w_q = zeros(n)              # quantized weights (integers)
w_remaining = w.copy()      # current "live" weight vector
H_inv = inverse(H + λI)     # damped inverse for stability

for j in 0..n-1:
    # Step 1: quantize column j
    q_j = round(w_remaining[j] / scale[j])         # using whatever scale rule
    w_q_value = scale[j] × q_j                     # the dequantized value we'll use

    # Step 2: error introduced
    err = w_remaining[j] - w_q_value

    # Step 3: OBS update-push the error into columns j+1..n-1
    for i in j+1..n-1:
        w_remaining[i] -= err × H_inv[j, i] / H_inv[j, j]

    # Record q
    w_q[j] = q_j

After all n columns are processed, w_q is the final quantized row. The quantization scale scale[j] can be per-column (rare), per-group, or per-channel.

8.5 Why the greedy scheme is good

At each step we are solving the optimal-update problem for the column we just quantized, given that all already-quantized columns are frozen. We do not re-update earlier columns-that would un-quantize them. The remaining error after the loop is bounded by the sum over all greedy steps of the residual that can't be absorbed into later columns; in practice this is small for transformer weights.

The greedy choice of column order matters less than you'd think (left-to-right is fine), but a more robust variant called act-order GPTQ sorts columns by descending diagonal of H (i.e., by activation magnitude) so that the high-impact columns are quantized first and have the most "downstream slack" to absorb errors.

8.6 Cholesky-based efficient implementation

Computing and storing H^{-1} ∈ ℝ^{n × n} is O(n^2) storage and O(n^3) for the inverse. For large hidden dimensions (n = 8192 in Llama-7B's MLP) that's 256 MB just for the inverse and a few seconds for the inversion-fine for an offline calibration but not free.

The trick: we only ever access the upper-triangular part of H^{-1} (we only update columns j+1..n-1 from column j). The Cholesky factorization of a positive-definite matrix H = L L^T lets us compute H^{-1} = L^{-T} L^{-1} cheaply, and crucially:

  • The Cholesky factor L^{-1} is itself lower-triangular.
  • The rows of L^{-T} we need are obtained sequentially.

GPTQ's trick is to perform Cholesky decomposition of H^{-1} once, then walk through its upper triangle column by column. The inner OBS update uses the precomputed Cholesky rows. This converts the algorithm from O(n^3) per layer to O(n^2) with cache-friendly access patterns.

8.7 The block (lazy-batch) trick

A further optimization: process columns in blocks of B (typically B = 128) instead of one-at-a-time. Within a block, do full updates. Between blocks, accumulate the lazy update for distant columns and apply it once when we move to the next block. This gives an order-of-magnitude wall-clock speedup because the compensations within a block fit in L1/L2 cache.

The block size B = 128 also matches the per-group scale boundary, so a single block uses one scale value (inside the block) and the algorithm naturally computes scales group-by-group.

8.8 GPTQ pseudocode, full

def gptq_layer(W, X, group_size=128, bits=4, percdamp=0.01):
    # W: [out_features, in_features]
    # X: [in_features, N]   (calibration activations)
    n = W.shape[1]
    H = 2.0 * X @ X.T                                    # [n, n]

    # Damping for numerical stability
    diag_mean = trace(H) / n
    H += percdamp * diag_mean * eye(n)

    # Cholesky of H^{-1}
    H_inv = cholesky_inverse(H)
    L = cholesky(H_inv, upper=True)                      # H^{-1} = L^T L  (some conventions vary)

    Q = zeros_like(W, dtype=int)
    scales = zeros((W.shape[0], n // group_size))

    for blk in range(0, n, B):                           # B = 128
        block_end = min(blk + B, n)
        Wblk = W[:, blk:block_end].clone()
        Lblk = L[blk:block_end, blk:block_end]           # B × B

        for j in range(B):
            col = blk + j
            # 1. Determine scale for this group if at group boundary
            if col % group_size == 0:
                gstart = col
                gend = min(col + group_size, n)
                # gather pre-update weights for this group
                gblock = Wblk[:, j:j + group_size] if (gend - gstart) <= (B - j) else  W[:, gstart:gend]
                s = gblock.abs().max(dim=1) / Q_max
                scales[:, col // group_size] = s

            # 2. Quantize column
            w_col = Wblk[:, j]
            q = round(w_col / s).clamp(-Q_max, Q_max)
            w_q = s * q
            err = (w_col - w_q) / Lblk[j, j]
            Q[:, col] = q

            # 3. OBS update within block
            Wblk[:, j+1:] -= err.unsqueeze(1) * Lblk[j, j+1:].unsqueeze(0)

        # 4. Lazy update: apply block residual to all columns to the right
        W[:, block_end:] -= (Wblk - W[:, blk:block_end]) @ L_offblock_relevant
        # (in practice: applied via the appropriate slice of L)

    return Q, scales

The percdamp (typically 0.01) adds a small multiple of the identity to H before inversion, guaranteeing positive-definiteness. Without damping, H can be near-singular when activations have small variance in some channels.

8.9 GPTQ vs. AWQ, head-to-head

Aspect AWQ GPTQ
Conceptual basis Activation-aware scaling Hessian-aware error compensation
Calibration cost Cheap (a few hundred forwards) More expensive (Cholesky per layer)
Output kernel Standard W4A16 + dequant + scale Standard W4A16 + dequant
Reordering required No Optional (act-order)
Quality on 7B-70B dense Excellent Excellent
Quality on small (<3B) Slightly better in practice Slightly worse
Production dominant? Yes (Marlin path is fastest) Yes (long-standing default)

Both methods produce INT4 group=128 weight matrices that are compatible with the same Marlin kernel-only the calibration procedure differs. In production you typically try both, pick the one with lower perplexity / better downstream eval.


9. SmoothQuant

Xiao, Lin, Seznec, Demouth, Han, "SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models", ICML 2023.

9.1 Why W8A8 is hard

W4A16 keeps activations in BF16, so there is no activation quantization error. W8A8 quantizes both weights and activations, enabling INT8 tensor-core matmul (which is twice the throughput of BF16 on most modern GPUs). But:

  • Weights are well-behaved → INT8 is easy.
  • Activations have heavy-tailed per-channel distributions → INT8 is hard because per-tensor scale is dragged up by outlier channels, and per-channel-per-token scaling is runtime-expensive.

The asymmetry: weight outliers are static (you can per-channel calibrate them once), but activation outliers are dynamic per token and per channel.

9.2 SmoothQuant insight

Use the same W = (W S)(S^{-1}) identity as AWQ, but with a different aim: redistribute magnitude from the activations to the weights such that both become INT8-friendly.

y = W x = (W diag(s)) (diag(s)^{-1} x) = W' x'

If s_j is large for outlier activation channels, then x'_j = x_j / s_j is small-the activation outliers are dampened. The price is that W'_{:, j} = W_{:, j} × s_j is larger-but weights had headroom to absorb that.

9.3 The migration strength α

How much magnitude to migrate? The choice is a hyperparameter α ∈ [0, 1] controlling how aggressively activation outliers are dampened:

s_j = max|x_j|^α / max|w_:, j|^(1-α)

The intuition:

  • α = 0s_j = 1 / max|w_:, j| → all magnitude pushed into activations (bad).
  • α = 1s_j = max|x_j| → all magnitude pushed out of activations (also bad-weights now have outliers).
  • α ≈ 0.5 → balanced.

In practice α = 0.5 is a good default; α = 0.85 has been reported for some Llama-class architectures whose activation outliers are particularly severe.

9.4 Where SmoothQuant is applied

The transformation is applied offline before INT8 calibration, at three points per transformer block:

  1. The input to the QKV projection (smoothing the LayerNorm output).
  2. The input to the MLP up-projection (smoothing the LayerNorm output).
  3. The input to the attention output projection-sometimes, depending on the architecture.

Importantly, the s scaling is fused into the previous layer's parameters:

  • For Pre-LN architectures, fold s into the LayerNorm scale (γ) before the attention/MLP block.
  • For other architectures, fold into the previous linear layer's weight columns.

Either way, no runtime compute is added-the smoothing is purely a calibration-time rewrite of the network parameters.

9.5 SmoothQuant pseudocode

def smoothquant_layer(prev_norm, layer_input_W, X_calib, alpha=0.5):
    # X_calib: [n, N]-calibration activations into this layer
    # layer_input_W: weight matrix [out, n] of the linear immediately following prev_norm
    # prev_norm: per-channel γ ∈ ℝ^n of the LayerNorm preceding the linear

    a = X_calib.abs().max(dim=1)                        # [n]
    w_max = layer_input_W.abs().max(dim=0)              # [n]
    s = (a ** alpha) / (w_max ** (1 - alpha))           # [n]

    # Fold into LayerNorm:
    prev_norm.γ /= s
    # Equivalent: divide the output of LayerNorm by s, which scales the input into the linear.

    # Fold into linear's input columns:
    layer_input_W *= s                                  # broadcast over output dim

9.6 W8A8 result

With SmoothQuant, a transformer can typically be quantized to INT8 weights and INT8 activations with <1 PPL degradation, enabling INT8 tensor-core throughput (2× BF16 on Hopper, more on Ada). It's a different point in the design space from AWQ/GPTQ-those target W4A16 (weight-only INT4); SmoothQuant targets W8A8 (both INT8).

In production, modern stacks often combine these:

  • Use SmoothQuant-style activation smoothing.
  • Apply GPTQ or AWQ for weight quantization to INT4.
  • Keep activations in BF16 (W4A16 path)-the smoothing helps even when activations stay in BF16, by making the post-smoothing weight distribution easier to quantize.

This combination is what tools like AutoAWQ and llm-compressor offer out of the box.


10. Activation quantization

If you are running W8A8 or any scheme that quantizes activations, you must decide when to compute the activation scale.

10.1 Static activation quantization

Calibrate once over a representative dataset; freeze a per-tensor (or per-channel) scale and zero_point; reuse them at runtime.

  • Pros: zero runtime overhead. The scale and zero_point are baked into the kernel.
  • Cons: bad accuracy when the runtime activation distribution differs from calibration. Especially bad with long context, long-tailed inputs, or out-of-distribution prompts.

10.2 Dynamic activation quantization

At each forward pass (each token, each layer), compute the scale on the fly:

scale_x = max(|x|) / Q_max

This is a per-token reduction over the channel dimension. On a GPU, it's a reduce_max over n elements, fused into the kernel.

  • Pros: handles per-token variation automatically. Robust.
  • Cons: costs a reduction. For decode (n ~ 4096-16384), this is a few microseconds-small compared to the matmul, but not free.

10.3 Per-token vs. per-tensor activation scales

Per-token (also called per-row) dynamic scaling is the gold standard. Per-tensor would require a global reduction across all tokens in a batch, which is impractical.

10.4 Practical guidance

  • For W8A8 inference: dynamic per-token activation scales, even if it costs 5% throughput. The accuracy gap to static is large.
  • For W4A16: no activation quantization at all-keep BF16. This is the "free" axis.
  • For FP8: per-tensor amax-based scaling, recomputed periodically (see §11).

11. FP8 inference

11.1 Hardware

NVIDIA H100 (Hopper) introduced FP8 tensor cores supporting:

  • E4M3 for activations and weights (forward path).
  • E5M2 for gradients (backward path, training).
  • 2× BF16 throughput on H100 SXM (~989 TFLOPS BF16 → ~1979 TFLOPS FP8).
  • Native dequant-on-load: tensor cores accept FP8 inputs and output FP32 accumulators.

Ada (RTX 4090, L40S) has FP8 support with similar speedup. Blackwell (B100, B200) extends to FP4.

11.2 Per-tensor scaling factor

Unlike INT8, where scale must be applied to recover the real value, FP8 is a floating-point format and represents a real number directly within its limited range. The catch: the limited range (~448 for E4M3) means real-world activations and weights need to be prescaled into the FP8 representable range.

The scaling factor is typically per-tensor and stored as FP32:

x_fp8 = quantize_e4m3(x_fp32 × scale_x)
y_fp8 = quantize_e4m3(W_fp32 × scale_w)        # quantized once at load
matmul_output (FP32 accumulator) = (1 / (scale_x × scale_w)) × Σ (x_fp8 × w_fp8)

The output is an FP32 accumulator; you then re-quantize to FP8 (with a new per-tensor scale) to feed the next layer.

11.3 Calibration: amax tracking

The scale factor for each tensor is set by tracking the maximum absolute value seen during calibration:

amax = max(|tensor|) across calibration set
scale = fp8_max / amax        # fp8_max ≈ 448 for E4M3

The challenge: amax is determined per tensor, and "per tensor" includes activations whose distribution depends on inputs. For weights this is static; for activations it must be tracked.

11.4 Delayed scaling

Recomputing amax every forward pass costs a reduction. NVIDIA TransformerEngine (TE) introduces delayed scaling:

  • Maintain a moving-window history of recent amax values (e.g., last 16 forward passes).
  • The scale used at step t is computed from the history through step t-1.
  • Rationale: amax changes slowly; using the previous amax for the current step is a good approximation.

This avoids the overhead of synchronous reduction inside each forward pass. The scale update is essentially an asynchronous bookkeeping operation.

11.5 NVIDIA TransformerEngine

TE is the NVIDIA-blessed library for FP8 training and inference. Key features:

  • Drop-in replacement for nn.Linear, nn.LayerNorm, attention layers.
  • Manages amax history and scale computation transparently.
  • Supports FP8 GEMM, FP8 attention (FA-3 with FP8 KV-cache), and mixed FP8/BF16 layers.
  • Integrates with PyTorch autograd: forward in E4M3, backward in E5M2.

In production, FP8 inference for dense LLMs is currently the right call for prefill-heavy workloads (because you get 2× tensor-core throughput) and for models that fit the H100's FP8-friendly architecture (Llama, Mistral, Qwen). For memory-bound decode, INT4 is still the higher-leverage choice because it halves memory traffic again.

11.6 FP8 vs. INT8

Both are 8-bit. Why prefer FP8?

  • FP8 has a built-in exponent → it handles dynamic range natively.
  • INT8 requires explicit scale storage and scale-aware kernels.
  • FP8 supports backprop (E5M2 covers gradient range); INT8 does not naturally.
  • FP8 is trickier on hardware that doesn't support it (Ampere). INT8 is universal.

In 2024+: FP8 is rapidly displacing INT8 for forward-pass quantization on H100/Blackwell. INT8 (W8A8 with SmoothQuant) is still common for older hardware and edge deployments.


12. FP8 training

A full treatment is in DD 14 (distributed training). The TL;DR:

  • Forward: weights, activations, optimizer-state copies → E4M3.
  • Backward: gradients → E5M2 (needs the wider exponent for the small magnitudes that arise from chain rule).
  • Master weights: kept in BF16 or FP32 to accumulate updates without underflow.
  • Per-tensor amax with delayed scaling: as above.
  • Loss scaling is no longer needed (FP8's exponent suffices for stable forward; E5M2 suffices for backward).

The training story is qualitatively similar to mixed-precision BF16 training but with a tighter scaling regime. NVIDIA's published recipes for FP8 LLama training show parity with BF16 within ~0.05 PPL on standard pretraining benchmarks at large scales.


13. On-the-fly dequantization and the Marlin kernel

W4A16 inference requires solving an unusual GEMM problem: the A operand is BF16, the B operand is INT4 packed as 4-bit values with FP16 per-group scales, and the output is BF16. Tensor cores don't natively accept INT4-with-scales. So the kernel must:

  1. Load packed INT4 weights from HBM.
  2. Unpack to BF16 in registers or shared memory.
  3. Multiply by per-group FP16 scale.
  4. Feed the resulting BF16 tile into a tensor-core BF16 matmul.

The naive implementation runs at FP16 GEMM throughput (no win) because dequantization is in series with the matmul. Marlin solves this.

13.1 Marlin kernel

Frantar, Castro, Chen, Ashkboos, Alistarh, "Marlin: Mixed-precision Auto-Regressive Linear kernels", 2024. (And the open-source marlin repository.)

Marlin is the highest-throughput open-source W4A16 GEMM as of writing. Key design decisions:

1. SMEM-based dequantization, double-buffered. While one warp dequantizes the next tile of weights from packed INT4 in SMEM into BF16 in another SMEM region, another warp consumes the previously-dequantized BF16 tile for tensor-core matmul. This hides dequant latency behind matmul latency.

2. Tensor-core BF16 matmul on dequantized weights. The matmul is a standard BF16 × BF16 → FP32 tensor-core GEMM. No new tensor-core type needed.

3. K-axis tile size aligned with group size. The inner-K tile is 128 (matching group_size = 128) so the per-group scale is loaded once per K-tile and held in registers throughout that tile's matmul.

4. Fused output dequantization. The per-group FP16 scale is multiplied into the BF16 weight tile before it goes into the tensor core. The tensor core itself sees pre-scaled BF16-no post-multiply pass.

5. Asynchronous HBM loads via cp.async. HBM → SMEM transfers use the Hopper cp.async.bulk instructions to maximize bandwidth and hide latency behind SMEM-resident dequant + matmul.

13.2 Marlin performance, qualitatively

  • Decode (memory-bound): ~3× FP16 GEMM throughput. This is the regime where halving weight bytes halves runtime.
  • Prefill (compute-bound): approaches FP16 GEMM throughput. Once arithmetic dominates, the dequant savings are gone, but we don't lose much either.
  • Crossover batch size: roughly 16–32 tokens per forward pass on H100 with Llama-class models. Below this batch, you're memory-bound and INT4 wins big. Above this, INT4 is roughly tied with BF16 on time, but you still save the 4× HBM footprint (which lets you run bigger models or longer contexts).

13.3 Why this matters

The Marlin kernel is what makes W4A16 inference practical. Without it, you'd have a quantized model that loaded 4× faster but ran the matmul at FP16 speed-wasting half the available speedup. With it, decode genuinely doubles or triples.

The Marlin design is now widely copied: vLLM, TensorRT-LLM, and SGLang all integrate Marlin or Marlin-derived W4A16 kernels for their INT4 paths.


14. Mixed-precision inference

Not every layer is equally quantization-tolerant. Empirically:

  • Embedding layers and the LM head are sensitive (small numerical noise → wrong tokens).
  • LayerNorm parameters are tiny but multiplicative; quantizing them is rarely worth it.
  • Early MLP layers and early attention QKV projections are sometimes more fragile than later ones.
  • The down-projection of the MLP block is often the most outlier-sensitive.

Mixed-precision inference exploits this: quantize most layers aggressively (INT4), but keep the few sensitive ones in higher precision (BF16 or INT8).

14.1 LLM.int8() (Dettmers et al., 2022)

The seminal mixed-precision INT8 scheme. Splits each matmul into:

  • The outlier columns (~1% of input channels) → kept in FP16.
  • The regular columns (~99%) → quantized to INT8.
  • Two separate matmuls; results summed.

Implementation in bitsandbytes. Slower than pure-INT8 but enabled the first 175B models to fit on 8×A100s.

14.2 Per-layer schemes in vLLM

vLLM's quantization config supports per-layer precision selection. Typical recipe:

default: int4_awq_g128
overrides:
  - lm_head: bf16
  - model.embed_tokens: bf16
  - model.layers.0.mlp.down_proj: int8     # extra-sensitive, bumped to INT8

The user provides this from a YAML or programmatic API; vLLM dispatches each layer to the appropriate kernel.

14.3 The lookup-free heuristic

For pure W4A16 with AWQ or GPTQ, you usually don't need mixed precision at all on 7B+ dense models. The single biggest win from mixed precision is on:

  • Models smaller than 3B.
  • Models with known outlier-sensitivity (some MoE experts).
  • Long-context regimes where activation tails grow.

If you're shipping a 70B Llama at INT4 AWQ, just go full INT4-the quality is fine.


15. Calibration set design

Calibration is the process of collecting activation statistics (X for GPTQ, mean-abs activations for AWQ, amax for FP8) on a representative dataset.

15.1 Size

  • 100–1000 examples is typical. Below 100, statistics are noisy. Above 1000, marginal gains are small.
  • Each example should be a sequence of typical length (e.g., 512–2048 tokens).
  • Total calibration tokens: ~100K–1M.

15.2 Distribution

The calibration set should match the deployment distribution. Quantization is an empirical approximation-if you calibrate on Wikipedia and deploy on chat dialogues, the activation distribution differs and quality suffers.

For general-purpose chat models, common calibration sets:

  • WikiText-2 / WikiText-103: standard, English encyclopedic prose.
  • C4: web-scrape, more diverse.
  • Pile: research-grade pretraining mix.
  • In-domain prompts: a sample of real production queries (best, when available).

For domain-specific deployments (code, medical, legal), use in-domain calibration. The accuracy difference from generic-vs-in-domain calibration can be 10–30% on downstream evals.

15.3 Length matching

If your deployment uses 8K context, calibrate on 8K-token sequences (or longer than typical). Activation distributions shift with sequence length-the longer the context, the heavier the activation tails in attention layers.

15.4 Mode matching

Calibrate in inference mode with the same prefill/decode pattern your serving stack uses. In particular:

  • Calibrate with a chat template wrapping if you ship with one.
  • Calibrate with the system prompt prepended.
  • Calibrate with the same tokenizer settings (BOS/EOS handling).

Small mismatches here are common bug sources.

15.5 Practical workflow

1. Collect 256 prompts from production logs (or a representative public set).
2. Render them through your full chat template / tokenizer pipeline.
3. Truncate or pad to a common length matching your typical workload.
4. Run forward passes on the original FP16/BF16 model, capturing activations.
5. Feed activations to GPTQ/AWQ/SmoothQuant.
6. Save the resulting quantized model.
7. Evaluate (next section).

16. Evaluation discipline

Quantization is an empirical engineering activity. You ship what passes evaluation. Two orthogonal axes:

16.1 Perplexity

Cheap (~1 GPU-hour for 7B on WikiText-2), sensitive (tenths of a PPL are detectable), but proxy. Run perplexity on a held-out set similar to your calibration distribution but disjoint:

  • WikiText-2 test split is the de-facto baseline.
  • C4 validation slice is also common.

What to look for:

  • INT4 AWQ/GPTQ on 7B+ dense: <0.5 PPL absolute increase from BF16.
  • INT4 on smaller models (<3B): up to 1-2 PPL.
  • W8A8 with SmoothQuant: <1 PPL.
  • FP8 inference: ~0.1 PPL or less.

If you see >2 PPL, something is wrong: bad calibration data, wrong group size, an unfused fold of s somewhere, or a layer that needs to stay BF16.

16.2 Downstream task evaluation

Perplexity is necessary but not sufficient. Always also measure on a battery of downstream benchmarks:

  • MMLU (Massive Multitask Language Understanding): broad knowledge.
  • HumanEval: code generation correctness (pass@1).
  • GSM8K: grade-school math reasoning.
  • HellaSwag / WinoGrande / ARC-Challenge: common-sense reasoning.
  • TruthfulQA: hallucination resistance.

A quantized model can have unchanged perplexity but drop 3% on MMLU-usually a sign that the long-tail factual recall has been damaged by precision loss in the LM head or final layers.

What to look for:

  • Quality-grade INT4: <1% absolute drop on MMLU and GSM8K.
  • Acceptable INT4: <2% drop.
  • W8A8: typically indistinguishable on benchmarks.

16.3 The full evaluation loop

1. Compute baseline metrics on the FP16 reference model.
2. Quantize with method X (AWQ, GPTQ, etc.) at precision P.
3. Recompute all metrics.
4. Compare: tabulate (metric, baseline, quantized, delta).
5. If any delta exceeds your acceptance threshold, iterate:
     - Try the other method.
     - Try larger group size (smaller is better; 64 vs 128 is a common knob).
     - Try mixed-precision (keep one or two fragile layers BF16).
     - Try a different α for AWQ/SmoothQuant.
6. Lock the recipe in version control along with calibration data hash.

The recipe lockdown is non-negotiable: quantization is sensitive enough that "I re-ran AWQ and got 0.3 PPL different" is a real and frequent failure mode. Pin everything: the model checkpoint hash, the calibration sample IDs, the random seed, the version of the quantization tool.

16.4 The ship-criteria

Don't ship a quantization scheme without:

  1. ≤ X PPL increase on at least two domains.
  2. ≤ Y% drop on at least three downstream tasks.
  3. A throughput measurement showing the expected speedup is realized in your serving stack.
  4. A latency tail check (p99 first-token, p99 inter-token)-quantization shifts kernel performance in non-uniform ways.
  5. A quality A/B on real production traffic for a small fraction of users, before full rollout.

17. Practical exercises

These are exercises with full worked solutions where appropriate. Set a calculator aside and try them yourself first.

Exercise 17.1: Memory footprint of a 70B model under various schemes

A Llama-3-70B has ~70.6B parameters. Compute the storage in GB under:

(a) FP32. (b) BF16. (c) INT8 per-channel symmetric. (d) INT4 group=128 with FP16 scale + FP16 zero_point. (e) INT4 group=64 with FP16 scale (symmetric, no zp).

Solution.

Let P = 70.6 × 10^9.

(a) FP32: 4 × P = 282.4 GB.

(b) BF16: 2 × P = 141.2 GB.

(c) INT8 per-channel: 1 × P for the integers, plus one FP16 scale per output channel. Output channels for Llama-3-70B sum to roughly ~10^6 across all linear layers, contributing ~2 MB-negligible. So ≈ 70.6 GB.

(d) INT4 group=128 with FP16 scale + FP16 zp: 4 × P / 8 = 35.3 GB for integers, plus (2 + 2) bytes × P / 128 = 4P/128 bytes ≈ 2.2 GB for scales+zp. Total ≈ 37.5 GB.

(e) INT4 group=64, FP16 scale, no zp: 35.3 GB for ints, plus 2 × P / 64 ≈ 2.2 GB for scales. Total ≈ 37.5 GB. (Smaller group means more scales, but no zp roughly compensates.)

The lesson: scale+zp overhead is meaningful at small group sizes. The crossover where scale overhead equals weight savings is around group = 8 (everything is then a scale).

Exercise 17.2: Effective bits per element

You're storing INT4 group=64, with FP32 scale and INT4 zp. Compute the effective bits per element.

Solution.

bits_per_elem = 4 (int) + 32/64 (scale) + 4/64 (zp)
              = 4 + 0.5 + 0.0625
              = 4.5625 bits/element

vs. INT4 group=128 FP16 scale only at 4.125. Group=64 with FP32 scale almost loses you back the savings of going INT4 in the first place-a useful reminder to use FP16 scales (or even INT8 packed scales in some implementations).

Exercise 17.3: Tracing AWQ on a 4-element row

Given:

W[0, :] = [0.20, -0.10, 0.05, -0.30]
calibration mean-abs activations a = [1.0, 1.0, 8.0, 1.0]

Compute AWQ INT4 quantization with α = 0.5, group_size = 4, Q_max = 7. Compare error vs. plain RTN.

Solution.

Plain RTN:

max|W| = 0.30
scale = 0.30 / 7 ≈ 0.04286
W/scale = [4.667, -2.333, 1.167, -7.0]
round → [5, -2, 1, -7]
W_dequant = [0.2143, -0.0857, 0.0429, -0.3000]
err = [-0.0143, -0.0143, 0.0071, 0.0]

Output error, per-channel:
err × a = [-0.0143, -0.0143, 0.0571, 0.0]
sum |err × a| = 0.0857

AWQ:

s = a^0.5 = [1.0, 1.0, 2.828, 1.0]
geomean = (1 × 1 × 2.828 × 1)^(1/4) ≈ 1.297
s ← s / 1.297 ≈ [0.771, 0.771, 2.181, 0.771]
W' = W × s = [0.1543, -0.0771, 0.1090, -0.2314]
max|W'| = 0.2314
scale' = 0.2314 / 7 ≈ 0.03306
W'/scale' = [4.668, -2.333, 3.298, -7.0]
round → [5, -2, 3, -7]
W'_dequant = [0.1653, -0.0661, 0.0992, -0.2314]
err in W'-space = [-0.0110, -0.0110, 0.0098, 0.0]
err in W-space (divide by s) = [-0.01427, -0.01427, 0.00450, 0.0]

Output error per-channel:
err × a = [-0.01427, -0.01427, 0.0360, 0.0]
sum |err × a| = 0.0645

AWQ reduced the channel-2 output error from 0.0571 to 0.036 - a ~37% reduction on the salient channel. Total summed error fell from0.086to0.065`. The improvement is modest but consistent; with realistic LLM activation tails (10×–100× outliers, not 8×) and across millions of weights, these savings compound to the difference between a usable INT4 model and a broken one.

Exercise 17.4: GPTQ on 2 columns

Suppose:

w = [0.5, 0.3]    (1 row, 2 cols, just for illustration)
H = [[2.0, 0.5],
     [0.5, 1.0]]
H^{-1} = (1/(2×1 - 0.5×0.5)) × [[ 1.0, -0.5],
                                  [-0.5,  2.0]]
       = (1/1.75) × [[1.0, -0.5], [-0.5, 2.0]]
       ≈ [[0.5714, -0.2857],
          [-0.2857, 1.1429]]

Quantize using INT2 (Q_max = 1, codes ∈ {-1, 0, 1}) with per-column scale = 0.5. Trace GPTQ.

Solution.

Column 0:

w[0] = 0.5
q[0] = round(0.5 / 0.5) = 1
w_q[0] = 0.5 × 1 = 0.5
err = 0.5 - 0.5 = 0.0
No update needed for column 1 (err is 0 here, lucky).

Column 1 (no update from column 0):

w[1] = 0.3
q[1] = round(0.3 / 0.5) = 1   (rounds to nearest; 0.6 rounds up)
w_q[1] = 0.5
err = 0.3 - 0.5 = -0.2
No further columns to update.

Result: w_q = [0.5, 0.5]. Layer reconstruction error = (0 + (-0.2))^T H ((0, -0.2)^T) = 0.04 × 1.0 = 0.04.

Now repeat with a non-trivial first-column error. Suppose w = [0.7, 0.3]:

Column 0: q[0] = round(0.7/0.5) = 1, w_q[0] = 0.5, err = 0.2.
OBS update for column 1:
  δw_1 = -err × H^{-1}[0,1] / H^{-1}[0,0]
       = -0.2 × (-0.2857 / 0.5714)
       = -0.2 × (-0.5)
       = +0.1
  w[1] ← 0.3 + 0.1 = 0.4

Column 1: q[1] = round(0.4/0.5) = 1, w_q[1] = 0.5, err = -0.1.
No further columns.

Final w_q = [0.5, 0.5]

Without the OBS update, w[1] would have been 0.3, rounded to 1 (since 0.6 > 0.5), giving the same `w_q = [0.5, 0.5] - same answer in this tiny case. But in a real layer with many columns and a non-pathological Hessian, the OBS updates do change the rounding decisions of later columns and lower reconstruction error noticeably.

Exercise 17.5: SmoothQuant α derivation

Given activation max-abs a = [1, 1, 100] and weight column-max-abs w = [10, 10, 10]. Compute s for α ∈ {0.0, 0.5, 1.0}. After smoothing, what are the new activation and weight max-abs?

Solution.

For α = 0.0: s = a^0 / w^1 = [1, 1, 1] / [10, 10, 10] = [0.1, 0.1, 0.1]. - New x' = x / s, max-abs [10, 10, 1000]. New w' = w × s, max-abs [1, 1, 1]. - Activation outlier got worse (100 → 1000). Bad.

For α = 0.5: s = a^0.5 / w^0.5 = [1, 1, 10] / [√10, √10, √10] ≈ [0.316, 0.316, 3.162]. - New x' max-abs: [1/0.316, 1/0.316, 100/3.162] ≈ [3.16, 3.16, 31.62]. - New w' max-abs: [10×0.316, 10×0.316, 10×3.162] ≈ [3.16, 3.16, 31.62]. - Activation and weight max are now equal-perfectly balanced.

For α = 1.0: s = a^1 / w^0 = [1, 1, 100] / [1, 1, 1] = [1, 1, 100]. - New x' max-abs: [1, 1, 1]. New w' max-abs: [10, 10, 1000]. - Now the weights have the outlier. Bad.

Conclusion: α = 0.5 makes activations and weights equally hard. The empirical optimum varies by architecture but lives in this neighborhood.

Exercise 17.6: End-to-end throughput estimate

You're serving Llama-3-70B at INT4 AWQ on H100 SXM (80 GB, 3.35 TB/s HBM). At batch size 1 (decode), estimate tokens-per-second.

Solution.

Weight footprint: ~37.5 GB (from Ex. 17.1). KV cache per token at FP16: 2 (k+v) × 80 (heads) × 128 (head_dim) × 80 (layers) × 2 (bytes) ≈ 3.3 MB/token. At 4096 context: 13.5 GB KV cache. Total HBM traffic per token: 37.5 GB + (per-token KV reads) ≈ 37.5 GB (KV is incremental per layer; full reads dominate).

Assume Marlin achieves 80% of peak HBM bandwidth in this regime: effective 0.8 × 3.35 = 2.68 TB/s.

Tokens/sec ≈ 2.68 TB/s / 37.5 GB ≈ 71.5 tokens/sec for the weight-traffic component alone.

Adding KV-cache reads (which scale with context length) and overhead, real-world numbers for vLLM-Marlin Llama-70B at INT4 are typically in the 50–80 tokens/sec range at batch 1, 4096 context-broadly consistent with this back-of-envelope. The same model at BF16 would be ~30 tokens/sec (memory-bound, half the throughput).


18. Cheat sheet

18.1 Default recipes

Goal Recipe
Maximum throughput, dense LLM INT4 AWQ group=128 + Marlin kernel
Quality-first, dense LLM INT4 GPTQ group=128 with act-order
W8A8 for older HW (A100, V100) SmoothQuant + GPTQ INT8 weights + dynamic INT8 act
H100 prefill-bound FP8 E4M3 with TransformerEngine
Small model (<3B) INT8 PTQ or mixed-precision INT4
Long-context with KV cache pressure INT4 weights + FP8 KV cache

18.2 Rules of thumb

  • INT4 group=128 ≈ 4.13–4.25 effective bits/element.
  • Halving weight bytes ≈ doubles decode throughput (memory-bound regime).
  • Marlin gives ~3× FP16 throughput on decode, near-parity on prefill.
  • AWQ ~ GPTQ in quality on 7B+ dense; AWQ slightly faster to calibrate, GPTQ slightly more flexible.
  • SmoothQuant α defaults: 0.5 generic, 0.85 for outlier-heavy Llamas.
  • Calibrate with 100–1000 examples matching deployment distribution.
  • Always evaluate perplexity AND downstream tasks before shipping.
  • FP8 ≠ INT8: FP8 has built-in dynamic range; INT8 needs explicit per-channel/per-group scale.

18.3 Common pitfalls

  • Forgetting to fold the AWQ s into the previous layer's parameters → runtime divide that wasn't budgeted.
  • Mismatch between calibration tokenization and deployment tokenization → activation distribution shift, accuracy drop.
  • Using a too-small group size with FP32 scales → effective bits balloons, savings disappear.
  • Static activation quantization with prompts that don't match calibration → silent quality regression.
  • No mixed-precision exception for the LM head → top-k token decisions degrade; benchmarks tank.
  • Skipping downstream eval and trusting perplexity → 0.5 PPL OK but MMLU drops 4 points.
  • Comparing pre-quant and post-quant on different decoding settings → spurious differences.
  • Forgetting the K-axis tile alignment → custom kernels run at half-speed because dequant doesn't fuse with matmul.

18.4 Further reading

  • Lin et al., AWQ, MLSys 2024.
  • Frantar et al., GPTQ, ICLR 2023.
  • Xiao et al., SmoothQuant, ICML 2023.
  • Frantar et al., Marlin, 2024.
  • Dettmers et al., LLM.int8(), NeurIPS 2022.
  • NVIDIA, FP8 Formats for Deep Learning (white paper, 2022).
  • NVIDIA TransformerEngine documentation.
  • The vLLM and TensorRT-LLM source trees-read the actual quantization configs and W4A16 kernels. They are the operational ground truth.

Closing notes

Quantization is the rare topic in AI systems where theory, kernels, and economics align cleanly: a few hundred lines of clean math (affine maps, OBS updates, magnitude redistribution) translate directly into 2–4× throughput, 4× memory savings, and tens of percent of inference cost reduction. It is also the topic where sloppy execution most reliably destroys model quality-quietly, without crashing-because every step is an empirical approximation.

The discipline that distinguishes a competent quantization engineer from a great one is paranoid evaluation. Anyone can run AutoAWQ on a Hugging Face checkpoint. The engineer who validates the result on three benchmarks, two domains, and a production traffic A/B is the one whose model actually ships and stays shipped.

Master this chapter and you have the leverage to halve your inference bill on any LLM workload. Master the failure modes and you have the discipline to do so without the customer noticing-except for their latency improving.

End of Deep Dive 09.

Deep Dive 10-Speculative Decoding and Prefill/Decode Disaggregation

"Decode is sequential. Prefill is parallel. Treating them as one workload was always a compromise-the inference frontier of 2024–2026 is a stack of techniques that finally separates them."

This chapter is a self-contained reference on the two most consequential serving-side ideas of the last two years: speculative decoding, which trades a small amount of extra compute for a large reduction in serial steps, and prefill/decode disaggregation, which physically separates the two phases of LLM inference onto different worker pools. By the end you should be able to derive every formula on a whiteboard, sketch the architecture diagrams from memory, write pseudocode for the speculative loop and the disaggregated request flow, and reason about when these techniques help, when they hurt, and how they compose with the rest of a production inference stack.

The chapter assumes the reader is comfortable with transformer inference at the level of Deep Dives 1–9 of this curriculum: KV-cache, paged attention, continuous batching, chunked prefill, prefix caching, weight quantization. We build directly on those primitives.


1. The Latency Problem

Before we can argue for any new technique, we need to be precise about what is slow and why.

1.1 Two latencies that matter

For any chat-style LLM application, two user-visible latencies dominate the experience:

  • TTFT-Time To First Token. The wall-clock time from the user pressing Enter to the first streamed token appearing. This is dominated by prefill: the model must read all input tokens, populate the KV-cache, and emit the first output token.
  • TPOT-Time Per Output Token. Sometimes called inter-token latency or ITL. The wall-clock interval between consecutive streamed tokens after the first. This is dominated by decode: each output token requires one forward pass.

Total response latency for an output of length N is approximately:

total_latency ≈ TTFT + (N − 1) × TPOT

For a 500-token reply with TTFT = 300 ms and TPOT = 30 ms, total latency is roughly 300 + 499 × 30 ≈ 15.3 s. The decode phase contributes nearly 15 of those 15.3 seconds. Decode dominates total latency for any reply longer than a handful of tokens.

1.2 Why decode is sequential

Each decode step depends on the previous token: token t+1 is sampled from p(· | x_{1..t}), and computing that distribution requires the hidden states from token t. There is no way, at the level of a vanilla autoregressive model, to compute token t+2 before token t+1 is sampled. So decode is intrinsically a serial loop:

for i in 1..N:
    h_i = forward(model, token_{i-1}, kv_cache)
    token_i = sample(h_i)
    kv_cache.append(h_i)

Total decode time for N tokens at batch size 1 is N × T_decode_step.

1.3 Why decode is memory-bound

Each decode step performs a forward pass over the full model with input length 1. The arithmetic is:

  • For each linear layer of weight W ∈ R^{d_out × d_in}, the operation is y = W · x for `x ∈ R^{d_in} - a single matrix-vector multiply.
  • FLOPs: 2 · d_in · d_out.
  • Memory traffic: read W (≈ d_in · d_out · bytes_per_param), read x, write y. Dominant term is reading W.
  • Arithmetic intensity ≈ 2 · d_in · d_out / (d_in · d_out · bytes_per_param) = 2 / bytes_per_param.

For FP16 weights (bytes_per_param = 2), arithmetic intensity is ≈ 1 FLOP/byte. An H100 has roughly 3 TB/s of HBM bandwidth and ~1000 TFLOPs of FP16 tensor-core throughput, giving a balance point (the "ridge" of the roofline) at ~330 FLOPs/byte. Decode at batch=1 is two orders of magnitude below the roofline ridge-it is severely memory-bound. The tensor cores are starved; we are reading weights faster than we can use them.

Now consider the same model running at batch size B. The matrix-vector multiply becomes a matrix-matrix multiply of shape (d_out × d_in) · (d_in × B). Weights are still read once; arithmetic scales with B. Arithmetic intensity becomes 2B / bytes_per_param. At B ≈ 256, FP16 decode finally crosses into compute-bound territory. This is the entire reason continuous batching exists.

1.4 Why prefill is compute-bound

Prefill processes a prompt of length L in a single forward pass. The same matmul becomes (d_out × d_in) · (d_in × L). Arithmetic intensity scales with L. For typical chat prompts (L ≈ 200 − 4000), prefill is firmly compute-bound; the GPU is well-utilized; tensor cores are saturated.

1.5 The optimization tension

Continuous batching solves decode throughput but does not help individual latency: with B = 64, the single-step decode time T_decode_step is roughly the same as at B = 1 (slightly higher because we are now compute-bound), and a request still pays N × T_decode_step for its N tokens. Per-request latency is bound below by a factor that batching does not touch.

That is the lever speculative decoding pulls on: it reduces the number of sequential target-model forward passes per output token, without changing the model. Disaggregation, complementarily, removes the second-order penalty that co-located batching imposes on TTFT and TPOT by tuning prefill and decode hardware independently.


2. Speculative Decoding-Setup and Core Algorithm

The technique was published concurrently in two 2023 papers:

  • Leviathan, Kalman, Matias-Fast Inference from Transformers via Speculative Decoding, ICML 2023.
  • Chen, Borgeaud, Irving, et al. (DeepMind)-Accelerating Large Language Model Decoding with Speculative Sampling, 2023.

Both arrive at the same algorithm with essentially the same correctness proof. We follow the Leviathan et al. notation.

2.1 The two-model setup

We have two language models over the same vocabulary V:

  • Target model M_q, the large model whose distribution q(· | context) we want to sample from. Call its single-step latency T_target.
  • Draft model M_p, a small model with distribution p(· | context). Single-step latency T_draft, with T_draft ≪ T_target. Typically `T_draft / T_target ∈ [0.05, 0.2] - a 7B drafting a 70B sits around 0.1.

The two models share the tokenizer (this matters; cross-tokenizer speculation is a research topic but is messier).

2.2 The speculative step

One speculative step produces between 1 and K + 1 accepted tokens by the following procedure. Let x_{1..t} be the current generated context.

SPECULATIVE_STEP(x_{1..t}):

    # 1. Draft K tokens autoregressively with M_p.
    for i in 1..K:
        p_i = M_p(x_{1..t+i-1})              # distribution over V
        ~x_{t+i} ~ p_i                       # sample draft token
    # Now we have draft tokens ~x_{t+1}, ..., ~x_{t+K} and their probs p_1, ..., p_K.

    # 2. Verify all K positions in ONE forward pass of M_q.
    #    Feed the sequence x_{1..t}, ~x_{t+1}, ..., ~x_{t+K} as if it were prefill.
    #    Get back K+1 distributions q_1, ..., q_{K+1}.
    q_1, ..., q_{K+1} = M_q(x_{1..t}, ~x_{t+1..t+K})

    # 3. Accept-reject loop using rejection sampling.
    n = 0
    for i in 1..K:
        r ~ Uniform(0, 1)
        if r < min(1, q_i(~x_{t+i}) / p_i(~x_{t+i})):
            n += 1                           # accept ~x_{t+i}
        else:
            break                            # reject; stop here

    # 4. Sample one extra "free" token at position t+n+1.
    if n < K:
        # Rejection happened. Sample from corrected distribution.
        q_corrected = normalize(max(0, q_{n+1} − p_{n+1}))
        x_{t+n+1} ~ q_corrected
    else:
        # All K accepted. Sample a bonus token from q_{K+1} for free.
        x_{t+K+1} ~ q_{K+1}

    return n + 1 accepted tokens

Each speculative step costs one target forward pass plus K draft forward passes, and it produces a random number of accepted tokens between 1 and K + 1 (the +1 is the bonus token from q_{K+1} when all draft tokens were accepted, or the corrected sample when one is rejected).

2.3 Correctness-why accepted tokens are exactly distributed as target-only sampling

This is the keystone of the technique. Without this, speculative decoding would change the model's output distribution, which is unacceptable.

Claim. Each accepted token at position t+i has the marginal distribution q_i.

Proof sketch. Consider position t+i. Two cases:

  1. The draft proposes ~x ~ p_i, and we accept with probability min(1, q_i(~x) / p_i(~x)).
  2. If rejected, we sample from q_corrected = normalize(max(0, q_i − p_i)).

The total probability that we end up emitting any specific token y at this position is:

P(emit y) = P(draft proposed y) · P(accept y | drafted y)
          + P(reject)             · P(corrected sample = y)

Compute each piece:

  • P(drafted y) = p_i(y).
  • P(accept y | drafted y) = min(1, q_i(y) / p_i(y)), so the joint p_i(y) · min(1, q_i(y) / p_i(y)) = min(p_i(y), q_i(y)).
  • P(reject) = 1 − Σ_z min(p_i(z), q_i(z)). Using min(a, b) = a − max(0, a−b):
    Σ_z min(p_i(z), q_i(z)) = Σ_z p_i(z) − Σ_z max(0, p_i(z) − q_i(z)) = 1 − Σ_z max(0, p_i(z) − q_i(z))
    
    Equivalently P(reject) = Σ_z max(0, p_i(z) − q_i(z)) = Σ_z max(0, q_i(z) − p_i(z)) (the two are equal because Σ p = Σ q = 1, so Σ (p − q)_+ = Σ (q − p)_+).
  • P(corrected sample = y) = max(0, q_i(y) − p_i(y)) / Σ_z max(0, q_i(z) − p_i(z)).

Substituting:

P(emit y) = min(p_i(y), q_i(y))
          + [Σ_z max(0, q_i(z) − p_i(z))] · [max(0, q_i(y) − p_i(y)) / Σ_z max(0, q_i(z) − p_i(z))]
          = min(p_i(y), q_i(y)) + max(0, q_i(y) − p_i(y))
          = q_i(y)

The last equality uses min(a, b) + max(0, b − a) = b for a, b ≥ 0. So at every position the emission distribution is exactly q_i. The samples at different positions are not independent-but the marginal at every position matches the target-and the joint distribution over the accepted prefix can be shown to match the target's joint by a similar inductive argument. ∎

Why this is non-obvious. The cheap thing-just sampling from p and accepting with high probability when q ≈ p - would *not* be unbiased; it would shift the distribution towardp. The rejection-sampling correction (q − p` clamped and renormalized) is what makes the algorithm exact. The genius of the 2023 paper is in noticing that the correction is computable from the same target forward pass that you needed anyway.

2.4 Bonus token

When all K draft tokens are accepted, we already have q_{K+1} from the verification forward pass-the target's distribution at position t + K + 1 conditioned on the verified prefix. We sample one extra token from it for free. This is why the maximum yield per step is K + 1.


3. The Speedup Formula

Now we derive the wall-clock speedup, which is the whole point.

3.1 Setup

Let: - T_target = wall-clock time for one target forward pass at length 1 (a single decode step). - T_target,K = wall-clock time for one target forward pass on K tokens of new input. For small K (say K ≤ 16), T_target,K ≈ T_target because the target was already memory-bound at batch=1; processing K tokens uses idle compute and adds only marginal time. We approximate T_target,K ≈ T_target in the basic model and refine later. - T_draft = single-step draft forward time. - K = draft length. - α = expected number of accepted tokens per speculative step, where α ∈ [1, K+1]. (Convention: includes the +1 bonus on full acceptance.)

3.2 Tokens per wallclock time

Cost of one speculative step:

T_step = K · T_draft + T_target,K  ≈  K · T_draft + T_target

Tokens emitted per step: α (in expectation).

Tokens per second:

throughput_spec = α / (K · T_draft + T_target)

Baseline (no speculation, single decode step per token):

throughput_base = 1 / T_target

Speedup:

S = throughput_spec / throughput_base
  = α · T_target / (T_target + K · T_draft)
  = α / (1 + K · (T_draft / T_target))

This is the central formula. Memorize it.

3.3 Sanity checks

  • If T_draft → 0 (free draft): S → α. The draft costs nothing; we get exactly α tokens per target call. Best possible.
  • If T_draft → T_target (draft as expensive as target): S → α / (1 + K). We are paying K + 1 target-equivalent calls per α tokens. Almost always worse than baseline (since α ≤ K + 1).
  • If α → 1 (no draft tokens accepted): S → 1 / (1 + K · T_draft / T_target) < 1. Bad draft hurts you.

3.4 Choosing K

α depends on K (more attempts have diminishing returns), and so does T_target,K. Treating α(K) as concave in K and T_target,K ≈ T_target for small K, the optimum balances the increasing numerator against the linear cost in the denominator. In practice, sweep K ∈ {2, 4, 6, 8, 12, 16} for your model pair on representative workloads and pick the empirical maximum. Common production sweet spots are K = 4 to K = 8.

3.5 Worked example

Suppose α = 3.5 (typical for a well-matched draft/target pair), K = 8, T_draft = T_target / 10. Then:

S = 3.5 / (1 + 8 · 0.1)
  = 3.5 / 1.8
  ≈ 1.94×

Roughly 2×. This matches the 2–3× range cited as typical for speculative decoding.

If the draft is faster (T_draft = T_target / 20) and acceptance is similar:

S = 3.5 / (1 + 8 · 0.05) = 3.5 / 1.4 ≈ 2.5×

If we have a great draft (α = 5.0, e.g., from EAGLE-style feature speculation) and T_draft / T_target = 0.1:

S = 5.0 / 1.8 ≈ 2.78×

The published claims of 2–3× across many papers are not coincidence; they fall out of the formula given realistic parameter values.


4. Why Speculation Works in Compute Terms

The throughput formula tells us that it works; the roofline tells us why.

At batch=1 decode, the GPU is memory-bound: tensor cores are running at perhaps 1–3% of peak. The weights are streaming through HBM, doing one matvec per pass, and the FLOP units sit idle.

When we feed the target model K + 1 candidate tokens (the prompt context plus the K draft tokens) in a single forward pass, the matmul shape becomes (d_out × d_in) · (d_in × (K+1)). Arithmetic intensity rises by a factor of K + 1. For K = 8, we are doing 9× the FLOPs while reading the weights only once. Up to roughly the roofline ridge-at FP16, somewhere around K ≈ 64–128 for a single-request decode on H100-this extra FLOP cost is essentially free. It happens in the slack time between memory reads.

This is why the claim T_target,K ≈ T_target for small K is a good approximation. The forward pass still pays the same memory bill (read all weights), and the marginal compute cost is small until K exceeds the arithmetic-intensity ridge.

The fundamental trade made by speculative decoding: we use the GPU's idle FLOPs (which we were paying for anyway) to convert sequential target steps into parallel verification of speculative branches. The currency is GPU compute we weren't using; the payoff is reduced wall-clock latency.


5. Acceptance Rate `α - Where It Comes From

α is the empirical quantity that determines whether speculation pays off. It depends on three things:

  1. Per-token agreement probability between draft and target. For each drafted token, the rejection-sampling acceptance probability is min(1, q(y) / p(y)) averaged over draft samples, which equals Σ_z min(p(z), q(z)). If p ≈ q this is close to 1; if p is wildly different this is close to 0.
  2. Length K: longer drafts give more chances but the geometric drop-off from rejections eventually dominates.
  3. Workload dependence: easy text (boilerplate code, formulaic responses) accepts more readily than hard text (novel reasoning, surprising vocabulary).

5.1 Geometric model

If we assume each token is accepted independently with probability β, then the number of accepted draft tokens is a truncated geometric:

P(n accepted) = β^n · (1 − β)        for 0 ≤ n < K
P(K accepted) = β^K

Expected accepted draft tokens:

E[n] = β · (1 − β^K) / (1 − β)

Plus the bonus token (+1) on every step (corrected sample on rejection, or q_{K+1} sample on full acceptance):

α = E[n] + 1 = (1 − β^{K+1}) / (1 − β)

5.2 Worked numbers

The independence assumption is optimistic but reasonable for back-of-envelope work.

  • β = 0.7, K = 4: α = (1 − 0.7^5) / 0.3 = (1 − 0.168) / 0.3 ≈ 2.77
  • β = 0.7, K = 8: α = (1 − 0.7^9) / 0.3 ≈ (1 − 0.040) / 0.3 ≈ 3.20
  • β = 0.8, K = 8: α = (1 − 0.8^9) / 0.2 ≈ (1 − 0.134) / 0.2 ≈ 4.33
  • β = 0.6, K = 8: α = (1 − 0.6^9) / 0.4 ≈ (1 − 0.010) / 0.4 ≈ 2.47

Published work on Llama-3-8B drafting Llama-3-70B reports per-token acceptance in the 0.6–0.8 range under typical chat workloads, giving accepted lengths of roughly 3–5 with K = 8. These are approximate ranges from public benchmarks; exact numbers vary by workload and are not promises.

5.3 What kills α

  • Tokenizer mismatch. Even small differences (added special tokens, different BPE merges) catastrophically reduce agreement. Always use the same tokenizer for draft and target.
  • Sampling temperature. At T = 0 (greedy), agreement is brittle; one disagreement and the prefix diverges. At higher temperatures both distributions are smoother and min(p, q) mass increases.
  • Out-of-distribution context. If the draft was distilled on a narrow domain and the target is asked something else, the draft's predictions drift.

6. Variants

6.1 Vanilla speculative

Separate draft model. Simplest. Examples in production: Llama-3-8B drafting Llama-3-70B, or a custom 1B distilled draft drafting a 70B+ target.

Trade-offs. Two model copies live in GPU memory. The draft must be served on the same hardware (or close to it) to avoid network latency in the inner loop. Pipeline complexity rises: two sets of weights, two KV-caches, two CUDA streams.

6.2 Self-speculative

The draft is a part of the target model rather than a separate model.

  • Layer-skip self-speculative. Run only the early layers of the target as the draft (e.g., first 8 layers of a 32-layer model). The draft "head" is the target's own LM head. Cheap because no extra weights, but α is usually lower because the early-layer representation lacks the depth to predict tokens accurately.
  • Distilled head. Train an additional lightweight head on the target's hidden states to predict the next token. Slightly more weights, often higher α.

6.3 Medusa (Cai et al., 2024)

Add M extra "Medusa heads" on top of the target's last hidden state. Each head is a small MLP that predicts the token at offset +1, +2, ..., +M from the current position, in parallel.

h_t = target_last_hidden(x_{1..t})
prediction_at_offset_j = MedusaHead_j(h_t)        for j = 1..M

So one forward pass of the target produces M candidate tokens at positions t+1, ..., t+M. To verify, the target processes the `M - token candidate as if it were prefill (the same trick as vanilla speculation), with tree-attention to handle multiple branches per position (see Section 7).

Why it's clever. No separate draft model. The Medusa heads add modest parameter count (a few percent of the target). Training fine-tunes only the heads while freezing the backbone, or fine-tunes both with a multi-objective loss.

Limitations. Each head predicts in isolation given h_t; later positions are increasingly hard to predict from h_t alone (no chain of conditioning), so per-offset acceptance falls off rapidly. Mitigated by sampling multiple candidates per offset and using tree-attention.

6.4 EAGLE / EAGLE-2 (Li et al., 2024)

Train a small autoregressive model that operates on the target's hidden states, not on tokens. The auxiliary model takes the target's last-layer hidden states h_{1..t} as input and predicts the next hidden state ~h_{t+1}, then the next, etc. Each predicted hidden state is decoded to a token via the target's LM head.

The intuition: predicting the next hidden state gives the auxiliary model access to a much richer signal than just the next token, dramatically improving α. The auxiliary model is small (typically a single transformer block plus a regression head).

EAGLE-2 adds dynamic tree expansion (branch where uncertain, prune where confident), pushing acceptance lengths higher.

As of 2026, EAGLE/EAGLE-2 is the de facto state of the art for self-speculative decoding on open-weight models. Reported accepted lengths sit in the 4–6 range with K = 8 on standard chat benchmarks, but specific numbers vary by setup and should be verified.

6.5 Lookahead decoding (Fu et al., 2024)

No draft model at all. Instead, exploit n-gram patterns from the target's own previous generations (or training data) to propose continuation candidates, then verify with a single target forward pass using a Jacobi-iteration-style parallel proposal.

Use case. When you can't ship a draft model (memory, deployment, latency constraints) but want some of the speedup. Gains are smaller (typically 1.3–1.8×, range approximate) but free in deployment terms.


7. Tree-Based Speculation

So far we've described linear speculation: a chain ~x_{t+1}, ~x_{t+2}, ..., ~x_{t+K}. One rejection terminates the prefix. With per-token acceptance β = 0.7 and K = 8, we get α ≈ 3.2.

Tree speculation generalizes to a tree of candidate continuations rooted at position t:

                t
               /|\
            a   b   c        # candidates at t+1
           /|   |\   \
         a1 a2 b1 b2 c1      # candidates at t+2 conditioned on parent

Each branch is a possible continuation. The target verifies the entire tree in one forward pass, using a custom tree attention mask so each node attends only to its ancestors.

TREE_VERIFY(tree):
    flatten tree to a sequence of nodes
    construct attention mask:
        node i attends to node j  iff  j is an ancestor of i in the tree
    one target forward pass on the flattened sequence with this mask
    for each root-to-leaf path:
        run the rejection-sampling accept/reject loop
    accept the longest accepted prefix across all paths

7.1 Why it helps

Multiple candidates at each depth give the target multiple chances to accept something. The expected accepted depth is higher than for a linear draft of the same total node count, because divergence in one branch doesn't kill the others.

7.2 Cost

Tree attention costs a forward pass on |tree| tokens with a custom mask. Modern attention kernels (FlashAttention with a tree mask, or specialized kernels) handle this efficiently; the overhead vs. linear verify is modest if |tree| is comparable to K.

7.3 Branching policy

How wide should the tree be? Branching factor b at each depth gives b^d leaves at depth d. Naively this explodes. Practical implementations use dynamic tree expansion: branch wider where the draft is uncertain (high entropy), prune where confident. EAGLE-2 popularized this.

A typical production tree has 25–60 nodes total, with deeper nodes thinner than shallower ones-e.g., (3, 3, 2, 2, 1, 1) branching across depths.


8. The Speculative-Batching Tension

Speculative decoding wins big at batch=1. Continuous batching wins big at batch≫1. They fight each other.

8.1 The trade

At batch=1: decode is memory-bound. Verifying K + 1 candidates uses idle compute, no extra wall-clock cost. Net win: ~2×.

At batch=64 (continuous batching at scale): decode is already compute-bound. The tensor cores are saturated emitting one token per request per step. Now verifying K + 1 candidates per request multiplies the compute load by ~K. The per-step time grows roughly linearly in K, and most of those extra-flop tokens are rejected. We are paying compute to do speculative work that produces only α tokens per step instead of one-but the cost grew by K + 1, not by 1.

The throughput formula at high batch (where T_target,K ≈ K · T_target because we are now compute-bound) becomes:

S_high_batch ≈ α / (K + K · T_draft / T_target) = α / (K · (1 + T_draft / T_target))

For α = 3.5, K = 8, T_draft / T_target = 0.1:

S_high_batch ≈ 3.5 / (8 · 1.1) ≈ 0.40

Speculation actively hurts throughput at high batch. This is not a small effect; it's a 2–3× slowdown vs. plain continuous batching.

8.2 The production pattern

The reconciliation: enable speculation adaptively.

ADAPTIVE_SPECULATION_POLICY(batch_state):
    if current_batch_size <= LOW_BATCH_THRESHOLD:
        use_speculation = True
    elif current_batch_size >= HIGH_BATCH_THRESHOLD:
        use_speculation = False
    else:
        use_speculation = (priority == HIGH)   # for low-latency requests only

    return use_speculation

Typical thresholds: speculation on at batch ≤ 8, off at batch ≥ 32, with a high-priority override in the middle band.

This is the kind of policy decision that gets made at the scheduler level, not at the model level. It's also why speculative decoding is sometimes called "a single-user technique"-in pure-throughput regimes (training, batch inference at scale) it doesn't pay.

8.3 An important nuance

Speculation can still win at moderate batch if the target has spare compute headroom-for example, if the GPU is bandwidth-limited by quantized weights (W4 weights at batch ≤ 16 are still memory-bound on H100). The crossover batch size is workload- and hardware-specific. The right answer is to measure and program the scheduler accordingly.


9. Engineering Speculative Decoding

The algorithm is short. The implementation has corners.

9.1 Two KV-caches

Both draft and target maintain their own KV-cache. They must stay in lockstep with the accepted prefix, not the proposed prefix.

        accepted prefix    last accepted token    proposed (not yet accepted)
  target KV: [............................ T ]
  draft  KV: [............................ T ][~x_{t+1} ~x_{t+2} ... ~x_{t+K}]

When verification finishes:

  • If n tokens accepted (n ≤ K), the draft's KV-cache for positions t+n+1 .. t+K must be rolled back (dropped). The draft re-drafts from position t+n+1 next step.
  • The target's KV-cache must be extended with positions t+1 .. t+n+1 (the accepted prefix, including the bonus or corrected token).

For paged-attention KV-caches, "rolling back" the draft is a matter of returning pages to the free pool (cheap). For contiguous KV-caches, rolling back means truncating the cache pointers.

9.2 The verification forward pass

The target verifies K candidates by running a forward pass on K new tokens given its existing cache. This is identical to a chunked-prefill of length K. Existing prefill kernels handle it directly; no new attention kernel needed (unless using tree speculation, which needs a custom mask).

9.3 Synchronization

Naive implementation runs draft and target serially:

1. draft K steps
2. target verify
3. accept/reject
4. update both caches
5. goto 1

At step 1, the target is idle. At step 2, the draft is idle. On a single GPU, idle silicon is wasted money.

Optimization. Pipeline the draft for step t+1 during the target's verification of step t. The draft uses recently accepted tokens; if the target ends up rejecting some, the draft has to roll back, but the expected work is reduced.

This is the same kind of speculative-on-speculation idea as branch prediction in CPUs. Implementations vary; reference open-source implementations include the speculative path in vLLM, TensorRT-LLM, and SGLang.

9.4 Batching speculation

Within a batch, different requests will have different acceptance lengths per step. After a step, request A has 4 new tokens, request B has 1, request C has 5. The next iteration's batch is jagged. The scheduler must handle this. Continuous batching frameworks generally do; the bookkeeping is per-request KV-cache offsets and per-request next-token positions.

9.5 Pseudocode for the speculative loop

SPECULATIVE_GENERATE(prompt, max_tokens, K):
    target_kv = prefill(M_target, prompt)
    draft_kv  = prefill(M_draft, prompt)
    output = []

    while len(output) < max_tokens:
        # 1. Draft K tokens
        draft_tokens = []
        draft_probs  = []
        ctx = output[-1] if output else last_prompt_token
        for i in 1..K:
            p = M_draft.step(ctx, draft_kv)
            t = sample(p)
            draft_tokens.append(t)
            draft_probs.append(p)
            ctx = t

        # 2. Target verify in one pass
        q_dists = M_target.forward(draft_tokens, target_kv)   # K+1 distributions

        # 3. Accept-reject
        n_accepted = 0
        for i in 1..K:
            r = uniform(0, 1)
            if r < min(1, q_dists[i][draft_tokens[i]] / draft_probs[i][draft_tokens[i]]):
                n_accepted += 1
                output.append(draft_tokens[i])
            else:
                # Corrected sample
                q_corr = normalize(max(0, q_dists[i] - draft_probs[i]))
                output.append(sample(q_corr))
                break

        if n_accepted == K:
            # Bonus token
            output.append(sample(q_dists[K+1]))

        # 4. KV-cache hygiene
        target_kv.commit(n_accepted + 1)              # accepted + corrected/bonus
        draft_kv.rollback_to(target_kv.length)        # discard rejected draft positions
        draft_kv.append_token(output[-1])             # so next draft step starts from accepted token

    return output

10. Disaggregated Inference-The Motivation

The second frontier idea: stop running prefill and decode on the same workers.

10.1 The problem with co-location

A co-located worker handles a stream of mixed requests. At any moment its scheduler chooses which requests to execute and in which phase:

  • Some requests are in prefill (compute-bound).
  • Some are in decode (memory-bound).
  • The scheduler wants to batch.

The optimal batching policy for prefill differs from the optimal policy for decode:

  • Prefill is compute-bound for any non-trivial prompt. Larger batches do not help much (we are already at the roofline ridge from the long sequence dimension), and they hurt latency by stretching T_prefill. Prefill prefers small batches-sometimes batch = 1-for low TTFT.
  • Decode is memory-bound at batch=1 and compute-bound around batch ≈ 256. Decode wants the largest batch the GPU memory allows, for throughput.

A co-located worker must serve both. Common compromises:

  1. Prefill-then-decode flushing. At each scheduler tick, run pending prefill, then a decode tick. Decode requests wait for prefill to finish; prefill requests wait for decode to flush. Latency for both phases suffers.
  2. Chunked prefill. Slice prefill into chunks of C tokens and interleave with decode steps in the same forward pass. Smooths the latency, but a chunk of prefill in the same forward pass as decode costs the decode requests time (because the forward pass runs at the longer sequence length).
  3. SLO violations under load. As load rises, the queue mixes more aggressively; both TTFT and TPOT degrade simultaneously. There is no separate knob to tune them independently.

10.2 The disaggregation insight

If we physically separate prefill workers and decode workers, each pool can be:

  • Sized independently (e.g., 1 prefill worker per 4 decode workers, depending on workload mix).
  • Tuned independently (different batch sizes, different scheduler policies).
  • Even on different hardware (prefill on H100s for compute, decode on cheaper GPUs with high HBM bandwidth).

The cost: a request must move between workers, which means transferring its KV-cache from the prefill worker to the decode worker.


11. DistServe (Zhong et al., OSDI 2024)

DistServe is the canonical reference design for disaggregated LLM serving.

11.1 Architecture

                  ┌───────────────┐
   request ─────► │  Global       │
                  │  Scheduler    │
                  └───┬───────┬───┘
                      │       │
                      ▼       ▼
            ┌────────────┐  ┌────────────┐
            │  Prefill   │  │  Prefill   │  ...   prefill pool
            │  Worker 1  │  │  Worker 2  │
            └─────┬──────┘  └─────┬──────┘
                  │ KV-cache      │
                  │ over RDMA     │
                  ▼               ▼
            ┌────────────┐  ┌────────────┐
            │  Decode    │  │  Decode    │  ...   decode pool
            │  Worker 1  │  │  Worker 2  │
            └─────┬──────┘  └─────┬──────┘
                  │               │
                  └──────► token stream to user

The request flow:

1. Request arrives at scheduler.
2. Scheduler picks a prefill worker (load balancing).
3. Prefill worker: prefill, populate KV-cache.
4. Scheduler picks a decode worker (load balancing on KV memory pressure).
5. KV-cache transferred from prefill worker to decode worker (RDMA).
6. Decode worker: continuous-batch decode until completion.
7. Tokens streamed to client during decode.

11.2 Why each pool can be tuned independently

  • Prefill pool. Optimize for low TTFT under SLO. Use small batches (often 1), maybe with chunked prefill for very long prompts. Configure for fast tensor cores. Prefill workers do not need huge KV memory.
  • Decode pool. Optimize for high decode throughput. Use the largest continuous batch the GPU memory allows. Configure for HBM bandwidth. Decode workers do need huge KV memory and benefit from W4-quantized weights.

11.3 KV-cache transfer

This is the new cost. Per-request KV-cache size (FP16, no quantization) for a model with L layers, H heads, head dim d, and S sequence length:

KV_size_bytes = 2 (K and V) · L · H · d · S · 2 bytes (FP16)
              = 4 · L · H · d · S

For Llama-3-70B (L=80, H_kv=8 after GQA, d=128) at S=8192:

KV_size = 4 · 80 · 8 · 128 · 8192 ≈ 2.7 GB    (FP16, with GQA)

For multi-head attention without GQA the same model would be ~22 GB. Most modern large models use GQA, putting per-request KV in the few-GB range.

Transfer time over RDMA (NVLink between GPUs ≈ 600 GB/s, InfiniBand HDR ≈ 25–50 GB/s, GH200 NVLink-C2C even higher): for ~3 GB at 50 GB/s, transfer takes ~60 ms. At 200 GB/s (NVLink), ~15 ms.

Crucially, the transfer is overlap-able with the first decode step on the receiving worker. If the prefill worker streams the KV-cache layer-by-layer, the decode worker can begin processing as soon as the first layer arrives. End-to-end transfer cost can be made sub-step-time with careful scheduling.

11.4 Reported gains

The DistServe paper reports order-of-magnitude (4–7×) reductions in achievable load while meeting both TTFT and TPOT SLOs, compared to co-located baselines. These are the paper's reported numbers; actual gains depend heavily on workload and hardware. The mechanism is straightforward: by separating the two phases, the scheduler is no longer forced into compromises that hurt both metrics.

11.5 Load balancing across pools

The global scheduler decides:

  • Which prefill worker gets the next request? Pick the one with the shortest prefill queue and enough free KV memory for the prompt's KV.
  • Which decode worker gets the request after prefill? Pick the one with the most free KV memory (decode pressure on memory; not on compute, which is shared across the batch).
  • How to size the pools? The ratio depends on workload. For chat (short prompts, long replies), more decode workers. For RAG / long-context summarization (long prompts, short replies), more prefill workers. The DistServe paper proposes a search/profiling procedure to size the pools.

12. Splitwise (Patel et al., 2024)

Microsoft's variant on the same idea, with a sharper focus on heterogeneous hardware.

12.1 Key idea

  • Prefill is compute-bound → use the most compute-dense GPUs (H100, MI300X).
  • Decode is bandwidth-bound → use GPUs with the best $/GB-of-HBM-bandwidth (A100, sometimes older accelerators that are cheaper but still bandwidth-rich).

A Splitwise-style cluster runs different GPU SKUs in different pools. You buy fewer expensive GPUs for prefill and more cheaper GPUs for decode. The economics improve substantially for workloads with imbalanced phase costs.

12.2 Trade-offs

Heterogeneity adds operational complexity (different drivers, different profiling, different failure modes), and KV-cache transfer between heterogeneous nodes is more constrained (PCIe rather than NVLink in some configs). Splitwise's published benchmarks demonstrate the cost-performance frontier; the actual deployment ratios are workload-specific.


13. Mooncake (Qin et al., 2024)

Moonshot AI's serving architecture for their Kimi chat product. Published in 2024 with full architectural detail.

13.1 KVCache-centric design

Mooncake's organizing principle: the KV-cache is the central data structure of the serving system, not the GPU worker.

The design:

  • A distributed KV-cache pool spans CPU memory across the cluster (and SSDs as a backing store). Total capacity is far larger than what fits on the GPUs alone.
  • GPUs hold only the working set of KV-cache they need right now. Other entries live in CPU memory or SSD.
  • A scheduler routes requests to whichever GPU has the relevant prefix already hot-or, if none, to the GPU that can most cheaply load the prefix from the pool.

13.2 Why this matters for chat

Chat workloads have enormous prefix overlap:

  • System prompts repeat across requests.
  • Multi-turn conversations share their entire history.
  • Tools / RAG contexts are reused.

A 32K-token system prompt prefilled fresh every request is pure waste; with prefix caching its KV is already computed and we just have to find it and use it. Prefix cache hit rates on production chat traffic are commonly cited in the 50–80% range (range approximate; depends entirely on workload).

Mooncake's distributed pool maximizes the chance of a hit anywhere in the cluster.

13.3 Disaggregation in Mooncake

Mooncake is also disaggregated (prefill / decode separation), and the two ideas compose: prefill workers consult the global cache before doing any work; if the prefix is hit, they may skip prefill entirely and just hand the cached KV to a decode worker.

13.4 Reported gains

The Mooncake paper documents substantially higher GPU utilization and request throughput than co-located baselines on Moonshot's production traffic. Specific numbers in the paper depend on their workload mix; treat as approximate. The architectural lesson is robust: at production scale, the global KV-cache is the system.


14. KV-Cache Transfer in Detail

The cost that disaggregation pays. Worth understanding precisely.

14.1 Sizing

Per-request KV-cache (FP16, GQA, L layers, H_kv KV heads, d head dim, S sequence length):

KV_size = 4 · L · H_kv · d · S    bytes

Approximate scenarios (FP16, GQA with H_kv = 8):

Model L S KV size
Llama-3-8B 32 8K ~270 MB
Llama-3-70B 80 8K ~2.7 GB
Llama-3-70B 80 32K ~10.7 GB
Llama-3-405B 126 8K ~4.2 GB
Llama-3-405B 126 128K ~67 GB

These are uncompressed. KV-quantization (INT8, sometimes INT4) cuts these by 2× or 4× at small accuracy cost.

14.2 Transfer bandwidth

Link Approx bandwidth
NVLink 4 (intra-node, H100) 600 GB/s
NVSwitch fabric (inter-node, NVL72) ~900 GB/s aggregate
InfiniBand NDR 50 GB/s per port
InfiniBand HDR 25 GB/s per port
PCIe Gen5 x16 64 GB/s (often less in practice)

Transfer time = KV_size / bandwidth.

For 3 GB over 50 GB/s InfiniBand: 60 ms. For 3 GB over 600 GB/s NVLink: 5 ms. For 10 GB over 50 GB/s: 200 ms.

14.3 Overlap with first decode step

The transfer can overlap with the first decode step on the receiving worker. The receiving worker needs the KV-cache for layer only when it computes attention at layer . If the prefill worker streams KV in layer order, and the decode worker is processing layers in the same order, the only thing that needs to arrive before decode can start is the KV for layer 0. After that, the rest can stream in parallel with computation, hiding most of the transfer latency.

This "layer-pipelined" KV transfer is implemented in DistServe and extensively in Mooncake. It is the engineering move that makes disaggregation production-viable.

14.4 Disaggregation pseudocode

DISAGGREGATED_REQUEST(prompt):
    # 1. Scheduler routing
    prefill_worker = scheduler.pick_prefill_worker(prompt)
    decode_worker  = scheduler.pick_decode_worker(prompt)

    # 2. Prefix cache lookup (if available)
    cached_kv, cache_offset = global_cache.lookup(prompt)
    if cache_offset == len(prompt):
        # full hit: skip prefill entirely
        kv = cached_kv
    else:
        # 3. Prefill (possibly chunked, possibly resuming from cache)
        prefill_worker.load_kv_prefix(cached_kv)
        kv = prefill_worker.prefill(prompt[cache_offset:])

    # 4. Streamed transfer to decode worker (layer-pipelined)
    transfer_handle = prefill_worker.stream_kv_to(decode_worker, kv)

    # 5. Decode worker waits for layer-0 KV, then begins
    decode_worker.await_layer(transfer_handle, layer=0)
    for tok in decode_worker.decode_loop(transfer_handle):
        yield tok          # stream to client

    # 6. Async write-back to global cache for future hits
    global_cache.insert_async(prompt, kv)

15. Combining the Techniques-A Production Stack

Each individual technique gives a multiplicative factor. Production-grade inference is a stack.

15.1 The full stack, ordered

A modern serving stack (2026) typically includes:

  1. Paged attention-non-contiguous KV-cache allocation, eliminates fragmentation. Enables (2)–(4).
  2. Continuous batching-token-level scheduling across requests. Pushes decode toward compute-bound.
  3. Chunked prefill-slices long prefills into chunks that fit alongside decode. Smooths TTFT under load.
  4. Prefix caching-global KV reuse across requests with shared prefixes. Eliminates redundant prefill.
  5. W4 weight quantization-4-bit weights with FP16 activations. Cuts memory traffic ~4×, important for decode.
  6. (Optional) Disaggregation-separate prefill and decode pools. Lets you hit both TTFT and TPOT SLOs at higher load.
  7. (Optional) Speculative decoding-adaptively enabled at low batch / high priority. Cuts per-request decode latency ~2×.

15.2 Hypothetical attribution

A back-of-envelope walkthrough of how each layer contributes. Numbers are illustrative, not measured.

Suppose a baseline naive implementation of Llama-3-70B serves at throughput 1× (whatever absolute units we choose) with TTFT and TPOT both far above SLO at moderate load.

  • Add paged attention: little throughput change at low load, but enables larger effective batch size (less wasted KV memory) → ~1.5× throughput at high load.
  • Add continuous batching: now at batch ≈ 64 effectively → ~5–10× throughput vs naive (decode crossing into compute-bound regime). TTFT only modestly improved.
  • Add chunked prefill: TTFT under load improves ~2–3×; throughput roughly flat.
  • Add prefix caching (chat workload, 60% prefix hit rate): effective prefill compute ~0.4× of original; TTFT improves another ~2×; throughput modestly better.
  • Add W4 quantization: decode bandwidth-bound regime improves ~3× (we read 4 bits of weight per active param instead of 16). Throughput at low-to-moderate batch ~2–3× better; high-batch gains smaller.
  • Add disaggregation: TTFT and TPOT can both meet SLO at higher load. SLO-attainable throughput at SLO ~3–5× the co-located version (per DistServe-class results, range approximate).
  • Add adaptive speculative decoding: low-batch / high-priority requests see ~2× lower TPOT.

Multiplied together, a fully-optimized stack lands roughly 50–200× the naive baseline on relevant metrics, depending on workload. These factors are illustrative and not promises. What matters is that they are roughly multiplicative-none of them subsumes the others; each addresses a different bottleneck.

15.3 What doesn't compose

A few combinations require care:

  • Speculation + high batch: as analyzed, hurts. Use adaptive policy.
  • Speculation + tree attention + paged attention: requires the paged attention kernel to support custom masks. Most modern kernels (FlashAttention v3, vLLM's paged kernels) do.
  • Disaggregation + prefix caching: requires the prefix cache to be global, not per-worker. Otherwise prefix hits collide with worker locality. This is exactly Mooncake's design.
  • Disaggregation + speculation: speculation lives entirely on the decode worker. The prefill worker doesn't need to know about it. Compose freely.

16. Frontier Directions (research-stage as of 2026)

These are active research areas. Treat as ideas to track, not as production techniques.

16.1 Continuous depth / early-exit (research-stage)

Not all tokens need all the model's layers to predict correctly. "Easy" tokens (function words, boilerplate) might be settled by layer 12 of a 32-layer model. Early-exit decoding adds a per-layer prediction head and exits when the head's confidence exceeds a threshold.

Status: works in research, but production deployment is rare because (a) the calibrated thresholds are workload-specific, (b) early-exit at layer produces partial KV-cache (only layers up to ), which other speculative-style techniques want to be complete. Active research as of 2026.

16.2 Multi-token prediction (Gloeckle et al., 2024) (research-stage in production deployment)

Train the model to predict the next N tokens directly, rather than just the next token, by adding N parallel output heads. The 2024 paper showed gains on code and reasoning benchmarks. Decode can then emit N tokens per forward pass (with verification, similar to Medusa).

Status: training-time technique, requires retraining the model. Some open-weight frontier models (2025–2026) include MTP heads natively. Production adoption growing but not yet ubiquitous.

16.3 Diffusion language models (research-stage)

Cast text generation as a diffusion process over the entire output sequence, denoising all positions in parallel. Several papers (2023–2026) demonstrate non-autoregressive parallel decoding, with quality approaching autoregressive baselines on some tasks.

Status: research-stage. Quality gap with autoregressive models has narrowed but not closed for general chat as of 2026. Watch this space-if the gap closes, decode is no longer sequential, and the entire framing of "decode is the bottleneck" changes.


17. Practical Exercises

Six problems. Treat them as if you were on a whiteboard with a 70B-model engineer; show the derivations.

Exercise 1-Derive the speedup formula

State and derive the speculative decoding throughput speedup formula S = α / (1 + K · T_draft / T_target). Identify each assumption and where it can break.

Solution sketch. Tokens per step: α (definition). Time per step: T_target,K + K · T_draft. Approximation: T_target,K ≈ T_target for K below the arithmetic-intensity ridge of the target model on the current hardware. Throughput: α / (T_target + K · T_draft). Baseline throughput: 1 / T_target. Ratio: α · T_target / (T_target + K · T_draft) = α / (1 + K · T_draft / T_target). Breaks when (a) batch is high enough that T_target,K is not approximately T_target (compute-bound regime), (b) draft model causes target cache contention (e.g., they share GPU memory and one evicts the other).

Exercise 2-Compute the speedup for given parameters

Given α = 3.5, K = 8, T_draft = T_target / 10, compute the expected speedup. Then redo with T_draft = T_target / 5. Then with α = 2.0.

Solution. - Base case: S = 3.5 / (1 + 8 · 0.1) = 3.5 / 1.8 ≈ 1.94×. - T_draft = T_target / 5: S = 3.5 / (1 + 8 · 0.2) = 3.5 / 2.6 ≈ 1.35×. - α = 2.0, T_draft = T_target / 10: S = 2.0 / 1.8 ≈ 1.11×. Marginal.

Lesson: speedup is sensitive to both α and T_draft / T_target. A weak draft (low α) or a slow draft (high T_draft) both kill the gain.

Exercise 3-Derive α from per-token acceptance β

Assume each draft token is accepted independently with probability β. Derive α = (1 − β^{K+1}) / (1 − β). Check the limits β → 0 and β → 1.

Solution sketch. Number of accepted draft tokens n is truncated geometric: P(n=k) = β^k(1−β) for k < K, P(n=K) = β^K. Expected accepted draft: E[n] = Σ_{k=0}^{K−1} k β^k (1−β) + K β^K = β (1 − β^K) / (1 − β). Plus 1 bonus / corrected token always: α = β(1 − β^K)/(1 − β) + 1 = (1 − β^{K+1})/(1 − β). Limit β → 0: α → 1 (just the corrected sample). Limit β → 1: α → K + 1 (every draft accepted plus bonus).

Exercise 4-KV-cache transfer budget

A disaggregated cluster runs Llama-3-70B at 32K context with FP16 KV (GQA, H_kv = 8, d = 128, L = 80). Transfer between prefill and decode workers is over a 50 GB/s link. (a) Compute KV-cache size per request. (b) Compute raw transfer time. (c) The decode worker's first decode step takes 30 ms. By overlapping transfer with the first decode step, what fraction of the transfer time can be hidden? (d) What if you switch to INT8 KV-cache?

Solution. (a) KV = 4 · 80 · 8 · 128 · 32768 = 10.7 GB (FP16, 2 bytes per element accounted for in the leading 4 = 2(K+V) · 2 bytes). (b) 10.7 GB / 50 GB/s = 214 ms. (c) The first 30 ms of transfer overlap with the first decode step. Hidden fraction: 30 / 214 ≈ 14%. Most of the transfer is not hidden by a single decode step. To fully hide, layer-pipelining: layer 0 transfer (≈ 134 MB at 50 GB/s ≈ 2.7 ms) finishes before decode starts; subsequent layers stream in parallel with later decode steps (not just the first). Each decode step is 30 ms; per-layer transfer is 10.7 GB / 80 / 50 GB/s ≈ 2.7 ms. As long as decode step time > per-layer transfer time, layer-pipelined transfer hides fully-true here. (d) INT8 halves KV to 5.35 GB, halves transfer to 107 ms; same layer-pipelining argument hides it even more easily.

Exercise 5-Adaptive speculation policy

Design the scheduler logic for adaptively enabling speculative decoding in a continuous-batching server. Inputs available: current_batch_size, priority flag per request, recent_α_estimate, T_draft, T_target, current GPU compute utilization. Output: per-request speculation enable/disable.

Solution sketch.

SHOULD_SPECULATE(request, batch_state, system):
    # Compute expected throughput with and without speculation
    # at the current batch size
    α      = recent_α_estimate
    T_d    = T_draft
    T_t    = T_target
    B      = batch_state.size
    util   = system.compute_utilization()    # 0..1, fraction of FLOPs used

    # Rough single-step time scales as max(memory_bound_time, compute_bound_time).
    # At low util the headroom for K-token verify is essentially free.
    # At high util the K-token verify multiplies cost ~K.
    headroom = 1.0 - util
    effective_K_cost = K * (1 - headroom * 0.7)    # heuristic

    spec_throughput = α / (1 + effective_K_cost * T_d / T_t)
    if request.priority == HIGH:
        # always favor latency for high priority, even at small loss
        return spec_throughput > 0.9
    return spec_throughput > 1.05                  # only when net gain

Real implementations measure α online per request class and re-evaluate the policy periodically.

Exercise 6-Workload routing for disaggregated serving

You operate a disaggregated cluster with 8 prefill workers (H100) and 16 decode workers (A100) serving Llama-3-70B. Three workload classes share the cluster: - (W1) Chat: 200-token prompts, 400-token replies, 40% of traffic. - (W2) RAG: 8K-token prompts, 100-token replies, 50% of traffic. - (W3) Code completion: 1K-token prompts, 50-token replies, 10% of traffic.

(a) Reason about whether the prefill/decode pool sizing is appropriate. (b) Propose routing rules. (c) Where should prefix caching matter most?

Solution sketch. (a) Prefill compute is roughly proportional to prompt length. Weighted prompt length = 0.4·200 + 0.5·8000 + 0.1·1000 = 80 + 4000 + 100 = 4180. Weighted reply length = 0.4·400 + 0.5·100 + 0.1·50 = 160 + 50 + 5 = 215. Decode compute is roughly proportional to reply length × batch utilization; it's bandwidth-bound, so what matters is how many concurrent requests we can keep in decode given KV memory. With ~3 GB KV per request average and ~80 GB usable per A100, ~26 concurrent decode requests per A100. With 16 A100s, ~420 concurrent decode slots.

The prefill load is dominated by RAG (W2). 8 H100s might be tight or generous depending on prefill tokens/sec per H100. Sketch: an H100 prefilling 70B at FP16 hits ~10K–20K tokens/sec (range approximate). 8 of them ≈ 100K tokens/sec. To serve a workload mix where each request is ~4180 prompt tokens on average, that's ~24 requests/sec arrival rate sustainable. If your traffic exceeds that, add prefill workers.

(b) Routing rules: - All workloads use the same prefill/decode pools. - Within prefill: prioritize W3 (low-latency code completion) over W2 (RAG can tolerate higher TTFT). Hold W2 in a chunked-prefill queue to avoid head-of-line blocking on the H100s. - Within decode: continuous-batch all three. Pin W3's decode to a smaller, lower-latency decode pool subset if TPOT SLOs differ.

(c) Prefix caching matters most for W1 (chat: huge multi-turn prefix overlap) and W2 (RAG: shared system prompt + retrieved-context overlap on popular queries). It matters least for W3 (code completion is mostly novel context per request, though shared system prompt for the IDE may still help).


18. Summary

The 2024–2026 inference frontier is a story about separating concerns that were always different.

  • Decode is sequential and memory-bound. Speculative decoding pulls the lever of converting sequential target steps into parallel verification of speculative branches, paid for with idle GPU compute that we already owned. The math gives 2–3× per-request speedup at low batch.
  • Prefill is parallel and compute-bound; decode wants the opposite hardware policy. Disaggregation finally separates them, giving each its own pool, each tuned independently, each potentially on different hardware. DistServe / Splitwise / Mooncake are the canonical references; gains of several factors at SLO are reported.
  • The two techniques compose with each other and with the rest of the stack (paged attention, continuous batching, chunked prefill, prefix caching, W4). Each contributes a multiplicative factor; full-stack production systems are 50–200× a naive baseline (illustrative, workload-dependent).
  • Each technique has a regime where it doesn't help. Speculation hurts at high batch. Disaggregation adds operational complexity and is unnecessary when load is so low that co-located scheduling never compromises. Quantization has accuracy costs. Knowing the regimes is the engineering work.

The reader who has internalized this chapter should be able to: derive the speculative speedup formula on demand; explain why decode is memory-bound and prefill compute-bound; sketch a disaggregated cluster's request flow and KV-transfer pipeline; argue for or against speculation given a batch state; size prefill/decode pools for a workload mix; and read papers on EAGLE, Medusa, DistServe, Splitwise, Mooncake without needing the introductions.

The next deep dive in this curriculum (Month 5, Week 18) builds on these primitives toward end-to-end production serving stacks.


References (canonical, for further reading)

  • Leviathan, Kalman, Matias. Fast Inference from Transformers via Speculative Decoding. ICML 2023.
  • Chen, Borgeaud, Irving, et al. Accelerating Large Language Model Decoding with Speculative Sampling. 2023.
  • Cai et al. Medusa: Simple LLM Inference Acceleration Framework with Multiple Decoding Heads. 2024.
  • Li et al. EAGLE / EAGLE-2: Speculative Sampling Requires Rethinking Feature Uncertainty. 2024.
  • Fu et al. Lookahead Decoding. 2024.
  • Zhong et al. DistServe: Disaggregating Prefill and Decoding for Goodput-Optimized Large Language Model Serving. OSDI 2024.
  • Patel et al. Splitwise: Efficient Generative LLM Inference Using Phase Splitting. 2024.
  • Qin et al. Mooncake: A KVCache-centric Disaggregated Architecture for LLM Serving. 2024.
  • Gloeckle et al. Better & Faster Large Language Models via Multi-token Prediction. 2024.

(Citations are by name and year. Treat performance numbers in this chapter as approximate ranges from public reporting; reproduce on your own workload before relying on them.)

Deep Dive 11-Numerics and Mixed Precision

"Floating-point arithmetic is the silent assassin of deep learning. Most training divergences are not bugs in the model; they are bugs in the number system."

A neural network is, at the end of the day, a chain of arithmetic operations executed on finite-precision hardware. The mathematics on the whiteboard treats real numbers; the GPU treats sequences of bits with explicit rounding rules. Whether your run converges, plateaus, NaNs, or silently bias-shifts is determined by the gap between those two worlds. This chapter is the reference for closing that gap.

We will derive-not just state-IEEE-754 floating point, walk through every format relevant to ML (FP64, FP32, TF32, FP16, BF16, FP8 E4M3 / E5M2, FP4), explain why each operation in a transformer needs the precision it does, write out the loss-scaling and FP8 delayed-scaling algorithms in pseudocode, and finish with worked exercises that you should be able to do on paper.

Read this once carefully. Then re-read sections 4, 7, and 11 the next time a training run NaNs.


Table of contents

  1. IEEE-754 in 30 minutes
  2. The ML floating-point zoo
  3. Operation-by-operation precision requirements
  4. The standard mixed-precision recipe
  5. Loss scaling, derived
  6. Why BF16 is different (and what it costs)
  7. FP8 training in detail
  8. TF32: the silent precision drop
  9. Adam + low precision pitfalls
  10. Catastrophic cancellation in reductions
  11. Numerical stability tricks in transformers
  12. Detecting and handling NaN
  13. Determinism
  14. Practical exercises

1. IEEE-754 in 30 minutes

1.1 Why we cannot use real numbers

A real number x ∈ ℝ requires, in general, infinite information to represent. Computers store fixed-width approximations. The IEEE-754 standard (1985, revised 2008 and 2019) defines a family of binary floating-point formats and the rounding rules for arithmetic on them.

A binary floating-point number is a triple (s, e, m) interpreted as

x = (-1)^s × 2^E × M

where: - s is one sign bit (0 = positive, 1 = negative), - e is the biased exponent stored in n_exp bits, - m is the mantissa (also called significand fraction) stored in n_man bits.

The "biased" part means that the exponent field stores e = E + bias, where bias = 2^(n_exp - 1) - 1. We do this so that the exponent field can represent both negative and positive E while remaining an unsigned integer-e ranges from 0 to 2^n_exp - 1, and E ranges from 1 - bias to bias.

The mantissa stores only the fractional part. There is an implicit leading 1 for normal numbers:

M = 1.m_{n_man-1} m_{n_man-2} ... m_1 m_0   (binary)
  = 1 + sum_{i=0}^{n_man-1} m_i × 2^(i - n_man)

So for FP32 with 23 mantissa bits, M lies in [1, 2) with a granularity of 2^-23 ≈ 1.19e-7.

1.2 The FP32 example

FP32 (binary32): 1 + 8 + 23 = 32 bits.

  • bias = 2^7 - 1 = 127.
  • Normal e range: 1 to 254. So E ranges from - 126to+127`.
  • Smallest normal: 1.0 × 2^-126 ≈ 1.175e-38.
  • Largest finite: (2 - 2^-23) × 2^127 ≈ 3.403e+38.
  • Machine epsilon `ε = 2^-23 ≈ 1.19e-7 - the gap between 1 and the next representable number.

FP32 represents about 7 decimal significant digits because log10(2^23) ≈ 6.92.

1.3 Subnormals (denormals)

The exponent code e = 0 is special: it represents subnormal (denormal) numbers, which fill the gap between zero and the smallest normal:

x_subnormal = (-1)^s × 2^(1 - bias) × (0.m_{n_man-1} ... m_0)_2

Note the implicit leading bit becomes 0 instead of 1, and the exponent is fixed at 1 - bias (not 0 - bias, a one-off to make the transition continuous). For FP32:

  • Smallest positive subnormal: 2^-23 × 2^-126 = 2^-149 ≈ 1.4e-45.
  • Largest subnormal: (1 - 2^-23) × 2^-126, just under the smallest normal.

Subnormals enable gradual underflow: as a value shrinks below the smallest normal, it loses precision bit by bit but does not abruptly become zero. Some hardware flushes subnormals to zero (FTZ/DAZ flags) for performance-this matters in DSPs and is occasionally encountered on GPUs.

1.4 Special values

The exponent code e = 2^n_exp - 1 is also special:

Field Mantissa Meaning
e = 0, m = 0, s = 0 - +0
e = 0, m = 0, s = 1 - - 0`
e = 0, m ≠ 0 - subnormal
1 ≤ e ≤ 2^n_exp - 2 any normal
e = all-ones, m = 0 - ±inf
e = all-ones, m ≠ 0 - NaN

NaN comes in two flavours: quiet NaN (qNaN) and signaling NaN (sNaN), distinguished by the high bit of the mantissa. ML rarely cares; both propagate through arithmetic and we generally trap on either.

**+0 versus - 0:** they compare equal but1 / (+0) = +infwhile1 / (-0) = -inf`. This is a common source of subtle bugs in custom ops.

1.5 Rounding modes

IEEE-754 defines five modes; only two matter day to day:

  1. Round to nearest, ties to even (RNE)-default. If a real number falls exactly between two representable floats, pick the one whose last mantissa bit is 0. RNE is unbiased: averaging many rounded results does not introduce a systematic drift.
  2. Round toward zero (truncation)-used in some quantization paths.
  3. Round toward +inf, round toward -inf-interval arithmetic.
  4. Round to nearest, ties away from zero-common in financial code, rare in ML.

Stochastic rounding (section 9.2) is not in IEEE-754 but is increasingly important in low-bit training.

The fundamental error bound: for any operation op ∈ {+, -, ×, /, sqrt}, the IEEE-754 result satisfies

fl(a op b) = (a op b) × (1 + δ),    |δ| ≤ ε / 2

with ε the machine epsilon and the relative error bounded by the unit roundoff u = ε / 2.

1.6 Sources of error in + and ×

Multiplication. fl(a × b) = ab × (1 + δ) with |δ| ≤ u. The error is always relative to the magnitude of the result, so multiplication is benign.

Addition. fl(a + b) = (a + b) × (1 + δ) likewise, but the worst case happens when a ≈ -b: the relative error stays small, but the absolute error of the result is large compared to the (small) result. This is catastrophic cancellation: subtracting two nearly equal numbers exposes their roundoff.

Example in FP32:

a = 1.0000001
b = 1.0000000
a - b = 1e-7  (in real arithmetic)
        but  fl(a) = 1.0000001 ± 6e-8
              fl(b) = 1.0       ± 6e-8
              fl(a - b) = 1e-7 with absolute error ~ 1.2e-7

The result is dominated by roundoff. Section 10 returns to this.

1.7 Compound operations: FMA

Most modern hardware exposes a fused multiply-add (FMA): fl(a × b + c) computed with one rounding instead of two. FMA is more accurate than separate mul then add, and most tensor-core math is built on this. Different FMA orderings are why (a + b) + c ≠ a + (b + c) in general.


2. The ML floating-point zoo

We now place every format used in practice next to the others.

2.1 Bit layouts

            sign  exp  man   bias   smallest normal       max finite
FP64        1     11   52    1023   ~2.225e-308           ~1.798e+308
FP32        1      8   23    127    ~1.175e-38            ~3.403e+38
TF32        1      8   10    127    ~1.175e-38            ~3.403e+38   (compute, not storage)
FP16        1      5   10     15    ~6.104e-5             ~6.550e+4
BF16        1      8    7    127    ~1.175e-38            ~3.389e+38
FP8 E4M3    1      4    3      7    ~2^-9 = 1.95e-3       448  (or 240 in saturating variant)
FP8 E5M2    1      5    2     15    ~6.104e-5             ~5.734e+4
FP4 E2M1    1      2    1      1    ~0.5                  6

A few subtleties:

  • TF32 is not a storage format. Tensors are stored as FP32 in memory; the tensor core internally rounds operands to a 1+8+10 layout (FP32 range, FP16 precision) before doing the matmul, then accumulates in FP32. From the user's API, the tensors look like FP32-only the computation is degraded. Section 8.
  • FP8 E4M3 in the OFP8 standard (Micikevicius et al., 2022; adopted by NVIDIA's TransformerEngine, Intel, AMD, ARM) does not support infinities. The e=all-ones, m=all-ones codepoint is reused as a finite value, raising max from 240 to 448. Some implementations instead reserve that codepoint for inf/NaN, giving max = 240. Both variants exist; check your library.
  • FP8 E5M2 is fully IEEE-754-shaped: it has inf and NaN. Max ≈ 57344, smallest normal ≈ 6.10e-5-the same range as FP16.
  • FP4 E2M1 has only 16 codepoints total (including sign and zero). Practical FP4 training requires per-block scaling (e.g., MXFP4 with 32-element blocks) to be viable at all.

2.2 Range and precision, side-by-side

Format Decimal digits Dynamic range (orders of magnitude) Use
FP64 ~15–16 ~600 Scientific computing; rarely ML
FP32 ~7 ~76 Master weights, optimizer state
TF32 ~4 ~76 Hidden tensor-core compute
FP16 ~3–4 ~10 Mixed-precision compute (legacy)
BF16 ~2–3 ~76 Mixed-precision compute (default)
FP8 E4M3 ~1–2 ~5 Activations, weights
FP8 E5M2 ~1 ~10 Gradients
FP4 <1 ~2 Inference; experimental training

Decimal digits = log10(2^(n_man + 1)) (the +1 from the implicit leading bit). Dynamic range = log10(max / smallest_normal).

The two axes-precision (mantissa bits) and range (exponent bits)-trade off independently. FP16 and BF16 are both 16-bit, but FP16 spends 5 exponent bits and 10 mantissa bits, while BF16 spends 8 and 7. BF16 gives up half the precision in exchange for the full FP32 range. For deep learning, where activations and gradients can span many orders of magnitude, range matters more than precision. This is the single most important fact in ML numerics.

2.3 Memory cost

Per parameter, with Adam optimizer states:

Configuration Weights Grads m v Total
Pure FP32 4 4 4 4 16 B
FP16/BF16 + FP32 master 2 + 4 2 4 4 16 B
FP8 + FP32 master 1 + 4 1 4 4 14 B
FP8 + FP16 master 1 + 2 1 2 2 8 B

Mixed precision saves memory for activations (during forward we hold half-precision tensors), not for parameters-until you're willing to sacrifice the FP32 master weights, which most production runs are not.


3. Operation-by-operation precision requirements

Different operations have different sensitivity to precision. Get this wrong and you waste bits where they don't matter while starving operations that do.

3.1 Matrix multiply

For C = A × B where A ∈ ℝ^{m×k}, B ∈ ℝ^{k×n}:

C_{ij} = sum_{p=1}^{k} A_{ip} × B_{pj}

With low-precision inputs and a k - element accumulation, the error is dominated by the **accumulator**, not the multiplicands. Each multiplication contributes one rounding (urelative); the sum then accumulateskof these. For naive sequential summation, the worst-case error isO(k × u × max|x|)`.

Tensor cores always accumulate in higher precision than they multiply:

Input Accumulator
FP16 / BF16 FP32
FP8 E4M3 / E5M2 FP32
TF32 FP32
INT8 INT32

You can ask for FP16 accumulation on some old hardware; you should not. For k ≈ 4096 (typical hidden dim), FP16 accumulation can lose 3 decimal digits; FP32 accumulation keeps the error around the unit roundoff.

The matmul lesson: inputs can be cheap; the accumulator must be expensive.

3.2 Reductions (sum, mean, norm)

Reductions are more sensitive than matmuls because: 1. The number of terms N (e.g., the feature dimension in LayerNorm) can be larger than the inner dim of typical matmuls. 2. Naive sequential summation has O(N) worst-case error growth. 3. Layer norm / RMS norm involves both a sum (mean) and a sum of squares (variance)-the latter is even more sensitive.

Practical rule: reductions always promote to FP32, regardless of input dtype. PyTorch and TF do this by default for mean, sum, var, norm. Check your custom kernels.

We dissect reductions in section 10.

3.3 Softmax

The softmax s_i = exp(x_i) / sum_j exp(x_j) has two failure modes in low precision:

  • Overflow: if any x_i > log(max_finite), then exp(x_i) = inf. For FP16, log(65504) ≈ 11.09. Logits routinely exceed this in attention.
  • Underflow: if all exp(x_i) are below the smallest normal, the denominator is zero. For FP16, smallest normal is ~6e-5, so any x_i < log(6e-5) ≈ -9.7 underflows.

Standard fix: subtract the max,

m = max_i x_i
s_i = exp(x_i - m) / sum_j exp(x_j - m)

Now the largest exponent argument is 0, so exp(...) ≤ 1. The subtraction is exact as long as x_i and m are close in magnitude (which they are after normalization).

This works mathematically because exp(x_i) / sum_j exp(x_j) = exp(x_i - m) / sum_j exp(x_j - m) - multiplying numerator and denominator byexp(-m)`.

Online softmax (FlashAttention) uses an incremental version of this trick; we cover it in the attention deep dive.

3.4 Gradient accumulation

When training with micro-batching, you accumulate gradients across micro-batches before stepping the optimizer:

grad_buf += grad_micro

If grad_buf is in BF16 and grad_micro is small, the accumulation underflows. Always accumulate in higher precision than the gradients themselves. PyTorch's GradScaler and DeepSpeed do this automatically; if you write your own pipeline, you must do it explicitly.

Concretely: gradients computed in BF16 should accumulate into an FP32 buffer. Gradients computed in FP8 should accumulate into FP16 or FP32. Otherwise the small contributions are lost: large_buf + tiny = large_buf whenever tiny < ε × large_buf.


4. The standard mixed-precision recipe

The Micikevicius et al. (2018) recipe is the foundation. Every modern training stack-Apex, PyTorch AMP, DeepSpeed, Megatron, JAX/Flax-implements it.

4.1 The four invariants

  1. Master weights in FP32. The "real" parameters live in FP32. We make a cast copy in low precision (FP16 or BF16) for the forward and backward passes.
  2. Forward and backward in low precision. Activations, weight matmuls, attention, layer norms (with FP32 accumulation), all run in FP16 or BF16.
  3. Gradients in low precision are produced by autograd, then immediately upcast to FP32 before going into the optimizer.
  4. Optimizer states (m, v for Adam) in FP32. The optimizer step is computed entirely in FP32; only after stepping do we re-cast the master weights down to low precision for the next forward.

If you violate any of these, expect divergence on long runs.

4.2 The full step

# Setup
master_weights_fp32 = init_weights()
adam_m_fp32 = zeros_like(master_weights_fp32)
adam_v_fp32 = zeros_like(master_weights_fp32)
loss_scale = 2**15  # FP16 only; for BF16 set to 1

for batch in dataloader:
    # 1. Cast master to low precision for compute
    weights_lp = master_weights_fp32.to(low_precision_dtype)

    # 2. Forward in low precision
    logits = model(batch.x, weights_lp)
    loss   = cross_entropy(logits, batch.y)

    # 3. Scale loss (FP16 only)
    loss_scaled = loss * loss_scale

    # 4. Backward in low precision; produces grads in low precision
    grads_lp = backward(loss_scaled, weights_lp)

    # 5. Upcast and unscale
    grads_fp32 = grads_lp.to(fp32) / loss_scale

    # 6. NaN/inf check (FP16 only); skip step if found
    if any_nan_or_inf(grads_fp32):
        loss_scale /= 2
        continue

    # 7. Optimizer step in FP32
    adam_m_fp32 = beta1 * adam_m_fp32 + (1-beta1) * grads_fp32
    adam_v_fp32 = beta2 * adam_v_fp32 + (1-beta2) * grads_fp32 ** 2
    master_weights_fp32 -= lr * adam_m_fp32 / (sqrt(adam_v_fp32) + eps)

    # 8. Optional: increase loss scale after a streak of clean steps
    successful_steps += 1
    if successful_steps >= 2000:
        loss_scale *= 2
        successful_steps = 0

The loss_scale machinery is unique to FP16 (section 5). For BF16, you can set it to 1 and remove the NaN-check / dynamic-update branches, but you should still upcast grads to FP32 before the optimizer step.

4.3 What about activation memory?

In the forward pass, we save activations for the backward pass. These should be stored in the same precision they were computed in (FP16/BF16/FP8)-that's where the memory saving comes from. The FP32 master weights add only 4 × N_params bytes, while activation memory grows with batch_size × seq_len × hidden_dim × num_layers, so for big models the half-precision activations dominate.

This is also why "FP32 training" without master weights wastes memory: FP32 activations are 2× the BF16 ones, and the activations are usually the bigger chunk.


5. Loss scaling, derived

5.1 Why FP16 needs it

Gradients in deep networks at end of training can be very small. The smallest positive normal FP16 is 2^-14 ≈ 6.10e-5; with subnormals you can get down to 2^-24 ≈ 5.96e-8 but with shrinking precision. Anything smaller silently becomes zero.

In practice, late-training gradients for many parameters cluster around 1e-7 to 1e-9. They underflow FP16. The optimizer sees zero, the parameter does not update, the network plateaus.

5.2 The trick

Multiply the loss by a large constant S before backward:

loss_scaled = S × loss

By the chain rule, every gradient is multiplied by S:

∂(S × loss) / ∂w = S × ∂loss / ∂w

Now if the original gradient was 1e-9, the scaled gradient is S × 1e-9. With S = 2^15 = 32768, the scaled value is ~3.3e-5, which is comfortably representable.

After the backward pass, before the optimizer step, upcast to FP32 and divide by S to restore the true gradient magnitude.

5.3 Static loss scaling

Pick S once and leave it. Common choices: 2^7, 2^10, 2^15. Too small: gradients still underflow. Too large: gradients overflow to inf, kill the step.

Static scaling is simple but fragile. A bad batch can blow it up; a phase change in training can render it suboptimal.

5.4 Dynamic loss scaling

The standard algorithm (used by NVIDIA Apex AMP, PyTorch torch.amp.GradScaler):

S          := 2^15        # initial loss scale
streak     := 0
patience   := 2000        # successful steps before doubling
backoff    := 0.5         # multiplier on overflow (halve)
growth     := 2.0         # multiplier after streak (double)
S_max      := 2^24
S_min      := 1.0

for each step:
    grads_lp := backward(S * loss, weights_lp)
    grads_fp := upcast(grads_lp) / S

    if any_nan_or_inf(grads_fp):
        S       := max(S * backoff, S_min)
        streak  := 0
        skip optimizer step    # crucial: do NOT update weights
        continue

    optimizer.step(grads_fp)
    streak += 1
    if streak >= patience:
        S       := min(S * growth, S_max)
        streak  := 0

Two important details often missed:

  1. On overflow, you must skip the step, not clip and proceed. The corrupted gradients have no useful information.
  2. Check for NaN/inf on the scaled gradients before unscaling, or equivalently on the unscaled-the operation is just a divide. Fused implementations check during the unscale.

5.5 Choosing S_max and patience

S_max = 2^24 is conservative: even if a single gradient is ~1, multiplying by 2^24 ≈ 1.6e7 puts it in FP16 overflow range (6.5e4). So in practice runs settle to S between 2^10 and 2^16, occasionally pushing higher.

patience = 2000 is an empirical choice from Apex. Lower (say 100) and you double too aggressively, causing frequent overflow-rollback cycles. Higher (10000) and you under-utilize the FP16 range during long calm phases.

5.6 Why this is invisible to the user (mostly)

PyTorch wraps it in GradScaler:

scaler = torch.cuda.amp.GradScaler()
for batch in loader:
    with torch.cuda.amp.autocast(dtype=torch.float16):
        loss = model(batch).loss
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)         # divides grads by S in-place
    clip_grad_norm_(model.parameters(), max_norm)  # now safe
    scaler.step(optimizer)              # skips if inf/nan detected
    scaler.update()                     # adjusts S

Two lines you must not skip: unscale_ before clip_grad_norm_ (otherwise you clip the scaled grads, with the wrong norm), and scaler.update() after every step.


6. Why BF16 is different (and what it costs)

6.1 The free lunch (almost)

BF16 has 8 exponent bits, the same as FP32. Its dynamic range matches FP32: roughly 1.18e-38 to 3.4e+38. Gradients do not underflow. Loss scaling is unnecessary.

This is why BF16 has eaten the world. Hopper, TPU v3/v4/v5, Ampere, and AMD MI250+ all natively support it. Modern training defaults are:

  • Master weights: FP32
  • Compute: BF16
  • Optimizer states: FP32
  • No loss scaling, no NaN-rollback, no dynamic scale tuning.

6.2 The catch

BF16 has 7 mantissa bits, vs 10 for FP16 and 23 for FP32. That's log10(2^8) ≈ 2.4 decimal digits of precision (counting the implicit leading bit). Two consequences:

  1. Catastrophic cancellation is more likely in any subtraction or near-cancelling sum.
  2. Accumulation errors compound more aggressively. A million-element BF16 sum can lose almost all precision (we compute this in exercise 14.2).

Mitigation:

  • Accumulate everything in FP32. This is the default in PyTorch/JAX for reductions; double-check custom kernels.
  • Keep master weights in FP32. Adam's tiny updates would otherwise be lost in BF16 weights (section 9).
  • For very long runs (>10^6 steps), some practitioners use FP32 for selected layers (final norm, classification head) to avoid drift.

6.3 BF16 versus FP16 in practice

Concern FP16 BF16
Range overflow Likely without scaling Almost never
Range underflow Likely without scaling Almost never
Mantissa precision 10 bits (~3 decimal) 7 bits (~2.4 decimal)
Loss scaling needed Yes, ideally dynamic No
Hardware support Volta+ Ampere+, TPU all
Fine-tuning safety Fragile Robust

For new code, default to BF16 unless you're targeting hardware that lacks it.


7. FP8 training in detail

FP8 was introduced as a training format with NVIDIA Hopper (H100, 2022). Supporting libraries: TransformerEngine (NVIDIA), MS-AMP (Microsoft), and increasingly native frameworks.

The key insight: FP8 cannot be used "in place" of FP16/BF16. The dynamic range is too small (~5 orders of magnitude for E4M3, ~10 for E5M2). You must scale per-tensor, and you must update the scale carefully.

7.1 The two FP8 formats

E4M3 E5M2
Bits 1 + 4 + 3 1 + 5 + 2
Bias 7 15
Smallest normal 2^-6 ≈ 0.0156 2^-14 ≈ 6.10e-5
Smallest subnormal 2^-9 ≈ 0.00195 2^-16 ≈ 1.53e-5
Max finite (with inf reserved) 240 57344
Max finite (saturating variant) 448 n/a (inf is real)
Has inf? No (saturating) / Yes (IEEE-style) Yes
Has NaN? Yes (only one codepoint) Yes

E4M3 trades range for precision (3-bit mantissa beats 2-bit). E5M2 trades precision for range (matches FP16 range exactly-useful for gradients which span many orders).

Standard assignment (TransformerEngine, OFP8 paper): activations and weights → E4M3; gradients → E5M2. The intuition: weights and activations have a tighter distribution after layer norm; gradients are wider and need range.

7.2 Per-tensor scaling

For each tensor X we maintain a scalar S_X in FP32. The quantize/dequantize pair:

quantize:    X_fp8 = round( clip(X_fp32 * S_X, -FP8_MAX, +FP8_MAX) )
dequantize:  X_fp32_reconstructed = X_fp8 / S_X

We choose S_X so that X_fp32 × S_X lands close to (but not above) FP8_MAX. If amax_X = max|X_fp32|, the optimal scale is roughly

S_X = FP8_MAX / amax_X * margin

where margin < 1 (e.g., 1 / 2^k for some small k) gives headroom for transient spikes.

7.3 The matmul

Tensor cores execute the matmul on (X_fp8, W_fp8) and accumulate in FP32. The scales come out in the dequantize:

Y_fp8 ≈ matmul(X_fp8, W_fp8)        # with FP32 accumulation internally
Y_fp32 = Y_fp8 / (S_X * S_W)         # apply both scales in one step

Then we re-quantize Y to FP8 with its own scale S_Y for the next layer. All of this is one fused kernel in TransformerEngine.

7.4 Delayed (lazy) scaling

The naive approach-compute amax(X) now, set S_X = FP8_MAX / amax(X), then quantize-adds a full reduction over X before every matmul. That reduction would dominate cost.

Delayed scaling instead uses the amax from the previous step:

S_X[t]  = FP8_MAX / max_history[t-1] * margin
amax_X[t] = max|X[t]|       # computed alongside the matmul, cheap
push amax_X[t] into history; trim to last K entries

The history is typically K = 16 to K = 1024 steps. We keep the maximum over the last K rather than the most recent value, to be robust to dips.

7.5 Algorithm in pseudocode

struct FP8TensorScaling {
    fp32  scale            # quantize multiplier this step
    fp32  amax_history[K]  # last K observed amax values
    int   history_idx
}

def fp8_matmul(X_fp32, W_fp32, x_meta, w_meta):
    # 1. Compute scales from previous-step amax
    s_x = FP8_MAX / (max(x_meta.amax_history) + EPS)
    s_w = FP8_MAX / (max(w_meta.amax_history) + EPS)

    # 2. Quantize current tensors and observe current amax
    X_fp8 = round(clip(X_fp32 * s_x, -FP8_MAX, +FP8_MAX))
    W_fp8 = round(clip(W_fp32 * s_w, -FP8_MAX, +FP8_MAX))
    amax_x_now = max(abs(X_fp32))      # FP32 reduction, fused with the cast
    amax_w_now = max(abs(W_fp32))

    # 3. Tensor-core matmul; FP32 accumulator
    Y_fp32 = matmul_fp8_to_fp32(X_fp8, W_fp8) / (s_x * s_w)

    # 4. Update history for next step
    x_meta.amax_history[x_meta.history_idx] = amax_x_now
    w_meta.amax_history[w_meta.history_idx] = amax_w_now
    x_meta.history_idx = (x_meta.history_idx + 1) mod K
    w_meta.history_idx = (w_meta.history_idx + 1) mod K

    return Y_fp32

The + EPS guards against amax = 0 on a freshly-initialized layer or a fully-pruned weight matrix.

7.6 NaN/inf detection

E4M3 in the saturating variant has no inf representation, so an out-of-range value silently saturates to ±448. This is normally fine-clip is the desired behavior. But it does mean that you cannot detect overflow by checking for inf in the FP8 tensor. Instead:

  • Check the FP32 amax before quantization. If it explodes, the previous step was bad.
  • Check for NaN in the FP32 master weights and the FP32 dequantized output.
  • Optional: check if the FP8 tensor saturates more than X% of its elements; that's a sign the scale is wrong.

E5M2 has a real inf, so standard isinf checks work.

7.7 A worked numerical example

Consider an activation tensor X with amax = 12.5. Using E4M3 with FP8_MAX = 448:

S_X = 448 / 12.5 = 35.84

Take a single value x = 1.7:

x * S_X = 60.928
round to E4M3:    60.928 → 60   (E4M3 step at this magnitude is 4: 56, 60, 64)
fp8 stored:       60
dequantize:       60 / 35.84 = 1.6741
absolute error:   |1.7 - 1.6741| = 0.0259
relative error:   0.0152 (1.5%)

That's the precision floor for FP8: roughly 1–3% relative error per element, which the matmul's FP32 accumulator partially averages out across thousands of multiplies.

For a value near the max, say x = 12.0:

x * S_X = 430.08
round to E4M3:    430.08 → 432  (step at this magnitude is 32)
dequantize:       432 / 35.84 = 12.054
relative error:   0.0045 (0.45%)

And for a tiny value x = 0.001:

x * S_X = 0.0358
This is in subnormal range for E4M3: smallest subnormal = 2^-9 ≈ 0.00195
Quantize: round(0.0358 / 0.00195) * 0.00195 ≈ 18 * 0.00195 = 0.0352
Dequantize: 0.0352 / 35.84 ≈ 9.82e-4
Absolute error: ~2e-5
Relative error: 1.8%

The takeaway: FP8's relative precision is roughly constant in the well-scaled regime, falling off only in the subnormal tail.

7.8 What goes in FP8 and what doesn't

Even in a fully FP8-trained model:

  • Weights, activations, gradients: FP8 (E4M3 / E5M2 split).
  • Optimizer states: FP32 (or FP16, with care).
  • Master weights: FP32 (always).
  • Layer norm / RMS norm gain and bias: FP32 or BF16.
  • Embedding tables: usually BF16-distribution is too long-tailed for E4M3.
  • Final classifier / logits: BF16 or FP32-softmax is too sensitive.

8. TF32: the silent precision drop

TF32 (TensorFloat-32) is NVIDIA-specific (Ampere onwards). It is not a storage format; you cannot allocate a TF32 tensor.

8.1 What it actually is

When tensor cores execute an FP32 matmul, they internally: 1. Read FP32 inputs A, B from memory. 2. Round each to TF32 (1+8+10)-discarding 13 mantissa bits. 3. Multiply (using TF32-precision multipliers). 4. Accumulate the products in FP32. 5. Write FP32 output to memory.

So the user sees an FP32 matmul, but the compute throughput is the FP16-tensor-core rate while the precision is FP16-mantissa quality.

8.2 When it bites

For most training, TF32 is fine-losses converge to within 0.01% of true-FP32 results. But:

  • Numerical methods that depend on FP32 precision in the small mantissa (e.g., orthogonalization, Gram-Schmidt, eigendecomposition, large-N integration) can fail subtly.
  • Very small learning rates with large weights: w + small_update may not change w if the update is below TF32 ulp.
  • Reproduce-old-paper validation: you are no longer doing what the paper did.

8.3 The toggle

In PyTorch (post-1.7):

torch.backends.cuda.matmul.allow_tf32 = True   # default on Ampere+
torch.backends.cudnn.allow_tf32 = True         # convolutions

# To turn off:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

PyTorch's default flipped between True and False across versions. As of 2.x, TF32 is enabled by default for matmul on Ampere/Hopper. If you need bit-stable repros or are debugging numerical drift, turn it off and re-test.

8.4 TF32 versus BF16 + FP32 master

These two routes give similar speed and similar accuracy:

  • TF32: keep all tensors in FP32, hardware drops 13 mantissa bits internally.
  • BF16 + master: compute in BF16 (16 bits, 7 mantissa), keep FP32 master.

The BF16 path uses half the memory for activations. TF32 uses none of the BF16 machinery (no autocast wrapping). Most modern training has migrated to BF16 + master because of the activation memory win.


9. Adam + low precision pitfalls

9.1 The fundamental issue

Adam's update is

m  := β1 m + (1 - β1) g
v  := β2 v + (1 - β2) g²
m̂  := m / (1 - β1^t)
v̂  := v / (1 - β2^t)
p  := p - lr × m̂ / (sqrt(v̂) + ε)

For a typical late-training step: lr ≈ 1e-4, m̂ / sqrt(v̂) ≈ O(1), so the update Δp ≈ 1e-4. Meanwhile a typical weight value is ~1e-1 to ~1.

If p is stored in BF16: - p ≈ 0.5, BF16 ulp at 0.5 is 2^-8 ≈ 3.9e-3. - Δp ≈ 1e-4, much smaller than the ulp. - p - Δp rounds back to p. The update is lost.

The same calculation in FP32: p ≈ 0.5, FP32 ulp at 0.5 is 2^-24 ≈ 6e-8, much smaller than Δp ≈ 1e-4. The update sticks.

This is why FP32 master weights are non-negotiable. The entire purpose of master weights is to be the high-precision substrate where Adam's tiny update can survive.

9.2 Stochastic rounding

If you absolutely must store master weights in low precision (memory pressure on a 100B+ model), one fix is stochastic rounding.

Standard (deterministic) RNE rounding: round(0.5 + α) returns the nearest representable value, with ties going to even. This is unbiased on the individual round, but for many small accumulations into a low-precision accumulator, the bias is non-zero-small updates are systematically dropped.

Stochastic rounding: round up with probability proportional to the residual:

def stochastic_round(x_fp32, target_dtype):
    x_lo = floor_to(x_fp32, target_dtype)   # next representable below
    x_hi = next_above(x_lo, target_dtype)   # next representable above
    residual = (x_fp32 - x_lo) / (x_hi - x_lo)   # in [0, 1)
    if random_uniform(0, 1) < residual:
        return x_hi
    else:
        return x_lo

This is unbiased in expectation even after repeated accumulation: E[round(x)] = x. In practice, with stochastic rounding into BF16 master weights, even very small Adam updates accumulate correctly over many steps, because the probability of rounding up matches the residual.

Cost: requires per-element random numbers (philox-style RNG, fast on GPU). Some FP8 training recipes (HFP8, MS-AMP) use stochastic rounding for the FP32 → FP8 cast on weights to retain trainability.

9.3 Practical guidance

  • Default: FP32 master weights. Done.
  • If memory-bound on master weights: stochastic rounding into BF16 master weights. Validate convergence carefully.
  • Never: deterministic RNE rounding into BF16 master weights for a long run. The bias accumulates.

10. Catastrophic cancellation in reductions

10.1 The error model

Naive sequential summation:

s = x_1
for i in 2..N:
    s = s + x_i

At each step, the floating-point add introduces a relative error |δ| ≤ u. The error in s after N adds satisfies, in the worst case,

|fl(sum) - true_sum| ≤ N × u × max_i |x_i|

(actually the bound is (N - 1) × u × sum |x_i| for non-negative inputs, but N × u × max is a useful approximation when terms are similar in magnitude).

FP16 example: N = 10^6, u = 2^-11 ≈ 4.9e-4 (for FP16, one ulp at 1.0 is 2^-10, so u = 2^-11).

relative error ~ N × u = 10^6 × 4.9e-4 = 490

That's 490× the magnitude of `max|x_i| - catastrophic.

FP32: u = 2^-24 ≈ 6e-8. Same N:

relative error ~ 10^6 × 6e-8 = 0.06

6%-bad but recoverable.

BF16: u = 2^-8 ≈ 4e-3. Same N:

relative error ~ 10^6 × 4e-3 = 4000

Worse than FP16. BF16's range advantage does not save you here.

10.2 Pairwise summation

Recursive halving:

def pairwise_sum(x):
    if len(x) == 1: return x[0]
    mid = len(x) / 2
    return pairwise_sum(x[:mid]) + pairwise_sum(x[mid:])

Error bound: O(log N × u × max|x_i|).

For N = 10^6: log_2(10^6) ≈ 20. So FP16 pairwise error is `20 × u × max ≈ 0.01 - five orders of magnitude better than naive.

This is what NumPy, PyTorch, TF, JAX all use for sum, mean, etc. It's the default-but only if you're calling the framework's reduction. Custom CUDA kernels that use a single-thread accumulator have naive O(N × u) error. Be careful.

10.3 Kahan summation

Track and re-add the lost low-order bits:

def kahan_sum(x):
    s = 0
    c = 0    # compensation
    for xi in x:
        y = xi - c
        t = s + y
        c = (t - s) - y    # the rounding error
        s = t
    return s

Error bound: O(u × max|x_i|), independent of N. Cost: 4 ops per element instead of 1.

On GPU, the 4× slowdown is usually not worth it: pairwise summation is O(log N) error and parallelizes naturally on a tree-reduction. Kahan is mostly used in scientific computing and rarely in ML.

10.4 Reduction precision in practice

PyTorch and most frameworks upcast to FP32 for reductions even when the input is BF16/FP16:

x_bf16 = torch.randn(1_000_000, dtype=torch.bfloat16)
m = x_bf16.mean()    # internally: cast to FP32, pairwise reduce, cast back

You can disable this for some kernels (e.g., keep_dtype=True flags), but you almost never want to. Always reduce in FP32, then cast the scalar result back.

LayerNorm and RMSNorm specifically:

# Pseudocode for fused LayerNorm
def layernorm(x_bf16, gamma, beta, eps):
    x_fp32 = x_bf16.to(fp32)            # upcast
    mean = pairwise_mean(x_fp32, dim=-1)
    var  = pairwise_mean((x_fp32 - mean)**2, dim=-1)
    x_norm = (x_fp32 - mean) / sqrt(var + eps)
    out = x_norm * gamma + beta
    return out.to(bf16)                  # downcast at the end

The upcast at the start and downcast at the end are crucial; the variance computation in BF16 would lose 3 decimal digits to cancellation.


11. Numerical stability tricks in transformers

11.1 Softmax with max subtraction

Already covered in section 3.3. Restating for completeness:

def stable_softmax(x):
    m = max(x)
    z = exp(x - m)
    return z / sum(z)

This is equivalent to plain softmax, never overflows, and underflows only the most-negative entries to zero (which is correct behavior-they have negligible probability).

In online softmax (FlashAttention), we extend this to streaming: when we see a new chunk of logits with max m_new, we rescale the running sum:

m_new_global = max(m_old_global, m_new)
S_new = S_old × exp(m_old_global - m_new_global) + sum_in_chunk(exp(x - m_new_global))
m_old_global = m_new_global

11.2 LayerNorm / RMSNorm with FP32 accumulator

LayerNorm computes mean and variance across the feature dim. Variance is E[(x - μ)²], which is a difference of two near-equal terms when computed naively as E[x²] - μ². Always use the centered form:

μ = mean(x)
σ² = mean((x - μ)²)

and always in FP32. Re-cast at the very end.

RMSNorm is simpler: σ_rms² = mean(x²). No subtraction, but still use FP32 to keep the squared-sum from overflowing or losing precision.

11.3 Attention with √dₖ scaling

The dot-product attention logit is q · k = sum_{i=1}^{d_k} q_i k_i.

Assume q_i, k_i are independent random variables with zero mean and unit variance. Then:

E[q · k] = 0
Var[q · k] = sum Var[q_i k_i] = d_k × Var[q_i] × Var[k_i] = d_k

So the standard deviation of q · k grows like √d_k. For d_k = 64, std ≈ 8; for d_k = 128, std ≈ 11.

After softmax, large logits saturate the distribution into a near-one-hot. To keep gradients flowing, we scale:

attn = softmax((Q K^T) / sqrt(d_k))

This puts the logits back to unit-variance regardless of d_k, preserving gradient signal early in training.

The numerical bonus: with logits at unit variance, max(Q K^T / sqrt(d_k)) rarely exceeds 5–10, comfortably within FP16 / BF16 / E4M3 range after the max-subtraction softmax trick.

11.4 Logit soft-cap (Gemma, others)

Gemma 2 introduced logit soft-capping: clip the pre-softmax logits with tanh:

logits = soft_cap × tanh(logits / soft_cap)

with soft_cap = 30 or similar. This prevents extreme logits from blowing up the softmax (mostly a problem with very long contexts where one or two logits can drift huge), and incidentally regularizes the model.

The same trick appears as z-loss (PaLM, T5): an auxiliary loss z_loss × log(sum(exp(logits)))² that pushes the log-partition-function down, preventing logit drift.

11.5 Embedding scaling

Embedding tables in many architectures (the original Transformer, T5) multiply by √d_model after lookup. Reason: the embedding entries are initialized small (N(0, 1/d_model)); without rescaling they would be drowned out by the positional encoding (which is unit-variance). Numerically: keeps the early-layer activations in a sensible range.

11.6 Output projection / unembedding

The unembedding (last Linear to vocab size) is often shared with the input embedding (tied weights). Some recipes scale the output of the last LayerNorm by 1/√d_model or skip the final norm. Others (LLaMA, GPT-NeoX) apply a final RMSNorm specifically because the output magnitudes drift over many layers. Numerically, you want logits to land near O(1) for a stable softmax.


12. Detecting and handling NaN

12.1 Where NaN comes from

Common sources:

  1. Divide by zero (or by a value that underflows to zero). E.g., 1 / (sqrt(v) + eps) with eps = 1e-10 can NaN if sqrt(v) underflows to 0 in FP16 and eps is below the FP16 minimum.
  2. Overflow to inf, then inf - inf or 0 × inf. After overflow, subsequent ops produce NaN.
  3. Invalid op: sqrt(negative), log(0), log(negative). The negative usually came from a tiny numerical error in a quantity that should mathematically be non-negative.
  4. Bad data: a NaN in the input batch propagates everywhere.

12.2 The typical failure mode

A single FP16 overflow in a forward pass:

  1. Some logit = inf.
  2. softmax(inf, ...) → involves inf / inf = NaN.
  3. NaN propagates through attention, FFN, all subsequent layers.
  4. Loss = NaN.
  5. loss.backward() produces NaN gradients on every parameter.
  6. Optimizer step: weight = weight - lr × NaN / NaN → all weights become NaN.
  7. Game over.

This is recoverable only if you detect before step 6.

12.3 Detection

if torch.isnan(loss) or torch.isinf(loss):
    # Skip this step entirely
    optimizer.zero_grad()
    if isinstance(scaler, GradScaler):
        scaler._scale = scaler._scale * 0.5
    log_warning(f"NaN/inf loss at step {step}; skipping")
    continue

A more thorough check after backward:

def grads_finite(model):
    for p in model.parameters():
        if p.grad is not None and not torch.isfinite(p.grad).all():
            return False
    return True

torch.amp.GradScaler does this automatically when you call scaler.step(optimizer): if any grad is NaN/inf, the optimizer step is skipped and scaler.update() halves the scale.

12.4 Gradient clipping

Clip the gradient norm (or per-tensor) to bound the worst-case update:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

How it works:

total_norm = sqrt(sum_p ||p.grad||² )
if total_norm > max_norm:
    scale = max_norm / total_norm
    for p in params:
        p.grad *= scale

Clipping prevents a single bad batch from producing a runaway update that pushes weights into a regime where subsequent forward passes overflow. It does not prevent NaN if NaN is already in the gradients-total_norm becomes NaN and clipping does nothing useful. Always check for NaN before or after clipping.

Important: with GradScaler, you must call scaler.unscale_(optimizer) before clipping, otherwise you clip the scaled gradients (the wrong norm).

12.5 Skip-step recovery

Algorithm:

Save checkpoint every N steps.
On NaN:
    1. Reset optimizer momentum (Adam m, v) to last good checkpoint.
    2. Restore weights to last good checkpoint.
    3. Reduce LR by 0.5x for K steps.
    4. (FP16) Halve loss scale.
    5. Resume.

Production runs implement this as a watchdog. Without it, a single catastrophic batch destroys days of compute.

12.6 Eps placement matters

The Adam denominator: sqrt(v) + eps. Two equivalent formulations have different numerical properties:

  • `update = m / (sqrt(v) + eps) - eps inside, dimensionally correct, default in PyTorch.
  • `update = m / sqrt(v + eps²) - eps inside the sqrt, slightly different behavior near v=0.

PyTorch's default eps = 1e-8. In FP16 storage that underflows; this is one reason Adam states are always FP32. In FP32 it's fine.


13. Determinism

13.1 Sources of non-determinism

GPU training is non-deterministic by default. Sources:

  1. Atomic adds in reductions: many CUDA kernels (e.g., scatter_add, some softmax kernels, certain backward passes) use atomicAdd for thread-safe accumulation. The order of atomic adds is non-deterministic, and FP32 addition is non-associative. So you get bit-different results across runs even with the same inputs.
  2. CUDA workspace reuse: cuBLAS picks different algorithms based on workspace size and available memory. Different runs → different algorithms → bit-different results.
  3. Multi-threaded data loading: workers can return batches in different orders.
  4. NCCL collectives: ring/tree algorithms have run-dependent ordering.
  5. cuDNN heuristics: cuDNN benchmarks kernels and picks the fastest, but the choice depends on transient hardware state.

13.2 PyTorch deterministic mode

import torch
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Required for some cuBLAS kernels:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"   # or ":16:8"

# Seed everything:
import random; random.seed(0)
import numpy as np; np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

Combined with single-process, single-worker data loading and a fixed seed, you can get bit-exact reproducibility on a single GPU.

13.3 Cost

Deterministic mode is slower:

  • scatter_add and embedding gradients: 2–10× slower (because we lose atomic-add).
  • Some convolution algorithms: 1.2–2× slower.
  • Multi-GPU training: harder still, because NCCL collectives are not bit-deterministic without specific configuration.

Use deterministic mode for debugging only. Production runs should accept non-determinism and rely on statistical reproducibility (the loss curve looks the same up to small noise).

13.4 What "reproducible" means in practice

For a paper or ablation study:

  • Run the same configuration 3 times with different seeds.
  • Report mean ± std of final metrics.
  • If ablations are within the std, they are noise.

Bit-exact reproducibility is rarely the goal. Statistical reproducibility (results within seed-noise) is.


14. Practical exercises

Solutions inline. Try each before reading the answer.

14.1 FP16 representable-zero

Problem: Show that the value 1e-5 (decimal) cannot be represented as a normal FP16 number. What is the closest FP16 value?

Solution:

FP16 smallest positive normal = 2^-14 ≈ 6.1035e-5.

1e-5 < 6.1e-5, so it is below the smallest normal-it's in subnormal range.

FP16 subnormal step = 2^-14 × 2^-10 = 2^-24 ≈ 5.96e-8.

1e-5 / 5.96e-8 ≈ 167.77. Round to nearest even: 168.

Closest FP16 = 168 × 2^-24 ≈ 1.0014e-5. Relative error: 0.14%.

So 1e-5 is representable in FP16-but only as a subnormal, with ~10× less precision than a normal FP16 value of similar magnitude. If subnormals are flushed (FTZ), 1e-5 becomes 0. This is exactly the regime where FP16 gradients silently underflow without loss scaling.

14.2 BF16 accumulation error

Problem: You sum N = 10^6 values, each ~U(-1, 1). Estimate the absolute error of naive sequential BF16 summation versus pairwise BF16 summation versus FP32 pairwise.

Solution:

BF16: u = 2^-8 ≈ 3.9e-3.

Naive: N × u × max|x_i| ≈ 10^6 × 3.9e-3 × 1 = 3900. The error dwarfs the true sum (which is O(sqrt(N)) ≈ 1000 by CLT). Total noise.

Pairwise BF16: log_2(N) × u × max|x_i| ≈ 20 × 3.9e-3 × 1 ≈ 0.08. Acceptable.

FP32 pairwise: u = 2^-24 ≈ 6e-8. Error ≈ 20 × 6e-8 = 1.2e-6. Negligible.

Lesson: BF16 reductions are usable only with pairwise (or better) summation. Always upcast to FP32 anyway, because it's free on modern hardware.

14.3 Loss-scale recovery trace

Problem: A run starts with S = 2^15 = 32768. After every overflow, S halves. Over how many consecutive overflows would S reach 2^0 = 1? At what point does S drop below the regime where it's helpful (assume "helpful" means S ≥ 2^7 = 128)?

Solution:

2^15 / 2^k = 2^(15-k). To reach S = 1 = 2^0, we need 15 halvings.

To drop below 2^7, we need to fall to 2^6 = 64. That is 2^(15-k) = 2^6k = 9. So 9 consecutive overflow halvings push S below the useful regime.

If the dynamic-scaling patience is 2000 successful steps for a doubling, we can recover slowly: from S = 2^6 to S = 2^15 takes 9 doublings = 18000 successful steps minimum. In practice a single bad batch causes one halving, but a phase change (e.g., LR warmup ending, distribution shift) can cause cascading overflows-9 in a row is unlikely but not impossible.

Implication: monitor loss_scale as a training metric. A scale that has been falling for 100 steps is a warning sign.

14.4 7B model optimizer-state memory

Problem: For a 7B parameter model with Adam optimizer, compute optimizer-state memory for: (a) FP32 master weights, FP32 m, v. (b) BF16 master weights, FP16 m, v. (c) Bonus: total memory including weights, gradients, master weights for case (a) and (b).

Solution:

N = 7e9 parameters.

(a) Pure FP32 optimizer states: - m in FP32: 7e9 × 4 = 28 GB - v in FP32: 7e9 × 4 = 28 GB - Total optimizer state: 56 GB

(b) BF16 master + FP16 m, v: - m in FP16: 7e9 × 2 = 14 GB - v in FP16: 7e9 × 2 = 14 GB - Total optimizer state: 28 GB

(c) Full memory accounting:

Case (a), standard mixed-precision: - BF16 weights for compute: 7e9 × 2 = 14 GB - BF16 gradients: 7e9 × 2 = 14 GB - FP32 master weights: 7e9 × 4 = 28 GB - FP32 m + v: 56 GB - Total just for params/grads/opt: 112 GB

Case (b), aggressive low-precision: - BF16 weights: 14 GB - BF16 gradients: 14 GB - BF16 master weights: 14 GB - FP16 m + v: 28 GB - Total: 70 GB

But case (b) requires stochastic rounding for the BF16 master weights and may sacrifice convergence quality. This is why ZeRO/FSDP stage 3 (sharding optimizer states across GPUs) is more popular than aggressive low-precision optimizers.

14.5 FP8 scale evolution

Problem: An activation tensor's amax history over the last 5 steps is [2.1, 2.4, 8.5, 2.3, 2.2] (the 8.5 is a transient spike from an outlier batch). You use E4M3 (FP8_MAX = 448) with margin = 1 (no headroom factor). What scale does the next step use: (a) Using the most recent amax (2.2)? (b) Using the max of history (8.5)?

What's the consequence of each choice?

Solution:

(a) S = 448 / 2.2 ≈ 203.6. Tight scale: every quantization uses the full FP8 range. But on the next outlier batch (similar to the spike), the tensor would saturate at ~448 / 203.6 ≈ 2.2, clipping any value above 2.2. We'd lose the outliers.

(b) S = 448 / 8.5 ≈ 52.7. Looser scale: most batches under-utilize the FP8 range (max value used: 2.4 × 52.7 ≈ 126, well below 448). But outliers up to 8.5 are represented faithfully.

Standard practice: use max of recent history to get robustness to spikes, possibly with an additional margin (e.g., margin = 1/2, giving an extra 2× headroom). The cost is some wasted FP8 range on calm batches; the benefit is graceful handling of outliers.

This is why K (history length) matters: too short and you forget spikes (under-scaled, clip outliers); too long and you over-pad indefinitely (over-scaled, waste precision).

14.6 Softmax overflow boundary in FP16

Problem: For a FP16 softmax (without max subtraction), at what magnitude does the largest logit cause exp to overflow? Compare to the typical pre-scaling logit magnitude in attention with d_k = 128 and unit-variance Q, K.

Solution:

FP16 max = 65504 ≈ 6.55e4. So exp(x) > 65504x > log(65504) ≈ 11.09.

Without scaling, attention logit q · k has std √d_k = √128 ≈ 11.3. So a single-σ logit already overflows FP16. A 3σ outlier (x ≈ 34) overflows by 23 in log space, i.e., by exp(23) ≈ 10^10.

With scaling by 1/√d_k, logits have std 1. Now a 5σ outlier (x ≈ 5) gives exp(5) ≈ 148, comfortably representable.

The math: √d_k scaling is necessary, not optional, for FP16 attention. Even with max-subtraction softmax, the gradient of the un-scaled logit can blow up. Scaling is built into every transformer for this reason.


Closing remarks

A few things to remember when this chapter is closed:

  1. BF16 + FP32 master weights + FP32 reductions is the modern default. It's robust, well-supported, and conceptually simple. Reach for FP16 only on hardware that doesn't have BF16; reach for FP8 only when you've measured the savings and committed to dealing with delayed scaling.

  2. Range matters more than precision for ML. This is why BF16 ate FP16 and why FP8 split into two formats (E4M3 for tight distributions, E5M2 for wide ones).

  3. The accumulator is sacred. Tensor cores will let you compute in 8 or 16 bits, but they accumulate in 32. Reductions you write yourself should do the same.

  4. Master weights exist because Adam updates are tiny. Not for any other reason. If you ever invent an optimizer with O(1) updates (some second-order methods approach this), you may be able to drop the master.

  5. Most NaN crashes are loss-scale or learning-rate problems. Before re-debugging the model, check the simplest things: is loss_scale stable? is the LR schedule sane? did the data have a NaN?

  6. Determinism is a debugging tool, not a production goal. Statistical reproducibility (across seeds) is what matters for science.

The next chapter (12_KERNEL_FUSION.md, if you're working through the curriculum in order) builds on this: now that we know which precisions are needed where, we can design custom kernels that fuse multiple operations while respecting these precision rules.

Deep Dive 12 - Kernel Fusion: Theory, Practice, and the Compilers That Do It For You

Chapter 11 told you which precisions to use where. Chapter 12 tells you how to schedule those operations onto the GPU so the precision decisions actually pay off - by eliminating the HBM round-trips that dominate end-to-end latency in modern deep-learning workloads.

This chapter is self-contained. You can read it standalone; it pulls forward concepts from chapters 01 (GPU architecture), 02 (CUDA), 03 (Triton), 04 (PyTorch internals + Inductor), 05 (JAX + XLA), 07 (FlashAttention), and 11 (numerics) and will reference them by chapter number rather than re-deriving.


Table of contents

  1. Why fuse at all
  2. The HBM round-trip cost model
  3. Fusion taxonomy
  4. Vertical fusion derived
  5. Horizontal fusion derived
  6. GEMM epilogue fusion
  7. Streaming-reduction fusion: the FlashAttention pattern
  8. Compiler-driven fusion: XLA, TorchInductor, Triton
  9. Hand-rolled fusion in Triton: three full kernels
  10. Precision discipline under fusion
  11. The limits of fusion
  12. Profiling fused kernels with Nsight Compute
  13. When NOT to fuse
  14. Practical exercises
  15. Cheat sheet and further reading

1. Why fuse at all

The single most important observation in modern GPU performance work:

Most deep-learning operators in the forward pass of a transformer are memory-bandwidth-bound, not compute-bound.

To see why, recall the roofline from chapter 01. A modern H100 GPU has roughly:

  • BF16 dense tensor-core throughput: ~989 TFLOPS.
  • HBM3 bandwidth: ~3.0 TB/s.

The crossover arithmetic intensity at which a kernel transitions from memory-bound to compute-bound is:

I_crossover = peak_FLOPS / peak_BW = 989e12 / 3.0e12 ≈ 330 FLOP/byte.

A pure elementwise operation like y = a * x + b performs 2 FLOPs per 12 bytes moved (4 for x, 4 for y, optionally 4 for a; in BF16 halve it; doesn't matter - the intensity is well under 1 FLOP/byte). The GPU sits at <1% of peak FLOPs and 100% of peak bandwidth for the entire kernel.

The consequence: if your network is a chain of elementwise ops and small reductions, total time is determined almost entirely by total bytes moved through HBM - not by total work done. A LayerNorm → Linear → GELU → Dropout chain executed as four separate kernels reads and writes the activation tensor through HBM four times. Fused into one kernel, it reads once, writes once.

For the typical transformer hidden state at batch=8, seqlen=4096, hidden=8192, BF16:

activation size = 8 * 4096 * 8192 * 2 bytes = 512 MiB

Each HBM round-trip costs 512 MiB / 3 TB/s ≈ 175 µs. Saving three round-trips per layer × 80 layers = 240 round-trips = 42 ms per forward pass - for free, just by stopping the round-tripping. That is the prize.


2. The HBM round-trip cost model

Let n_op be the number of fused operations in a chain, S the size of the activation tensor in bytes, BW the HBM bandwidth, and K_launch the per-kernel launch overhead (~5 µs on a modern driver). Time for unfused execution:

T_unfused = n_op * (2*S / BW + K_launch)

(Each op reads S, writes S.)

Time for fused execution (one kernel reads once, writes once, does all the work in registers/SMEM):

T_fused = 2*S / BW + K_launch + T_compute_in_kernel

For elementwise chains, T_compute_in_kernel is negligible compared to the HBM term, so:

speedup ≈ n_op    (asymptotically, ignoring launch overhead)

A fused chain of 5 elementwise ops is roughly 5× faster than the unfused version, regardless of how clever the unfused kernels are individually. This is the headline result that motivates every fusion compiler ever written.

Worked numerical example

Take the post-attention residual stream of a Llama-3-70B layer, batch=1, seqlen=8192:

hidden = 8192,  bf16,   activation = 1 * 8192 * 8192 * 2 = 128 MiB

The post-attention chain: x + attn_out → RMSNorm → linear_gate → silu → linear_up · gate → linear_down → x + ffn_out.

Counting just the elementwise pieces (residual add, RMSNorm scale, SiLU, elementwise multiply, residual add) - five elementwise/light-reduction operations on the activation tensor:

T_unfused_elementwise = 5 * (2 * 128 MiB / 3 TB/s + 5 µs)
                     = 5 * (85 µs + 5 µs)
                     = 450 µs

T_fused_elementwise   = 2 * 128 MiB / 3 TB/s + 5 µs
                     = 90 µs

Per layer, per token, fusion saves ~360 µs in the elementwise chain. Across 80 layers and a 100-token decode, that's ~2.9 seconds. On a real inference engine the saving is closer to 30–50% of total latency because the matmuls still dominate, but eliminating elementwise round-trips is the single most impactful generic optimization in deep-learning compilers.


3. Fusion taxonomy

Fusion comes in five shapes, in roughly ascending implementation difficulty:

# Pattern Example Difficulty
1 Elementwise → elementwise (a + b) * c Trivial (every compiler does it)
2 Elementwise → reduction sum(x * x) (used in RMSNorm) Easy
3 Reduction → elementwise (broadcast) RMSNorm = x / sqrt(mean(x²) + ε) * γ Medium (needs two-pass or online algorithm)
4 GEMM + epilogue gelu(A @ B + bias) Medium (CUTLASS/CUBLASLt epilogue API)
5 Streaming reduction over GEMM output FlashAttention: softmax(Q@Kᵀ / √d) @ V Hard (requires algorithmic redesign; chapter 07)

A sixth, more ambitious shape - multi-GEMM fusion, where two matrix multiplies sharing an intermediate are fused (e.g., the FFN's up_proj and gate_proj in SwiGLU) - is increasingly common in production inference engines but requires either (a) careful CUTLASS programming or (b) horizontal fusion at the Triton level.

The taxonomy axis you actually care about is vertical vs horizontal:

  • Vertical (producer-consumer) fusion combines operations along the data-flow direction: op B consumes op A's output, so we keep A's output in registers/SMEM and feed it directly to B without writing to HBM. All of patterns 1–5 above are vertical.
  • Horizontal (sibling) fusion combines independent operations that have no data dependency, executing them in the same kernel to amortize launch overhead and (sometimes) share input loads. Example: q = x @ Wq; k = x @ Wk; v = x @ Wv can be done as one fused kernel that loads x once.

The next two sections derive each rigorously.


4. Vertical fusion derived

4.1 The producer-consumer pattern

Consider two operations:

B = f(A)
C = g(B)

Unfused, the dataflow through HBM is:

HBM:  read A   →   write B
HBM:  read B   →   write C
Total HBM traffic = |A| + 2|B| + |C|

Fused into one kernel:

For each tile of A:
    load tile_A from HBM into registers
    tile_B = f(tile_A)            # stays in registers
    tile_C = g(tile_B)            # stays in registers
    store tile_C to HBM
Total HBM traffic = |A| + |C|

We saved 2|B| bytes of HBM traffic. If f and g are elementwise and same-shape, |A| = |B| = |C|, so we cut traffic by 50% and (in the bandwidth-bound regime) doubled throughput.

4.2 The shape-compatibility requirement

Vertical fusion works only when the producer's output layout matches the consumer's input layout at the tile granularity. Two cases:

  • Pointwise op → pointwise op: trivially compatible (same element-to-element correspondence). Always fusible.
  • Reduction → broadcast: the reduction shrinks the tensor; the broadcast re-expands it. Fusion is possible but requires either (a) keeping the reduction result in SMEM and re-reading per element (the two-pass RMSNorm pattern), or (b) computing the reduction online during the consumer pass.

4.3 Worked example: RMSNorm fused

The naive RMSNorm:

mean_sq = (x * x).mean(dim=-1, keepdim=True)   # kernel 1
rrms    = torch.rsqrt(mean_sq + eps)           # kernel 2
y       = x * rrms * gamma                     # kernel 3

Three kernels, each round-tripping x (or a derivative of it) through HBM. Fused, in pseudocode:

def rmsnorm_fused(x, gamma, eps):
    # x: (..., H)
    # one tile = one row (H elements)
    for row in tiles(x):
        # pass 1: reduction
        s = 0.0
        for j in range(0, H, BLOCK):
            xj = load(row, j)            # HBM → registers
            s += sum(xj * xj)            # accumulate in fp32
        rrms = rsqrt(s / H + eps)        # scalar

        # pass 2: scale-broadcast
        for j in range(0, H, BLOCK):
            xj = load(row, j)            # HBM → registers (re-read!)
            gj = load(gamma, j)
            store(row, j, xj * rrms * gj)

We read x twice and write it once - total 3|x| HBM bytes - but we eliminated |x| for mean_sq (which never materialized) and saved two kernel launches. For the Llama-3-70B example in §2, the unfused version moves 5|x| bytes; the fused version moves 3|x|. Speedup: 5/3 = 1.67×.

A single-pass RMSNorm uses Welford-style online statistics to avoid the re-read of x - that drops traffic to 2|x|, the absolute floor. The Triton kernel in chapter 03 shows this.


5. Horizontal fusion derived

5.1 The independent-siblings pattern

Now consider three independent operations sharing an input:

Q = X @ Wq
K = X @ Wk
V = X @ Wv

Unfused: three kernel launches, each reading X from HBM. HBM traffic = 3|X| + |Q| + |K| + |V|.

Horizontally fused: one kernel reads X once and produces all three outputs. HBM traffic = |X| + |Q| + |K| + |V|. Saves 2|X|.

For X of shape (batch * seqlen, hidden) = (8 * 4096, 8192) in BF16 = 512 MiB, savings = 1 GiB of HBM traffic ≈ 340 µs at H100 bandwidth, just for the QKV projections per layer.

In practice, modern inference engines (vLLM, TensorRT-LLM) fuse QKV by concatenating [Wq, Wk, Wv] along the output dimension into a single W_qkv of shape (hidden, 3*head_dim*n_heads), and slicing after the matmul. This is mathematically identical to horizontal fusion and is the standard pattern - if a transformer codebase you read does not fuse QKV, that's a perf bug.

5.2 The SwiGLU case

For the FFN block:

gate = silu(X @ W_gate)
up   = X @ W_up
y    = (gate * up) @ W_down

W_gate and W_up can be horizontally fused into W_gu of shape (hidden, 2 * inter_dim), sliced into halves, with SiLU and the elementwise multiply fused as the epilogue. Saves |X| HBM read per FFN per layer. For Llama-3-70B at batch=1 decode, ~85 µs/layer × 80 layers = 6.8 ms per token - substantial.


6. GEMM epilogue fusion

A GEMM epilogue is any elementwise operation chained immediately after C = A @ B. The CUTLASS library (and its successor, CuTeDSL) supports declarative epilogue fusion via a templated programming model. Common epilogues:

  • C = A @ B + bias (the GEMM-bias-add pattern in every linear layer).
  • C = act(A @ B + bias) where act ∈ {ReLU, GELU, SiLU}.
  • C = act(A @ B + bias) * scale (for quantized inference).
  • C = act(A @ B + bias) + residual (the residual stream pattern in transformers - fuses the residual add into the matmul).

6.1 Why epilogue fusion is cheap

The GEMM kernel already has the output tile C_tile resident in registers immediately after computing it. Applying an elementwise function to it before storing costs zero additional HBM traffic. The only cost is a handful of extra instructions per register, well below the noise floor of the matmul itself.

In CUTLASS terms, the epilogue is a templated EpilogueOp that receives the accumulator tile and produces the output tile:

using EpilogueOp = cutlass::epilogue::thread::LinearCombinationGELU<
    ElementOutput,           // bf16
    128 / sizeof_bits<...>,  // vector length
    ElementAccumulator,      // fp32
    ElementCompute>;         // fp32

The LinearCombinationGELU epilogue computes act(α * accumulator + β * bias) in registers, then stores to HBM. One kernel; zero round-trip.

6.2 The "fuse the residual add into the matmul" trick

The residual connection in a transformer block computes:

x_new = x + linear(layernorm(x))

If the linear is a CUTLASS GEMM with epilogue D = α*(A@B) + β*C, you can pass x itself as C with β=1, and the residual add costs zero extra HBM traffic - the GEMM was going to write its output anyway; with the epilogue, it writes the residual-added output instead. This saves a full |x| round-trip per block, every layer, every forward pass.

This is implemented by addmm in PyTorch (when properly routed to CUBLASLt with the epilogue path) and by every production inference engine. If you write y = linear(x) + residual as two separate kernels in a hot path, that's a perf bug.


7. Streaming-reduction fusion: the FlashAttention pattern

The most algorithmically sophisticated fusion in modern AI is FlashAttention (chapter 07). The naive attention computation:

S = Q @ Kᵀ          # (B, H, M, N) - materialized
P = softmax(S, axis=-1)  # (B, H, M, N) - materialized
O = P @ V            # (B, H, M, d)

The intermediate S and P are O(M·N) and dominate memory at long sequence length. At seqlen=8192 with head_dim=128, S for a single batch×head pair is 256 MiB in BF16. For B=8, H=64, the total is 128 GiB. Doesn't fit.

FlashAttention's insight: S and P never need to be materialized in HBM. Compute the softmax incrementally, tile-by-tile, while accumulating O directly. The mathematical machinery is the online softmax derived in chapter 03 §online-softmax and rigorously in chapter 07.

For fusion purposes, the structural lesson is:

A reduction (softmax-then-matmul) over a streaming source (Q@Kᵀ computed tile-by-tile) can be fused into a single kernel that never materializes the intermediate.

This pattern - streaming a producer through a reducer with state kept in registers/SMEM - generalizes well beyond attention. Examples in production:

  • Cross-entropy loss fused with the final logits projection. Logits at vocab=128k × seqlen=4096 × batch=8 are 16 GiB; never materialize.
  • Top-k sampling fused with logits. Same memory argument.
  • MoE router + dispatch fused. The router's softmax + top-k + scatter can all run in a single kernel.

The price: the fused kernel is algorithmically non-trivial. Each new instance requires real engineering. Compilers cannot yet derive these fusions automatically; you write them by hand in Triton or CUDA.


8. Compiler-driven fusion: XLA, TorchInductor, Triton

Three major systems perform deep-learning kernel fusion in production. Their philosophies differ; their results converge.

8.1 XLA (JAX, TensorFlow)

Chapter 05 covers XLA in depth. The relevant fusion passes:

  • fusion pass: the canonical pass. Groups elementwise/broadcast/reduce ops into "fusion clusters" and emits one kernel per cluster. Driven by a cost model that estimates HBM traffic.
  • gpu_fusion_pipeline: the GPU-specific lowering. Emits LLVM IR with a single CUDA kernel per fusion. Modern XLA also emits Triton for some patterns (matmul + epilogue).
  • priority_fusion: newer pass with a priority queue over fusion candidates.

Fusion in XLA is declarative: you write pure functional JAX, XLA decides what fuses. You can inspect with jax.jit(f).lower(...).compile().as_text() (chapter 05).

8.2 TorchInductor (PyTorch 2)

Chapter 04 covers Inductor. Its fusion strategy:

  • Scheduler-driven node fusion: Inductor builds a graph of IRNodes (one per ATen op), then greedily fuses adjacent nodes whose fusion satisfies a memory-locality cost model.
  • Emit Triton or C++ for the fused kernel. GPU path emits Triton; CPU path emits C++ with OpenMP.
  • Pointwise + reduction + pointwise is the bread-and-butter fusion class. More than 80% of the speedups Inductor delivers come from this pattern (per PyTorch's perf blogs).

Inspect with TORCH_LOGS="output_code" (chapter 04) - you get the actual Triton source Inductor generated.

8.3 Triton autotuning (manual, but compiler-assisted)

Triton (chapter 03) is the kernel-author's tool, not a graph compiler. You write the fused kernel; Triton handles the lowering. The compiler contribution is in autotuning - exploring tile shapes, num_warps, num_stages combinations and picking the best.

Production stack composition (typical 2026 inference engine):

Model architecture (PyTorch)
torch.compile + Inductor      ← elementwise + reduction fusion (auto)
CUBLASLt / CUTLASS matmuls    ← GEMM + epilogue fusion (manual config)
FlashAttention / xFormers     ← streaming-reduction fusion (handwritten)
Triton custom kernels         ← anything Inductor missed (handwritten)

The lesson: let the compiler do the easy fusions; reserve human effort for the algorithmically hard ones (FlashAttention, paged attention, fused MoE, fused quantized GEMMs like Marlin).


9. Hand-rolled fusion in Triton: three full kernels

We work three increasingly complex examples. All assume the reader has read chapter 03 (Triton).

9.1 Kernel 1 - Fused bias-GELU-residual

Operation: y = gelu(x @ W + bias) + residual, where x: (M, K), W: (K, N), bias: (N,), residual: (M, N).

The matmul itself uses standard tiled GEMM (chapter 02). The fusion is in the epilogue: after computing the accumulator tile, apply bias, GELU, and the residual add in registers, then store.

import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({'BM': 128, 'BN': 256, 'BK': 32}, num_warps=8, num_stages=3),
        triton.Config({'BM': 64,  'BN': 128, 'BK': 32}, num_warps=4, num_stages=4),
        triton.Config({'BM': 128, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def fused_linear_gelu_residual(
    x_ptr, w_ptr, bias_ptr, residual_ptr, y_ptr,
    M, N, K,
    sxm, sxk, swk, swn, srm, srn, sym, syn,
    BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BM + tl.arange(0, BM)
    offs_n = pid_n * BN + tl.arange(0, BN)
    offs_k = tl.arange(0, BK)

    x_ptrs = x_ptr + offs_m[:, None] * sxm + offs_k[None, :] * sxk
    w_ptrs = w_ptr + offs_k[:, None] * swk + offs_n[None, :] * swn

    acc = tl.zeros((BM, BN), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BK)):
        x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BK, other=0.0)
        w = tl.load(w_ptrs, mask=offs_k[:, None] < K - k * BK, other=0.0)
        acc += tl.dot(x, w)
        x_ptrs += BK * sxk
        w_ptrs += BK * swk

    # --- epilogue, in registers, no HBM round-trip ---
    bias = tl.load(bias_ptr + offs_n, mask=offs_n < N, other=0.0)
    acc = acc + bias[None, :]

    # GELU approximation (tanh form), fp32
    c = 0.7978845608  # sqrt(2/pi)
    acc_g = 0.5 * acc * (1.0 + tl.math.tanh(c * (acc + 0.044715 * acc * acc * acc)))

    residual = tl.load(
        residual_ptr + offs_m[:, None] * srm + offs_n[None, :] * srn,
        mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0,
    )
    y = acc_g + residual.to(tl.float32)

    tl.store(
        y_ptr + offs_m[:, None] * sym + offs_n[None, :] * syn,
        y.to(y_ptr.dtype.element_ty),
        mask=(offs_m[:, None] < M) & (offs_n[None, :] < N),
    )

HBM traffic accounting:

  • Unfused (4 kernels: matmul, bias-add, GELU, residual-add): |x| + |W| + 4|y| + |bias| + |residual||x| + |W| + 5|y|.
  • Fused: |x| + |W| + |bias| + |residual| + |y||x| + |W| + 2|y|.
  • Savings: 3|y| HBM bytes per call.

9.2 Kernel 2 - Fused RMSNorm with online statistics

Two-pass RMSNorm requires re-reading x. One-pass uses Welford-style online updates. For RMSNorm specifically, since we only need the sum of squares (not variance), the update is simply additive:

@triton.jit
def rmsnorm_fwd_fused(
    x_ptr, gamma_ptr, y_ptr,
    stride_xm, stride_xn,
    stride_ym, stride_yn,
    N, eps,
    BLOCK_N: tl.constexpr,
):
    # One program instance handles one row.
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK_N)
    mask = cols < N

    x = tl.load(x_ptr + row * stride_xm + cols * stride_xn, mask=mask, other=0.0).to(tl.float32)
    sum_sq = tl.sum(x * x, axis=0)
    rrms = 1.0 / tl.sqrt(sum_sq / N + eps)
    gamma = tl.load(gamma_ptr + cols, mask=mask, other=0.0).to(tl.float32)

    y = (x * rrms * gamma).to(y_ptr.dtype.element_ty)
    tl.store(y_ptr + row * stride_ym + cols * stride_yn, y, mask=mask)

If N > BLOCK_N (hidden dim larger than what fits in one tile), this becomes two-pass with shared-memory state. For modern transformer hidden dims (4096–16384), one-tile-per-row is feasible up to BLOCK_N=16384 on H100 (uses ~64 KiB of registers/SMEM).

Note the precision discipline: load in BF16, promote to FP32 for the reduction and the divide, store back in BF16. Chapter 11 §3.3 explains why this is mandatory - accumulating sum-of-squares in BF16 catastrophically loses precision past hidden ≈ 1024.

9.3 Kernel 3 - Fused softmax (causal-masked, for attention)

The streaming softmax kernel from chapter 03, with causal masking and tile-wise online normalization:

@triton.jit
def causal_softmax(
    s_ptr, o_ptr, stride_b, stride_h, stride_m, stride_n,
    M, N,
    BLOCK_N: tl.constexpr,
):
    pid_bh = tl.program_id(0)
    pid_m  = tl.program_id(1)
    # Process one query row at a time
    row = pid_m
    cols = tl.arange(0, BLOCK_N)
    base = pid_bh * stride_h + row * stride_m

    # Online softmax state
    m_i = -float('inf')
    l_i = 0.0
    # First pass: find max and partial sum
    for start in range(0, N, BLOCK_N):
        offs = start + cols
        mask = (offs < N) & (offs <= row)              # causal
        s = tl.load(s_ptr + base + offs * stride_n, mask=mask, other=-float('inf')).to(tl.float32)
        m_new = tl.maximum(m_i, tl.max(s, axis=0))
        l_i = l_i * tl.exp(m_i - m_new) + tl.sum(tl.exp(s - m_new), axis=0)
        m_i = m_new

    # Second pass: normalize and store
    for start in range(0, N, BLOCK_N):
        offs = start + cols
        mask = (offs < N) & (offs <= row)
        s = tl.load(s_ptr + base + offs * stride_n, mask=mask, other=-float('inf')).to(tl.float32)
        p = tl.exp(s - m_i) / l_i
        tl.store(o_ptr + base + offs * stride_n, p.to(o_ptr.dtype.element_ty), mask=mask)

This kernel is a building block; the full FlashAttention kernel goes one step further and fuses the matmul P @ V into the same loop, never materializing P at all. See chapter 07 for the full derivation; the punchline is that the inner loop interleaves Q @ Kᵀ tile computation, online softmax update, and P @ V accumulation - all in registers.


10. Precision discipline under fusion

Fusion makes precision choices more dangerous, not less, because:

  1. Intermediate values that used to be written to HBM in their materialized dtype are now stored in registers in whatever dtype the producer last computed in. A BF16 elementwise op that previously rounded its output to BF16 may now keep it as FP32 in registers, and the next op consumes FP32 - which may be silently better, but is also a behavioral change.
  2. Accumulators inside a fused reduction must be FP32, not the input dtype. Chapter 11 §3.2 derives this for reductions; the rule applies verbatim inside fused kernels.
  3. Epilogues on GEMMs typically receive FP32 accumulator tiles and downcast at the last step. If you insert a precision-sensitive operation (a divide, an exp, a log) in the epilogue, do it in FP32 before the downcast.

The discipline cheat sheet:

Operation Compute dtype inside fused kernel Why
Elementwise add/mul match input no precision loss either way
Elementwise divide / sqrt / exp / log FP32 nonlinear; small inputs lose precision in BF16
Reduction (sum, mean, dot, max) FP32 catastrophic cancellation in BF16 past ~256 elements
Softmax FP32 internally, BF16 output both reduction and exp need FP32
LayerNorm / RMSNorm FP32 statistics, BF16 output reduction + divide
GEMM accumulator FP32 (tensor cores already do this for BF16/FP16 inputs) hardware default
GELU / SiLU activation FP32 if the epilogue, else match tanh/exp inside

If your fused kernel diverges from the unfused reference past ~1e-3 in BF16, you almost certainly downcasted an accumulator too early.


11. The limits of fusion

Fusion is not free; it competes for finite GPU resources. Three hard constraints:

11.1 Register pressure

Each Triton/CUDA kernel uses some number of registers per thread. An H100 SM has 65,536 32-bit registers shared across active warps. Occupancy = active_warps / max_warps_per_SM. A fused kernel with deeper computation needs more registers per thread; past a threshold, occupancy collapses and HBM-fetch latency stops being hidden by warp switching.

The relationship is:

max_threads_per_SM = registers_per_SM / registers_per_thread

If your fused kernel uses 128 regs/thread, you get 65536 / 128 = 512 threads per SM - only 16 warps. If you only need 32 warps to hide latency, this is fine; if you need 64, you've over-fused.

Diagnostic: nvcc --ptxas-options=-v (CUDA) or Triton's autotune output reports registers per thread. Above 128, look hard.

11.2 Shared memory capacity

H100 has 228 KiB of SMEM per SM (configurable). Fused kernels often use SMEM to stage intermediate tiles. Past the capacity, you can't fit two concurrent thread blocks per SM, halving occupancy.

For matmul kernels with epilogues, SMEM is dominated by the A and B tiles: 2 * BM * BK * dtype_bytes + 2 * BK * BN * dtype_bytes (the 2 is for double-buffering). Epilogue logic is usually register-only.

11.3 The kernel-launch amortization plateau

For very small inputs (batch=1, seqlen=1 in decoding), kernel launch overhead (~5 µs) is comparable to kernel runtime. Fusion's value is huge - eliminating a launch saves more than the kernel itself takes. But for very large inputs, launch overhead is amortized to zero and fusion's only value is the HBM-traffic reduction.

The decoding regime (single-token autoregressive) is the most fusion-sensitive workload in AI infrastructure. Every inference engine in production fuses aggressively because of this.

11.4 Tile-shape mismatches

If op A is naturally tiled (64, 128) and op B (128, 64), you cannot fuse them in the obvious way - A's output tile doesn't match B's input tile. You either accept a transpose in registers (cheap if it fits) or accept the unfused cost. Compiler-driven fusion (XLA, Inductor) deals with this by not attempting fusions that require expensive layout changes; the cost model rejects them.


12. Profiling fused kernels with Nsight Compute

You have a fused kernel. Is it actually fast? Nsight Compute (ncu) is the answer.

The minimal workflow:

ncu --set full --kernel-name fused_linear_gelu_residual -o report ./my_app
ncu-ui report.ncu-rep   # open the GUI

The metrics that matter for fused kernels:

Metric Meaning Healthy value
sm__throughput.avg.pct_of_peak_sustained_elapsed SM utilization >70% for compute-bound, <30% for memory-bound (expected)
dram__throughput.avg.pct_of_peak_sustained_elapsed HBM utilization >70% for memory-bound (you want this)
l1tex__t_sectors_pipe_lsu_mem_global_op_ld.sum Global loads in sectors Compare unfused vs fused; should drop
launch__registers_per_thread Reg pressure <128 typical, >196 alarming
launch__shared_mem_per_block SMEM use <96 KiB to allow 2 blocks/SM on H100 default
smsp__warps_eligible.avg.pct_of_peak_sustained_elapsed Warp scheduler utilization >70% means latency is well-hidden

The "Did fusion work?" test: profile the unfused chain and the fused kernel. Compare dram__bytes_read.sum + dram__bytes_write.sum. A successful fusion reduces this by approximately the predicted amount from §2.


13. When NOT to fuse

Three situations where fusion is the wrong call:

13.1 When the unfused kernels are already individually fast and the activations need to be saved for the backward pass

For training, the activation produced by an intermediate op is often needed by the backward pass. If you fuse the op with the next one, you must either (a) recompute the activation in the backward pass (the activation-checkpointing pattern), or (b) write the intermediate to HBM anyway, eliminating the fusion benefit.

PyTorch's torch.compile handles this with AOTAutograd partitioning - the forward and backward graphs are jointly optimized; the partitioner decides what to save vs recompute. For hand-rolled training kernels, this trade is explicit.

13.2 When fusion harms debuggability and the perf delta is small

A fused kernel that delivers 5% latency improvement but is 10× harder to debug, profile, and modify is a net loss for an actively evolving codebase. Save aggressive hand-fusion for the inner loop of mature, stable code paths.

13.3 When the fused kernel's autotune surface is too large

A fused matmul with 3 epilogue variants × 5 tile shapes × 4 num_warps × 3 num_stages = 180 autotune configurations. Each can take seconds to compile and benchmark. For one-off scripts, the autotune time exceeds the inference time saved. Production engines amortize this with a tuning cache (Triton's @triton.autotune does this automatically - cache key is the input shape).


14. Practical exercises

Exercise 1 - Quantify the win

Take a Llama-2-7B forward pass at batch=4, seqlen=2048, BF16. Compute, from first principles:

  • Total HBM bytes moved by the elementwise + normalization operations in one decoder layer, unfused (model each LayerNorm/RMSNorm as 3 kernels, each residual add as 1, each activation as 1).
  • The same, fully fused (each block executes RMSNorm and the SwiGLU pipeline as single fused kernels).
  • Estimated latency saving per layer on H100 (3 TB/s HBM).

Hint: hidden = 4096; intermediate = 11008; head_dim = 128; n_layers = 32. Show your work.

Exercise 2 - Implement and benchmark fused RMSNorm

Implement the single-pass RMSNorm kernel from §9.2 in Triton. Benchmark vs the PyTorch nn.RMSNorm equivalent at shapes (B, S, H) = (8, 4096, 4096) and (1, 1, 4096) (training vs decode). Report:

  • Throughput (TB/s of effective HBM bandwidth).
  • Numerical max-abs-error vs an FP32 reference computed in torch.float64.

Bonus: show that BF16-accumulator RMSNorm diverges from FP64 by >1e-2 at H=8192 and explain why.

Exercise 3 - GEMM epilogue fusion in CUTLASS

Pick one of: PyTorch's addmm (with the linear → add → activation pattern) or CUTLASS's LinearCombinationGELU epilogue example. Profile the fused vs unfused (separate matmul + bias + GELU) version at shape (M, N, K) = (8192, 8192, 8192) and report:

  • Latency difference.
  • HBM bytes moved (from ncu).
  • Justify the gap with the §2 cost model.

Exercise 4 - Find a fusion in Inductor's output

Write a small PyTorch function with a fusible chain (e.g., x.relu().mul(2).add(1).sigmoid()). Compile with torch.compile, set TORCH_LOGS="output_code", and inspect the generated Triton kernel. Confirm that all four ops appear in one kernel. Find one example in your own codebase (or a public model) where Inductor failed to fuse a chain you expected it to, and explain why (read the Inductor scheduler logs).

Exercise 5 - Implement FlashAttention v1 in Triton

(Stretch.) Working from chapter 07's algorithmic pseudocode and chapter 03's Triton tutorial, implement FlashAttention v1 (forward only, no causal mask). Benchmark vs torch.nn.functional.scaled_dot_product_attention (which dispatches to FlashAttention) at (B, H, S, D) = (4, 16, 4096, 128). You should be within 3× of the optimized version on first attempt; closing the gap requires deeper tile-shape and stage tuning.

Exercise 6 - Precision regression hunt

Take the fused RMSNorm from exercise 2. Deliberately introduce a bug: compute the sum-of-squares in BF16 instead of FP32. Show numerically that:

  • The error vs FP64 grows with the hidden dimension.
  • The error grows faster than linearly (specifically, O(sqrt(H)) from the central-limit-theorem accumulation of rounding noise).
  • The error at H=8192 is large enough to perturb downstream logits past the temperature-sampling threshold for typical LLMs.

Connect to chapter 11 §3.2.


15. Cheat sheet and further reading

Cheat sheet

  • Fuse elementwise chains aggressively. Compilers (Inductor, XLA) do this for free; verify they did.
  • Fuse GEMM epilogues. bias + activation + residual belong in the matmul kernel. Use addmm, CUBLASLt, or CUTLASS.
  • Fuse QKV and gate/up projections. Always. If you see three separate matmuls for Q, K, V - that's a perf bug.
  • Fuse reductions with their producers/consumers (RMSNorm, softmax, top-k). Online algorithms (Welford, online softmax) make this single-pass.
  • Reserve hand-Triton for the algorithmically hard cases (FlashAttention, fused MoE, paged attention, fused quantized GEMM).
  • Keep FP32 accumulators inside fused kernels. Always. See chapter 11.
  • Profile with ncu and check that HBM traffic dropped by the predicted amount; if it didn't, fusion didn't happen.

Further reading

  • PyTorch Inductor docs - pytorch.org/docs/stable/torch.compiler_inductor.html. The scheduler and fusion-cost-model sections.
  • XLA fusion - openxla.org/xla/operation_semantics. The Fusion instruction and the fusion_kind enum.
  • CUTLASS epilogues - the cutlass/epilogue/ directory in the CUTLASS repo. Especially LinearCombinationGeneric.
  • FlashAttention papers - Dao et al., FlashAttention (2022); FlashAttention-2 (2023); FlashAttention-3 (2024). Each is a different fusion algorithm on the same operator.
  • Triton tutorials - triton-lang.org/main/getting-started/tutorials/. The fused-softmax and fused-attention tutorials are the canonical references.
  • Horace He, Making Deep Learning Go Brrrr From First Principles (blog). The clearest exposition of the bandwidth-bound argument that motivates all of this.
  • NVIDIA CUDA Best Practices Guide - the Memory Optimizations chapter. Foundational.

Chapter 13 (not yet written) will continue with custom autograd for fused kernels - how to register backward passes for the kernels you fuse, and how to compose them through torch.autograd.Function and register_autograd (chapter 04 §custom-ops). Until that chapter exists, the canonical reference is the FlashAttention repo's csrc/ directory.

Deep Dives-Self-Contained Reference Chapters

Twelve chapters that take the AI Systems curriculum from "moderate-depth survey + external paper assignments" to self-contained mastery resources. Each chapter was authored to let a reader master the topic from the document alone, without needing the underlying papers, vendor whitepapers, or framework docs as primary sources.

Total: ~104,000 words / ~14,500 lines across 12 files. Each chapter is 7,000–11,000 words, layered (intuition → mechanism → math → numbers → diagrams → exercises), and ends with worked exercises.


Reading Order and Curriculum Mapping

The deep dives are designed to be read in tandem with the monthly modules. Recommended pairing:

When Read this deep dive After the monthly module
Week 5–6 01_GPU_ARCHITECTURE.md Month 2 §5
Week 6–7 02_CUDA_PROGRAMMING.md Month 2 §6–7
Week 8 03_TRITON.md Month 2 §8
Week 9–10 04_PYTORCH_INTERNALS.md Month 3 §9–10
Week 11 05_JAX_XLA.md Month 3 §11
Week 13–16 06_DISTRIBUTED_TRAINING.md Month 4 (all)
Week 17 07_ATTENTION_TRANSFORMER.md Month 5 §17
Week 18 08_INFERENCE_SERVING.md Month 5 §18
Week 19 09_QUANTIZATION.md Month 5 §19
Week 20 10_SPECULATIVE_DISAGGREGATION.md Month 5 §20
Week 16 + always 11_NUMERICS_AND_MIXED_PRECISION.md Month 4 §16 (referenced everywhere)
Week 8 + always 12_KERNEL_FUSION.md Month 2 §8 (referenced from Month 3 / 5 too)

You can also read the deep dives standalone as a reference text. Topical order:

  • Hardware foundation: 01 → 02 → 03
  • Framework foundation: 04 → 05
  • Numerical foundation: 11 (orthogonal to all others; reference often)
  • Training: 06 (which assumes 11)
  • Inference architecture: 07 → 08 → 09 → 10

Chapter Index

`01_GPU_ARCHITECTURE.md - NVIDIA GPU Architecture and Memory Hierarchy

~9,100 words. Throughput vs latency machines; SIMT/SIMD/MIMD; the streaming multiprocessor (warps, schedulers, registers, divergence, ITS); the full memory hierarchy with H100 numbers; tensor cores (WMMA, mma.sync, fragments, all precisions including FP8); 2:4 sparsity; cp.async + TMA + thread-block clusters; occupancy theory derived from first principles with three worked numerical examples; NVLink/NVSwitch/NVL72; Ada and Blackwell deltas (with explicit uncertainty); AMD MI300X contrast; 5 worked exercises.

`02_CUDA_PROGRAMMING.md - CUDA From First Kernel to Optimized GEMM

~8,800 words. Programming model, qualifiers, launch syntax; indexing and grid-stride loops; memory transfer (sync/async, pinned, UVM, zero-copy); streams and events with overlap pipelines; error-handling discipline; coalescing rules with worked numerical examples; shared memory bank conflicts and the [32][33] fix derived; reductions (four variants, perf evolution); six-stage tiled GEMM walkthrough with code (naive → coalesced → SMEM tiled → register tiled → tensor-core wmma → cp.async double-buffered); nvcuda::wmma and mma.sync PTX; cooperative groups + Hopper TBC; profiling discipline; complete buildable BF16 GEMM at 2048×2048; 6 exercises.

`03_TRITON.md - The Triton GPU DSL

~7,000 words. Why Triton (block-level programming model); @triton.jit, program instances, tl.load/store/dot; mask semantics; tl.constexpr specialization; @triton.autotune configs and caching; online softmax derivation (single-element and block forms), Welford, log-sum-exp; six full annotated kernels: vector add, naive matmul, autotuned tiled matmul with L2-friendly swizzle, fused softmax, RMSNorm forward+backward, simplified causal flash-attention; compilation pipeline (Python → MLIR → PTX); torch integration via torch.library.custom_op; nine concrete pitfalls; Triton vs CUTLASS vs hand-CUDA decision table; 6 exercises.

`04_PYTORCH_INTERNALS.md - PyTorch From Tensor to Inductor

~8,400 words. Layered architecture with ASCII trace of a + b; torch.Tensorat::Tensorc10::TensorImplc10::Storage; strides and views; the dispatcher (DispatchKey priority, key sets, TORCH_LIBRARY_IMPL); native_functions.yaml codegen; autograd engine (dynamic tape, Function/Node, next_edges, custom autograd.Function, version counters); requires_grad vs no_grad vs inference_mode; autocast as a dispatcher layer; torch.compile end-to-end (TorchDynamo, AOTAutograd, Inductor, guards, modes, TORCH_LOGS); modern custom op path (@torch.library.custom_op, register_fake, register_autograd) with Triton example; C++ extension skeleton; CUDA caching allocator with free-list algorithm and stream-aware reuse; profiler internals; 6 exercises.

`05_JAX_XLA.md - JAX Transformations and the XLA Compiler

~8,200 words. Why JAX (functional, composable, XLA-default, TPU-first); pure functions as the unit of compilation; PyTrees and tree_util; stateless PRNGs (PRNGKey, split, the three rules); tracing and jaxprs with worked annotated example; jit cache keys, recompilation costs, AOT lowering; grad/vjp/jvp/higher-order/jacrev/custom VJPs; vmap semantics; legacy pmap vs unified jit + Mesh/PartitionSpec; shard_map; structured loops (scan, while_loop); XLA HLO with op table, full pipeline (jaxpr → StableHLO → HLO → device), fusion, layout, GSPMD with worked Megatron-MLP propagation; TPU vs GPU; Equinox/Flax/Optax; pallas; 6 exercises.

`06_DISTRIBUTED_TRAINING.md - Communication, Parallelism, and Schedule Math

~9,400 words. Memory decomposition (16Φ accounting); collective primitive definitions; derivation of all 5 all-reduce algorithms including ring's bandwidth-optimality proof; NCCL algorithm selection; DDP with bucketing/overlap; full ZeRO-1/2/3 memory math table; FSDP (wrapping, prefetch, mixed precision, checkpointing, CPU offload, FSDP2); Megatron column-/row-parallel derivations and the column→row chain trick; attention/MLP TP layouts; pipeline schedules with ASCII diagrams (naive, GPipe, 1F1B, interleaved 1F1B, Zero Bubble) with bubble formulas; 3D parallelism decision matrix with worked examples for 8B/70B/405B; sequence parallelism; FP16 loss scaling derivation; FP8 (E4M3/E5M2 + per-tensor); profiling for overlap; fault tolerance; cluster topology; 6 exercises.

`07_ATTENTION_TRANSFORMER.md - Transformer Math and FlashAttention

~8,600 words. Autoregressive setup; scaled dot-product attention with full √dₖ derivation from variance argument; multi-head; causal masking; MQA/GQA with worked Llama-3 group sizes; RoPE complex-number derivation showing <q'_m, k'_n> depends only on m-n; ALiBi/sliding window/YaRN/NTK/PI; pre-norm vs post-norm; RMSNorm vs LayerNorm; SwiGLU and the 8/3 d_ff ratio; KV-cache math with Llama-3-70B worked example (~2.5 GB at 8K with GQA); O(S²) cost analysis; full FlashAttention derivation of online softmax with inductive proof of equivalence to all-at-once softmax; tiled algorithm pseudocode; FA-2 and FA-3 deltas; flash_attn_with_kvcache; 6 exercises.

`08_INFERENCE_SERVING.md - Paged Attention, Continuous Batching, vLLM

~8,200 words. Cost-model derivation T_step ≈ (W + b·KV·S) / B_HBM; H100 arithmetic-intensity crossover (~295 FLOP/byte); decode-batch-1 sits at ~1 FLOP/byte; PagedAttention with block pool sizing, page tables, block manager pseudocode, ASCII block-table diagram, fragmentation analysis; Orca-style continuous batching with full scheduler pseudocode; vLLM architecture and engine main loop; chunked prefill (Sarathi-Serve); eviction strategies (swap vs recompute); prefix caching via cumulative content hashing; speculative decoding preview; DistServe/Mooncake/Splitwise disaggregation with KV-transfer engineering; TTFT/TPOT/throughput/goodput SLOs; tuning levers with rules of thumb; 6 exercises.

`09_QUANTIZATION.md - Number Formats, AWQ, GPTQ, SmoothQuant, FP8

~10,600 words. Why quantize (arithmetic-intensity argument); number formats (FP32/FP16/BF16/FP8 E4M3/E5M2/INT8/INT4) with bias, range, precision derivations; affine quantization with full derivation of scale/zero_point; granularity (per-tensor/per-channel/per-group) and the 4.13–4.25 effective-bits computation; outlier theory with Var(e_y) = σ_w² · ‖x‖²; AWQ full derivation of W = (W·diag(s))(diag(s)⁻¹·x) identity with worked numerical example; GPTQ derived from Optimal Brain Surgeon with H = 2·X·X^T, Cholesky efficiency, lazy-batch blocks, full pseudocode; SmoothQuant with α derivation; activation quantization (static/dynamic/per-token); FP8 inference (H100 hardware, TransformerEngine, delayed scaling); Marlin kernel; mixed precision (LLM.int8()); calibration; evaluation discipline; 6 exercises.

`10_SPECULATIVE_DISAGGREGATION.md - Speculative Decoding & Disaggregated Inference

~8,900 words. Latency framing (TTFT vs TPOT); speculative decoding with rejection-sampling correctness proof; speedup formula S = α / (1 + K · T_draft / T_target) derived; geometric model α = (1 − β^{K+1}) / (1 − β) with worked numbers; variants (vanilla, self-speculative, Medusa, EAGLE/EAGLE-2, lookahead); tree speculation with attention-mask construction; speculation-batching tension showing speculation hurts at saturated batch; engineering (dual KV-caches, rollback, pipelining); DistServe architecture; Splitwise heterogeneous-hardware angle; Mooncake KVCache-centric design; KV-transfer sizing and overlap; full production stack composition with attribution; frontier directions explicitly marked research-stage; 6 exercises.

`11_NUMERICS_AND_MIXED_PRECISION.md - Floating Point and Training Stability

~8,600 words. IEEE-754 derivation including subnormals, RNE, FMA, fl(a op b) = (a op b)(1+δ) model; full bit layouts and side-by-side range/precision table for FP64/FP32/TF32/FP16/BF16/FP8 E4M3/E5M2/FP4; per-operation precision (matmul accumulator rule, reductions, softmax overflow); standard mixed-precision recipe with full pseudocode; loss scaling derivation including dynamic GradScaler algorithm; BF16 advantages; FP8 in detail with delayed scaling, amax history, full pseudocode, worked numerical example; TF32; Adam + low precision pitfall with stochastic rounding; catastrophic cancellation (naive vs pairwise vs Kahan); transformer stability tricks (stable softmax, online softmax recurrence, √dₖ derivation, logit soft-cap, z-loss); NaN handling; determinism; 6 exercises.

`12_KERNEL_FUSION.md - Kernel Fusion: Theory, Practice, and the Compilers That Do It For You

~8,000 words. The HBM round-trip cost model with worked Llama-3-70B numbers; fusion taxonomy (vertical/horizontal, five patterns); vertical fusion derived with RMSNorm worked example; horizontal fusion with QKV-projection and SwiGLU gate_up fusion math; GEMM epilogue fusion including the "fuse the residual into the matmul" trick; streaming-reduction fusion as the FlashAttention pattern; compiler-driven fusion in XLA, TorchInductor, Triton with production stack composition; three full Triton kernels (fused linear-GELU-residual, single-pass RMSNorm, causal masked online softmax); precision discipline under fusion (cheat-sheet table); register-pressure / SMEM / launch-amortization / tile-mismatch limits with H100 numbers; Nsight Compute metrics for verifying fusion worked; when NOT to fuse (training activations, debug cost, autotune-time inflation); 6 exercises.


Anti-Fabrication Discipline

Each chapter was authored under explicit anti-fabrication rules. Numbers cited:

  • Hardware constants (warp = 32 threads, H100 = 132 SMs, 80 GB HBM3, ~3 TB/s HBM3 BW, ~989 TFLOPS BF16 dense): unhedged.
  • Algorithm complexities (ring all-reduce bandwidth, FlashAttention HBM access, GPipe bubble): derived in the text.
  • Approximate / illustrative numbers (per-layer perf factors, hit rates, speedups in real systems): explicitly hedged with "~" or "approximate" or "verify with vendor docs."
  • Research-stage techniques (multi-token prediction, diffusion LMs, FP4 production deployments): flagged as such.

Layered Pedagogy

Every chapter follows the same shape:

  1. Why the topic exists-what problem it solves.
  2. Mental model-the right way to think about it.
  3. Mechanism-how it actually works, step by step.
  4. Math-derivations, not assertions.
  5. Numbers-concrete worked examples with specific hardware/model/shape choices.
  6. Diagrams-ASCII where they clarify.
  7. Code-for engineering chapters, runnable kernels and pseudocode.
  8. Pitfalls-the things you'll get wrong on first attempt.
  9. Exercises-six per chapter, with worked answers.

This is the mastery layout: anyone who reads a chapter end-to-end and completes the exercises has internalized the topic at a working-engineer level.


How to Use This Resource

As curriculum companion: read the monthly module, then the matching deep dive, then return to the lab with both as references.

As a reference text: tabbed open during work; jump to the relevant section by topic.

As interview prep: each chapter's exercise section is approximately the depth of a senior-level systems-engineering interview. If you can solve the exercises cold, you can answer the interview question.

As a teaching resource: each chapter is a self-contained lecture worth of material. Use to onboard a new engineer to a sub-topic in a single afternoon.