Skip to content

Deep Dive 12 - Kernel Fusion: Theory, Practice, and the Compilers That Do It For You

Chapter 11 told you which precisions to use where. Chapter 12 tells you how to schedule those operations onto the GPU so the precision decisions actually pay off - by eliminating the HBM round-trips that dominate end-to-end latency in modern deep-learning workloads.

This chapter is self-contained. You can read it standalone; it pulls forward concepts from chapters 01 (GPU architecture), 02 (CUDA), 03 (Triton), 04 (PyTorch internals + Inductor), 05 (JAX + XLA), 07 (FlashAttention), and 11 (numerics) and will reference them by chapter number rather than re-deriving.


Table of contents

  1. Why fuse at all
  2. The HBM round-trip cost model
  3. Fusion taxonomy
  4. Vertical fusion derived
  5. Horizontal fusion derived
  6. GEMM epilogue fusion
  7. Streaming-reduction fusion: the FlashAttention pattern
  8. Compiler-driven fusion: XLA, TorchInductor, Triton
  9. Hand-rolled fusion in Triton: three full kernels
  10. Precision discipline under fusion
  11. The limits of fusion
  12. Profiling fused kernels with Nsight Compute
  13. When NOT to fuse
  14. Practical exercises
  15. Cheat sheet and further reading

1. Why fuse at all

The single most important observation in modern GPU performance work:

Most deep-learning operators in the forward pass of a transformer are memory-bandwidth-bound, not compute-bound.

To see why, recall the roofline from chapter 01. A modern H100 GPU has roughly:

  • BF16 dense tensor-core throughput: ~989 TFLOPS.
  • HBM3 bandwidth: ~3.0 TB/s.

The crossover arithmetic intensity at which a kernel transitions from memory-bound to compute-bound is:

I_crossover = peak_FLOPS / peak_BW = 989e12 / 3.0e12 ≈ 330 FLOP/byte.

