Skip to content

Deep Dive 07: Attention, the Transformer, and FlashAttention

A self-contained reference. By the end of this chapter you should be able to: derive scaled dot-product attention from first principles, implement causal multi-head attention from scratch matching F.scaled_dot_product_attention, reason about KV-cache memory for any decoder-only LLM, derive the online softmax that powers FlashAttention, and explain why FA-2 and FA-3 each roughly doubled throughput.


Table of contents

  1. From "predict the next token" to a transformer
  2. Scaled dot-product attention-full derivation
  3. Multi-head attention
  4. Causal masking
  5. MQA and GQA-why decode is bandwidth-bound
  6. Position encodings: Sinusoidal, Learned, RoPE, ALiBi, Sliding window, YaRN/PI/NTK
  7. The transformer block: pre-norm, residuals, FFN
  8. LayerNorm vs RMSNorm
  9. Activations: GeLU, SwiGLU
  10. KV-cache: math, layouts, paged attention
  11. Attention complexity: O(S^2) is the enemy
  12. FlashAttention-derivation of the online softmax and tiled algorithm
  13. FlashAttention-2 deltas
  14. FlashAttention-3 deltas
  15. Decode-time variant: flash_attn_with_kvcache
  16. Practical exercises

Notation used throughout:

B = batch size S = sequence length (often L_q or L_k for query/key length separately) H = number of attention heads d = model hidden size (sometimes d_model) d_h = per-head dimension; usually d / H d_k = key dimension per head (= d_h in standard attention) d_v = value dimension per head (= d_h in standard attention) H_q = number of query heads, H_kv = number of K/V heads (for GQA/MQA) V = vocab size N = batch * sequence dimension when flattened


1. From "predict the next token" to a transformer

1.1 The autoregressive language modeling setup

A language model is a probability distribution over token sequences. Given a vocabulary of size V and a sequence x_1, x_2, ..., x_T of token IDs in {0, 1, ..., V-1}, the model factorizes the joint probability via the chain rule:

P(x_1, x_2, ..., x_T) = prod_{t=1..T} P(x_t | x_1, ..., x_{t-1})

We call P(x_t | x_<t) the next-token distribution. Training a decoder-only transformer is exactly fitting a parametric model p_theta(x_t | x_<t) by minimizing the negative log likelihood:

L(theta) = -(1/T) sum_{t=1..T} log p_theta(x_t | x_<t)

For a batch this is just averaged. The gradient signal at every position t trains the model to predict its own next token from a left context.

The structure we need:

  • Input: a sequence of T tokens, embedded as vectors of dimension d. So x in R^{T x d}.
  • Output: a sequence of T vectors in R^{T x d} (one per position). These get projected to V-dim logits and softmaxed to give the next-token distribution at each position.
  • Constraint: the output at position t must depend only on inputs x_<= t (causal); otherwise the loss leaks the answer.
  • Inductive bias: every output position should be a function of all previous positions, not just the immediately preceding one. We do not want the gradient to have to traverse hundreds of recurrent steps.

The transformer's answer is: at every position, aggregate information from all earlier positions in parallel via attention, then mix it position-wise with an MLP, then repeat L layers deep.

1.2 Why attention: the "gather" intuition

Imagine T = 1024 tokens and we are computing the new representation for position t. Some earlier positions are highly relevant ("the antecedent of the pronoun two paragraphs back"); most are not. Conceptually we want a soft-lookup:

  • Each position t emits a query q_t-what it is looking for.
  • Each position s emits a key k_s-what it advertises.
  • Each position s emits a value v_s-the payload it would contribute.
  • The new representation at t is a weighted sum of v_s, where the weight is high when q_t and k_s match.

A natural similarity is the dot product q_t . k_s. To turn unbounded scores into a convex combination we softmax across s. That gives:

a_{t,s} = softmax_s( q_t . k_s )         # attention weights at row t
h_t = sum_s a_{t,s} v_s                  # output at position t

This is parallel across all (t, s), so we batch it as one matrix multiply. This is the entire idea. Everything below is making it numerically stable, making it scale to many heads, making it causal, encoding position, and making it fit in memory at long context.


2. Scaled dot-product attention-full derivation

2.1 The formula

Given matrices Q in R^{S x d_k}, K in R^{S x d_k}, V in R^{S x d_v}:

Attention(Q, K, V) = softmax( Q K^T / sqrt(d_k) ) V

Step by step:

  1. S = Q K^T # raw scores, shape (S, S)
  2. S_scaled = S / sqrt(d_k) # divide by sqrt(d_k)
  3. P = softmax(S_scaled, axis=-1) # row-wise softmax, shape (S, S)
  4. O = P V # output, shape (S, d_v)

Each row P[t, :] is a probability distribution over key positions. Each row O[t, :] = sum_s P[t, s] * V[s, :]. So O[t, :] is the convex combination of value rows weighted by how well key s matches query t.

2.2 Why divide by sqrt(d_k): the variance derivation

This is not aesthetic. It is required to keep the softmax in its useful regime as d_k grows.

Assume each component of q in R^{d_k} and each component of k in R^{d_k} are independent random variables with mean 0 and variance 1. (This is the post-normalization, post-init regime that pre-norm transformers operate in.)

The dot product is

q . k = sum_{i=1..d_k} q_i k_i

The expected value is E[q.k] = sum E[q_i] E[k_i] = 0 (independence + zero mean). The variance:

Var(q.k) = sum_{i=1..d_k} Var(q_i k_i)
         = sum_{i=1..d_k} ( E[q_i^2] E[k_i^2] - 0 )
         = sum_{i=1..d_k} (1 * 1)
         = d_k

So q.k has standard deviation sqrt(d_k). Without scaling, individual scores have magnitudes that grow as sqrt(d_k). For modern d_k (64 or 128), this puts the softmax inputs in the regime where one entry dominates and the rest are crushed to zero. Two consequences:

  1. The softmax becomes nearly one-hot. Attention degenerates to a hard argmax look-up, which is hard to train (the gradient through softmax is approximately p_i (delta_{ij} - p_j); when p is one-hot, almost all entries are zero or saturated).
  2. The forward and backward become numerically fragile. Tiny perturbations to a single near-max score flip which key wins.

Dividing by sqrt(d_k):

Var(q.k / sqrt(d_k)) = Var(q.k) / d_k = 1

so the scaled scores have standard deviation 1 regardless of d_k. The softmax stays in a regime with meaningful gradients. That is the entire argument.

