Week 17 - LLM Inference, the KV-Cache, Attention Math¶
17.1 Conceptual Core¶
- LLM inference has two distinct phases:
- Prefill: process all input tokens in parallel. Compute-bound (a big matmul). Latency dominated by sequence length × hidden dim.
- Decode: generate tokens one at a time, each requiring a forward pass attending to all previous tokens. Severely memory-bound. Latency dominated by reading model weights + KV-cache from HBM.
- The KV-cache: at decode step t, the model needs all previous keys and values to attend to. Recomputing them every step is O(t²). Caching them is O(t). The cache is large: for a 70B model with 8K context, ~10 GB per request.
- Why decode is memory-bound: each generated token reads ~70 GB of weights and ~10 GB of KV-cache → ~80 GB from HBM. At ~3 TB/s HBM, that's ~25 ms minimum, regardless of compute. This is why decode rarely uses tensor cores efficiently.
17.2 Mechanical Detail¶
- Standard attention:
O = softmax(QK^T / √d_k) V. For a single decode step on the t-th token, Q is(1, d), K and V are(t, d). - The KV-cache layout matters:
- `(num_layers, 2, batch, num_heads, max_seq_len, head_dim) - naive contiguous. Wasteful: pre-allocates max_seq_len for every request even short ones.
- Paged (vLLM):
(num_layers, 2, num_blocks, num_heads, block_size, head_dim)plus per-request page tables. Fragmentation gone; sharing across requests possible (prefix caching). - FlashAttention (Dao et al., 2022 → v3 by 2024): tiles attention so the K/V never materialize a full t×t score matrix. Streaming: process Q in chunks, online softmax. Reduces HBM access from O(t²) to O(t·d). Critical for long contexts.
- Multi-Query Attention (MQA) and Grouped-Query Attention (GQA): share K/V heads across query heads. MQA: 1 KV head per layer (8× shrinkage on Llama-2-70B style). GQA: groups (e.g., 8 KV heads shared across 64 query heads). Modern open models (Llama-3, Qwen) all use GQA.
17.3 Lab-"Decode From Scratch"¶
- Implement greedy decoding for a small Hugging Face model (Llama-3-8B works on a single A100; smaller for L4):
- Prefill once, capture KV-cache.
- Decode loop: forward(token, kv_cache) → next_token.
- Append next_token to KV-cache.
- Measure tokens/sec. Compute the achieved HBM BW (model weights × tokens / time).
- Replace standard attention with
flash_attn_with_kvcache. Re-measure. - Document the decode-vs-prefill latency split for a 1K-prefill, 512-decode request.
17.4 Idiomatic & Diagnostic Drill¶
nsys profilea decode step. The expected pattern: long matmuls (model weights load) interspersed with small attention kernels. Tensor cores ~10-20% utilized.
17.5 Production Slice¶
- Inference cost is dominated by decode tokens × hardware-hour. Build a one-page estimator: given (model size, GPU type, batch size, concurrent requests), what is your tokens-per-dollar?