Saltar a contenido

Deep Dive 05-JAX and XLA

Reading contract. This chapter is a self-contained reference. After reading and working the exercises you should be able to (a) read and write idiomatic JAX, (b) reason about what happens when @jax.jit is applied to a function, (c) inspect jaxprs and HLO, (d) shard a computation across a multi-host TPU/GPU cluster using Mesh + PartitionSpec, and (e) pick between jit - with-sharding,shard_map, and (legacy)pmap` for a given workload. We do not punt to the JAX docs.


Table of contents

  1. Why JAX exists
  2. Functional purity: the unit of compilation
  3. PyTrees and jax.tree_util
  4. Stateless PRNGs (PRNGKey)
  5. Tracing and jaxprs
  6. jax.jit: caching, recompilation, static args
  7. jax.grad, value_and_grad, jvp, vjp
  8. jax.vmap and per-example gradients
  9. Device parallelism: pmap (legacy) vs jit + sharding (modern)
  10. jax.shard_map: when you want manual control
  11. Structured loops: lax.scan, lax.fori_loop, lax.while_loop
  12. XLA: HLO IR, compilation pipeline, fusion, layout, GSPMD
  13. TPU vs GPU under XLA
  14. Module systems on top: Equinox and Flax
  15. jax.experimental.pallas (Triton-like kernel DSL)
  16. Practical exercises (with worked answers)
  17. Cheat-sheet appendix

1. Why JAX exists

JAX is a library for numerical computing whose central thesis is:

Numerical programs are functions of arrays. If you keep them pure, you can compose program transformations on them-autodiff, vectorization, parallelization, just-in-time compilation-and you can lower the result through a single optimizing array compiler (XLA) to CPU, GPU, or TPU.

That sentence has every load-bearing word in JAX's design. Let us unpack it against the contrast that most readers carry: PyTorch.

1.1 PyTorch's design (so we have a foil)

PyTorch is eager and imperative: a tensor operation runs immediately on the host's accelerator queue. The autograd graph is built dynamically as a side effect of forward computation-every tensor with requires_grad=True allocates a node in a graph stored on the tensor itself. Modules are stateful objects (nn.Module) that own their parameters, buffers, and (transitively) the optimizer state. To go fast you typically (a) call into eager kernels, (b) use torch.compile (TorchDynamo + Inductor) which traces Python bytecode and lowers to fused kernels, or (c) drop into custom CUDA / Triton.

PyTorch optimizes for debuggability and programmer ergonomics in idiomatic Python. You can print a tensor, set a Python breakpoint inside a forward pass, mutate a list of layers conditionally, and it all works.

1.2 JAX's design choices

JAX takes an almost orthogonal set of choices:

  1. Pure functions are the unit. A JAX-compilable function takes arrays in, returns arrays out, and has no side effects: no mutation of outer Python state, no in-place tensor edits, no random state hidden in a global. Everything that would be state is threaded explicitly through arguments and return values.

  2. Composable transformations. Once a function is pure, JAX can give you several function-to-function transformations:

  3. `jax.jit - trace and compile via XLA.
  4. `jax.grad - return a function computing the gradient.
  5. `jax.vmap - return a function that runs the original over a new batch axis.
  6. jax.pmap / jit with sharding-return a function that runs across devices. These compose: jit(vmap(grad(f))) is meaningful and well-defined. This composability is the soul of JAX.

  7. XLA as the default backend. Where PyTorch's "real" backend is a constellation of cuDNN/cuBLAS calls in eager mode and Inductor-generated Triton at compile time, JAX always lowers to HLO (XLA's IR) and lets the XLA compiler emit device code. This means TPU and GPU share most of the toolchain. (XLA was originally a TPU compiler at Google, and that lineage shows.)

  8. TPU first-class. Unlike PyTorch where TPU support is delivered through torch_xla as an extra layer, JAX speaks XLA natively. A JAX program written for one TPU pod core scales to thousands of cores essentially by adding sharding annotations.

1.3 Trade-offs

The JAX bargain:

  • You give up in-place mutation, easy printing inside compiled code, dynamic shapes that change every call, and Python-level control flow that depends on traced values.
  • You get a uniform compilation pipeline, world-class autodiff that composes with everything, free vectorization (vmap), free distribution (jit + sharding), and predictable performance because compilation is explicit.

For research workloads with stable shapes and heavy linear algebra (transformers, diffusion, scientific computing), this trade is excellent. For workloads with ragged shapes, dynamic graphs of varying topology, or heavy host-side branching (some RL systems, classical NLP pipelines), it can be painful.


2. Functional purity: the unit of compilation

A pure function in JAX's sense:

  • Outputs depend only on inputs. No global reads (other than constants closed over at trace time).
  • Has no observable side effects. No global writes, no I/O, no mutation of arguments.
  • Same input → same output, every call.
import jax
import jax.numpy as jnp

# Pure
def loss(params, x, y):
    pred = x @ params["W"] + params["b"]
    return jnp.mean((pred - y) ** 2)

# Impure-mutates a global counter
counter = 0
def bad(x):
    global counter
    counter += 1
    return x * 2

jax.jit(bad) will appear to work but counter will be incremented exactly once per trace, not once per call. The bug surfaces silently. This is a recurring pattern: JAX does not police purity at runtime; it assumes it. If you violate it, you get correct-looking output and incorrect semantics.

The discipline imposed by purity buys two enormous things:

  1. Trivial reverse-mode AD. With no side effects, the chain rule is just structural induction over the jaxpr. There is nothing to "undo."
  2. Trivial parallelization. Pure functions are referentially transparent; you can run them on any device, in any order, multiple times, without changing program meaning.

2.1 Where state lives

If the model has weights, optimizer moments, RNG state, batch-norm running stats, those are values that the user threads through arguments:

def train_step(params, opt_state, rng, batch):
    rng, sub = jax.random.split(rng)
    grads = jax.grad(loss)(params, batch, sub)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, rng

Compare with PyTorch where optimizer.step() mutates param.data and param.grad in place. JAX surfaces the mutation as new return values. The train_step itself remains pure.


3. PyTrees and jax.tree_util

A neural network has hundreds or thousands of parameters. Threading them as positional arguments is unworkable. JAX solves this by making arbitrary nested Python containers first-class.

