Month 9-Week 1: DDP, FSDP, multi-GPU run¶
Week summary¶
- Goal: Read foundational distributed-training papers (ZeRO, FSDP). Run a real multi-GPU FSDP training job. Internalize what scaling looks like.
- Time: ~10 h over 3 sessions.
- Output: Multi-GPU FSDP run with documented scaling efficiency; paper notes.
- Sequences relied on: 16-distributed-training rungs 01, 02, 05, 06, 10.
Why this week matters¶
You will never pretrain a frontier model. You absolutely will: read papers that reference DDP/FSDP/ZeRO, work alongside ML researchers, debug scaling regressions, and decide whether to scale up or out. Concept depth + one real multi-GPU run is the right ratio.
Prerequisites¶
- M08 complete.
- Budget for 2× GPU time (~$5–15 for one run on RunPod / Lambda Labs).
Recommended cadence¶
- Session A-Tue/Wed evening (~3 h): memory math + DDP
- Session B-Sat morning (~4 h): ZeRO + FSDP papers
- Session C-Sun afternoon (~3 h): multi-GPU run
Session A-Memory math + DDP¶
Goal: Compute training memory for a 7B model. Understand DDP and its bottleneck.
Part 1-Transformer memory math (75 min)¶
For a model with N parameters in bf16, training memory:
- Weights: 2N bytes.
- Gradients: 2N bytes (same shape as weights).
- Optimizer state (AdamW): 8N bytes (fp32 momentum + variance).
- Activations: depends on batch × seq × layers; with checkpointing, much less.
For 7B params: 7×2 + 7×2 + 7×8 = ~84 GB before activations. Single A100 (80GB) is just barely insufficient without optimizer-state sharding or quantization.
Read: EleutherAI's Transformer Math 101 blog post (search). Or Stas Bekman's "How to fit larger models" guide.
Part 2-DDP fundamentals (60 min)¶
DDP (DistributedDataParallel): - Each GPU holds the full model. - Each gets a different mini-batch. - After backward, gradients are all-reduced across GPUs (averaged). - All-reduce bandwidth is the bottleneck.
Read PyTorch DDP overview docs (search "pytorch ddp tutorial"). Plus the original DDP paper for context.
When DDP works: - Model fits on one GPU. - Want to train on more data faster.
When DDP doesn't: - Model doesn't fit on one GPU. (You need ZeRO/FSDP.)
Part 3-Self-check (45 min)¶
For a 13B model in bf16 with AdamW, on 4× A100 80GB: - DDP: needs 156GB per GPU → won't fit. ZeRO-3 or FSDP needed. - Memory math when sharded across 4: 156/4 ≈ 39GB → fits, but tight.
Predict before measuring.
Output of Session A¶
- Memory math for 7B and 13B.
- DDP mental model.
Session B-ZeRO + FSDP papers¶
Goal: Read ZeRO and FSDP papers. Understand what each shards.
Part 1-ZeRO paper (90 min)¶
Read: ZeRO (arxiv.org/abs/1910.02054). Sections 1, 2, 3.
Three stages: - ZeRO-1: shard optimizer state across GPUs. Saves ~4×. - ZeRO-2: also shard gradients. Saves ~8×. - ZeRO-3: also shard parameters. Saves ~16× (= 1/N where N is # GPUs).
Tradeoff: each stage adds communication. ZeRO-3 has the most communication but the most memory savings.
Part 2-FSDP paper (75 min)¶
Read: FSDP (arxiv.org/abs/2304.11277). Sections 1–4.
FSDP = Fully Sharded Data Parallel. PyTorch-native equivalent of ZeRO-3.
Key design: - Parameters sharded; gathered just-in-time per layer's forward. - After forward, params re-sharded (free memory). - Same for backward.
Wrapping policies determine granularity-wrap each transformer block, or wrap individual layers? Different tradeoffs.
Part 3-bf16, mixed precision (30 min)¶
Read PyTorch AMP docs.
Modern training uses bf16 (brain-float-16) instead of fp16: - Same memory as fp16. - Same dynamic range as fp32 (no overflow). - Almost-as-stable as fp32.
Why bf16 wins: training stability without the need for loss scaling.
Output of Session B¶
- ZeRO + FSDP paper notes.
- Mental model: when to use each stage.
Session C-Multi-GPU FSDP run¶
Goal: Run a real multi-GPU FSDP training job. Observe scaling.
Part 1-Rent + setup (45 min)¶
Rent 2× A10 (or similar) on RunPod or Lambda Labs (~$2–3/hr).
Use Hugging Face Accelerate for easy multi-GPU:
Part 2-Run (90 min)¶
Adapt your M08-W02 SFT script to use Accelerate's FSDP:
from accelerate import Accelerator
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig
import functools
fsdp_plugin = FullyShardedDataParallelPlugin(
auto_wrap_policy=functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={Qwen2DecoderLayer},
),
state_dict_type="FULL_STATE_DICT",
)
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
# Standard training loop wrapped with accelerator
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for batch in dataloader:
loss = compute_loss(model, batch)
accelerator.backward(loss)
optimizer.step()
Launch:
Part 3-Observe scaling (45 min)¶
Compare: - Single GPU baseline: tokens/sec. - 2-GPU FSDP: tokens/sec.
Scaling efficiency = (2-GPU throughput) / (2 × single-GPU throughput).
Likely: ~1.5–1.7× scaling efficiency. Not 2× because of communication overhead.
Why not 2×? - All-reduce of gradients takes time. - Parameter gathering for FSDP adds latency. - Data loading may bottleneck.
Document in distributed-experiments/ directory. Include: scaling efficiency, GPU memory observed, throughput numbers.
Output of Session C¶
- Multi-GPU FSDP run completed.
- Scaling efficiency documented.
End-of-week artifact¶
- ZeRO + FSDP paper notes
- Multi-GPU FSDP run
- Scaling efficiency documented
End-of-week self-assessment¶
- I can compute training memory for any transformer.
- I can explain what each ZeRO stage shards.
- I can launch a multi-GPU job with Accelerate.
Common failure modes for this week¶
- Skipping the math. Memory accounting is foundational.
- Not running on real multi-GPU. Reading vs doing-both required.
- Treating "FSDP works" as the lesson. The lesson is the scaling efficiency and what limits it.
What's next (preview of M09-W02)¶
Track final push (part 1) + first OSS PR upstream.