Saltar a contenido

Deep Dive 10-Fine-Tuning: From SFT to RLHF

Prerequisites. Linear algebra, calculus through gradients, basic probability, a working mental model of transformer training. Familiarity with cross-entropy losses and autoregressive language modeling is assumed.

Cross-references. - Distributed training (FSDP, ZeRO-3, tensor/pipeline parallelism): AI_SYSTEMS_PLAN/DEEP_DIVES/06. - Mixed-precision and FP8 numerics: AI_SYSTEMS_PLAN/DEEP_DIVES/11. - Eval discipline (held-out sets, calibration): AI_SYSTEMS_PLAN/DEEP_DIVES/08.

Scope. This is the document the curriculum's reading list points to. Sequence 15 names LoRA, QLoRA, DPO, and GRPO without deriving them. Here we derive them end-to-end and pair the math with the engineering. The DPO derivation in §8 is the centerpiece.


Table of contents

  1. The decision matrix: prompt vs RAG vs fine-tuning
  2. Supervised fine-tuning (SFT)
  3. Catastrophic forgetting
  4. LoRA-full derivation
  5. QLoRA-full derivation
  6. Preference learning-RLHF concepts
  7. PPO for RLHF (high-level)
  8. DPO-full derivation
  9. GRPO
  10. Reward model design
  11. Preference data curation
  12. Constitutional AI / RLAIF
  13. Frontier-scale fine-tuning
  14. Full FT vs LoRA-the decision
  15. Evaluation before and after FT
  16. The end-to-end FT workflow
  17. Practical exercises

1. The decision matrix: prompt vs RAG vs fine-tuning

A pretrained model has three knobs you can turn to bend its behavior toward your task. They are not interchangeable; they live on different axes and answer different problems.

1.1 The three knobs

Prompt engineering. You ship the model unchanged. At inference time you prepend instructions, examples, or a system message that elicits the desired behavior. The model's weights are static; its context changes.

  • What it gives you. Behavior, in-context. Few-shot patterns, persona, output format, refusal policy.
  • What it costs. Tokens per call. A 4 k-token system prompt on every request is a 4 k-token tax on every request, forever.

Retrieval-augmented generation (RAG). At inference time, retrieve relevant documents from an external index (vector DB, BM25, hybrid) and inject them into the context. The model's weights are static; its facts come from outside.

  • What it gives you. Knowledge access, freshness, citations, attribution.
  • What it costs. Retrieval latency, index construction, index maintenance, retrieval-quality engineering, plus the per-call token tax for the retrieved passages.

Fine-tuning (FT). You change the weights. New (prompt, completion) pairs or preference pairs gradient-update the model so the desired behavior is baked in rather than re-elicited every call.

  • What it gives you. Stable behavior across many prompts, with no per-call prompt overhead. New tone, new format conventions, new domain idiom.
  • What it costs. A one-time training run (ranging from a few GPU-hours for small LoRA to thousands of GPU-hours for full FT of a 70B), plus a per-model serving slot (or, with multi-LoRA, a shared base + small adapters).

1.2 Rules of thumb

The clean decision rule is the stability × type matrix:

Stable Volatile
Behavior Fine-tune Prompt
Knowledge RAG (or embed in FT if tiny) RAG
  • Behavior = how the model responds: tone, format, style, safety posture, reasoning pattern, refusals, JSON conformance, persona.
  • Knowledge = facts the model relies on: product catalog, docs, policies, yesterday's customer ticket history.
  • Stable = changes monthly or slower; volatile = changes daily or faster.

Stable behavior across many domains → fine-tune. The behavior is the same regardless of which fact you're answering with; bake it in once and pay zero prompt overhead per call.

Stable knowledge that doesn't fit in the prompt → RAG. Even if your manual never changes, you cannot fit 100 MB of docs in a prompt. Retrieve.

Volatile knowledge → RAG, always. Re-training every time a doc changes is absurd. Re-index instead.

Volatile behavior → prompt. If your team is iterating on tone twice a week, you cannot ship a fine-tune twice a week. Adjust the prompt; promote to FT only when it stabilizes.

1.3 Cost comparison (order of magnitude)

Let T = tokens of prompt overhead, Q = queries per day, c_in = cost per input token.

  • Prompt cost / dayT · Q · c_in. Linear in queries, forever.
  • RAG cost / day(T + T_retrieved) · Q · c_in + index_ops. Slightly worse than prompt because retrieved chunks are extra context.
  • FT costC_train (one-time) + Q · c_in at the base token count (no overhead). Compute amortizes; per-call you pay only for the actual question.

The tipping point: if T is 2–4 k tokens and traffic is non-trivial, the amortized prompt tax beats the FT cost within weeks. This is how the math stops being abstract.

1.4 What you usually combine

In production you almost never pick one. The default stack is:

  1. Base model (pretrained + instruct-tuned by the vendor).
  2. Light fine-tune (LoRA) on stable behavior-tone, JSON shape, refusals.
  3. RAG for the knowledge that lives in your DB, docs, tickets.
  4. Prompt for the residual-the things you tweak weekly.

When this deep dive talks about "fine-tuning," it almost always means layer 2: a LoRA, sometimes a DPO on top, on top of an instruct-tuned base.


2. Supervised fine-tuning (SFT)

SFT is the simplest form of fine-tuning. You have demonstrations: pairs (x, y) where x is a prompt and y is the response you want the model to produce. You train the model to maximize p(y | x).

2.1 Setup

The dataset is a list of (x, y) pairs. For chat models, x typically includes the system prompt and prior turns; y is the assistant's reply.

Tokenize each pair into a single sequence: [x_tokens] [y_tokens]. Build an attention mask covering the whole sequence so the model attends causally across the boundary.

2.2 The loss

Standard autoregressive cross-entropy:

L_SFT(θ) = - E_{(x,y) ~ D} [ Σ_{t=1..|y|}  log p_θ(y_t | x, y_{<t}) ]

