Deep Dive 10-Fine-Tuning: From SFT to RLHF¶
Prerequisites. Linear algebra, calculus through gradients, basic probability, a working mental model of transformer training. Familiarity with cross-entropy losses and autoregressive language modeling is assumed.
Cross-references. - Distributed training (FSDP, ZeRO-3, tensor/pipeline parallelism):
AI_SYSTEMS_PLAN/DEEP_DIVES/06. - Mixed-precision and FP8 numerics:AI_SYSTEMS_PLAN/DEEP_DIVES/11. - Eval discipline (held-out sets, calibration):AI_SYSTEMS_PLAN/DEEP_DIVES/08.Scope. This is the document the curriculum's reading list points to. Sequence 15 names LoRA, QLoRA, DPO, and GRPO without deriving them. Here we derive them end-to-end and pair the math with the engineering. The DPO derivation in §8 is the centerpiece.
Table of contents¶
- The decision matrix: prompt vs RAG vs fine-tuning
- Supervised fine-tuning (SFT)
- Catastrophic forgetting
- LoRA-full derivation
- QLoRA-full derivation
- Preference learning-RLHF concepts
- PPO for RLHF (high-level)
- DPO-full derivation
- GRPO
- Reward model design
- Preference data curation
- Constitutional AI / RLAIF
- Frontier-scale fine-tuning
- Full FT vs LoRA-the decision
- Evaluation before and after FT
- The end-to-end FT workflow
- Practical exercises
1. The decision matrix: prompt vs RAG vs fine-tuning¶
A pretrained model has three knobs you can turn to bend its behavior toward your task. They are not interchangeable; they live on different axes and answer different problems.
1.1 The three knobs¶
Prompt engineering. You ship the model unchanged. At inference time you prepend instructions, examples, or a system message that elicits the desired behavior. The model's weights are static; its context changes.
- What it gives you. Behavior, in-context. Few-shot patterns, persona, output format, refusal policy.
- What it costs. Tokens per call. A 4 k-token system prompt on every request is a 4 k-token tax on every request, forever.
Retrieval-augmented generation (RAG). At inference time, retrieve relevant documents from an external index (vector DB, BM25, hybrid) and inject them into the context. The model's weights are static; its facts come from outside.
- What it gives you. Knowledge access, freshness, citations, attribution.
- What it costs. Retrieval latency, index construction, index maintenance, retrieval-quality engineering, plus the per-call token tax for the retrieved passages.
Fine-tuning (FT). You change the weights. New (prompt, completion) pairs
or preference pairs gradient-update the model so the desired behavior is baked
in rather than re-elicited every call.
- What it gives you. Stable behavior across many prompts, with no per-call prompt overhead. New tone, new format conventions, new domain idiom.
- What it costs. A one-time training run (ranging from a few GPU-hours for small LoRA to thousands of GPU-hours for full FT of a 70B), plus a per-model serving slot (or, with multi-LoRA, a shared base + small adapters).
1.2 Rules of thumb¶
The clean decision rule is the stability × type matrix:
| Stable | Volatile | |
|---|---|---|
| Behavior | Fine-tune | Prompt |
| Knowledge | RAG (or embed in FT if tiny) | RAG |
- Behavior = how the model responds: tone, format, style, safety posture, reasoning pattern, refusals, JSON conformance, persona.
- Knowledge = facts the model relies on: product catalog, docs, policies, yesterday's customer ticket history.
- Stable = changes monthly or slower; volatile = changes daily or faster.
Stable behavior across many domains → fine-tune. The behavior is the same regardless of which fact you're answering with; bake it in once and pay zero prompt overhead per call.
Stable knowledge that doesn't fit in the prompt → RAG. Even if your manual never changes, you cannot fit 100 MB of docs in a prompt. Retrieve.
Volatile knowledge → RAG, always. Re-training every time a doc changes is absurd. Re-index instead.
Volatile behavior → prompt. If your team is iterating on tone twice a week, you cannot ship a fine-tune twice a week. Adjust the prompt; promote to FT only when it stabilizes.
1.3 Cost comparison (order of magnitude)¶
Let T = tokens of prompt overhead, Q = queries per day, c_in = cost per
input token.
- Prompt cost / day ≈
T · Q · c_in. Linear in queries, forever. - RAG cost / day ≈
(T + T_retrieved) · Q · c_in + index_ops. Slightly worse than prompt because retrieved chunks are extra context. - FT cost ≈
C_train(one-time) +Q · c_inat the base token count (no overhead). Compute amortizes; per-call you pay only for the actual question.
The tipping point: if T is 2–4 k tokens and traffic is non-trivial, the
amortized prompt tax beats the FT cost within weeks. This is how the math
stops being abstract.
1.4 What you usually combine¶
In production you almost never pick one. The default stack is:
- Base model (pretrained + instruct-tuned by the vendor).
- Light fine-tune (LoRA) on stable behavior-tone, JSON shape, refusals.
- RAG for the knowledge that lives in your DB, docs, tickets.
- Prompt for the residual-the things you tweak weekly.
When this deep dive talks about "fine-tuning," it almost always means layer 2: a LoRA, sometimes a DPO on top, on top of an instruct-tuned base.
2. Supervised fine-tuning (SFT)¶
SFT is the simplest form of fine-tuning. You have demonstrations: pairs
(x, y) where x is a prompt and y is the response you want the model to
produce. You train the model to maximize p(y | x).
2.1 Setup¶
The dataset is a list of (x, y) pairs. For chat models, x typically
includes the system prompt and prior turns; y is the assistant's reply.
Tokenize each pair into a single sequence: [x_tokens] [y_tokens]. Build an
attention mask covering the whole sequence so the model attends causally
across the boundary.
2.2 The loss¶
Standard autoregressive cross-entropy:
The crucial detail: mask the prompt tokens out of the loss. Concretely,
build a labels tensor that is the input ids shifted by one, with all
positions belonging to x set to - 100(the PyTorch ignore index). Only they` positions contribute to the loss.
2.2.1 Why mask the prompt¶
Two reasons:
- You don't want to teach the model to predict prompts. The user writes the prompt; modeling its distribution is wasted gradient. Worse: in chat, the prompt distribution is bizarre (system prompts, special tokens, role markers) and you don't want the model to bias toward producing it.
- Effective sample efficiency. Half your tokens being prompt is half your gradient being noise from the model's perspective.
In transformers, the canonical pattern is:
input_ids = tokenizer(x + y, return_tensors="pt").input_ids[0]
labels = input_ids.clone()
labels[: len(tokenizer(x).input_ids)] = -100 # mask the prompt
(With chat templates you mask everything except assistant turns.)
2.3 Data quality dominates data quantity¶
Folklore but well-supported: 1 000 hand-curated examples beat 100 000 noisy ones. The reason is mechanical: SFT tightens the model's output distribution toward the training distribution, including its mistakes. A noisy dataset gives the model permission to be sloppy.
Practical rules:
- Curate ruthlessly. Read every example yourself before training. If a human wouldn't be proud to ship that response, the model shouldn't either.
- Diversity matters more than count. 1 000 examples covering 50 task archetypes beat 10 000 examples of the same five.
- Prefer expert demonstrations. Subject-matter experts produce 3–5× cleaner data than crowdworkers, and SFT is bottlenecked by ceiling, not volume.
2.4 Hyperparameters that matter¶
These are not magic. They follow from gradient stability and from the size of the update you're trying to make.
- Learning rate.
- Full FT: 1e-5 to 5e-5 is the typical band. Larger models want smaller LRs; a 70B should be at the low end. Approximate.
- LoRA: 1e-4 to 5e-4-a decade higher because the trained tensor is initialized at zero and is much smaller, so it can absorb a larger update without destabilizing.
- Epochs. 1–3. More epochs trade generalization for memorization. If your eval is plateauing at epoch 2, stop. If it's still climbing at epoch 3, you probably have not enough data, not too few epochs.
- Warmup. 5–10 % of total steps, linear from 0 to peak LR. Skipping warmup on a freshly-loaded pretrained model is a great way to corrupt early layers.
- LR schedule. Cosine decay to ~10 % of peak LR is the default; linear decay is fine.
- Batch size. Whatever fits in memory after gradient accumulation. The effective batch size matters more than the per-step batch size; aim for 64–256 sequences-equivalent.
- Sequence length. Match production. Don't train on 512-token sequences and serve 4 k.
- Sequence packing. Pack short examples into one sequence (with attention boundaries between them) to fill context efficiently. A dataset of 200-token chat turns wastes 87 % of a 1 600-token forward pass without packing.
2.5 Minimal SFT pseudocode (TRL)¶
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer, SFTConfig
model_name = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="bfloat16")
ds = load_dataset("json", data_files="sft.jsonl", split="train")
# rows: {"messages": [{"role": "system", ...}, {"role": "user", ...},
# {"role": "assistant", ...}]}
cfg = SFTConfig(
output_dir="out/sft",
num_train_epochs=2,
per_device_train_batch_size=4,
gradient_accumulation_steps=8, # effective batch 32
learning_rate=2e-5,
warmup_ratio=0.05,
lr_scheduler_type="cosine",
bf16=True,
packing=True,
max_seq_length=4096,
)
trainer = SFTTrainer(model=model, tokenizer=tokenizer,
train_dataset=ds, args=cfg)
trainer.train()
TRL's SFTTrainer masks the prompt for you when messages follows the chat
template, and packing=True does the sequence packing.
3. Catastrophic forgetting¶
A model fine-tuned heavily on a narrow distribution forgets things it used to know. Not metaphorically-measurably. MMLU drops; GSM8K drops; safety behavior drifts. This is catastrophic forgetting, and it is the central risk of fine-tuning.
3.1 Why it happens¶
Pretraining packs an enormous amount of knowledge into the model's weights. That knowledge lives as a fragile equilibrium of activations. SFT pushes the weights toward a small distribution (your data), and gradients that move the model toward your data are not, in general, gradients that preserve far-away knowledge. The model is solving a different problem now, and the old solution is collateral.
3.2 Mitigations, in order of strength¶
- Keep epochs low. 1 epoch with rich data forgets less than 5 epochs on the same data.
- Mix in instruction-tune data. During FT, blend in 5–20 % general instruction data (e.g., a slice of the original SFT distribution if you have it, or a public mix). This anchors the model.
- Use LoRA. A small rank-
rperturbation of the weights cannot express drastic forgetting (§4). The base remains intact and can be detached from the adapter at any time. - KL-regularize toward the base. Add a term
β · KL(π_θ || π_base)to the loss, so updates that move the output distribution far from the base are penalized. This is the same idea that makes RLHF stable (§6). - Replay buffer. Cache representative examples from the base distribution and interleave them.
3.3 Why FT-from-scratch is dangerous; FT-from-instruct is safer¶
A pretrained-only base model has not learned to follow instructions, refuse, or behave safely. SFT on top of that base on a narrow domain produces a model that is good at your task and aggressively bad at everything else, including safety. SFT on top of a vendor instruct-tuned model preserves the instruction/safety scaffolding by construction (especially with LoRA), and your fine-tune adds a thin layer of domain behavior.
The lesson: unless you have a very specific reason, always FT on top of the instruct-tuned variant.
4. LoRA-full derivation¶
LoRA (Hu et al., 2021) is the dominant parameter-efficient fine-tuning method. The derivation is short and the consequences are large.
4.1 The empirical insight¶
When you fully fine-tune a pretrained model, the weight update Δ = W' − W
is empirically low rank. That is, even though Δ is a d × k matrix with
nominal rank min(d, k), almost all of its singular values are tiny. The
fine-tuning update lives in a low-dimensional subspace of weight space.
This is intuitively reasonable: pretraining already filled in the heavy features; fine-tuning is steering, not relearning.
4.2 The decomposition¶
Parameterize the update as a low-rank product:
Then B · A is at most rank r by construction. Replace W + Δ with
W + B · A and freeze W. The trainable parameters are A and B only.
The forward pass becomes:
The right-hand side shows the implementation: compute A·x first (small,
r × 1), then B · (A·x) (d × 1), and add to the original W · x.
No new full-size matmuls.
4.3 Initialization¶
You need Δ = 0 at the start of training so the model output equals the
pretrained model exactly. The standard choice:
A~ Gaussian (e.g., Kaiming-uniform-the default in most LoRA libs).B = 0.
Then B · A = 0 regardless of A's values, so Δ = 0 at step 0.
Gradients still flow through A (because B is multiplied by A and
B's gradient is (∂L/∂Δ) · Aᵀ, which is nonzero when B = 0 and
`A ≠ 0 - wait, careful here).
Walk through the gradients explicitly. Let L be the loss and let
g = ∂L/∂(BA) ∈ R^{d × k}. Then:
At step 0, B = 0, so ∂L/∂A = 0. Only B updates. After one step,
B ≠ 0, and A starts updating too. So initialization is asymmetric on
purpose: the trainer takes a step on B first and unfreezes A on the
second step. In practice this works fine and converges identically to
symmetric initializations.
(Some libraries swap the convention-A zero, B Gaussian-which gives
the symmetric result. Either is fine; just match the library's docs.)
4.4 Scaling: the α/r trick¶
LoRA introduces a scalar:
The scaling factor α/r decouples learning rate from rank. Without it,
doubling r would double the magnitude of B · A's expected output (more
basis vectors summed), and you'd have to halve the LR to compensate.
Common practice: fix α and sweep r. Typical: α = 16 or α = 32,
r ∈ {8, 16, 32, 64}. With α/r scaling, the effective LR stays sane
across rank changes.
4.5 Where to apply LoRA¶
The original paper applied LoRA only to the query and value projections
of attention (W_q, W_v). The reasoning: those are the projections most
sensitive to fine-tuning.
Modern practice broadens this:
- Q, K, V, O (all four attention projections). Adds parameters but gives a noticeable quality bump.
- MLP weights (
W_up,W_down,W_gatefor gated MLPs like SwiGLU). Best quality. More parameters. The general consensus from recent fine-tuning work is that targeting MLPs matters as much as attention. - Embeddings and LM head. Usually skip; large parameter counts and small benefit for most tasks. Apply only when changing vocabulary semantics (e.g., adding new tokens).
The rule: more LoRA targets → better quality, more parameters. For most applications, all linear layers in the transformer block is the default that is hard to beat.
4.6 Memory math¶
Per matrix W ∈ R^{d × k}:
- Full FT trains
d · kparameters. - LoRA trains
r · (d + k)parameters.
For d = k = 4096 and r = 16:
- Full FT:
4096 · 4096 = 16 777 216 ≈ 16.8 Mparameters per matrix. - LoRA:
16 · (4096 + 4096) = 131 072 ≈ 131 kparameters per matrix. - Ratio: 128× fewer trainable parameters per matrix.
For Llama-7B (32 layers, applying LoRA to Q, V at r = 16):
- 32 layers × 2 matrices × 131 072 ≈ 8.4 M trainable parameters.
- The full model is ~7 B parameters.
- Trainable fraction: 8.4M / 7B ≈ 0.12 %.
You move 0.12 % of the parameters and recover most of the full-FT quality. This is the LoRA promise.
The optimizer state is also tiny. Adam stores two moments (m, v) per
trainable parameter, both in fp32. Full FT of 7B in mixed precision:
7B · (2 + 2) × 4 bytes = 112 GB of optimizer state alone. LoRA at
0.12 % trainable: ~135 MB. The optimizer fits in cache.
4.7 Inference: merge or keep separate¶
Two deployment modes:
- Merge. At inference time, compute
W' = W + (α/r) · B · Aonce and replace the base weight. Zero serving overhead-the model is shape- identical to the base. - Keep separate. Carry
BandAas side tensors. ApplyW·x + (α/r)·B·(A·x)at every forward. Tiny overhead. Lets you hot-swap adapters at request time.
4.8 Multi-LoRA serving¶
This is the modern multi-tenant pattern. Load one base model; carry many small adapters; route each request to the right one.
- One base model in GPU memory (e.g., 14 GB for an 8B in bf16).
- 50 customer-specific adapters at ~50 MB each = 2.5 GB.
- Total VRAM: 16.5 GB. One H100-80GB serves 50 fine-tunes.
Frameworks supporting this: vLLM ( - -enable-lora`), LoRAX, S-LoRA. The batched matrix multiply for multiple adapters in the same batch is the nontrivial systems work; the libraries handle it.
4.9 Minimal LoRA pseudocode (PEFT)¶
from peft import LoraConfig, get_peft_model
lora = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora)
model.print_trainable_parameters()
# trainable params: ~42M / total params: ~8B / trainable %: ~0.5
Drop this in front of an SFTTrainer and you have LoRA SFT.
5. QLoRA-full derivation¶
QLoRA (Dettmers et al., 2023) combines LoRA with aggressive 4-bit weight quantization. The combination is the reason a 70B fine-tune is feasible on a single 48 GB GPU.
5.1 The insight¶
LoRA already shrinks the optimizer state and trainable parameters. The remaining memory hog is the base model weights themselves (e.g., 70B in bf16 = 140 GB). If you quantize them to 4 bits, the base model takes 35 GB. The LoRA adapters stay in higher precision (bf16) and continue to train normally; the frozen quantized base is used only for forward and backward through the frozen weights.
The trick is that the gradient through a frozen weight only needs to read the weight, not write it. So you can leave the base in 4-bit storage and dequantize on the fly during the matmul; no quantization-aware-training machinery is needed for the base.
5.2 NF4 (NormalFloat-4) quantization¶
Standard INT4 quantization splits the value range into 16 evenly-spaced levels. For tensor values that are roughly uniformly distributed, this is near-optimal. For neural-network weights, which are well-modeled as zero-mean normal, uniform spacing wastes precision in the tails (where few weights live) and underdescribes the bulk near zero.
NF4's solution: choose the 16 levels to be the 16 quantiles of a standard normal distribution. Concretely, the levels are:
(Approximately-the QLoRA paper splits the levels symmetrically around zero and includes 0 exactly.) Then a normally-distributed weight tensor has approximately uniform mass in each of the 16 NF4 bins. This is information-theoretically optimal for normal weights: each level carries the same bit of information.
The lookup table is fixed-derived once from the normal CDF-and hardcoded. Quantization at runtime is: divide the tensor by its absmax into the [-1, 1] range, then for each value find the nearest level in the NF4 table, store the level index (4 bits) and remember the scale.
5.3 Double quantization¶
The scale factors themselves take memory. For a 70B model with a block size of 64, you have 70 B / 64 ≈ 1.1 B scale factors. In fp32, that's ~4.4 GB of scales-a non-trivial fraction of the quantized model.
Double quantization quantizes the scale factors themselves to FP8. Saves ~3 GB on a 70B. Not glamorous but free.
5.4 Paged optimizers¶
Even with all of the above, training spikes can OOM the GPU. NVIDIA's Unified Memory (UVM) lets you allocate optimizer state in a way that can page between GPU and CPU memory automatically, like virtual memory at the OS level. Optimizer states for momentum/variance are large but infrequently accessed during forward/backward; paging them out during peak activation memory is invisible.
Result: the GPU survives transient memory pressure that would otherwise kill the run.
5.5 Memory budget-70B on a single 48 GB GPU at r=64¶
Counting:
- Base weights (NF4). 70 B params × 4 bits = 35 GB. Subtract a tiny bit for double-quant constants (negligible after DQ).
- LoRA adapters. Apply LoRA to all linear layers (~7 modules per
layer × 80 layers = 560 modules). Average matrix size for a 70B is
roughly
8192 × 8192(model dim 8 192, MLP up to 28 672). Trainable per matrix at r=64 ≈64 · (8192 + 8192) = 1 048 576 ≈ 1 M. Across modules: roughly 200–300 M trainable params, in bf16 = 400–600 MB. Optimizer state (Adam, fp32, 2× the params) = ~2 GB. Total: ~2.5 GB. - Activations and gradients. With activation checkpointing, this is the dominant remaining cost. For batch 1, seqlen 2048, on a 70B with AC: ~6–10 GB. Without AC: 30+ GB and you OOM.
- Slack for kernels, KV, fragmentation. ~3 GB.
Total: 35 + 2.5 + ~8 + 3 ≈ 48–49 GB. Right at the line on a 48 GB GPU (A6000-Ada, RTX 6000 Ada, A40). A single H100-80GB has comfortable margin.
5.6 QLoRA pseudocode¶
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
bnb = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="bfloat16",
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.3-70B-Instruct",
quantization_config=bnb,
device_map="auto",
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, LoraConfig(r=64, lora_alpha=16, ...))
Pass optim="paged_adamw_8bit" to the trainer for paged optimizers.
6. Preference learning-RLHF concepts¶
SFT teaches the model to imitate good demonstrations. But humans are often better at judging than generating: writing the perfect customer-support reply is hard; picking which of two replies is better is easy. Preference learning leverages this asymmetry.
6.1 Why preferences instead of demonstrations¶
- Cheaper. A pairwise comparison is faster than authoring a gold reply.
- More reliable. Two raters tend to agree on rankings even when they'd produce different "ideal" answers.
- Captures style and nuance. "This response is more empathetic" is easy to mark and very hard to specify.
- Negative information. SFT can't tell the model what not to do; preferences can.
6.2 The classic RLHF pipeline (InstructGPT, 2022)¶
- SFT. Standard supervised fine-tune on demonstrations.
- Reward model (RM). Collect preference pairs
(x, y_w, y_l)(chosen, rejected). Train a modelr_φ(x, y) → ℝthat scores responses, with the loss derived in §10. - RL. Fine-tune the SFT policy
π_θwith reinforcement learning to maximize expected reward, subject to a KL penalty toward the SFT model.
The objective for stage 3:
Equivalently, per-token:
where r_t is typically zero for non-final tokens and r_φ(x, y) for the
final token.
6.3 The KL constraint, derived¶
Why the KL penalty? Without it, the policy will exploit the reward model. The RM is a fitted approximation of human preference; it has blind spots. Pure reward maximization runs the policy toward whatever the RM accidentally likes-verbosity, hedging, specific tokens-and quality collapses. This is reward hacking (§10.3).
The KL term β · KL(π_θ || π_SFT) says: stay close to the SFT model. The
SFT model is a known-good distribution (it produces fluent text); large
deviations are suspicious. β controls how tight the leash is.
In closed form, the KL-constrained optimal policy is
with Z(x) = Σ_y π_SFT(y|x) · exp(r(x,y)/β). Derivation:
We maximize, for each prompt x,
F(π) = E_{y ~ π} [r(x, y)] - β · KL(π(·|x) || π_SFT(·|x))
= Σ_y π(y|x) · r(x, y) - β · Σ_y π(y|x) · log(π(y|x)/π_SFT(y|x))
subject to Σ_y π(y|x) = 1. Lagrangian:
Take ∂/∂π(y|x):
Solve for π:
log(π(y|x)/π_SFT(y|x)) = (r(x, y) - β - λ(x)) / β
π(y|x) = π_SFT(y|x) · exp((r(x, y) - β - λ(x)) / β)
= π_SFT(y|x) · exp(r(x, y)/β) · C(x)
where C(x) = exp(-(β + λ(x))/β) is constant in y. Apply the
normalization Σ_y π(y|x) = 1:
So:
This is the KL-regularized RL optimal policy. We will use (★) in the DPO derivation in §8-it is the central identity.
7. PPO for RLHF (high-level)¶
Stage 3 of RLHF (the actual RL fine-tune) is traditionally done with PPO (Schulman et al., 2017). PPO is a policy-gradient algorithm with a trust-region-style clip to keep updates small.
7.1 The PPO clip¶
Let ratio_t(θ) = π_θ(y_t|x, y_<t) / π_θ_old(y_t|x, y_<t) be the importance
ratio between the current and the previous policy iterate. Let A_t be the
estimated advantage (token-level). PPO maximizes:
with ε ≈ 0.1–0.2. Why the min of a clipped and unclipped term: if the
update would push the policy further than 1+ε (or below 1-ε) and the
advantage is positive (negative), the clipped version is taken-which has
zero gradient-preventing runaway moves. If the advantage is negative and
the ratio drops below 1-ε, the unclipped term is more negative, and
that's what's selected, so the policy can still pull away from bad actions.
7.2 The four-model setup¶
PPO RLHF carries four models in memory simultaneously:
- Actor / policy (
π_θ)-the model being trained. - Critic / value function (
V_φ)-estimates expected return at each token, used to compute advantages via GAE. - Reward model (
r_ψ)-frozen, scores final responses. - Reference policy (
π_SFT)-frozen, used in the KL penalty.
Memory cost: roughly 2× the actor for the critic (often initialized from the same base), plus actor + RM + reference. For a 7B base, you're managing ~28 B parameters' worth of model state. For a 70B base, RLHF is genuinely a multi-node enterprise.
7.3 Why PPO RLHF is hard¶
- Hyperparameter sensitivity.
β,ε, RM-LR, actor-LR, critic-LR, KL target, GAE-λ, all interact. Small changes can collapse training. - Reward hacking. RM is imperfect; the actor finds exploits.
- KL ratchet. As training progresses, the policy can asymptotically
drift from
π_SFTeven with the KL penalty, especially on long generations. - Memory. Four models. Distributed RLHF on a 70B is real research infrastructure.
- Sample inefficiency. Each PPO step requires generating completions (slow, autoregressive) before the gradient step.
DPO (§8) was motivated by all of this: can we get the same alignment benefit without the RL stack?
8. DPO-full derivation¶
This is the chapter's centerpiece. DPO (Rafailov et al., 2023) shows that the classic RLHF objective has a closed-form optimum, that this optimum can be reparameterized in terms of the policy alone, and that the resulting objective is a simple supervised cross-entropy loss on preference pairs. No reward model. No PPO. No critic. No four-model setup.
8.1 The starting point: the KL-constrained RL objective¶
From §6.3 we had the optimal policy under the KL-regularized RL objective:
Two observations:
- The function
r(x, y)and the policyπ*(y|x)together withπ_SFTfully determine each other (given β). One can be solved from the others. - We will not solve for
π*fromr. We will go the other direction: solve forrin terms ofπ*andπ_SFT.
8.2 Inverting (★) to express r in terms of π* and π_SFT¶
Take the log of (★):
Solve for r(x, y):
r(x, y) = β · [ log π*(y|x) - log π_SFT(y|x) ] + β · log Z(x)
= β · log( π*(y|x) / π_SFT(y|x) ) + β · log Z(x) (♦)
This is the reward-policy duality. The reward function and the optimal
policy are in 1-to-1 correspondence (modulo the log Z prompt-dependent
constant). Importantly, Z(x) depends only on x and π_SFT, not on
`y - this will let it cancel in a moment.
8.3 The Bradley-Terry preference model¶
Humans rank pairs. Given a prompt x and two completions y_w (winner /
chosen) and y_l (loser / rejected), the probability that y_w is
preferred is modeled as
where σ is the logistic sigmoid. This is the Bradley-Terry model (Bradley & Terry, 1952). It is the standard parametric assumption in preference learning: pairwise probabilities are determined by a difference of latent scores.
8.4 Substituting (♦) into (BT)-the cancellation¶
The reward difference is:
r(x, y_w) - r(x, y_l)
= [ β·log(π*(y_w|x)/π_SFT(y_w|x)) + β·log Z(x) ]
- [ β·log(π*(y_l|x)/π_SFT(y_l|x)) + β·log Z(x) ]
= β · [ log(π*(y_w|x)/π_SFT(y_w|x)) - log(π*(y_l|x)/π_SFT(y_l|x)) ]
The β · log Z(x) terms cancel because they don't depend on y. This
cancellation is what makes DPO possible-the partition function, which
would otherwise be intractable to compute, vanishes.
Define the implicit reward (the policy-side log-ratio):
where we have replaced π* with our trainable π_θ and π_SFT with
π_ref (the reference, typically a frozen copy of the SFT model). Then:
r(x, y_w) - r(x, y_l) = r̂_θ(x, y_w) - r̂_θ(x, y_l)
= β·log(π_θ(y_w|x)/π_ref(y_w|x))
- β·log(π_θ(y_l|x)/π_ref(y_l|x))
8.5 The DPO loss¶
Plug back into (BT) and take the negative log-likelihood over a dataset of
preference pairs D = { (x, y_w, y_l) }:
L_DPO(θ) = - E_{(x,y_w,y_l)~D} [
log σ( β · log(π_θ(y_w|x)/π_ref(y_w|x))
- β · log(π_θ(y_l|x)/π_ref(y_l|x)) )
] (DPO)
This is the DPO loss. Let's read what it says. Define
Δ̂(x, y_w, y_l) := r̂_θ(x, y_w) - r̂_θ(x, y_l). Then L_DPO = -E[log σ(Δ̂)].
- When
π_θraisesy_w's likelihood relative toπ_refmore thany_l's,Δ̂is large positive,σ(Δ̂) → 1, loss → 0. Good. - When
π_θdoes the opposite, loss is large. - The gradient pushes
π_θto increase the relative log-prob of winners and decrease the relative log-prob of losers, with referenceπ_refdefining "relative."
8.6 The gradient-what DPO actually does¶
Differentiate L_DPO. Let u = β · (Δ̂). Then L = -log σ(u), so
dL/du = -(1 - σ(u)) = -σ(-u). The gradient is
Read this carefully:
σ(-Δ̂)is the model's current error mass on this pair: it's near 1 when the model is wrong (Δ̂ < 0) and near 0 when right.- The gradient is then the difference of log-probability gradients of winner and loser, scaled by error.
- The update increases
log π_θ(y_w|x)and decreaseslog π_θ(y_l|x), more so on pairs the model gets wrong. This is exactly the desired behavior, and it requires no reward model at all.
8.7 Hyperparameter β¶
β is the KL strength inherited from the original RL objective.
- Small β (~0.01): weak KL regularization, model can drift far from
π_ref. Stronger preference fitting, more risk of degeneration. - Large β (~1.0): strong leash, model stays close to
π_ref, preference signal is effectively diluted. - Typical: 0.1–0.5. Llama-style alignment runs sit around 0.1.
If your DPO output is bizarre or repetitive, try larger β. If it's identical to the SFT model, try smaller β.
8.8 The reference model in practice¶
π_ref is typically a frozen copy of the SFT model at the start of
DPO. It is loaded once and used only to compute log π_ref(y_w|x) and
`log π_ref(y_l|x) - no gradients.
Engineering tricks:
- Precompute log-probs of
π_refonce for the whole dataset. The reference is frozen; you can run it offline and cache. - Disk-cached reference halves your VRAM during DPO.
- No-reference DPO variants (e.g., IPO, CPO, SimPO) remove or rework the reference. Performance varies by dataset; SimPO has been competitive on chat benchmarks while halving the reference cost.
8.9 DPO vs PPO RLHF-the engineering scorecard¶
| Axis | PPO RLHF | DPO |
|---|---|---|
| Reward model | Required, separately trained | Implicit in the loss |
| Sampling during training | Yes (slow, autoregressive) | No (offline pairs) |
| Models in memory | 4 (actor, critic, RM, ref) | 2 (policy, ref) |
| Hyperparameter count | High | Low (β, LR, epochs) |
| Stability | Notoriously fragile | Stable |
| Quality ceiling | Slightly higher in some studies | Comparable on most |
| Implementation effort | Substantial | A training loop |
For most teams, DPO is the right starting point.
8.10 DPO pseudocode (TRL)¶
from trl import DPOTrainer, DPOConfig
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
base = "out/sft" # the SFT checkpoint
tokenizer = AutoTokenizer.from_pretrained(base)
policy = AutoModelForCausalLM.from_pretrained(base, torch_dtype="bfloat16")
ref = AutoModelForCausalLM.from_pretrained(base, torch_dtype="bfloat16")
# rows: {"prompt": "...", "chosen": "...", "rejected": "..."}
ds = load_dataset("json", data_files="prefs.jsonl", split="train")
cfg = DPOConfig(
output_dir="out/dpo",
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=5e-7, # DPO LR is ~10× smaller than SFT LR
beta=0.1,
warmup_ratio=0.05,
bf16=True,
)
trainer = DPOTrainer(model=policy, ref_model=ref, tokenizer=tokenizer,
train_dataset=ds, args=cfg)
trainer.train()
DPO with LoRA: wrap policy in a get_peft_model(...) first; you can
omit ref_model= and TRL will automatically use the base model under
the LoRA as the reference (because disabling adapters gives you the SFT
model back). This is a sweet trick that makes LoRA-DPO cost almost the
same as LoRA-SFT.
9. GRPO¶
GRPO (Group Relative Policy Optimization, DeepSeek 2024) is a recent PPO variant that drops the value function (the critic) and replaces it with group-relative statistics. It is the technique behind DeepSeek-Math and DeepSeek-R1's reasoning fine-tunes.
9.1 The insight¶
PPO's advantage A_t = R_t - V_φ(s_t) requires a learned value model
V_φ. The value model is roughly the same size as the actor, doubling
training memory.
GRPO observes: if you sample G completions from the same prompt, the
empirical mean and standard deviation of the group's rewards already form
a serviceable baseline. No critic needed.
9.2 The objective¶
For each prompt x, sample G completions {y_i}. Compute reward r_i
for each (from a reward model, a verifier, or in DeepSeek's case, a rule-
based math grader). Compute the group-relative advantage:
Apply this advantage to all tokens in y_i, then run a PPO-style clipped
update:
L_GRPO(θ) = - E [ Σ_i Σ_t min( ratio_{i,t}·A_i, clip(ratio_{i,t}, 1-ε, 1+ε)·A_i ) ]
+ β · KL( π_θ || π_ref )
The KL is added directly to the loss (rather than as a per-token reward shaping, as in PPO RLHF).
9.3 What changed vs PPO¶
- No critic. Save ~50 % of memory and compute.
- Per-prompt baselining. Reduces variance compared to a global baseline; especially effective for verifiable tasks where the reward is binary or near-binary.
Gis typically 8–16. Memory cost:Gcompletions in flight per step, but still cheaper than a critic.
9.4 When to reach for GRPO¶
- Reasoning tasks with verifiable rewards (math, code unit tests, formal verification). The rule-based reward is exact, no reward hacking, and group-relative baselining is extremely informative.
- When you cannot afford the critic memory.
- When PPO is unstable and DPO is insufficient (DPO is offline; GRPO is online, which matters for tasks where the policy is supposed to explore).
For chat alignment with subjective preferences, DPO is still simpler.
10. Reward model design¶
If you do go the PPO/RM route, the reward model is its own subsystem.
10.1 Architecture¶
The standard recipe: take the SFT model, replace the LM head with a linear scalar head, and train. This means:
- Same backbone as the policy (already pretrained on language; speaks the same dialect).
- One scalar output per
(x, y)pair (the reward).
In code: take the last-token hidden state of (x, y), project to a
scalar via Linear(d_model, 1).
10.2 Training loss¶
Given preference pairs (x, y_w, y_l), train under Bradley-Terry:
This is the same NLL of (BT) with r = r_φ. Notice the symmetry with the
DPO loss-in DPO, the implicit reward r̂_θ is a function of the policy;
here, r_φ is a separate model.
10.3 Reward hacking¶
The model trained against the RM is incentivized to maximize r_φ, not
true human preference. The RM is a fitted approximation with blind
spots. When the policy finds these blind spots, you get:
- Length bias. RMs trained on short-vs-long pairs that humans reasonably preferred often learn "longer is better." The policy generates longer and longer responses with no quality gain. The classic RLHF failure.
- Sycophancy. RM rewards agreement with the user; policy becomes a yes-man.
- Token-level exploits. RM has a quirk on certain tokens; policy finds it.
Mitigations:
- KL constraint. Prevents drift from the SFT distribution where the RM was actually calibrated.
- Length normalization. Subtract a length term from the RM target.
- Multiple RMs. Average several RMs, or take min, to reduce blind- spot exploitation.
- Reward overoptimization curves. Plot RM-score vs human-judged quality during training; stop when human quality plateaus or drops even as RM-score rises.
10.4 Process reward models (PRM)¶
For multi-step reasoning (math, coding), an outcome reward model (ORM) scores only the final answer, giving zero gradient through the chain of thought. Process reward models score every step:
- Train PRM on
(prompt, partial CoT, step is correct?)data. - During RL, give per-step reward (or accumulate stepwise rewards as a shaped final reward).
Used in OpenAI's "Let's Verify Step by Step" (Lightman et al., 2023) and implicitly by GRPO with verifiable rewards. Substantially better gradient signal for reasoning.
11. Preference data curation¶
Reward models and DPO live or die by preference data quality.
11.1 Sources¶
- Human raters (gold). Highest quality, highest cost. $0.50–$5 per pair depending on complexity. Domain experts for specialized fields.
- LLM-as-judge (cheap). Use a strong frontier model to rank pairs. Cheap, fast, biased-known issues include position bias (favoring the first option), length bias, self-bias (favoring its own family), and verbosity bias.
- Hybrid. Use LLM-as-judge for the bulk and humans for stratified audits and disagreement resolution. Common production pattern.
11.2 Volume¶
- 5 k–10 k pairs is a viable starting point for DPO if the pairs cover the target distribution.
- 30 k–50 k pairs is a strong fine-tune.
- 100 k+ approaches diminishing returns unless the underlying task is broad.
These are softer numbers than SFT volume; preference learning is more sample-efficient because each pair contains a comparison rather than just a positive example.
11.3 Quality control¶
- Inter-rater agreement. Multi-annotate ~10 % of the data and measure Cohen's κ. Below 0.4 is concerning; aim for 0.6+.
- Stratify by difficulty. Easy pairs ("clearly better" vs "clearly worse") teach little. Hard pairs (close calls between two reasonable answers) drive most learning. Bias toward harder pairs.
- Ensure the chosen is actually good. A pair where both options are bad teaches the model to be slightly less bad. Discard such pairs during curation.
- De-duplicate prompts. Many similar prompts dominate the loss and bias the model.
11.4 Self-reward and Constitutional AI¶
If human raters are unaffordable, use the model itself or a sibling model as the rater, guided by a written constitution (a list of principles like "be honest," "don't help with weapons," "ask for clarification when ambiguous"). Two stages:
- Generate two completions per prompt.
- Have the model judge which better follows the constitution; that becomes your preference pair.
This is the scaffolding behind RLAIF (§12).
12. Constitutional AI / RLAIF¶
Constitutional AI (Bai et al., 2022) replaces the human rater with an AI rater bound by an explicit set of written principles-the constitution. It scales preference data collection from "as many humans as you can hire" to "as many GPUs as you can spin up."
12.1 Two stages¶
SL-CAI (Supervised Learning). For each prompt, the model produces an
initial response, then critiques its own response against the
constitution, then revises. The SFT data is (prompt, revised_response).
This bakes in the constitution's behavior at the SFT stage.
RL-CAI (Reinforcement Learning). The model produces two responses to each prompt. Another model-also bound by the constitution-judges which is better. The resulting preference pairs train an RM (or feed DPO/GRPO directly).
12.2 Why it scales¶
Humans are the bottleneck in RLHF data. RLAIF removes that bottleneck:
- Cost. GPU-hours instead of human-hours, often 10–100× cheaper per pair.
- Throughput. Millions of pairs in days, not months.
- Consistency. A constitution is reproducible; humans are not.
12.3 What it loses¶
- Bias inheritance. The judge model has its own biases, and they propagate.
- Constitutional drift. A long constitution is not always followed precisely; principles get weighted unevenly.
- Worse on out-of-distribution preferences. Where humans use common sense the model lacks, RLAIF fails.
In practice, hybrid pipelines-RLAIF for breadth, human-rater preference data for hot-button axes (safety, factuality, sensitive domains)-outperform either pure approach.
13. Frontier-scale fine-tuning¶
The curriculum's Sequence 15 leaves a gap: how do you actually fine-tune
a 70B (or 405B, or 671B) model? The answer involves the distributed-
training stack from AI_SYSTEMS_PLAN/DEEP_DIVES/06 and the numerics from
/11. Here is the integration view.
13.1 The model parallelism axis¶
A 70B in bf16 is 140 GB, plus activations, plus optimizer state. It does not fit on a single 80 GB GPU even for inference (close, with KV). Training on a single node (8×80 GB = 640 GB) requires:
- FSDP (Fully Sharded Data Parallel, ZeRO-3 equivalent). Shards
parameters, gradients, and optimizer state across data-parallel ranks.
Each rank holds 1/N of the parameters at rest; gathers full layers
on demand for forward/backward. Cross-ref
/06. - Activation checkpointing. Discard activations during forward; recompute during backward. ~2× the compute; ~5× less activation memory. Without it, 70B SFT does not fit anywhere reasonable.
- Mixed precision. Bf16 for parameters and activations; fp32 for
optimizer master weights and accumulations. FP8 if your GPUs support
it (H100, B200)-see
/11for the scaling-factor and stochastic- rounding considerations.
13.2 Practical configuration: 70B FT on 8×H100¶
- Method. QLoRA or LoRA (rarely full FT at this scale on a single node).
- LoRA r. 64 (more than 7B because the model has more capacity to exploit).
- Sharding. FSDP with full-shard, mixed precision bf16.
- Activation checkpointing. On.
- Per-device batch. 1, with gradient accumulation to effective batch 64–128.
- Sequence length. Match production (4 k–8 k typical).
- LR. 1e-4 for LoRA at this scale.
- Throughput expectation. ~15–40 k tokens/sec across 8×H100 with activation checkpointing, roughly.
This is plausible. A single 8×H100 node can fine-tune 70B in a few hours to days, depending on data volume.
13.3 Full FT at 70B+: 32×H100¶
Full fine-tuning at 70B requires roughly 32×H100 with multi-node FSDP, or
8 nodes × 8 GPUs with carefully tuned communication. The bottleneck is
the all-gather/reduce-scatter pattern of FSDP across the cluster
interconnect (NVLink within node, InfiniBand between). See /06 for the
full breakdown.
For 405B+, you are squarely in tensor + pipeline + data + sequence
parallelism territory. The tooling is Megatron-LM, NeMo, or DeepSpeed at
scale. Cross-ref /06.
13.4 The takeaway¶
For most teams, never full-FT a 70B+. Use QLoRA. The quality gap is small (§14), the cost gap is large, and the operational complexity gap is enormous.
14. Full FT vs LoRA-the decision¶
14.1 Quality¶
Full FT consistently beats LoRA in head-to-head comparisons, but the gap is small: usually 0.5–2 percentage points on benchmark suites. For most tasks, this is below the noise floor of evaluation.
LoRA's quality scales with r. The curve flattens around r=64 for most
tasks; pushing to r=128 rarely helps. The right move is to target more
modules (Q, K, V, O, MLP gate/up/down) rather than push r higher.
14.2 Cost¶
LoRA is 10–100× cheaper in training compute. The savings come from:
- Smaller optimizer state. 100× fewer trainable parameters → 100× smaller Adam state.
- Smaller gradients. Same factor.
- No need to checkpoint full weights. Only the LoRA tensors.
QLoRA is another 2–4× on top of that for memory.
14.3 Adapter portability¶
A LoRA adapter is tens of MB on disk. You can email it. You can store 1 000 of them in a directory. You can deploy multi-tenant fine-tunes serving 100+ customers from one base model (§4.8).
A full fine-tune is the size of the model-tens of GB. Each one is its own deployment.
14.4 When full FT actually wins¶
- Very large data (>100 k–1 M examples). LoRA's low-rank constraint
starts to bite when there's enough signal to overflow the rank-
rbottleneck. - Substantial behavior shift. New language, new modality, new output structure-these are big distribution moves; LoRA can be insufficient.
- Continued pretraining (not really fine-tuning). Domain-adaptive pretraining on hundreds of millions of tokens of new corpus benefits from full updates.
For everything else-task-specific FT, persona FT, format/tone FT, preference learning-LoRA is the answer.
14.5 Decision matrix¶
| Situation | Recommended |
|---|---|
| <10 k examples, behavior tweak | LoRA |
| 10 k–100 k examples, single domain | LoRA (r=32–64) |
| Multi-tenant: 1 base + many customers | LoRA (multi-LoRA serving) |
| 100 k–1 M examples, broad shift | Full FT or large QLoRA |
| Continued pretraining on new corpus | Full FT |
| Low VRAM (single 24–48 GB GPU), large base | QLoRA |
| Frontier scale (70B+) | QLoRA, almost always |
| Preference alignment | DPO with LoRA |
| Reasoning RL with verifiable rewards | GRPO |
15. Evaluation before and after FT¶
Eval is non-negotiable. It is also the most-skipped step in fine-tuning
projects. The headline failure mode of fine-tuning is "the new model
behaves better on the dev set and worse in production"-which is exactly
what bad eval discipline produces. Cross-ref /08.
15.1 The four eval surfaces¶
- In-distribution held-out test set. Build before you train. Hold out 5–20 % of your fine-tune data, never let it touch a training batch. Report metrics here.
- Out-of-distribution eval. General-capability suites: MMLU, GSM8K, HumanEval, plus your own out-of-domain prompts. The question this answers: did we lose general capability? If MMLU drops 5 points, you have a forgetting problem.
- Production traffic eval. Ship to a small fraction of users (1 %), compare aggregate metrics (resolution rate, escalation rate, CSAT) against the previous model. The only eval that matters in the end.
- Calibration. Did the model become overconfident? Test on prompts where the right answer is "I don't know" or "I need clarification." Fine-tuned models often lose calibration because the FT data contains few abstentions.
15.2 The pre-FT baseline¶
Before training, evaluate the base model + your prompt + your RAG on the same eval set. This is your floor. Any FT that doesn't beat the floor by a meaningful margin (≥ 5 % on the metric you care about) is not worth shipping.
This is the most common mistake: teams skip the baseline, train, observe "the model works," ship, and discover later that the prompt-only baseline was already there.
15.3 Eval gates in CI¶
- Run eval after every fine-tune.
- Compare to the production model.
- Block deploy on regression > X % on any axis (in-domain, OOD, safety, calibration).
This is hygiene. It's also rare in practice. Building it once pays for itself ten times over.
15.4 Things that go wrong in eval¶
- Train-eval contamination. The eval set leaks into training data through some path you didn't notice. Always hash and check.
- Metric overfitting. Optimizing for eval metric without measuring qualities the metric doesn't capture (e.g., toxicity, hallucination).
- Ignoring OOD. "Our customer-support metric is up 12 %!" while MMLU drops 8 points. The model is now narrower; in production it hits OOD prompts and degrades.
16. The end-to-end FT workflow¶
A practical sequence that compounds rather than thrashes:
Step 1-Define the eval set first¶
Before any training. Hold out 200–500 examples. Define the metrics. Run
the base model + prompt + RAG against it; record the baseline numbers.
Cross-ref /08.
Step 2-Baseline: prompt + RAG¶
Try to solve the problem without FT. Iterate prompt and retrieval until you've extracted what's reasonable. Record the result. This is your baseline.
Step 3-SFT-LoRA on small data¶
Curate ~1 000 high-quality (prompt, completion) pairs. Run a LoRA SFT
(r=16, 1–3 epochs, LR 2e-4). Evaluate. If the lift over baseline is
sufficient and OOD eval is intact, ship.
Step 4-Scale data or escalate¶
If the lift is insufficient:
- First: scale data to 10 k examples. Most gains come from data
volume, not method change.
- Then: increase r and target modules.
- Last resort: full FT.
Step 5-Add preference learning (DPO) if behavior alignment matters¶
Once SFT is good, collect 5 k–30 k preference pairs (chosen, rejected). Run LoRA-DPO at LR 5e-7, β=0.1, 1 epoch. Evaluate on: - In-domain held-out preference accuracy. - OOD capability suites. - Calibration.
Step 6-Ship behind eval gates¶
Deploy with feature flag, A/B against the previous model. Watch production metrics for at least one full traffic cycle (week or two). Roll forward only if metrics improve and don't regress on safety.
Step 7-Monitor for drift¶
Re-evaluate periodically (monthly). Fine-tuned models can degrade as production traffic distribution shifts. When eval drops, retrain on fresh data-don't try to patch.
17. Practical exercises¶
Exercise 1-LoRA trainable parameters for Llama-7B at r=16, Q+V¶
Llama-7B parameters: 32 layers, hidden dim d = 4096, attention
projection matrices are 4096 × 4096.
Per matrix at r=16:
Q and V per layer: 2 · 128 K = 256 K. Across 32 layers:
32 · 256 K = 8 192 K = 8.0 M trainable parameters.
Total Llama-7B parameters: ~6.7 B (technically). Trainable fraction:
8.0 M / 6.7 B ≈ 0.12 %.
If you target Q, K, V, O instead of just Q+V: 4 matrices × 32 layers × 128 K = 16.4 M trainable, still 0.24 %.
If you target Q, K, V, O plus the three MLP matrices (gate, up, down,
each 4096 × 11008 for Llama-7B): per MLP matrix at r=16,
16 · (4096 + 11008) = 241 664 ≈ 236 K. Three of them per layer = 708 K.
Across 32 layers: 22.6 M. Total with attention: ~39 M. Still under
0.6 % of the model.
Exercise 2-Derive the DPO loss¶
Given:
1. The KL-regularized RL objective with optimal policy
π*(y|x) = (1/Z(x)) · π_SFT(y|x) · exp(r(x,y)/β). (See §6.3 for the
derivation.)
2. The Bradley-Terry preference model
P(y_w ≻ y_l | x) = σ(r(x, y_w) - r(x, y_l)).
Step A-invert (1) to express r in terms of π* and π_SFT:
log π*(y|x) = log π_SFT(y|x) + r(x,y)/β - log Z(x)
r(x, y) = β · log(π*(y|x)/π_SFT(y|x)) + β · log Z(x)
Step B-substitute into the BT difference:
The β·log Z(x) terms cancel because they are independent of y.
Step C-replace π* with π_θ (trainable) and π_SFT with π_ref
(frozen):
Step D-take negative log-likelihood:
Done.
Exercise 3-QLoRA memory budget for 70B on a 48 GB GPU at r=64¶
Inputs: - Base model: 70 B parameters. - Quantization: NF4 (4 bits per weight), with double quantization. - LoRA: r=64, applied to all linear modules. - Per-device batch: 1, sequence length 2 048, activation checkpointing on.
Computation:
- Quantized base model.
- Naive:
70 · 10⁹ · 4 bits / 8 = 35 · 10⁹ bytes = 35 GB. - Quantization scales with double-quant: ~0.5 GB.
-
Subtotal: ~35.5 GB.
-
LoRA adapters.
- 70B has roughly 80 layers, hidden dim ~8 192, MLP intermediate
~28 672. Linear modules per layer: Q (8192×8192), K (8192×1024 for
GQA-Llama-3 70B has 8 KV heads of 128 dim, so K and V are
actually
8192 × 1024), V (8192×1024), O (8192×8192), gate (8192×28672), up (8192×28672), down (28672×8192). - Trainable per matrix at r=64:
- Q:
64·(8192+8192) = 1.05 M - K:
64·(8192+1024) = 0.59 M - V:
64·(8192+1024) = 0.59 M - O:
64·(8192+8192) = 1.05 M - gate:
64·(8192+28672) = 2.36 M - up:
64·(8192+28672) = 2.36 M - down:
64·(28672+8192) = 2.36 M - Total per layer: ~10.4 M
- Q:
- Across 80 layers: ~830 M trainable parameters.
- In bf16:
830 M · 2 = 1.66 GB. -
Subtotal: ~1.7 GB.
-
Optimizer state (paged AdamW 8-bit).
- 8-bit Adam stores moments in 8-bit; effectively ~1 byte per moment × 2 moments × 830 M = ~1.7 GB.
- With paging, peaks may spill to CPU RAM transparently.
-
Subtotal: ~1.7 GB resident, more in CPU.
-
Activations + gradients (with activation checkpointing).
- Highly model- and seqlen-dependent. For 70B, batch 1, seq 2048, bf16, with AC: roughly 6–10 GB resident peak.
-
Subtotal: ~8 GB.
-
Slack: KV state during forward generation in eval, kernel workspaces, fragmentation: ~2 GB.
Sum: 35.5 + 1.7 + 1.7 + 8 + 2 ≈ 49 GB.
Conclusion: 70B QLoRA at r=64 fits barely on a 48 GB GPU and comfortably on a 80 GB GPU. To make 48 GB work in practice, drop to r=32, reduce sequence length to 1 024, or accept paging-induced slowdowns.
Exercise 4-Preference data collection guidelines (5-page spec)¶
Brief outline; flesh into a real spec for your team.
§1. Goals. Collect N preference pairs to fine-tune model M on
behavior axis B. Define B precisely (e.g., "tone consistent with
brand voice while preserving accuracy").
§2. Sources and stratification. - 60 % from production traffic (real user prompts). - 30 % from adversarial / edge-case prompts authored by the team. - 10 % from synthetic prompts generated by an LLM with seed topics. - Stratify by: domain, difficulty, length, sensitive content.
§3. Generation protocol.
- For each prompt, sample two completions from M.
- Use temperature 0.7 to ensure diversity but maintain quality.
- Discard prompts where both responses are clearly bad.
- Discard prompts where both responses are nearly identical.
§4. Rater pool. - 6–10 raters minimum. - 2 senior raters as gold standard. - Onboarding: 100 calibration pairs, must achieve κ ≥ 0.6 vs gold. - Re-calibration weekly with 20 fresh gold pairs.
§5. Annotation interface. - Show prompt + two completions in random order. - Rater selects "A better," "B better," "equal," "both bad." - Optional comment field for hard cases. - Discard "equal" and "both bad" from training data.
§6. Quality controls. - 10 % of pairs are gold-standard, double-annotated by senior raters. Cohen's κ measured weekly per rater; raters with κ < 0.5 are retrained or removed. - 5 % of pairs are duplicates surfaced after a 1-week gap; raters who flip on duplicates are flagged. - LLM-as-judge runs in parallel for triage; high-disagreement pairs surface to senior review.
§7. Stratification check. - Audit final dataset distribution by domain / difficulty bins. - Reject and resample if any bin is <5 % or >40 % of total.
§8. Privacy and safety. - Strip PII before raters see prompts. - Skip pairs where both responses are policy-violating. - Document and version the rater guidelines.
§9. Versioning and provenance. - Each pair carries: source prompt id, model checkpoint, sampling config, rater id, timestamp, agreement scores. - Dataset is versioned; every fine-tune cites the dataset hash.
Exercise 5-Eval matrix for a customer-support fine-tune¶
Eval surfaces, with target metrics:
| Surface | Test set | Metrics | Pass criteria |
|---|---|---|---|
| In-domain | 500 held-out support tickets | Resolution rate, factual accuracy, brand voice score | ≥ baseline + 8 % |
| Out-of-domain | MMLU 1k, GSM8K 200, HumanEval 164 | Standard scores | ≤ 2 pp regression vs base instruct |
| Refusals | 200 prompts known to require refusal | Refusal rate, refusal quality (LLM judge) | ≥ 95 % refusal rate |
| Hallucinations | 200 prompts with known answers | Hallucination rate (human-judged) | ≤ 3 % |
| Calibration | 100 ambiguous prompts | "I don't know" rate, expected calibration error | ECE ≤ baseline |
| Adversarial / safety | 300 jailbreak-style prompts | Safety violation rate | ≤ 0.5 % |
| Long-context | 50 long-thread tickets | Resolution rate | ≥ baseline |
| Multi-turn | 100 multi-turn conversations | Turn-level coherence (LLM judge) | ≥ 4.0 / 5 |
| Production A/B | 1 % live traffic | CSAT, escalation rate, AHT | CSAT ≥ control, escalation ≤ control |
Run all but the A/B in CI. A/B runs only after CI passes.
Exercise 6-GRPO step on a 4-completion group, rewards [0.8, 0.5, 0.3, 0.7]¶
Group rewards: r = [0.8, 0.5, 0.3, 0.7].
Mean: (0.8 + 0.5 + 0.3 + 0.7) / 4 = 2.3 / 4 = 0.575.
Variance:
((0.8-0.575)² + (0.5-0.575)² + (0.3-0.575)² + (0.7-0.575)²) / 4
= (0.0506 + 0.0056 + 0.0756 + 0.0156) / 4
= 0.1475 / 4
= 0.0369
Std: √0.0369 ≈ 0.192.
Group-relative advantages:
- A_1 = (0.8 - 0.575) / 0.192 = 0.225 / 0.192 ≈ +1.17
- A_2 = (0.5 - 0.575) / 0.192 = -0.075 / 0.192 ≈ -0.39
- A_3 = (0.3 - 0.575) / 0.192 = -0.275 / 0.192 ≈ -1.43
- A_4 = (0.7 - 0.575) / 0.192 = 0.125 / 0.192 ≈ +0.65
Sanity: positive advantages for above-average completions (1, 4), negative for below (2, 3). Sum of advantages is zero by construction (mean-centered). Magnitudes scaled by the within-group spread.
Per-completion update (sketch, ignoring KL): - Completion 1 gets pushed up with strength 1.17. - Completion 2 gets pushed down with strength 0.39. - Completion 3 gets pushed down with strength 1.43. - Completion 4 gets pushed up with strength 0.65.
The PPO clip then bounds each per-token ratio update. The KL term
β · KL(π_θ || π_ref) is added to the loss separately to keep the
policy close to the reference.
Appendix-End-to-end pseudocode skeleton¶
The following is the complete shape of an SFT → DPO pipeline using TRL on a single 8×H100 node. It is meant to be readable, not runnable; specific imports and config flags will drift across TRL versions.
# ---- 0. shared config ----
BASE = "meta-llama/Llama-3.1-8B-Instruct"
SFT_OUT = "out/sft"
DPO_OUT = "out/dpo"
# ---- 1. SFT-LoRA ----
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
tok = AutoTokenizer.from_pretrained(BASE)
model = AutoModelForCausalLM.from_pretrained(BASE, torch_dtype="bfloat16")
model = get_peft_model(model, LoraConfig(
r=16, lora_alpha=32, lora_dropout=0.05,
target_modules=["q_proj","k_proj","v_proj","o_proj",
"gate_proj","up_proj","down_proj"],
task_type="CAUSAL_LM"))
sft_data = load_dataset("json", data_files="data/sft.jsonl", split="train")
SFTTrainer(
model=model, tokenizer=tok, train_dataset=sft_data,
args=SFTConfig(
output_dir=SFT_OUT, num_train_epochs=2,
per_device_train_batch_size=4, gradient_accumulation_steps=8,
learning_rate=2e-4, warmup_ratio=0.05,
bf16=True, packing=True, max_seq_length=4096,
),
).train()
# ---- 2. eval after SFT ----
# (run held-out eval, OOD MMLU/GSM8K, calibration; gate on metrics)
# ---- 3. DPO-LoRA ----
from trl import DPOTrainer, DPOConfig
# Load the SFT-LoRA into a fresh policy; DPO will treat the base
# (with adapters disabled) as the reference automatically.
policy = AutoModelForCausalLM.from_pretrained(SFT_OUT,
torch_dtype="bfloat16")
prefs = load_dataset("json", data_files="data/prefs.jsonl", split="train")
DPOTrainer(
model=policy, ref_model=None, tokenizer=tok,
train_dataset=prefs,
args=DPOConfig(
output_dir=DPO_OUT, num_train_epochs=1,
per_device_train_batch_size=2, gradient_accumulation_steps=8,
learning_rate=5e-7, beta=0.1,
warmup_ratio=0.05, bf16=True,
),
).train()
# ---- 4. eval after DPO ----
# Repeat held-out eval, OOD eval, preference accuracy, calibration.
# Ship behind A/B if all gates pass.
The pipeline embodies the chapter's whole argument: a small LoRA SFT on curated demonstrations, eval gates, then a DPO pass on preference pairs, then more eval, then a careful production rollout. No PPO. No critic. No reward model. The result is competitive with the full classical RLHF stack at a fraction of the operational cost-and that is why the field has converged on this recipe as the default.
Citations and further reading¶
- Hu, E. J. et al. (2021). LoRA: Low-Rank Adaptation of Large Language Models. arXiv:2106.09685.
- Dettmers, T. et al. (2023). QLoRA: Efficient Finetuning of Quantized LLMs. arXiv:2305.14314.
- Rafailov, R. et al. (2023). Direct Preference Optimization: Your Language Model is Secretly a Reward Model. arXiv:2305.18290.
- Shao, Z. et al. (2024). DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models (introduces GRPO). arXiv:2402.03300.
- Bai, Y. et al. (2022). Constitutional AI: Harmlessness from AI Feedback. arXiv:2212.08073.
- Bradley, R. A. and Terry, M. E. (1952). Rank Analysis of Incomplete Block Designs: I. The Method of Paired Comparisons. Biometrika.
- Schulman, J. et al. (2017). Proximal Policy Optimization Algorithms. arXiv:1707.06347.
- Lightman, H. et al. (2023). Let's Verify Step by Step. arXiv:2305.20050.
- Ouyang, L. et al. (2022). Training language models to follow instructions with human feedback (InstructGPT). arXiv:2203.02155.
Cross-references inside this curriculum:
- Distributed training (FSDP, ZeRO, tensor/pipeline parallelism):
AI_SYSTEMS_PLAN/DEEP_DIVES/06. - Mixed precision, FP8, numerics:
AI_SYSTEMS_PLAN/DEEP_DIVES/11. - Eval discipline:
AI_SYSTEMS_PLAN/DEEP_DIVES/08.
End of Deep Dive 10.