Saltar a contenido

Deep Dive 11-Numerics and Mixed Precision

"Floating-point arithmetic is the silent assassin of deep learning. Most training divergences are not bugs in the model; they are bugs in the number system."

A neural network is, at the end of the day, a chain of arithmetic operations executed on finite-precision hardware. The mathematics on the whiteboard treats real numbers; the GPU treats sequences of bits with explicit rounding rules. Whether your run converges, plateaus, NaNs, or silently bias-shifts is determined by the gap between those two worlds. This chapter is the reference for closing that gap.

We will derive-not just state-IEEE-754 floating point, walk through every format relevant to ML (FP64, FP32, TF32, FP16, BF16, FP8 E4M3 / E5M2, FP4), explain why each operation in a transformer needs the precision it does, write out the loss-scaling and FP8 delayed-scaling algorithms in pseudocode, and finish with worked exercises that you should be able to do on paper.

Read this once carefully. Then re-read sections 4, 7, and 11 the next time a training run NaNs.


Table of contents

  1. IEEE-754 in 30 minutes
  2. The ML floating-point zoo
  3. Operation-by-operation precision requirements
  4. The standard mixed-precision recipe
  5. Loss scaling, derived
  6. Why BF16 is different (and what it costs)
  7. FP8 training in detail
  8. TF32: the silent precision drop
  9. Adam + low precision pitfalls
  10. Catastrophic cancellation in reductions
  11. Numerical stability tricks in transformers
  12. Detecting and handling NaN
  13. Determinism
  14. Practical exercises

1. IEEE-754 in 30 minutes

1.1 Why we cannot use real numbers

A real number x ∈ ℝ requires, in general, infinite information to represent. Computers store fixed-width approximations. The IEEE-754 standard (1985, revised 2008 and 2019) defines a family of binary floating-point formats and the rounding rules for arithmetic on them.

A binary floating-point number is a triple (s, e, m) interpreted as

x = (-1)^s × 2^E × M

where: - s is one sign bit (0 = positive, 1 = negative), - e is the biased exponent stored in n_exp bits, - m is the mantissa (also called significand fraction) stored in n_man bits.

The "biased" part means that the exponent field stores e = E + bias, where bias = 2^(n_exp - 1) - 1. We do this so that the exponent field can represent both negative and positive E while remaining an unsigned integer-e ranges from 0 to 2^n_exp - 1, and E ranges from 1 - bias to bias.

The mantissa stores only the fractional part. There is an implicit leading 1 for normal numbers:

M = 1.m_{n_man-1} m_{n_man-2} ... m_1 m_0   (binary)
  = 1 + sum_{i=0}^{n_man-1} m_i × 2^(i - n_man)

So for FP32 with 23 mantissa bits, M lies in [1, 2) with a granularity of 2^-23 ≈ 1.19e-7.

1.2 The FP32 example

FP32 (binary32): 1 + 8 + 23 = 32 bits.

  • bias = 2^7 - 1 = 127.
  • Normal e range: 1 to 254. So E ranges from - 126to+127`.
  • Smallest normal: 1.0 × 2^-126 ≈ 1.175e-38.
  • Largest finite: (2 - 2^-23) × 2^127 ≈ 3.403e+38.
  • Machine epsilon `ε = 2^-23 ≈ 1.19e-7 - the gap between 1 and the next representable number.

FP32 represents about 7 decimal significant digits because log10(2^23) ≈ 6.92.

1.3 Subnormals (denormals)

The exponent code e = 0 is special: it represents subnormal (denormal) numbers, which fill the gap between zero and the smallest normal:

x_subnormal = (-1)^s × 2^(1 - bias) × (0.m_{n_man-1} ... m_0)_2

Note the implicit leading bit becomes 0 instead of 1, and the exponent is fixed at 1 - bias (not 0 - bias, a one-off to make the transition continuous). For FP32:

  • Smallest positive subnormal: 2^-23 × 2^-126 = 2^-149 ≈ 1.4e-45.
  • Largest subnormal: (1 - 2^-23) × 2^-126, just under the smallest normal.

Subnormals enable gradual underflow: as a value shrinks below the smallest normal, it loses precision bit by bit but does not abruptly become zero. Some hardware flushes subnormals to zero (FTZ/DAZ flags) for performance-this matters in DSPs and is occasionally encountered on GPUs.

1.4 Special values

The exponent code e = 2^n_exp - 1 is also special:

Field Mantissa Meaning
e = 0, m = 0, s = 0 - +0
e = 0, m = 0, s = 1 - - 0`
e = 0, m ≠ 0 - subnormal
1 ≤ e ≤ 2^n_exp - 2 any normal
e = all-ones, m = 0 - ±inf
e = all-ones, m ≠ 0 - NaN

NaN comes in two flavours: quiet NaN (qNaN) and signaling NaN (sNaN), distinguished by the high bit of the mantissa. ML rarely cares; both propagate through arithmetic and we generally trap on either.

**+0 versus - 0:** they compare equal but1 / (+0) = +infwhile1 / (-0) = -inf`. This is a common source of subtle bugs in custom ops.

1.5 Rounding modes

IEEE-754 defines five modes; only two matter day to day:

  1. Round to nearest, ties to even (RNE)-default. If a real number falls exactly between two representable floats, pick the one whose last mantissa bit is 0. RNE is unbiased: averaging many rounded results does not introduce a systematic drift.
  2. Round toward zero (truncation)-used in some quantization paths.
  3. Round toward +inf, round toward -inf-interval arithmetic.
  4. Round to nearest, ties away from zero-common in financial code, rare in ML.

Stochastic rounding (section 9.2) is not in IEEE-754 but is increasingly important in low-bit training.

The fundamental error bound: for any operation op ∈ {+, -, ×, /, sqrt}, the IEEE-754 result satisfies

fl(a op b) = (a op b) × (1 + δ),    |δ| ≤ ε / 2

with ε the machine epsilon and the relative error bounded by the unit roundoff u = ε / 2.

1.6 Sources of error in + and ×

Multiplication. fl(a × b) = ab × (1 + δ) with |δ| ≤ u. The error is always relative to the magnitude of the result, so multiplication is benign.

Addition. fl(a + b) = (a + b) × (1 + δ) likewise, but the worst case happens when a ≈ -b: the relative error stays small, but the absolute error of the result is large compared to the (small) result. This is catastrophic cancellation: subtracting two nearly equal numbers exposes their roundoff.

Example in FP32:

a = 1.0000001
b = 1.0000000
a - b = 1e-7  (in real arithmetic)
        but  fl(a) = 1.0000001 ± 6e-8
              fl(b) = 1.0       ± 6e-8
              fl(a - b) = 1e-7 with absolute error ~ 1.2e-7