A common confusion: why sqrt(d_k) and not d_k? Because we want standard deviation to be 1, not variance. Variance scales linearly with d_k, so the standard deviation scales as sqrt(d_k).

2.3 Numerically stable softmax

Real implementations never compute softmax(z) as exp(z) / sum(exp(z)) naively because exp can overflow. The standard trick:

softmax(z)_i = exp(z_i - max(z)) / sum_j exp(z_j - max(z))

Subtracting the max is mathematically a no-op (top and bottom both get multiplied by exp(-max(z))) but keeps every exponent <= 0, so all values are in (0, 1]. Hold this idea-the same identity is what makes FlashAttention's online softmax possible (Section 12).

2.4 Tensor shapes

Walk through one attention layer.

Input: x in R^{B x S x d} Project: Q = x W_Q, K = x W_K, V = x W_V (each W is d x d) so Q, K, V in R^{B x S x d} Reshape: view as (B, S, H, d_h) where d_h = d / H transpose to (B, H, S, d_h) Scores: QK^T over last two dims gives (B, H, S, S) Softmax: along last axis, gives (B, H, S, S) Apply V: (B, H, S, S) x (B, H, S, d_h) -> (B, H, S, d_h) Reshape: transpose back to (B, S, H, d_h), view as (B, S, d) Project: O = h W_O, W_O in R^{d x d}, output (B, S, d)


3. Multi-head attention

3.1 Why multi-head

A single attention head computes one scoring function q.k. There are many useful relations between tokens-syntactic dependency, coreference, positional adjacency, semantic similarity-and forcing one head to encode all of them in a single d-dim space is a bottleneck. Multi-head says: split the d-dim space into H subspaces, each of dim d_h = d / H, and run an independent attention per subspace. Concatenate the H outputs, project.

The total parameter count and FLOPs do not change: one big d x d projection is identical to H independent (d x d_h) projections concatenated. The computational difference is in the QK^T step: instead of one (S, d)x(d, S) matmul of cost O(S^2 d), you do H independent (S, d_h)x(d_h, S) matmuls of cost O(S^2 d_h) each, totalling O(H * S^2 * d_h) = O(S^2 d). Same FLOPs.

What changes is the expressive structure. Each head's scores are softmax-normalized independently, so head i can spend all its probability mass on syntactic neighbors while head j attends globally to topical anchors.

3.2 Fused QKV projection

Implementations almost always fuse the three projections into one matmul:

qkv = x @ W_qkv        # W_qkv in R^{d x 3d}
Q, K, V = qkv.split(d, dim=-1)

This is one big GEMM (general matrix multiply) instead of three small ones. On GPUs, one large matmul beats three small ones because of launch overhead and tensor-core utilization. The math is identical: the columns of W_qkv are just [W_Q | W_K | W_V] concatenated.

For GQA (Section 5) the fusion is asymmetric: W_qkv has shape d x (H_q + 2*H_kv) * d_h, so the K and V slices are narrower than Q.

3.3 Pseudocode for multi-head attention

def mha(x, W_qkv, W_o, H):
    B, S, d = x.shape
    d_h = d // H
    qkv = x @ W_qkv                                 # (B, S, 3d)
    q, k, v = qkv.split(d, dim=-1)                  # each (B, S, d)
    q = q.view(B, S, H, d_h).transpose(1, 2)        # (B, H, S, d_h)
    k = k.view(B, S, H, d_h).transpose(1, 2)
    v = v.view(B, S, H, d_h).transpose(1, 2)
    scores = q @ k.transpose(-1, -2) / sqrt(d_h)    # (B, H, S, S)
    scores = scores.masked_fill(causal_mask, -inf)  # (S, S) lower tri
    p = softmax(scores, dim=-1)                     # (B, H, S, S)
    out = p @ v                                     # (B, H, S, d_h)
    out = out.transpose(1, 2).reshape(B, S, d)      # (B, S, d)
    return out @ W_o                                # (B, S, d)

4. Causal masking

4.1 Why we need it

In autoregressive training, we feed the model the entire sequence x_1..T once, compute the output at every position in parallel, and demand that position t predict x_{t+1}. If position t's output is allowed to depend on x_>t, the loss is zero by trivial copying-the model has not learned anything.

We therefore need: output at position t is a function of x_1..t only. Concretely, in the (S, S) attention matrix, row t may have nonzero entries only in columns 1..t. Columns t+1..S must be zeroed.

4.2 Implementation: -inf before softmax

You cannot zero out the probabilities after softmax-softmax normalizes, so zeroing some entries breaks the convex-combination invariant. Instead you set the scores (pre-softmax) at masked positions to -inf:

M[t, s] = 0          if s <= t
M[t, s] = -inf       if s > t

scores += M
P = softmax(scores, dim=-1)

Because exp(-inf) = 0, those positions contribute nothing to the normalization sum or the output. The remaining positions still form a valid probability distribution.

In code:

mask = torch.full((S, S), float('-inf'))
mask = torch.triu(mask, diagonal=1)   # zeros on/below diagonal, -inf above
scores = scores + mask                # broadcast over (B, H)

torch.triu(..., diagonal=1) sets the strictly upper triangle to whatever the original matrix had and zero elsewhere-but here we want -inf in the upper tri and 0 in the lower, which is what the snippet above produces because the source matrix is all -inf and triu keeps the strict upper.

4.3 The mask in matrix form

For S = 4:

M = [ [  0  -inf -inf -inf ]
      [  0    0  -inf -inf ]
      [  0    0    0  -inf ]
      [  0    0    0    0  ] ]

After softmax, the attention probability matrix is lower triangular. Each row sums to 1. Row t spreads probability over columns 1..t.


5. MQA and GQA-why decode is bandwidth-bound

5.1 Standard MHA

Standard multi-head attention has H query heads and H K/V heads-every query head has its own private K and V. KV-cache size scales with H.

5.2 Multi-Query Attention (MQA)

Shazeer 2019. All H query heads share a single K head and a single V head. So while Q has shape (B, H, S, d_h), K and V have shape (B, 1, S, d_h). The attention scores Q @ K^T broadcast K across the H query heads: (B, H, S, d_h) @ (B, 1, d_h, S) -> (B, H, S, S).

KV-cache shrinks by a factor of H. For Llama-3-70B (H=64) this would be a 64x reduction. The trade-off: model quality drops measurably because all heads must agree on what to store. MQA was the right answer for PaLM and early-era inference engines but has been largely superseded.

5.3 Grouped-Query Attention (GQA)

