Week 11 - JAX, XLA, HLO¶
11.1 Conceptual Core¶
- JAX is a library for composable function transformations:
jit,grad,vmap,pmap,shard_map. Underneath, every transformed function is traced into a jaxpr and compiled by XLA. - XLA is a domain-specific compiler for linear algebra. Its IR is HLO (High Level Operations)-a small, well-defined op set with parametric shapes. XLA optimizes via fusion, layout assignment, sharding propagation.
- JAX is favored for: research clarity (functional purity), TPU support (XLA is the TPU compiler), large-scale training (sharding via
pjit/shard_mapis more elegant than PyTorch's parallelism stack). - PyTorch is favored for: ecosystem maturity, eager-mode debugging, NVIDIA hardware-tuning depth.
- Both are correct answers; the curriculum requires fluency in both because real production stacks use both.
11.2 Mechanical Detail¶
jax.jittraces the function with abstractTracerarguments, builds a jaxpr, compiles with XLA. The compiled artifact is cached by input shapes/dtypes/static args.jax.gradis reverse-mode AD that operates on jaxprs-purely functional. Closures and side effects don't survivegrad.jax.vmapvectorizes a function across a new axis. The classic example: a function that operates on one example becomes a batched function.pjit/shard_map(modern unifiedjitwith sharding annotations): distribute computation across devices.- HLO inspection:
jax.jit(f).lower(args).compiler_ir(dialect="hlo")returns the HLO text. Read this-same value as reading Inductor's Triton.
11.3 Lab-"JAX Equivalent"¶
Re-implement your Month 1 training loop in JAX:
- Pure-functional model (no nn.Module mutation).
- optax for the optimizer.
- jax.jit the train step.
- Add jax.vmap somewhere meaningfully (e.g., per-example metric computation).
- Compare end-to-end throughput with the PyTorch baseline.
11.4 Idiomatic & Diagnostic Drill¶
- Inspect the HLO. Identify a fused op produced by XLA that PyTorch+Inductor produced (or didn't produce) for the equivalent code.
11.5 Production Slice¶
- For TPU-bound work, JAX is the right tool. Document the JAX install path on Cloud TPU; rent a v4-8 for two hours; run a small training job; capture the HLO and the XLA cost analysis.