A PyTree is, recursively, either: - a leaf (an array, a scalar, anything not a registered container), or - a container node-by default tuple, list, dict, None, `namedtuple - with PyTree children.

Custom dataclasses can be registered with jax.tree_util.register_pytree_node (or, for dataclasses, @jax.tree_util.register_dataclass / Equinox's automatic registration).

3.1 Why this matters

Every JAX transformation that takes a function f(x) -> y and produces f'(x) -> y' operates over PyTrees: x and y may be arbitrarily nested. The transformations preserve structure.

Example: jax.grad applied to a function whose first argument is params = {"layer1": {"W": ..., "b": ...}, "layer2": {...}} returns a gradient PyTree with the same shape as params. You never write flatten_params glue code by hand.

3.2 The core API

from jax import tree_util as tu

leaves, treedef = tu.tree_flatten(params)   # list of arrays + structure spec
params2 = tu.tree_unflatten(treedef, leaves) # rebuild

# Map a function over every leaf:
doubled = jax.tree.map(lambda x: 2 * x, params)

# Map across two PyTrees that share structure:
sum_tree = jax.tree.map(lambda a, b: a + b, params, grads)

In recent JAX, the public surface is jax.tree.map, jax.tree.leaves, jax.tree.structure, etc. The underlying jax.tree_util module remains.

3.3 PyTree as the universal interface

Internally, every JAX transformation: 1. Flattens its inputs to a flat list of leaves + a tree-structure description. 2. Operates on the flat list (where everything is just an array). 3. Reconstructs the output structure on the way out.

This is why you can pass a dict of arrays to jit, grad, vmap, pmap and they all "do the right thing"-they each call tree_flatten and treat leaves uniformly.

3.4 Custom PyTree

import dataclasses
@jax.tree_util.register_dataclass
@dataclasses.dataclass
class GRUCell:
    W: jax.Array
    U: jax.Array
    b: jax.Array

Now GRUCell instances are valid PyTrees. jax.grad will return a GRUCell of gradients.

Contrast with PyTorch. PyTorch's nn.Module is a class with a state_dict() method. JAX's analog is "any PyTree." The "module" abstraction is built on top (Flax, Equinox), not into the core.

3.5 Exercise

Given params = {"a": jnp.zeros((3,)), "b": [jnp.ones((2,2)), jnp.ones((2,))]}, write the call that returns {"a": shape (3,), "b": [shape (2,2), shape (2,)]}.

jax.tree.map(lambda x: x.shape, params)

4. Stateless PRNGs (PRNGKey)

Random numbers are state. State is impure. JAX therefore cannot have a global RNG (well, it could, but it would break composition with jit/vmap/pmap).

Solution: an explicit, immutable key, threaded by the user.

key = jax.random.PRNGKey(42)         # an array of shape (2,) uint32 (historically)
key, subkey = jax.random.split(key)  # 2 new keys; old key conceptually consumed
x = jax.random.normal(subkey, (1024,))

Three rules:

  1. Never reuse a key. random.normal(key, ...) is a pure function of key; passing the same key gives identical samples.
  2. Always split before consuming. split(key, n) returns n fresh keys.
  3. Threading is the user's job. Every function that consumes randomness takes a key argument.

4.1 Why this design

  • Reproducibility. Two runs with the same starting key produce bit-identical results, even across vmap/pmap/sharding.
  • Composability. vmap(f) over a batched-keys argument gives per-example randomness with no surprise. With a global RNG, vmap could not say what the per-example samples should be.
  • Determinism under parallelism. Each device gets its own key derived deterministically from the master key. No race conditions, no per-device RNG state to seed.

4.2 Idiom: per-step splitting

def train_step(params, rng, batch):
    rng, dropout_key = jax.random.split(rng)
    logits = model(params, batch, dropout_key)
    ...
    return params, rng  # return the *new* rng for the next step

You can split into many keys at once: keys = jax.random.split(rng, num=8) and vmap over them.

4.3 Key types

Modern JAX has typed keys (e.g., the threefry and rbg algorithms). For day-to-day work PRNGKey(seed) is fine; for cryptographic-grade or platform-specific needs see the jax.random.key API. The conceptual model-explicit, splittable, stateless-is unchanged.


5. Tracing and jaxprs

This is the conceptual heart of JAX. Internalize it and the rest follows.

5.1 What @jax.jit actually does

When you call a `jit - decorated function for the first time with concrete arguments:

  1. JAX inspects each argument's shape and dtype (and static-argnum python values).
  2. It calls your Python function with abstract Tracer objects in place of those arguments-objects that record every operation performed on them but do not compute values.
  3. The resulting trace is a jaxpr (JAX expression): a small typed IR of primitive operations.
  4. The jaxpr is lowered to HLO and compiled by XLA for the target device.
  5. The compiled executable is cached, keyed by (function identity, abstract input signature, static-arg values).
  6. The actual concrete arguments are run through the executable.

Subsequent calls with the same abstract signature skip steps 2–5 and just run the cached executable.

5.2 A worked jaxpr

import jax
import jax.numpy as jnp

def f(x, y):
    a = x * y
    b = jnp.sin(a)
    return jnp.sum(b)

print(jax.make_jaxpr(f)(jnp.ones((3,)), jnp.ones((3,))))

You will see something like:

{ lambda ; a:f32[3] b:f32[3]. let
    c:f32[3] = mul a b
    d:f32[3] = sin c
    e:f32[] = reduce_sum[axes=(0,)] d
  in (e,) }

Reading it:

  • a:f32[3] b:f32[3] - twofloat32inputs of shape(3,)`.
  • The let block names intermediate values.
  • mul, sin, reduce_sum are JAX primitives-the leaves of the jaxpr.
  • `in (e,) - the output tuple.

A jaxpr is pure, typed, and closed: every variable is bound, every operation is a primitive, every shape and dtype is known. This is the input XLA receives.

5.3 Concrete vs Abstract vs Traced

Three kinds of values flow through JAX code:

  • Concrete arrays (jax.Array): real data on a device. Default outside jit.
  • Abstract arrays (ShapedArray, ConcreteArray): metadata only-shape, dtype, optional weak type. Used for tracing.
  • Tracers (Tracer): wrappers presented to the user's Python function during tracing. They look like arrays (they have .shape, .dtype, support +, *, jnp.sin, etc.) but every operation on them appends a node to the jaxpr.

A common pitfall: writing Python control flow on a tracer.

@jax.jit
def f(x):
    if x > 0:           # ConcretizationError: Tracer cannot be branched on
        return x
    else:
        return -x

The condition x > 0 is a tracer because x is. Python's if requires a concrete bool. You must use jax.lax.cond(x > 0, ..., ...) (or jnp.where for elementwise selection), which becomes part of the jaxpr.

5.4 Static arguments

If x is a Python scalar that determines the shape of arrays-e.g. number of layers-make it static:

from functools import partial
@partial(jax.jit, static_argnums=(1,))
def make_zeros(rng, n):     # n is static
    return jax.random.normal(rng, (n,))

