Skip to content

Week 11 - JAX, XLA, HLO

11.1 Conceptual Core

  • JAX is a library for composable function transformations: jit, grad, vmap, pmap, shard_map. Underneath, every transformed function is traced into a jaxpr and compiled by XLA.
  • XLA is a domain-specific compiler for linear algebra. Its IR is HLO (High Level Operations)-a small, well-defined op set with parametric shapes. XLA optimizes via fusion, layout assignment, sharding propagation.
  • JAX is favored for: research clarity (functional purity), TPU support (XLA is the TPU compiler), large-scale training (sharding via pjit/shard_map is more elegant than PyTorch's parallelism stack).
  • PyTorch is favored for: ecosystem maturity, eager-mode debugging, NVIDIA hardware-tuning depth.
  • Both are correct answers; the curriculum requires fluency in both because real production stacks use both.

11.2 Mechanical Detail

  • jax.jit traces the function with abstract Tracer arguments, builds a jaxpr, compiles with XLA. The compiled artifact is cached by input shapes/dtypes/static args.
  • jax.grad is reverse-mode AD that operates on jaxprs-purely functional. Closures and side effects don't survive grad.
  • jax.vmap vectorizes a function across a new axis. The classic example: a function that operates on one example becomes a batched function.
  • pjit / shard_map (modern unified jit with sharding annotations): distribute computation across devices.
  • HLO inspection: jax.jit(f).lower(args).compiler_ir(dialect="hlo") returns the HLO text. Read this-same value as reading Inductor's Triton.

11.3 Lab-"JAX Equivalent"

Re-implement your Month 1 training loop in JAX: - Pure-functional model (no nn.Module mutation). - optax for the optimizer. - jax.jit the train step. - Add jax.vmap somewhere meaningfully (e.g., per-example metric computation). - Compare end-to-end throughput with the PyTorch baseline.

11.4 Idiomatic & Diagnostic Drill

  • Inspect the HLO. Identify a fused op produced by XLA that PyTorch+Inductor produced (or didn't produce) for the equivalent code.

11.5 Production Slice

  • For TPU-bound work, JAX is the right tool. Document the JAX install path on Cloud TPU; rent a v4-8 for two hours; run a small training job; capture the HLO and the XLA cost analysis.

Comments