Ainslie et al. 2023. Compromise between MHA (H KV heads) and MQA (1 KV head). Pick H_kv such that H_q is a multiple of H_kv. Group every G = H_q / H_kv query heads to share one K/V head.

  • Llama-3-70B: H_q = 64, H_kv = 8, so G = 8.
  • Llama-3-8B: H_q = 32, H_kv = 8, so G = 4.

KV-cache shrinks by G compared to MHA, while quality matches MHA almost exactly in published benchmarks. This is now the default for serious LLMs (Llama 2/3, Mistral, Qwen, DeepSeek, ...).

5.4 Why this matters for inference

During decode (generating one token at a time after the prompt is processed), each step does:

  1. Read Q for the current token (1 token, fast).
  2. Read K and V for all S tokens in the cache (the entire history).
  3. Compute attention.
  4. Append new K and V to the cache.

Step 2 reads the full KV-cache from HBM every single decode step. With S = 8192 tokens, MHA, BF16, 80 layers, 64 heads, d_h = 128:

KV bytes per token per layer per head = 2 (K and V) * 128 * 2 (BF16) = 512
Total KV bytes = 80 * 64 * 8192 * 512 = 21.5 GB

Reading 21.5 GB at H100 HBM bandwidth (~3.35 TB/s) takes ~6.4 ms per token just for KV. Decode is dominated by this read, not by the FLOPs of the matmul. Decode is memory-bandwidth-bound.

GQA cuts the KV by 8x. MQA cuts it by 64x. Now you understand why every serious LLM ships GQA: it directly multiplies decode throughput by ~8.

The exact KV-cache memory formula (for one request):

bytes = 2 * num_layers * H_kv * seq_len * d_h * dtype_bytes

6. Position encodings

The attention operation softmax(QK^T / sqrt(d_k)) V is permutation-equivariant in the sequence dimension: if you shuffle the rows of Q, K, V identically, the output rows shuffle the same way. Without a position signal the model cannot tell "the cat sat" from "sat cat the". We need to inject position.

6.1 Sinusoidal (Vaswani 2017)

For position p (0-indexed) and embedding dimension i, define

PE[p, 2i]   = sin(p / 10000^{2i/d})
PE[p, 2i+1] = cos(p / 10000^{2i/d})

Then x_p_with_pos = embed(x_p) + PE[p].

Why this form? Two properties:

  1. Each dimension is a sinusoid with a different frequency, ranging from wavelength 2pi (i=0) up to wavelength 10000 * 2pi (i = d/2).
  2. PE[p + k] is a linear function of PE[p] (for fixed k), because sin and cos satisfy sin(a + b) = sin(a) cos(b) + cos(a) sin(b) cos(a + b) = cos(a) cos(b) - sin(a) sin(b) so the model can learn linear projections that compute relative offsets.

Why it generalizes mediocrely: the model still has to learn to use the relative-position structure, and the additive coupling means position info gets mixed with content via the projection matrices. In practice sinusoidal extrapolation to 2x the trained context degrades quickly.

6.2 Learned absolute position embeddings

Instead of the closed-form PE, allocate a learnable matrix P in R^{S_max x d} and add P[p] to embed(x_p). Used by GPT-2, BERT, OPT.

Pros: simple, often matches or beats sinusoidal at trained lengths. Cons: cannot extrapolate at all-there are no learned vectors for positions beyond S_max. Hard limit on context length.

6.3 RoPE-Rotary Position Embedding (Su et al. 2021)

The most-used position encoding in modern LLMs (Llama, Mistral, Qwen, DeepSeek, Gemma, ...). Worth deriving in full.

6.3.1 Goal

We want the dot product q_m . k_n to depend only on the relative offset n - m (and on the contents of the tokens), not on absolute positions m and n separately. Concretely, we want a function f such that the modified query at position m, q'_m = R_m q_m, and modified key at position n, k'_n = R_n k_n, satisfy

q'_m . k'_n = g(q_m, k_n, n - m)

for some function g. Note: rotation R_m means apply some linear map that depends on position m.

6.3.2 Complex-number formulation

Pair up the d_h components of q in R^{d_h} into d_h/2 pairs. Treat each pair (q_{2i}, q_{2i+1}) as a complex number z_i = q_{2i} + i * q_{2i+1}. Same for k.

For a fixed angular frequency theta_i, define the rotation by position m as

z_i^{(m)} = z_i * exp(i * m * theta_i)

i.e. multiply the complex number by a unit-magnitude complex of angle m * theta_i. Equivalently, in R^2,

