Week 14 - Data Parallelism: DDP, ZeRO, FSDP¶
14.1 Conceptual Core¶
- DDP (Distributed Data Parallel): every rank holds the full model. Each step, ranks process different micro-batches; gradients are allreduced before optimizer step. Simple, memory-inefficient (model + optimizer states replicated N times).
- ZeRO (Zero Redundancy Optimizer; Rajbhandari et al., SC 2020): observation that DDP wastes memory by replicating optimizer states (8x model size for Adam in FP32 momentum + variance). Three stages:
- ZeRO-1: shard optimizer states across ranks. Saves N× on optimizer memory.
- ZeRO-2: also shard gradients.
- ZeRO-3: also shard parameters. Fetches them just-in-time per layer via allgather; reduces gradients via reduce-scatter.
- FSDP (Fully Sharded Data Parallel; PyTorch's implementation of ZeRO-3): the modern standard for training models that don't fit in single-GPU memory under DDP.
14.2 Mechanical Detail¶
- FSDP wraps modules. Each FSDP unit (typically a transformer layer) shards parameters across all ranks. Forward:
- Allgather parameters of unit i (just before its forward).
- Compute forward of unit i.
- Free non-local parameters.
- Move to unit i+1.
- Backward is symmetric, with reduce-scatter for gradients.
- Mixed precision in FSDP: parameters in BF16 for compute, gradients in FP32 for stability, optimizer states in FP32. Configurable via
MixedPrecision. - Activation checkpointing: instead of storing all activations for backward, store only at unit boundaries; recompute the forward when needed. Trades ~30% extra compute for ~3× less activation memory. Essential for long-context LLM training.
- CPU offload: optimizer states or even parameters can spill to CPU RAM (slower but allows larger models). Used in low-memory regimes; avoid in production HPC.
14.3 Lab-"FSDP a Small Model"¶
On 4-8 GPUs (single node fine):
1. Train a 1B-parameter transformer in FSDP. Use transformer_auto_wrap_policy.
2. Compare memory and throughput: DDP-OOM-baseline (small model) vs FSDP small vs FSDP same-model-larger.
3. Add activation checkpointing; re-measure.
4. Add CPU offload; observe the speed cost.
5. Compute scaling efficiency (throughput_8gpu / (8 × throughput_1gpu)).
14.4 Idiomatic & Diagnostic Drill¶
torch.profilerwithwith_stack=Trueon an FSDP step. Identify the allgather and reduce-scatter calls; measure their fraction of step time.
14.5 Production Slice¶
- FSDP's
BackwardPrefetch.BACKWARD_PREoverlaps backward compute with next-layer's allgather. Verify it's enabled; without it, large models leave 20-30% perf on the table.