Saltar a contenido

Week 12 - Custom Operators: From CUDA Kernel to torch.ops

12.1 Conceptual Core

  • When PyTorch / JAX don't have a fast-enough op for your needs, you write one. The standard path:
  • Implement the kernel (CUDA, Triton, or C++).
  • Wrap with the framework's extension API.
  • Register with the dispatcher.
  • Define the autograd backward (forward + backward = autograd.Function).
  • Optionally support torch.compile via abstract-shape registration.

12.2 Mechanical Detail

  • PyTorch C++ extension (the recommended modern path):
  • setup.py with torch.utils.cpp_extension.CUDAExtension.
  • C++/CUDA source with pybind11 bindings.
  • Built at install time; loadable as import myop.
  • torch.library API (PyTorch 2.x) for dispatcher integration without C++:
    @torch.library.custom_op("myns::myop", mutates_args=())
    def myop(x: torch.Tensor) -> torch.Tensor:
        return _my_triton_kernel(x)
    
    @myop.register_fake
    def _(x):
        return torch.empty_like(x)  # for compile/dynamo
    
  • Backward registration: torch.library.register_autograd("myns::myop", backward_fn).
  • Triton-as-custom-op: torch.compile recognizes Triton kernels and integrates them into the compiled graph without a graph break-the modern preferred path.

12.3 Lab-"RMSNorm From Scratch"

RMSNorm is used in modern LLMs (Llama, Qwen). Implement it three ways: 1. PyTorch: pure tensor ops. 2. Triton custom op: a fused kernel that reads input, computes RMS, normalizes, scales-all in one pass over HBM. 3. CUDA C++ extension: same kernel in CUDA C++ with a pybind11 binding.

For each: forward + backward, autograd-correct (numerical-grad test), benchmarked vs the others on (B, S, H) = (8, 4096, 4096) BF16. Your fused Triton version should beat PyTorch by 3-5×.

12.4 Idiomatic & Diagnostic Drill

  • Test your custom op under torch.compile. Verify it doesn't break the graph (check torch._dynamo.explain).

12.5 Production Slice

  • Custom ops in production must ship binary artifacts compatible with the user's PyTorch version. Use torch.ops.load_library for shared-library loading; pin PyTorch ABI.

Month 3 Capstone Deliverable

A framework-internals/ directory: 1. dispatcher-trace/ (week 9)-the annotated walk through ATen. 2. compile-bench/ (week 10)-torch.compile measurements + graph-break triage. 3. jax-baseline/ (week 11)-JAX training loop matching the PyTorch baseline; HLO analysis. 4. rmsnorm-fused/ (week 12)-three implementations, benchmark plot, autograd tests.

By end of month you should be comfortable reading framework source-the literacy that distinguishes systems engineers from framework users.


  • The torch.compile design doc on pytorch.org/docs/.
  • The Inductor design doc.
  • The JAX "How JAX primitives work" guide.
  • The XLA HLO operation semantics page.
  • The PyTorch dispatcher tutorial in pytorch/pytorch/wiki.

Comments