Deep Dive 12 - Kernel Fusion: Theory, Practice, and the Compilers That Do It For You¶
Chapter 11 told you which precisions to use where. Chapter 12 tells you how to schedule those operations onto the GPU so the precision decisions actually pay off - by eliminating the HBM round-trips that dominate end-to-end latency in modern deep-learning workloads.
This chapter is self-contained. You can read it standalone; it pulls forward concepts from chapters 01 (GPU architecture), 02 (CUDA), 03 (Triton), 04 (PyTorch internals + Inductor), 05 (JAX + XLA), 07 (FlashAttention), and 11 (numerics) and will reference them by chapter number rather than re-deriving.
Table of contents¶
- Why fuse at all
- The HBM round-trip cost model
- Fusion taxonomy
- Vertical fusion derived
- Horizontal fusion derived
- GEMM epilogue fusion
- Streaming-reduction fusion: the FlashAttention pattern
- Compiler-driven fusion: XLA, TorchInductor, Triton
- Hand-rolled fusion in Triton: three full kernels
- Precision discipline under fusion
- The limits of fusion
- Profiling fused kernels with Nsight Compute
- When NOT to fuse
- Practical exercises
- Cheat sheet and further reading
1. Why fuse at all¶
The single most important observation in modern GPU performance work:
Most deep-learning operators in the forward pass of a transformer are memory-bandwidth-bound, not compute-bound.
To see why, recall the roofline from chapter 01. A modern H100 GPU has roughly:
- BF16 dense tensor-core throughput: ~989 TFLOPS.
- HBM3 bandwidth: ~3.0 TB/s.
The crossover arithmetic intensity at which a kernel transitions from memory-bound to compute-bound is:
A pure elementwise operation like y = a * x + b performs 2 FLOPs per 12 bytes moved (4 for x, 4 for y, optionally 4 for a; in BF16 halve it; doesn't matter - the intensity is well under 1 FLOP/byte). The GPU sits at <1% of peak FLOPs and 100% of peak bandwidth for the entire kernel.
The consequence: if your network is a chain of elementwise ops and small reductions, total time is determined almost entirely by total bytes moved through HBM - not by total work done. A LayerNorm → Linear → GELU → Dropout chain executed as four separate kernels reads and writes the activation tensor through HBM four times. Fused into one kernel, it reads once, writes once.
For the typical transformer hidden state at batch=8, seqlen=4096, hidden=8192, BF16:
Each HBM round-trip costs 512 MiB / 3 TB/s ≈ 175 µs. Saving three round-trips per layer × 80 layers = 240 round-trips = 42 ms per forward pass - for free, just by stopping the round-tripping. That is the prize.
2. The HBM round-trip cost model¶
Let n_op be the number of fused operations in a chain, S the size of the activation tensor in bytes, BW the HBM bandwidth, and K_launch the per-kernel launch overhead (~5 µs on a modern driver). Time for unfused execution:
(Each op reads S, writes S.)
Time for fused execution (one kernel reads once, writes once, does all the work in registers/SMEM):
For elementwise chains, T_compute_in_kernel is negligible compared to the HBM term, so:
A fused chain of 5 elementwise ops is roughly 5× faster than the unfused version, regardless of how clever the unfused kernels are individually. This is the headline result that motivates every fusion compiler ever written.
Worked numerical example¶
Take the post-attention residual stream of a Llama-3-70B layer, batch=1, seqlen=8192:
The post-attention chain: x + attn_out → RMSNorm → linear_gate → silu → linear_up · gate → linear_down → x + ffn_out.
Counting just the elementwise pieces (residual add, RMSNorm scale, SiLU, elementwise multiply, residual add) - five elementwise/light-reduction operations on the activation tensor:
T_unfused_elementwise = 5 * (2 * 128 MiB / 3 TB/s + 5 µs)
= 5 * (85 µs + 5 µs)
= 450 µs
T_fused_elementwise = 2 * 128 MiB / 3 TB/s + 5 µs
= 90 µs
Per layer, per token, fusion saves ~360 µs in the elementwise chain. Across 80 layers and a 100-token decode, that's ~2.9 seconds. On a real inference engine the saving is closer to 30–50% of total latency because the matmuls still dominate, but eliminating elementwise round-trips is the single most impactful generic optimization in deep-learning compilers.
3. Fusion taxonomy¶
Fusion comes in five shapes, in roughly ascending implementation difficulty:
| # | Pattern | Example | Difficulty |
|---|---|---|---|
| 1 | Elementwise → elementwise | (a + b) * c |
Trivial (every compiler does it) |
| 2 | Elementwise → reduction | sum(x * x) (used in RMSNorm) |
Easy |
| 3 | Reduction → elementwise (broadcast) | RMSNorm = x / sqrt(mean(x²) + ε) * γ |
Medium (needs two-pass or online algorithm) |
| 4 | GEMM + epilogue | gelu(A @ B + bias) |
Medium (CUTLASS/CUBLASLt epilogue API) |
| 5 | Streaming reduction over GEMM output | FlashAttention: softmax(Q@Kᵀ / √d) @ V |
Hard (requires algorithmic redesign; chapter 07) |
A sixth, more ambitious shape - multi-GEMM fusion, where two matrix multiplies sharing an intermediate are fused (e.g., the FFN's up_proj and gate_proj in SwiGLU) - is increasingly common in production inference engines but requires either (a) careful CUTLASS programming or (b) horizontal fusion at the Triton level.
The taxonomy axis you actually care about is vertical vs horizontal:
- Vertical (producer-consumer) fusion combines operations along the data-flow direction: op B consumes op A's output, so we keep A's output in registers/SMEM and feed it directly to B without writing to HBM. All of patterns 1–5 above are vertical.
- Horizontal (sibling) fusion combines independent operations that have no data dependency, executing them in the same kernel to amortize launch overhead and (sometimes) share input loads. Example:
q = x @ Wq; k = x @ Wk; v = x @ Wvcan be done as one fused kernel that loadsxonce.
The next two sections derive each rigorously.
4. Vertical fusion derived¶
4.1 The producer-consumer pattern¶
Consider two operations:
Unfused, the dataflow through HBM is:
Fused into one kernel:
For each tile of A:
load tile_A from HBM into registers
tile_B = f(tile_A) # stays in registers
tile_C = g(tile_B) # stays in registers
store tile_C to HBM
Total HBM traffic = |A| + |C|
We saved 2|B| bytes of HBM traffic. If f and g are elementwise and same-shape, |A| = |B| = |C|, so we cut traffic by 50% and (in the bandwidth-bound regime) doubled throughput.
4.2 The shape-compatibility requirement¶
Vertical fusion works only when the producer's output layout matches the consumer's input layout at the tile granularity. Two cases:
- Pointwise op → pointwise op: trivially compatible (same element-to-element correspondence). Always fusible.
- Reduction → broadcast: the reduction shrinks the tensor; the broadcast re-expands it. Fusion is possible but requires either (a) keeping the reduction result in SMEM and re-reading per element (the two-pass RMSNorm pattern), or (b) computing the reduction online during the consumer pass.
4.3 Worked example: RMSNorm fused¶
The naive RMSNorm:
mean_sq = (x * x).mean(dim=-1, keepdim=True) # kernel 1
rrms = torch.rsqrt(mean_sq + eps) # kernel 2
y = x * rrms * gamma # kernel 3
Three kernels, each round-tripping x (or a derivative of it) through HBM. Fused, in pseudocode:
def rmsnorm_fused(x, gamma, eps):
# x: (..., H)
# one tile = one row (H elements)
for row in tiles(x):
# pass 1: reduction
s = 0.0
for j in range(0, H, BLOCK):
xj = load(row, j) # HBM → registers
s += sum(xj * xj) # accumulate in fp32
rrms = rsqrt(s / H + eps) # scalar
# pass 2: scale-broadcast
for j in range(0, H, BLOCK):
xj = load(row, j) # HBM → registers (re-read!)
gj = load(gamma, j)
store(row, j, xj * rrms * gj)
We read x twice and write it once - total 3|x| HBM bytes - but we eliminated |x| for mean_sq (which never materialized) and saved two kernel launches. For the Llama-3-70B example in §2, the unfused version moves 5|x| bytes; the fused version moves 3|x|. Speedup: 5/3 = 1.67×.
A single-pass RMSNorm uses Welford-style online statistics to avoid the re-read of x - that drops traffic to 2|x|, the absolute floor. The Triton kernel in chapter 03 shows this.
5. Horizontal fusion derived¶
5.1 The independent-siblings pattern¶
Now consider three independent operations sharing an input:
Unfused: three kernel launches, each reading X from HBM. HBM traffic = 3|X| + |Q| + |K| + |V|.
Horizontally fused: one kernel reads X once and produces all three outputs. HBM traffic = |X| + |Q| + |K| + |V|. Saves 2|X|.
For X of shape (batch * seqlen, hidden) = (8 * 4096, 8192) in BF16 = 512 MiB, savings = 1 GiB of HBM traffic ≈ 340 µs at H100 bandwidth, just for the QKV projections per layer.
In practice, modern inference engines (vLLM, TensorRT-LLM) fuse QKV by concatenating [Wq, Wk, Wv] along the output dimension into a single W_qkv of shape (hidden, 3*head_dim*n_heads), and slicing after the matmul. This is mathematically identical to horizontal fusion and is the standard pattern - if a transformer codebase you read does not fuse QKV, that's a perf bug.
5.2 The SwiGLU case¶
For the FFN block:
W_gate and W_up can be horizontally fused into W_gu of shape (hidden, 2 * inter_dim), sliced into halves, with SiLU and the elementwise multiply fused as the epilogue. Saves |X| HBM read per FFN per layer. For Llama-3-70B at batch=1 decode, ~85 µs/layer × 80 layers = 6.8 ms per token - substantial.
6. GEMM epilogue fusion¶
A GEMM epilogue is any elementwise operation chained immediately after C = A @ B. The CUTLASS library (and its successor, CuTeDSL) supports declarative epilogue fusion via a templated programming model. Common epilogues:
C = A @ B + bias(the GEMM-bias-add pattern in every linear layer).C = act(A @ B + bias)whereact ∈ {ReLU, GELU, SiLU}.C = act(A @ B + bias) * scale(for quantized inference).C = act(A @ B + bias) + residual(the residual stream pattern in transformers - fuses the residual add into the matmul).
6.1 Why epilogue fusion is cheap¶
The GEMM kernel already has the output tile C_tile resident in registers immediately after computing it. Applying an elementwise function to it before storing costs zero additional HBM traffic. The only cost is a handful of extra instructions per register, well below the noise floor of the matmul itself.
In CUTLASS terms, the epilogue is a templated EpilogueOp that receives the accumulator tile and produces the output tile:
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationGELU<
ElementOutput, // bf16
128 / sizeof_bits<...>, // vector length
ElementAccumulator, // fp32
ElementCompute>; // fp32
The LinearCombinationGELU epilogue computes act(α * accumulator + β * bias) in registers, then stores to HBM. One kernel; zero round-trip.
6.2 The "fuse the residual add into the matmul" trick¶
The residual connection in a transformer block computes:
If the linear is a CUTLASS GEMM with epilogue D = α*(A@B) + β*C, you can pass x itself as C with β=1, and the residual add costs zero extra HBM traffic - the GEMM was going to write its output anyway; with the epilogue, it writes the residual-added output instead. This saves a full |x| round-trip per block, every layer, every forward pass.
This is implemented by addmm in PyTorch (when properly routed to CUBLASLt with the epilogue path) and by every production inference engine. If you write y = linear(x) + residual as two separate kernels in a hot path, that's a perf bug.
7. Streaming-reduction fusion: the FlashAttention pattern¶
The most algorithmically sophisticated fusion in modern AI is FlashAttention (chapter 07). The naive attention computation:
S = Q @ Kᵀ # (B, H, M, N) - materialized
P = softmax(S, axis=-1) # (B, H, M, N) - materialized
O = P @ V # (B, H, M, d)
The intermediate S and P are O(M·N) and dominate memory at long sequence length. At seqlen=8192 with head_dim=128, S for a single batch×head pair is 256 MiB in BF16. For B=8, H=64, the total is 128 GiB. Doesn't fit.
FlashAttention's insight: S and P never need to be materialized in HBM. Compute the softmax incrementally, tile-by-tile, while accumulating O directly. The mathematical machinery is the online softmax derived in chapter 03 §online-softmax and rigorously in chapter 07.
For fusion purposes, the structural lesson is:
A reduction (softmax-then-matmul) over a streaming source (Q@Kᵀ computed tile-by-tile) can be fused into a single kernel that never materializes the intermediate.
This pattern - streaming a producer through a reducer with state kept in registers/SMEM - generalizes well beyond attention. Examples in production:
- Cross-entropy loss fused with the final logits projection. Logits at vocab=128k × seqlen=4096 × batch=8 are 16 GiB; never materialize.
- Top-k sampling fused with logits. Same memory argument.
- MoE router + dispatch fused. The router's softmax + top-k + scatter can all run in a single kernel.
The price: the fused kernel is algorithmically non-trivial. Each new instance requires real engineering. Compilers cannot yet derive these fusions automatically; you write them by hand in Triton or CUDA.
8. Compiler-driven fusion: XLA, TorchInductor, Triton¶
Three major systems perform deep-learning kernel fusion in production. Their philosophies differ; their results converge.
8.1 XLA (JAX, TensorFlow)¶
Chapter 05 covers XLA in depth. The relevant fusion passes:
fusionpass: the canonical pass. Groups elementwise/broadcast/reduce ops into "fusion clusters" and emits one kernel per cluster. Driven by a cost model that estimates HBM traffic.gpu_fusion_pipeline: the GPU-specific lowering. Emits LLVM IR with a single CUDA kernel per fusion. Modern XLA also emits Triton for some patterns (matmul + epilogue).priority_fusion: newer pass with a priority queue over fusion candidates.
Fusion in XLA is declarative: you write pure functional JAX, XLA decides what fuses. You can inspect with jax.jit(f).lower(...).compile().as_text() (chapter 05).
8.2 TorchInductor (PyTorch 2)¶
Chapter 04 covers Inductor. Its fusion strategy:
- Scheduler-driven node fusion: Inductor builds a graph of
IRNodes (one per ATen op), then greedily fuses adjacent nodes whose fusion satisfies a memory-locality cost model. - Emit Triton or C++ for the fused kernel. GPU path emits Triton; CPU path emits C++ with OpenMP.
- Pointwise + reduction + pointwise is the bread-and-butter fusion class. More than 80% of the speedups Inductor delivers come from this pattern (per PyTorch's perf blogs).
Inspect with TORCH_LOGS="output_code" (chapter 04) - you get the actual Triton source Inductor generated.
8.3 Triton autotuning (manual, but compiler-assisted)¶
Triton (chapter 03) is the kernel-author's tool, not a graph compiler. You write the fused kernel; Triton handles the lowering. The compiler contribution is in autotuning - exploring tile shapes, num_warps, num_stages combinations and picking the best.
Production stack composition (typical 2026 inference engine):
Model architecture (PyTorch)
│
▼
torch.compile + Inductor ← elementwise + reduction fusion (auto)
│
▼
CUBLASLt / CUTLASS matmuls ← GEMM + epilogue fusion (manual config)
│
▼
FlashAttention / xFormers ← streaming-reduction fusion (handwritten)
│
▼
Triton custom kernels ← anything Inductor missed (handwritten)
The lesson: let the compiler do the easy fusions; reserve human effort for the algorithmically hard ones (FlashAttention, paged attention, fused MoE, fused quantized GEMMs like Marlin).
9. Hand-rolled fusion in Triton: three full kernels¶
We work three increasingly complex examples. All assume the reader has read chapter 03 (Triton).
9.1 Kernel 1 - Fused bias-GELU-residual¶
Operation: y = gelu(x @ W + bias) + residual, where x: (M, K), W: (K, N), bias: (N,), residual: (M, N).
The matmul itself uses standard tiled GEMM (chapter 02). The fusion is in the epilogue: after computing the accumulator tile, apply bias, GELU, and the residual add in registers, then store.
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BM': 128, 'BN': 256, 'BK': 32}, num_warps=8, num_stages=3),
triton.Config({'BM': 64, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=4),
triton.Config({'BM': 128, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
],
key=['M', 'N', 'K'],
)
@triton.jit
def fused_linear_gelu_residual(
x_ptr, w_ptr, bias_ptr, residual_ptr, y_ptr,
M, N, K,
sxm, sxk, swk, swn, srm, srn, sym, syn,
BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BM + tl.arange(0, BM)
offs_n = pid_n * BN + tl.arange(0, BN)
offs_k = tl.arange(0, BK)
x_ptrs = x_ptr + offs_m[:, None] * sxm + offs_k[None, :] * sxk
w_ptrs = w_ptr + offs_k[:, None] * swk + offs_n[None, :] * swn
acc = tl.zeros((BM, BN), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BK)):
x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BK, other=0.0)
w = tl.load(w_ptrs, mask=offs_k[:, None] < K - k * BK, other=0.0)
acc += tl.dot(x, w)
x_ptrs += BK * sxk
w_ptrs += BK * swk
# --- epilogue, in registers, no HBM round-trip ---
bias = tl.load(bias_ptr + offs_n, mask=offs_n < N, other=0.0)
acc = acc + bias[None, :]
# GELU approximation (tanh form), fp32
c = 0.7978845608 # sqrt(2/pi)
acc_g = 0.5 * acc * (1.0 + tl.math.tanh(c * (acc + 0.044715 * acc * acc * acc)))
residual = tl.load(
residual_ptr + offs_m[:, None] * srm + offs_n[None, :] * srn,
mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0,
)
y = acc_g + residual.to(tl.float32)
tl.store(
y_ptr + offs_m[:, None] * sym + offs_n[None, :] * syn,
y.to(y_ptr.dtype.element_ty),
mask=(offs_m[:, None] < M) & (offs_n[None, :] < N),
)
HBM traffic accounting:
- Unfused (4 kernels: matmul, bias-add, GELU, residual-add):
|x| + |W| + 4|y| + |bias| + |residual|≈|x| + |W| + 5|y|. - Fused:
|x| + |W| + |bias| + |residual| + |y|≈|x| + |W| + 2|y|. - Savings:
3|y|HBM bytes per call.
9.2 Kernel 2 - Fused RMSNorm with online statistics¶
Two-pass RMSNorm requires re-reading x. One-pass uses Welford-style online updates. For RMSNorm specifically, since we only need the sum of squares (not variance), the update is simply additive:
@triton.jit
def rmsnorm_fwd_fused(
x_ptr, gamma_ptr, y_ptr,
stride_xm, stride_xn,
stride_ym, stride_yn,
N, eps,
BLOCK_N: tl.constexpr,
):
# One program instance handles one row.
row = tl.program_id(0)
cols = tl.arange(0, BLOCK_N)
mask = cols < N
x = tl.load(x_ptr + row * stride_xm + cols * stride_xn, mask=mask, other=0.0).to(tl.float32)
sum_sq = tl.sum(x * x, axis=0)
rrms = 1.0 / tl.sqrt(sum_sq / N + eps)
gamma = tl.load(gamma_ptr + cols, mask=mask, other=0.0).to(tl.float32)
y = (x * rrms * gamma).to(y_ptr.dtype.element_ty)
tl.store(y_ptr + row * stride_ym + cols * stride_yn, y, mask=mask)
If N > BLOCK_N (hidden dim larger than what fits in one tile), this becomes two-pass with shared-memory state. For modern transformer hidden dims (4096–16384), one-tile-per-row is feasible up to BLOCK_N=16384 on H100 (uses ~64 KiB of registers/SMEM).
Note the precision discipline: load in BF16, promote to FP32 for the reduction and the divide, store back in BF16. Chapter 11 §3.3 explains why this is mandatory - accumulating sum-of-squares in BF16 catastrophically loses precision past hidden ≈ 1024.
9.3 Kernel 3 - Fused softmax (causal-masked, for attention)¶
The streaming softmax kernel from chapter 03, with causal masking and tile-wise online normalization:
@triton.jit
def causal_softmax(
s_ptr, o_ptr, stride_b, stride_h, stride_m, stride_n,
M, N,
BLOCK_N: tl.constexpr,
):
pid_bh = tl.program_id(0)
pid_m = tl.program_id(1)
# Process one query row at a time
row = pid_m
cols = tl.arange(0, BLOCK_N)
base = pid_bh * stride_h + row * stride_m
# Online softmax state
m_i = -float('inf')
l_i = 0.0
# First pass: find max and partial sum
for start in range(0, N, BLOCK_N):
offs = start + cols
mask = (offs < N) & (offs <= row) # causal
s = tl.load(s_ptr + base + offs * stride_n, mask=mask, other=-float('inf')).to(tl.float32)
m_new = tl.maximum(m_i, tl.max(s, axis=0))
l_i = l_i * tl.exp(m_i - m_new) + tl.sum(tl.exp(s - m_new), axis=0)
m_i = m_new
# Second pass: normalize and store
for start in range(0, N, BLOCK_N):
offs = start + cols
mask = (offs < N) & (offs <= row)
s = tl.load(s_ptr + base + offs * stride_n, mask=mask, other=-float('inf')).to(tl.float32)
p = tl.exp(s - m_i) / l_i
tl.store(o_ptr + base + offs * stride_n, p.to(o_ptr.dtype.element_ty), mask=mask)
This kernel is a building block; the full FlashAttention kernel goes one step further and fuses the matmul P @ V into the same loop, never materializing P at all. See chapter 07 for the full derivation; the punchline is that the inner loop interleaves Q @ Kᵀ tile computation, online softmax update, and P @ V accumulation - all in registers.
10. Precision discipline under fusion¶
Fusion makes precision choices more dangerous, not less, because:
- Intermediate values that used to be written to HBM in their materialized dtype are now stored in registers in whatever dtype the producer last computed in. A BF16 elementwise op that previously rounded its output to BF16 may now keep it as FP32 in registers, and the next op consumes FP32 - which may be silently better, but is also a behavioral change.
- Accumulators inside a fused reduction must be FP32, not the input dtype. Chapter 11 §3.2 derives this for reductions; the rule applies verbatim inside fused kernels.
- Epilogues on GEMMs typically receive FP32 accumulator tiles and downcast at the last step. If you insert a precision-sensitive operation (a divide, an exp, a log) in the epilogue, do it in FP32 before the downcast.
The discipline cheat sheet:
| Operation | Compute dtype inside fused kernel | Why |
|---|---|---|
| Elementwise add/mul | match input | no precision loss either way |
| Elementwise divide / sqrt / exp / log | FP32 | nonlinear; small inputs lose precision in BF16 |
| Reduction (sum, mean, dot, max) | FP32 | catastrophic cancellation in BF16 past ~256 elements |
| Softmax | FP32 internally, BF16 output | both reduction and exp need FP32 |
| LayerNorm / RMSNorm | FP32 statistics, BF16 output | reduction + divide |
| GEMM accumulator | FP32 (tensor cores already do this for BF16/FP16 inputs) | hardware default |
| GELU / SiLU activation | FP32 if the epilogue, else match | tanh/exp inside |
If your fused kernel diverges from the unfused reference past ~1e-3 in BF16, you almost certainly downcasted an accumulator too early.
11. The limits of fusion¶
Fusion is not free; it competes for finite GPU resources. Three hard constraints:
11.1 Register pressure¶
Each Triton/CUDA kernel uses some number of registers per thread. An H100 SM has 65,536 32-bit registers shared across active warps. Occupancy = active_warps / max_warps_per_SM. A fused kernel with deeper computation needs more registers per thread; past a threshold, occupancy collapses and HBM-fetch latency stops being hidden by warp switching.
The relationship is:
If your fused kernel uses 128 regs/thread, you get 65536 / 128 = 512 threads per SM - only 16 warps. If you only need 32 warps to hide latency, this is fine; if you need 64, you've over-fused.
Diagnostic: nvcc --ptxas-options=-v (CUDA) or Triton's autotune output reports registers per thread. Above 128, look hard.
11.2 Shared memory capacity¶
H100 has 228 KiB of SMEM per SM (configurable). Fused kernels often use SMEM to stage intermediate tiles. Past the capacity, you can't fit two concurrent thread blocks per SM, halving occupancy.
For matmul kernels with epilogues, SMEM is dominated by the A and B tiles: 2 * BM * BK * dtype_bytes + 2 * BK * BN * dtype_bytes (the 2 is for double-buffering). Epilogue logic is usually register-only.
11.3 The kernel-launch amortization plateau¶
For very small inputs (batch=1, seqlen=1 in decoding), kernel launch overhead (~5 µs) is comparable to kernel runtime. Fusion's value is huge - eliminating a launch saves more than the kernel itself takes. But for very large inputs, launch overhead is amortized to zero and fusion's only value is the HBM-traffic reduction.
The decoding regime (single-token autoregressive) is the most fusion-sensitive workload in AI infrastructure. Every inference engine in production fuses aggressively because of this.
11.4 Tile-shape mismatches¶
If op A is naturally tiled (64, 128) and op B (128, 64), you cannot fuse them in the obvious way - A's output tile doesn't match B's input tile. You either accept a transpose in registers (cheap if it fits) or accept the unfused cost. Compiler-driven fusion (XLA, Inductor) deals with this by not attempting fusions that require expensive layout changes; the cost model rejects them.
12. Profiling fused kernels with Nsight Compute¶
You have a fused kernel. Is it actually fast? Nsight Compute (ncu) is the answer.
The minimal workflow:
ncu --set full --kernel-name fused_linear_gelu_residual -o report ./my_app
ncu-ui report.ncu-rep # open the GUI
The metrics that matter for fused kernels:
| Metric | Meaning | Healthy value |
|---|---|---|
sm__throughput.avg.pct_of_peak_sustained_elapsed |
SM utilization | >70% for compute-bound, <30% for memory-bound (expected) |
dram__throughput.avg.pct_of_peak_sustained_elapsed |
HBM utilization | >70% for memory-bound (you want this) |
l1tex__t_sectors_pipe_lsu_mem_global_op_ld.sum |
Global loads in sectors | Compare unfused vs fused; should drop |
launch__registers_per_thread |
Reg pressure | <128 typical, >196 alarming |
launch__shared_mem_per_block |
SMEM use | <96 KiB to allow 2 blocks/SM on H100 default |
smsp__warps_eligible.avg.pct_of_peak_sustained_elapsed |
Warp scheduler utilization | >70% means latency is well-hidden |
The "Did fusion work?" test: profile the unfused chain and the fused kernel. Compare dram__bytes_read.sum + dram__bytes_write.sum. A successful fusion reduces this by approximately the predicted amount from §2.
13. When NOT to fuse¶
Three situations where fusion is the wrong call:
13.1 When the unfused kernels are already individually fast and the activations need to be saved for the backward pass¶
For training, the activation produced by an intermediate op is often needed by the backward pass. If you fuse the op with the next one, you must either (a) recompute the activation in the backward pass (the activation-checkpointing pattern), or (b) write the intermediate to HBM anyway, eliminating the fusion benefit.
PyTorch's torch.compile handles this with AOTAutograd partitioning - the forward and backward graphs are jointly optimized; the partitioner decides what to save vs recompute. For hand-rolled training kernels, this trade is explicit.
13.2 When fusion harms debuggability and the perf delta is small¶
A fused kernel that delivers 5% latency improvement but is 10× harder to debug, profile, and modify is a net loss for an actively evolving codebase. Save aggressive hand-fusion for the inner loop of mature, stable code paths.
13.3 When the fused kernel's autotune surface is too large¶
A fused matmul with 3 epilogue variants × 5 tile shapes × 4 num_warps × 3 num_stages = 180 autotune configurations. Each can take seconds to compile and benchmark. For one-off scripts, the autotune time exceeds the inference time saved. Production engines amortize this with a tuning cache (Triton's @triton.autotune does this automatically - cache key is the input shape).
14. Practical exercises¶
Exercise 1 - Quantify the win¶
Take a Llama-2-7B forward pass at batch=4, seqlen=2048, BF16. Compute, from first principles:
- Total HBM bytes moved by the elementwise + normalization operations in one decoder layer, unfused (model each LayerNorm/RMSNorm as 3 kernels, each residual add as 1, each activation as 1).
- The same, fully fused (each block executes RMSNorm and the SwiGLU pipeline as single fused kernels).
- Estimated latency saving per layer on H100 (3 TB/s HBM).
Hint: hidden = 4096; intermediate = 11008; head_dim = 128; n_layers = 32. Show your work.
Exercise 2 - Implement and benchmark fused RMSNorm¶
Implement the single-pass RMSNorm kernel from §9.2 in Triton. Benchmark vs the PyTorch nn.RMSNorm equivalent at shapes (B, S, H) = (8, 4096, 4096) and (1, 1, 4096) (training vs decode). Report:
- Throughput (TB/s of effective HBM bandwidth).
- Numerical max-abs-error vs an FP32 reference computed in
torch.float64.
Bonus: show that BF16-accumulator RMSNorm diverges from FP64 by >1e-2 at H=8192 and explain why.
Exercise 3 - GEMM epilogue fusion in CUTLASS¶
Pick one of: PyTorch's addmm (with the linear → add → activation pattern) or CUTLASS's LinearCombinationGELU epilogue example. Profile the fused vs unfused (separate matmul + bias + GELU) version at shape (M, N, K) = (8192, 8192, 8192) and report:
- Latency difference.
- HBM bytes moved (from
ncu). - Justify the gap with the §2 cost model.
Exercise 4 - Find a fusion in Inductor's output¶
Write a small PyTorch function with a fusible chain (e.g., x.relu().mul(2).add(1).sigmoid()). Compile with torch.compile, set TORCH_LOGS="output_code", and inspect the generated Triton kernel. Confirm that all four ops appear in one kernel. Find one example in your own codebase (or a public model) where Inductor failed to fuse a chain you expected it to, and explain why (read the Inductor scheduler logs).
Exercise 5 - Implement FlashAttention v1 in Triton¶
(Stretch.) Working from chapter 07's algorithmic pseudocode and chapter 03's Triton tutorial, implement FlashAttention v1 (forward only, no causal mask). Benchmark vs torch.nn.functional.scaled_dot_product_attention (which dispatches to FlashAttention) at (B, H, S, D) = (4, 16, 4096, 128). You should be within 3× of the optimized version on first attempt; closing the gap requires deeper tile-shape and stage tuning.
Exercise 6 - Precision regression hunt¶
Take the fused RMSNorm from exercise 2. Deliberately introduce a bug: compute the sum-of-squares in BF16 instead of FP32. Show numerically that:
- The error vs FP64 grows with the hidden dimension.
- The error grows faster than linearly (specifically,
O(sqrt(H))from the central-limit-theorem accumulation of rounding noise). - The error at H=8192 is large enough to perturb downstream logits past the temperature-sampling threshold for typical LLMs.
Connect to chapter 11 §3.2.
15. Cheat sheet and further reading¶
Cheat sheet¶
- Fuse elementwise chains aggressively. Compilers (Inductor, XLA) do this for free; verify they did.
- Fuse GEMM epilogues.
bias + activation + residualbelong in the matmul kernel. Useaddmm, CUBLASLt, or CUTLASS. - Fuse QKV and gate/up projections. Always. If you see three separate matmuls for Q, K, V - that's a perf bug.
- Fuse reductions with their producers/consumers (RMSNorm, softmax, top-k). Online algorithms (Welford, online softmax) make this single-pass.
- Reserve hand-Triton for the algorithmically hard cases (FlashAttention, fused MoE, paged attention, fused quantized GEMM).
- Keep FP32 accumulators inside fused kernels. Always. See chapter 11.
- Profile with
ncuand check that HBM traffic dropped by the predicted amount; if it didn't, fusion didn't happen.
Further reading¶
- PyTorch Inductor docs -
pytorch.org/docs/stable/torch.compiler_inductor.html. The scheduler and fusion-cost-model sections. - XLA fusion -
openxla.org/xla/operation_semantics. TheFusioninstruction and thefusion_kindenum. - CUTLASS epilogues - the
cutlass/epilogue/directory in the CUTLASS repo. EspeciallyLinearCombinationGeneric. - FlashAttention papers - Dao et al., FlashAttention (2022); FlashAttention-2 (2023); FlashAttention-3 (2024). Each is a different fusion algorithm on the same operator.
- Triton tutorials -
triton-lang.org/main/getting-started/tutorials/. The fused-softmax and fused-attention tutorials are the canonical references. - Horace He, Making Deep Learning Go Brrrr From First Principles (blog). The clearest exposition of the bandwidth-bound argument that motivates all of this.
- NVIDIA CUDA Best Practices Guide - the
Memory Optimizationschapter. Foundational.
Chapter 13 (not yet written) will continue with custom autograd for fused kernels - how to register backward passes for the kernels you fuse, and how to compose them through torch.autograd.Function and register_autograd (chapter 04 §custom-ops). Until that chapter exists, the canonical reference is the FlashAttention repo's csrc/ directory.