The crucial detail: mask the prompt tokens out of the loss. Concretely, build a labels tensor that is the input ids shifted by one, with all positions belonging to x set to - 100(the PyTorch ignore index). Only they` positions contribute to the loss.

2.2.1 Why mask the prompt

Two reasons:

  1. You don't want to teach the model to predict prompts. The user writes the prompt; modeling its distribution is wasted gradient. Worse: in chat, the prompt distribution is bizarre (system prompts, special tokens, role markers) and you don't want the model to bias toward producing it.
  2. Effective sample efficiency. Half your tokens being prompt is half your gradient being noise from the model's perspective.

In transformers, the canonical pattern is:

input_ids = tokenizer(x + y, return_tensors="pt").input_ids[0]
labels = input_ids.clone()
labels[: len(tokenizer(x).input_ids)] = -100  # mask the prompt

(With chat templates you mask everything except assistant turns.)

2.3 Data quality dominates data quantity

Folklore but well-supported: 1 000 hand-curated examples beat 100 000 noisy ones. The reason is mechanical: SFT tightens the model's output distribution toward the training distribution, including its mistakes. A noisy dataset gives the model permission to be sloppy.

Practical rules:

  • Curate ruthlessly. Read every example yourself before training. If a human wouldn't be proud to ship that response, the model shouldn't either.
  • Diversity matters more than count. 1 000 examples covering 50 task archetypes beat 10 000 examples of the same five.
  • Prefer expert demonstrations. Subject-matter experts produce 3–5× cleaner data than crowdworkers, and SFT is bottlenecked by ceiling, not volume.

2.4 Hyperparameters that matter

These are not magic. They follow from gradient stability and from the size of the update you're trying to make.

  • Learning rate.
  • Full FT: 1e-5 to 5e-5 is the typical band. Larger models want smaller LRs; a 70B should be at the low end. Approximate.
  • LoRA: 1e-4 to 5e-4-a decade higher because the trained tensor is initialized at zero and is much smaller, so it can absorb a larger update without destabilizing.
  • Epochs. 1–3. More epochs trade generalization for memorization. If your eval is plateauing at epoch 2, stop. If it's still climbing at epoch 3, you probably have not enough data, not too few epochs.
  • Warmup. 5–10 % of total steps, linear from 0 to peak LR. Skipping warmup on a freshly-loaded pretrained model is a great way to corrupt early layers.
  • LR schedule. Cosine decay to ~10 % of peak LR is the default; linear decay is fine.
  • Batch size. Whatever fits in memory after gradient accumulation. The effective batch size matters more than the per-step batch size; aim for 64–256 sequences-equivalent.
  • Sequence length. Match production. Don't train on 512-token sequences and serve 4 k.
  • Sequence packing. Pack short examples into one sequence (with attention boundaries between them) to fill context efficiently. A dataset of 200-token chat turns wastes 87 % of a 1 600-token forward pass without packing.

2.5 Minimal SFT pseudocode (TRL)

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer, SFTConfig

model_name = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="bfloat16")

ds = load_dataset("json", data_files="sft.jsonl", split="train")
# rows: {"messages": [{"role": "system", ...}, {"role": "user", ...},
#                     {"role": "assistant", ...}]}

cfg = SFTConfig(
    output_dir="out/sft",
    num_train_epochs=2,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,        # effective batch 32
    learning_rate=2e-5,
    warmup_ratio=0.05,
    lr_scheduler_type="cosine",
    bf16=True,
    packing=True,
    max_seq_length=4096,
)

trainer = SFTTrainer(model=model, tokenizer=tokenizer,
                     train_dataset=ds, args=cfg)
trainer.train()

TRL's SFTTrainer masks the prompt for you when messages follows the chat template, and packing=True does the sequence packing.


3. Catastrophic forgetting

A model fine-tuned heavily on a narrow distribution forgets things it used to know. Not metaphorically-measurably. MMLU drops; GSM8K drops; safety behavior drifts. This is catastrophic forgetting, and it is the central risk of fine-tuning.

3.1 Why it happens

Pretraining packs an enormous amount of knowledge into the model's weights. That knowledge lives as a fragile equilibrium of activations. SFT pushes the weights toward a small distribution (your data), and gradients that move the model toward your data are not, in general, gradients that preserve far-away knowledge. The model is solving a different problem now, and the old solution is collateral.

3.2 Mitigations, in order of strength

  1. Keep epochs low. 1 epoch with rich data forgets less than 5 epochs on the same data.
  2. Mix in instruction-tune data. During FT, blend in 5–20 % general instruction data (e.g., a slice of the original SFT distribution if you have it, or a public mix). This anchors the model.
  3. Use LoRA. A small rank-r perturbation of the weights cannot express drastic forgetting (§4). The base remains intact and can be detached from the adapter at any time.
  4. KL-regularize toward the base. Add a term β · KL(π_θ || π_base) to the loss, so updates that move the output distribution far from the base are penalized. This is the same idea that makes RLHF stable (§6).
  5. Replay buffer. Cache representative examples from the base distribution and interleave them.

3.3 Why FT-from-scratch is dangerous; FT-from-instruct is safer

A pretrained-only base model has not learned to follow instructions, refuse, or behave safely. SFT on top of that base on a narrow domain produces a model that is good at your task and aggressively bad at everything else, including safety. SFT on top of a vendor instruct-tuned model preserves the instruction/safety scaffolding by construction (especially with LoRA), and your fine-tune adds a thin layer of domain behavior.

The lesson: unless you have a very specific reason, always FT on top of the instruct-tuned variant.


4. LoRA-full derivation

LoRA (Hu et al., 2021) is the dominant parameter-efficient fine-tuning method. The derivation is short and the consequences are large.

4.1 The empirical insight

When you fully fine-tune a pretrained model, the weight update Δ = W' − W is empirically low rank. That is, even though Δ is a d × k matrix with nominal rank min(d, k), almost all of its singular values are tiny. The fine-tuning update lives in a low-dimensional subspace of weight space.

This is intuitively reasonable: pretraining already filled in the heavy features; fine-tuning is steering, not relearning.

4.2 The decomposition

Parameterize the update as a low-rank product:

Δ = B · A,    where  B ∈ R^{d × r},  A ∈ R^{r × k},   r ≪ min(d, k)

Then B · A is at most rank r by construction. Replace W + Δ with W + B · A and freeze W. The trainable parameters are A and B only.

The forward pass becomes:

h = (W + B · A) · x  =  W·x + B·(A·x)

The right-hand side shows the implementation: compute A·x first (small, r × 1), then B · (A·x) (d × 1), and add to the original W · x. No new full-size matmuls.

4.3 Initialization

You need Δ = 0 at the start of training so the model output equals the pretrained model exactly. The standard choice:

  • A ~ Gaussian (e.g., Kaiming-uniform-the default in most LoRA libs).
  • B = 0.

Then B · A = 0 regardless of A's values, so Δ = 0 at step 0. Gradients still flow through A (because B is multiplied by A and B's gradient is (∂L/∂Δ) · Aᵀ, which is nonzero when B = 0 and `A ≠ 0 - wait, careful here).

Walk through the gradients explicitly. Let L be the loss and let g = ∂L/∂(BA) ∈ R^{d × k}. Then:

∂L/∂B = g · Aᵀ        # nonzero when A is nonzero
∂L/∂A = Bᵀ · g        # zero when B = 0

At step 0, B = 0, so ∂L/∂A = 0. Only B updates. After one step, B ≠ 0, and A starts updating too. So initialization is asymmetric on purpose: the trainer takes a step on B first and unfreezes A on the second step. In practice this works fine and converges identically to symmetric initializations.