n is now treated as part of the cache key, not as a tracer. Different n values trigger different compilations.

5.5 Print-debugging

print(x) inside a jitted function prints the tracer (useful for inspection), not the value. For per-call debug prints use jax.debug.print("x = {}", x) which compiles into a host callback.


6. jax.jit: cache, recompilation, costs

6.1 The cache key

The compilation cache is keyed (essentially) by:

Component What changes the key
Function identity The Python function object
Abstract input signature shape and dtype of every leaf in the input PyTree
Static argument values python == equality of static_argnums / static_argnames values
PyTree structure the treedef of inputs
Device / sharding context target backend & sharding spec

So:

  • Calling f(x_f32_3x4) then f(x_f32_3x4) → 1 compile, 2 calls.
  • Calling f(x_f32_3x4) then f(x_f32_3x5)2 compiles (different shape).
  • Calling f(x_f32_3x4) then f(x_f64_3x4) → 2 compiles (different dtype).
  • Calling f({'a': x, 'b': y}) then f({'b': y, 'a': x}) → 1 compile (dicts have stable PyTree order).
  • Changing the value of a non-static argument → 0 compiles.
  • Changing the value of a static argument → 1 compile per distinct value.

6.2 Recompilation costs

Compilation is not free: HLO optimization plus device codegen can cost hundreds of milliseconds to tens of seconds for transformer-scale models. If your training loop accidentally retraces every step, the program runs but at compile-bound throughput. Symptoms:

  • First step: 5 s. Second step: 5 s. Third step: 5 s. (Should be: 5 s, 50 ms, 50 ms.)
  • Memory growth in the HLO module cache.

Detection:

jax.config.update("jax_log_compiles", True)  # logs every compile

or set JAX_LOG_COMPILES=1. You should see one line per training-loop function, not one per step.

Common causes of accidental retraces:

  1. Padding-by-actual-length: each batch has a slightly different sequence length, so shapes vary. Fix: pad to a small set of bucket lengths.
  2. Passing Python ints that the function uses to construct shapes-make them static_argnums.
  3. Passing different PyTree structures (e.g., a dict with an optional key).

6.3 jit is lazy, dispatch is async

jit - compiled calls return immediately with futures (jax.Arraybacked by a pending computation). Usejax.block_until_ready(x)orx.block_until_ready()` when timing.

6.4 Ahead-of-time lowering

lowered = jax.jit(f).lower(jnp.ones((4,)), jnp.ones((4,)))
print(lowered.as_text())            # StableHLO MLIR text
print(lowered.compiler_ir(dialect="hlo"))  # HLO module
compiled = lowered.compile()
print(compiled.cost_analysis())     # FLOPs, bytes, etc., for benchmarking

This is the inspection toolchain you will use repeatedly.


7. jax.grad and friends

7.1 Reverse-mode AD on a jaxpr

jax.grad(f) returns a function g such that g(x) equals df/dx evaluated at x. Mechanically:

  1. Trace f to a jaxpr (the primal jaxpr).
  2. Walk the jaxpr forward, recording residuals where needed.
  3. Construct a transposed / reverse jaxpr that computes the cotangent: for each primitive, JAX has a registered VJP rule (primitive.def_vjp(...)).
  4. Return that as a callable jaxpr (typically jitted).

Functional purity is what makes step 3 trivial: each primitive has a local linearization, and the chain rule is just composition because there are no hidden state edges to break it.

By convention, grad(f) differentiates with respect to the first argument and expects a scalar output.

def loss(params, x, y):
    return jnp.mean((x @ params["W"] - y) ** 2)

grad_fn = jax.grad(loss)
g = grad_fn(params, x, y)   # g has the same PyTree shape as params

For multiple argnums: jax.grad(loss, argnums=(0, 1)) returns a tuple of grads.

7.2 value_and_grad

You usually want both the loss value and the gradient. jax.value_and_grad(loss) returns (loss, grads) from a single trace-no double work.

(loss_val, grads) = jax.value_and_grad(loss)(params, x, y)

7.3 Higher-order

jax.grad(jax.grad(f)) is well-defined (it traces grad(f) and differentiates that trace). For Hessians, use jax.hessian(f) (which is jacfwd(jacrev(f)) under the hood for a typical scalar function).

7.4 Forward and reverse mode primitives

Two lower-level operators:

  • jax.jvp(f, primals, tangents) - forward-mode: computesf(primals)and the directional derivativeJ · tangents` in one go. Good when output dim ≪ input dim.
  • jax.vjp(f, *primals) - reverse-mode: returns(f(primals), vjp_fn)wherevjp_fn(cotangent)computescotangent · J. This is whatgrad` is built on.

Rule of thumb:

  • Few outputs, many inputs (training loss → loss is scalar): reverse mode (grad/vjp).
  • Few inputs, many outputs (sensitivity of a vector-valued function to a small parameter): forward mode (jvp).
  • For Jacobians: jax.jacrev (reverse) for tall Jacobians, jax.jacfwd (forward) for wide Jacobians.

7.5 Custom derivatives

@jax.custom_vjp
def stable_softmax(x):
    z = x - jnp.max(x)
    return jnp.exp(z) / jnp.sum(jnp.exp(z))

def fwd(x): ...
def bwd(res, g): ...
stable_softmax.defvjp(fwd, bwd)

Use this when (a) you have a hand-derived gradient that is more numerically stable than the autodiff one, or (b) you want to break a gradient (stop_gradient is a single-line alternative for that).

7.6 Contrast with PyTorch

In PyTorch, loss.backward() walks the dynamically-built graph attached to the tensor, populates .grad fields by mutation, and frees the graph. In JAX, grad(loss) builds a new function that returns a new PyTree of gradients. There is no .grad attribute, no graph to free, no optimizer.zero_grad() needed-purity makes "zero out before backward" unnecessary because nothing is mutated.


8. jax.vmap: vectorization is a transformation

vmap(f) returns a function that runs f over an extra leading axis without you writing the batched code. It is not a Python for loop. It rewrites the jaxpr to push the batch dimension through every primitive.

8.1 The basic interface

def dot(a, b):                     # a:(d,) b:(d,) -> ()
    return jnp.sum(a * b)

batched_dot = jax.vmap(dot)        # (B,d), (B,d) -> (B,)
batched_dot(jnp.ones((32, 5)), jnp.ones((32, 5)))

8.2 in_axes / out_axes

Specify which axis of each argument is the batched axis:

# Batch over a's first axis, broadcast b:
jax.vmap(dot, in_axes=(0, None))   # (B,d), (d,) -> (B,)

# Batch over a's last axis, b's first axis:
jax.vmap(dot, in_axes=(-1, 0))

None means "this argument is not batched-broadcast it." out_axes controls where the batch axis appears in the output (default 0).

For PyTree arguments, in_axes is itself a PyTree (or a single int, applied uniformly).

8.3 How vmap works under the hood

For each primitive p, JAX has registered a batching rule: given how each input is batched, what is the batched output and along which axis? vmap walks the jaxpr applying these rules. Most primitives have rules that turn into a single fatter primitive call-e.g., vmap(dot) becomes a matmul, not a Python loop. That is why vmap is fast: it produces good HLO.

8.4 Composes with grad

The canonical example: per-example gradients.

In standard training, grad(loss) gives the gradient of the mean loss-a single PyTree summed across the batch. Sometimes you want one gradient per example (for influence functions, differential privacy, gradient clipping per example, etc.).

def per_example_loss(params, x, y):     # x:(d,), y:()-single example
    pred = x @ params["W"] + params["b"]
    return (pred - y) ** 2

per_example_grad = jax.vmap(jax.grad(per_example_loss), in_axes=(None, 0, 0))
# per_example_grad(params, X, Y) returns grads with leading axis B

Read it carefully: - grad(per_example_loss) is a function that, given a single example, returns a single gradient PyTree. - vmap(...) lifts that over a batch axis on (x, y) while sharing params. - The result has the same PyTree structure as params but with an added leading batch dimension on every leaf.

This is two lines of JAX. The PyTorch equivalent typically requires functorch.vmap (now folded into torch.func) or torch.func.vmap(grad(...)).

8.5 vmap for the inference case

A model that runs on a single example can be batched by writing the model for one example and vmap - ing it. This is sometimes cleaner than worrying about broadcasting in the model definition. In practice, most JAX models are written batched (becausematmulalready does it for free), butvmap` is invaluable for non-trivial axes (e.g., MoE expert routing, beam search, per-head attention).


