Week 12 - Custom Operators: From CUDA Kernel to torch.ops¶
12.1 Conceptual Core¶
- When PyTorch / JAX don't have a fast-enough op for your needs, you write one. The standard path:
- Implement the kernel (CUDA, Triton, or C++).
- Wrap with the framework's extension API.
- Register with the dispatcher.
- Define the autograd backward (forward + backward =
autograd.Function). - Optionally support
torch.compilevia abstract-shape registration.
12.2 Mechanical Detail¶
- PyTorch C++ extension (the recommended modern path):
setup.pywithtorch.utils.cpp_extension.CUDAExtension.- C++/CUDA source with
pybind11bindings. - Built at install time; loadable as
import myop. torch.libraryAPI (PyTorch 2.x) for dispatcher integration without C++:- Backward registration:
torch.library.register_autograd("myns::myop", backward_fn). - Triton-as-custom-op:
torch.compilerecognizes Triton kernels and integrates them into the compiled graph without a graph break-the modern preferred path.
12.3 Lab-"RMSNorm From Scratch"¶
RMSNorm is used in modern LLMs (Llama, Qwen). Implement it three ways:
1. PyTorch: pure tensor ops.
2. Triton custom op: a fused kernel that reads input, computes RMS, normalizes, scales-all in one pass over HBM.
3. CUDA C++ extension: same kernel in CUDA C++ with a pybind11 binding.
For each: forward + backward, autograd-correct (numerical-grad test), benchmarked vs the others on (B, S, H) = (8, 4096, 4096) BF16. Your fused Triton version should beat PyTorch by 3-5×.
12.4 Idiomatic & Diagnostic Drill¶
- Test your custom op under
torch.compile. Verify it doesn't break the graph (checktorch._dynamo.explain).
12.5 Production Slice¶
- Custom ops in production must ship binary artifacts compatible with the user's PyTorch version. Use
torch.ops.load_libraryfor shared-library loading; pin PyTorch ABI.
Month 3 Capstone Deliverable¶
A framework-internals/ directory:
1. dispatcher-trace/ (week 9)-the annotated walk through ATen.
2. compile-bench/ (week 10)-torch.compile measurements + graph-break triage.
3. jax-baseline/ (week 11)-JAX training loop matching the PyTorch baseline; HLO analysis.
4. rmsnorm-fused/ (week 12)-three implementations, benchmark plot, autograd tests.
By end of month you should be comfortable reading framework source-the literacy that distinguishes systems engineers from framework users.
Recommended Reading Done This Month¶
- The
torch.compiledesign doc onpytorch.org/docs/. - The Inductor design doc.
- The JAX "How JAX primitives work" guide.
- The XLA HLO operation semantics page.
- The PyTorch dispatcher tutorial in
pytorch/pytorch/wiki.