Skip to content

Week 10 - torch.compile, TorchDynamo, Inductor

10.1 Conceptual Core

  • torch.compile (PyTorch 2.0+) is a JIT compiler that captures Python+PyTorch into a graph and compiles it to optimized kernels. The pipeline:
  • TorchDynamo-Python frame evaluation hook; captures bytecode into FX graphs, handles graph breaks for unsupported ops.
  • AOTAutograd-runs both forward and backward through Dynamo, partitions into a joint graph, decomposes high-level ops into a small "core ATen" set.
  • Inductor-the default backend. Lowers the FX graph to Triton kernels (for CUDA) or C++/OpenMP (for CPU). Schedules with kernel fusion.
  • The user-visible promise: ~30-50% speedup on training, more for inference, with one decorator. The reality: graph breaks and silent fallbacks make this a discipline, not a free lunch.

10.2 Mechanical Detail

  • Graph breaks: any operation Dynamo can't trace falls back to eager. Common causes: data-dependent control flow on tensor values, print, custom Python objects, certain if patterns.
  • `torch._dynamo.explain(model)(input) - shows graph breaks with reasons.
  • `TORCH_COMPILE_DEBUG=1 - dumps every stage of compilation. Massive output; useful when debugging perf regressions.
  • Inductor codegen: TORCH_LOGS=output_code shows the generated Triton kernels. Read these-they're surprisingly readable and often reveal optimization opportunities you can replicate by hand.
  • Modes: mode="reduce-overhead" (CUDA graphs), mode="max-autotune" (heavy autotuning), default. Choose for the workload.
  • Caching: compiled artifacts cached in ~/.cache/torch_inductor. First run is slow; subsequent calls are fast.

10.3 Lab-"Compile and Compare"

Take your honest-training-loop from Month 1. Add model = torch.compile(model). Measure: 1. First-step time (compilation cost). 2. Steady-state step time vs uncompiled. 3. With TORCH_LOGS="recompiles": how many recompilations occurred? Why? 4. With mode="max-autotune": extra speed vs default? Worth the compile time?

Triage any graph breaks; report in COMPILE_LOG.md.

10.4 Idiomatic & Diagnostic Drill

  • The "guard" system: every compiled artifact carries assumptions about input shapes, dtypes, requires_grad. A mismatched call recompiles. Dynamic shapes are a special hell-investigate dynamic=True for serving workloads.

10.5 Production Slice

  • For inference, torch.compile + CUDA graphs (mode="reduce-overhead") is the production path. Document the compile-warmup procedure for your serving stack.

Comments