Skip to content

Week 17 - LLM Inference, the KV-Cache, Attention Math

17.1 Conceptual Core

  • LLM inference has two distinct phases:
  • Prefill: process all input tokens in parallel. Compute-bound (a big matmul). Latency dominated by sequence length × hidden dim.
  • Decode: generate tokens one at a time, each requiring a forward pass attending to all previous tokens. Severely memory-bound. Latency dominated by reading model weights + KV-cache from HBM.
  • The KV-cache: at decode step t, the model needs all previous keys and values to attend to. Recomputing them every step is O(t²). Caching them is O(t). The cache is large: for a 70B model with 8K context, ~10 GB per request.
  • Why decode is memory-bound: each generated token reads ~70 GB of weights and ~10 GB of KV-cache → ~80 GB from HBM. At ~3 TB/s HBM, that's ~25 ms minimum, regardless of compute. This is why decode rarely uses tensor cores efficiently.

17.2 Mechanical Detail

  • Standard attention: O = softmax(QK^T / √d_k) V. For a single decode step on the t-th token, Q is (1, d), K and V are (t, d).
  • The KV-cache layout matters:
  • `(num_layers, 2, batch, num_heads, max_seq_len, head_dim) - naive contiguous. Wasteful: pre-allocates max_seq_len for every request even short ones.
  • Paged (vLLM): (num_layers, 2, num_blocks, num_heads, block_size, head_dim) plus per-request page tables. Fragmentation gone; sharing across requests possible (prefix caching).
  • FlashAttention (Dao et al., 2022 → v3 by 2024): tiles attention so the K/V never materialize a full t×t score matrix. Streaming: process Q in chunks, online softmax. Reduces HBM access from O(t²) to O(t·d). Critical for long contexts.
  • Multi-Query Attention (MQA) and Grouped-Query Attention (GQA): share K/V heads across query heads. MQA: 1 KV head per layer (8× shrinkage on Llama-2-70B style). GQA: groups (e.g., 8 KV heads shared across 64 query heads). Modern open models (Llama-3, Qwen) all use GQA.

17.3 Lab-"Decode From Scratch"

  1. Implement greedy decoding for a small Hugging Face model (Llama-3-8B works on a single A100; smaller for L4):
  2. Prefill once, capture KV-cache.
  3. Decode loop: forward(token, kv_cache) → next_token.
  4. Append next_token to KV-cache.
  5. Measure tokens/sec. Compute the achieved HBM BW (model weights × tokens / time).
  6. Replace standard attention with flash_attn_with_kvcache. Re-measure.
  7. Document the decode-vs-prefill latency split for a 1K-prefill, 512-decode request.

17.4 Idiomatic & Diagnostic Drill

  • nsys profile a decode step. The expected pattern: long matmuls (model weights load) interspersed with small attention kernels. Tensor cores ~10-20% utilized.

17.5 Production Slice

  • Inference cost is dominated by decode tokens × hardware-hour. Build a one-page estimator: given (model size, GPU type, batch size, concurrent requests), what is your tokens-per-dollar?

Comments