Week 16 - Mixed Precision, FP8, Numerical Stability at Scale¶
16.1 Conceptual Core¶
- Modern training uses multiple precisions simultaneously:
- Compute (matmul) in BF16 / FP8.
- Master weights in FP32.
- Optimizer states in FP32 (Adam: momentum + variance).
- Gradients accumulated in FP32 to avoid loss-scale issues.
- FP8 training (Hopper / Blackwell): two formats-E4M3 (more mantissa, less range, used for activations/weights) and E5M2 (more range, used for gradients). NVIDIA TransformerEngine library handles the casting.
- The challenges of low-precision training:
- Loss scaling (FP16): scale loss by a power of 2 before backward to prevent gradient underflow; unscale before optimizer. Done automatically by
GradScaler. - Per-tensor scaling (FP8): each tensor needs its own scale factor (the max abs value); recomputed every step. This is delicate; TransformerEngine handles it.
- Numerical stability: occasional NaNs from low-precision overflow. Detect and handle (skip step or reduce LR).
16.2 Mechanical Detail¶
- NVIDIA TransformerEngine (
transformer-enginePython package): drop-in replacement layers (te.Linear,te.LayerNorm,te.TransformerLayer) that use FP8 internally with auto-scaling. - Activation memory dominates training memory at long contexts. With BF16 activations, a 32K-context sequence at 8B params can need ~80 GB activations alone. FP8 halves this; FlashAttention reduces it further.
- Gradient accumulation steps × per-step batch = effective batch size. Memory is per-step batch; convergence is governed by effective batch.
- Communication / compute overlap: with FSDP, you want allgather of layer N+1 to overlap with compute of layer N. With TP, you want the backward allreduce to overlap with the next layer's forward of the next microbatch. Profile to verify; absent overlap, you're leaving 20-50%.
16.3 Lab-"FP8 Train a Small Model"¶
On at least one H100/H200/B200 (you may need to rent for a day):
1. Take your week 14 FSDP setup. Replace all linear layers with te.Linear. Wrap blocks with te.fp8_autocast.
2. Train the same model in BF16 vs FP8. Compare:
- Throughput.
- Memory.
- Loss curve (the test of stability-FP8 should match BF16 within noise).
3. Document any NaN events and recovery actions.
If H100+ is unavailable, do this lab in BF16 + torch.cuda.amp, comparing against FP32. The instability dynamics are similar at lower stakes.
16.4 Idiomatic & Diagnostic Drill¶
- Track per-tensor scale factors across training steps. Sudden scale-factor drops indicate impending NaN; sudden rises indicate underflow risk on the next step.
16.5 Production Slice¶
- Build a "training health" dashboard: loss, grad norm, parameter norm, scale factors, throughput, memory. Alert on grad-norm spikes (data quality issue or instability) and on plateaued loss without grad-norm decrease (saturated learning).
Month 4 Capstone Deliverable¶
A distributed-training/ directory:
1. allreduce-bench/ (week 13)-bandwidth measurements + topology doc.
2. fsdp-scaling/ (week 14)-scaling efficiency study.
3. tp-attention/ (week 15)-hand-rolled TP attention + benchmark.
4. fp8-train/ (week 16)-FP8 vs BF16 comparison.
A PARALLELISM_GUIDE.md decision matrix: given (model size, GPU count, NVLink topology, IB bandwidth), what 3D-parallelism degrees do you pick?
Recommended Reading Done This Month¶
- The ZeRO paper (SC 2020).
- The Megatron-LM paper.
- GPipe (NeurIPS 2019) and PipeDream (SOSP 2019) for pipeline parallelism.
- The NCCL design doc.
- Narayanan et al., Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM (SC 2021)-the canonical 3D-parallelism paper.
- Micikevicius et al., FP8 Formats for Deep Learning (2022).