Skip to content

Month 3-Framework Internals: PyTorch, torch.compile, JAX/XLA, Custom Ops

Goal: by the end of week 12 you can (a) read PyTorch's dispatcher source and trace an op from Python through ATen to a CUDA kernel, (b) explain torch.compile's graph capture and Inductor backend, (c) read JAX/XLA HLO and reason about XLA optimizations, and (d) ship a custom CUDA kernel as a PyTorch extension callable from Python.

Deep-dive companions (read in tandem): - Weeks 9–10, 12 → DEEP_DIVES/04_PYTORCH_INTERNALS.md - full layered architecture trace, dispatcher mechanics, autograd engine, completetorch.compilepipeline (Dynamo + AOTAutograd + Inductor), modern custom-op path with Triton, CUDA caching allocator algorithm. - Week 11 →DEEP_DIVES/05_JAX_XLA.md - pure-functional model, jaxpr tracing with annotated examples, full XLA pipeline, GSPMD with Megatron-MLP propagation walkthrough.


Weeks

Comments