Saltar a contenido

Week 14 - Data Parallelism: DDP, ZeRO, FSDP

14.1 Conceptual Core

  • DDP (Distributed Data Parallel): every rank holds the full model. Each step, ranks process different micro-batches; gradients are allreduced before optimizer step. Simple, memory-inefficient (model + optimizer states replicated N times).
  • ZeRO (Zero Redundancy Optimizer; Rajbhandari et al., SC 2020): observation that DDP wastes memory by replicating optimizer states (8x model size for Adam in FP32 momentum + variance). Three stages:
  • ZeRO-1: shard optimizer states across ranks. Saves N× on optimizer memory.
  • ZeRO-2: also shard gradients.
  • ZeRO-3: also shard parameters. Fetches them just-in-time per layer via allgather; reduces gradients via reduce-scatter.
  • FSDP (Fully Sharded Data Parallel; PyTorch's implementation of ZeRO-3): the modern standard for training models that don't fit in single-GPU memory under DDP.

14.2 Mechanical Detail

  • FSDP wraps modules. Each FSDP unit (typically a transformer layer) shards parameters across all ranks. Forward:
  • Allgather parameters of unit i (just before its forward).
  • Compute forward of unit i.
  • Free non-local parameters.
  • Move to unit i+1.
  • Backward is symmetric, with reduce-scatter for gradients.
  • Mixed precision in FSDP: parameters in BF16 for compute, gradients in FP32 for stability, optimizer states in FP32. Configurable via MixedPrecision.
  • Activation checkpointing: instead of storing all activations for backward, store only at unit boundaries; recompute the forward when needed. Trades ~30% extra compute for ~3× less activation memory. Essential for long-context LLM training.
  • CPU offload: optimizer states or even parameters can spill to CPU RAM (slower but allows larger models). Used in low-memory regimes; avoid in production HPC.

14.3 Lab-"FSDP a Small Model"

On 4-8 GPUs (single node fine): 1. Train a 1B-parameter transformer in FSDP. Use transformer_auto_wrap_policy. 2. Compare memory and throughput: DDP-OOM-baseline (small model) vs FSDP small vs FSDP same-model-larger. 3. Add activation checkpointing; re-measure. 4. Add CPU offload; observe the speed cost. 5. Compute scaling efficiency (throughput_8gpu / (8 × throughput_1gpu)).

14.4 Idiomatic & Diagnostic Drill

  • torch.profiler with with_stack=True on an FSDP step. Identify the allgather and reduce-scatter calls; measure their fraction of step time.

14.5 Production Slice

  • FSDP's BackwardPrefetch.BACKWARD_PRE overlaps backward compute with next-layer's allgather. Verify it's enabled; without it, large models leave 20-30% perf on the table.

Comments