9. Device parallelism: pmap (legacy) vs jit + sharding (modern)

JAX has been through a small evolution here. Understand both because real codebases mix them.

9.1 pmap (the original)

pmap is "parallel map": it `jit - compiles a function and runs it on multiple devices, with one shard of the leading axis per device.

@jax.pmap                          # 8 devices: leading axis must be 8
def step(params, batch):
    ...
    return new_params

# params must be replicated across devices:
params = jax.tree.map(lambda x: jnp.broadcast_to(x, (8,) + x.shape), params)
batch = ... shape (8, B/8, ...)
new_params = step(params, batch)

Cross-device communication is via collectives inside the function: jax.lax.psum(x, axis_name="i"), pmean, all_gather, etc., where axis_name is set by pmap(..., axis_name="i").

pmap is single-program multiple-data with explicit sharding by the user, restricted to one batch dimension and one mesh axis. It composes-`pmap(vmap(grad(f))) - but it is awkward when you need 2D meshes, host coordination across many TPU hosts, or per-tensor sharding choices.

9.2 pjit (intermediate) and the unified jit (modern)

Around 2022 JAX introduced pjit - ajitthat took aMeshandPartitionSpecs and let XLA's GSPMD partitioner shard arbitrary tensors across an arbitrary mesh. This was the right abstraction. By 2024,pjitandjitwere unified: today **jax.jitnatively understands sharding**, andpjit` is an alias.

The modern path:

import jax
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding

devices = jax.devices()                     # e.g., 8 GPUs or a 2x4 TPU slice
mesh = Mesh(np.array(devices).reshape(2, 4), axis_names=("data", "model"))

# Place an array sharded:
def shard(x, spec):
    return jax.device_put(x, NamedSharding(mesh, spec))

x   = shard(x_host,    P("data", None))     # batch sharded over 2 devices
W   = shard(W_host,    P(None, "model"))    # output dim sharded over 4 devices
b   = shard(b_host,    P("model"))

@jax.jit
def forward(x, W, b):
    y = x @ W + b
    # Optionally constrain intermediate shardings:
    y = jax.lax.with_sharding_constraint(y, NamedSharding(mesh, P("data", "model")))
    return y

What happens under the hood:

  1. jit traces forward with abstract sharded inputs.
  2. The jaxpr lowers to HLO with sharding annotations on the inputs, outputs, and any with_sharding_constraint points.
  3. GSPMD (the partitioner inside XLA) propagates sharding through the whole HLO module, deciding per-op how each tensor is laid out.
  4. GSPMD inserts collectives (all-reduce, all-gather, reduce-scatter, all-to-all) where needed.
  5. XLA emits one program per device; each device runs only its slice.

This is GSPMD: General SPMD partitioner. The user writes a single-device-shaped program with annotations on a few key tensors; the compiler figures out the rest.

9.3 Mesh, PartitionSpec, NamedSharding

  • Mesh: a logical N-dimensional grid of devices with named axes. Common patterns:
  • 1D: `("data",) - pure data parallelism.
  • 2D: `("data", "model") - DP × tensor-parallel (Megatron-style).
  • 3D: ("data", "fsdp", "model") or `("pp", "data", "model") - pipeline + DP + TP.
  • PartitionSpec (alias P): for an array with shape (d0, d1, …, dn), a P(spec0, spec1, …) says how each axis is partitioned over the mesh. Each speci is either:
  • None: replicated along this array axis.
  • "name": sharded along the mesh axis "name".
  • A tuple ("a", "b"): sharded along the product of mesh axes a and b.
  • NamedSharding(mesh, P(...)): binds a PartitionSpec to a concrete mesh.

Examples: - P("data", None) on (B, D): sharded batch, replicated features (FSDP-like for activations). - P(None, "model") on (D_in, D_out): replicated D_in, sharded D_out (tensor-parallel weight). - P(("data", "fsdp")) on (D,) with mesh ("data", "fsdp"): sharded over both axes simultaneously.

9.4 with_sharding_constraint

Inside a `jit - ed function, you can pin an intermediate sharding:

y = jax.lax.with_sharding_constraint(y, NamedSharding(mesh, P("data", "model")))

This is a hint to GSPMD: "after this point, y must be sharded this way." Use it to: - Break ambiguity when GSPMD picks a sharding you don't want. - Force a re-shard at a known boundary (e.g., between tensor-parallel attention and tensor-parallel MLP).

9.5 Output sharding

jit with sharding can take in_shardings and out_shardings arguments to specify how inputs and outputs are sharded. Default: inferred from the actual input shardings and propagated by GSPMD.