The result is dominated by roundoff. Section 10 returns to this.

1.7 Compound operations: FMA

Most modern hardware exposes a fused multiply-add (FMA): fl(a × b + c) computed with one rounding instead of two. FMA is more accurate than separate mul then add, and most tensor-core math is built on this. Different FMA orderings are why (a + b) + c ≠ a + (b + c) in general.


2. The ML floating-point zoo

We now place every format used in practice next to the others.

2.1 Bit layouts

            sign  exp  man   bias   smallest normal       max finite
FP64        1     11   52    1023   ~2.225e-308           ~1.798e+308
FP32        1      8   23    127    ~1.175e-38            ~3.403e+38
TF32        1      8   10    127    ~1.175e-38            ~3.403e+38   (compute, not storage)
FP16        1      5   10     15    ~6.104e-5             ~6.550e+4
BF16        1      8    7    127    ~1.175e-38            ~3.389e+38
FP8 E4M3    1      4    3      7    ~2^-9 = 1.95e-3       448  (or 240 in saturating variant)
FP8 E5M2    1      5    2     15    ~6.104e-5             ~5.734e+4
FP4 E2M1    1      2    1      1    ~0.5                  6

A few subtleties:

  • TF32 is not a storage format. Tensors are stored as FP32 in memory; the tensor core internally rounds operands to a 1+8+10 layout (FP32 range, FP16 precision) before doing the matmul, then accumulates in FP32. From the user's API, the tensors look like FP32-only the computation is degraded. Section 8.
  • FP8 E4M3 in the OFP8 standard (Micikevicius et al., 2022; adopted by NVIDIA's TransformerEngine, Intel, AMD, ARM) does not support infinities. The e=all-ones, m=all-ones codepoint is reused as a finite value, raising max from 240 to 448. Some implementations instead reserve that codepoint for inf/NaN, giving max = 240. Both variants exist; check your library.
  • FP8 E5M2 is fully IEEE-754-shaped: it has inf and NaN. Max ≈ 57344, smallest normal ≈ 6.10e-5-the same range as FP16.
  • FP4 E2M1 has only 16 codepoints total (including sign and zero). Practical FP4 training requires per-block scaling (e.g., MXFP4 with 32-element blocks) to be viable at all.

2.2 Range and precision, side-by-side

Format Decimal digits Dynamic range (orders of magnitude) Use
FP64 ~15–16 ~600 Scientific computing; rarely ML
FP32 ~7 ~76 Master weights, optimizer state
TF32 ~4 ~76 Hidden tensor-core compute
FP16 ~3–4 ~10 Mixed-precision compute (legacy)
BF16 ~2–3 ~76 Mixed-precision compute (default)
FP8 E4M3 ~1–2 ~5 Activations, weights
FP8 E5M2 ~1 ~10 Gradients
FP4 <1 ~2 Inference; experimental training

Decimal digits = log10(2^(n_man + 1)) (the +1 from the implicit leading bit). Dynamic range = log10(max / smallest_normal).

The two axes-precision (mantissa bits) and range (exponent bits)-trade off independently. FP16 and BF16 are both 16-bit, but FP16 spends 5 exponent bits and 10 mantissa bits, while BF16 spends 8 and 7. BF16 gives up half the precision in exchange for the full FP32 range. For deep learning, where activations and gradients can span many orders of magnitude, range matters more than precision. This is the single most important fact in ML numerics.

2.3 Memory cost

Per parameter, with Adam optimizer states:

Configuration Weights Grads m v Total
Pure FP32 4 4 4 4 16 B
FP16/BF16 + FP32 master 2 + 4 2 4 4 16 B
FP8 + FP32 master 1 + 4 1 4 4 14 B
FP8 + FP16 master 1 + 2 1 2 2 8 B

Mixed precision saves memory for activations (during forward we hold half-precision tensors), not for parameters-until you're willing to sacrifice the FP32 master weights, which most production runs are not.


3. Operation-by-operation precision requirements

Different operations have different sensitivity to precision. Get this wrong and you waste bits where they don't matter while starving operations that do.

3.1 Matrix multiply

For C = A × B where A ∈ ℝ^{m×k}, B ∈ ℝ^{k×n}:

C_{ij} = sum_{p=1}^{k} A_{ip} × B_{pj}

With low-precision inputs and a k - element accumulation, the error is dominated by the **accumulator**, not the multiplicands. Each multiplication contributes one rounding (urelative); the sum then accumulateskof these. For naive sequential summation, the worst-case error isO(k × u × max|x|)`.

Tensor cores always accumulate in higher precision than they multiply:

Input Accumulator
FP16 / BF16 FP32
FP8 E4M3 / E5M2 FP32
TF32 FP32
INT8 INT32

You can ask for FP16 accumulation on some old hardware; you should not. For k ≈ 4096 (typical hidden dim), FP16 accumulation can lose 3 decimal digits; FP32 accumulation keeps the error around the unit roundoff.

The matmul lesson: inputs can be cheap; the accumulator must be expensive.

3.2 Reductions (sum, mean, norm)

Reductions are more sensitive than matmuls because: 1. The number of terms N (e.g., the feature dimension in LayerNorm) can be larger than the inner dim of typical matmuls. 2. Naive sequential summation has O(N) worst-case error growth. 3. Layer norm / RMS norm involves both a sum (mean) and a sum of squares (variance)-the latter is even more sensitive.

Practical rule: reductions always promote to FP32, regardless of input dtype. PyTorch and TF do this by default for mean, sum, var, norm. Check your custom kernels.

We dissect reductions in section 10.

3.3 Softmax

The softmax s_i = exp(x_i) / sum_j exp(x_j) has two failure modes in low precision:

  • Overflow: if any x_i > log(max_finite), then exp(x_i) = inf. For FP16, log(65504) ≈ 11.09. Logits routinely exceed this in attention.
  • Underflow: if all exp(x_i) are below the smallest normal, the denominator is zero. For FP16, smallest normal is ~6e-5, so any x_i < log(6e-5) ≈ -9.7 underflows.

Standard fix: subtract the max,

m = max_i x_i
s_i = exp(x_i - m) / sum_j exp(x_j - m)

Now the largest exponent argument is 0, so exp(...) ≤ 1. The subtraction is exact as long as x_i and m are close in magnitude (which they are after normalization).

This works mathematically because exp(x_i) / sum_j exp(x_j) = exp(x_i - m) / sum_j exp(x_j - m) - multiplying numerator and denominator byexp(-m)`.

Online softmax (FlashAttention) uses an incremental version of this trick; we cover it in the attention deep dive.

3.4 Gradient accumulation

When training with micro-batching, you accumulate gradients across micro-batches before stepping the optimizer:

grad_buf += grad_micro

If grad_buf is in BF16 and grad_micro is small, the accumulation underflows. Always accumulate in higher precision than the gradients themselves. PyTorch's GradScaler and DeepSpeed do this automatically; if you write your own pipeline, you must do it explicitly.

Concretely: gradients computed in BF16 should accumulate into an FP32 buffer. Gradients computed in FP8 should accumulate into FP16 or FP32. Otherwise the small contributions are lost: large_buf + tiny = large_buf whenever tiny < ε × large_buf.


4. The standard mixed-precision recipe

The Micikevicius et al. (2018) recipe is the foundation. Every modern training stack-Apex, PyTorch AMP, DeepSpeed, Megatron, JAX/Flax-implements it.

4.1 The four invariants

  1. Master weights in FP32. The "real" parameters live in FP32. We make a cast copy in low precision (FP16 or BF16) for the forward and backward passes.
  2. Forward and backward in low precision. Activations, weight matmuls, attention, layer norms (with FP32 accumulation), all run in FP16 or BF16.
  3. Gradients in low precision are produced by autograd, then immediately upcast to FP32 before going into the optimizer.
  4. Optimizer states (m, v for Adam) in FP32. The optimizer step is computed entirely in FP32; only after stepping do we re-cast the master weights down to low precision for the next forward.

If you violate any of these, expect divergence on long runs.

4.2 The full step

# Setup
master_weights_fp32 = init_weights()
adam_m_fp32 = zeros_like(master_weights_fp32)
adam_v_fp32 = zeros_like(master_weights_fp32)
loss_scale = 2**15  # FP16 only; for BF16 set to 1

for batch in dataloader:
    # 1. Cast master to low precision for compute
    weights_lp = master_weights_fp32.to(low_precision_dtype)

    # 2. Forward in low precision
    logits = model(batch.x, weights_lp)
    loss   = cross_entropy(logits, batch.y)

    # 3. Scale loss (FP16 only)
    loss_scaled = loss * loss_scale

    # 4. Backward in low precision; produces grads in low precision
    grads_lp = backward(loss_scaled, weights_lp)

    # 5. Upcast and unscale
    grads_fp32 = grads_lp.to(fp32) / loss_scale

    # 6. NaN/inf check (FP16 only); skip step if found
    if any_nan_or_inf(grads_fp32):
        loss_scale /= 2
        continue

    # 7. Optimizer step in FP32
    adam_m_fp32 = beta1 * adam_m_fp32 + (1-beta1) * grads_fp32
    adam_v_fp32 = beta2 * adam_v_fp32 + (1-beta2) * grads_fp32 ** 2
    master_weights_fp32 -= lr * adam_m_fp32 / (sqrt(adam_v_fp32) + eps)

    # 8. Optional: increase loss scale after a streak of clean steps
    successful_steps += 1
    if successful_steps >= 2000:
        loss_scale *= 2
        successful_steps = 0

The loss_scale machinery is unique to FP16 (section 5). For BF16, you can set it to 1 and remove the NaN-check / dynamic-update branches, but you should still upcast grads to FP32 before the optimizer step.

4.3 What about activation memory?

In the forward pass, we save activations for the backward pass. These should be stored in the same precision they were computed in (FP16/BF16/FP8)-that's where the memory saving comes from. The FP32 master weights add only 4 × N_params bytes, while activation memory grows with batch_size × seq_len × hidden_dim × num_layers, so for big models the half-precision activations dominate.

This is also why "FP32 training" without master weights wastes memory: FP32 activations are 2× the BF16 ones, and the activations are usually the bigger chunk.


5. Loss scaling, derived

5.1 Why FP16 needs it

Gradients in deep networks at end of training can be very small. The smallest positive normal FP16 is 2^-14 ≈ 6.10e-5; with subnormals you can get down to 2^-24 ≈ 5.96e-8 but with shrinking precision. Anything smaller silently becomes zero.

In practice, late-training gradients for many parameters cluster around 1e-7 to 1e-9. They underflow FP16. The optimizer sees zero, the parameter does not update, the network plateaus.

5.2 The trick

Multiply the loss by a large constant S before backward:

loss_scaled = S × loss

By the chain rule, every gradient is multiplied by S:

∂(S × loss) / ∂w = S × ∂loss / ∂w

Now if the original gradient was 1e-9, the scaled gradient is S × 1e-9. With S = 2^15 = 32768, the scaled value is ~3.3e-5, which is comfortably representable.

After the backward pass, before the optimizer step, upcast to FP32 and divide by S to restore the true gradient magnitude.

5.3 Static loss scaling

Pick S once and leave it. Common choices: 2^7, 2^10, 2^15. Too small: gradients still underflow. Too large: gradients overflow to inf, kill the step.

Static scaling is simple but fragile. A bad batch can blow it up; a phase change in training can render it suboptimal.

5.4 Dynamic loss scaling

The standard algorithm (used by NVIDIA Apex AMP, PyTorch torch.amp.GradScaler):

S          := 2^15        # initial loss scale
streak     := 0
patience   := 2000        # successful steps before doubling
backoff    := 0.5         # multiplier on overflow (halve)
growth     := 2.0         # multiplier after streak (double)
S_max      := 2^24
S_min      := 1.0

for each step:
    grads_lp := backward(S * loss, weights_lp)
    grads_fp := upcast(grads_lp) / S

    if any_nan_or_inf(grads_fp):
        S       := max(S * backoff, S_min)
        streak  := 0
        skip optimizer step    # crucial: do NOT update weights
        continue

    optimizer.step(grads_fp)
    streak += 1
    if streak >= patience:
        S       := min(S * growth, S_max)
        streak  := 0

Two important details often missed:

  1. On overflow, you must skip the step, not clip and proceed. The corrupted gradients have no useful information.
  2. Check for NaN/inf on the scaled gradients before unscaling, or equivalently on the unscaled-the operation is just a divide. Fused implementations check during the unscale.

5.5 Choosing S_max and patience

S_max = 2^24 is conservative: even if a single gradient is ~1, multiplying by 2^24 ≈ 1.6e7 puts it in FP16 overflow range (6.5e4). So in practice runs settle to S between 2^10 and 2^16, occasionally pushing higher.

patience = 2000 is an empirical choice from Apex. Lower (say 100) and you double too aggressively, causing frequent overflow-rollback cycles. Higher (10000) and you under-utilize the FP16 range during long calm phases.

5.6 Why this is invisible to the user (mostly)

PyTorch wraps it in GradScaler:

scaler = torch.cuda.amp.GradScaler()
for batch in loader:
    with torch.cuda.amp.autocast(dtype=torch.float16):
        loss = model(batch).loss
    scaler.scale(loss).backward()
    scaler.unscale_(optimizer)         # divides grads by S in-place
    clip_grad_norm_(model.parameters(), max_norm)  # now safe
    scaler.step(optimizer)              # skips if inf/nan detected
    scaler.update()                     # adjusts S

Two lines you must not skip: unscale_ before clip_grad_norm_ (otherwise you clip the scaled grads, with the wrong norm), and scaler.update() after every step.


6. Why BF16 is different (and what it costs)

6.1 The free lunch (almost)

BF16 has 8 exponent bits, the same as FP32. Its dynamic range matches FP32: roughly 1.18e-38 to 3.4e+38. Gradients do not underflow. Loss scaling is unnecessary.

This is why BF16 has eaten the world. Hopper, TPU v3/v4/v5, Ampere, and AMD MI250+ all natively support it. Modern training defaults are:

  • Master weights: FP32
  • Compute: BF16
  • Optimizer states: FP32
  • No loss scaling, no NaN-rollback, no dynamic scale tuning.

6.2 The catch

BF16 has 7 mantissa bits, vs 10 for FP16 and 23 for FP32. That's log10(2^8) ≈ 2.4 decimal digits of precision (counting the implicit leading bit). Two consequences:

  1. Catastrophic cancellation is more likely in any subtraction or near-cancelling sum.
  2. Accumulation errors compound more aggressively. A million-element BF16 sum can lose almost all precision (we compute this in exercise 14.2).

Mitigation:

  • Accumulate everything in FP32. This is the default in PyTorch/JAX for reductions; double-check custom kernels.
  • Keep master weights in FP32. Adam's tiny updates would otherwise be lost in BF16 weights (section 9).
  • For very long runs (>10^6 steps), some practitioners use FP32 for selected layers (final norm, classification head) to avoid drift.

6.3 BF16 versus FP16 in practice

Concern FP16 BF16
Range overflow Likely without scaling Almost never
Range underflow Likely without scaling Almost never
Mantissa precision 10 bits (~3 decimal) 7 bits (~2.4 decimal)
Loss scaling needed Yes, ideally dynamic No
Hardware support Volta+ Ampere+, TPU all
Fine-tuning safety Fragile Robust

For new code, default to BF16 unless you're targeting hardware that lacks it.


7. FP8 training in detail

FP8 was introduced as a training format with NVIDIA Hopper (H100, 2022). Supporting libraries: TransformerEngine (NVIDIA), MS-AMP (Microsoft), and increasingly native frameworks.

The key insight: FP8 cannot be used "in place" of FP16/BF16. The dynamic range is too small (~5 orders of magnitude for E4M3, ~10 for E5M2). You must scale per-tensor, and you must update the scale carefully.

7.1 The two FP8 formats

E4M3 E5M2
Bits 1 + 4 + 3 1 + 5 + 2
Bias 7 15
Smallest normal 2^-6 ≈ 0.0156 2^-14 ≈ 6.10e-5
Smallest subnormal 2^-9 ≈ 0.00195 2^-16 ≈ 1.53e-5
Max finite (with inf reserved) 240 57344
Max finite (saturating variant) 448 n/a (inf is real)
Has inf? No (saturating) / Yes (IEEE-style) Yes
Has NaN? Yes (only one codepoint) Yes

E4M3 trades range for precision (3-bit mantissa beats 2-bit). E5M2 trades precision for range (matches FP16 range exactly-useful for gradients which span many orders).

Standard assignment (TransformerEngine, OFP8 paper): activations and weights → E4M3; gradients → E5M2. The intuition: weights and activations have a tighter distribution after layer norm; gradients are wider and need range.

7.2 Per-tensor scaling

For each tensor X we maintain a scalar S_X in FP32. The quantize/dequantize pair:

quantize:    X_fp8 = round( clip(X_fp32 * S_X, -FP8_MAX, +FP8_MAX) )
dequantize:  X_fp32_reconstructed = X_fp8 / S_X

We choose S_X so that X_fp32 × S_X lands close to (but not above) FP8_MAX. If amax_X = max|X_fp32|, the optimal scale is roughly

S_X = FP8_MAX / amax_X * margin

where margin < 1 (e.g., 1 / 2^k for some small k) gives headroom for transient spikes.

7.3 The matmul

Tensor cores execute the matmul on (X_fp8, W_fp8) and accumulate in FP32. The scales come out in the dequantize:

Y_fp8 ≈ matmul(X_fp8, W_fp8)        # with FP32 accumulation internally
Y_fp32 = Y_fp8 / (S_X * S_W)         # apply both scales in one step

Then we re-quantize Y to FP8 with its own scale S_Y for the next layer. All of this is one fused kernel in TransformerEngine.

7.4 Delayed (lazy) scaling

The naive approach-compute amax(X) now, set S_X = FP8_MAX / amax(X), then quantize-adds a full reduction over X before every matmul. That reduction would dominate cost.

Delayed scaling instead uses the amax from the previous step:

S_X[t]  = FP8_MAX / max_history[t-1] * margin
amax_X[t] = max|X[t]|       # computed alongside the matmul, cheap
push amax_X[t] into history; trim to last K entries

The history is typically K = 16 to K = 1024 steps. We keep the maximum over the last K rather than the most recent value, to be robust to dips.

7.5 Algorithm in pseudocode

struct FP8TensorScaling {
    fp32  scale            # quantize multiplier this step
    fp32  amax_history[K]  # last K observed amax values
    int   history_idx
}

def fp8_matmul(X_fp32, W_fp32, x_meta, w_meta):
    # 1. Compute scales from previous-step amax
    s_x = FP8_MAX / (max(x_meta.amax_history) + EPS)
    s_w = FP8_MAX / (max(w_meta.amax_history) + EPS)

    # 2. Quantize current tensors and observe current amax
    X_fp8 = round(clip(X_fp32 * s_x, -FP8_MAX, +FP8_MAX))
    W_fp8 = round(clip(W_fp32 * s_w, -FP8_MAX, +FP8_MAX))
    amax_x_now = max(abs(X_fp32))      # FP32 reduction, fused with the cast
    amax_w_now = max(abs(W_fp32))

    # 3. Tensor-core matmul; FP32 accumulator
    Y_fp32 = matmul_fp8_to_fp32(X_fp8, W_fp8) / (s_x * s_w)

    # 4. Update history for next step
    x_meta.amax_history[x_meta.history_idx] = amax_x_now
    w_meta.amax_history[w_meta.history_idx] = amax_w_now
    x_meta.history_idx = (x_meta.history_idx + 1) mod K
    w_meta.history_idx = (w_meta.history_idx + 1) mod K

    return Y_fp32

The + EPS guards against amax = 0 on a freshly-initialized layer or a fully-pruned weight matrix.

7.6 NaN/inf detection

E4M3 in the saturating variant has no inf representation, so an out-of-range value silently saturates to ±448. This is normally fine-clip is the desired behavior. But it does mean that you cannot detect overflow by checking for inf in the FP8 tensor. Instead:

  • Check the FP32 amax before quantization. If it explodes, the previous step was bad.
  • Check for NaN in the FP32 master weights and the FP32 dequantized output.
  • Optional: check if the FP8 tensor saturates more than X% of its elements; that's a sign the scale is wrong.

E5M2 has a real inf, so standard isinf checks work.

7.7 A worked numerical example

Consider an activation tensor X with amax = 12.5. Using E4M3 with FP8_MAX = 448:

S_X = 448 / 12.5 = 35.84

Take a single value x = 1.7:

x * S_X = 60.928
round to E4M3:    60.928 → 60   (E4M3 step at this magnitude is 4: 56, 60, 64)
fp8 stored:       60
dequantize:       60 / 35.84 = 1.6741
absolute error:   |1.7 - 1.6741| = 0.0259
relative error:   0.0152 (1.5%)

That's the precision floor for FP8: roughly 1–3% relative error per element, which the matmul's FP32 accumulator partially averages out across thousands of multiplies.

For a value near the max, say x = 12.0:

x * S_X = 430.08
round to E4M3:    430.08 → 432  (step at this magnitude is 32)
dequantize:       432 / 35.84 = 12.054
relative error:   0.0045 (0.45%)

And for a tiny value x = 0.001:

x * S_X = 0.0358
This is in subnormal range for E4M3: smallest subnormal = 2^-9 ≈ 0.00195
Quantize: round(0.0358 / 0.00195) * 0.00195 ≈ 18 * 0.00195 = 0.0352
Dequantize: 0.0352 / 35.84 ≈ 9.82e-4
Absolute error: ~2e-5
Relative error: 1.8%

The takeaway: FP8's relative precision is roughly constant in the well-scaled regime, falling off only in the subnormal tail.

7.8 What goes in FP8 and what doesn't

Even in a fully FP8-trained model:

  • Weights, activations, gradients: FP8 (E4M3 / E5M2 split).
  • Optimizer states: FP32 (or FP16, with care).
  • Master weights: FP32 (always).
  • Layer norm / RMS norm gain and bias: FP32 or BF16.
  • Embedding tables: usually BF16-distribution is too long-tailed for E4M3.
  • Final classifier / logits: BF16 or FP32-softmax is too sensitive.

8. TF32: the silent precision drop

TF32 (TensorFloat-32) is NVIDIA-specific (Ampere onwards). It is not a storage format; you cannot allocate a TF32 tensor.

8.1 What it actually is

When tensor cores execute an FP32 matmul, they internally: 1. Read FP32 inputs A, B from memory. 2. Round each to TF32 (1+8+10)-discarding 13 mantissa bits. 3. Multiply (using TF32-precision multipliers). 4. Accumulate the products in FP32. 5. Write FP32 output to memory.

So the user sees an FP32 matmul, but the compute throughput is the FP16-tensor-core rate while the precision is FP16-mantissa quality.

8.2 When it bites

For most training, TF32 is fine-losses converge to within 0.01% of true-FP32 results. But:

  • Numerical methods that depend on FP32 precision in the small mantissa (e.g., orthogonalization, Gram-Schmidt, eigendecomposition, large-N integration) can fail subtly.
  • Very small learning rates with large weights: w + small_update may not change w if the update is below TF32 ulp.
  • Reproduce-old-paper validation: you are no longer doing what the paper did.

8.3 The toggle

In PyTorch (post-1.7):

torch.backends.cuda.matmul.allow_tf32 = True   # default on Ampere+
torch.backends.cudnn.allow_tf32 = True         # convolutions

# To turn off:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

PyTorch's default flipped between True and False across versions. As of 2.x, TF32 is enabled by default for matmul on Ampere/Hopper. If you need bit-stable repros or are debugging numerical drift, turn it off and re-test.

8.4 TF32 versus BF16 + FP32 master

These two routes give similar speed and similar accuracy:

  • TF32: keep all tensors in FP32, hardware drops 13 mantissa bits internally.
  • BF16 + master: compute in BF16 (16 bits, 7 mantissa), keep FP32 master.

The BF16 path uses half the memory for activations. TF32 uses none of the BF16 machinery (no autocast wrapping). Most modern training has migrated to BF16 + master because of the activation memory win.


9. Adam + low precision pitfalls

9.1 The fundamental issue

Adam's update is

m  := β1 m + (1 - β1) g
v  := β2 v + (1 - β2) g²
m̂  := m / (1 - β1^t)
v̂  := v / (1 - β2^t)
p  := p - lr × m̂ / (sqrt(v̂) + ε)

For a typical late-training step: lr ≈ 1e-4, m̂ / sqrt(v̂) ≈ O(1), so the update Δp ≈ 1e-4. Meanwhile a typical weight value is ~1e-1 to ~1.

If p is stored in BF16: - p ≈ 0.5, BF16 ulp at 0.5 is 2^-8 ≈ 3.9e-3. - Δp ≈ 1e-4, much smaller than the ulp. - p - Δp rounds back to p. The update is lost.

The same calculation in FP32: p ≈ 0.5, FP32 ulp at 0.5 is 2^-24 ≈ 6e-8, much smaller than Δp ≈ 1e-4. The update sticks.

This is why FP32 master weights are non-negotiable. The entire purpose of master weights is to be the high-precision substrate where Adam's tiny update can survive.

9.2 Stochastic rounding

If you absolutely must store master weights in low precision (memory pressure on a 100B+ model), one fix is stochastic rounding.

Standard (deterministic) RNE rounding: round(0.5 + α) returns the nearest representable value, with ties going to even. This is unbiased on the individual round, but for many small accumulations into a low-precision accumulator, the bias is non-zero-small updates are systematically dropped.

Stochastic rounding: round up with probability proportional to the residual:

def stochastic_round(x_fp32, target_dtype):
    x_lo = floor_to(x_fp32, target_dtype)   # next representable below
    x_hi = next_above(x_lo, target_dtype)   # next representable above
    residual = (x_fp32 - x_lo) / (x_hi - x_lo)   # in [0, 1)
    if random_uniform(0, 1) < residual:
        return x_hi
    else:
        return x_lo

This is unbiased in expectation even after repeated accumulation: E[round(x)] = x. In practice, with stochastic rounding into BF16 master weights, even very small Adam updates accumulate correctly over many steps, because the probability of rounding up matches the residual.

Cost: requires per-element random numbers (philox-style RNG, fast on GPU). Some FP8 training recipes (HFP8, MS-AMP) use stochastic rounding for the FP32 → FP8 cast on weights to retain trainability.

9.3 Practical guidance

  • Default: FP32 master weights. Done.
  • If memory-bound on master weights: stochastic rounding into BF16 master weights. Validate convergence carefully.
  • Never: deterministic RNE rounding into BF16 master weights for a long run. The bias accumulates.

10. Catastrophic cancellation in reductions

10.1 The error model

Naive sequential summation:

s = x_1
for i in 2..N:
    s = s + x_i

At each step, the floating-point add introduces a relative error |δ| ≤ u. The error in s after N adds satisfies, in the worst case,

|fl(sum) - true_sum| ≤ N × u × max_i |x_i|

(actually the bound is (N - 1) × u × sum |x_i| for non-negative inputs, but N × u × max is a useful approximation when terms are similar in magnitude).

FP16 example: N = 10^6, u = 2^-11 ≈ 4.9e-4 (for FP16, one ulp at 1.0 is 2^-10, so u = 2^-11).

relative error ~ N × u = 10^6 × 4.9e-4 = 490

That's 490× the magnitude of `max|x_i| - catastrophic.

FP32: u = 2^-24 ≈ 6e-8. Same N:

relative error ~ 10^6 × 6e-8 = 0.06

6%-bad but recoverable.

BF16: u = 2^-8 ≈ 4e-3. Same N:

relative error ~ 10^6 × 4e-3 = 4000

Worse than FP16. BF16's range advantage does not save you here.

10.2 Pairwise summation

Recursive halving:

def pairwise_sum(x):
    if len(x) == 1: return x[0]
    mid = len(x) / 2
    return pairwise_sum(x[:mid]) + pairwise_sum(x[mid:])

Error bound: O(log N × u × max|x_i|).

For N = 10^6: log_2(10^6) ≈ 20. So FP16 pairwise error is `20 × u × max ≈ 0.01 - five orders of magnitude better than naive.

This is what NumPy, PyTorch, TF, JAX all use for sum, mean, etc. It's the default-but only if you're calling the framework's reduction. Custom CUDA kernels that use a single-thread accumulator have naive O(N × u) error. Be careful.

10.3 Kahan summation

Track and re-add the lost low-order bits:

def kahan_sum(x):
    s = 0
    c = 0    # compensation
    for xi in x:
        y = xi - c
        t = s + y
        c = (t - s) - y    # the rounding error
        s = t
    return s

Error bound: O(u × max|x_i|), independent of N. Cost: 4 ops per element instead of 1.

On GPU, the 4× slowdown is usually not worth it: pairwise summation is O(log N) error and parallelizes naturally on a tree-reduction. Kahan is mostly used in scientific computing and rarely in ML.

10.4 Reduction precision in practice

PyTorch and most frameworks upcast to FP32 for reductions even when the input is BF16/FP16:

x_bf16 = torch.randn(1_000_000, dtype=torch.bfloat16)
m = x_bf16.mean()    # internally: cast to FP32, pairwise reduce, cast back

You can disable this for some kernels (e.g., keep_dtype=True flags), but you almost never want to. Always reduce in FP32, then cast the scalar result back.

LayerNorm and RMSNorm specifically:

# Pseudocode for fused LayerNorm
def layernorm(x_bf16, gamma, beta, eps):
    x_fp32 = x_bf16.to(fp32)            # upcast
    mean = pairwise_mean(x_fp32, dim=-1)
    var  = pairwise_mean((x_fp32 - mean)**2, dim=-1)
    x_norm = (x_fp32 - mean) / sqrt(var + eps)
    out = x_norm * gamma + beta
    return out.to(bf16)                  # downcast at the end

The upcast at the start and downcast at the end are crucial; the variance computation in BF16 would lose 3 decimal digits to cancellation.


11. Numerical stability tricks in transformers

11.1 Softmax with max subtraction

Already covered in section 3.3. Restating for completeness:

def stable_softmax(x):
    m = max(x)
    z = exp(x - m)
    return z / sum(z)

This is equivalent to plain softmax, never overflows, and underflows only the most-negative entries to zero (which is correct behavior-they have negligible probability).

In online softmax (FlashAttention), we extend this to streaming: when we see a new chunk of logits with max m_new, we rescale the running sum:

m_new_global = max(m_old_global, m_new)
S_new = S_old × exp(m_old_global - m_new_global) + sum_in_chunk(exp(x - m_new_global))
m_old_global = m_new_global

11.2 LayerNorm / RMSNorm with FP32 accumulator

LayerNorm computes mean and variance across the feature dim. Variance is E[(x - μ)²], which is a difference of two near-equal terms when computed naively as E[x²] - μ². Always use the centered form:

μ = mean(x)
σ² = mean((x - μ)²)

and always in FP32. Re-cast at the very end.

RMSNorm is simpler: σ_rms² = mean(x²). No subtraction, but still use FP32 to keep the squared-sum from overflowing or losing precision.

11.3 Attention with √dₖ scaling

The dot-product attention logit is q · k = sum_{i=1}^{d_k} q_i k_i.

Assume q_i, k_i are independent random variables with zero mean and unit variance. Then:

E[q · k] = 0
Var[q · k] = sum Var[q_i k_i] = d_k × Var[q_i] × Var[k_i] = d_k

So the standard deviation of q · k grows like √d_k. For d_k = 64, std ≈ 8; for d_k = 128, std ≈ 11.

After softmax, large logits saturate the distribution into a near-one-hot. To keep gradients flowing, we scale:

attn = softmax((Q K^T) / sqrt(d_k))

This puts the logits back to unit-variance regardless of d_k, preserving gradient signal early in training.

The numerical bonus: with logits at unit variance, max(Q K^T / sqrt(d_k)) rarely exceeds 5–10, comfortably within FP16 / BF16 / E4M3 range after the max-subtraction softmax trick.

11.4 Logit soft-cap (Gemma, others)

Gemma 2 introduced logit soft-capping: clip the pre-softmax logits with tanh:

logits = soft_cap × tanh(logits / soft_cap)

with soft_cap = 30 or similar. This prevents extreme logits from blowing up the softmax (mostly a problem with very long contexts where one or two logits can drift huge), and incidentally regularizes the model.

The same trick appears as z-loss (PaLM, T5): an auxiliary loss z_loss × log(sum(exp(logits)))² that pushes the log-partition-function down, preventing logit drift.

11.5 Embedding scaling

Embedding tables in many architectures (the original Transformer, T5) multiply by √d_model after lookup. Reason: the embedding entries are initialized small (N(0, 1/d_model)); without rescaling they would be drowned out by the positional encoding (which is unit-variance). Numerically: keeps the early-layer activations in a sensible range.

11.6 Output projection / unembedding

The unembedding (last Linear to vocab size) is often shared with the input embedding (tied weights). Some recipes scale the output of the last LayerNorm by 1/√d_model or skip the final norm. Others (LLaMA, GPT-NeoX) apply a final RMSNorm specifically because the output magnitudes drift over many layers. Numerically, you want logits to land near O(1) for a stable softmax.


12. Detecting and handling NaN

12.1 Where NaN comes from

Common sources:

  1. Divide by zero (or by a value that underflows to zero). E.g., 1 / (sqrt(v) + eps) with eps = 1e-10 can NaN if sqrt(v) underflows to 0 in FP16 and eps is below the FP16 minimum.
  2. Overflow to inf, then inf - inf or 0 × inf. After overflow, subsequent ops produce NaN.
  3. Invalid op: sqrt(negative), log(0), log(negative). The negative usually came from a tiny numerical error in a quantity that should mathematically be non-negative.
  4. Bad data: a NaN in the input batch propagates everywhere.

12.2 The typical failure mode

A single FP16 overflow in a forward pass:

  1. Some logit = inf.
  2. softmax(inf, ...) → involves inf / inf = NaN.
  3. NaN propagates through attention, FFN, all subsequent layers.
  4. Loss = NaN.
  5. loss.backward() produces NaN gradients on every parameter.
  6. Optimizer step: weight = weight - lr × NaN / NaN → all weights become NaN.
  7. Game over.

This is recoverable only if you detect before step 6.

12.3 Detection

if torch.isnan(loss) or torch.isinf(loss):
    # Skip this step entirely
    optimizer.zero_grad()
    if isinstance(scaler, GradScaler):
        scaler._scale = scaler._scale * 0.5
    log_warning(f"NaN/inf loss at step {step}; skipping")
    continue

A more thorough check after backward:

def grads_finite(model):
    for p in model.parameters():
        if p.grad is not None and not torch.isfinite(p.grad).all():
            return False
    return True

torch.amp.GradScaler does this automatically when you call scaler.step(optimizer): if any grad is NaN/inf, the optimizer step is skipped and scaler.update() halves the scale.

12.4 Gradient clipping

Clip the gradient norm (or per-tensor) to bound the worst-case update:

torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

How it works:

total_norm = sqrt(sum_p ||p.grad||² )
if total_norm > max_norm:
    scale = max_norm / total_norm
    for p in params:
        p.grad *= scale

Clipping prevents a single bad batch from producing a runaway update that pushes weights into a regime where subsequent forward passes overflow. It does not prevent NaN if NaN is already in the gradients-total_norm becomes NaN and clipping does nothing useful. Always check for NaN before or after clipping.

Important: with GradScaler, you must call scaler.unscale_(optimizer) before clipping, otherwise you clip the scaled gradients (the wrong norm).

12.5 Skip-step recovery

Algorithm:

Save checkpoint every N steps.
On NaN:
    1. Reset optimizer momentum (Adam m, v) to last good checkpoint.
    2. Restore weights to last good checkpoint.
    3. Reduce LR by 0.5x for K steps.
    4. (FP16) Halve loss scale.
    5. Resume.

Production runs implement this as a watchdog. Without it, a single catastrophic batch destroys days of compute.

12.6 Eps placement matters

The Adam denominator: sqrt(v) + eps. Two equivalent formulations have different numerical properties:

  • `update = m / (sqrt(v) + eps) - eps inside, dimensionally correct, default in PyTorch.
  • `update = m / sqrt(v + eps²) - eps inside the sqrt, slightly different behavior near v=0.

PyTorch's default eps = 1e-8. In FP16 storage that underflows; this is one reason Adam states are always FP32. In FP32 it's fine.


13. Determinism

13.1 Sources of non-determinism

GPU training is non-deterministic by default. Sources:

  1. Atomic adds in reductions: many CUDA kernels (e.g., scatter_add, some softmax kernels, certain backward passes) use atomicAdd for thread-safe accumulation. The order of atomic adds is non-deterministic, and FP32 addition is non-associative. So you get bit-different results across runs even with the same inputs.
  2. CUDA workspace reuse: cuBLAS picks different algorithms based on workspace size and available memory. Different runs → different algorithms → bit-different results.
  3. Multi-threaded data loading: workers can return batches in different orders.
  4. NCCL collectives: ring/tree algorithms have run-dependent ordering.
  5. cuDNN heuristics: cuDNN benchmarks kernels and picks the fastest, but the choice depends on transient hardware state.

13.2 PyTorch deterministic mode

import torch
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Required for some cuBLAS kernels:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"   # or ":16:8"

# Seed everything:
import random; random.seed(0)
import numpy as np; np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)

Combined with single-process, single-worker data loading and a fixed seed, you can get bit-exact reproducibility on a single GPU.

13.3 Cost

Deterministic mode is slower:

  • scatter_add and embedding gradients: 2–10× slower (because we lose atomic-add).
  • Some convolution algorithms: 1.2–2× slower.
  • Multi-GPU training: harder still, because NCCL collectives are not bit-deterministic without specific configuration.

Use deterministic mode for debugging only. Production runs should accept non-determinism and rely on statistical reproducibility (the loss curve looks the same up to small noise).

13.4 What "reproducible" means in practice

For a paper or ablation study:

  • Run the same configuration 3 times with different seeds.
  • Report mean ± std of final metrics.
  • If ablations are within the std, they are noise.

Bit-exact reproducibility is rarely the goal. Statistical reproducibility (results within seed-noise) is.


14. Practical exercises

Solutions inline. Try each before reading the answer.

14.1 FP16 representable-zero

Problem: Show that the value 1e-5 (decimal) cannot be represented as a normal FP16 number. What is the closest FP16 value?

Solution:

FP16 smallest positive normal = 2^-14 ≈ 6.1035e-5.

1e-5 < 6.1e-5, so it is below the smallest normal-it's in subnormal range.

FP16 subnormal step = 2^-14 × 2^-10 = 2^-24 ≈ 5.96e-8.

1e-5 / 5.96e-8 ≈ 167.77. Round to nearest even: 168.

Closest FP16 = 168 × 2^-24 ≈ 1.0014e-5. Relative error: 0.14%.

So 1e-5 is representable in FP16-but only as a subnormal, with ~10× less precision than a normal FP16 value of similar magnitude. If subnormals are flushed (FTZ), 1e-5 becomes 0. This is exactly the regime where FP16 gradients silently underflow without loss scaling.

14.2 BF16 accumulation error

Problem: You sum N = 10^6 values, each ~U(-1, 1). Estimate the absolute error of naive sequential BF16 summation versus pairwise BF16 summation versus FP32 pairwise.

Solution:

BF16: u = 2^-8 ≈ 3.9e-3.

Naive: N × u × max|x_i| ≈ 10^6 × 3.9e-3 × 1 = 3900. The error dwarfs the true sum (which is O(sqrt(N)) ≈ 1000 by CLT). Total noise.

Pairwise BF16: log_2(N) × u × max|x_i| ≈ 20 × 3.9e-3 × 1 ≈ 0.08. Acceptable.

FP32 pairwise: u = 2^-24 ≈ 6e-8. Error ≈ 20 × 6e-8 = 1.2e-6. Negligible.

Lesson: BF16 reductions are usable only with pairwise (or better) summation. Always upcast to FP32 anyway, because it's free on modern hardware.

14.3 Loss-scale recovery trace

Problem: A run starts with S = 2^15 = 32768. After every overflow, S halves. Over how many consecutive overflows would S reach 2^0 = 1? At what point does S drop below the regime where it's helpful (assume "helpful" means S ≥ 2^7 = 128)?

Solution:

2^15 / 2^k = 2^(15-k). To reach S = 1 = 2^0, we need 15 halvings.

To drop below 2^7, we need to fall to 2^6 = 64. That is 2^(15-k) = 2^6k = 9. So 9 consecutive overflow halvings push S below the useful regime.

If the dynamic-scaling patience is 2000 successful steps for a doubling, we can recover slowly: from S = 2^6 to S = 2^15 takes 9 doublings = 18000 successful steps minimum. In practice a single bad batch causes one halving, but a phase change (e.g., LR warmup ending, distribution shift) can cause cascading overflows-9 in a row is unlikely but not impossible.

Implication: monitor loss_scale as a training metric. A scale that has been falling for 100 steps is a warning sign.

14.4 7B model optimizer-state memory

Problem: For a 7B parameter model with Adam optimizer, compute optimizer-state memory for: (a) FP32 master weights, FP32 m, v. (b) BF16 master weights, FP16 m, v. (c) Bonus: total memory including weights, gradients, master weights for case (a) and (b).

Solution:

N = 7e9 parameters.

(a) Pure FP32 optimizer states: - m in FP32: 7e9 × 4 = 28 GB - v in FP32: 7e9 × 4 = 28 GB - Total optimizer state: 56 GB

(b) BF16 master + FP16 m, v: - m in FP16: 7e9 × 2 = 14 GB - v in FP16: 7e9 × 2 = 14 GB - Total optimizer state: 28 GB

(c) Full memory accounting:

Case (a), standard mixed-precision: - BF16 weights for compute: 7e9 × 2 = 14 GB - BF16 gradients: 7e9 × 2 = 14 GB - FP32 master weights: 7e9 × 4 = 28 GB - FP32 m + v: 56 GB - Total just for params/grads/opt: 112 GB

Case (b), aggressive low-precision: - BF16 weights: 14 GB - BF16 gradients: 14 GB - BF16 master weights: 14 GB - FP16 m + v: 28 GB - Total: 70 GB

But case (b) requires stochastic rounding for the BF16 master weights and may sacrifice convergence quality. This is why ZeRO/FSDP stage 3 (sharding optimizer states across GPUs) is more popular than aggressive low-precision optimizers.

14.5 FP8 scale evolution

Problem: An activation tensor's amax history over the last 5 steps is [2.1, 2.4, 8.5, 2.3, 2.2] (the 8.5 is a transient spike from an outlier batch). You use E4M3 (FP8_MAX = 448) with margin = 1 (no headroom factor). What scale does the next step use: (a) Using the most recent amax (2.2)? (b) Using the max of history (8.5)?

What's the consequence of each choice?

Solution:

(a) S = 448 / 2.2 ≈ 203.6. Tight scale: every quantization uses the full FP8 range. But on the next outlier batch (similar to the spike), the tensor would saturate at ~448 / 203.6 ≈ 2.2, clipping any value above 2.2. We'd lose the outliers.

(b) S = 448 / 8.5 ≈ 52.7. Looser scale: most batches under-utilize the FP8 range (max value used: 2.4 × 52.7 ≈ 126, well below 448). But outliers up to 8.5 are represented faithfully.

Standard practice: use max of recent history to get robustness to spikes, possibly with an additional margin (e.g., margin = 1/2, giving an extra 2× headroom). The cost is some wasted FP8 range on calm batches; the benefit is graceful handling of outliers.

This is why K (history length) matters: too short and you forget spikes (under-scaled, clip outliers); too long and you over-pad indefinitely (over-scaled, waste precision).

14.6 Softmax overflow boundary in FP16

Problem: For a FP16 softmax (without max subtraction), at what magnitude does the largest logit cause exp to overflow? Compare to the typical pre-scaling logit magnitude in attention with d_k = 128 and unit-variance Q, K.

Solution:

FP16 max = 65504 ≈ 6.55e4. So exp(x) > 65504x > log(65504) ≈ 11.09.

Without scaling, attention logit q · k has std √d_k = √128 ≈ 11.3. So a single-σ logit already overflows FP16. A 3σ outlier (x ≈ 34) overflows by 23 in log space, i.e., by exp(23) ≈ 10^10.

With scaling by 1/√d_k, logits have std 1. Now a 5σ outlier (x ≈ 5) gives exp(5) ≈ 148, comfortably representable.

The math: √d_k scaling is necessary, not optional, for FP16 attention. Even with max-subtraction softmax, the gradient of the un-scaled logit can blow up. Scaling is built into every transformer for this reason.


Closing remarks

A few things to remember when this chapter is closed:

  1. BF16 + FP32 master weights + FP32 reductions is the modern default. It's robust, well-supported, and conceptually simple. Reach for FP16 only on hardware that doesn't have BF16; reach for FP8 only when you've measured the savings and committed to dealing with delayed scaling.

  2. Range matters more than precision for ML. This is why BF16 ate FP16 and why FP8 split into two formats (E4M3 for tight distributions, E5M2 for wide ones).

  3. The accumulator is sacred. Tensor cores will let you compute in 8 or 16 bits, but they accumulate in 32. Reductions you write yourself should do the same.

  4. Master weights exist because Adam updates are tiny. Not for any other reason. If you ever invent an optimizer with O(1) updates (some second-order methods approach this), you may be able to drop the master.

  5. Most NaN crashes are loss-scale or learning-rate problems. Before re-debugging the model, check the simplest things: is loss_scale stable? is the LR schedule sane? did the data have a NaN?

  6. Determinism is a debugging tool, not a production goal. Statistical reproducibility (across seeds) is what matters for science.

The next chapter (12_KERNEL_FUSION.md, if you're working through the curriculum in order) builds on this: now that we know which precisions are needed where, we can design custom kernels that fuse multiple operations while respecting these precision rules.

Comments