A pure elementwise operation like y = a * x + b performs 2 FLOPs per 12 bytes moved (4 for x, 4 for y, optionally 4 for a; in BF16 halve it; doesn't matter - the intensity is well under 1 FLOP/byte). The GPU sits at <1% of peak FLOPs and 100% of peak bandwidth for the entire kernel.

The consequence: if your network is a chain of elementwise ops and small reductions, total time is determined almost entirely by total bytes moved through HBM - not by total work done. A LayerNorm → Linear → GELU → Dropout chain executed as four separate kernels reads and writes the activation tensor through HBM four times. Fused into one kernel, it reads once, writes once.

For the typical transformer hidden state at batch=8, seqlen=4096, hidden=8192, BF16:

activation size = 8 * 4096 * 8192 * 2 bytes = 512 MiB

Each HBM round-trip costs 512 MiB / 3 TB/s ≈ 175 µs. Saving three round-trips per layer × 80 layers = 240 round-trips = 42 ms per forward pass - for free, just by stopping the round-tripping. That is the prize.


2. The HBM round-trip cost model

Let n_op be the number of fused operations in a chain, S the size of the activation tensor in bytes, BW the HBM bandwidth, and K_launch the per-kernel launch overhead (~5 µs on a modern driver). Time for unfused execution:

T_unfused = n_op * (2*S / BW + K_launch)

(Each op reads S, writes S.)

Time for fused execution (one kernel reads once, writes once, does all the work in registers/SMEM):

T_fused = 2*S / BW + K_launch + T_compute_in_kernel

For elementwise chains, T_compute_in_kernel is negligible compared to the HBM term, so:

speedup ≈ n_op    (asymptotically, ignoring launch overhead)

A fused chain of 5 elementwise ops is roughly 5× faster than the unfused version, regardless of how clever the unfused kernels are individually. This is the headline result that motivates every fusion compiler ever written.

Worked numerical example

Take the post-attention residual stream of a Llama-3-70B layer, batch=1, seqlen=8192:

hidden = 8192,  bf16,   activation = 1 * 8192 * 8192 * 2 = 128 MiB

The post-attention chain: x + attn_out → RMSNorm → linear_gate → silu → linear_up · gate → linear_down → x + ffn_out.

Counting just the elementwise pieces (residual add, RMSNorm scale, SiLU, elementwise multiply, residual add) - five elementwise/light-reduction operations on the activation tensor:

T_unfused_elementwise = 5 * (2 * 128 MiB / 3 TB/s + 5 µs)
                     = 5 * (85 µs + 5 µs)
                     = 450 µs

T_fused_elementwise   = 2 * 128 MiB / 3 TB/s + 5 µs
                     = 90 µs

Per layer, per token, fusion saves ~360 µs in the elementwise chain. Across 80 layers and a 100-token decode, that's ~2.9 seconds. On a real inference engine the saving is closer to 30–50% of total latency because the matmuls still dominate, but eliminating elementwise round-trips is the single most impactful generic optimization in deep-learning compilers.


3. Fusion taxonomy

Fusion comes in five shapes, in roughly ascending implementation difficulty:

# Pattern Example Difficulty
1 Elementwise → elementwise (a + b) * c Trivial (every compiler does it)
2 Elementwise → reduction sum(x * x) (used in RMSNorm) Easy
3 Reduction → elementwise (broadcast) RMSNorm = x / sqrt(mean(x²) + ε) * γ Medium (needs two-pass or online algorithm)
4 GEMM + epilogue gelu(A @ B + bias) Medium (CUTLASS/CUBLASLt epilogue API)
5 Streaming reduction over GEMM output FlashAttention: softmax(Q@Kᵀ / √d) @ V Hard (requires algorithmic redesign; chapter 07)

A sixth, more ambitious shape - multi-GEMM fusion, where two matrix multiplies sharing an intermediate are fused (e.g., the FFN's up_proj and gate_proj in SwiGLU) - is increasingly common in production inference engines but requires either (a) careful CUTLASS programming or (b) horizontal fusion at the Triton level.

The taxonomy axis you actually care about is vertical vs horizontal:

  • Vertical (producer-consumer) fusion combines operations along the data-flow direction: op B consumes op A's output, so we keep A's output in registers/SMEM and feed it directly to B without writing to HBM. All of patterns 1–5 above are vertical.
  • Horizontal (sibling) fusion combines independent operations that have no data dependency, executing them in the same kernel to amortize launch overhead and (sometimes) share input loads. Example: q = x @ Wq; k = x @ Wk; v = x @ Wv can be done as one fused kernel that loads x once.

The next two sections derive each rigorously.


4. Vertical fusion derived

4.1 The producer-consumer pattern

Consider two operations:

B = f(A)
C = g(B)

Unfused, the dataflow through HBM is:

HBM:  read A   →   write B
HBM:  read B   →   write C
Total HBM traffic = |A| + 2|B| + |C|

Fused into one kernel:

For each tile of A:
    load tile_A from HBM into registers
    tile_B = f(tile_A)            # stays in registers
    tile_C = g(tile_B)            # stays in registers
    store tile_C to HBM
Total HBM traffic = |A| + |C|

We saved 2|B| bytes of HBM traffic. If f and g are elementwise and same-shape, |A| = |B| = |C|, so we cut traffic by 50% and (in the bandwidth-bound regime) doubled throughput.

4.2 The shape-compatibility requirement

Vertical fusion works only when the producer's output layout matches the consumer's input layout at the tile granularity. Two cases:

  • Pointwise op → pointwise op: trivially compatible (same element-to-element correspondence). Always fusible.
  • Reduction → broadcast: the reduction shrinks the tensor; the broadcast re-expands it. Fusion is possible but requires either (a) keeping the reduction result in SMEM and re-reading per element (the two-pass RMSNorm pattern), or (b) computing the reduction online during the consumer pass.

4.3 Worked example: RMSNorm fused

The naive RMSNorm:

mean_sq = (x * x).mean(dim=-1, keepdim=True)   # kernel 1
rrms    = torch.rsqrt(mean_sq + eps)           # kernel 2
y       = x * rrms * gamma                     # kernel 3

Three kernels, each round-tripping x (or a derivative of it) through HBM. Fused, in pseudocode:

def rmsnorm_fused(x, gamma, eps):
    # x: (..., H)
    # one tile = one row (H elements)
    for row in tiles(x):
        # pass 1: reduction
        s = 0.0
        for j in range(0, H, BLOCK):
            xj = load(row, j)            # HBM → registers
            s += sum(xj * xj)            # accumulate in fp32
        rrms = rsqrt(s / H + eps)        # scalar

        # pass 2: scale-broadcast
        for j in range(0, H, BLOCK):
            xj = load(row, j)            # HBM → registers (re-read!)
            gj = load(gamma, j)
            store(row, j, xj * rrms * gj)

We read x twice and write it once - total 3|x| HBM bytes - but we eliminated |x| for mean_sq (which never materialized) and saved two kernel launches. For the Llama-3-70B example in §2, the unfused version moves 5|x| bytes; the fused version moves 3|x|. Speedup: 5/3 = 1.67×.

A single-pass RMSNorm uses Welford-style online statistics to avoid the re-read of x - that drops traffic to 2|x|, the absolute floor. The Triton kernel in chapter 03 shows this.


5. Horizontal fusion derived

5.1 The independent-siblings pattern

Now consider three independent operations sharing an input:

Q = X @ Wq
K = X @ Wk
V = X @ Wv

Unfused: three kernel launches, each reading X from HBM. HBM traffic = 3|X| + |Q| + |K| + |V|.

Horizontally fused: one kernel reads X once and produces all three outputs. HBM traffic = |X| + |Q| + |K| + |V|. Saves 2|X|.

For X of shape (batch * seqlen, hidden) = (8 * 4096, 8192) in BF16 = 512 MiB, savings = 1 GiB of HBM traffic ≈ 340 µs at H100 bandwidth, just for the QKV projections per layer.

In practice, modern inference engines (vLLM, TensorRT-LLM) fuse QKV by concatenating [Wq, Wk, Wv] along the output dimension into a single W_qkv of shape (hidden, 3*head_dim*n_heads), and slicing after the matmul. This is mathematically identical to horizontal fusion and is the standard pattern - if a transformer codebase you read does not fuse QKV, that's a perf bug.

5.2 The SwiGLU case

For the FFN block:

gate = silu(X @ W_gate)
up   = X @ W_up
y    = (gate * up) @ W_down

W_gate and W_up can be horizontally fused into W_gu of shape (hidden, 2 * inter_dim), sliced into halves, with SiLU and the elementwise multiply fused as the epilogue. Saves |X| HBM read per FFN per layer. For Llama-3-70B at batch=1 decode, ~85 µs/layer × 80 layers = 6.8 ms per token - substantial.


6. GEMM epilogue fusion

A GEMM epilogue is any elementwise operation chained immediately after C = A @ B. The CUTLASS library (and its successor, CuTeDSL) supports declarative epilogue fusion via a templated programming model. Common epilogues:

  • C = A @ B + bias (the GEMM-bias-add pattern in every linear layer).
  • C = act(A @ B + bias) where act ∈ {ReLU, GELU, SiLU}.
  • C = act(A @ B + bias) * scale (for quantized inference).
  • C = act(A @ B + bias) + residual (the residual stream pattern in transformers - fuses the residual add into the matmul).

6.1 Why epilogue fusion is cheap

The GEMM kernel already has the output tile C_tile resident in registers immediately after computing it. Applying an elementwise function to it before storing costs zero additional HBM traffic. The only cost is a handful of extra instructions per register, well below the noise floor of the matmul itself.

In CUTLASS terms, the epilogue is a templated EpilogueOp that receives the accumulator tile and produces the output tile:

using EpilogueOp = cutlass::epilogue::thread::LinearCombinationGELU<
    ElementOutput,           // bf16
    128 / sizeof_bits<...>,  // vector length
    ElementAccumulator,      // fp32
    ElementCompute>;         // fp32

The LinearCombinationGELU epilogue computes act(α * accumulator + β * bias) in registers, then stores to HBM. One kernel; zero round-trip.

6.2 The "fuse the residual add into the matmul" trick

The residual connection in a transformer block computes:

x_new = x + linear(layernorm(x))

If the linear is a CUTLASS GEMM with epilogue D = α*(A@B) + β*C, you can pass x itself as C with β=1, and the residual add costs zero extra HBM traffic - the GEMM was going to write its output anyway; with the epilogue, it writes the residual-added output instead. This saves a full |x| round-trip per block, every layer, every forward pass.

This is implemented by addmm in PyTorch (when properly routed to CUBLASLt with the epilogue path) and by every production inference engine. If you write y = linear(x) + residual as two separate kernels in a hot path, that's a perf bug.


7. Streaming-reduction fusion: the FlashAttention pattern

The most algorithmically sophisticated fusion in modern AI is FlashAttention (chapter 07). The naive attention computation:

S = Q @ Kᵀ          # (B, H, M, N) - materialized
P = softmax(S, axis=-1)  # (B, H, M, N) - materialized
O = P @ V            # (B, H, M, d)

The intermediate S and P are O(M·N) and dominate memory at long sequence length. At seqlen=8192 with head_dim=128, S for a single batch×head pair is 256 MiB in BF16. For B=8, H=64, the total is 128 GiB. Doesn't fit.

FlashAttention's insight: S and P never need to be materialized in HBM. Compute the softmax incrementally, tile-by-tile, while accumulating O directly. The mathematical machinery is the online softmax derived in chapter 03 §online-softmax and rigorously in chapter 07.

For fusion purposes, the structural lesson is:

A reduction (softmax-then-matmul) over a streaming source (Q@Kᵀ computed tile-by-tile) can be fused into a single kernel that never materializes the intermediate.

This pattern - streaming a producer through a reducer with state kept in registers/SMEM - generalizes well beyond attention. Examples in production:

  • Cross-entropy loss fused with the final logits projection. Logits at vocab=128k × seqlen=4096 × batch=8 are 16 GiB; never materialize.
  • Top-k sampling fused with logits. Same memory argument.
  • MoE router + dispatch fused. The router's softmax + top-k + scatter can all run in a single kernel.

The price: the fused kernel is algorithmically non-trivial. Each new instance requires real engineering. Compilers cannot yet derive these fusions automatically; you write them by hand in Triton or CUDA.


8. Compiler-driven fusion: XLA, TorchInductor, Triton

Three major systems perform deep-learning kernel fusion in production. Their philosophies differ; their results converge.

8.1 XLA (JAX, TensorFlow)

Chapter 05 covers XLA in depth. The relevant fusion passes:

  • fusion pass: the canonical pass. Groups elementwise/broadcast/reduce ops into "fusion clusters" and emits one kernel per cluster. Driven by a cost model that estimates HBM traffic.
  • gpu_fusion_pipeline: the GPU-specific lowering. Emits LLVM IR with a single CUDA kernel per fusion. Modern XLA also emits Triton for some patterns (matmul + epilogue).
  • priority_fusion: newer pass with a priority queue over fusion candidates.

Fusion in XLA is declarative: you write pure functional JAX, XLA decides what fuses. You can inspect with jax.jit(f).lower(...).compile().as_text() (chapter 05).

8.2 TorchInductor (PyTorch 2)

Chapter 04 covers Inductor. Its fusion strategy:

  • Scheduler-driven node fusion: Inductor builds a graph of IRNodes (one per ATen op), then greedily fuses adjacent nodes whose fusion satisfies a memory-locality cost model.
  • Emit Triton or C++ for the fused kernel. GPU path emits Triton; CPU path emits C++ with OpenMP.
  • Pointwise + reduction + pointwise is the bread-and-butter fusion class. More than 80% of the speedups Inductor delivers come from this pattern (per PyTorch's perf blogs).

Inspect with TORCH_LOGS="output_code" (chapter 04) - you get the actual Triton source Inductor generated.

8.3 Triton autotuning (manual, but compiler-assisted)

Triton (chapter 03) is the kernel-author's tool, not a graph compiler. You write the fused kernel; Triton handles the lowering. The compiler contribution is in autotuning - exploring tile shapes, num_warps, num_stages combinations and picking the best.

Production stack composition (typical 2026 inference engine):

Model architecture (PyTorch)
torch.compile + Inductor      ← elementwise + reduction fusion (auto)
CUBLASLt / CUTLASS matmuls    ← GEMM + epilogue fusion (manual config)
FlashAttention / xFormers     ← streaming-reduction fusion (handwritten)
Triton custom kernels         ← anything Inductor missed (handwritten)

The lesson: let the compiler do the easy fusions; reserve human effort for the algorithmically hard ones (FlashAttention, paged attention, fused MoE, fused quantized GEMMs like Marlin).


9. Hand-rolled fusion in Triton: three full kernels

We work three increasingly complex examples. All assume the reader has read chapter 03 (Triton).

9.1 Kernel 1 - Fused bias-GELU-residual

Operation: y = gelu(x @ W + bias) + residual, where x: (M, K), W: (K, N), bias: (N,), residual: (M, N).

The matmul itself uses standard tiled GEMM (chapter 02). The fusion is in the epilogue: after computing the accumulator tile, apply bias, GELU, and the residual add in registers, then store.

import triton
import triton.language as tl

@triton.autotune(
    configs=[
        triton.Config({'BM': 128, 'BN': 256, 'BK': 32}, num_warps=8, num_stages=3),
        triton.Config({'BM': 64,  'BN': 128, 'BK': 32}, num_warps=4, num_stages=4),
        triton.Config({'BM': 128, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def fused_linear_gelu_residual(
    x_ptr, w_ptr, bias_ptr, residual_ptr, y_ptr,
    M, N, K,
    sxm, sxk, swk, swn, srm, srn, sym, syn,
    BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    offs_m = pid_m * BM + tl.arange(0, BM)
    offs_n = pid_n * BN + tl.arange(0, BN)
    offs_k = tl.arange(0, BK)

    x_ptrs = x_ptr + offs_m[:, None] * sxm + offs_k[None, :] * sxk
    w_ptrs = w_ptr + offs_k[:, None] * swk + offs_n[None, :] * swn

    acc = tl.zeros((BM, BN), dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BK)):
        x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BK, other=0.0)
        w = tl.load(w_ptrs, mask=offs_k[:, None] < K - k * BK, other=0.0)
        acc += tl.dot(x, w)
        x_ptrs += BK * sxk
        w_ptrs += BK * swk

    # --- epilogue, in registers, no HBM round-trip ---
    bias = tl.load(bias_ptr + offs_n, mask=offs_n < N, other=0.0)
    acc = acc + bias[None, :]

    # GELU approximation (tanh form), fp32
    c = 0.7978845608  # sqrt(2/pi)
    acc_g = 0.5 * acc * (1.0 + tl.math.tanh(c * (acc + 0.044715 * acc * acc * acc)))

    residual = tl.load(
        residual_ptr + offs_m[:, None] * srm + offs_n[None, :] * srn,
        mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0,
    )
    y = acc_g + residual.to(tl.float32)

    tl.store(
        y_ptr + offs_m[:, None] * sym + offs_n[None, :] * syn,
        y.to(y_ptr.dtype.element_ty),
        mask=(offs_m[:, None] < M) & (offs_n[None, :] < N),
    )

HBM traffic accounting:

  • Unfused (4 kernels: matmul, bias-add, GELU, residual-add): |x| + |W| + 4|y| + |bias| + |residual||x| + |W| + 5|y|.
  • Fused: |x| + |W| + |bias| + |residual| + |y||x| + |W| + 2|y|.
  • Savings: 3|y| HBM bytes per call.

9.2 Kernel 2 - Fused RMSNorm with online statistics

Two-pass RMSNorm requires re-reading x. One-pass uses Welford-style online updates. For RMSNorm specifically, since we only need the sum of squares (not variance), the update is simply additive:

@triton.jit
def rmsnorm_fwd_fused(
    x_ptr, gamma_ptr, y_ptr,
    stride_xm, stride_xn,
    stride_ym, stride_yn,
    N, eps,
    BLOCK_N: tl.constexpr,
):
    # One program instance handles one row.
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK_N)
    mask = cols < N

    x = tl.load(x_ptr + row * stride_xm + cols * stride_xn, mask=mask, other=0.0).to(tl.float32)
    sum_sq = tl.sum(x * x, axis=0)
    rrms = 1.0 / tl.sqrt(sum_sq / N + eps)
    gamma = tl.load(gamma_ptr + cols, mask=mask, other=0.0).to(tl.float32)

    y = (x * rrms * gamma).to(y_ptr.dtype.element_ty)
    tl.store(y_ptr + row * stride_ym + cols * stride_yn, y, mask=mask)

If N > BLOCK_N (hidden dim larger than what fits in one tile), this becomes two-pass with shared-memory state. For modern transformer hidden dims (4096–16384), one-tile-per-row is feasible up to BLOCK_N=16384 on H100 (uses ~64 KiB of registers/SMEM).

Note the precision discipline: load in BF16, promote to FP32 for the reduction and the divide, store back in BF16. Chapter 11 §3.3 explains why this is mandatory - accumulating sum-of-squares in BF16 catastrophically loses precision past hidden ≈ 1024.

9.3 Kernel 3 - Fused softmax (causal-masked, for attention)

The streaming softmax kernel from chapter 03, with causal masking and tile-wise online normalization:

@triton.jit
def causal_softmax(
    s_ptr, o_ptr, stride_b, stride_h, stride_m, stride_n,
    M, N,
    BLOCK_N: tl.constexpr,
):
    pid_bh = tl.program_id(0)
    pid_m  = tl.program_id(1)
    # Process one query row at a time
    row = pid_m
    cols = tl.arange(0, BLOCK_N)
    base = pid_bh * stride_h + row * stride_m

    # Online softmax state
    m_i = -float('inf')
    l_i = 0.0
    # First pass: find max and partial sum
    for start in range(0, N, BLOCK_N):
        offs = start + cols
        mask = (offs < N) & (offs <= row)              # causal
        s = tl.load(s_ptr + base + offs * stride_n, mask=mask, other=-float('inf')).to(tl.float32)
        m_new = tl.maximum(m_i, tl.max(s, axis=0))
        l_i = l_i * tl.exp(m_i - m_new) + tl.sum(tl.exp(s - m_new), axis=0)
        m_i = m_new

    # Second pass: normalize and store
    for start in range(0, N, BLOCK_N):
        offs = start + cols
        mask = (offs < N) & (offs <= row)
        s = tl.load(s_ptr + base + offs * stride_n, mask=mask, other=-float('inf')).to(tl.float32)
        p = tl.exp(s - m_i) / l_i
        tl.store(o_ptr + base + offs * stride_n, p.to(o_ptr.dtype.element_ty), mask=mask)

This kernel is a building block; the full FlashAttention kernel goes one step further and fuses the matmul P @ V into the same loop, never materializing P at all. See chapter 07 for the full derivation; the punchline is that the inner loop interleaves Q @ Kᵀ tile computation, online softmax update, and P @ V accumulation - all in registers.


10. Precision discipline under fusion

Fusion makes precision choices more dangerous, not less, because:

  1. Intermediate values that used to be written to HBM in their materialized dtype are now stored in registers in whatever dtype the producer last computed in. A BF16 elementwise op that previously rounded its output to BF16 may now keep it as FP32 in registers, and the next op consumes FP32 - which may be silently better, but is also a behavioral change.
  2. Accumulators inside a fused reduction must be FP32, not the input dtype. Chapter 11 §3.2 derives this for reductions; the rule applies verbatim inside fused kernels.
  3. Epilogues on GEMMs typically receive FP32 accumulator tiles and downcast at the last step. If you insert a precision-sensitive operation (a divide, an exp, a log) in the epilogue, do it in FP32 before the downcast.

The discipline cheat sheet:

Operation Compute dtype inside fused kernel Why
Elementwise add/mul match input no precision loss either way
Elementwise divide / sqrt / exp / log FP32 nonlinear; small inputs lose precision in BF16
Reduction (sum, mean, dot, max) FP32 catastrophic cancellation in BF16 past ~256 elements
Softmax FP32 internally, BF16 output both reduction and exp need FP32
LayerNorm / RMSNorm FP32 statistics, BF16 output reduction + divide
GEMM accumulator FP32 (tensor cores already do this for BF16/FP16 inputs) hardware default
GELU / SiLU activation FP32 if the epilogue, else match tanh/exp inside

If your fused kernel diverges from the unfused reference past ~1e-3 in BF16, you almost certainly downcasted an accumulator too early.


11. The limits of fusion

Fusion is not free; it competes for finite GPU resources. Three hard constraints:

11.1 Register pressure

Each Triton/CUDA kernel uses some number of registers per thread. An H100 SM has 65,536 32-bit registers shared across active warps. Occupancy = active_warps / max_warps_per_SM. A fused kernel with deeper computation needs more registers per thread; past a threshold, occupancy collapses and HBM-fetch latency stops being hidden by warp switching.

The relationship is:

max_threads_per_SM = registers_per_SM / registers_per_thread

If your fused kernel uses 128 regs/thread, you get 65536 / 128 = 512 threads per SM - only 16 warps. If you only need 32 warps to hide latency, this is fine; if you need 64, you've over-fused.

Diagnostic: nvcc --ptxas-options=-v (CUDA) or Triton's autotune output reports registers per thread. Above 128, look hard.

11.2 Shared memory capacity

H100 has 228 KiB of SMEM per SM (configurable). Fused kernels often use SMEM to stage intermediate tiles. Past the capacity, you can't fit two concurrent thread blocks per SM, halving occupancy.

For matmul kernels with epilogues, SMEM is dominated by the A and B tiles: 2 * BM * BK * dtype_bytes + 2 * BK * BN * dtype_bytes (the 2 is for double-buffering). Epilogue logic is usually register-only.

11.3 The kernel-launch amortization plateau

For very small inputs (batch=1, seqlen=1 in decoding), kernel launch overhead (~5 µs) is comparable to kernel runtime. Fusion's value is huge - eliminating a launch saves more than the kernel itself takes. But for very large inputs, launch overhead is amortized to zero and fusion's only value is the HBM-traffic reduction.

The decoding regime (single-token autoregressive) is the most fusion-sensitive workload in AI infrastructure. Every inference engine in production fuses aggressively because of this.

11.4 Tile-shape mismatches

If op A is naturally tiled (64, 128) and op B (128, 64), you cannot fuse them in the obvious way - A's output tile doesn't match B's input tile. You either accept a transpose in registers (cheap if it fits) or accept the unfused cost. Compiler-driven fusion (XLA, Inductor) deals with this by not attempting fusions that require expensive layout changes; the cost model rejects them.


12. Profiling fused kernels with Nsight Compute

You have a fused kernel. Is it actually fast? Nsight Compute (ncu) is the answer.

The minimal workflow:

ncu --set full --kernel-name fused_linear_gelu_residual -o report ./my_app
ncu-ui report.ncu-rep   # open the GUI

The metrics that matter for fused kernels:

Metric Meaning Healthy value
sm__throughput.avg.pct_of_peak_sustained_elapsed SM utilization >70% for compute-bound, <30% for memory-bound (expected)
dram__throughput.avg.pct_of_peak_sustained_elapsed HBM utilization >70% for memory-bound (you want this)
l1tex__t_sectors_pipe_lsu_mem_global_op_ld.sum Global loads in sectors Compare unfused vs fused; should drop
launch__registers_per_thread Reg pressure <128 typical, >196 alarming
launch__shared_mem_per_block SMEM use <96 KiB to allow 2 blocks/SM on H100 default
smsp__warps_eligible.avg.pct_of_peak_sustained_elapsed Warp scheduler utilization >70% means latency is well-hidden

The "Did fusion work?" test: profile the unfused chain and the fused kernel. Compare dram__bytes_read.sum + dram__bytes_write.sum. A successful fusion reduces this by approximately the predicted amount from §2.


13. When NOT to fuse

Three situations where fusion is the wrong call:

13.1 When the unfused kernels are already individually fast and the activations need to be saved for the backward pass

For training, the activation produced by an intermediate op is often needed by the backward pass. If you fuse the op with the next one, you must either (a) recompute the activation in the backward pass (the activation-checkpointing pattern), or (b) write the intermediate to HBM anyway, eliminating the fusion benefit.

PyTorch's torch.compile handles this with AOTAutograd partitioning - the forward and backward graphs are jointly optimized; the partitioner decides what to save vs recompute. For hand-rolled training kernels, this trade is explicit.

13.2 When fusion harms debuggability and the perf delta is small

A fused kernel that delivers 5% latency improvement but is 10× harder to debug, profile, and modify is a net loss for an actively evolving codebase. Save aggressive hand-fusion for the inner loop of mature, stable code paths.

13.3 When the fused kernel's autotune surface is too large

A fused matmul with 3 epilogue variants × 5 tile shapes × 4 num_warps × 3 num_stages = 180 autotune configurations. Each can take seconds to compile and benchmark. For one-off scripts, the autotune time exceeds the inference time saved. Production engines amortize this with a tuning cache (Triton's @triton.autotune does this automatically - cache key is the input shape).


14. Practical exercises

Exercise 1 - Quantify the win

Take a Llama-2-7B forward pass at batch=4, seqlen=2048, BF16. Compute, from first principles:

  • Total HBM bytes moved by the elementwise + normalization operations in one decoder layer, unfused (model each LayerNorm/RMSNorm as 3 kernels, each residual add as 1, each activation as 1).
  • The same, fully fused (each block executes RMSNorm and the SwiGLU pipeline as single fused kernels).
  • Estimated latency saving per layer on H100 (3 TB/s HBM).

Hint: hidden = 4096; intermediate = 11008; head_dim = 128; n_layers = 32. Show your work.

Exercise 2 - Implement and benchmark fused RMSNorm

Implement the single-pass RMSNorm kernel from §9.2 in Triton. Benchmark vs the PyTorch nn.RMSNorm equivalent at shapes (B, S, H) = (8, 4096, 4096) and (1, 1, 4096) (training vs decode). Report:

  • Throughput (TB/s of effective HBM bandwidth).
  • Numerical max-abs-error vs an FP32 reference computed in torch.float64.

Bonus: show that BF16-accumulator RMSNorm diverges from FP64 by >1e-2 at H=8192 and explain why.

Exercise 3 - GEMM epilogue fusion in CUTLASS

Pick one of: PyTorch's addmm (with the linear → add → activation pattern) or CUTLASS's LinearCombinationGELU epilogue example. Profile the fused vs unfused (separate matmul + bias + GELU) version at shape (M, N, K) = (8192, 8192, 8192) and report:

  • Latency difference.
  • HBM bytes moved (from ncu).
  • Justify the gap with the §2 cost model.

Exercise 4 - Find a fusion in Inductor's output

Write a small PyTorch function with a fusible chain (e.g., x.relu().mul(2).add(1).sigmoid()). Compile with torch.compile, set TORCH_LOGS="output_code", and inspect the generated Triton kernel. Confirm that all four ops appear in one kernel. Find one example in your own codebase (or a public model) where Inductor failed to fuse a chain you expected it to, and explain why (read the Inductor scheduler logs).

Exercise 5 - Implement FlashAttention v1 in Triton

(Stretch.) Working from chapter 07's algorithmic pseudocode and chapter 03's Triton tutorial, implement FlashAttention v1 (forward only, no causal mask). Benchmark vs torch.nn.functional.scaled_dot_product_attention (which dispatches to FlashAttention) at (B, H, S, D) = (4, 16, 4096, 128). You should be within 3× of the optimized version on first attempt; closing the gap requires deeper tile-shape and stage tuning.

Exercise 6 - Precision regression hunt

Take the fused RMSNorm from exercise 2. Deliberately introduce a bug: compute the sum-of-squares in BF16 instead of FP32. Show numerically that:

  • The error vs FP64 grows with the hidden dimension.
  • The error grows faster than linearly (specifically, O(sqrt(H)) from the central-limit-theorem accumulation of rounding noise).
  • The error at H=8192 is large enough to perturb downstream logits past the temperature-sampling threshold for typical LLMs.

Connect to chapter 11 §3.2.


15. Cheat sheet and further reading

Cheat sheet

  • Fuse elementwise chains aggressively. Compilers (Inductor, XLA) do this for free; verify they did.
  • Fuse GEMM epilogues. bias + activation + residual belong in the matmul kernel. Use addmm, CUBLASLt, or CUTLASS.
  • Fuse QKV and gate/up projections. Always. If you see three separate matmuls for Q, K, V - that's a perf bug.
  • Fuse reductions with their producers/consumers (RMSNorm, softmax, top-k). Online algorithms (Welford, online softmax) make this single-pass.
  • Reserve hand-Triton for the algorithmically hard cases (FlashAttention, fused MoE, paged attention, fused quantized GEMM).
  • Keep FP32 accumulators inside fused kernels. Always. See chapter 11.
  • Profile with ncu and check that HBM traffic dropped by the predicted amount; if it didn't, fusion didn't happen.

Further reading

  • PyTorch Inductor docs - pytorch.org/docs/stable/torch.compiler_inductor.html. The scheduler and fusion-cost-model sections.
  • XLA fusion - openxla.org/xla/operation_semantics. The Fusion instruction and the fusion_kind enum.
  • CUTLASS epilogues - the cutlass/epilogue/ directory in the CUTLASS repo. Especially LinearCombinationGeneric.
  • FlashAttention papers - Dao et al., FlashAttention (2022); FlashAttention-2 (2023); FlashAttention-3 (2024). Each is a different fusion algorithm on the same operator.
  • Triton tutorials - triton-lang.org/main/getting-started/tutorials/. The fused-softmax and fused-attention tutorials are the canonical references.
  • Horace He, Making Deep Learning Go Brrrr From First Principles (blog). The clearest exposition of the bandwidth-bound argument that motivates all of this.
  • NVIDIA CUDA Best Practices Guide - the Memory Optimizations chapter. Foundational.

Chapter 13 (not yet written) will continue with custom autograd for fused kernels - how to register backward passes for the kernels you fuse, and how to compose them through torch.autograd.Function and register_autograd (chapter 04 §custom-ops). Until that chapter exists, the canonical reference is the FlashAttention repo's csrc/ directory.

Comments