forward_p = jax.jit(
    forward,
    in_shardings=(NamedSharding(mesh, P("data", None)),
                  NamedSharding(mesh, P(None, "model")),
                  NamedSharding(mesh, P("model"))),
    out_shardings=NamedSharding(mesh, P("data", "model")),
)

9.6 When pmap is still useful

pmap survives for: - Quick single-axis SPMD where setting up a Mesh feels heavy. - Legacy code. - A few research patterns that depend on pmap's tight coupling between Python-side leading axis and device axis.

For new code at scale, prefer jit + sharding.


10. jax.shard_map: when you want manual control

jit + sharding is implicit: you annotate, GSPMD figures out collectives. shard_map is explicit: you write what each device sees, and you call collectives by hand.

from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P

@partial(shard_map, mesh=mesh, in_specs=(P("data", None), P(None, "model")),
                            out_specs=P("data", "model"))
def matmul(x_local, W_local):
    # x_local has the *local* shape on each device.
    # We must call collectives explicitly.
    y_partial = x_local @ W_local
    return y_partial   # already correctly sharded

Inside shard_map you operate on the local shard. Collectives (jax.lax.psum, jax.lax.all_gather, jax.lax.all_to_all, jax.lax.ppermute) reference the mesh axis names and run across the corresponding device subset.

When to use shard_map over jit+sharding: - You need an algorithm that GSPMD does not synthesize well (custom ring all-reduce, expert-parallel routing, sequence parallelism with overlap). - You want predictable collective placement for performance debugging. - You are implementing a low-level kernel (e.g., a custom attention with sequence sharding and explicit all_to_all).

When to stick with jit+sharding: - Standard transformer training. GSPMD does an excellent job. - You want one piece of code that retargets between mesh shapes without rewriting.

The mental model: jit+sharding is "declare the partitioning, let the compiler handle parallelism"; shard_map is "I am writing SPMD by hand, shoulder-to-shoulder with the hardware."


11. Structured loops: lax.scan, lax.fori_loop, lax.while_loop

A Python for loop inside a `jit - traced function unrolls into the jaxpr. For 4 iterations that is fine. For 1024 iterations the jaxpr (and the resulting HLO and the compile time) explode. Use structured control flow.

11.1 `jax.lax.scan - the workhorse

scan is a stateful map-reduce. Signature (informal):

def f(carry, x):                # carry: state, x: per-step input
    new_carry = ...
    y = ...
    return new_carry, y

final_carry, ys = jax.lax.scan(f, init_carry, xs)

It is O(T) in compile size regardless of T, because the loop body is traced once and reused.

Use for: - RNN forward passes. - Sampling loops where each step depends on the last. - Any reduction with an explicit state that you want to materialize per step (ys).

scan differentiates correctly through the loop in O(T) memory if you use unroll= carefully or rely on rematerialization (jax.checkpoint). The default reverse-mode AD over scan keeps activations for every step; for long sequences combine with jax.checkpoint to trade compute for memory.

11.2 jax.lax.fori_loop

def body(i, state): ...
final = jax.lax.fori_loop(0, N, body, init)

Like scan but does not stack per-step outputs (no ys). Slightly cheaper, less expressive. AD support is limited if N is dynamic-for differentiable loops prefer scan.

11.3 jax.lax.while_loop

def cond(state): ...
def body(state): ...
final = jax.lax.while_loop(cond, body, init)

Truly dynamic iteration count. Cannot be reverse-mode differentiated in the general case (the number of iterations is data-dependent). Use for inference-only loops with data-dependent termination (e.g., autoregressive sampling until EOS).

11.4 jax.lax.cond and jax.lax.switch

For data-dependent branching:

y = jax.lax.cond(pred, true_fn, false_fn, x)
y = jax.lax.switch(idx, [fn0, fn1, fn2], x)

Both branches are traced and compiled; only one runs at execution. This means both branches must return the same PyTree structure with matching shapes/dtypes.


12. XLA: the compiler under the hood

XLA (Accelerated Linear Algebra) is the compiler. JAX is one front-end; TensorFlow and PyTorch/XLA are others. We focus on what JAX programmers need to know.

12.1 HLO: the IR

HLO ("High Level Optimizer") is XLA's intermediate representation. An HLO module is a collection of computations; each computation is a list of instructions; each instruction has a name, an opcode, operand references, and a typed shape.

Modern XLA actually uses StableHLO (an MLIR dialect) at the boundary, then lowers internally to HLO. The two are largely isomorphic for our purposes.

A small HLO snippet for jnp.sum(jnp.sin(x * y)):

HloModule jit_f, entry_computation_layout={(f32[3]{0}, f32[3]{0})->(f32[])}

ENTRY main {
  x = f32[3]{0} parameter(0)
  y = f32[3]{0} parameter(1)
  prod = f32[3]{0} multiply(x, y)
  s    = f32[3]{0} sine(prod)
  zero = f32[] constant(0)
  ROOT sum = f32[] reduce(s, zero), dimensions={0}, to_apply=add_f32
}

You can dump HLO from JAX:

hlo = jax.jit(f).lower(jnp.ones((3,)), jnp.ones((3,))).compiler_ir(dialect="hlo")
print(hlo.to_string())

or get the post-optimization HLO via compiled.as_text() / compiled.runtime_executable(). Setting the env var XLA_FLAGS=--xla_dump_to=/tmp/xla_dump --xla_dump_hlo_as_text will dump every compiled module.

12.2 The HLO op set you should recognize

Op Meaning
parameter(i) The i-th input
constant(...) A literal
add, multiply, subtract, divide Elementwise
sine, exponential, log, tanh, ... Elementwise unary
compare Elementwise comparison
convert Dtype cast
broadcast(..., dimensions=...) Reshape/expand to a larger shape
reshape Same data, new shape
transpose(..., dimensions=...) Permute axes
slice, dynamic-slice Static / dynamic slicing
concatenate Concatenate along a dim
reduce(operand, init, dimensions=, to_apply=) Reduction with a reducer computation
dot(a, b, lhs_contracting_dims=..., rhs_contracting_dims=..., lhs_batch_dims=..., rhs_batch_dims=...) Generalized matmul
convolution Generalized conv
gather, scatter Indirect read / write
select Elementwise where
tuple, get-tuple-element Tuple constructors / accessors
while, conditional Structured control flow
all-reduce, all-gather, reduce-scatter, all-to-all, collective-permute Cross-device collectives
custom-call Escape hatch to a hand-written kernel (cuDNN, custom CUDA, Pallas)

