Deep Dive 05-JAX and XLA¶
Reading contract. This chapter is a self-contained reference. After reading and working the exercises you should be able to (a) read and write idiomatic JAX, (b) reason about what happens when
@jax.jitis applied to a function, (c) inspect jaxprs and HLO, (d) shard a computation across a multi-host TPU/GPU cluster usingMesh+PartitionSpec, and (e) pick betweenjit - with-sharding,shard_map, and (legacy)pmap` for a given workload. We do not punt to the JAX docs.
Table of contents¶
- Why JAX exists
- Functional purity: the unit of compilation
- PyTrees and
jax.tree_util - Stateless PRNGs (
PRNGKey) - Tracing and jaxprs
jax.jit: caching, recompilation, static argsjax.grad,value_and_grad,jvp,vjpjax.vmapand per-example gradients- Device parallelism:
pmap(legacy) vsjit+ sharding (modern) jax.shard_map: when you want manual control- Structured loops:
lax.scan,lax.fori_loop,lax.while_loop - XLA: HLO IR, compilation pipeline, fusion, layout, GSPMD
- TPU vs GPU under XLA
- Module systems on top: Equinox and Flax
jax.experimental.pallas(Triton-like kernel DSL)- Practical exercises (with worked answers)
- Cheat-sheet appendix
1. Why JAX exists¶
JAX is a library for numerical computing whose central thesis is:
Numerical programs are functions of arrays. If you keep them pure, you can compose program transformations on them-autodiff, vectorization, parallelization, just-in-time compilation-and you can lower the result through a single optimizing array compiler (XLA) to CPU, GPU, or TPU.
That sentence has every load-bearing word in JAX's design. Let us unpack it against the contrast that most readers carry: PyTorch.
1.1 PyTorch's design (so we have a foil)¶
PyTorch is eager and imperative: a tensor operation runs immediately on the host's accelerator queue. The autograd graph is built dynamically as a side effect of forward computation-every tensor with requires_grad=True allocates a node in a graph stored on the tensor itself. Modules are stateful objects (nn.Module) that own their parameters, buffers, and (transitively) the optimizer state. To go fast you typically (a) call into eager kernels, (b) use torch.compile (TorchDynamo + Inductor) which traces Python bytecode and lowers to fused kernels, or (c) drop into custom CUDA / Triton.
PyTorch optimizes for debuggability and programmer ergonomics in idiomatic Python. You can print a tensor, set a Python breakpoint inside a forward pass, mutate a list of layers conditionally, and it all works.
1.2 JAX's design choices¶
JAX takes an almost orthogonal set of choices:
-
Pure functions are the unit. A JAX-compilable function takes arrays in, returns arrays out, and has no side effects: no mutation of outer Python state, no in-place tensor edits, no random state hidden in a global. Everything that would be state is threaded explicitly through arguments and return values.
-
Composable transformations. Once a function is pure, JAX can give you several function-to-function transformations:
- `jax.jit - trace and compile via XLA.
- `jax.grad - return a function computing the gradient.
- `jax.vmap - return a function that runs the original over a new batch axis.
-
jax.pmap/jitwith sharding-return a function that runs across devices. These compose:jit(vmap(grad(f)))is meaningful and well-defined. This composability is the soul of JAX. -
XLA as the default backend. Where PyTorch's "real" backend is a constellation of cuDNN/cuBLAS calls in eager mode and Inductor-generated Triton at compile time, JAX always lowers to HLO (XLA's IR) and lets the XLA compiler emit device code. This means TPU and GPU share most of the toolchain. (XLA was originally a TPU compiler at Google, and that lineage shows.)
-
TPU first-class. Unlike PyTorch where TPU support is delivered through
torch_xlaas an extra layer, JAX speaks XLA natively. A JAX program written for one TPU pod core scales to thousands of cores essentially by adding sharding annotations.
1.3 Trade-offs¶
The JAX bargain:
- You give up in-place mutation, easy printing inside compiled code, dynamic shapes that change every call, and Python-level control flow that depends on traced values.
- You get a uniform compilation pipeline, world-class autodiff that composes with everything, free vectorization (
vmap), free distribution (jit+ sharding), and predictable performance because compilation is explicit.
For research workloads with stable shapes and heavy linear algebra (transformers, diffusion, scientific computing), this trade is excellent. For workloads with ragged shapes, dynamic graphs of varying topology, or heavy host-side branching (some RL systems, classical NLP pipelines), it can be painful.
2. Functional purity: the unit of compilation¶
A pure function in JAX's sense:
- Outputs depend only on inputs. No global reads (other than constants closed over at trace time).
- Has no observable side effects. No global writes, no I/O, no mutation of arguments.
- Same input → same output, every call.
import jax
import jax.numpy as jnp
# Pure
def loss(params, x, y):
pred = x @ params["W"] + params["b"]
return jnp.mean((pred - y) ** 2)
# Impure-mutates a global counter
counter = 0
def bad(x):
global counter
counter += 1
return x * 2
jax.jit(bad) will appear to work but counter will be incremented exactly once per trace, not once per call. The bug surfaces silently. This is a recurring pattern: JAX does not police purity at runtime; it assumes it. If you violate it, you get correct-looking output and incorrect semantics.
The discipline imposed by purity buys two enormous things:
- Trivial reverse-mode AD. With no side effects, the chain rule is just structural induction over the jaxpr. There is nothing to "undo."
- Trivial parallelization. Pure functions are referentially transparent; you can run them on any device, in any order, multiple times, without changing program meaning.
2.1 Where state lives¶
If the model has weights, optimizer moments, RNG state, batch-norm running stats, those are values that the user threads through arguments:
def train_step(params, opt_state, rng, batch):
rng, sub = jax.random.split(rng)
grads = jax.grad(loss)(params, batch, sub)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, rng
Compare with PyTorch where optimizer.step() mutates param.data and param.grad in place. JAX surfaces the mutation as new return values. The train_step itself remains pure.
3. PyTrees and jax.tree_util¶
A neural network has hundreds or thousands of parameters. Threading them as positional arguments is unworkable. JAX solves this by making arbitrary nested Python containers first-class.
A PyTree is, recursively, either:
- a leaf (an array, a scalar, anything not a registered container), or
- a container node-by default tuple, list, dict, None, `namedtuple - with PyTree children.
Custom dataclasses can be registered with jax.tree_util.register_pytree_node (or, for dataclasses, @jax.tree_util.register_dataclass / Equinox's automatic registration).
3.1 Why this matters¶
Every JAX transformation that takes a function f(x) -> y and produces f'(x) -> y' operates over PyTrees: x and y may be arbitrarily nested. The transformations preserve structure.
Example: jax.grad applied to a function whose first argument is params = {"layer1": {"W": ..., "b": ...}, "layer2": {...}} returns a gradient PyTree with the same shape as params. You never write flatten_params glue code by hand.
3.2 The core API¶
from jax import tree_util as tu
leaves, treedef = tu.tree_flatten(params) # list of arrays + structure spec
params2 = tu.tree_unflatten(treedef, leaves) # rebuild
# Map a function over every leaf:
doubled = jax.tree.map(lambda x: 2 * x, params)
# Map across two PyTrees that share structure:
sum_tree = jax.tree.map(lambda a, b: a + b, params, grads)
In recent JAX, the public surface is jax.tree.map, jax.tree.leaves, jax.tree.structure, etc. The underlying jax.tree_util module remains.
3.3 PyTree as the universal interface¶
Internally, every JAX transformation: 1. Flattens its inputs to a flat list of leaves + a tree-structure description. 2. Operates on the flat list (where everything is just an array). 3. Reconstructs the output structure on the way out.
This is why you can pass a dict of arrays to jit, grad, vmap, pmap and they all "do the right thing"-they each call tree_flatten and treat leaves uniformly.
3.4 Custom PyTree¶
import dataclasses
@jax.tree_util.register_dataclass
@dataclasses.dataclass
class GRUCell:
W: jax.Array
U: jax.Array
b: jax.Array
Now GRUCell instances are valid PyTrees. jax.grad will return a GRUCell of gradients.
Contrast with PyTorch. PyTorch's
nn.Moduleis a class with astate_dict()method. JAX's analog is "any PyTree." The "module" abstraction is built on top (Flax, Equinox), not into the core.
3.5 Exercise¶
Given params = {"a": jnp.zeros((3,)), "b": [jnp.ones((2,2)), jnp.ones((2,))]}, write the call that returns {"a": shape (3,), "b": [shape (2,2), shape (2,)]}.
4. Stateless PRNGs (PRNGKey)¶
Random numbers are state. State is impure. JAX therefore cannot have a global RNG (well, it could, but it would break composition with jit/vmap/pmap).
Solution: an explicit, immutable key, threaded by the user.
key = jax.random.PRNGKey(42) # an array of shape (2,) uint32 (historically)
key, subkey = jax.random.split(key) # 2 new keys; old key conceptually consumed
x = jax.random.normal(subkey, (1024,))
Three rules:
- Never reuse a key.
random.normal(key, ...)is a pure function ofkey; passing the same key gives identical samples. - Always split before consuming.
split(key, n)returnsnfresh keys. - Threading is the user's job. Every function that consumes randomness takes a key argument.
4.1 Why this design¶
- Reproducibility. Two runs with the same starting key produce bit-identical results, even across
vmap/pmap/sharding. - Composability.
vmap(f)over a batched-keys argument gives per-example randomness with no surprise. With a global RNG,vmapcould not say what the per-example samples should be. - Determinism under parallelism. Each device gets its own key derived deterministically from the master key. No race conditions, no per-device RNG state to seed.
4.2 Idiom: per-step splitting¶
def train_step(params, rng, batch):
rng, dropout_key = jax.random.split(rng)
logits = model(params, batch, dropout_key)
...
return params, rng # return the *new* rng for the next step
You can split into many keys at once: keys = jax.random.split(rng, num=8) and vmap over them.
4.3 Key types¶
Modern JAX has typed keys (e.g., the threefry and rbg algorithms). For day-to-day work PRNGKey(seed) is fine; for cryptographic-grade or platform-specific needs see the jax.random.key API. The conceptual model-explicit, splittable, stateless-is unchanged.
5. Tracing and jaxprs¶
This is the conceptual heart of JAX. Internalize it and the rest follows.
5.1 What @jax.jit actually does¶
When you call a `jit - decorated function for the first time with concrete arguments:
- JAX inspects each argument's shape and dtype (and static-argnum python values).
- It calls your Python function with abstract
Tracerobjects in place of those arguments-objects that record every operation performed on them but do not compute values. - The resulting trace is a jaxpr (JAX expression): a small typed IR of primitive operations.
- The jaxpr is lowered to HLO and compiled by XLA for the target device.
- The compiled executable is cached, keyed by (function identity, abstract input signature, static-arg values).
- The actual concrete arguments are run through the executable.
Subsequent calls with the same abstract signature skip steps 2–5 and just run the cached executable.
5.2 A worked jaxpr¶
import jax
import jax.numpy as jnp
def f(x, y):
a = x * y
b = jnp.sin(a)
return jnp.sum(b)
print(jax.make_jaxpr(f)(jnp.ones((3,)), jnp.ones((3,))))
You will see something like:
{ lambda ; a:f32[3] b:f32[3]. let
c:f32[3] = mul a b
d:f32[3] = sin c
e:f32[] = reduce_sum[axes=(0,)] d
in (e,) }
Reading it:
a:f32[3] b:f32[3] - twofloat32inputs of shape(3,)`.- The
letblock names intermediate values. mul,sin,reduce_sumare JAX primitives-the leaves of the jaxpr.- `in (e,) - the output tuple.
A jaxpr is pure, typed, and closed: every variable is bound, every operation is a primitive, every shape and dtype is known. This is the input XLA receives.
5.3 Concrete vs Abstract vs Traced¶
Three kinds of values flow through JAX code:
- Concrete arrays (
jax.Array): real data on a device. Default outsidejit. - Abstract arrays (
ShapedArray,ConcreteArray): metadata only-shape, dtype, optional weak type. Used for tracing. - Tracers (
Tracer): wrappers presented to the user's Python function during tracing. They look like arrays (they have.shape,.dtype, support+,*,jnp.sin, etc.) but every operation on them appends a node to the jaxpr.
A common pitfall: writing Python control flow on a tracer.
@jax.jit
def f(x):
if x > 0: # ConcretizationError: Tracer cannot be branched on
return x
else:
return -x
The condition x > 0 is a tracer because x is. Python's if requires a concrete bool. You must use jax.lax.cond(x > 0, ..., ...) (or jnp.where for elementwise selection), which becomes part of the jaxpr.
5.4 Static arguments¶
If x is a Python scalar that determines the shape of arrays-e.g. number of layers-make it static:
from functools import partial
@partial(jax.jit, static_argnums=(1,))
def make_zeros(rng, n): # n is static
return jax.random.normal(rng, (n,))
n is now treated as part of the cache key, not as a tracer. Different n values trigger different compilations.
5.5 Print-debugging¶
print(x) inside a jitted function prints the tracer (useful for inspection), not the value. For per-call debug prints use jax.debug.print("x = {}", x) which compiles into a host callback.
6. jax.jit: cache, recompilation, costs¶
6.1 The cache key¶
The compilation cache is keyed (essentially) by:
| Component | What changes the key |
|---|---|
| Function identity | The Python function object |
| Abstract input signature | shape and dtype of every leaf in the input PyTree |
| Static argument values | python == equality of static_argnums / static_argnames values |
| PyTree structure | the treedef of inputs |
| Device / sharding context | target backend & sharding spec |
So:
- Calling
f(x_f32_3x4)thenf(x_f32_3x4)→ 1 compile, 2 calls. - Calling
f(x_f32_3x4)thenf(x_f32_3x5)→ 2 compiles (different shape). - Calling
f(x_f32_3x4)thenf(x_f64_3x4)→ 2 compiles (different dtype). - Calling
f({'a': x, 'b': y})thenf({'b': y, 'a': x})→ 1 compile (dicts have stable PyTree order). - Changing the value of a non-static argument → 0 compiles.
- Changing the value of a static argument → 1 compile per distinct value.
6.2 Recompilation costs¶
Compilation is not free: HLO optimization plus device codegen can cost hundreds of milliseconds to tens of seconds for transformer-scale models. If your training loop accidentally retraces every step, the program runs but at compile-bound throughput. Symptoms:
- First step: 5 s. Second step: 5 s. Third step: 5 s. (Should be: 5 s, 50 ms, 50 ms.)
- Memory growth in the HLO module cache.
Detection:
or set JAX_LOG_COMPILES=1. You should see one line per training-loop function, not one per step.
Common causes of accidental retraces:
- Padding-by-actual-length: each batch has a slightly different sequence length, so shapes vary. Fix: pad to a small set of bucket lengths.
- Passing Python ints that the function uses to construct shapes-make them
static_argnums. - Passing different PyTree structures (e.g., a dict with an optional key).
6.3 jit is lazy, dispatch is async¶
jit - compiled calls return immediately with futures (jax.Arraybacked by a pending computation). Usejax.block_until_ready(x)orx.block_until_ready()` when timing.
6.4 Ahead-of-time lowering¶
lowered = jax.jit(f).lower(jnp.ones((4,)), jnp.ones((4,)))
print(lowered.as_text()) # StableHLO MLIR text
print(lowered.compiler_ir(dialect="hlo")) # HLO module
compiled = lowered.compile()
print(compiled.cost_analysis()) # FLOPs, bytes, etc., for benchmarking
This is the inspection toolchain you will use repeatedly.
7. jax.grad and friends¶
7.1 Reverse-mode AD on a jaxpr¶
jax.grad(f) returns a function g such that g(x) equals df/dx evaluated at x. Mechanically:
- Trace
fto a jaxpr (the primal jaxpr). - Walk the jaxpr forward, recording residuals where needed.
- Construct a transposed / reverse jaxpr that computes the cotangent: for each primitive, JAX has a registered VJP rule (
primitive.def_vjp(...)). - Return that as a callable jaxpr (typically jitted).
Functional purity is what makes step 3 trivial: each primitive has a local linearization, and the chain rule is just composition because there are no hidden state edges to break it.
By convention, grad(f) differentiates with respect to the first argument and expects a scalar output.
def loss(params, x, y):
return jnp.mean((x @ params["W"] - y) ** 2)
grad_fn = jax.grad(loss)
g = grad_fn(params, x, y) # g has the same PyTree shape as params
For multiple argnums: jax.grad(loss, argnums=(0, 1)) returns a tuple of grads.
7.2 value_and_grad¶
You usually want both the loss value and the gradient. jax.value_and_grad(loss) returns (loss, grads) from a single trace-no double work.
7.3 Higher-order¶
jax.grad(jax.grad(f)) is well-defined (it traces grad(f) and differentiates that trace). For Hessians, use jax.hessian(f) (which is jacfwd(jacrev(f)) under the hood for a typical scalar function).
7.4 Forward and reverse mode primitives¶
Two lower-level operators:
jax.jvp(f, primals, tangents) - forward-mode: computesf(primals)and the directional derivativeJ · tangents` in one go. Good when output dim ≪ input dim.jax.vjp(f, *primals) - reverse-mode: returns(f(primals), vjp_fn)wherevjp_fn(cotangent)computescotangent · J. This is whatgrad` is built on.
Rule of thumb:
- Few outputs, many inputs (training loss → loss is scalar): reverse mode (
grad/vjp). - Few inputs, many outputs (sensitivity of a vector-valued function to a small parameter): forward mode (
jvp). - For Jacobians:
jax.jacrev(reverse) for tall Jacobians,jax.jacfwd(forward) for wide Jacobians.
7.5 Custom derivatives¶
@jax.custom_vjp
def stable_softmax(x):
z = x - jnp.max(x)
return jnp.exp(z) / jnp.sum(jnp.exp(z))
def fwd(x): ...
def bwd(res, g): ...
stable_softmax.defvjp(fwd, bwd)
Use this when (a) you have a hand-derived gradient that is more numerically stable than the autodiff one, or (b) you want to break a gradient (stop_gradient is a single-line alternative for that).
7.6 Contrast with PyTorch¶
In PyTorch, loss.backward() walks the dynamically-built graph attached to the tensor, populates .grad fields by mutation, and frees the graph. In JAX, grad(loss) builds a new function that returns a new PyTree of gradients. There is no .grad attribute, no graph to free, no optimizer.zero_grad() needed-purity makes "zero out before backward" unnecessary because nothing is mutated.
8. jax.vmap: vectorization is a transformation¶
vmap(f) returns a function that runs f over an extra leading axis without you writing the batched code. It is not a Python for loop. It rewrites the jaxpr to push the batch dimension through every primitive.
8.1 The basic interface¶
def dot(a, b): # a:(d,) b:(d,) -> ()
return jnp.sum(a * b)
batched_dot = jax.vmap(dot) # (B,d), (B,d) -> (B,)
batched_dot(jnp.ones((32, 5)), jnp.ones((32, 5)))
8.2 in_axes / out_axes¶
Specify which axis of each argument is the batched axis:
# Batch over a's first axis, broadcast b:
jax.vmap(dot, in_axes=(0, None)) # (B,d), (d,) -> (B,)
# Batch over a's last axis, b's first axis:
jax.vmap(dot, in_axes=(-1, 0))
None means "this argument is not batched-broadcast it." out_axes controls where the batch axis appears in the output (default 0).
For PyTree arguments, in_axes is itself a PyTree (or a single int, applied uniformly).
8.3 How vmap works under the hood¶
For each primitive p, JAX has registered a batching rule: given how each input is batched, what is the batched output and along which axis? vmap walks the jaxpr applying these rules. Most primitives have rules that turn into a single fatter primitive call-e.g., vmap(dot) becomes a matmul, not a Python loop. That is why vmap is fast: it produces good HLO.
8.4 Composes with grad¶
The canonical example: per-example gradients.
In standard training, grad(loss) gives the gradient of the mean loss-a single PyTree summed across the batch. Sometimes you want one gradient per example (for influence functions, differential privacy, gradient clipping per example, etc.).
def per_example_loss(params, x, y): # x:(d,), y:()-single example
pred = x @ params["W"] + params["b"]
return (pred - y) ** 2
per_example_grad = jax.vmap(jax.grad(per_example_loss), in_axes=(None, 0, 0))
# per_example_grad(params, X, Y) returns grads with leading axis B
Read it carefully:
- grad(per_example_loss) is a function that, given a single example, returns a single gradient PyTree.
- vmap(...) lifts that over a batch axis on (x, y) while sharing params.
- The result has the same PyTree structure as params but with an added leading batch dimension on every leaf.
This is two lines of JAX. The PyTorch equivalent typically requires functorch.vmap (now folded into torch.func) or torch.func.vmap(grad(...)).
8.5 vmap for the inference case¶
A model that runs on a single example can be batched by writing the model for one example and vmap - ing it. This is sometimes cleaner than worrying about broadcasting in the model definition. In practice, most JAX models are written batched (becausematmulalready does it for free), butvmap` is invaluable for non-trivial axes (e.g., MoE expert routing, beam search, per-head attention).
9. Device parallelism: pmap (legacy) vs jit + sharding (modern)¶
JAX has been through a small evolution here. Understand both because real codebases mix them.
9.1 pmap (the original)¶
pmap is "parallel map": it `jit - compiles a function and runs it on multiple devices, with one shard of the leading axis per device.
@jax.pmap # 8 devices: leading axis must be 8
def step(params, batch):
...
return new_params
# params must be replicated across devices:
params = jax.tree.map(lambda x: jnp.broadcast_to(x, (8,) + x.shape), params)
batch = ... shape (8, B/8, ...)
new_params = step(params, batch)
Cross-device communication is via collectives inside the function: jax.lax.psum(x, axis_name="i"), pmean, all_gather, etc., where axis_name is set by pmap(..., axis_name="i").
pmap is single-program multiple-data with explicit sharding by the user, restricted to one batch dimension and one mesh axis. It composes-`pmap(vmap(grad(f))) - but it is awkward when you need 2D meshes, host coordination across many TPU hosts, or per-tensor sharding choices.
9.2 pjit (intermediate) and the unified jit (modern)¶
Around 2022 JAX introduced pjit - ajitthat took aMeshandPartitionSpecs and let XLA's GSPMD partitioner shard arbitrary tensors across an arbitrary mesh. This was the right abstraction. By 2024,pjitandjitwere unified: today **jax.jitnatively understands sharding**, andpjit` is an alias.
The modern path:
import jax
from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
devices = jax.devices() # e.g., 8 GPUs or a 2x4 TPU slice
mesh = Mesh(np.array(devices).reshape(2, 4), axis_names=("data", "model"))
# Place an array sharded:
def shard(x, spec):
return jax.device_put(x, NamedSharding(mesh, spec))
x = shard(x_host, P("data", None)) # batch sharded over 2 devices
W = shard(W_host, P(None, "model")) # output dim sharded over 4 devices
b = shard(b_host, P("model"))
@jax.jit
def forward(x, W, b):
y = x @ W + b
# Optionally constrain intermediate shardings:
y = jax.lax.with_sharding_constraint(y, NamedSharding(mesh, P("data", "model")))
return y
What happens under the hood:
jittracesforwardwith abstract sharded inputs.- The jaxpr lowers to HLO with sharding annotations on the inputs, outputs, and any
with_sharding_constraintpoints. - GSPMD (the partitioner inside XLA) propagates sharding through the whole HLO module, deciding per-op how each tensor is laid out.
- GSPMD inserts collectives (
all-reduce,all-gather,reduce-scatter,all-to-all) where needed. - XLA emits one program per device; each device runs only its slice.
This is GSPMD: General SPMD partitioner. The user writes a single-device-shaped program with annotations on a few key tensors; the compiler figures out the rest.
9.3 Mesh, PartitionSpec, NamedSharding¶
Mesh: a logical N-dimensional grid of devices with named axes. Common patterns:- 1D: `("data",) - pure data parallelism.
- 2D: `("data", "model") - DP × tensor-parallel (Megatron-style).
- 3D:
("data", "fsdp", "model")or `("pp", "data", "model") - pipeline + DP + TP. PartitionSpec(aliasP): for an array with shape(d0, d1, …, dn), aP(spec0, spec1, …)says how each axis is partitioned over the mesh. Eachspeciis either:None: replicated along this array axis."name": sharded along the mesh axis"name".- A tuple
("a", "b"): sharded along the product of mesh axesaandb. NamedSharding(mesh, P(...)): binds aPartitionSpecto a concrete mesh.
Examples:
- P("data", None) on (B, D): sharded batch, replicated features (FSDP-like for activations).
- P(None, "model") on (D_in, D_out): replicated D_in, sharded D_out (tensor-parallel weight).
- P(("data", "fsdp")) on (D,) with mesh ("data", "fsdp"): sharded over both axes simultaneously.
9.4 with_sharding_constraint¶
Inside a `jit - ed function, you can pin an intermediate sharding:
This is a hint to GSPMD: "after this point, y must be sharded this way." Use it to:
- Break ambiguity when GSPMD picks a sharding you don't want.
- Force a re-shard at a known boundary (e.g., between tensor-parallel attention and tensor-parallel MLP).
9.5 Output sharding¶
jit with sharding can take in_shardings and out_shardings arguments to specify how inputs and outputs are sharded. Default: inferred from the actual input shardings and propagated by GSPMD.
forward_p = jax.jit(
forward,
in_shardings=(NamedSharding(mesh, P("data", None)),
NamedSharding(mesh, P(None, "model")),
NamedSharding(mesh, P("model"))),
out_shardings=NamedSharding(mesh, P("data", "model")),
)
9.6 When pmap is still useful¶
pmap survives for:
- Quick single-axis SPMD where setting up a Mesh feels heavy.
- Legacy code.
- A few research patterns that depend on pmap's tight coupling between Python-side leading axis and device axis.
For new code at scale, prefer jit + sharding.
10. jax.shard_map: when you want manual control¶
jit + sharding is implicit: you annotate, GSPMD figures out collectives. shard_map is explicit: you write what each device sees, and you call collectives by hand.
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P
@partial(shard_map, mesh=mesh, in_specs=(P("data", None), P(None, "model")),
out_specs=P("data", "model"))
def matmul(x_local, W_local):
# x_local has the *local* shape on each device.
# We must call collectives explicitly.
y_partial = x_local @ W_local
return y_partial # already correctly sharded
Inside shard_map you operate on the local shard. Collectives (jax.lax.psum, jax.lax.all_gather, jax.lax.all_to_all, jax.lax.ppermute) reference the mesh axis names and run across the corresponding device subset.
When to use shard_map over jit+sharding:
- You need an algorithm that GSPMD does not synthesize well (custom ring all-reduce, expert-parallel routing, sequence parallelism with overlap).
- You want predictable collective placement for performance debugging.
- You are implementing a low-level kernel (e.g., a custom attention with sequence sharding and explicit all_to_all).
When to stick with jit+sharding:
- Standard transformer training. GSPMD does an excellent job.
- You want one piece of code that retargets between mesh shapes without rewriting.
The mental model: jit+sharding is "declare the partitioning, let the compiler handle parallelism"; shard_map is "I am writing SPMD by hand, shoulder-to-shoulder with the hardware."
11. Structured loops: lax.scan, lax.fori_loop, lax.while_loop¶
A Python for loop inside a `jit - traced function unrolls into the jaxpr. For 4 iterations that is fine. For 1024 iterations the jaxpr (and the resulting HLO and the compile time) explode. Use structured control flow.
11.1 `jax.lax.scan - the workhorse¶
scan is a stateful map-reduce. Signature (informal):
def f(carry, x): # carry: state, x: per-step input
new_carry = ...
y = ...
return new_carry, y
final_carry, ys = jax.lax.scan(f, init_carry, xs)
It is O(T) in compile size regardless of T, because the loop body is traced once and reused.
Use for:
- RNN forward passes.
- Sampling loops where each step depends on the last.
- Any reduction with an explicit state that you want to materialize per step (ys).
scan differentiates correctly through the loop in O(T) memory if you use unroll= carefully or rely on rematerialization (jax.checkpoint). The default reverse-mode AD over scan keeps activations for every step; for long sequences combine with jax.checkpoint to trade compute for memory.
11.2 jax.lax.fori_loop¶
Like scan but does not stack per-step outputs (no ys). Slightly cheaper, less expressive. AD support is limited if N is dynamic-for differentiable loops prefer scan.
11.3 jax.lax.while_loop¶
Truly dynamic iteration count. Cannot be reverse-mode differentiated in the general case (the number of iterations is data-dependent). Use for inference-only loops with data-dependent termination (e.g., autoregressive sampling until EOS).
11.4 jax.lax.cond and jax.lax.switch¶
For data-dependent branching:
Both branches are traced and compiled; only one runs at execution. This means both branches must return the same PyTree structure with matching shapes/dtypes.
12. XLA: the compiler under the hood¶
XLA (Accelerated Linear Algebra) is the compiler. JAX is one front-end; TensorFlow and PyTorch/XLA are others. We focus on what JAX programmers need to know.
12.1 HLO: the IR¶
HLO ("High Level Optimizer") is XLA's intermediate representation. An HLO module is a collection of computations; each computation is a list of instructions; each instruction has a name, an opcode, operand references, and a typed shape.
Modern XLA actually uses StableHLO (an MLIR dialect) at the boundary, then lowers internally to HLO. The two are largely isomorphic for our purposes.
A small HLO snippet for jnp.sum(jnp.sin(x * y)):
HloModule jit_f, entry_computation_layout={(f32[3]{0}, f32[3]{0})->(f32[])}
ENTRY main {
x = f32[3]{0} parameter(0)
y = f32[3]{0} parameter(1)
prod = f32[3]{0} multiply(x, y)
s = f32[3]{0} sine(prod)
zero = f32[] constant(0)
ROOT sum = f32[] reduce(s, zero), dimensions={0}, to_apply=add_f32
}
You can dump HLO from JAX:
hlo = jax.jit(f).lower(jnp.ones((3,)), jnp.ones((3,))).compiler_ir(dialect="hlo")
print(hlo.to_string())
or get the post-optimization HLO via compiled.as_text() / compiled.runtime_executable(). Setting the env var XLA_FLAGS=--xla_dump_to=/tmp/xla_dump --xla_dump_hlo_as_text will dump every compiled module.
12.2 The HLO op set you should recognize¶
| Op | Meaning |
|---|---|
parameter(i) |
The i-th input |
constant(...) |
A literal |
add, multiply, subtract, divide |
Elementwise |
sine, exponential, log, tanh, ... |
Elementwise unary |
compare |
Elementwise comparison |
convert |
Dtype cast |
broadcast(..., dimensions=...) |
Reshape/expand to a larger shape |
reshape |
Same data, new shape |
transpose(..., dimensions=...) |
Permute axes |
slice, dynamic-slice |
Static / dynamic slicing |
concatenate |
Concatenate along a dim |
reduce(operand, init, dimensions=, to_apply=) |
Reduction with a reducer computation |
dot(a, b, lhs_contracting_dims=..., rhs_contracting_dims=..., lhs_batch_dims=..., rhs_batch_dims=...) |
Generalized matmul |
convolution |
Generalized conv |
gather, scatter |
Indirect read / write |
select |
Elementwise where |
tuple, get-tuple-element |
Tuple constructors / accessors |
while, conditional |
Structured control flow |
all-reduce, all-gather, reduce-scatter, all-to-all, collective-permute |
Cross-device collectives |
custom-call |
Escape hatch to a hand-written kernel (cuDNN, custom CUDA, Pallas) |
dot is the linchpin. It is fully general: contracting dims are reduced, batch dims are kept, the rest are output. A standard (M, K) × (K, N) → (M, N) matmul has lhs_contracting=[1], rhs_contracting=[0], no batch dims. Attention's (B, H, S, D) × (B, H, T, D) has lhs_contracting=[3], rhs_contracting=[3], lhs_batch=[0,1], rhs_batch=[0,1].
12.3 The compilation pipeline¶
Roughly:
jaxpr
│ (lowered by JAX)
▼
StableHLO (MLIR)
│ (XLA front-end)
▼
HLO
│ (XLA HLO passes)
▼
Optimized HLO
│ (Backend: GPU / TPU / CPU)
▼
Device code (PTX/SASS via LLVM-NVPTX, TPU machine code, x86 LLVM)
The HLO passes do the heavy lifting:
- Algebraic simplification.
x * 1 → x,concat(slice, slice) → original, fold constants. - Layout assignment. Pick physical memory layouts (which dim is fastest-varying, tile shapes, padding) per buffer to match the target hardware.
- Sharding propagation (GSPMD). From annotated tensors, infer shardings everywhere; insert collectives.
- Operator fusion. Combine adjacent elementwise ops, plus a producer reduction or matmul, into a single kernel-this is the single biggest performance win on GPU. JAX programs that look like 100 small numpy ops often compile to a handful of fused kernels.
- Memory scheduling. Order ops to minimize peak memory; insert rematerialization if needed.
- Lowering. GPU: emit LLVM IR (or call into cuBLAS/cuDNN for some patterns), then to PTX. TPU: emit TPU-specific IR for the matrix unit and vector unit, schedule across HBM/VMEM.
12.4 Fusion in detail¶
Fusion is the reason JAX feels fast. Consider:
In eager numpy this is three kernel launches: multiply, add, tanh-each reads from and writes to HBM. XLA fuses them into one kernel: read x, y, z once, do all the elementwise math in registers, write the result once. On memory-bound workloads (most elementwise + small reductions) this is a 3–10× speedup.
XLA's GPU backend also fuses producer reductions and consumer elementwise (and vice versa) and a matmul with epilogue elementwise (and prologue elementwise on its inputs) where profitable. The resulting fused kernel is what you see as a single HLO fusion instruction in the post-optimization dump.
Caveat: fusion can hide bugs. If you jax.debug.print inside a function and it disappears, fusion has eliminated it; use jax.disable_jit() for debugging.
12.5 Layout assignment¶
Layout = how a logical shape (N, H, W, C) maps onto physical memory. On TPUs, tiling matters enormously: the matrix unit prefers (128, 128) tiles aligned in particular ways; XLA inserts paddings and chooses layouts so dot products hit fast paths. On GPUs, the choice is less critical (memory is more uniform) but still matters for shared memory and tensor cores.
You can sometimes coax XLA with shape choices (sizes that are multiples of 128 on TPU, multiples of 8 with bf16 for tensor cores on GPU).
12.6 GSPMD: how a single jit targets thousands of devices¶
GSPMD's input: an HLO module with sharding annotations on some subset of values (parameters, outputs, with_sharding_constraint points). Its output: an HLO module where every value has a sharding, with collectives inserted at boundaries.
The propagation is bidirectional and uses cost models. Key rules:
- A
dotbetweenA: P("data", None)(sharded on batch) andB: P(None, "model")(sharded on output dim) produces a resultP("data", "model")with no collective-purely local matmul. - A
dotbetweenA: P("data", "k")andB: P("k", "model")(sharded on the contraction dim) requires anall-reduceover thekaxis after the local matmul. - An elementwise op with mismatched shardings inserts an
all-gatheror areduce-scatterto align them. - A sequence-axis split followed by self-attention typically needs an
all-to-allto switch from sequence-sharded to head-sharded.
GSPMD picks the lowest-cost combination. For most transformer workloads with a sensible mesh and a sensible PartitionSpec, the result is close to what an expert hand-writer would do. Where it isn't, you reach for with_sharding_constraint or shard_map.
12.7 Sharding propagation example¶
mesh = Mesh(devices.reshape(2, 4), ("data", "model"))
W1 = shard(W1_host, P(None, "model")) # (D, 4D), sharded out dim
W2 = shard(W2_host, P("model", None)) # (4D, D), sharded in dim
x = shard(x_host, P("data", None)) # (B, D), sharded batch
@jax.jit
def mlp(x, W1, W2):
h = jax.nn.gelu(x @ W1)
y = h @ W2
return y
GSPMD will:
1. Compute x @ W1: x is P("data", None), W1 is P(None, "model"), so result is P("data", "model") (no collective).
2. gelu: elementwise, sharding unchanged.
3. h @ W2: h is P("data", "model"), W2 is P("model", None). The contraction is on the "model" axis, so an all-reduce across "model" is inserted after the local matmul, yielding P("data", None).
That is the standard Megatron tensor-parallel MLP, derived automatically from three PartitionSpecs.
13. TPU vs GPU under JAX/XLA¶
XLA was conceived at Google with TPUs as its primary target, then extended to GPU and CPU. That history influences the runtime behavior.
Mark the TPU specifics here as "broadly true, version-dependent." TPU generations differ (v3/v4/v5p/v5e/Trillium…), and details have shifted.
13.1 Architectural differences (high level)¶
| Aspect | GPU (NVIDIA) | TPU |
|---|---|---|
| Compute units | SMs with FP32/FP16 cores + tensor cores | Matrix multiply unit (systolic array, 128×128 typical) + vector unit |
| Memory hierarchy | Registers → shared mem → L2 → HBM | Registers → VMEM (per core) → HBM |
| Numeric formats | FP32, FP16, BF16, FP8, INT8 | Primarily BF16 / FP32 / INT8; FP8 in newer gens |
| Interconnect | NVLink within a node, IB / Ethernet across nodes | Native ICI (inter-chip interconnect) in pod topology-2D/3D torus |
| Sweet spot | Heterogeneous workloads, irregular ops, custom CUDA | Big regular dense matmuls, large pods |
13.2 Compiler differences¶
- Fusion granularity. TPU XLA tends to produce very large fused regions-sometimes the entire transformer block becomes one or two ops. GPU XLA fuses aggressively but is bounded by SM resource limits.
- Layout. TPUs are pickier about layout (the matrix unit's tile size is fixed). XLA pads aggressively to align-small tensors can have substantial padding overhead. Sizes that are multiples of 128 (sometimes 256) are friendly.
- Collectives. TPU pods have a 2D or 3D torus; XLA's collective scheduler is tightly tuned for it. GPU collectives go through NCCL (or XLA's own GPU collectives), and topology is less regular.
- Async dispatch. Both backends launch async; on TPU the compiler often overlaps collectives with compute aggressively (because the cost model is well-understood).
13.3 Why XLA was TPU-first¶
TPUs are useless without a compiler-there is no eager kernel library equivalent to cuDNN. Every TPU program goes through XLA. So XLA had to be excellent at TPU codegen, sharding, and pod-scale collectives from day one. JAX inherited all of that. Running JAX on TPU is "the original path"; running JAX on GPU shares the front end but uses a different (also mature) backend.
13.4 Practical implications¶
- A JAX program that runs well on 1 GPU often runs well on 1 TPU core with no changes.
- A JAX program that scales to 8 GPUs via
jit+sharding scales to a 2048-core TPU pod by enlarging the `Mesh - same code. - TPU memory per core is typically smaller than a top-end GPU's HBM. You will lean more on FSDP-style sharding and rematerialization on TPU.
14. Module systems: Equinox and Flax¶
JAX core has no nn.Module. Two libraries dominate.
14.1 Flax (flax.linen and the newer flax.nnx)¶
flax.linen (the long-standing API):
from flax import linen as nn
class MLP(nn.Module):
hidden: int
out: int
@nn.compact
def __call__(self, x):
x = nn.Dense(self.hidden)(x)
x = nn.relu(x)
return nn.Dense(self.out)(x)
model = MLP(64, 10)
params = model.init(rng, dummy_input) # returns a PyTree of params
y = model.apply(params, x)
Idioms:
- Modules are dataclasses; calling them constructs thunks, not param-owning objects.
- model.init(rng, x) traces the module to create a parameter PyTree.
- model.apply(params, x) runs the forward pass-params are an argument, never owned.
This is JAX-functional through and through. Optimizer state is also a PyTree, threaded through train_step.
flax.nnx is a newer API that gives you a more PyTorch-like stateful feel while preserving JAX semantics under the hood. Pick linen for stable production code and the largest ecosystem; pick nnx if you prefer the ergonomics.
14.2 Equinox¶
Equinox treats modules as PyTrees of dataclasses directly:
import equinox as eqx
class MLP(eqx.Module):
l1: eqx.nn.Linear
l2: eqx.nn.Linear
def __init__(self, key):
k1, k2 = jax.random.split(key)
self.l1 = eqx.nn.Linear(784, 128, key=k1)
self.l2 = eqx.nn.Linear(128, 10, key=k2)
def __call__(self, x):
return self.l2(jax.nn.relu(self.l1(x)))
model = MLP(jax.random.PRNGKey(0))
y = model(x) # call the model directly
grads = jax.grad(loss)(model, x, y_true) # model itself is the param tree
Idioms:
- A module is its parameters. There is no separate params PyTree.
- eqx.partition(model, eqx.is_array) separates trainable arrays from non-array fields when needed (e.g., for optax).
Pick Equinox if you like the "model is data" mental model and small dependency. Pick Flax for the bigger ecosystem (pretrained checkpoints, integrations, MaxText/Pax).
14.3 Optimizer: optax¶
Both work with optax, the standard JAX optimizer library. optax exposes optimizers as (init, update) function pairs that operate on PyTrees:
import optax
opt = optax.adamw(1e-3)
opt_state = opt.init(params)
grads = jax.grad(loss)(params, ...)
updates, opt_state = opt.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
Note: pure functions, no mutation. optax chains transformations (optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(...))) which makes complex schedules trivially composable.
15. jax.experimental.pallas: Triton-like kernels in JAX¶
When the compiler's automatic codegen is not enough-typically for fused attention, custom flash-attention variants, or quantized kernels-Pallas lets you write kernels in a Python DSL that lowers to:
- Triton on GPU, and
- Mosaic on TPU.
Sketch:
from jax.experimental import pallas as pl
def add_kernel(x_ref, y_ref, o_ref):
o_ref[...] = x_ref[...] + y_ref[...]
@jax.jit
def add(x, y):
return pl.pallas_call(
add_kernel,
grid=(x.shape[0] // 128,),
in_specs=[pl.BlockSpec((128,), lambda i: (i,)),
pl.BlockSpec((128,), lambda i: (i,))],
out_specs=pl.BlockSpec((128,), lambda i: (i,)),
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
)(x, y)
Mental model: Pallas kernels look like CUDA/Triton kernels (you reason about blocks and refs/pointers), but they integrate as a single HLO custom-call in your JAX program, with full vmap and jit composition. Use it sparingly-only when the surrounding XLA-generated code leaves serious performance on the table-and benchmark against the un-Pallas baseline.
This is the JAX answer to PyTorch's "drop into Triton". It is a young area; expect API motion. The big production win so far is fused-attention kernels.
16. Practical exercises (with worked answers)¶
How to use these. Read the question. Pause. Predict. Then read the answer. If you predicted wrong, re-read the relevant section.
Exercise 1-What gets recompiled?¶
You have:
@jax.jit
def f(x, y):
return x @ y + 1.0
f(jnp.ones((4, 8)), jnp.ones((8, 16))) # call A
f(jnp.ones((4, 8)), jnp.ones((8, 16))) # call B
f(jnp.ones((4, 8)), jnp.ones((8, 32))) # call C
f(jnp.ones((4, 8)).astype(jnp.bfloat16),
jnp.ones((8, 16)).astype(jnp.bfloat16)) # call D
Which calls trigger compilation?
Answer. A compiles. B reuses A's cache (identical abstract signature). C compiles a new entry (different shape on y). D compiles a new entry (different dtype on both). Total: 3 compilations, 4 calls.
Exercise 2-Why did my training loop slow to a crawl?¶
Symptom: every step takes ~2 s, no GPU utilization between steps. JAX_LOG_COMPILES=1 shows a compile log every step.
Likely cause. Variable-shape inputs. Most often: padding sequences to their actual lengths inside the data loader, so each batch has shape (B, L) with a different L.
Fixes (any of):
1. Pad to a fixed maximum length (cheapest, slight wasted compute).
2. Pad to a small set of bucket lengths (e.g., 128/256/512/1024)-at most that many compilations.
3. Mark sequence length as static_argnums only if you genuinely have a tiny number of distinct values.
Exercise 3-Per-example gradient norms¶
You want, for each example in a batch, the L2 norm of its per-example gradient. Write it.
def per_example_loss(params, x, y):
return ((x @ params["W"] - y) ** 2).mean() # scalar per example
per_grad = jax.vmap(jax.grad(per_example_loss),
in_axes=(None, 0, 0)) # share params, batch x and y
def per_example_norm(params, X, Y):
grads = per_grad(params, X, Y) # PyTree, leaves shape (B, ...)
flat = jax.tree.leaves(grads)
sq = sum(jnp.sum(g.reshape(g.shape[0], -1) ** 2, axis=1) for g in flat)
return jnp.sqrt(sq) # shape (B,)
Reading: vmap(grad(...)) produces a function that returns gradient PyTrees with a leading batch axis; we then compute the per-row norm. No Python loop.
Exercise 4-A Megatron-style MLP, by PartitionSpec¶
You have an MLP Linear(D, 4D) → gelu → Linear(4D, D). You have an 8-device mesh ("data", "model") with shape (2, 4). Specify shardings such that:
- The batch is data-parallel.
- Both Linear layers are tensor-parallel along the hidden axis.
- The output of the MLP is data-sharded, model-replicated (so a downstream layer-norm sees a complete vector per example).
Answer.
mesh = Mesh(devices.reshape(2, 4), ("data", "model"))
x_spec = P("data", None) # (B, D)
W1_spec = P(None, "model") # (D, 4D)
W2_spec = P("model", None) # (4D, D)
y_spec = P("data", None) # (B, D)
GSPMD will compute x @ W1 locally (no collective), gelu locally, then h @ W2 with an all-reduce over "model" on the contraction axis to produce y_spec. This is exactly Megatron-LM's "column-parallel then row-parallel" pattern, derived from PartitionSpecs.
Exercise 5-Why won't this differentiate?¶
You wrap it in jax.grad(lambda x: f(x)[0].sum()) and get an error. Why?
Answer. The Python while runs at trace time and depends on a traced value (jnp.linalg.norm(x) > 1.0 is a tracer). You will hit a ConcretizationTypeError even before differentiation. The fix is jax.lax.while_loop, but: while_loop does not support reverse-mode AD because the iteration count is data-dependent. If you need a differentiable variant, use jax.lax.scan with a known maximum number of iterations, masking the unused steps.
Exercise 6-Reading a jaxpr¶
What does this code do?
def g(x):
return jnp.where(x > 0, x, 0.5 * x)
print(jax.make_jaxpr(g)(jnp.array([-1.0, 2.0, -3.0])))
Answer. Approximately:
{ lambda ; a:f32[3]. let
b:bool[3] = gt a 0.0
c:f32[3] = mul a 0.5
d:f32[3] = select_n b c a
in (d,) }
It is a leaky-ReLU-ish function (slope 0.5 on the negative side). The jaxpr makes the elementwise nature explicit and shows that where becomes a select_n (3-way select primitive) rather than a Python branch. It will compile to a single fused HLO kernel (compare, multiply, select).
17. Cheat sheet¶
17.1 The four core transformations¶
| Transformation | Maps f to |
Mental model |
|---|---|---|
jax.jit(f) |
A function that traces, lowers to HLO, compiles, caches, runs | "Make it fast" |
jax.grad(f) |
A function returning df/dx |
"Differentiate" |
jax.vmap(f) |
A function with an extra batch axis | "Vectorize" |
jax.pmap(f) / jit(..., in_shardings=...) |
A function running across devices | "Parallelize" |
They compose. jit(vmap(grad(f))) is well-defined and idiomatic.
17.2 Common errors and their meaning¶
| Error | Cause | Fix |
|---|---|---|
ConcretizationTypeError |
Python control flow on a tracer | Use jax.lax.cond/where, or make the variable static |
TracerArrayConversionError |
Tried to convert a tracer to numpy or to use it as a Python int | Push the work into JAX-land or restructure |
| Repeated compiles | Variable shapes / dtypes / static args / pytree structures | Stabilize shapes, use bucketing, audit static_argnums |
| OOM during compile | Long Python for loop being unrolled |
Use lax.scan |
| Silent wrong answer | Side effect (mutation, global) inside jitted function | Make it pure |
17.3 Inspection toolbox¶
jax.make_jaxpr(f)(x) # show jaxpr
jax.jit(f).lower(x).as_text() # StableHLO MLIR
jax.jit(f).lower(x).compiler_ir(dialect="hlo") # HLO
jax.jit(f).lower(x).compile().as_text() # post-opt HLO
jax.jit(f).lower(x).compile().cost_analysis() # FLOPs, bytes
jax.config.update("jax_log_compiles", True) # see every compile
JAX_TRACEBACK_FILTERING=off # (env) full Python tracebacks
XLA_FLAGS=--xla_dump_to=/tmp/xla --xla_dump_hlo_as_text # (env) dump every module
17.4 When to reach for which loop primitive¶
| Need | Primitive |
|---|---|
| Static iteration count, want per-step outputs, want AD | jax.lax.scan |
| Static iteration count, no per-step outputs, no AD on the loop count | jax.lax.fori_loop |
| Data-dependent iteration count, no AD over the loop | jax.lax.while_loop |
| Data-dependent branch, both branches valid | jax.lax.cond |
| Choose among N branches by index | jax.lax.switch |
17.5 Sharding spec recipes (mesh ("data", "model"))¶
| Goal | Inputs | Weights | Note |
|---|---|---|---|
| Pure DP (replicated weights) | P("data", ...) |
P(None, ...) |
Standard data parallel |
| FSDP-style | P("data", ...) |
P("data", ...) (gathered before use) |
Combine with with_sharding_constraint |
| Tensor parallel (Megatron MLP) | P("data", None) |
W1: P(None, "model"), W2: P("model", None) |
All-reduce after W2 |
| 2D parallelism | P("data", None) |
TP weights as above | ("data", "model") mesh |
17.6 Mental discipline¶
- State-as-argument. Anything that "carries over"-params, opt state, RNG, batchnorm running stats-is an argument and a return value, never a global.
- Trace once, run many. Every `jit - decorated function should be traced a small number of times across the entire program lifetime.
- Annotate sharding sparsely. Annotate the inputs and key intermediate shardings; let GSPMD figure out the rest.
- Profile with the post-opt HLO. What you wrote in Python and what runs on the device can diverge dramatically due to fusion. Read the post-optimization HLO before optimizing.
- Read the jaxpr when surprised. It is short and exact.
Closing¶
JAX is a small core (pure functions over PyTrees, traced into jaxprs, lowered to XLA HLO) with a large amount of leverage on top (composable transformations, GSPMD, Pallas, the Flax/Equinox ecosystem). Its design exacts a discipline-purity, explicit state, structured control flow, deliberate sharding-and pays back with a programming model that scales seamlessly from a single laptop GPU to a 4096-chip TPU pod with the same source code.
If you internalize four ideas from this chapter, make them: (1) purity makes transformations compose, (2) jit is trace-then-cache, and the cache key is the abstract signature, (3) HLO is the IR; fusion and GSPMD are where the magic lives, (4) Mesh + PartitionSpec is how you tell the compiler about your hardware, and the rest is propagation. Everything else in JAX is a refinement of those.