(Some libraries swap the convention-A zero, B Gaussian-which gives the symmetric result. Either is fine; just match the library's docs.)

4.4 Scaling: the α/r trick

LoRA introduces a scalar:

Δ = (α / r) · B · A

The scaling factor α/r decouples learning rate from rank. Without it, doubling r would double the magnitude of B · A's expected output (more basis vectors summed), and you'd have to halve the LR to compensate.

Common practice: fix α and sweep r. Typical: α = 16 or α = 32, r ∈ {8, 16, 32, 64}. With α/r scaling, the effective LR stays sane across rank changes.

4.5 Where to apply LoRA

The original paper applied LoRA only to the query and value projections of attention (W_q, W_v). The reasoning: those are the projections most sensitive to fine-tuning.

Modern practice broadens this:

  • Q, K, V, O (all four attention projections). Adds parameters but gives a noticeable quality bump.
  • MLP weights (W_up, W_down, W_gate for gated MLPs like SwiGLU). Best quality. More parameters. The general consensus from recent fine-tuning work is that targeting MLPs matters as much as attention.
  • Embeddings and LM head. Usually skip; large parameter counts and small benefit for most tasks. Apply only when changing vocabulary semantics (e.g., adding new tokens).

The rule: more LoRA targets → better quality, more parameters. For most applications, all linear layers in the transformer block is the default that is hard to beat.

4.6 Memory math

Per matrix W ∈ R^{d × k}:

  • Full FT trains d · k parameters.
  • LoRA trains r · (d + k) parameters.

For d = k = 4096 and r = 16:

  • Full FT: 4096 · 4096 = 16 777 216 ≈ 16.8 M parameters per matrix.
  • LoRA: 16 · (4096 + 4096) = 131 072 ≈ 131 k parameters per matrix.
  • Ratio: 128× fewer trainable parameters per matrix.

For Llama-7B (32 layers, applying LoRA to Q, V at r = 16):

  • 32 layers × 2 matrices × 131 072 ≈ 8.4 M trainable parameters.
  • The full model is ~7 B parameters.
  • Trainable fraction: 8.4M / 7B ≈ 0.12 %.

You move 0.12 % of the parameters and recover most of the full-FT quality. This is the LoRA promise.

The optimizer state is also tiny. Adam stores two moments (m, v) per trainable parameter, both in fp32. Full FT of 7B in mixed precision: 7B · (2 + 2) × 4 bytes = 112 GB of optimizer state alone. LoRA at 0.12 % trainable: ~135 MB. The optimizer fits in cache.

4.7 Inference: merge or keep separate

Two deployment modes:

  1. Merge. At inference time, compute W' = W + (α/r) · B · A once and replace the base weight. Zero serving overhead-the model is shape- identical to the base.
  2. Keep separate. Carry B and A as side tensors. Apply W·x + (α/r)·B·(A·x) at every forward. Tiny overhead. Lets you hot-swap adapters at request time.

4.8 Multi-LoRA serving

This is the modern multi-tenant pattern. Load one base model; carry many small adapters; route each request to the right one.

  • One base model in GPU memory (e.g., 14 GB for an 8B in bf16).
  • 50 customer-specific adapters at ~50 MB each = 2.5 GB.
  • Total VRAM: 16.5 GB. One H100-80GB serves 50 fine-tunes.

Frameworks supporting this: vLLM ( - -enable-lora`), LoRAX, S-LoRA. The batched matrix multiply for multiple adapters in the same batch is the nontrivial systems work; the libraries handle it.

4.9 Minimal LoRA pseudocode (PEFT)

from peft import LoraConfig, get_peft_model

lora = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                    "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora)
model.print_trainable_parameters()
# trainable params: ~42M / total params: ~8B / trainable %: ~0.5

Drop this in front of an SFTTrainer and you have LoRA SFT.


5. QLoRA-full derivation

QLoRA (Dettmers et al., 2023) combines LoRA with aggressive 4-bit weight quantization. The combination is the reason a 70B fine-tune is feasible on a single 48 GB GPU.

5.1 The insight

LoRA already shrinks the optimizer state and trainable parameters. The remaining memory hog is the base model weights themselves (e.g., 70B in bf16 = 140 GB). If you quantize them to 4 bits, the base model takes 35 GB. The LoRA adapters stay in higher precision (bf16) and continue to train normally; the frozen quantized base is used only for forward and backward through the frozen weights.

The trick is that the gradient through a frozen weight only needs to read the weight, not write it. So you can leave the base in 4-bit storage and dequantize on the fly during the matmul; no quantization-aware-training machinery is needed for the base.

5.2 NF4 (NormalFloat-4) quantization

Standard INT4 quantization splits the value range into 16 evenly-spaced levels. For tensor values that are roughly uniformly distributed, this is near-optimal. For neural-network weights, which are well-modeled as zero-mean normal, uniform spacing wastes precision in the tails (where few weights live) and underdescribes the bulk near zero.

NF4's solution: choose the 16 levels to be the 16 quantiles of a standard normal distribution. Concretely, the levels are:

q_i = Φ⁻¹( (i + 0.5) / 16 ),    i = 0, 1, ..., 15

(Approximately-the QLoRA paper splits the levels symmetrically around zero and includes 0 exactly.) Then a normally-distributed weight tensor has approximately uniform mass in each of the 16 NF4 bins. This is information-theoretically optimal for normal weights: each level carries the same bit of information.

The lookup table is fixed-derived once from the normal CDF-and hardcoded. Quantization at runtime is: divide the tensor by its absmax into the [-1, 1] range, then for each value find the nearest level in the NF4 table, store the level index (4 bits) and remember the scale.

5.3 Double quantization

The scale factors themselves take memory. For a 70B model with a block size of 64, you have 70 B / 64 ≈ 1.1 B scale factors. In fp32, that's ~4.4 GB of scales-a non-trivial fraction of the quantized model.

Double quantization quantizes the scale factors themselves to FP8. Saves ~3 GB on a 70B. Not glamorous but free.

5.4 Paged optimizers

Even with all of the above, training spikes can OOM the GPU. NVIDIA's Unified Memory (UVM) lets you allocate optimizer state in a way that can page between GPU and CPU memory automatically, like virtual memory at the OS level. Optimizer states for momentum/variance are large but infrequently accessed during forward/backward; paging them out during peak activation memory is invisible.

Result: the GPU survives transient memory pressure that would otherwise kill the run.

5.5 Memory budget-70B on a single 48 GB GPU at r=64

Counting:

  • Base weights (NF4). 70 B params × 4 bits = 35 GB. Subtract a tiny bit for double-quant constants (negligible after DQ).
  • LoRA adapters. Apply LoRA to all linear layers (~7 modules per layer × 80 layers = 560 modules). Average matrix size for a 70B is roughly 8192 × 8192 (model dim 8 192, MLP up to 28 672). Trainable per matrix at r=64 ≈ 64 · (8192 + 8192) = 1 048 576 ≈ 1 M. Across modules: roughly 200–300 M trainable params, in bf16 = 400–600 MB. Optimizer state (Adam, fp32, 2× the params) = ~2 GB. Total: ~2.5 GB.
  • Activations and gradients. With activation checkpointing, this is the dominant remaining cost. For batch 1, seqlen 2048, on a 70B with AC: ~6–10 GB. Without AC: 30+ GB and you OOM.
  • Slack for kernels, KV, fragmentation. ~3 GB.

Total: 35 + 2.5 + ~8 + 3 ≈ 48–49 GB. Right at the line on a 48 GB GPU (A6000-Ada, RTX 6000 Ada, A40). A single H100-80GB has comfortable margin.

5.6 QLoRA pseudocode

from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

bnb = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype="bfloat16",
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.3-70B-Instruct",
    quantization_config=bnb,
    device_map="auto",
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, LoraConfig(r=64, lora_alpha=16, ...))

Pass optim="paged_adamw_8bit" to the trainer for paged optimizers.


6. Preference learning-RLHF concepts

SFT teaches the model to imitate good demonstrations. But humans are often better at judging than generating: writing the perfect customer-support reply is hard; picking which of two replies is better is easy. Preference learning leverages this asymmetry.

6.1 Why preferences instead of demonstrations

  • Cheaper. A pairwise comparison is faster than authoring a gold reply.
  • More reliable. Two raters tend to agree on rankings even when they'd produce different "ideal" answers.
  • Captures style and nuance. "This response is more empathetic" is easy to mark and very hard to specify.
  • Negative information. SFT can't tell the model what not to do; preferences can.

6.2 The classic RLHF pipeline (InstructGPT, 2022)

  1. SFT. Standard supervised fine-tune on demonstrations.
  2. Reward model (RM). Collect preference pairs (x, y_w, y_l) (chosen, rejected). Train a model r_φ(x, y) → ℝ that scores responses, with the loss derived in §10.
  3. RL. Fine-tune the SFT policy π_θ with reinforcement learning to maximize expected reward, subject to a KL penalty toward the SFT model.

The objective for stage 3:

J(θ) = E_{x ~ D, y ~ π_θ(·|x)} [ r_φ(x, y) ] - β · KL( π_θ(·|x) || π_SFT(·|x) )

Equivalently, per-token:

J(θ) = E [ Σ_t  r_t  -  β · log(π_θ(y_t|x, y_<t) / π_SFT(y_t|x, y_<t)) ]

where r_t is typically zero for non-final tokens and r_φ(x, y) for the final token.

6.3 The KL constraint, derived

Why the KL penalty? Without it, the policy will exploit the reward model. The RM is a fitted approximation of human preference; it has blind spots. Pure reward maximization runs the policy toward whatever the RM accidentally likes-verbosity, hedging, specific tokens-and quality collapses. This is reward hacking (§10.3).

The KL term β · KL(π_θ || π_SFT) says: stay close to the SFT model. The SFT model is a known-good distribution (it produces fluent text); large deviations are suspicious. β controls how tight the leash is.

In closed form, the KL-constrained optimal policy is

π*(y | x) = (1/Z(x)) · π_SFT(y | x) · exp( r(x, y) / β )

with Z(x) = Σ_y π_SFT(y|x) · exp(r(x,y)/β). Derivation:

We maximize, for each prompt x,

F(π) = E_{y ~ π} [r(x, y)] - β · KL(π(·|x) || π_SFT(·|x))
     = Σ_y π(y|x) · r(x, y) - β · Σ_y π(y|x) · log(π(y|x)/π_SFT(y|x))

subject to Σ_y π(y|x) = 1. Lagrangian:

L = Σ_y π(y|x) [ r(x, y) - β·log(π(y|x)/π_SFT(y|x)) ] - λ(x)·(Σ_y π(y|x) - 1)

Take ∂/∂π(y|x):

r(x, y) - β·log(π(y|x)/π_SFT(y|x)) - β - λ(x) = 0

Solve for π:

log(π(y|x)/π_SFT(y|x)) = (r(x, y) - β - λ(x)) / β
π(y|x) = π_SFT(y|x) · exp((r(x, y) - β - λ(x)) / β)
       = π_SFT(y|x) · exp(r(x, y)/β) · C(x)

where C(x) = exp(-(β + λ(x))/β) is constant in y. Apply the normalization Σ_y π(y|x) = 1:

C(x) = 1 / Σ_y [ π_SFT(y|x) · exp(r(x, y)/β) ] = 1 / Z(x)

So:

π*(y | x) = (1 / Z(x)) · π_SFT(y | x) · exp(r(x, y) / β)        (★)

This is the KL-regularized RL optimal policy. We will use (★) in the DPO derivation in §8-it is the central identity.


7. PPO for RLHF (high-level)

Stage 3 of RLHF (the actual RL fine-tune) is traditionally done with PPO (Schulman et al., 2017). PPO is a policy-gradient algorithm with a trust-region-style clip to keep updates small.

7.1 The PPO clip

Let ratio_t(θ) = π_θ(y_t|x, y_<t) / π_θ_old(y_t|x, y_<t) be the importance ratio between the current and the previous policy iterate. Let A_t be the estimated advantage (token-level). PPO maximizes:

L_clip(θ) = E_t [ min( ratio_t · A_t,  clip(ratio_t, 1-ε, 1+ε) · A_t ) ]

with ε ≈ 0.1–0.2. Why the min of a clipped and unclipped term: if the update would push the policy further than 1+ε (or below 1-ε) and the advantage is positive (negative), the clipped version is taken-which has zero gradient-preventing runaway moves. If the advantage is negative and the ratio drops below 1-ε, the unclipped term is more negative, and that's what's selected, so the policy can still pull away from bad actions.

7.2 The four-model setup

PPO RLHF carries four models in memory simultaneously:

  1. Actor / policy (π_θ)-the model being trained.
  2. Critic / value function (V_φ)-estimates expected return at each token, used to compute advantages via GAE.
  3. Reward model (r_ψ)-frozen, scores final responses.
  4. Reference policy (π_SFT)-frozen, used in the KL penalty.

Memory cost: roughly 2× the actor for the critic (often initialized from the same base), plus actor + RM + reference. For a 7B base, you're managing ~28 B parameters' worth of model state. For a 70B base, RLHF is genuinely a multi-node enterprise.

7.3 Why PPO RLHF is hard

  • Hyperparameter sensitivity. β, ε, RM-LR, actor-LR, critic-LR, KL target, GAE-λ, all interact. Small changes can collapse training.
  • Reward hacking. RM is imperfect; the actor finds exploits.
  • KL ratchet. As training progresses, the policy can asymptotically drift from π_SFT even with the KL penalty, especially on long generations.
  • Memory. Four models. Distributed RLHF on a 70B is real research infrastructure.
  • Sample inefficiency. Each PPO step requires generating completions (slow, autoregressive) before the gradient step.

DPO (§8) was motivated by all of this: can we get the same alignment benefit without the RL stack?


8. DPO-full derivation

This is the chapter's centerpiece. DPO (Rafailov et al., 2023) shows that the classic RLHF objective has a closed-form optimum, that this optimum can be reparameterized in terms of the policy alone, and that the resulting objective is a simple supervised cross-entropy loss on preference pairs. No reward model. No PPO. No critic. No four-model setup.

8.1 The starting point: the KL-constrained RL objective

From §6.3 we had the optimal policy under the KL-regularized RL objective:

π*(y | x) = (1 / Z(x)) · π_SFT(y | x) · exp( r(x, y) / β )       (★)

Two observations:

  • The function r(x, y) and the policy π*(y|x) together with π_SFT fully determine each other (given β). One can be solved from the others.
  • We will not solve for π* from r. We will go the other direction: solve for r in terms of π* and π_SFT.

8.2 Inverting (★) to express r in terms of π* and π_SFT

Take the log of (★):

log π*(y|x) = log π_SFT(y|x) + r(x, y)/β - log Z(x)

Solve for r(x, y):

r(x, y) = β · [ log π*(y|x) - log π_SFT(y|x) ] + β · log Z(x)
        = β · log( π*(y|x) / π_SFT(y|x) ) + β · log Z(x)             (♦)

This is the reward-policy duality. The reward function and the optimal policy are in 1-to-1 correspondence (modulo the log Z prompt-dependent constant). Importantly, Z(x) depends only on x and π_SFT, not on `y - this will let it cancel in a moment.

8.3 The Bradley-Terry preference model

Humans rank pairs. Given a prompt x and two completions y_w (winner / chosen) and y_l (loser / rejected), the probability that y_w is preferred is modeled as

P(y_w ≻ y_l | x) = σ( r(x, y_w) - r(x, y_l) )                          (BT)

where σ is the logistic sigmoid. This is the Bradley-Terry model (Bradley & Terry, 1952). It is the standard parametric assumption in preference learning: pairwise probabilities are determined by a difference of latent scores.

8.4 Substituting (♦) into (BT)-the cancellation

The reward difference is:

r(x, y_w) - r(x, y_l)
  = [ β·log(π*(y_w|x)/π_SFT(y_w|x)) + β·log Z(x) ]
  - [ β·log(π*(y_l|x)/π_SFT(y_l|x)) + β·log Z(x) ]
  = β · [ log(π*(y_w|x)/π_SFT(y_w|x)) - log(π*(y_l|x)/π_SFT(y_l|x)) ]

The β · log Z(x) terms cancel because they don't depend on y. This cancellation is what makes DPO possible-the partition function, which would otherwise be intractable to compute, vanishes.

Define the implicit reward (the policy-side log-ratio):

r̂_θ(x, y) := β · log( π_θ(y|x) / π_ref(y|x) )                          (▼)

where we have replaced π* with our trainable π_θ and π_SFT with π_ref (the reference, typically a frozen copy of the SFT model). Then:

r(x, y_w) - r(x, y_l) = r̂_θ(x, y_w) - r̂_θ(x, y_l)
                      = β·log(π_θ(y_w|x)/π_ref(y_w|x))
                      - β·log(π_θ(y_l|x)/π_ref(y_l|x))

8.5 The DPO loss

Plug back into (BT) and take the negative log-likelihood over a dataset of preference pairs D = { (x, y_w, y_l) }:

L_DPO(θ) = - E_{(x,y_w,y_l)~D} [
    log σ(  β · log(π_θ(y_w|x)/π_ref(y_w|x))
          - β · log(π_θ(y_l|x)/π_ref(y_l|x)) )
]                                                                       (DPO)

This is the DPO loss. Let's read what it says. Define Δ̂(x, y_w, y_l) := r̂_θ(x, y_w) - r̂_θ(x, y_l). Then L_DPO = -E[log σ(Δ̂)].

  • When π_θ raises y_w's likelihood relative to π_ref more than y_l's, Δ̂ is large positive, σ(Δ̂) → 1, loss → 0. Good.
  • When π_θ does the opposite, loss is large.
  • The gradient pushes π_θ to increase the relative log-prob of winners and decrease the relative log-prob of losers, with reference π_ref defining "relative."

8.6 The gradient-what DPO actually does

Differentiate L_DPO. Let u = β · (Δ̂). Then L = -log σ(u), so dL/du = -(1 - σ(u)) = -σ(-u). The gradient is

∇_θ L_DPO = -β · σ( -Δ̂ ) · [ ∇_θ log π_θ(y_w|x) - ∇_θ log π_θ(y_l|x) ]

Read this carefully:

  • σ(-Δ̂) is the model's current error mass on this pair: it's near 1 when the model is wrong (Δ̂ < 0) and near 0 when right.
  • The gradient is then the difference of log-probability gradients of winner and loser, scaled by error.
  • The update increases log π_θ(y_w|x) and decreases log π_θ(y_l|x), more so on pairs the model gets wrong. This is exactly the desired behavior, and it requires no reward model at all.

8.7 Hyperparameter β

β is the KL strength inherited from the original RL objective.

  • Small β (~0.01): weak KL regularization, model can drift far from π_ref. Stronger preference fitting, more risk of degeneration.
  • Large β (~1.0): strong leash, model stays close to π_ref, preference signal is effectively diluted.
  • Typical: 0.1–0.5. Llama-style alignment runs sit around 0.1.

If your DPO output is bizarre or repetitive, try larger β. If it's identical to the SFT model, try smaller β.

8.8 The reference model in practice

π_ref is typically a frozen copy of the SFT model at the start of DPO. It is loaded once and used only to compute log π_ref(y_w|x) and `log π_ref(y_l|x) - no gradients.

Engineering tricks:

  • Precompute log-probs of π_ref once for the whole dataset. The reference is frozen; you can run it offline and cache.
  • Disk-cached reference halves your VRAM during DPO.
  • No-reference DPO variants (e.g., IPO, CPO, SimPO) remove or rework the reference. Performance varies by dataset; SimPO has been competitive on chat benchmarks while halving the reference cost.

8.9 DPO vs PPO RLHF-the engineering scorecard

Axis PPO RLHF DPO
Reward model Required, separately trained Implicit in the loss
Sampling during training Yes (slow, autoregressive) No (offline pairs)
Models in memory 4 (actor, critic, RM, ref) 2 (policy, ref)
Hyperparameter count High Low (β, LR, epochs)
Stability Notoriously fragile Stable
Quality ceiling Slightly higher in some studies Comparable on most
Implementation effort Substantial A training loop

For most teams, DPO is the right starting point.

8.10 DPO pseudocode (TRL)

from trl import DPOTrainer, DPOConfig
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

base = "out/sft"  # the SFT checkpoint
tokenizer = AutoTokenizer.from_pretrained(base)
policy = AutoModelForCausalLM.from_pretrained(base, torch_dtype="bfloat16")
ref    = AutoModelForCausalLM.from_pretrained(base, torch_dtype="bfloat16")

# rows: {"prompt": "...", "chosen": "...", "rejected": "..."}
ds = load_dataset("json", data_files="prefs.jsonl", split="train")

cfg = DPOConfig(
    output_dir="out/dpo",
    num_train_epochs=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=5e-7,           # DPO LR is ~10× smaller than SFT LR
    beta=0.1,
    warmup_ratio=0.05,
    bf16=True,
)
trainer = DPOTrainer(model=policy, ref_model=ref, tokenizer=tokenizer,
                     train_dataset=ds, args=cfg)
trainer.train()

DPO with LoRA: wrap policy in a get_peft_model(...) first; you can omit ref_model= and TRL will automatically use the base model under the LoRA as the reference (because disabling adapters gives you the SFT model back). This is a sweet trick that makes LoRA-DPO cost almost the same as LoRA-SFT.


9. GRPO

GRPO (Group Relative Policy Optimization, DeepSeek 2024) is a recent PPO variant that drops the value function (the critic) and replaces it with group-relative statistics. It is the technique behind DeepSeek-Math and DeepSeek-R1's reasoning fine-tunes.

9.1 The insight

PPO's advantage A_t = R_t - V_φ(s_t) requires a learned value model V_φ. The value model is roughly the same size as the actor, doubling training memory.

GRPO observes: if you sample G completions from the same prompt, the empirical mean and standard deviation of the group's rewards already form a serviceable baseline. No critic needed.

9.2 The objective

For each prompt x, sample G completions {y_i}. Compute reward r_i for each (from a reward model, a verifier, or in DeepSeek's case, a rule- based math grader). Compute the group-relative advantage:

A_i = (r_i - mean({r_1, ..., r_G})) / std({r_1, ..., r_G})

Apply this advantage to all tokens in y_i, then run a PPO-style clipped update:

L_GRPO(θ) = - E [ Σ_i  Σ_t  min( ratio_{i,t}·A_i,  clip(ratio_{i,t}, 1-ε, 1+ε)·A_i ) ]
            + β · KL( π_θ || π_ref )

The KL is added directly to the loss (rather than as a per-token reward shaping, as in PPO RLHF).

9.3 What changed vs PPO

  • No critic. Save ~50 % of memory and compute.
  • Per-prompt baselining. Reduces variance compared to a global baseline; especially effective for verifiable tasks where the reward is binary or near-binary.
  • G is typically 8–16. Memory cost: G completions in flight per step, but still cheaper than a critic.

9.4 When to reach for GRPO

  • Reasoning tasks with verifiable rewards (math, code unit tests, formal verification). The rule-based reward is exact, no reward hacking, and group-relative baselining is extremely informative.
  • When you cannot afford the critic memory.
  • When PPO is unstable and DPO is insufficient (DPO is offline; GRPO is online, which matters for tasks where the policy is supposed to explore).

For chat alignment with subjective preferences, DPO is still simpler.


10. Reward model design

If you do go the PPO/RM route, the reward model is its own subsystem.

10.1 Architecture

The standard recipe: take the SFT model, replace the LM head with a linear scalar head, and train. This means:

  • Same backbone as the policy (already pretrained on language; speaks the same dialect).
  • One scalar output per (x, y) pair (the reward).

In code: take the last-token hidden state of (x, y), project to a scalar via Linear(d_model, 1).

10.2 Training loss

Given preference pairs (x, y_w, y_l), train under Bradley-Terry:

L_RM(φ) = - E [ log σ( r_φ(x, y_w) - r_φ(x, y_l) ) ]

This is the same NLL of (BT) with r = r_φ. Notice the symmetry with the DPO loss-in DPO, the implicit reward r̂_θ is a function of the policy; here, r_φ is a separate model.

10.3 Reward hacking

The model trained against the RM is incentivized to maximize r_φ, not true human preference. The RM is a fitted approximation with blind spots. When the policy finds these blind spots, you get:

  • Length bias. RMs trained on short-vs-long pairs that humans reasonably preferred often learn "longer is better." The policy generates longer and longer responses with no quality gain. The classic RLHF failure.
  • Sycophancy. RM rewards agreement with the user; policy becomes a yes-man.
  • Token-level exploits. RM has a quirk on certain tokens; policy finds it.

Mitigations:

  • KL constraint. Prevents drift from the SFT distribution where the RM was actually calibrated.
  • Length normalization. Subtract a length term from the RM target.
  • Multiple RMs. Average several RMs, or take min, to reduce blind- spot exploitation.
  • Reward overoptimization curves. Plot RM-score vs human-judged quality during training; stop when human quality plateaus or drops even as RM-score rises.

10.4 Process reward models (PRM)

For multi-step reasoning (math, coding), an outcome reward model (ORM) scores only the final answer, giving zero gradient through the chain of thought. Process reward models score every step:

  • Train PRM on (prompt, partial CoT, step is correct?) data.
  • During RL, give per-step reward (or accumulate stepwise rewards as a shaped final reward).

Used in OpenAI's "Let's Verify Step by Step" (Lightman et al., 2023) and implicitly by GRPO with verifiable rewards. Substantially better gradient signal for reasoning.


11. Preference data curation

Reward models and DPO live or die by preference data quality.

11.1 Sources

  • Human raters (gold). Highest quality, highest cost. $0.50–$5 per pair depending on complexity. Domain experts for specialized fields.
  • LLM-as-judge (cheap). Use a strong frontier model to rank pairs. Cheap, fast, biased-known issues include position bias (favoring the first option), length bias, self-bias (favoring its own family), and verbosity bias.
  • Hybrid. Use LLM-as-judge for the bulk and humans for stratified audits and disagreement resolution. Common production pattern.

11.2 Volume

  • 5 k–10 k pairs is a viable starting point for DPO if the pairs cover the target distribution.
  • 30 k–50 k pairs is a strong fine-tune.
  • 100 k+ approaches diminishing returns unless the underlying task is broad.

These are softer numbers than SFT volume; preference learning is more sample-efficient because each pair contains a comparison rather than just a positive example.

11.3 Quality control

  • Inter-rater agreement. Multi-annotate ~10 % of the data and measure Cohen's κ. Below 0.4 is concerning; aim for 0.6+.
  • Stratify by difficulty. Easy pairs ("clearly better" vs "clearly worse") teach little. Hard pairs (close calls between two reasonable answers) drive most learning. Bias toward harder pairs.
  • Ensure the chosen is actually good. A pair where both options are bad teaches the model to be slightly less bad. Discard such pairs during curation.
  • De-duplicate prompts. Many similar prompts dominate the loss and bias the model.

11.4 Self-reward and Constitutional AI

If human raters are unaffordable, use the model itself or a sibling model as the rater, guided by a written constitution (a list of principles like "be honest," "don't help with weapons," "ask for clarification when ambiguous"). Two stages:

  1. Generate two completions per prompt.
  2. Have the model judge which better follows the constitution; that becomes your preference pair.

This is the scaffolding behind RLAIF (§12).


12. Constitutional AI / RLAIF

Constitutional AI (Bai et al., 2022) replaces the human rater with an AI rater bound by an explicit set of written principles-the constitution. It scales preference data collection from "as many humans as you can hire" to "as many GPUs as you can spin up."

12.1 Two stages

SL-CAI (Supervised Learning). For each prompt, the model produces an initial response, then critiques its own response against the constitution, then revises. The SFT data is (prompt, revised_response). This bakes in the constitution's behavior at the SFT stage.

RL-CAI (Reinforcement Learning). The model produces two responses to each prompt. Another model-also bound by the constitution-judges which is better. The resulting preference pairs train an RM (or feed DPO/GRPO directly).

12.2 Why it scales

Humans are the bottleneck in RLHF data. RLAIF removes that bottleneck:

  • Cost. GPU-hours instead of human-hours, often 10–100× cheaper per pair.
  • Throughput. Millions of pairs in days, not months.
  • Consistency. A constitution is reproducible; humans are not.

12.3 What it loses

  • Bias inheritance. The judge model has its own biases, and they propagate.
  • Constitutional drift. A long constitution is not always followed precisely; principles get weighted unevenly.
  • Worse on out-of-distribution preferences. Where humans use common sense the model lacks, RLAIF fails.

In practice, hybrid pipelines-RLAIF for breadth, human-rater preference data for hot-button axes (safety, factuality, sensitive domains)-outperform either pure approach.


13. Frontier-scale fine-tuning

The curriculum's Sequence 15 leaves a gap: how do you actually fine-tune a 70B (or 405B, or 671B) model? The answer involves the distributed- training stack from AI_SYSTEMS_PLAN/DEEP_DIVES/06 and the numerics from /11. Here is the integration view.

13.1 The model parallelism axis

A 70B in bf16 is 140 GB, plus activations, plus optimizer state. It does not fit on a single 80 GB GPU even for inference (close, with KV). Training on a single node (8×80 GB = 640 GB) requires:

  • FSDP (Fully Sharded Data Parallel, ZeRO-3 equivalent). Shards parameters, gradients, and optimizer state across data-parallel ranks. Each rank holds 1/N of the parameters at rest; gathers full layers on demand for forward/backward. Cross-ref /06.
  • Activation checkpointing. Discard activations during forward; recompute during backward. ~2× the compute; ~5× less activation memory. Without it, 70B SFT does not fit anywhere reasonable.
  • Mixed precision. Bf16 for parameters and activations; fp32 for optimizer master weights and accumulations. FP8 if your GPUs support it (H100, B200)-see /11 for the scaling-factor and stochastic- rounding considerations.

13.2 Practical configuration: 70B FT on 8×H100

  • Method. QLoRA or LoRA (rarely full FT at this scale on a single node).
  • LoRA r. 64 (more than 7B because the model has more capacity to exploit).
  • Sharding. FSDP with full-shard, mixed precision bf16.
  • Activation checkpointing. On.
  • Per-device batch. 1, with gradient accumulation to effective batch 64–128.
  • Sequence length. Match production (4 k–8 k typical).
  • LR. 1e-4 for LoRA at this scale.
  • Throughput expectation. ~15–40 k tokens/sec across 8×H100 with activation checkpointing, roughly.

This is plausible. A single 8×H100 node can fine-tune 70B in a few hours to days, depending on data volume.

13.3 Full FT at 70B+: 32×H100

Full fine-tuning at 70B requires roughly 32×H100 with multi-node FSDP, or 8 nodes × 8 GPUs with carefully tuned communication. The bottleneck is the all-gather/reduce-scatter pattern of FSDP across the cluster interconnect (NVLink within node, InfiniBand between). See /06 for the full breakdown.

For 405B+, you are squarely in tensor + pipeline + data + sequence parallelism territory. The tooling is Megatron-LM, NeMo, or DeepSpeed at scale. Cross-ref /06.

13.4 The takeaway

For most teams, never full-FT a 70B+. Use QLoRA. The quality gap is small (§14), the cost gap is large, and the operational complexity gap is enormous.


14. Full FT vs LoRA-the decision

14.1 Quality

Full FT consistently beats LoRA in head-to-head comparisons, but the gap is small: usually 0.5–2 percentage points on benchmark suites. For most tasks, this is below the noise floor of evaluation.

LoRA's quality scales with r. The curve flattens around r=64 for most tasks; pushing to r=128 rarely helps. The right move is to target more modules (Q, K, V, O, MLP gate/up/down) rather than push r higher.

14.2 Cost

LoRA is 10–100× cheaper in training compute. The savings come from:

  • Smaller optimizer state. 100× fewer trainable parameters → 100× smaller Adam state.
  • Smaller gradients. Same factor.
  • No need to checkpoint full weights. Only the LoRA tensors.

QLoRA is another 2–4× on top of that for memory.

14.3 Adapter portability

A LoRA adapter is tens of MB on disk. You can email it. You can store 1 000 of them in a directory. You can deploy multi-tenant fine-tunes serving 100+ customers from one base model (§4.8).

A full fine-tune is the size of the model-tens of GB. Each one is its own deployment.

14.4 When full FT actually wins

  • Very large data (>100 k–1 M examples). LoRA's low-rank constraint starts to bite when there's enough signal to overflow the rank-r bottleneck.
  • Substantial behavior shift. New language, new modality, new output structure-these are big distribution moves; LoRA can be insufficient.
  • Continued pretraining (not really fine-tuning). Domain-adaptive pretraining on hundreds of millions of tokens of new corpus benefits from full updates.

For everything else-task-specific FT, persona FT, format/tone FT, preference learning-LoRA is the answer.

14.5 Decision matrix

Situation Recommended
<10 k examples, behavior tweak LoRA
10 k–100 k examples, single domain LoRA (r=32–64)
Multi-tenant: 1 base + many customers LoRA (multi-LoRA serving)
100 k–1 M examples, broad shift Full FT or large QLoRA
Continued pretraining on new corpus Full FT
Low VRAM (single 24–48 GB GPU), large base QLoRA
Frontier scale (70B+) QLoRA, almost always
Preference alignment DPO with LoRA
Reasoning RL with verifiable rewards GRPO

15. Evaluation before and after FT

Eval is non-negotiable. It is also the most-skipped step in fine-tuning projects. The headline failure mode of fine-tuning is "the new model behaves better on the dev set and worse in production"-which is exactly what bad eval discipline produces. Cross-ref /08.

15.1 The four eval surfaces

  1. In-distribution held-out test set. Build before you train. Hold out 5–20 % of your fine-tune data, never let it touch a training batch. Report metrics here.
  2. Out-of-distribution eval. General-capability suites: MMLU, GSM8K, HumanEval, plus your own out-of-domain prompts. The question this answers: did we lose general capability? If MMLU drops 5 points, you have a forgetting problem.
  3. Production traffic eval. Ship to a small fraction of users (1 %), compare aggregate metrics (resolution rate, escalation rate, CSAT) against the previous model. The only eval that matters in the end.
  4. Calibration. Did the model become overconfident? Test on prompts where the right answer is "I don't know" or "I need clarification." Fine-tuned models often lose calibration because the FT data contains few abstentions.

15.2 The pre-FT baseline

Before training, evaluate the base model + your prompt + your RAG on the same eval set. This is your floor. Any FT that doesn't beat the floor by a meaningful margin (≥ 5 % on the metric you care about) is not worth shipping.

This is the most common mistake: teams skip the baseline, train, observe "the model works," ship, and discover later that the prompt-only baseline was already there.

15.3 Eval gates in CI

  • Run eval after every fine-tune.
  • Compare to the production model.
  • Block deploy on regression > X % on any axis (in-domain, OOD, safety, calibration).

This is hygiene. It's also rare in practice. Building it once pays for itself ten times over.

15.4 Things that go wrong in eval

  • Train-eval contamination. The eval set leaks into training data through some path you didn't notice. Always hash and check.
  • Metric overfitting. Optimizing for eval metric without measuring qualities the metric doesn't capture (e.g., toxicity, hallucination).
  • Ignoring OOD. "Our customer-support metric is up 12 %!" while MMLU drops 8 points. The model is now narrower; in production it hits OOD prompts and degrades.

16. The end-to-end FT workflow

A practical sequence that compounds rather than thrashes:

Step 1-Define the eval set first

Before any training. Hold out 200–500 examples. Define the metrics. Run the base model + prompt + RAG against it; record the baseline numbers. Cross-ref /08.

Step 2-Baseline: prompt + RAG

Try to solve the problem without FT. Iterate prompt and retrieval until you've extracted what's reasonable. Record the result. This is your baseline.

Step 3-SFT-LoRA on small data

Curate ~1 000 high-quality (prompt, completion) pairs. Run a LoRA SFT (r=16, 1–3 epochs, LR 2e-4). Evaluate. If the lift over baseline is sufficient and OOD eval is intact, ship.

Step 4-Scale data or escalate

If the lift is insufficient: - First: scale data to 10 k examples. Most gains come from data volume, not method change. - Then: increase r and target modules. - Last resort: full FT.

Step 5-Add preference learning (DPO) if behavior alignment matters

Once SFT is good, collect 5 k–30 k preference pairs (chosen, rejected). Run LoRA-DPO at LR 5e-7, β=0.1, 1 epoch. Evaluate on: - In-domain held-out preference accuracy. - OOD capability suites. - Calibration.

Step 6-Ship behind eval gates

Deploy with feature flag, A/B against the previous model. Watch production metrics for at least one full traffic cycle (week or two). Roll forward only if metrics improve and don't regress on safety.

Step 7-Monitor for drift

Re-evaluate periodically (monthly). Fine-tuned models can degrade as production traffic distribution shifts. When eval drops, retrain on fresh data-don't try to patch.


17. Practical exercises

Exercise 1-LoRA trainable parameters for Llama-7B at r=16, Q+V

Llama-7B parameters: 32 layers, hidden dim d = 4096, attention projection matrices are 4096 × 4096.

Per matrix at r=16:

trainable = r · (d_in + d_out) = 16 · (4096 + 4096) = 131 072 = 128 K

Q and V per layer: 2 · 128 K = 256 K. Across 32 layers: 32 · 256 K = 8 192 K = 8.0 M trainable parameters.

Total Llama-7B parameters: ~6.7 B (technically). Trainable fraction: 8.0 M / 6.7 B ≈ 0.12 %.

If you target Q, K, V, O instead of just Q+V: 4 matrices × 32 layers × 128 K = 16.4 M trainable, still 0.24 %.

If you target Q, K, V, O plus the three MLP matrices (gate, up, down, each 4096 × 11008 for Llama-7B): per MLP matrix at r=16, 16 · (4096 + 11008) = 241 664 ≈ 236 K. Three of them per layer = 708 K. Across 32 layers: 22.6 M. Total with attention: ~39 M. Still under 0.6 % of the model.

Exercise 2-Derive the DPO loss

Given: 1. The KL-regularized RL objective with optimal policy π*(y|x) = (1/Z(x)) · π_SFT(y|x) · exp(r(x,y)/β). (See §6.3 for the derivation.) 2. The Bradley-Terry preference model P(y_w ≻ y_l | x) = σ(r(x, y_w) - r(x, y_l)).

Step A-invert (1) to express r in terms of π* and π_SFT:

log π*(y|x) = log π_SFT(y|x) + r(x,y)/β - log Z(x)
r(x, y) = β · log(π*(y|x)/π_SFT(y|x)) + β · log Z(x)

Step B-substitute into the BT difference:

r(x, y_w) - r(x, y_l)
  = β·log(π*(y_w|x)/π_SFT(y_w|x)) - β·log(π*(y_l|x)/π_SFT(y_l|x))

The β·log Z(x) terms cancel because they are independent of y.

Step C-replace π* with π_θ (trainable) and π_SFT with π_ref (frozen):

P(y_w ≻ y_l | x; θ) = σ(  β · log(π_θ(y_w|x)/π_ref(y_w|x))
                        - β · log(π_θ(y_l|x)/π_ref(y_l|x)) )

Step D-take negative log-likelihood:

L_DPO(θ) = -E[ log σ(  β · log(π_θ(y_w|x)/π_ref(y_w|x))
                     - β · log(π_θ(y_l|x)/π_ref(y_l|x)) ) ]

Done.

Exercise 3-QLoRA memory budget for 70B on a 48 GB GPU at r=64

Inputs: - Base model: 70 B parameters. - Quantization: NF4 (4 bits per weight), with double quantization. - LoRA: r=64, applied to all linear modules. - Per-device batch: 1, sequence length 2 048, activation checkpointing on.

Computation:

  1. Quantized base model.
  2. Naive: 70 · 10⁹ · 4 bits / 8 = 35 · 10⁹ bytes = 35 GB.
  3. Quantization scales with double-quant: ~0.5 GB.
  4. Subtotal: ~35.5 GB.

  5. LoRA adapters.

  6. 70B has roughly 80 layers, hidden dim ~8 192, MLP intermediate ~28 672. Linear modules per layer: Q (8192×8192), K (8192×1024 for GQA-Llama-3 70B has 8 KV heads of 128 dim, so K and V are actually 8192 × 1024), V (8192×1024), O (8192×8192), gate (8192×28672), up (8192×28672), down (28672×8192).
  7. Trainable per matrix at r=64:
    • Q: 64·(8192+8192) = 1.05 M
    • K: 64·(8192+1024) = 0.59 M
    • V: 64·(8192+1024) = 0.59 M
    • O: 64·(8192+8192) = 1.05 M
    • gate: 64·(8192+28672) = 2.36 M
    • up: 64·(8192+28672) = 2.36 M
    • down: 64·(28672+8192) = 2.36 M
    • Total per layer: ~10.4 M
  8. Across 80 layers: ~830 M trainable parameters.
  9. In bf16: 830 M · 2 = 1.66 GB.
  10. Subtotal: ~1.7 GB.

  11. Optimizer state (paged AdamW 8-bit).

  12. 8-bit Adam stores moments in 8-bit; effectively ~1 byte per moment × 2 moments × 830 M = ~1.7 GB.
  13. With paging, peaks may spill to CPU RAM transparently.
  14. Subtotal: ~1.7 GB resident, more in CPU.

  15. Activations + gradients (with activation checkpointing).

  16. Highly model- and seqlen-dependent. For 70B, batch 1, seq 2048, bf16, with AC: roughly 6–10 GB resident peak.
  17. Subtotal: ~8 GB.

  18. Slack: KV state during forward generation in eval, kernel workspaces, fragmentation: ~2 GB.

Sum: 35.5 + 1.7 + 1.7 + 8 + 2 ≈ 49 GB.

Conclusion: 70B QLoRA at r=64 fits barely on a 48 GB GPU and comfortably on a 80 GB GPU. To make 48 GB work in practice, drop to r=32, reduce sequence length to 1 024, or accept paging-induced slowdowns.

Exercise 4-Preference data collection guidelines (5-page spec)

Brief outline; flesh into a real spec for your team.

§1. Goals. Collect N preference pairs to fine-tune model M on behavior axis B. Define B precisely (e.g., "tone consistent with brand voice while preserving accuracy").

§2. Sources and stratification. - 60 % from production traffic (real user prompts). - 30 % from adversarial / edge-case prompts authored by the team. - 10 % from synthetic prompts generated by an LLM with seed topics. - Stratify by: domain, difficulty, length, sensitive content.

§3. Generation protocol. - For each prompt, sample two completions from M. - Use temperature 0.7 to ensure diversity but maintain quality. - Discard prompts where both responses are clearly bad. - Discard prompts where both responses are nearly identical.

§4. Rater pool. - 6–10 raters minimum. - 2 senior raters as gold standard. - Onboarding: 100 calibration pairs, must achieve κ ≥ 0.6 vs gold. - Re-calibration weekly with 20 fresh gold pairs.

§5. Annotation interface. - Show prompt + two completions in random order. - Rater selects "A better," "B better," "equal," "both bad." - Optional comment field for hard cases. - Discard "equal" and "both bad" from training data.

§6. Quality controls. - 10 % of pairs are gold-standard, double-annotated by senior raters. Cohen's κ measured weekly per rater; raters with κ < 0.5 are retrained or removed. - 5 % of pairs are duplicates surfaced after a 1-week gap; raters who flip on duplicates are flagged. - LLM-as-judge runs in parallel for triage; high-disagreement pairs surface to senior review.

§7. Stratification check. - Audit final dataset distribution by domain / difficulty bins. - Reject and resample if any bin is <5 % or >40 % of total.

§8. Privacy and safety. - Strip PII before raters see prompts. - Skip pairs where both responses are policy-violating. - Document and version the rater guidelines.

§9. Versioning and provenance. - Each pair carries: source prompt id, model checkpoint, sampling config, rater id, timestamp, agreement scores. - Dataset is versioned; every fine-tune cites the dataset hash.

Exercise 5-Eval matrix for a customer-support fine-tune

Eval surfaces, with target metrics:

Surface Test set Metrics Pass criteria
In-domain 500 held-out support tickets Resolution rate, factual accuracy, brand voice score ≥ baseline + 8 %
Out-of-domain MMLU 1k, GSM8K 200, HumanEval 164 Standard scores ≤ 2 pp regression vs base instruct
Refusals 200 prompts known to require refusal Refusal rate, refusal quality (LLM judge) ≥ 95 % refusal rate
Hallucinations 200 prompts with known answers Hallucination rate (human-judged) ≤ 3 %
Calibration 100 ambiguous prompts "I don't know" rate, expected calibration error ECE ≤ baseline
Adversarial / safety 300 jailbreak-style prompts Safety violation rate ≤ 0.5 %
Long-context 50 long-thread tickets Resolution rate ≥ baseline
Multi-turn 100 multi-turn conversations Turn-level coherence (LLM judge) ≥ 4.0 / 5
Production A/B 1 % live traffic CSAT, escalation rate, AHT CSAT ≥ control, escalation ≤ control

Run all but the A/B in CI. A/B runs only after CI passes.

Exercise 6-GRPO step on a 4-completion group, rewards [0.8, 0.5, 0.3, 0.7]

Group rewards: r = [0.8, 0.5, 0.3, 0.7].

Mean: (0.8 + 0.5 + 0.3 + 0.7) / 4 = 2.3 / 4 = 0.575.

Variance:

((0.8-0.575)² + (0.5-0.575)² + (0.3-0.575)² + (0.7-0.575)²) / 4
= (0.0506 + 0.0056 + 0.0756 + 0.0156) / 4
= 0.1475 / 4
= 0.0369

Std: √0.0369 ≈ 0.192.

Group-relative advantages: - A_1 = (0.8 - 0.575) / 0.192 = 0.225 / 0.192 ≈ +1.17 - A_2 = (0.5 - 0.575) / 0.192 = -0.075 / 0.192 ≈ -0.39 - A_3 = (0.3 - 0.575) / 0.192 = -0.275 / 0.192 ≈ -1.43 - A_4 = (0.7 - 0.575) / 0.192 = 0.125 / 0.192 ≈ +0.65

Sanity: positive advantages for above-average completions (1, 4), negative for below (2, 3). Sum of advantages is zero by construction (mean-centered). Magnitudes scaled by the within-group spread.

Per-completion update (sketch, ignoring KL): - Completion 1 gets pushed up with strength 1.17. - Completion 2 gets pushed down with strength 0.39. - Completion 3 gets pushed down with strength 1.43. - Completion 4 gets pushed up with strength 0.65.

The PPO clip then bounds each per-token ratio update. The KL term β · KL(π_θ || π_ref) is added to the loss separately to keep the policy close to the reference.


Appendix-End-to-end pseudocode skeleton

The following is the complete shape of an SFT → DPO pipeline using TRL on a single 8×H100 node. It is meant to be readable, not runnable; specific imports and config flags will drift across TRL versions.

# ---- 0. shared config ----
BASE = "meta-llama/Llama-3.1-8B-Instruct"
SFT_OUT = "out/sft"
DPO_OUT = "out/dpo"

# ---- 1. SFT-LoRA ----
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset

tok = AutoTokenizer.from_pretrained(BASE)
model = AutoModelForCausalLM.from_pretrained(BASE, torch_dtype="bfloat16")
model = get_peft_model(model, LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05,
    target_modules=["q_proj","k_proj","v_proj","o_proj",
                    "gate_proj","up_proj","down_proj"],
    task_type="CAUSAL_LM"))

sft_data = load_dataset("json", data_files="data/sft.jsonl", split="train")

SFTTrainer(
    model=model, tokenizer=tok, train_dataset=sft_data,
    args=SFTConfig(
        output_dir=SFT_OUT, num_train_epochs=2,
        per_device_train_batch_size=4, gradient_accumulation_steps=8,
        learning_rate=2e-4, warmup_ratio=0.05,
        bf16=True, packing=True, max_seq_length=4096,
    ),
).train()

# ---- 2. eval after SFT ----
# (run held-out eval, OOD MMLU/GSM8K, calibration; gate on metrics)

# ---- 3. DPO-LoRA ----
from trl import DPOTrainer, DPOConfig

# Load the SFT-LoRA into a fresh policy; DPO will treat the base
# (with adapters disabled) as the reference automatically.
policy = AutoModelForCausalLM.from_pretrained(SFT_OUT,
                                              torch_dtype="bfloat16")
prefs = load_dataset("json", data_files="data/prefs.jsonl", split="train")

DPOTrainer(
    model=policy, ref_model=None, tokenizer=tok,
    train_dataset=prefs,
    args=DPOConfig(
        output_dir=DPO_OUT, num_train_epochs=1,
        per_device_train_batch_size=2, gradient_accumulation_steps=8,
        learning_rate=5e-7, beta=0.1,
        warmup_ratio=0.05, bf16=True,
    ),
).train()

# ---- 4. eval after DPO ----
# Repeat held-out eval, OOD eval, preference accuracy, calibration.
# Ship behind A/B if all gates pass.

The pipeline embodies the chapter's whole argument: a small LoRA SFT on curated demonstrations, eval gates, then a DPO pass on preference pairs, then more eval, then a careful production rollout. No PPO. No critic. No reward model. The result is competitive with the full classical RLHF stack at a fraction of the operational cost-and that is why the field has converged on this recipe as the default.


Citations and further reading

  • Hu, E. J. et al. (2021). LoRA: Low-Rank Adaptation of Large Language Models. arXiv:2106.09685.
  • Dettmers, T. et al. (2023). QLoRA: Efficient Finetuning of Quantized LLMs. arXiv:2305.14314.
  • Rafailov, R. et al. (2023). Direct Preference Optimization: Your Language Model is Secretly a Reward Model. arXiv:2305.18290.
  • Shao, Z. et al. (2024). DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models (introduces GRPO). arXiv:2402.03300.
  • Bai, Y. et al. (2022). Constitutional AI: Harmlessness from AI Feedback. arXiv:2212.08073.
  • Bradley, R. A. and Terry, M. E. (1952). Rank Analysis of Incomplete Block Designs: I. The Method of Paired Comparisons. Biometrika.
  • Schulman, J. et al. (2017). Proximal Policy Optimization Algorithms. arXiv:1707.06347.
  • Lightman, H. et al. (2023). Let's Verify Step by Step. arXiv:2305.20050.
  • Ouyang, L. et al. (2022). Training language models to follow instructions with human feedback (InstructGPT). arXiv:2203.02155.

Cross-references inside this curriculum:

  • Distributed training (FSDP, ZeRO, tensor/pipeline parallelism): AI_SYSTEMS_PLAN/DEEP_DIVES/06.
  • Mixed precision, FP8, numerics: AI_SYSTEMS_PLAN/DEEP_DIVES/11.
  • Eval discipline: AI_SYSTEMS_PLAN/DEEP_DIVES/08.

End of Deep Dive 10.

Comments