dot is the linchpin. It is fully general: contracting dims are reduced, batch dims are kept, the rest are output. A standard (M, K) × (K, N) → (M, N) matmul has lhs_contracting=[1], rhs_contracting=[0], no batch dims. Attention's (B, H, S, D) × (B, H, T, D) has lhs_contracting=[3], rhs_contracting=[3], lhs_batch=[0,1], rhs_batch=[0,1].

12.3 The compilation pipeline

Roughly:

  jaxpr
    │  (lowered by JAX)
  StableHLO (MLIR)
    │  (XLA front-end)
  HLO
    │  (XLA HLO passes)
  Optimized HLO
    │  (Backend: GPU / TPU / CPU)
  Device code (PTX/SASS via LLVM-NVPTX,  TPU machine code,  x86 LLVM)

The HLO passes do the heavy lifting:

  1. Algebraic simplification. x * 1 → x, concat(slice, slice) → original, fold constants.
  2. Layout assignment. Pick physical memory layouts (which dim is fastest-varying, tile shapes, padding) per buffer to match the target hardware.
  3. Sharding propagation (GSPMD). From annotated tensors, infer shardings everywhere; insert collectives.
  4. Operator fusion. Combine adjacent elementwise ops, plus a producer reduction or matmul, into a single kernel-this is the single biggest performance win on GPU. JAX programs that look like 100 small numpy ops often compile to a handful of fused kernels.
  5. Memory scheduling. Order ops to minimize peak memory; insert rematerialization if needed.
  6. Lowering. GPU: emit LLVM IR (or call into cuBLAS/cuDNN for some patterns), then to PTX. TPU: emit TPU-specific IR for the matrix unit and vector unit, schedule across HBM/VMEM.

12.4 Fusion in detail

Fusion is the reason JAX feels fast. Consider:

def f(x, y, z):
    return jnp.tanh(x * y + z)

In eager numpy this is three kernel launches: multiply, add, tanh-each reads from and writes to HBM. XLA fuses them into one kernel: read x, y, z once, do all the elementwise math in registers, write the result once. On memory-bound workloads (most elementwise + small reductions) this is a 3–10× speedup.

XLA's GPU backend also fuses producer reductions and consumer elementwise (and vice versa) and a matmul with epilogue elementwise (and prologue elementwise on its inputs) where profitable. The resulting fused kernel is what you see as a single HLO fusion instruction in the post-optimization dump.

Caveat: fusion can hide bugs. If you jax.debug.print inside a function and it disappears, fusion has eliminated it; use jax.disable_jit() for debugging.

12.5 Layout assignment

Layout = how a logical shape (N, H, W, C) maps onto physical memory. On TPUs, tiling matters enormously: the matrix unit prefers (128, 128) tiles aligned in particular ways; XLA inserts paddings and chooses layouts so dot products hit fast paths. On GPUs, the choice is less critical (memory is more uniform) but still matters for shared memory and tensor cores.

You can sometimes coax XLA with shape choices (sizes that are multiples of 128 on TPU, multiples of 8 with bf16 for tensor cores on GPU).

12.6 GSPMD: how a single jit targets thousands of devices

GSPMD's input: an HLO module with sharding annotations on some subset of values (parameters, outputs, with_sharding_constraint points). Its output: an HLO module where every value has a sharding, with collectives inserted at boundaries.

The propagation is bidirectional and uses cost models. Key rules:

  • A dot between A: P("data", None) (sharded on batch) and B: P(None, "model") (sharded on output dim) produces a result P("data", "model") with no collective-purely local matmul.
  • A dot between A: P("data", "k") and B: P("k", "model") (sharded on the contraction dim) requires an all-reduce over the k axis after the local matmul.
  • An elementwise op with mismatched shardings inserts an all-gather or a reduce-scatter to align them.
  • A sequence-axis split followed by self-attention typically needs an all-to-all to switch from sequence-sharded to head-sharded.

GSPMD picks the lowest-cost combination. For most transformer workloads with a sensible mesh and a sensible PartitionSpec, the result is close to what an expert hand-writer would do. Where it isn't, you reach for with_sharding_constraint or shard_map.

12.7 Sharding propagation example

mesh = Mesh(devices.reshape(2, 4), ("data", "model"))
W1 = shard(W1_host, P(None, "model"))   # (D, 4D), sharded out dim
W2 = shard(W2_host, P("model", None))   # (4D, D), sharded in dim
x  = shard(x_host,  P("data", None))    # (B, D), sharded batch

@jax.jit
def mlp(x, W1, W2):
    h = jax.nn.gelu(x @ W1)
    y = h @ W2
    return y

GSPMD will: 1. Compute x @ W1: x is P("data", None), W1 is P(None, "model"), so result is P("data", "model") (no collective). 2. gelu: elementwise, sharding unchanged. 3. h @ W2: h is P("data", "model"), W2 is P("model", None). The contraction is on the "model" axis, so an all-reduce across "model" is inserted after the local matmul, yielding P("data", None).

That is the standard Megatron tensor-parallel MLP, derived automatically from three PartitionSpecs.


13. TPU vs GPU under JAX/XLA

XLA was conceived at Google with TPUs as its primary target, then extended to GPU and CPU. That history influences the runtime behavior.

Mark the TPU specifics here as "broadly true, version-dependent." TPU generations differ (v3/v4/v5p/v5e/Trillium…), and details have shifted.

13.1 Architectural differences (high level)

Aspect GPU (NVIDIA) TPU
Compute units SMs with FP32/FP16 cores + tensor cores Matrix multiply unit (systolic array, 128×128 typical) + vector unit
Memory hierarchy Registers → shared mem → L2 → HBM Registers → VMEM (per core) → HBM
Numeric formats FP32, FP16, BF16, FP8, INT8 Primarily BF16 / FP32 / INT8; FP8 in newer gens
Interconnect NVLink within a node, IB / Ethernet across nodes Native ICI (inter-chip interconnect) in pod topology-2D/3D torus
Sweet spot Heterogeneous workloads, irregular ops, custom CUDA Big regular dense matmuls, large pods

13.2 Compiler differences

  • Fusion granularity. TPU XLA tends to produce very large fused regions-sometimes the entire transformer block becomes one or two ops. GPU XLA fuses aggressively but is bounded by SM resource limits.
  • Layout. TPUs are pickier about layout (the matrix unit's tile size is fixed). XLA pads aggressively to align-small tensors can have substantial padding overhead. Sizes that are multiples of 128 (sometimes 256) are friendly.
  • Collectives. TPU pods have a 2D or 3D torus; XLA's collective scheduler is tightly tuned for it. GPU collectives go through NCCL (or XLA's own GPU collectives), and topology is less regular.
  • Async dispatch. Both backends launch async; on TPU the compiler often overlaps collectives with compute aggressively (because the cost model is well-understood).