[ q'_{2i}   ]   [ cos(m theta_i)  -sin(m theta_i) ] [ q_{2i}   ]
[ q'_{2i+1} ] = [ sin(m theta_i)   cos(m theta_i) ] [ q_{2i+1} ]

This is a 2D rotation by angle m * theta_i in the (2i, 2i+1) plane.

The key calculation: the inner product of two rotated complex numbers z_a^{(m)} and z_b^{(n)} is

< z_a * exp(i m theta) , z_b * exp(i n theta) >
    = Re( (z_a exp(i m theta)) * conj(z_b exp(i n theta)) )
    = Re( z_a conj(z_b) * exp(i (m - n) theta) )

Crucially this depends on m - n only, not on m and n separately. Summed across all pairs i (each with its own theta_i),

q'_m . k'_n = sum_i Re( z_{q,i} conj(z_{k,i}) * exp(i (m-n) theta_i) )
            = g(q_m, k_n, m - n)

This is exactly the relative-position-only similarity we wanted.

6.3.3 Choice of frequencies

Following sinusoidal precedent,

theta_i = base^{-2i / d_h}     for i = 0, 1, ..., d_h/2 - 1

with base = 10000 typically. So low-i pairs rotate fast (tracking local relative offsets) and high-i pairs rotate slow (tracking global offsets).

6.3.4 Implementation

In practice you precompute two tables of shape (S_max, d_h/2):

cos_table[m, i] = cos(m * theta_i)
sin_table[m, i] = sin(m * theta_i)

Then the rotation at position m is applied componentwise to a query/key of shape (..., d_h):

def apply_rope(x, cos_table, sin_table, positions):
    # x: (B, H, S, d_h)
    # positions: (S,) integer positions
    cos = cos_table[positions]          # (S, d_h/2)
    sin = sin_table[positions]          # (S, d_h/2)
    x1 = x[..., 0::2]                   # even-indexed (B, H, S, d_h/2)
    x2 = x[..., 1::2]                   # odd-indexed
    rot1 = x1 * cos - x2 * sin
    rot2 = x1 * sin + x2 * cos
    out = stack([rot1, rot2], dim=-1).flatten(-2)
    return out

Two important rules:

  • RoPE is applied to Q and K, not to V.
  • RoPE is applied after the linear projections W_Q, W_K, before the attention scores are computed.
  • RoPE is applied per head, not per model-dim.

Inference with RoPE on a KV-cache: the K stored in the cache is already rotated. When you append a new token, you rotate its K with the current position and append. No re-rotation of past K is needed.

6.3.5 Why RoPE became dominant

  • No additional parameters (just trig tables).
  • Naturally relative-generalizes (with the extension techniques in 6.6) to longer contexts than trained.
  • Composes cleanly with FlashAttention because it is applied before attention, not as an additive bias inside the softmax.

6.4 ALiBi-Attention with Linear Biases (Press et al. 2022)

Even simpler: modify the score matrix directly with a position-dependent bias.

scores[t, s] = q_t . k_s / sqrt(d_k) + slope_h * (s - t)

Here slope_h is a head-specific negative slope (precomputed). The penalty grows linearly with how far back the key is. Different heads have different slopes, so some attend close, some attend far.

Slopes for H heads are typically chosen as a geometric sequence:

slope_h = 2^{-8 h / H}   for h = 1, ..., H

Pros: zero-shot extrapolation-performance degrades smoothly when going to contexts longer than trained. No extra parameters.

Cons: less expressive than RoPE in practice. Largely supplanted by RoPE in flagship models, but BLOOM and a few others used ALiBi.

6.5 Sliding window attention (Mistral)

Restricts each token to attend only to the previous W tokens (e.g. W = 4096). This is a modification of the mask, not a position encoding, but it is closely related because it is how you get "positional" locality.

Mask:

M[t, s] = -inf  if s > t                    (causal)
M[t, s] = -inf  if t - s >= W               (out of window)
M[t, s] = 0     otherwise

In stacked layers, the receptive field grows: token t at layer L can indirectly see tokens up to t - L*W away (information propagates through the stack like a CNN). Mistral 7B uses W = 4096 with 32 layers, giving an effective receptive field of ~131K tokens.

KV-cache benefit: only the most recent W K/V need to be stored, capping KV memory at O(W) instead of O(S).

6.6 Context extension: YaRN, NTK-aware, Position Interpolation

A model trained at context length L_train often needs to be extended to 4L_train or 8L_train at deploy time. Three families, all working in the RoPE frequency domain:

Position Interpolation (PI) (Chen et al. 2023): scale all positions by L_train / L_target. Geometrically, every token's RoPE rotation is slowed by the scale factor, so the maximum rotation angle the model sees is unchanged. Works but degrades quality-high-frequency dimensions lose discriminative power.

NTK-aware scaling (bloc97 / community 2023): instead of scaling all frequencies uniformly, change the RoPE base such that high-frequency dimensions are barely affected and only low-frequency dimensions are interpolated. Concretely

base' = base * (L_target / L_train)^(d_h / (d_h - 2))

Better preservation of local relative-position info than PI.

YaRN (Peng et al. 2023): a more careful per-frequency-band schedule. Bands are split into "extrapolation" (high freq, kept as-is), "interpolation" (low freq, scaled), and a transition region. Adds a small temperature correction to the attention softmax. Empirically the strongest of the three with comparable fine-tuning.

In all three, you typically fine-tune for a small number of steps on long-context data after applying the schedule.


7. The transformer block

7.1 The structure

A decoder-only transformer layer has two sub-blocks: attention and FFN (also called MLP). Each sub-block has a residual connection and a normalization. The two competing wirings are post-norm (Vaswani 2017) and pre-norm (used by every modern LLM).

Post-norm (original):

h = LayerNorm( x + Attention(x) )
y = LayerNorm( h + FFN(h) )

Pre-norm:

h = x + Attention( LayerNorm(x) )
y = h + FFN( LayerNorm(h) )

The difference matters. In pre-norm, the residual stream x flows from input to output unmodified by any normalization-the layer adds a correction computed from a normalized view of x. This means deep pre-norm transformers are easier to train: gradients flow through the residuals without being scaled by repeated normalizers. Post-norm transformers required learning-rate warmup to train at depth and were fragile beyond ~12 layers without tricks.

Modern stacks (Llama, Mistral, GPT-NeoX) all use pre-norm + RMSNorm.

7.2 ASCII diagram of a pre-norm block

     x  -----------------------------------+----------+
     |                                     |          |
     v                                     |          |
  RMSNorm                                  |          |
     |                                     |          |
     v                                     |          |
  Attention(Q=K=V=norm(x), causal,         |          |
            with RoPE on Q and K)          |          |
     |                                     |          |
     v                                     |          |
     +-------------(add residual)----------+          |
     |                                                |
     v                                                |
     h                                                |
     |                                                |
     v                                                |
  RMSNorm                                             |
     |                                                |
     v                                                |
  FFN  (e.g. SwiGLU: down(silu(gate(x)) * up(x)))     |
     |                                                |
     v                                                |
     +---------(add residual)-------------------------+
     |
     v
     y

7.3 Pseudocode

def block(x, params):
    h = x + attention(rmsnorm(x, params.norm1), params.attn)
    y = h + ffn(rmsnorm(h, params.norm2), params.ffn)
    return y

def transformer(tokens, params):
    x = embed(tokens, params.embed)
    for layer_params in params.layers:
        x = block(x, layer_params)
    x = rmsnorm(x, params.final_norm)
    logits = x @ params.embed.weight.T   # tied embeddings, often
    return logits

8. LayerNorm vs RMSNorm

8.1 LayerNorm

Ba et al. 2016. For a vector x in R^d:

mean = (1/d) sum_i x_i
var  = (1/d) sum_i (x_i - mean)^2
y    = (x - mean) / sqrt(var + eps)
out  = gamma * y + beta

Two reductions across the feature axis (mean and variance), two learnable parameters per dim (gamma is scale, beta is shift), one elementwise subtract, one elementwise divide.

8.2 RMSNorm

Zhang & Sennrich 2019. Drops the mean centering:

rms = sqrt( (1/d) sum_i x_i^2 + eps )
out = (x / rms) * weight

One reduction across the feature axis (sum of squares), one learnable parameter per dim (weight, the scale), no shift, no mean subtraction.

8.3 Why RMSNorm is enough

Empirical observation (Zhang & Sennrich, then countless replications): in pre-norm transformers, the mean centering of LayerNorm contributes little to model quality. The crucial operation is the variance normalization-bounding the magnitude of x so that the subsequent linear layer sees inputs of controlled scale. Mean subtraction is redundant given the high dimensionality and the fact that the gamma/beta parameters can absorb shifts.

Computational benefits:

  • One reduction instead of two-about 30-40% faster.
  • Half the parameters (no beta).
  • Slightly better numerics (the mean subtraction can subtract two similar values, losing precision).

Every Llama-family model and most modern open-weights LLMs use RMSNorm.


9. Activation functions in the FFN

9.1 The plain FFN

Vaswani 2017 used a two-layer MLP per position:

ffn(x) = W_2 ( gelu( W_1 x ) )

Sizes: W_1 is d x d_ff, W_2 is d_ff x d, with d_ff = 4d typically. Two matmuls. Non-linearity is GeLU (or ReLU originally).

9.2 GeLU

Gaussian Error Linear Unit, Hendrycks & Gimpel 2016:

gelu(x) = x * Phi(x)

where Phi is the standard normal CDF. Approximated as:

gelu(x) ~= 0.5 x (1 + tanh( sqrt(2/pi) (x + 0.044715 x^3) ))

Smoother than ReLU near 0, approximately ReLU for large |x|. Empirically better than ReLU for transformers.

9.3 SwiGLU

Shazeer 2020. A gated FFN: instead of one input matmul + nonlinearity, project the input through two matrices, multiply them elementwise after nonlinearity-ing one of them, then project down.

gate(x) = silu( W_gate x )      # silu(x) = x * sigmoid(x)
up(x)   = W_up   x
ffn(x)  = W_down ( gate(x) * up(x) )

Three weight matrices instead of two. To keep parameter count approximately equal to a plain GeLU FFN with d_ff = 4d, SwiGLU implementations use d_ff = (8/3) d ~= 2.67 d. Llama-3 uses d_ff ~= (8/3) d rounded to a friendly multiple.

The cost: 3 matmuls instead of 2 (about +50% FFN compute). The benefit: empirically better quality at fixed parameter count, and better still at fixed compute when tuned. It is now the default-Llama, Mistral, Qwen all use SwiGLU.

The "GLU family" is parameterized by which nonlinearity wraps the gate: GLU (sigmoid), ReGLU (relu), GeGLU (gelu), SwiGLU (silu). SwiGLU won.


10. KV-cache

10.1 Why it exists

In autoregressive decode, you generate one new token at a time. Naively each step would be:

for t in 1..T_gen:
    full_input = prompt + generated_so_far     # length t
    run full transformer forward on full_input
    sample the next token

The forward pass on length-S input does O(S^2 d) work in attention. So generating T tokens is O(T^3 d) total. This is catastrophic.

Observation: at decode step t+1, the K and V matrices for positions 1..t are exactly the same as they were at step t. Only one new K/V pair is added at position t+1. So we can cache K and V across steps.

10.2 The decode loop with KV-cache

K_cache, V_cache = empty
# Prefill: process the prompt of length P in one big forward
Q, K, V = project(prompt)
K_cache, V_cache = K, V
output = attention(Q, K_cache, V_cache, causal=True)
# take the last position's logits, sample x_{P+1}

# Decode: one token at a time
for t in P+1..P+T_gen:
    x_t = embedded(generated[t])               # 1 token
    q, k, v = project(x_t)                     # each (B, H, 1, d_h)
    K_cache = cat(K_cache, k, dim=seq)         # grow by 1
    V_cache = cat(V_cache, v, dim=seq)
    out = attention(q, K_cache, V_cache, causal=True)  # implicit: q sees all K
    logits = unembed(out)
    sample the next token

Per-step work in decode: project 1 token (O(d^2)), do attention with Q of length 1 against K/V of length t (O(t * d)), FFN on 1 token (O(d^2)). Linear in t, not quadratic. Generating T tokens is O(T^2 d) total-quadratic in T, not cubic.

10.3 KV-cache memory

For a single request:

bytes = 2 * num_layers * H_kv * seq_len * d_h * dtype_bytes

Where: - 2 covers K and V. - num_layers: number of transformer blocks. - H_kv: K/V heads (= H_q in MHA, smaller in GQA, 1 in MQA). - seq_len: how many tokens are in the cache. - d_h: per-head dim. - dtype_bytes: 2 for FP16/BF16, 1 for FP8/INT8.

10.4 Worked example: Llama-3-70B at 8K context, BF16

num_layers = 80 H_kv = 8 (GQA, H_q = 64, group size 8) seq_len = 8192 d_h = 128 dtype_bytes = 2

bytes = 2 * 80 * 8 * 8192 * 128 * 2 = 2,684,354,560 ~= 2.5 GiB per request

(If it were MHA with H = 64 instead of GQA H_kv = 8, this would be ~20 GiB per request, which is why GQA exists.)

Per token added to the cache:

bytes_per_token = 2 * 80 * 8 * 128 * 2 = 327,680 bytes ~= 320 KiB per token

So generating 8K new tokens adds ~2.5 GiB to the cache. At HBM bandwidth 3.35 TB/s on H100, just reading that 2.5 GiB once costs ~0.75 ms.

For a 32K context the KV is 10 GiB per request; for 128K it is 40 GiB per request. KV-cache, not weights, becomes the dominant memory consumer at long context.

Same model, MQA hypothetical (H_kv = 1):

bytes = 2 * 80 * 1 * 8192 * 128 * 2 = 320 MiB per request

Same model, MHA hypothetical (H_kv = 64):

bytes = 2 * 80 * 64 * 8192 * 128 * 2 = 20 GiB per request

The 8x reduction MHA -> GQA is exactly what makes long-context inference practical.

10.5 Layout: contiguous vs paged

Contiguous layout. Allocate one big tensor per request of shape (num_layers, 2, H_kv, max_seq_len, d_h). Simple, but you must reserve max_seq_len up front. If max_seq_len = 8K but a request stops at 200 tokens, you wasted (1 - 200/8192) ~= 97% of that allocation.

In a serving system with many concurrent requests of varying lengths, contiguous layout means you must either: - Pre-size for the worst case and accept massive waste, or - Refuse to add a new request unless you have full max-len space free.

This caps concurrency catastrophically-a 40 GiB GPU might be unable to host more than 4 requests at 8K despite using <10% of the cache.

Paged attention (vLLM, Kwon et al. 2023). Treat the KV-cache as a virtual memory: divide it into fixed-size blocks (e.g., 16 tokens per block). Each request keeps a block table-a list of physical block IDs it owns. When a request grows, it allocates a new block from a free pool. When it finishes, blocks return to the pool.

This is exactly the OS virtual-memory abstraction applied to attention. Benefits: - Internal fragmentation is bounded by one block per request (~16 tokens of waste, instead of 8K). - Many more concurrent requests fit in the same HBM. - Copy-on-write block sharing for prefix caching: if requests share a system prompt, they can share its KV blocks.

The cost: the attention kernel must be paged-aware-it indirects through the block table instead of a contiguous slab. PagedAttention kernels (vLLM, FA-3, SGLang) are non-trivial; the attention loop has to gather K/V from non-contiguous memory. Modern serving engines all use paged KV.


11. Attention complexity-why long context is hard

For a sequence of length S with model dim d (single head):

QK^T:     (S x d) @ (d x S)  -> (S x S),   FLOPs ~ S^2 * d
softmax:  on S x S,                         FLOPs ~ S^2
@ V:      (S x S) @ (S x d)  -> (S x d),   FLOPs ~ S^2 * d
Total compute:          O(S^2 * d)
Score-matrix memory:    O(S^2)

For S = 32K, the score matrix per head is 32K * 32K = 1 G entries. In BF16 that is 2 GB. Per head. Per layer. For one request. Materializing the score matrix in HBM is the binding constraint at long context-long before you run out of FLOPs, you run out of memory to hold the intermediate softmax matrix.

This is what FlashAttention solves.


12. FlashAttention

Dao et al. 2022. Two ideas working together: (a) tile the attention computation so that only small blocks of Q, K, V are in fast memory at any time, never the full S x S score matrix; (b) compute softmax online over the K tiles so each output row's normalization stays correct without seeing all scores at once.

12.1 GPU memory hierarchy

  • HBM (high-bandwidth memory): 40-141 GB on H100, ~3 TB/s. Big, slow by GPU standards.
  • SRAM / shared memory / register file: ~256 KB per SM, ~20 TB/s. Tiny, fast.

Standard attention reads/writes the S x S score matrix to HBM at every step (because it doesn't fit in SRAM). The S x S matrix dominates HBM traffic. FlashAttention's goal: do all attention work for a Q tile inside SRAM, never spilling the score matrix.

12.2 The online softmax-full derivation

The non-trivial part is doing softmax incrementally across K tiles. Suppose K is split into tiles K_1, K_2, ..., K_T and we want to compute, for a fixed Q tile (call its rows q):

P = softmax( [s_1 ; s_2 ; ... ; s_T] )    where s_j = q K_j^T / sqrt(d_k)
O = P [V_1 ; V_2 ; ... ; V_T]

We process one tile (K_j, V_j) at a time, maintaining for each row of Q three running statistics:

m   = current running max over scores seen so far
l   = current running sum of exp(score - m) over scores seen so far
O_  = current running unnormalized output (sum of exp(score - m) * v)

After processing all T tiles, the final output is O_ / l. Per row.

Now the recurrence. Suppose we have processed tiles 1..j-1 and have state (m, l, O_). We process tile j with scores s_j (a block of K_BLOCK columns) and values V_j.

Step 1: compute the local max of the new tile.

m_local = max( s_j )                               # scalar per row of q

Step 2: update the running max.

m_new = max( m, m_local )                          # scalar per row

Step 3: rescale the old running sum and output to be relative to m_new. The reason: the old l was computed as sum exp(score - m). To put it on the m_new scale, we multiply by exp(m - m_new):

correction = exp( m - m_new )                      # scalar per row, in (0, 1]
l_new_partial = l * correction
O_new_partial = O_ * correction                    # vector per row

Step 4: compute the new tile's contribution on the m_new scale.

p_j = exp( s_j - m_new )                           # block of weights, in (0, 1]
l_new = l_new_partial + sum( p_j )                 # add new tile's mass
O_new = O_new_partial + p_j @ V_j                  # add new tile's contribution

Step 5: store (m_new, l_new, O_new) as the new state.

After all tiles are processed, divide:

O_final = O_ / l                                   # the actual softmaxed output

12.2.1 Why this is mathematically equal to the all-at-once softmax

Let s_1, ..., s_T be the per-tile score blocks and V_1, ..., V_T the per-tile value blocks. Let m_global = max over all entries of all s_j. The all-at-once softmax gives:

p_global_j = exp(s_j - m_global) / Z,    Z = sum_j sum_entries exp(s_j - m_global)
O = sum_j p_global_j @ V_j
  = (1/Z) sum_j exp(s_j - m_global) @ V_j

We need to show that the online algorithm produces exactly this.

Inductive claim: after processing tiles 1..j with running state (m, l, O_),

O_ = sum_{k=1..j} exp(s_k - m) @ V_k
l  = sum_{k=1..j} sum_entries exp(s_k - m)

Base case j = 0: m = -inf, l = 0, O_ = 0. The empty sums are zero. The exp(s - m) is technically 0/0 with m = -inf, but we never evaluate it on zero entries because the tile 1 update sets m = m_1.

Inductive step: assume the claim holds for j-1 with running max m. We process tile j with new tile max m_j = max(s_j), new running max m' = max(m, m_j). The correction factor is c = exp(m - m').

After the update, claim m and the new running quantities are:

O_'  = O_ * c + exp(s_j - m') @ V_j
     = c * sum_{k=1..j-1} exp(s_k - m) @ V_k + exp(s_j - m') @ V_j
     = sum_{k=1..j-1} exp(m - m') exp(s_k - m) @ V_k + exp(s_j - m') @ V_j
     = sum_{k=1..j-1} exp(s_k - m') @ V_k + exp(s_j - m') @ V_j
     = sum_{k=1..j} exp(s_k - m') @ V_k

Identical algebra for l. So the inductive claim holds with m replaced by m'.

By induction, after T tiles, m equals the global max m_global, and:

O_ = sum_{k=1..T} exp(s_k - m_global) @ V_k
l  = sum_{k=1..T} sum_entries exp(s_k - m_global) = Z

so O_ / l = O. Online softmax = batched softmax exactly. No approximation. The numerical-stability trick (subtract the running max) is the same trick as ordinary stable softmax, just done incrementally.

12.3 The tiled algorithm

# Inputs: Q (S_q x d), K (S_k x d), V (S_k x d).
# Tile sizes: B_q (Q rows per tile), B_k (K rows per tile).
# Output: O (S_q x d).

for q_tile in range(0, S_q, B_q):                      # outer: over Q
    Q_tile = Q[q_tile : q_tile + B_q]                  # (B_q, d), load to SRAM
    m = full((B_q,), -inf)                             # running max, in SRAM
    l = zeros((B_q,))                                  # running sum
    O_ = zeros((B_q, d))                               # running output

    for k_tile in range(0, S_k, B_k):                  # inner: over K, V
        K_tile = K[k_tile : k_tile + B_k]              # (B_k, d), load
        V_tile = V[k_tile : k_tile + B_k]              # (B_k, d), load
        S = Q_tile @ K_tile.T / sqrt(d)                # (B_q, B_k), in SRAM
        if causal: S = mask(S, q_tile, k_tile)
        m_local = rowmax(S)                            # (B_q,)
        m_new = maximum(m, m_local)                    # (B_q,)
        correction = exp(m - m_new)                    # (B_q,)
        P = exp(S - m_new[:, None])                    # (B_q, B_k)
        l = l * correction + rowsum(P)                 # (B_q,)
        O_ = O_ * correction[:, None] + P @ V_tile     # (B_q, d)
        m = m_new

    O[q_tile : q_tile + B_q] = O_ / l[:, None]         # final normalize, store to HBM

ASCII picture:

Q (S_q x d)               K (S_k x d)           V (S_k x d)        O (S_q x d)
+-----+                  +---+---+---+        +---+---+---+        +-----+
| Q_1 |  outer loop -->  |K_1|K_2|K_3|  ...   |V_1|V_2|V_3|        | O_1 |
+-----+                  +---+---+---+        +---+---+---+        +-----+
| Q_2 |                                                            | O_2 |
+-----+                                                            +-----+
| ... |                                                            | ... |
+-----+                                                            +-----+

For each Q_i:
  for k_tile = 1..T:
    score block S_ik = Q_i K_k^T / sqrt(d)             [stays in SRAM]
    update (m, l, O_) with online softmax
  emit O_i = O_ / l                                    [single write to HBM]

12.4 Why this saves memory and bandwidth

Memory: the only on-the-fly working set is one Q tile (B_q x d), one K tile (B_k x d), one V tile (B_k x d), one score block (B_q x B_k), and the running state (B_q x (d + 2)). All small. The full S x S score matrix is never materialized anywhere.

Total memory for activations is O(B_q * d + B_k * d + B_q * B_k), with B_q and B_k chosen to fit in SRAM (e.g., 64 or 128). The dominant memory is the inputs and the output, both O(S * d). So overall O(S * d) memory instead of O(S^2).

Bandwidth: standard attention's HBM traffic is dominated by reading and writing the S x S score matrix-O(S^2) reads + O(S^2) writes. FA reads each of Q, K, V from HBM once (O(S * d)) and writes O once (O(S * d)). Score blocks live in SRAM and never see HBM. Effective HBM bandwidth drops from O(S^2) per pass to O(S * d) per pass, an S/d-fold reduction.

For S = 8192, d = 128 per head, that is a 64x reduction in HBM traffic. Attention is bandwidth-bound on real GPUs, so this translates almost directly into ~5-10x wall-clock speedup at long context.

12.5 Backward pass

For the backward, you also recompute the attention block-by-block (you don't store the S x S matrix during the forward, so you can't read it back). The trick: store only m and l per row from the forward-they are O(S) total. In the backward, recompute scores tile-by-tile using the stored m and l, derive dQ, dK, dV. Total backward FLOPs are roughly 2-3x forward FLOPs, dominated by the recomputation. But memory stays O(S) vs O(S^2) for naive backward.


13. FlashAttention-2 deltas

Dao 2023. FA-1 was already great but had several inefficiencies that FA-2 addressed. FA-2 is roughly 2x faster than FA-1 on A100, getting close to GEMM efficiency.

The key ideas:

Better work partition. FA-1 parallelized over (batch * heads), with the seq-length loop being sequential within each (batch, head) thread block. For long sequences with few heads (large S, small BH), GPU utilization was poor-too few thread blocks. FA-2 parallelizes over (batch * heads * seq_q), so the Q-tile* outer loop is also parallel. This makes long-context attention scale to all SMs.

Reduced non-matmul work. GPU tensor cores execute matmul-shaped work at peak throughput; everything else (exp, max, divide) runs on much slower CUDA cores. FA-1 spent more time on these "non-matmul" operations than necessary, particularly on the per-block rescaling. FA-2 reorders the algorithm so the running statistics are updated less frequently and the final divide-by-l is deferred to a single pass at the end. The overall ratio of matmul to non-matmul FLOPs improves from ~70% to ~95%.

Causal masking optimization. The lower-triangular mask means the upper-right triangle of the attention matrix is zero. FA-2 skips entire K tiles that are guaranteed to be fully masked (where k_tile_start > q_tile_end), saving roughly half the work on causal attention.

Forward and backward both improved. The new partition and reduced non-matmul applies to both passes; backward gets ~2x as well.

Net: ~2x throughput vs FA-1, getting attention up to 50-70% of GPU theoretical peak, vs ~25-40% for FA-1.


14. FlashAttention-3 deltas (Hopper-targeted, 2024)

Shah, Bikshandi, Dao et al. 2024. Targets H100 specifically. Not just an algorithmic improvement; it leverages new hardware features:

TMA (Tensor Memory Accelerator). H100 has dedicated hardware for async bulk memory copies between HBM and SRAM. FA-3 uses TMA to overlap loading the next K/V tile with computing on the current one. This hides HBM latency in the inner loop.

Asynchronous WGMMA tensor cores. H100's WGMMA instruction issues matmul work to the tensor cores asynchronously-the warp can keep computing other things (softmax, normalization) while a previous matmul is still finishing. FA-3 schedules QK^T, softmax, and PV simultaneously on different warps of the same warp group. This is "warp specialization": some warps fetch, some compute matmul, some compute softmax, all running concurrently. The pipeline is fully filled.

FP8 support. H100 has FP8 tensor cores at ~2x the throughput of BF16. FA-3 supports FP8 for Q, K, V with online quantization scales. Critical for serving models like Llama-3 in FP8 at maximum throughput.

Net: ~2x over FA-2 on H100 for BF16, ~5x for FP8. FA-3 gets attention to 75% of peak BF16 theoretical and 80%+ of peak FP8 theoretical.

FA-3 is Hopper-only because the techniques rely on H100-specific hardware. On A100 you still want FA-2.


15. The decode-time variant: flash_attn_with_kvcache

Decode is a different beast from prefill. In decode:

  • Q has length 1 (one new token).
  • K and V are read from the existing KV-cache.
  • The new K and V need to be appended to the cache.
  • Causal masking is implicit: the new Q sees all cached K (which are all earlier positions by construction).

A naive decode would be three separate kernels: append K, append V, attention. Three HBM round-trips. The fused flash_attn_with_kvcache:

  1. Take new Q (B, H_q, 1, d_h), new K and V (B, H_kv, 1, d_h), the existing KV-cache, and a per-batch sequence-length array.
  2. Inside one kernel, write the new K, V into the cache at the correct slots (using the seq-len array).
  3. Run flash attention with Q against the now-updated K, V cache.
  4. Return output.

The win: only one HBM read of the cache, only one write of the new K/V slot, no separate launch overhead. On long contexts this is the difference between 200 and 350 tokens/sec/request on a single H100.

For paged KV layouts, there is a paged variant that takes a block table and gathers K/V from non-contiguous physical blocks. This is the kernel that production engines (vLLM, SGLang, TRT-LLM) actually call during decode.


16. Practical exercises

Six problems. Solve all six with pencil and paper, then in code.

Exercise 1: Derive the online softmax for 3 blocks

Given scores split into three blocks s = [s^(1), s^(2), s^(3)] with corresponding values V = [V^(1), V^(2), V^(3)]. Set initial state m = -inf, l = 0, O_ = 0. Step through the algorithm:

  1. After processing s^(1): write m_1, l_1, O_1 in terms of s^(1), V^(1).
  2. After processing s^(2): write m_2, l_2, O_2 using m_1, l_1, O_1 and s^(2), V^(2). Show the correction factor exp(m_1 - m_2).
  3. After processing s^(3): write m_3, l_3, O_3.
  4. Show that O_3 / l_3 equals the all-at-once softmax(s) @ V.

This exercise tests that you can keep the inductive invariant straight across multiple correction steps. The algebra is identical to the proof in Section 12.2.1 but explicit for T = 3.

Exercise 2: KV-cache size for various models

Compute the KV-cache size in bytes per request for the following configurations at seq_len = 4096, BF16 (2 bytes/element). Use the formula: bytes = 2 * num_layers * H_kv * seq_len * d_h * dtype_bytes.

Model num_layers H_q H_kv d_h KV cache (MiB)
Llama-3-8B 32 32 8 128 ?
Llama-3-70B 80 64 8 128 ?
Mistral 7B 32 32 8 128 ?
GPT-3 (MHA) 96 96 96 128 ?

(Spoiler: 1024, 2560, 1024, 36864 MiB respectively.) Notice the GPT-3 result: 36 GiB of KV per request at 4K context with MHA. This is why GQA is non-negotiable for serving.

Repeat at seq_len = 32768 and explain in one sentence why GPT-3-style MHA is infeasible at long context for serving.

Exercise 3: Implement causal-masked MHA in PyTorch

Implement the function

def my_attention(Q, K, V, causal=True):
    # Q, K, V: (B, H, S, d_h)
    # returns: (B, H, S, d_h)
    ...

matching torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=causal) to within 1e-5 in BF16.

Edge cases to test: - S = 1 (the decode-step shape). - S = 1 query against a length-N cache. - Different lengths for query and key (use the lower-triangular mask appropriately).

Exercise 4: Implement RoPE and verify the relative-position property

Implement apply_rope(x, positions) with d_h = 64 and base = 10000. Then numerically verify:

  1. For random q, k and various m, n, the value of apply_rope(q, m) . apply_rope(k, n) depends only on (m - n), not on m and n separately.
  2. Specifically, vary m, n simultaneously by the same shift and show the dot product is invariant to within float precision.

Exercise 5: GQA broadcasting

In GQA, K and V have shape (B, H_kv, S, d_h) but Q has shape (B, H_q, S, d_h). Implement the GQA attention so that each group of G = H_q / H_kv query heads shares one K/V head. Verify that GQA with H_kv = H_q reduces exactly to standard MHA.

Hint: one approach is to repeat-interleave K and V along the head axis by G; another is to reshape Q to (B, H_kv, G, S, d_h) and broadcast. The reshape approach saves memory.

Exercise 6: Decode-time KV growth and HBM bandwidth bound

Take Llama-3-70B (GQA, 80 layers, H_kv = 8, d_h = 128, BF16).

  1. Compute the KV-cache size at S = 0, 1024, 4096, 16384, 65536.
  2. Assuming HBM bandwidth = 3.35 TB/s, compute the minimum time per decode step due to reading the entire cache.
  3. Plot tokens/sec vs context length implied by this lower bound.
  4. Compare to actual measured decode throughput from a real engine (e.g., vLLM benchmarks). Where does the gap come from?

Expected answers (rough):

  • At S = 4096: KV is ~1.25 GiB; bandwidth-bound time per step is ~0.4 ms; that is a ceiling of ~2700 tok/s/request. Real systems hit 50-150 tok/s at this point because they are not running solo decode at full bandwidth-they batch multiple requests.
  • At S = 65536: KV is ~20 GiB; ceiling drops to ~170 tok/s/request.

The exercise drives home that decode tok/s drops linearly with context length, and that the dominant cost is HBM bandwidth on KV, not FLOPs.


Closing notes

What you should now be able to derive without notes:

  1. Why divide by sqrt(d_k) (variance argument, Section 2.2).
  2. The exact KV-cache memory formula for any decoder model (Section 10.3).
  3. Why GQA exists and what factor it saves (Section 5).
  4. RoPE's relative-position property by complex number rotation (Section 6.3).
  5. The online softmax algorithm and its proof of equivalence (Section 12.2).
  6. The asymptotic argument: standard attention is O(S^2 d) compute and O(S^2) memory; FlashAttention is O(S^2 d) compute and O(S * d) memory with O(S * d) effective HBM traffic.

What you should know from memory but cannot derive:

  • Concrete per-version deltas of FA-1 → FA-2 → FA-3 (Sections 13, 14).
  • Specific architecture choices of Llama-3 / Mistral (GQA group sizes, SwiGLU dim ratio of 8/3, RoPE base 10000 or extended bases).

Cross-reference the inference-side material in Month 5 of the AI Systems Plan: /home/voseghale/projects/self_dev/AI_SYSTEMS_PLAN/. KV-cache memory formulas, paged attention, and FlashAttention's wall-time implications are the single biggest driver of inference engineering decisions, and they all flow from the math in this chapter.

Comments