Skip to content

Month 3-Framework Internals: PyTorch, torch.compile, JAX/XLA, Custom Ops

Goal: by the end of week 12 you can (a) read PyTorch's dispatcher source and trace an op from Python through ATen to a CUDA kernel, (b) explain torch.compile's graph capture and Inductor backend, (c) read JAX/XLA HLO and reason about XLA optimizations, and (d) ship a custom CUDA kernel as a PyTorch extension callable from Python.

Deep-dive companions (read in tandem): - Weeks 9–10, 12 → the PyTorch Internals deep dive - full layered architecture trace, dispatcher mechanics, autograd engine, complete torch.compile pipeline (Dynamo + AOTAutograd + Inductor), modern custom-op path with Triton, CUDA caching allocator algorithm. - Week 11 → the JAX/XLA deep dive - pure-functional model, jaxpr tracing with annotated examples, full XLA pipeline, GSPMD with Megatron-MLP propagation walkthrough.

Worked investigation (hands-on, real GPU): Read an nsys trace of a training step - why a 100%-util GPU is slow; spot dataloader starvation, sync points, and fusion opportunities in the timeline. See the Worked Examples section.


Weeks

Comments