13.3 Why XLA was TPU-first

TPUs are useless without a compiler-there is no eager kernel library equivalent to cuDNN. Every TPU program goes through XLA. So XLA had to be excellent at TPU codegen, sharding, and pod-scale collectives from day one. JAX inherited all of that. Running JAX on TPU is "the original path"; running JAX on GPU shares the front end but uses a different (also mature) backend.

13.4 Practical implications

  • A JAX program that runs well on 1 GPU often runs well on 1 TPU core with no changes.
  • A JAX program that scales to 8 GPUs via jit+sharding scales to a 2048-core TPU pod by enlarging the `Mesh - same code.
  • TPU memory per core is typically smaller than a top-end GPU's HBM. You will lean more on FSDP-style sharding and rematerialization on TPU.

14. Module systems: Equinox and Flax

JAX core has no nn.Module. Two libraries dominate.

14.1 Flax (flax.linen and the newer flax.nnx)

flax.linen (the long-standing API):

from flax import linen as nn
class MLP(nn.Module):
    hidden: int
    out: int
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden)(x)
        x = nn.relu(x)
        return nn.Dense(self.out)(x)

model = MLP(64, 10)
params = model.init(rng, dummy_input)        # returns a PyTree of params
y = model.apply(params, x)

Idioms: - Modules are dataclasses; calling them constructs thunks, not param-owning objects. - model.init(rng, x) traces the module to create a parameter PyTree. - model.apply(params, x) runs the forward pass-params are an argument, never owned.

This is JAX-functional through and through. Optimizer state is also a PyTree, threaded through train_step.

flax.nnx is a newer API that gives you a more PyTorch-like stateful feel while preserving JAX semantics under the hood. Pick linen for stable production code and the largest ecosystem; pick nnx if you prefer the ergonomics.

14.2 Equinox

Equinox treats modules as PyTrees of dataclasses directly:

import equinox as eqx
class MLP(eqx.Module):
    l1: eqx.nn.Linear
    l2: eqx.nn.Linear
    def __init__(self, key):
        k1, k2 = jax.random.split(key)
        self.l1 = eqx.nn.Linear(784, 128, key=k1)
        self.l2 = eqx.nn.Linear(128, 10,  key=k2)
    def __call__(self, x):
        return self.l2(jax.nn.relu(self.l1(x)))

model = MLP(jax.random.PRNGKey(0))
y = model(x)                                  # call the model directly
grads = jax.grad(loss)(model, x, y_true)      # model itself is the param tree

Idioms: - A module is its parameters. There is no separate params PyTree. - eqx.partition(model, eqx.is_array) separates trainable arrays from non-array fields when needed (e.g., for optax).

Pick Equinox if you like the "model is data" mental model and small dependency. Pick Flax for the bigger ecosystem (pretrained checkpoints, integrations, MaxText/Pax).

14.3 Optimizer: optax

Both work with optax, the standard JAX optimizer library. optax exposes optimizers as (init, update) function pairs that operate on PyTrees:

import optax
opt = optax.adamw(1e-3)
opt_state = opt.init(params)
grads = jax.grad(loss)(params, ...)
updates, opt_state = opt.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)

Note: pure functions, no mutation. optax chains transformations (optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(...))) which makes complex schedules trivially composable.


15. jax.experimental.pallas: Triton-like kernels in JAX

When the compiler's automatic codegen is not enough-typically for fused attention, custom flash-attention variants, or quantized kernels-Pallas lets you write kernels in a Python DSL that lowers to:

  • Triton on GPU, and
  • Mosaic on TPU.

Sketch:

from jax.experimental import pallas as pl

def add_kernel(x_ref, y_ref, o_ref):
    o_ref[...] = x_ref[...] + y_ref[...]

@jax.jit
def add(x, y):
    return pl.pallas_call(
        add_kernel,
        grid=(x.shape[0] // 128,),
        in_specs=[pl.BlockSpec((128,), lambda i: (i,)),
                  pl.BlockSpec((128,), lambda i: (i,))],
        out_specs=pl.BlockSpec((128,), lambda i: (i,)),
        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
    )(x, y)

Mental model: Pallas kernels look like CUDA/Triton kernels (you reason about blocks and refs/pointers), but they integrate as a single HLO custom-call in your JAX program, with full vmap and jit composition. Use it sparingly-only when the surrounding XLA-generated code leaves serious performance on the table-and benchmark against the un-Pallas baseline.

This is the JAX answer to PyTorch's "drop into Triton". It is a young area; expect API motion. The big production win so far is fused-attention kernels.


16. Practical exercises (with worked answers)

How to use these. Read the question. Pause. Predict. Then read the answer. If you predicted wrong, re-read the relevant section.

Exercise 1-What gets recompiled?

You have:

@jax.jit
def f(x, y):
    return x @ y + 1.0

f(jnp.ones((4, 8)),  jnp.ones((8, 16)))     # call A
f(jnp.ones((4, 8)),  jnp.ones((8, 16)))     # call B
f(jnp.ones((4, 8)),  jnp.ones((8, 32)))     # call C
f(jnp.ones((4, 8)).astype(jnp.bfloat16),
  jnp.ones((8, 16)).astype(jnp.bfloat16))   # call D

Which calls trigger compilation?

Answer. A compiles. B reuses A's cache (identical abstract signature). C compiles a new entry (different shape on y). D compiles a new entry (different dtype on both). Total: 3 compilations, 4 calls.

Exercise 2-Why did my training loop slow to a crawl?

Symptom: every step takes ~2 s, no GPU utilization between steps. JAX_LOG_COMPILES=1 shows a compile log every step.

Likely cause. Variable-shape inputs. Most often: padding sequences to their actual lengths inside the data loader, so each batch has shape (B, L) with a different L.

Fixes (any of): 1. Pad to a fixed maximum length (cheapest, slight wasted compute). 2. Pad to a small set of bucket lengths (e.g., 128/256/512/1024)-at most that many compilations. 3. Mark sequence length as static_argnums only if you genuinely have a tiny number of distinct values.

Exercise 3-Per-example gradient norms

You want, for each example in a batch, the L2 norm of its per-example gradient. Write it.

def per_example_loss(params, x, y):
    return ((x @ params["W"] - y) ** 2).mean()    # scalar per example

per_grad = jax.vmap(jax.grad(per_example_loss),
                    in_axes=(None, 0, 0))         # share params, batch x and y

def per_example_norm(params, X, Y):
    grads = per_grad(params, X, Y)                # PyTree, leaves shape (B, ...)
    flat  = jax.tree.leaves(grads)
    sq    = sum(jnp.sum(g.reshape(g.shape[0], -1) ** 2, axis=1) for g in flat)
    return jnp.sqrt(sq)                           # shape (B,)

Reading: vmap(grad(...)) produces a function that returns gradient PyTrees with a leading batch axis; we then compute the per-row norm. No Python loop.

Exercise 4-A Megatron-style MLP, by PartitionSpec

You have an MLP Linear(D, 4D) → gelu → Linear(4D, D). You have an 8-device mesh ("data", "model") with shape (2, 4). Specify shardings such that: - The batch is data-parallel. - Both Linear layers are tensor-parallel along the hidden axis. - The output of the MLP is data-sharded, model-replicated (so a downstream layer-norm sees a complete vector per example).

Answer.

mesh = Mesh(devices.reshape(2, 4), ("data", "model"))
x_spec   = P("data", None)          # (B, D)
W1_spec  = P(None, "model")         # (D, 4D)
W2_spec  = P("model", None)         # (4D, D)
y_spec   = P("data", None)          # (B, D)

GSPMD will compute x @ W1 locally (no collective), gelu locally, then h @ W2 with an all-reduce over "model" on the contraction axis to produce y_spec. This is exactly Megatron-LM's "column-parallel then row-parallel" pattern, derived from PartitionSpecs.

Exercise 5-Why won't this differentiate?

@jax.jit
def f(x):
    n = 0
    while jnp.linalg.norm(x) > 1.0:
        x = x / 2
        n += 1
    return x, n

You wrap it in jax.grad(lambda x: f(x)[0].sum()) and get an error. Why?

Answer. The Python while runs at trace time and depends on a traced value (jnp.linalg.norm(x) > 1.0 is a tracer). You will hit a ConcretizationTypeError even before differentiation. The fix is jax.lax.while_loop, but: while_loop does not support reverse-mode AD because the iteration count is data-dependent. If you need a differentiable variant, use jax.lax.scan with a known maximum number of iterations, masking the unused steps.

Exercise 6-Reading a jaxpr

What does this code do?

def g(x):
    return jnp.where(x > 0, x, 0.5 * x)

print(jax.make_jaxpr(g)(jnp.array([-1.0, 2.0, -3.0])))

Answer. Approximately:

{ lambda ; a:f32[3]. let
    b:bool[3] = gt a 0.0
    c:f32[3]  = mul a 0.5
    d:f32[3]  = select_n b c a
  in (d,) }

It is a leaky-ReLU-ish function (slope 0.5 on the negative side). The jaxpr makes the elementwise nature explicit and shows that where becomes a select_n (3-way select primitive) rather than a Python branch. It will compile to a single fused HLO kernel (compare, multiply, select).


17. Cheat sheet

17.1 The four core transformations

Transformation Maps f to Mental model
jax.jit(f) A function that traces, lowers to HLO, compiles, caches, runs "Make it fast"
jax.grad(f) A function returning df/dx "Differentiate"
jax.vmap(f) A function with an extra batch axis "Vectorize"
jax.pmap(f) / jit(..., in_shardings=...) A function running across devices "Parallelize"

They compose. jit(vmap(grad(f))) is well-defined and idiomatic.

17.2 Common errors and their meaning

Error Cause Fix
ConcretizationTypeError Python control flow on a tracer Use jax.lax.cond/where, or make the variable static
TracerArrayConversionError Tried to convert a tracer to numpy or to use it as a Python int Push the work into JAX-land or restructure
Repeated compiles Variable shapes / dtypes / static args / pytree structures Stabilize shapes, use bucketing, audit static_argnums
OOM during compile Long Python for loop being unrolled Use lax.scan
Silent wrong answer Side effect (mutation, global) inside jitted function Make it pure

17.3 Inspection toolbox

jax.make_jaxpr(f)(x)                                 # show jaxpr
jax.jit(f).lower(x).as_text()                        # StableHLO MLIR
jax.jit(f).lower(x).compiler_ir(dialect="hlo")       # HLO
jax.jit(f).lower(x).compile().as_text()              # post-opt HLO
jax.jit(f).lower(x).compile().cost_analysis()        # FLOPs, bytes
jax.config.update("jax_log_compiles", True)          # see every compile
JAX_TRACEBACK_FILTERING=off                          # (env) full Python tracebacks
XLA_FLAGS=--xla_dump_to=/tmp/xla --xla_dump_hlo_as_text   # (env) dump every module

17.4 When to reach for which loop primitive

Need Primitive
Static iteration count, want per-step outputs, want AD jax.lax.scan
Static iteration count, no per-step outputs, no AD on the loop count jax.lax.fori_loop
Data-dependent iteration count, no AD over the loop jax.lax.while_loop
Data-dependent branch, both branches valid jax.lax.cond
Choose among N branches by index jax.lax.switch

17.5 Sharding spec recipes (mesh ("data", "model"))

Goal Inputs Weights Note
Pure DP (replicated weights) P("data", ...) P(None, ...) Standard data parallel
FSDP-style P("data", ...) P("data", ...) (gathered before use) Combine with with_sharding_constraint
Tensor parallel (Megatron MLP) P("data", None) W1: P(None, "model"), W2: P("model", None) All-reduce after W2
2D parallelism P("data", None) TP weights as above ("data", "model") mesh

17.6 Mental discipline

  • State-as-argument. Anything that "carries over"-params, opt state, RNG, batchnorm running stats-is an argument and a return value, never a global.
  • Trace once, run many. Every `jit - decorated function should be traced a small number of times across the entire program lifetime.
  • Annotate sharding sparsely. Annotate the inputs and key intermediate shardings; let GSPMD figure out the rest.
  • Profile with the post-opt HLO. What you wrote in Python and what runs on the device can diverge dramatically due to fusion. Read the post-optimization HLO before optimizing.
  • Read the jaxpr when surprised. It is short and exact.

Closing

JAX is a small core (pure functions over PyTrees, traced into jaxprs, lowered to XLA HLO) with a large amount of leverage on top (composable transformations, GSPMD, Pallas, the Flax/Equinox ecosystem). Its design exacts a discipline-purity, explicit state, structured control flow, deliberate sharding-and pays back with a programming model that scales seamlessly from a single laptop GPU to a 4096-chip TPU pod with the same source code.

If you internalize four ideas from this chapter, make them: (1) purity makes transformations compose, (2) jit is trace-then-cache, and the cache key is the abstract signature, (3) HLO is the IR; fusion and GSPMD are where the magic lives, (4) Mesh + PartitionSpec is how you tell the compiler about your hardware, and the rest is propagation. Everything else in JAX is a refinement of those.

Comments