Distributed Training: Mathematics and Engineering¶
A self-contained reference for understanding how modern foundation models are trained across hundreds-to-thousands of GPUs. By the end, you should be able to read ZeRO, Megatron-LM, GPipe, and PipeDream from summary alone, derive the memory and communication costs of any sharding scheme, and design a 3D-parallel configuration for a given model and cluster.
1. Why Distributed Training: The Memory Wall¶
1.1 Intuition¶
A single H100 80GB GPU has 80 GB of HBM. A single A100 has 40 or 80 GB. Modern foundation models exceed this by orders of magnitude:
- GPT-3 175B in FP16: 350 GB just for parameters.
- Llama-3 405B in BF16: 810 GB just for parameters.
But parameters are only one of four memory bills. Training requires:
- Parameters (W)-the weights themselves.
- Gradients (∇W)-same shape as parameters.
- Optimizer states (O)-Adam keeps
m(first moment) andv(second moment), both same shape as parameters. - Activations (A)-intermediate forward outputs needed for backward; scales with batch × sequence × hidden, not with parameters.
Distributed training is, fundamentally, the engineering discipline of partitioning these four buckets across many devices while keeping the arithmetic correct and the communication cheap.
1.2 Memory Decomposition (the Standard Accounting)¶
Let Φ be the number of parameters in the model (e.g., Φ = 7 × 10⁹ for a 7B). Let:
b_p= bytes per parameter in storage precision (FP16/BF16 = 2, FP32 = 4).b_g= bytes per gradient (typically same asb_p).b_o= bytes per optimizer state slot (Adam in FP32 keepsm,v, plus a master FP32 copy of weights → 12 bytes per param under standard mixed precision).
Standard mixed-precision Adam memory per replica, in bytes:
M_param = 2Φ (BF16 parameters)
M_grad = 2Φ (BF16 gradients, often kept FP32 → 4Φ)
M_opt = 12Φ (FP32 master weights 4Φ + FP32 m 4Φ + FP32 v 4Φ)
M_static = 16Φ ← without gradients FP32; with FP32 grads it's 18Φ or 20Φ
The "16Φ" figure is what ZeRO calls K=12 + 4 = 16 (12 bytes optimizer + 4 for FP16 params and grads). Variations exist; the principle is what matters.
For Φ = 70B with M_static ≈ 16Φ: 1.12 TB of static memory. No single 80GB GPU comes close. Hence: shard, replicate, or pipeline.
1.3 Activation Memory¶
Activations dominate at long sequence lengths. For a Transformer with L layers, hidden size h, batch B, sequence S, the activation memory per layer without recomputation is approximately:
where t is the tensor-parallel degree and a is the number of attention heads (Korthikanti et al. 2022 give a precise formula). The 5·a·S²/(h·t) term grows quadratically in sequence length and is what motivates sequence parallelism (Section 10).
With selective activation checkpointing (recomputing attention but storing the rest), this drops to roughly S · B · h · (10 + 24/t).
1.4 The Three Axes of Parallelism¶
Each of the four memory buckets can be partitioned along three orthogonal axes:
| Axis | What it splits | What it costs |
|---|---|---|
| Data parallel (DP) | Batch | All-reduce of gradients per step |
| Tensor parallel (TP) | Hidden dim of each layer | All-reduce inside each layer (×4 per Transformer block) |
| Pipeline parallel (PP) | Layer depth | Point-to-point sends between stages + bubble |
ZeRO/FSDP is a refinement of DP that also shards parameters/gradients/optimizer states. Sequence parallel splits along the sequence dim. The full design space is rich; the rest of this chapter walks it.
2. Communication Primitives¶
Distributed training is impossible without a vocabulary of collective operations. We define each precisely, in terms of input shape, output shape, and the data each rank ends with.
2.1 Point-to-Point¶
send(tensor, dst) / recv(tensor, src): rank src sends a buffer of shape T to rank dst. After the call, both ranks hold T (typically). Used in pipeline parallelism between adjacent stages.
Cost model: α + M/β where α is link latency and β is link bandwidth, M is message size in bytes.
2.2 Collectives¶
Let there be N ranks 0, …, N-1. Each rank holds a tensor x_i.
broadcast(root): after the call, every rank holds x_root. Input shape on root: T; output shape on every rank: T.
reduce(op, root): after the call, root holds op(x_0, x_1, …, x_{N-1}) elementwise (for op = sum, min, max). Other ranks: undefined. Input/output shape: T.
all_reduce(op): every rank ends with op(x_0, …, x_{N-1}). Equivalent to reduce followed by broadcast. Input shape per rank: T; output shape per rank: T. Total data on each rank stays T.
all_gather: each rank i holds a chunk x_i of shape T. After the call, every rank holds the concatenation [x_0 ‖ x_1 ‖ … ‖ x_{N-1}] of shape N · T. Total bytes received per rank: (N-1)·T.
reduce_scatter(op): each rank i holds a tensor of shape N · T partitioned into N chunks. After the call, rank i holds the elementwise reduction of the i - th chunks across all ranks, shapeT. Output shape per rank:T`.
all_to_all: each rank i holds N chunks of shape T, where chunk j is destined for rank j. After the call, rank i holds chunks i from each rank 0..N-1, concatenated, shape N · T. Used in expert-parallel MoE for the dispatch/combine step.
A clean identity: all_reduce = reduce_scatter + all_gather. This decomposition is the basis for ring all-reduce and for ZeRO-2/FSDP.
2.3 Algorithm Trees for All-Reduce¶
We model communication time as T = α · L + M / β where L is the number of sequential link traversals (latency-limited) and M is the bytes that traverse the slowest link (bandwidth-limited). We want both small.
2.3.1 Naive (gather-to-root, scatter-from-root)¶
Every rank sends its M bytes to rank 0 sequentially; rank 0 reduces and broadcasts. Time:
T_naive = (N-1) · (α + M/β) (gather, sequential through root's NIC)
+ (N-1) · (α + M/β) (broadcast, same)
≈ 2(N-1)·α + 2(N-1)·M/β
Latency: O(N). Bandwidth: O(N·M). Root's link is a bottleneck; worst-case algorithm. Useful only for tiny N and tiny M.
2.3.2 Binary Tree¶
Build a binary tree over ranks. Reduce phase: leaves send to parents, parents reduce, propagate to root. Broadcast phase: reverse. Each rank participates in log₂(N) rounds of reduce + log₂(N) of broadcast.
Per-round, every active rank sends M bytes:
Latency: O(log N). Bandwidth: O(M·log N). Better latency, worse bandwidth than naive for large M. Used for small-message all-reduces (e.g., scalar metrics).
2.3.3 Recursive Doubling¶
For all-gather (dual of all-reduce): pair up ranks at distances 1, 2, 4, … N/2, each pair exchanges its current data. After log₂(N) rounds, each rank holds all N chunks. Time per round: α + M_round/β where M_round doubles.
Total bandwidth cost: M + 2M + 4M + … + (N/2)M = (N-1)·M, but spread across log₂(N) rounds. For all-reduce via this method: same as tree.
2.3.4 Recursive Halving + Recursive Doubling (Rabenseifner)¶
For large messages: do reduce-scatter in log₂(N) recursive-halving rounds, then all-gather in log₂(N) recursive-doubling rounds.
- Reduce-scatter: in round
k, each rank exchangesM / 2^kbytes with its partner at distance2^(k-1). - After
log₂(N)rounds: each rank holdsM/Nbytes that are the reduction of the corresponding slice. - All-gather: reverse, doubling the chunk size each round.
Total bytes sent per rank:
Latency: 2·log₂(N)·α. Bandwidth-optimal in M, latency-logarithmic. This is what NCCL uses for inter-node all-reduce when the message and topology are right.
2.3.5 Ring All-Reduce (Patarasuk & Yuan, 2009)¶
Arrange the N ranks in a logical ring. Split the buffer into N chunks of size M/N.
Phase A: reduce-scatter (N-1 steps). In step k:
- Rank i sends chunk (i - k) mod N to rank (i+1) mod N.
- Rank i receives chunk (i - k - 1) mod N from rank (i-1) mod N, adds it to its local copy.
After N-1 steps, rank i owns the fully reduced chunk indexed (i+1) mod N (or some rotation; conventions vary).
Phase B: all-gather (N-1 steps). Same ring traversal: each rank passes its fully reduced chunk forward, eventually every rank holds all N reduced chunks.
Bandwidth analysis. In each step, every rank sends and receives one chunk of M/N bytes. Total steps: 2(N-1). Total bytes per rank (send or receive): 2(N-1) · M/N = 2(N-1)/N · M.
Latency: O(N). Bandwidth: O(2(N-1)/N · M) → O(2M) as N→∞. This is bandwidth-optimal: you cannot all-reduce M bytes for fewer than 2(N-1)/N · M bytes per link without violating information-theoretic bounds. The latency is the price; for large messages on high-bandwidth links, latency vanishes relatively, and ring is the right choice.
Why bandwidth-optimal? Each rank's data must influence every other rank, and each rank must end up with the same M bytes. The min cuts of the ring topology force at least 2(N-1)/N · M bytes through each link.
Engineering note. Ring is the default intra-node algorithm in NCCL because NVLink's full-duplex topology fits the ring perfectly. For inter-node, NCCL may switch to tree (for small messages) or "double-binary tree" (for medium messages).
2.4 Summary Table¶
| Algorithm | Latency | Bandwidth (per rank) | Best regime |
|---|---|---|---|
| Naive | 2(N-1)·α | 2(N-1)·M | Never |
| Binary tree | 2 log₂N · α | 2 log₂N · M | Tiny messages |
| Recursive doubling | log₂N · α | log₂N · M | All-gather, small-mid |
| Rabenseifner | 2 log₂N · α | 2(N-1)/N · M | Large M, low N |
| Ring all-reduce | 2(N-1)·α | 2(N-1)/N · M | Large M, intra-node |
3. NCCL Specifics¶
NCCL (NVIDIA Collective Communications Library) is what every PyTorch/JAX distributed run uses on NVIDIA GPUs. Understanding its choices is the difference between 30% and 70% of peak.
3.1 Algorithm Selection¶
NCCL maintains a cost model and picks among:
- Ring: bandwidth-optimal; default for large messages, intra-node, NVLink-rich.
- Tree (double-binary tree): better latency for small messages, especially inter-node.
- CollNet: uses InfiniBand SHARP (Mellanox switch-based aggregation) for in-network reduction. Subtracts host-side bandwidth.
- NVLS (NVLink SHARP, on H100/Hopper systems with NVSwitch v3): in-NVSwitch reduction, like CollNet but for NVLink.
The choice is made per-collective based on (count, datatype, topology) using internal tuning tables.
3.2 Reading NCCL_DEBUG=INFO¶
A typical line:
NCCL INFO Channel 00/02 : 0[1c000] -> 1[1d000] via P2P/IPC
NCCL INFO Channel 00 : 0[1c000] -> 1[1d000] via NET/IB/0/GDRDMA
What to extract:
- Channel count (
02here): NCCL uses multiple parallel rings; more channels = more concurrent SM use. - Transport (
P2P/IPC,NET/IB,SHM): tells you whether traffic is staying on NVLink, going through shared memory, or hitting the network. - GDRDMA: GPUDirect RDMA, meaning NIC reads directly from GPU memory-desirable.
Ring 00 : 0 1 2 3 …: the ring order. For two-node 8-GPU jobs, NCCL builds rings that traverse intra-node NVLink before crossing IB.
If you see Tree lines for a large all-reduce, NCCL has decided the message is small enough or topology is sparse enough; sometimes wrong-you can override.
3.3 Tuning Knobs¶
| Variable | Effect |
|---|---|
NCCL_DEBUG=INFO |
Verbose init logs, ring topology |
NCCL_DEBUG_SUBSYS=ALL |
Even more (COLL, INIT, NET) |
NCCL_IB_HCA=mlx5_0,mlx5_1 |
Which IB cards to use (rail affinity) |
NCCL_SOCKET_IFNAME=eth0 |
Which Ethernet interface for bootstrap |
NCCL_TOPO_FILE=topo.xml |
Override auto-detected topology |
NCCL_ALGO=Ring (or Tree, CollNet, NVLS) |
Force algorithm |
NCCL_PROTO=LL,LL128,Simple |
Force protocol; LL = low-latency (small msgs), Simple = bulk |
NCCL_NTHREADS, NCCL_MAX_NCHANNELS |
Channel parallelism |
NCCL_P2P_DISABLE=1 |
Disable peer-to-peer (debug) |
NCCL_IB_GID_INDEX |
RoCE-specific; pick the right GID |
NCCL_BUFFSIZE |
Per-channel buffer size; affects pipelining |
3.4 Rail Affinity¶
Modern multi-NIC nodes use rails: GPU 0 talks to NIC 0, GPU 1 to NIC 1, etc. NCCL needs to know this; otherwise GPU 0 might end up sending through NIC 3 over PCIe, halving inter-node bandwidth.
Verify: nvidia-smi topo -m shows the GPU↔NIC connectivity matrix (PXB, PHB, NV2, etc.). Make sure NCCL's chosen NCCL_IB_HCA list aligns with the rails.
4. Data Parallelism (DDP)¶
4.1 The Pattern¶
Every rank holds the full model. Each step:
- Local forward on local micro-batch.
- Local backward.
- All-reduce of gradients across all ranks (sum, then divide by N).
- Local optimizer step (identical on every rank → models stay in sync).
This requires that initial weights are identical (broadcast at init) and that gradient all-reduce is exact. Both are easy to guarantee.
4.2 Memory Cost¶
Per rank (mixed precision, Adam), with Φ parameters:
- Params (BF16):
2Φ - Grads (FP32 for stability, often):
4Φ - or BF16 →2Φ` - Optimizer master weights (FP32):
4Φ - Optimizer
m(FP32):4Φ - Optimizer
v(FP32):4Φ
Total static: ~16-18Φ per rank. DDP does not reduce model state memory: every rank pays full price. With Φ = 7B, that's ~112 GB on each rank. Already too big for one H100 80GB.
4.3 PyTorch DDP Implementation Details¶
The all-reduce isn't done as one giant call. PyTorch DDP uses gradient bucketing:
- Gradients are grouped into buckets of ~25MB (configurable via
bucket_cap_mb). - As each parameter's gradient is computed during backward, it's added to its bucket.
- When a bucket is full, NCCL's
allReduceis launched asynchronously, overlapping with continuing backward computation on other parameters. - Sync point at end of backward: wait for all bucket all-reduces to finish.
This overlap is critical: without it, you'd have backward → idle → all-reduce → idle. With it, all-reduce hides under backward.
Pseudocode of the hook:
on_grad_ready(param):
bucket = bucket_for(param)
bucket.add(param.grad)
if bucket.full() or last_param_in_bucket(param):
bucket.handle = ncclAllReduce(bucket.buffer, async=True)
4.4 Scaling Efficiency¶
Throughput per GPU vs. one GPU:
If T_allreduce is fully hidden under backward, efficiency → 100%. In practice, the last bucket can't be hidden (it's the final gradients). And small messages near the start have high latency relative to size.
Gotcha 1: small models, fast nodes. If T_compute per step is small (say 50ms) and the all-reduce of the final bucket is 20ms, you're at 70% efficiency before any other loss.
Gotcha 2: stragglers. All-reduce is barrier-like; one slow rank slows everyone. NCCL's default 30-min timeout will kill the job.
Gotcha 3: find_unused_parameters=True. Adds an extra pass and an extra all-reduce of a bitmask; avoid if possible.
4.5 Micro-example¶
Llama-2 7B, 8×H100, batch 8M tokens. Φ = 7e9, BF16 grads = 14 GB. Ring all-reduce on 8 GPUs at 600 GB/s NVLink:
If a forward+backward step takes 400 ms, the 40.8 ms hides ~90%; effective overhead ~4-8 ms. Efficiency: 98-99%. This is what makes 7B "easy" on 8 GPUs.
5. ZeRO (Rajbhandari et al., SC 2020)¶
The Zero Redundancy Optimizer observes that DDP duplicates the four memory buckets N times across N ranks. ZeRO shards them.
5.1 The Insight¶
Recall standard mixed-precision Adam memory: ~16Φ per rank. Of that, 12Φ is optimizer states (FP32 master, m, v). DDP keeps a full copy on every rank. Why? Because the optimizer step happens locally, and you need the states to take a step.
But: you only need the optimizer state of parameter p on the rank that owns p. If we shard:
- Stage 1: shard optimizer states.
- Stage 2: also shard gradients.
- Stage 3: also shard parameters.
5.2 ZeRO-1: Optimizer State Sharding¶
Partition the parameter index space into N shards: rank i owns shard i's optimizer state (master weights + m + v, all FP32).
Per-step protocol: 1. Forward: full BF16 weights on every rank (still replicated). 2. Backward: full BF16 gradients on every rank. 3. All-reduce gradients (same as DDP). 4. Each rank applies optimizer to its shard of params (using its shard of m, v, master weights). 5. All-gather updated BF16 weights so every rank has the full updated model.
Memory per rank:
Params (BF16): 2Φ
Grads (BF16): 2Φ (allocated full, but could shard after all-reduce)
Optimizer (FP32): 12Φ / N
Total ≈ 4Φ + 12Φ/N
For Φ=70B, N=64: 4·70 + 12·70/64 = 280 + 13.1 = 293 GB. Still bad. ZeRO-1 alone isn't enough for big models.
Communication. ZeRO-1 adds an all_gather of params after the optimizer step. Cost: (N-1)/N · 2Φ bytes (BF16). This is in addition to the gradient all-reduce. Net: roughly 1.5× DDP communication.
5.3 ZeRO-2: Optimizer + Gradient Sharding¶
Replace gradient all-reduce with reduce-scatter: each rank ends up with the reduced gradient of its shard only. Now only sharded gradients need to be stored.
Per-step protocol: 1. Forward: full params (replicated). 2. Backward: full grads computed locally. 3. Reduce-scatter gradients → each rank has reduced grads for its shard only. 4. Local optimizer step on shard. 5. All-gather updated params.
Memory per rank:
Params (BF16): 2Φ
Grads (BF16): 2Φ / N ← sharded after reduce-scatter
Optimizer (FP32): 12Φ / N
Total ≈ 2Φ + 14Φ/N
For Φ=70B, N=64: 140 + 14·70/64 = 140 + 15.3 = 155 GB. Still too much.
Communication. reduce_scatter + all_gather = all_reduce algorithmically, so total bytes are the same as DDP. ZeRO-2 is essentially free in communication vs DDP-same bytes, less memory.
5.4 ZeRO-3: Full Sharding¶
Also shard parameters. Now no rank holds the full model. Forward must gather the params for the current layer, do the matmul, then free them.
Per-step protocol (per layer):
1. Forward layer l: all_gather the BF16 params of layer l. Do matmul. Free the gathered params (keep only your shard).
2. Backward layer l: all_gather BF16 params again (or kept from forward). Compute grads. reduce_scatter grads → each rank holds only its grad shard. Free grad buffer.
3. After full backward: each rank applies optimizer to its param shard. No final all-gather needed; next forward will gather as needed.
Memory per rank (everything sharded):
For Φ=70B, N=64: 16·70/64 = 17.5 GB of static state. Now there's room for activations!
Communication cost. ZeRO-3 adds, per step:
- One
all_gatherof params per layer (forward) ≈(N-1)/N · 2Φtotal over the model. - One
all_gatherof params per layer (backward, can be elided with smart reuse). - One
reduce_scatterof grads per layer ≈(N-1)/N · 2Φ.
Roughly 1.5× DDP communication. The tradeoff: you save N× memory at the cost of 1.5× bytes on the wire. Worth it when you can't fit otherwise.
5.5 Memory Math Summary¶
For one rank, parameters Φ, mixed-precision Adam (K=12 for opt states):
| Scheme | Params | Grads | Opt states | Total |
|---|---|---|---|---|
| DDP | 2Φ | 2Φ | 12Φ | 16Φ |
| ZeRO-1 | 2Φ | 2Φ | 12Φ/N | 4Φ + 12Φ/N |
| ZeRO-2 | 2Φ | 2Φ/N | 12Φ/N | 2Φ + 14Φ/N |
| ZeRO-3 | 2Φ/N | 2Φ/N | 12Φ/N | 16Φ/N |
(Note: some implementations keep grads in FP32 → 4Φ, shifting all rows. The structure is identical.)
5.6 ZeRO-Offload, ZeRO-Infinity¶
- ZeRO-Offload: move optimizer states (and optionally gradients) to CPU memory and run the optimizer step on CPU. Trades GPU↔CPU PCIe bandwidth (~32 GB/s) for GPU memory.
- ZeRO-Infinity: extends to NVMe; uses overlap so prefetched param shards arrive in time for the layer that needs them. Enables training models that exceed total GPU memory of the cluster.
These are useful when you can't add GPUs but have RAM and SSDs.
6. FSDP-PyTorch's ZeRO-3¶
PyTorch's Fully Sharded Data Parallel is the in-tree implementation of ZeRO-3, with a few production refinements.
6.1 Wrapping Discipline¶
You don't shard parameter-by-parameter-too fine; the all-gather overhead per matmul would crush you. You shard at the unit level, where a "unit" is a meaningful submodule like a Transformer block.
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer},
)
model = FSDP(model, auto_wrap_policy=policy, ...)
This wraps each LlamaDecoderLayer as an FSDP unit. The whole model is also wrapped (the outer FSDP). Each unit holds one all-gather buffer and one reduce-scatter buffer.
6.2 The Per-Unit Dance¶
For each unit during forward:
- Pre-forward hook:
all_gatherparams of this unit (BF16) into a flat buffer. - Forward compute: standard matmul/attention.
- Post-forward hook: free the all-gathered params; only the local shard remains.
For each unit during backward:
- Pre-backward hook:
all_gatherparams again (since they were freed). - Backward compute: produces gradients in the all-gathered shape.
- Post-backward hook:
reduce_scattergradients across ranks; each rank ends with its shard's gradients. Free the all-gathered params.
After full backward, the optimizer steps on local shards. No final all-gather (next forward does it lazily).
6.3 Prefetching¶
The naive schedule has bubbles: you can't compute layer l+1 while waiting for its all-gather. Fix: prefetch.
BACKWARD_PRE issues the all-gather of the previous unit (in backward order, so the next-to-be-computed unit) before the current unit's backward runs. This overlaps:
BACKWARD_POST is the conservative variant; BACKWARD_PRE is what you want when memory allows (it temporarily holds two all-gather buffers).
There's also forward_prefetch=True for forward, similar idea.
6.4 Mixed Precision¶
mp_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.bfloat16,
)
FSDP(model, mixed_precision=mp_policy, ...)
param_dtype=BF16: all-gathered params are BF16; matmuls run in BF16.reduce_dtype=FP32: gradients are reduce-scattered in FP32 to avoid catastrophic cancellation (small grads summed across many ranks). This costs 2× bandwidth on grads but stabilizes training.- Optimizer in FP32: ZeRO-style master weights, m, v all FP32 on the local shard.
6.5 Activation Checkpointing¶
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper, CheckpointImpl
)
apply_activation_checkpointing(
model,
checkpoint_wrapper_fn=functools.partial(
checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT
),
auto_wrap_policy=policy,
)
For each wrapped block, only the input is saved; intermediate activations are recomputed during backward. Memory: ~O(L) instead of O(L · h · S · B) for the per-layer activations. Compute: ~+33% (one extra forward).
NO_REENTRANT is the modern impl; supports arbitrary autograd graphs.
6.6 CPU Offload¶
Param shards live on CPU; before all-gather, they're brought back to GPU. Use only when out of GPU memory; PCIe is slow.
6.7 The FlatParameter¶
Internally, FSDP flattens all params in a unit into a single 1D tensor. Reasons:
- One all-gather per unit (not one per param tensor).
- Predictable memory layout.
- Easier sharding math: just split the flat tensor into
Nequal slices (with padding).
Optimizer sees this flat parameter; optim.step operates on the local slice. State_dict round-tripping un-flattens.
6.8 FSDP2 (per-parameter sharding)¶
The successor (in-tree as of PyTorch 2.4+) abandons FlatParameter in favor of per-parameter sharding via DTensor. Cleaner semantics, better composition with TP/PP, slightly higher dispatch cost. The collective behavior is the same.
7. Tensor Parallelism (Megatron-LM)¶
When a single layer's weight matrix is too big for one GPU, or when matmul throughput is the limit, split the weight matrix itself.
7.1 The Two Splits¶
A linear layer computes Y = X · W where X ∈ ℝ^{B×K}, W ∈ ℝ^{K×N}, Y ∈ ℝ^{B×N}.
Column-parallel. Split W along the output (column) dimension:
Each rank i holds W_i and a copy of X. Computes Y_i = X · W_i locally, shape B × N/t. Output is sharded along columns.
- Forward output: sharded-
Y_ion ranki. To get fullY, concatenate:all_gatheralong last dim. But if the next op is row-parallel, no gather needed (see below). - Backward of input:
∂L/∂X = ∂L/∂Y · W^T. Each rank computes∂L/∂X_i = ∂L/∂Y_i · W_i^T. To get full∂L/∂X(for the previous layer that has fullX), need all-reduce of∂L/∂X.
Row-parallel. Split W along the input (row) dimension:
Each rank holds W_i and a corresponding shard X_i ∈ ℝ^{B × K/t} of the input. Computes Y_i = X_i · W_i ∈ ℝ^{B × N} (full output dim, but a partial sum).
- Forward output: full shape but partial sum. Need all-reduce to get the true
Y = ∑_i Y_i. - Backward of input:
∂L/∂X_i = ∂L/∂Y · W_i^T. Each rank already has∂L/∂Y(because forward all-reduced it). Output∂L/∂X_iis sharded-exactly the input format the previous (column-parallel) layer expects. No comm.
7.2 The Megatron Pattern¶
The genius of Megatron-LM: chain a column-parallel followed by a row-parallel so that the sharded output of the first is exactly the sharded input of the second. The all-gather between them is avoided.
MLP block (h → 4h → h):
Y = GeLU(X · W_1) · W_2
W_1: column-parallel (split 4h along output)
→ Each rank computes X · W_1_i, shape B×4h/t. NO all-gather.
GeLU: elementwise, sharded fine.
W_2: row-parallel (split 4h along input)
→ Each rank has its shard of GeLU output, computes (GeLU_i) · W_2_i.
→ All-reduce the result to get full Y.
Forward comm: one all-reduce at the end of MLP. Backward comm: one all-reduce at the start of MLP (input gradient).
Self-attention block. With a heads, head dim d_h, hidden h = a · d_h:
Q = X · W_Q, K = X · W_K, V = X · W_V (h → h)
A = softmax(Q K^T / √d_h) · V (per-head)
Y = A · W_O (h → h)
W_Q, W_K, W_V: column-parallel. Each rank gets a/t heads.
→ Q_i, K_i, V_i shape B×S×(h/t).
→ Per-head attention is local-no comm.
W_O: row-parallel.
→ All-reduce at end.
Forward comm: one all-reduce at end of attention. Backward comm: one all-reduce at start of attention.
7.3 Total Comm per Transformer Block¶
Per block (attention + MLP), TP communication is:
- 2 all-reduces in forward (one at end of attention, one at end of MLP).
- 2 all-reduces in backward (one at start of MLP's backward = ∂L/∂X for X feeding MLP, one at start of attention's backward).
4 all-reduces per block per step.
Each all-reduce is on activations of size B · S · h BF16 = 2BSh bytes. With ring all-reduce on t ranks: 2(t-1)/t · 2BSh bytes per all-reduce.
For a Transformer with L blocks, total TP comm per step:
For Llama-3 70B (h=8192, L=80), B=4, S=8192, t=8: 4·80 · 1.75 · 2·4·8192·8192 ≈ 4·80·1.75·512MB = 280 GB per step crossing the TP comm. On 600 GB/s NVLink: ~470 ms per step. This is why TP wants intra-node only: cross-node TP at 50 GB/s (IB) would be 10× slower.
7.4 Why TP Needs NVLink¶
In summary: TP's comm is on the critical path of every layer. It cannot overlap with compute (the next layer can't start until this all-reduce finishes-well, except with sequence parallelism tricks). So bandwidth matters, immensely.
Rule of thumb: TP degree ≤ GPUs per node. On an 8×H100 DGX, TP=8 fits one node. TP=16 across two nodes through IB → ruinous.
7.5 Embedding TP¶
Vocabulary is large (e.g., 128K tokens × 8K hidden = 1B params just for embedding). TP splits along vocab:
- Each rank holds
V/trows of the embedding table. - Embedding lookup: each rank looks up the rows for tokens in its slice, zeros elsewhere; all-reduce.
- Output projection (the LM head, often tied to embedding): row-parallel; same structure.
The softmax over vocab requires care: subtract per-row max before exp; the per-row max is computed via all-reduce(max).
7.6 Micro-example¶
A row-parallel W: K=8192 → N=8192, t=8. Each rank holds W_i: 1024 × 8192. Forward:
- Input shard X_i: B × 1024 (output of previous column-parallel).
- Local matmul: Y_i = X_i · W_i, shape B × 8192.
- All-reduce: Y = ∑_i Y_i, shape B × 8192.
Each rank's matmul uses 1024 · 8192 · B · 2 / 8 FLOPs (BF16). Each rank's all-reduce sends 2(t-1)/t · 2 · B · 8192 bytes. Compare compute time vs comm time to assess overlap potential.
8. Pipeline Parallelism¶
Splits the model along the depth axis. Layers 1..L/4 on stage 0, L/4+1..L/2 on stage 1, etc.
8.1 Naive Pipeline¶
Stage 0: forward layers 1..L/P → send activations to stage 1 → wait → receive grads from stage 1 → backward → done. Each stage waits for the others; only one stage active at a time. Utilization: 1/P. Useless without microbatching.
8.2 GPipe (Huang et al., 2018)¶
Split each minibatch into M microbatches. Pipeline them: stage 0 starts microbatch 1, then microbatch 2, then 3, etc. After all forwards complete, run backwards in reverse.
ASCII schedule for P=4, M=4:
time →
S0: F1 F2 F3 F4 . . . B4 B3 B2 B1
S1: . F1 F2 F3 F4 . . . B4 B3 B2 B1
S2: . . F1 F2 F3 F4 . . . B4 B3 B2 B1
S3: . . . F1 F2 F3 F4 B4 B3 B2 B1 . . .
The leading and trailing dots are bubbles: idle time. With P stages and M microbatches:
- Total useful time per stage:
M · t_F + M · t_B≈M · (t_F + t_B). - Total wall time:
(M + P - 1) · (t_F + t_B). - Bubble fraction:
(P-1) / (M + P - 1).
Bubble = (P-1)/(M+P-1). For P=4, M=4: 3/7 ≈ 43%. For M=16: 3/19 ≈ 16%. For M=64: 3/67 ≈ 4.5%. Make M large.
GPipe memory issue: stage 0 must keep activations from M microbatches alive (it does forward 1..M before any backward). Activation memory: O(M · per_microbatch_act). Use activation checkpointing to bring it down.
8.3 1F1B (PipeDream)¶
Interleave forwards and backwards: as soon as a microbatch reaches the last stage, start its backward immediately. Backward propagates back; meanwhile, new forwards continue feeding in. Each stage maintains the invariant that it does one forward then one backward (1F1B).
ASCII for P=4, M=8:
S0: F1 F2 F3 F4 F5 B1 F6 B2 F7 B3 F8 B4 .. B5 B6 B7 B8
S1: . F1 F2 F3 F4 F5 B1 F6 B2 F7 B3 F8 B4 B5 B6 B7 B8
S2: . . F1 F2 F3 F4 F5 B1 F6 B2 F7 B3 F8 B4 B5 B6 B7 B8
S3: . . . F1 F2 B1 F3 B2 F4 B3 F5 B4 F6 B5 F7 B6 F8 B7 B8
Bubble fraction is the same: (P-1)/(M+P-1). But peak activation memory per stage drops dramatically. Stage 0 only holds activations for in-flight microbatches at any moment ≈ P (warm-up forwards before first backward). Memory: O(P) not O(M).
This is why every modern pipeline uses 1F1B-style scheduling.
8.4 Interleaved 1F1B (Megatron-LM)¶
Idea: each stage owns multiple non-contiguous chunks of layers. With v "virtual stages" per physical stage:
- Layers 1..L/(P·v) on stage 0's chunk 1.
- Layers L/(P·v)+1..2L/(P·v) on stage 1's chunk 1.
- ...
- Layers L/v + 1..L/v + L/(P·v) on stage 0's chunk 2.
- ...
Each microbatch traverses P · v chunks (with comm between every chunk).
The bubble shrinks: (P-1)/(v · M + P - 1). With v=4, M=8, P=4: 3/(32+3) = 8.6% vs 3/11 = 27% for vanilla 1F1B. Cost: more comm (v× more sends/recvs), more scheduling complexity.
8.5 Zero Bubble Pipeline (Qi et al., 2023)¶
Split the backward into two parts:
- B-input (
B_X): compute∂L/∂X(gradient with respect to layer input)-needed by previous stage. - B-weight (
B_W): compute∂L/∂W(gradient with respect to weights)-only needed locally for optimizer.
B_W doesn't block the previous stage. Schedule B_X as soon as possible; defer B_W to fill bubbles. With careful scheduling, achievable bubble approaches zero.
Tradeoff: scheduling complexity increases; some implementations require equal t_F = t_{B_X} and divisible structure.
8.6 Memory and Throughput Comparison¶
Per-stage activation memory (with M microbatches, P stages):
| Schedule | Bubble | Activation memory per stage |
|---|---|---|
| Naive | (P-1)/P | 1 microbatch |
| GPipe | (P-1)/(M+P-1) | M microbatches |
| 1F1B | (P-1)/(M+P-1) | ~P microbatches |
| Interleaved 1F1B (v) | (P-1)/(vM+P-1) | ~P/v microbatches |
| Zero Bubble | ~0 | similar to 1F1B |
8.7 PP Communication¶
Between adjacent stages: send/recv of activations (forward) and gradients (backward). Per microbatch, per stage boundary: 2 · B_micro · S · h bytes (forward + backward). Unlike TP, this is a point-to-point call, not a collective. Scales fine across IB. Pipeline parallelism is the right axis for crossing nodes.
8.8 PP Caveats¶
- LayerNorm/RMSNorm statistics are local-fine.
- Loss computation happens on the last stage; it must
sendthe loss scalar back if you want it everywhere. - Optimizer state is local to each stage-no comm.
- Pipeline imbalance: if stages have unequal compute (e.g., embeddings on stage 0 are cheap, last stage has LM head + loss), the slowest stage sets pace. Solution: assign more layers to faster stages.
9. 3D Parallelism: DP × TP × PP¶
9.1 The Decision Matrix¶
Given a model with parameters Φ, layers L, hidden h, and a cluster of G GPUs in nodes of g_n GPUs each, choose (t, p, d) such that t · p · d = G:
t= TP degree.p= PP degree.d= DP (or FSDP) degree.
Constraints:
t ≤ g_n-TP must stay intra-node (NVLink/NVSwitch).pdividesL-pipeline stages need integer layer counts.- Per-GPU memory must fit: parameters + activations + optimizer state.
- Global batch size constraints:
d × micro_batch × M_microbatches = total_batch. Global batch determines convergence; can't grow without bound.
9.2 Decision Heuristic¶
- Compute params per stage:
Φ / (t · p)(TP shards within stage; PP shards across stages). - Compute static memory per GPU:
16 · Φ / (t · p)for DP-replicated optimizer, or16 · Φ / (t · p · d)for FSDP within DP (i.e., shard within the DP dim). - Add activation memory per GPU. With activation checkpointing and sequence parallelism, this is much smaller.
- Pick smallest
tthat fits, then smallestp, thendfills the rest.
Why prefer small t and p?
- TP comm is on the critical path of every layer.
- PP creates bubbles unless you can make M huge (which inflates batch).
- DP comm overlaps cleanly with backward.
So the priority is: maximize DP, use TP only as needed for memory/compute, use PP only when a single layer is too big for a node-level group, OR when TP+DP can't fit the model.
9.3 Worked Examples¶
Example A: 8B model on 64 H100 (8 nodes × 8 GPUs)¶
Φ = 8B. Static memory at 16Φ = 128 GB.
Try t=1, p=1, d=64 with FSDP-3:
- Per-GPU static: 16Φ/64 = 2 GB. Easy fit.
- Activations: with B_micro=4, S=8192, BF16, ~10 GB. Plenty of room.
- TP comm: none. PP comm: none. Just FSDP all-gather/reduce-scatter.
- Pick this. Simple is fast.
Example B: 70B model on 64 H100¶
Φ = 70B. Static at 16Φ = 1.12 TB.
Try t=1, p=1, d=64 FSDP-3: per-GPU = 17.5 GB. Fits, but barely (activations + workspace might push it over).
Try t=8, p=1, d=8 (TP within node, DP across nodes):
- Per-GPU params (shared by TP shard, replicated across DP) = 2Φ/t = 2·70/8 = 17.5 GB BF16.
- With FSDP across DP: divide by 8 more → 2.2 GB.
- TP comm hot inside each node, DP comm reasonable.
- Likely faster than pure FSDP-3 because activations are sharded too (with sequence parallelism) and matmuls are bigger per rank.
Try t=8, p=2, d=4:
- Stage holds 35B. Per-GPU params = 2·35/8 = 8.75 GB.
- Activations on stage 0 have to live for ~p=2 microbatches → smaller pressure.
- Bubble (P-1)/(M+P-1) = 1/(M+1); with M=8, ~11% bubble. Not great.
- Probably not worth it for 70B; pick the previous.
Example C: 405B model on 1024 H100 (128 nodes × 8 GPUs)¶
Φ = 405B. Static at 16Φ = 6.48 TB. Need ~6.5 TB / 80 GB = ~81 GPUs minimum just for state, but activation+workspace really demands more.
Try t=8, p=8, d=16:
- Stage holds Φ/p = 50.6B params.
- Per-GPU params = 2·50.6 / 8 = 12.6 GB.
- With FSDP across the d=16 dim: 12.6 / 16 = 0.79 GB. Good.
- Activation memory per stage with checkpointing + SP: typically 5-15 GB per GPU.
- TP comm intra-node. PP comm node-to-node (reasonable, ~few GB per microbatch). DP/FSDP comm across the d=16 dim (across more nodes).
- 8 stages → bubble (P-1)/(M+P-1). With M=64, bubble = 7/71 ≈ 10%. Acceptable.
This is roughly the published configuration shape for Llama 3.1 405B.
9.4 The Order of Ranks¶
A subtle but critical point: rank assignment order matters. Convention (Megatron):
So the innermost dimension is TP. This places TP groups within consecutive ranks, which on a node maps to neighbors connected by NVLink. DP comes next, then PP outermost. Get this wrong and you get TP all-reduces over IB → catastrophe.
torch.distributed.new_group lets you create the three groups (TP, DP, PP) with explicit rank lists.
10. Sequence Parallelism (Korthikanti et al., 2022)¶
10.1 The Problem¶
Even with TP, certain ops are not parallelized: LayerNorm, dropout, residual adds. These keep the full activation tensor of shape B × S × h on every TP rank. As S grows (long context training), this dominates memory.
10.2 The Idea¶
Split these ops along the sequence dimension. Each TP rank holds B × S/t × h for LayerNorm/Dropout. The activation memory drops by t.
But TP linears need the full sequence (matmul is in (B·S) × h). So at the boundary: all-gather to convert from sequence-parallel (B × S/t × h) to TP-parallel (B × S × h) on each rank, do the column-parallel linear, etc.
The all-gather replaces the all-reduce that was at the start of the column-parallel block. Comm volume: same total (all-reduce = reduce-scatter + all-gather; SP just moves the boundary). But now the LayerNorm/dropout/residual sit in the cheap SP region.
Net: same comm, much less activation memory. That's the deal.
10.3 Where Each Lives¶
[SP region] LayerNorm | activations B × S/t × h (sharded along S)
↓ all-gather to TP region
[TP region] Linear (column-parallel) → Linear (row-parallel)
↓ reduce-scatter back to SP region
[SP region] Residual + LayerNorm
reduce_scatter at the TP→SP boundary does double duty: aggregates the row-parallel partial sum and shards along sequence in one op.
10.4 Synergy with Long Context¶
Without SP, the 5 a S² / h term (attention activation memory) grows unboundedly. With SP and activation checkpointing, S=128K becomes feasible on commodity-cluster sizes.
Modern LLM training stacks (Megatron-LM, NeMo, Pax, MaxText) ship SP on by default for any TP > 1.
11. Mixed Precision and FP8¶
11.1 Why Lower Precision¶
Compute throughput on H100: - FP32 (TF32): 989 TFLOPS - BF16/FP16: 1979 TFLOPS - FP8: 3958 TFLOPS
Memory and bandwidth scale similarly. The training imperative: keep math in low precision; only escalate to FP32 where needed for stability.
11.2 FP16 with Loss Scaling¶
FP16 dynamic range is small: ~6e-5 to 6.5e4. Many gradients are below 6e-5 → underflow to zero → no learning.
Trick: loss scaling. Multiply the loss by s before backward; gradients are s × larger and fit in FP16 range. Before the optimizer step, divide gradients by s.
Dynamic loss scaling (the production version):
s = initial_scale (e.g., 2^16)
on each step:
grads = backward(s · loss)
if any(isnan or isinf in grads):
s = s / 2 ← overflow: halve the scale, skip step
else:
grads = grads / s
optimizer.step()
if no overflow for `growth_interval` steps (e.g., 2000):
s = s · 2 ← grow when stable
This keeps s near the maximum that doesn't overflow. PyTorch's GradScaler implements exactly this.
11.3 BF16¶
BF16 has FP32's dynamic range (8 exponent bits) but only 7 mantissa bits. For gradient magnitude purposes, equivalent to FP32. No loss scaling needed.
The penalty is reduced precision in the matmul accumulator, but H100 accumulates BF16 matmuls in FP32 internally. Net effect: free dynamic range.
This is why every modern training pipeline (Llama, Mistral, GPT-4 leaks, etc.) uses BF16 not FP16.
11.4 FP8¶
Two FP8 formats:
- E4M3 (4 exponent bits, 3 mantissa bits): range ~±448. Higher precision, lower range. Use for forward activations and weights.
- E5M2 (5 exponent bits, 2 mantissa bits): range ~±57344. Lower precision, higher range. Use for backward gradients (which need range).
Standard recipe (NVIDIA TransformerEngine):
forward: weights and activations cast to E4M3 (per-tensor scaled).
matmul: BF16 accumulator, output dequantized to BF16.
backward: incoming gradient cast to E5M2.
optimizer: FP32 master weights, FP32 m, v.
11.5 Per-Tensor Scaling¶
Each tensor T gets a scale s_T. The FP8 representation is T_fp8 = round(T / s_T). To dequantize: T_bf16 = T_fp8 · s_T. The scale is updated each step using the amax (maximum absolute value) of the previous N steps.
s = amax / fp8_max_representable, with a safety margin.
TransformerEngine maintains an amax history per tensor (forward weights, forward activations, backward grads, etc.); the scale is the EMA or window-max over the history.
11.6 Engineering Notes¶
- FP8 wins compute-bound layers (large matmuls). Layers that are bandwidth-bound or have small compute don't benefit.
- Loss curves with FP8 should track BF16 to within noise. If they diverge, scaling is wrong.
delayed scaling(using last step's amax) is faster thancurrent scaling(compute amax during this step's pass) but slightly less stable.
12. Verifying Communication Overlap¶
A correctly implemented DDP/FSDP overlaps gradient comm with backward compute. Verifying this is a profiling exercise.
12.1 What "Overlap" Looks Like in a Trace¶
In torch.profiler or NSight, you'll see two streams active simultaneously:
- Compute stream: matmuls, attention.
- NCCL stream: all-reduce/all-gather/reduce-scatter kernels.
A healthy backward looks like:
Each all-reduce kernel runs on the NCCL stream while the next matmul runs on compute. Total wall time ≈ max(compute, comm), not sum.
12.2 Failure Modes¶
No overlap. Compute and NCCL are serialized. Causes:
- Bucket too small → all-reduce of a tiny bucket is latency-dominated, no compute to hide it.
- find_unused_parameters=True forcing extra sync.
- The all-gather in FSDP-3 must complete before forward proceeds (this is intentional-there's no way to hide it without prefetch).
Negative overlap. Comm kernel slows compute kernel because they share SMs (NCCL kernels use SMs to copy data). On H100, NCCL has dedicated copy engines for some collectives → less interference.
12.3 FSDP Prefetch in the Trace¶
With BACKWARD_PRE, you should see:
backward(L80): |compute|
all_gather(L79): |comm | ← issued before backward(L79)
backward(L79): |compute|
all_gather(L78): |comm |
The all-gather of unit l-1 runs while unit l's backward runs. Without BACKWARD_PRE, the all-gather of l-1 only starts after unit l's backward finishes → critical path doubled.
12.4 Measuring with Numbers¶
Run two configs:
- Real training: measure step time
T_step. - Training with comm replaced by no-op (or - -bench-no-comm
): measure step timeT_compute`.
Overlap quality: T_step / T_compute. Ideal: 1.0. Practical: 1.05–1.15 for a healthy DDP/FSDP setup.
13. Fault Tolerance¶
A 1024-GPU job over a week has many opportunities to fail. Hardware errors, NCCL hangs, ECC on a HBM page, OOM from a long sequence-any one kills the job.
13.1 NCCL Timeouts¶
NCCL has a watchdog. If a collective hasn't completed in NCCL_TIMEOUT (default 30 min) or PyTorch's dist.init_process_group(timeout=...), the rank aborts.
Causes:
- One rank deadlocked (e.g., infinite loop in dataloader).
- One rank slower (bad GPU; watch for GPU temperature throttling).
- Barriers mismatched: rank 0 calling all_reduce(x) while rank 1 calls all_reduce(y).
Detection:
- Set TORCH_NCCL_ASYNC_ERROR_HANDLING=1: throws a Python exception on timeout instead of segfaulting.
- TORCH_NCCL_DESYNC_DEBUG=1: dumps per-rank state when something hangs.
- Monitor with dcgmi for GPU faults.
13.2 Checkpointing¶
Full checkpoint. Gather the full model and optimizer state to rank 0; save one big file. Simple, but: huge memory spike (rank 0 needs RAM for the full model), slow at scale.
Sharded checkpoint (FSDP, DCP). Each rank saves its shard; metadata describes how to reassemble. Modern PyTorch: torch.distributed.checkpoint.save and load. Allows resharding (load a 64-rank checkpoint onto 128 ranks; reshape happens automatically).
Cadence. Every N steps; N chosen so that wasted-compute-on-failure is acceptable. With ~5% failure rate per day, every 30-60 min is reasonable. Async checkpointing (writing in background while training continues) helps.
Atomic writes. Write to ckpt.tmp then rename. If a node crashes mid-write, no partial corruption.
13.3 Elastic Training¶
Cluster topology changes (nodes go away). Solutions:
torchrun --nnodes=2:8: rendezvous backend allows joining/leaving.- On rank loss: surviving ranks re-form the process group. Resume from latest checkpoint with the new world size (sharded checkpoints make this resharding transparent).
In practice: most teams just use static configs and re-launch on failure. Elastic is mostly for ML-as-a-service platforms.
13.4 Determinism Considerations¶
NCCL is deterministic by default (same input → same output) but all-reduce is not associative in floating point. Different ring traversal orders yield slightly different sums. If your random seed plus determinism is reproducible to the bit on the same world size-but probably not on a different world size.
For deterministic training: fix world size, use deterministic ops, accept that scaling is the cost of bit-determinism.
14. Cluster Topology¶
Performance depends on physical wiring. Here's what each layer does and why it matters.
14.1 Inside the Node¶
NVLink (intra-GPU links). On H100: 18 NVLink-4 links per GPU at 25 GB/s each → 450 GB/s per direction, 900 GB/s bidirectional aggregate (for the NVLink fabric, though per-pair bandwidth is lower).
NVSwitch (the chip that connects NVLinks). DGX H100: 4× NVSwitch chips form a non-blocking fabric: any GPU can talk to any other at full NVLink speed. Without NVSwitch (e.g., HGX boards with direct NVLink), GPU 0 ↔ GPU 1 might be 600 GB/s but GPU 0 ↔ GPU 5 might be 300 GB/s (one hop).
NVLink SHARP (NVLS). NVSwitch v3 (H100) does in-network reduction: ranks send data to switch, switch reduces, sends result back. Halves bandwidth requirement on the link (instead of 2(N-1)/N · M, more like (N-1)/N · M).
PCIe. CPU↔GPU and GPU↔NIC paths. Gen5 x16 = 64 GB/s. Bottleneck for CPU offload (ZeRO-Offload).
14.2 Between Nodes¶
InfiniBand (or RoCE). H100 generation: NDR (400 Gb/s = 50 GB/s) per port. DGX H100 has 8× NDR cards: one per GPU (rail-aligned).
Rail topology. Each GPU has its own NIC. The cluster fabric is structured as R rails, where rail r connects only NIC r of each node to a top-of-rack switch. Cross-rail traffic must go through a higher-tier switch. NCCL, properly configured, keeps each rank's traffic on its own rail.
Fat tree / spine-leaf. Typical large-cluster topology: leaf switches connect ~32 nodes (= 256 GPUs); spine switches connect leaves. Bisection bandwidth determined by spine fan-out.
Hop count. Same-node: 0 hops (NVSwitch). Same-rack: 1 hop (leaf). Cross-rack: 3 hops (leaf-spine-leaf). Each hop adds ~1µs of latency and shares bandwidth with other tenants.
14.3 GPU↔NIC Affinity¶
nvidia-smi topo -m output:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 NIC2 NIC3 CPU
GPU0 X NV12 NV12 NV12 NV12 NV12 NV12 NV12 PXB SYS SYS SYS 0
NIC0 PXB SYS SYS SYS SYS SYS SYS SYS X ...
Legend: NV12 = 12 NVLinks. PXB = PCIe through bridge (good). SYS = through CPU socket (bad).
For NCCL: ensure NCCL_IB_HCA lists, in order, the NICs aligned with each GPU rank (mlx5_0 for GPU 0, etc.).
14.4 Why Topology Is in This Chapter¶
Algorithm choice depends on topology. Ring all-reduce on NVLink: optimal. Ring all-reduce that crosses spine switches every step: catastrophic (latency dominated, contention with other tenants).
Decision: TP within node, DP/PP across nodes. NCCL's auto-topology gets this right most of the time, but verify.
15. Practical Exercises¶
Six worked problems. Cover them yourself before reading the answers.
Exercise 1: All-reduce Bandwidth¶
You have 8 GPUs in a ring at 600 GB/s NVLink each. You all-reduce a tensor of 14 GB (BF16 gradients of a 7B model). How long does the all-reduce take, ignoring latency?
Answer. Ring all-reduce sends 2(N-1)/N · M = 2·7/8 · 14 = 24.5 GB per rank. At 600 GB/s: 24.5 / 600 = 40.8 ms.
If, instead of ring, NCCL chose tree (suboptimal here): 2 log₂(8) · M = 6 · 14 = 84 GB per rank → 140 ms. Big difference; this is why NCCL picks ring on intra-node.
Exercise 2: Peak Memory for FSDP + Activation Checkpointing¶
Compute peak per-GPU memory for a 70B model under FSDP-3 + activation checkpointing on 64 H100 80GB GPUs. Assume BF16 params/grads, FP32 optimizer (Adam, K=12), batch 4 per GPU, sequence 8192, hidden 8192, 80 layers, BF16 activations. Use the simplified bound: with checkpointing, per-layer activation ~ B · S · h · 12 bytes (rough).
Answer.
- Static (FSDP-3): 16Φ / N = 16 · 70 / 64 = 17.5 GB.
- Activation per layer (post-checkpoint): 4 · 8192 · 8192 · 12 ≈ 3.2 GB. With one stored input per layer × 80 layers: ~256 GB. This is way too much if each layer's input is stored.
Recheck: with selective activation checkpointing, you store only the input of each transformer block (one tensor per block). That's 4 · 8192 · 8192 · 2 = 537 MB per layer × 80 = 43 GB. Plus the recompute buffer for the layer currently being recomputed (transient).
So total: 17.5 (state) + 43 (stored block inputs) + ~10 (workspace, kv cache, recompute buffer) ≈ 70 GB. Tight on 80 GB. Solutions: increase N, decrease batch, add TP=2.
Exercise 3: Pipeline Bubble¶
You run GPipe with P=8 stages. What's the minimum number of microbatches to keep bubble below 5%?
Answer. Bubble = (P-1)/(M+P-1) ≤ 0.05. Solve: (M + 7) ≥ 7/0.05 = 140. So M ≥ 133. With per-microbatch batch of 1 and 133 microbatches, you need batch ≥ 133 per DP replica. May not be feasible if your global batch budget is small.
With 1F1B, same bubble. With interleaved 1F1B (v=4): 7/(4M+7) ≤ 0.05 → M ≥ 33. Much better.
Exercise 4: TP Communication Cost¶
A Transformer block with h=8192, B=4, S=8192, runs with t=8 TP. Compute total bytes communicated per block per step (forward + backward).
Answer. 4 all-reduces per block × 2(t-1)/t · 2BSh bytes per all-reduce.
2BSh = 2 · 4 · 8192 · 8192 = 512 MB.
Per all-reduce: 1.75 · 512 = 896 MB. Times 4: 3.58 GB per block per step.
With L=80 blocks: ~287 GB per step of TP comm. On 600 GB/s NVLink: ~480 ms. This is why TP comm dominates and why you want overlap (sequence parallelism + compute).
Exercise 5: 3D Parallel Sizing¶
Llama-3 405B on 1024 H100 80GB. Choose (t, p, d). Show your reasoning.
Answer.
- t = 8 (max, intra-node).
- Try p = 8. Stage params = 405/8 = 50.6B. Per-GPU after TP: 2 · 50.6 / 8 = 12.6 GB (BF16).
- d = 1024 / (8·8) = 16. With ZeRO/FSDP across DP: per-GPU state = (16Φ)/(t·p·d) = 16 · 405 / 1024 ≈ 6.3 GB. Fits comfortably; activations + workspace ~30 GB; total ~36 GB; within 80 GB.
- With PP=8 and global batch ~1024 sequences, microbatch=8 per stage × 16 DP replicas = 128 microbatches. Bubble = 7/(128+7) = 5.2%. With interleaved 1F1B v=4: 7/(512+7) ≈ 1.3%. Excellent.
(t=8, p=8, d=16) or thereabouts. Real systems may vary p based on per-stage compute and act memory.
Exercise 6: ZeRO Stage Tradeoff¶
Your model is 13B; you have 16 A100 40GB GPUs. Standard mixed-precision Adam. Which ZeRO stage is the minimum that fits?
Answer.
- DDP: 16Φ = 16·13 = 208 GB per rank-won't fit.
- ZeRO-1: 4Φ + 12Φ/N = 52 + 9.75 = 61.75 GB - won't fit on 40 GB.
- ZeRO-2:2Φ + 14Φ/N = 26 + 11.4 = 37.4 GB - borderline; with activations and workspace, will OOM.
- ZeRO-3: `16Φ/N = 13 GB - fits with room for activations.
ZeRO-3 / FSDP-3 is required. With activation checkpointing and a moderate batch, you'll stay around 30-35 GB peak.
Concluding Synthesis¶
The whole apparatus is one idea repeated at three resolutions:
-
Memory is the binding constraint. Single-GPU memory grows ~2× per generation; model size grows ~10× per year. The gap is what distributed training closes.
-
Communication is the price. Every form of partitioning (DP grad sync, ZeRO param gather, TP activation all-reduce, PP activation send) trades memory for bandwidth.
-
Topology determines algorithm. Ring is bandwidth-optimal but latency-linear. Tree is latency-logarithmic but bandwidth-suboptimal. Inside a node (NVLink), bandwidth is plentiful → ring; across nodes (IB), pick based on message size and tier.
The mature configuration for a frontier model: - TP = node size (e.g., 8), so TP all-reduces are NVLink-confined. - PP across nodes, so cross-node traffic is point-to-point (cheaper than all-reduce). - DP fills the rest, with FSDP/ZeRO-3 to shard state within each DP rank. - Sequence parallel on by default for any TP > 1, to keep activation memory in check. - Activation checkpointing to make long context tractable. - BF16 + FP8 in the matmul-heavy layers; FP32 master weights.
When you can derive each line of this recipe from first principles-why TP wants NVLink, why FSDP-3 needs prefetch, why PP wants 1F1B, why ZeRO trades 1.5× comm for N× memory-you have mastered distributed training.
Appendix: Notation¶
- `Φ - number of model parameters.
- `N - DP world size (or generic rank count for collectives).
- `t, p, d - TP, PP, DP degrees.
- `L - number of Transformer layers.
- `h - hidden dimension.
- `S - sequence length.
- `B - batch size.
- `M - number of microbatches (PP), or message size in bytes (collectives).
- `α - link latency (s).
- `β - link bandwidth (B/s).
- `K - Adam optimizer state multiplier (= 12 for FP32 m, v, master weights).
Appendix: Paper References (for further reading; the math above is self-contained)¶
- Patarasuk & Yuan, "Bandwidth Optimal All-reduce Algorithms for Clusters of Workstations," 2009.
- Rajbhandari et al., "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models," SC 2020.
- Shoeybi et al., "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism," 2019.
- Narayanan et al., "Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM," SC 2021.
- Huang et al., "GPipe: Efficient Training of Giant Neural Networks Using Pipeline Parallelism," NeurIPS 2019.
- Narayanan et al., "PipeDream: Generalized Pipeline Parallelism for DNN Training," SOSP 2019.
- Korthikanti et al., "Reducing Activation Recomputation in Large Transformer Models," 2022.
- Zhao et al., "PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel," 2023.
- Qi et al., "Zero Bubble Pipeline Parallelism," 2023.
- Micikevicius et al., "Mixed Precision Training," ICLR 2018.
- NVIDIA, TransformerEngine documentation (FP8 training recipes).
These give you the published numerical results and ablations; the algorithms themselves you now own.