Saltar a contenido

PyTorch Internals: A Deep Dive Reference

A self-contained chapter for the AI Systems curriculum, Month 3. Target reader: a backend/SRE engineer who already writes PyTorch model code, wants to understand what happens beneath tensor.add_() and model.compile(), and is willing to read C++-flavoured pseudocode. Goal: after this chapter you should be able to read the PyTorch source tree, debug a dispatch-related bug, write a custom op that survives torch.compile, and reason about performance from first principles.


0. How To Read This Chapter

PyTorch is a stack of layers. Understanding it means understanding which layer owns which decision. The chapter walks down the stack on the way in (Python -> ATen -> dispatcher -> backend), then back up (autograd, AMP, compile) because higher layers are easier to follow once you know the substrate.

For each topic you will see four passes:

  1. Intuition -- what mental model is correct.
  2. Mechanism -- the actual data structures and control flow.
  3. Minimal code -- the smallest example that exercises the mechanism.
  4. Dispatch trace -- "if I were the dispatcher, what would I do step by step." This is the most underrated reasoning tool in PyTorch -- once you can simulate the dispatcher in your head, almost every weird bug becomes obvious.

All code examples target PyTorch 2.4+ on Linux/CUDA. Where source paths are referenced, they are relative to the pytorch/pytorch repo root.


1. The Layered Architecture

1.1 Intuition

When you write c = a + b in Python you are at the top of a five-layer cake. The layers exist because each one solves a different problem:

Layer Language Job
Python frontend (torch.*) Python Ergonomics, autograd surfaces, nn.Module
Pybind shim (torch._C) C++/pybind11 Convert PyObject -> C++ args, hold the GIL boundary
ATen (at::Tensor, ops) C++ The op API. Defines what add means type-erased over backends
Dispatcher (c10) C++ Pick which kernel to run (Autograd? CUDA? Autocast?)
Backend kernels C++/CUDA/Triton/MPS/etc. Actually compute the bytes

The dispatcher is the keystone. ATen does not call kernels directly. ATen says "here is add(Tensor, Tensor) -> Tensor, dispatcher, please find the right implementation given these tensor properties." This indirection is what lets autograd, autocast, vmap, FakeTensor, meta tensors, and quantization all hook in at the same place.

1.2 Mechanism

Top-to-bottom for c = a + b where a, b are CUDA float32, requires_grad=True:

Python:           c = a + b
                    -> Tensor.__add__(self, other)
                    -> torch._C._TensorBase.add(self, other)   # via pybind
C++ (ATen):       at::add(self, other)
                    -> at::_ops::add_Tensor::call(self, other)  # codegen'd
Dispatcher:       Dispatcher::singleton().call(op_handle, stack)
                    -> picks kernel by computed DispatchKeySet
                    1. Autograd kernel (records grad_fn, then redispatches)
                    2. AMP/Autocast kernel (maybe casts, then redispatches)
                    3. CUDA kernel (the real one)
CUDA:             vectorized elementwise add launches; returns Tensor

The order of layers is not arbitrary. Autograd wraps autocast wraps the backend, because: - Autograd needs to see the original op so it can record the right grad_fn. - Autocast needs to decide casts before the backend sees the dtypes. - The backend just computes.

You will see the same pattern reappear: any new cross-cutting concern (functionalization, batching for vmap, fake-tensor tracing) becomes a new DispatchKey somewhere in the stack.

1.3 ASCII trace of a + b

            +---------------------------+
 Python --> | Tensor.__add__            |
            +-------------+-------------+
                          |
                          v   torch._C (pybind)
            +---------------------------+
            | at::add(Tensor, Tensor)   |   ATen
            +-------------+-------------+
                          |
                          v
            +---------------------------+
            | OperatorHandle::call      |   Dispatcher
            |  computes DispatchKeySet  |
            +-------------+-------------+
                          |
              redispatches through keys:
                          |
                 +--------+--------+
                 | Autograd kernel |  records AddBackward; redispatch w/o Autograd key
                 +--------+--------+
                          |
                 +--------+--------+
                 | Autocast kernel |  (if active) cast inputs; redispatch w/o Autocast key
                 +--------+--------+
                          |
                 +--------+--------+
                 | CUDA kernel     |  vectorized_elementwise_kernel<<<...>>>(...)
                 +-----------------+

Memorise that picture. Almost every "why is this slow / wrong / weird" question maps to a layer in it.


2. Tensor Representation

2.1 The four-level wrapping

A torch.Tensor you hold in Python is a thin handle. Underneath:

torch.Tensor                 # Python object, subclass of torch._C._TensorBase
    -> at::Tensor            # C++ value type, ~one pointer wide
        -> c10::TensorImpl   # the heap object: dtype, sizes, strides, key set
            -> c10::Storage  # owns the bytes (or shares them with views)
                -> c10::DataPtr -> raw void* + Device + Allocator

at::Tensor is essentially intrusive_ptr<TensorImpl>. Copying a tensor in C++ bumps a refcount; it does not copy bytes. That is why Tensor a = b; is cheap -- they share the same TensorImpl (and hence the same Storage).

2.2 Fields you must know

TensorImpl (in c10/core/TensorImpl.h) carries roughly:

Field Type Meaning
storage_ c10::Storage The byte buffer. Shared between views.
sizes_ SmallVector<int64_t> Shape.
strides_ SmallVector<int64_t> How many elements (not bytes) to step per dim.
storage_offset_ int64_t Where this tensor starts inside the storage, in elements.
dtype_ caffe2::TypeMeta float32, bfloat16, int64, ...
device_ c10::Device (DeviceType::CUDA, index=0), (CPU, -1), etc.
layout_ c10::Layout Strided, Sparse, SparseCsr, Mkldnn.
key_set_ DispatchKeySet Bitset of dispatch keys (Autograd, CUDA, ...).
requires_grad_ bool Lives via AutogradMeta, not directly here.
autograd_meta_ unique_ptr<AutogradMetaInterface> grad, grad_fn, version counter.

Storage (in c10/core/Storage.h) carries:

Field Type Meaning
data_ptr_ c10::DataPtr Owning pointer + device + allocator deleter.
size_bytes_ size_t Capacity in bytes.
allocator_ c10::Allocator* Where to ask for memory. On CUDA this is the caching allocator (Section 12).
resizable_ bool Can resize_ grow the buffer?

Note: Storage does not know dtype or shape. It is just bytes. Two views of the same Storage can in principle even disagree on dtype (e.g., a.view(torch.int32) reinterprets bits).

2.3 The single most important invariant

For a strided tensor:

address_of_element(i0, i1, ..., in) =
    storage.data_ptr
    + dtype_size * (storage_offset + sum_k(ik * stride_k))

That's it. Sizes, strides, storage_offset, dtype, base pointer. Five things determine where every element lives. Views are just other tensors that share storage but have different (sizes, strides, storage_offset). Contiguous is a property of the strides relative to the sizes, not of memory itself.

2.4 Why decouple view from storage

If shape lived inside storage you would copy bytes for transpose, narrow, unsqueeze. Decoupling lets these be O(1) metadata-only ops. The cost: kernels must respect arbitrary strides, or you must contiguous() first. PyTorch favours the first for "shape ops" and the second for "compute ops" -- compute kernels typically demand contiguous (or one of a few canonical memory formats) input.


3. Strides and Views

3.1 Stride arithmetic

For a contiguous (2, 3, 4) float32 tensor:

sizes   = [2, 3, 4]
strides = [12, 4, 1]            # elements (not bytes)
element (i, j, k) -> offset = 12*i + 4*j + 1*k

Strides for contiguous (row-major) tensors are: stride[i] = prod(sizes[i+1:]).

3.2 Three view ops with no copy

import torch
a = torch.arange(24, dtype=torch.float32).view(2, 3, 4)   # contiguous
print(a.stride())           # (12, 4, 1)

b = a.transpose(0, 2)        # swap dim 0 and dim 2
print(b.shape, b.stride())   # torch.Size([4, 3, 2]), (1, 4, 12)
print(b.is_contiguous())     # False
print(b.data_ptr() == a.data_ptr())   # True -- same storage

c = a.narrow(1, 1, 2)        # along dim 1, start=1, length=2
print(c.shape, c.stride(), c.storage_offset())
# torch.Size([2, 2, 4]), (12, 4, 1), 4

d = a.unsqueeze(0)            # add a leading dim of size 1, stride 0 (or any)
print(d.shape, d.stride())   # torch.Size([1, 2, 3, 4]), (24, 12, 4, 1) or similar

What changed in each:

Op sizes strides storage_offset storage
transpose(0,2) (4,3,2) (1,4,12) 0 shared
narrow(1,1,2) (2,2,4) (12,4,1) 1*4 = 4 shared
unsqueeze(0) (1,2,3,4) (24,12,4,1) 0 shared

The transposed tensor's storage looks the same byte-for-byte; we just re-described how to walk it. That is the entire trick.

3.3 is_contiguous and contiguous()

is_contiguous() returns true iff strides equal the canonical strides for the sizes:

strides[i] == prod(sizes[i+1:])    for all i (and ==1 for the last)

Many fused/elementwise CUDA kernels assume contiguous input so they can do unit-stride vectorised loads. If you pass them a transposed tensor they would either: - error out with a stride check, or - silently fall back to a strided kernel that is 5-10x slower.

So shape-bending operations are often followed by .contiguous() before a heavy op:

y = x.transpose(1, 2).contiguous()   # make a real copy with canonical strides
z = some_fused_kernel(y)

contiguous() allocates new storage and copies. Skipping it when needed wastes throughput; calling it when not needed wastes memory.

3.4 Memory format: channels_last

For 4D image tensors, two memory layouts are common:

  • torch.contiguous_format (NCHW): strides (C*H*W, H*W, W, 1).
  • torch.channels_last (NHWC): strides (C*H*W, 1, W*C, C).

Both have shape (N, C, H, W) -- only strides differ. CUDNN and many fused conv kernels are faster on channels_last for FP16/BF16 because it matches tensor-core memory access patterns. You opt in with:

x = x.to(memory_format=torch.channels_last)
model = model.to(memory_format=torch.channels_last)

Now x.is_contiguous() is false but x.is_contiguous(memory_format=torch.channels_last) is true. Kernels that understand the format will not insert a copy; kernels that do not will materialize a contiguous tensor on the way in.

3.5 The dispatcher's view of strides

The dispatcher does not know or care about strides. Strides live in TensorImpl. The kernel underneath cares. This is why a stride bug is almost always a kernel bug, never a dispatcher bug.


4. The Dispatcher

4.1 Intuition

The dispatcher is a polymorphism mechanism more powerful than virtual functions. A virtual call dispatches on one type. The PyTorch dispatcher dispatches on a set of features that come from all tensor inputs combined: the union of their dispatch keys plus thread-local state (autocast active? grad enabled?).

Think of it as: every op is a function pointer table indexed by DispatchKey. Every call computes a key set, picks the highest priority key in the set, looks up the kernel, runs it. The kernel may "redispatch" -- remove its own key from the set and ask the dispatcher to do it again -- to chain effects.

4.2 The DispatchKey enum

Defined in c10/core/DispatchKey.h. The keys form a priority order. Roughly (highest first):

... functorch / vmap / FuncTorchBatched ...
PythonTLSSnapshot
Python                    # __torch_dispatch__
Functionalize             # for export / aot
... per-backend autograd ...
AutogradOther / AutogradCPU / AutogradCUDA / AutogradXPU / AutogradMeta
... AMP / autocast keys ...
AutocastCPU
AutocastCUDA
AutocastXPU
... tracing / profiling ...
... backend keys (lowest, where real work happens) ...
CPU
CUDA
MPS
XPU
Meta                      # shape-only "fake" tensors
SparseCPU / SparseCUDA
QuantizedCPU / QuantizedCUDA

Each tensor has a DispatchKeySet -- a 64-bit bitset over these keys. A typical CUDA tensor with requires_grad=True has {AutogradCUDA, CUDA}. If you enter a with torch.autocast("cuda"): block, the thread-local "included" set adds AutocastCUDA, so calls inside that block effectively dispatch on {AutogradCUDA, AutocastCUDA, CUDA}.

4.3 Computing the key for a call

Per call:

ks = empty
for each tensor argument t:
    ks |= t.key_set()
ks |= local_include_set()      # e.g. AutocastCUDA inside autocast()
ks &= ~local_exclude_set()     # e.g. inference_mode excludes Autograd
top = ks.highest_priority_key()
kernel = op_table[op_id][top]
kernel(args)

This is conceptually ~50 lines of C++ in aten/src/ATen/core/dispatch/Dispatcher.{h,cpp}. Real implementation has fast paths and caching but the model is exact.

4.4 Redispatch

A kernel can run, then ask the dispatcher to re-run the same op but with its own key removed. That is how chaining works. In pseudocode for the autograd kernel for add:

Tensor add_autograd(const Tensor& a, const Tensor& b) {
    // 1. Make output by redispatching to lower keys (skip Autograd).
    auto out = at::redispatch::add(
        c10::DispatchKeySet(c10::DispatchKey::AutogradCUDA).remove_from(ks),
        a, b);
    // 2. If grad mode is on and any input requires grad, build the graph node.
    if (compute_requires_grad(a, b)) {
        auto node = std::make_shared<AddBackward0>();
        node->set_next_edges(collect_next_edges(a, b));
        set_history(out, node);
    }
    return out;
}

So autograd doesn't do the math. It records bookkeeping, then asks the layer below to do the math.

4.5 TORCH_LIBRARY and TORCH_LIBRARY_IMPL

Two macros in C++. The first declares ops in a namespace; the second registers a backend implementation for some keys.

#include <torch/library.h>

// Declare the op. Namespace "myops". Schema is C++-typed Python.
TORCH_LIBRARY(myops, m) {
    m.def("triple(Tensor x) -> Tensor");
}

// Implement for CPU.
Tensor triple_cpu(const Tensor& x) {
    return x * 3;
}
TORCH_LIBRARY_IMPL(myops, CPU, m) {
    m.impl("triple", triple_cpu);
}

// Implement for CUDA.
Tensor triple_cuda(const Tensor& x) {
    return x * 3;
}
TORCH_LIBRARY_IMPL(myops, CUDA, m) {
    m.impl("triple", triple_cuda);
}

You can also write m.impl("aten::add.Tensor", &my_add) to override a built-in op for your library / dispatch key. This is how custom backends, quantized ops, and out-of-tree devices plug in without touching ATen.

4.6 Layered keys: why autograd > autocast > backend

Imagine the opposite order. If autocast were above autograd, then when autograd records grad_fn, the recorded op would already have the cast applied -- so the backward would see the cast and might run in low precision unintentionally. Putting autograd on top means it sees the user-level op and the user-level dtypes. Conversely autocast sits above the backend so the cast happens before the kernel chooses its codepath.

The general rule: cross-cutting concerns that transform the call (cast, batch, fake-tensor-ify) sit above the backend; concerns that observe and record (autograd) sit above the transformers so their recording is faithful to user intent.

4.7 __torch_dispatch__ and __torch_function__

Two extension points worth knowing:

  • __torch_function__: a Python-level override. If you define a subclass of Tensor with this method, torch.add(my_tensor, ...) will route through your method before doing anything else. Used by libraries like torch.compile's subclass tracing, torch.func, and pretty-printing-only wrappers.
  • __torch_dispatch__: a post-dispatcher override. Called from inside the dispatcher at the Python key. You see the canonical op (aten.add.Tensor) with already-resolved overloads. Used for FakeTensor, FunctionalTensor, LoggingTensor, and is the right hook for "I want to intercept everything below the API level."

If you only ever debug models you may never write either, but you will see them mentioned in stack traces and dynamo logs.


5. ATen Op Registration: native_functions.yaml

ATen's op surface is defined declaratively in aten/src/ATen/native/native_functions.yaml. Each entry looks roughly like:

- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
  variants: function, method
  dispatch:
    CPU: add_cpu
    CUDA: add_cuda
    SparseCPU, SparseCUDA: add_sparse
    MkldnnCPU: mkldnn_add
  autogen: add.out
  tags: pointwise, canonical

This declares:

  • Schema: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor. This is the canonical Python-typed signature. The .Tensor suffix is the overload name (vs add.Scalar).
  • Variants: generate at::add(...) (function form) and Tensor::add(...) (method form).
  • Dispatch table: which C++ function implements which key.
  • Autogen: also generate the add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) variant from the out= pattern.

A codegen tool (driven from torchgen/, output mostly under build/aten/src/ATen/) reads this YAML and emits:

  1. Op symbol headers: at::_ops::add_Tensor callables.
  2. Function variants: free functions in at:: and methods on at::Tensor.
  3. Default implementations: `add_(...) {return at::add(...).copy_(...)} - style helpers.
  4. Autograd derivative bindings (combined with derivatives.yaml): AddBackward0::apply etc.
  5. Python bindings: THPVariable_add etc. via tools/autograd/.

When you read PyTorch source and cannot find at::add's definition: it's generated. Look at aten/src/ATen/native/BinaryOps.cpp for the kernels and at the YAML for the contract.

derivatives.yaml (in tools/autograd/) is the sibling file. Each entry binds an op to its VJP:

- name: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
  self: grad
  other: maybe_multiply(grad, alpha)

That tiny snippet is the autograd of add. The codegen turns it into an AddBackward0 Function class that returns (grad, alpha*grad).


6. The Autograd Engine

6.1 Intuition

Autograd in PyTorch is a dynamic tape: the graph is built every forward pass, executed once in reverse, and discarded. There is no "compile autograd". This buys flexibility (control flow, dynamic shapes) at the cost of allocating one graph node per differentiable op per forward.

6.2 Mechanism

Each differentiable op produces an output Tensor whose grad_fn is a Node (subclass of torch::autograd::Node, formerly Function). The Node holds:

  • next_edges_: list of (Node*, input_nr) pairs pointing at the Nodes that produced this op's inputs.
  • saved tensors / scalars needed for the backward (e.g., for mul, both inputs).
  • apply(grads_out) -> grads_in: the VJP.

grad_fn is null on leaves; leaves with requires_grad=True instead have an AccumulateGrad Node which writes into tensor.grad.

When you call loss.backward():

1. Engine seeds the gradient for `loss` (default: 1.0 if scalar).
2. It performs a reverse topological traversal starting at loss.grad_fn.
3. For each Node in topo order, call Node.apply(grad_outputs).
   Result: grad_inputs, one per next_edge.
4. Send each grad_input to the corresponding next_edge's Node, accumulating.
5. When a Node's incoming grad count is satisfied, schedule it.

The engine is multi-threaded across devices: there is one worker per device that owns Nodes for that device (torch/csrc/autograd/engine.cpp). CPU work runs on the calling thread.

6.3 Trace of a tiny graph

import torch
a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)
c = a * b
d = c + a
d.backward()
print(a.grad, b.grad)   # tensor(4.) tensor(2.)

Forward graph (as the dispatcher's autograd kernels build it):

        AccumulateGrad(a)            AccumulateGrad(b)
              ^                            ^
              |                            |
              +-----------+   +------------+
                          |   |
                       MulBackward0      <- saves a, b
                          |
                          v
                          c
                          |
              +-----------+
              |
        AddBackward0      <- needs alpha=1 only
              |
              v
              d

Backward execution starting from d with seed 1:

AddBackward0(grad=1) -> grad_c = 1, grad_a_partial = 1
MulBackward0(grad=1)  -> grad_a_partial2 = b = 3, grad_b = a = 2
AccumulateGrad(a): a.grad = 1 + 3 = 4
AccumulateGrad(b): b.grad = 2

Note how a had two paths to it (through c and direct) and the engine summed contributions at AccumulateGrad. That summing is what next_edges plus accumulation buys you for free.

6.4 VJP definition per op

For each forward op y = f(x1, ..., xn), the VJP is grads_in = J^T @ grad_y. PyTorch implements VJPs op-by-op so it never materialises a Jacobian. For mul(a, b):

y = a*b
dy/da = b
dy/db = a
VJP given grad_y:
    grad_a = grad_y * b
    grad_b = grad_y * a

These are themselves PyTorch ops that go through the dispatcher. Crucially, by default they go through the dispatcher with the Autograd key still on -- enabling double backward (computing gradients of gradients). To turn that off you would compute backward in no_grad.

6.5 Custom torch.autograd.Function

You write one when no built-in op covers your forward, or when you have a faster handwritten backward. The shape:

import torch

class FusedScaleClamp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, scale, lo, hi):
        ctx.save_for_backward(x)
        ctx.scale = scale
        ctx.lo, ctx.hi = lo, hi
        y = torch.clamp(x * scale, min=lo, max=hi)
        return y

    @staticmethod
    def backward(ctx, grad_y):
        (x,) = ctx.saved_tensors
        s, lo, hi = ctx.scale, ctx.lo, ctx.hi
        # grad flows only where the clamp didn't saturate
        scaled = x * s
        mask = (scaled > lo) & (scaled < hi)
        grad_x = grad_y * s * mask.to(grad_y.dtype)
        # No grads w.r.t. python scalars
        return grad_x, None, None, None

# Use it
x = torch.randn(8, requires_grad=True)
y = FusedScaleClamp.apply(x, 2.0, -1.0, 1.0)
y.sum().backward()
print(x.grad)

Three rules to keep in mind:

  1. The number of returned grads in backward must equal the number of inputs to forward. Use None for non-differentiable inputs (Python scalars, ints).
  2. Anything you save_for_backward must be a tensor; non-tensor context goes on ctx.<attr>.
  3. If your forward calls only differentiable PyTorch ops, you usually don't need a custom Function -- just write the function. Custom Functions are for when you sidestep autograd (e.g., calling a Triton kernel, or wanting a fused/cheaper backward).

6.6 Versioning and inplace

Each Storage carries a version counter. Inplace ops bump it. Saved tensors record the version they were saved at. On backward, the engine checks: if the version is now higher, you mutated a tensor that was needed for grad and the engine raises RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation. This is the famous error. The fix is almost always: don't += into something whose value the backward needs.


7. requires_grad, no_grad, inference_mode

7.1 requires_grad

A per-tensor flag (lives in AutogradMeta). When true, ops involving the tensor produce outputs whose grad_fn is set, and the tensor itself appears in the graph (via AccumulateGrad if it is a leaf).

Default: false for plain tensors, true for nn.Parameter.

7.2 no_grad

A thread-local override. Inside with torch.no_grad(): (or @torch.no_grad()), GradMode::is_enabled() returns false. The autograd kernel for each op checks this flag and, if grad mode is off, skips recording -- it just redispatches to the layer below. The op still runs through the autograd dispatch key (because the inputs still have AutogradCUDA in their key set), it just doesn't build graph nodes.

Use case: evaluation. You still get inference correctness; you save the bookkeeping cost of building backward graphs.

7.3 inference_mode

A stronger thread-local mode introduced for inference workloads. Inside with torch.inference_mode()::

  1. The Autograd dispatch keys are excluded from the key set entirely. The dispatcher no longer enters the autograd kernel at all -- it goes straight to the backend kernel.
  2. Outputs are marked as inference tensors. Their version counter is disabled. They cannot later be used in autograd.

Why is it faster than no_grad?

Cost no_grad inference_mode
Enter autograd dispatch kernel yes no
Allocate AutogradMeta for outputs yes (cheap but nonzero) no
Bump version counter on inplace yes no
Lookup kernel via dispatcher once per op once per op

You skip a whole layer of dispatch and a pile of small allocations. Benchmarks show ~5-15% overhead reduction on small ops where dispatch cost dominates.

The price: outputs cannot be used in autograd later. If you accidentally pass an inference tensor into a training graph, you get RuntimeError: Inference tensors cannot be saved for backward.

7.4 When to use which

  • Training loop, ever: nothing.
  • Eval loop you might re-enter training from: torch.no_grad().
  • Pure inference server: torch.inference_mode(). Wrap once at the top of the request handler.

8. AMP / Autocast

8.1 Intuition

Mixed precision exists because matmul on tensor cores is much faster in FP16/BF16 than FP32, but reductions (loss, softmax denominator, layer norm stats) want FP32 to avoid catastrophic cancellation. Autocast classifies ops:

  • lower-precision allowed (matmul, conv, linear): cast inputs down before running.
  • must stay in FP32 (loss functions, softmax, layer norm in some cases): leave inputs alone, or cast up.
  • promote (add of mixed dtypes): cast all inputs to the highest dtype present.

8.2 Mechanism: it's just another dispatch key

When you enter with torch.autocast("cuda", dtype=torch.float16):, the thread-local include set adds AutocastCUDA. Now any op called on a CUDA tensor inside the block has AutocastCUDA in its key set. The autocast kernel for that op runs before the backend kernel. It looks up the op's autocast policy:

// pseudocode for an autocast-lower op like matmul
Tensor matmul_autocastCUDA(const Tensor& a, const Tensor& b) {
    auto target = at::autocast::current_dtype(c10::DeviceType::CUDA);  // e.g. half
    auto a_cast = cached_cast(target, a);
    auto b_cast = cached_cast(target, b);
    return at::redispatch::matmul(/*remove AutocastCUDA*/, a_cast, b_cast);
}

For a "must stay FP32" op, the autocast kernel casts up instead. cached_cast keeps a small thread-local cache so if you matmul the same weight twice in a forward you don't re-cast it.

The lists live in aten/src/ATen/autocast_mode.cpp (and FP16/BF16 specific lists). They are explicit: you can read which ops are "lower", "fp32", "promote".

8.3 GradScaler vs autocast-only

The danger of FP16 is gradient underflow: tiny gradients become 0. The fix is loss scaling: multiply the loss by a big number S before backward, divide grads by S before the optimizer step. If any grad becomes Inf/NaN, skip the step and shrink S; otherwise gradually grow S.

scaler = torch.cuda.amp.GradScaler()
for x, y in loader:
    opt.zero_grad(set_to_none=True)
    with torch.autocast("cuda", dtype=torch.float16):
        out = model(x)
        loss = loss_fn(out, y)
    scaler.scale(loss).backward()
    scaler.step(opt)
    scaler.update()

BF16 has the same exponent range as FP32, so underflow is not an issue and you do not need GradScaler:

for x, y in loader:
    opt.zero_grad(set_to_none=True)
    with torch.autocast("cuda", dtype=torch.bfloat16):
        out = model(x)
        loss = loss_fn(out, y)
    loss.backward()
    opt.step()

Rule of thumb on Ampere or later: BF16 unless you have a specific reason. Older GPUs (Volta, Turing) lack good BF16, so FP16 + GradScaler.


9. torch.compile Pipeline

torch.compile(model) returns a wrapped callable. Under the hood it composes three pieces: TorchDynamo (capture), AOTAutograd (joint forward+backward capture), Inductor (codegen).

+------------------+      +----------------+      +------------+      +------------+
|  Python module   | ---> |   TorchDynamo  | ---> | AOTAutograd | ---> | Inductor   |
+------------------+      +----------------+      +------------+      +------------+
                            captures FX graph     joint fwd/bwd       Triton/C++
                            + guards              traced into          kernels
                                                  core ATen            + scheduling

9.1 TorchDynamo

Source: torch/_dynamo/.

Dynamo is a Python-level tracer that hooks into CPython's frame evaluation API (PEP 523). It registers an alternative frame evaluator. When Python is about to execute a function decorated by torch.compile, Dynamo intercepts the bytecode, symbolically executes it, and produces:

  1. An FX graph of tensor ops (torch.fx.GraphModule).
  2. A set of guards -- runtime predicates over inputs that, if true, mean it's safe to reuse this graph (e.g., "input 0 is torch.float32", "input 0's shape is (B, 1024) for some int B", "this Python int equals 7").
  3. A residual bytecode that calls the compiled graph and does anything Dynamo couldn't handle.

The key idea: instead of tracing a Tensor program, Dynamo traces Python bytecode, with FakeTensors standing in for real tensors. Every PyTorch op invocation gets recorded into the FX graph. Every Python-level operation that isn't a tensor op (a list append, an if x.shape[0] > 16:) is either:

  • specialised into a guard ("we recorded this branch when the shape was 32; if at runtime it's not, recompile"), or
  • becomes a graph break.

Graph breaks

A graph break happens when Dynamo encounters something it cannot symbolically execute. Examples:

  • print(x) -- has a side effect Dynamo doesn't model.
  • Calling into a third-party C extension that isn't a torch op.
  • A try/except whose handling Dynamo isn't sure how to capture.
  • Data-dependent control flow on a tensor without using torch.cond.
  • Mutating a global Python object.

When a break happens, Dynamo:

  1. Compiles the graph it has so far.
  2. Falls back to the Python interpreter for the offending statement.
  3. Resumes tracing from the next statement -- producing a second graph after the break.

Each graph compiles separately and you pay the call/launch overhead between them. One graph break = lost optimization opportunity. Many = torch.compile may even be slower than eager.

Inspecting

import torch

@torch.compile
def f(x):
    y = x.sin()
    print("hi")          # graph break
    return y.cos()

torch._dynamo.explain(f)(torch.randn(4))
# prints: number of graphs, number of breaks, reasons, locations

torch._dynamo.explain is your first debugging tool. If it says "1 graph, 0 breaks", you're golden. If it says "5 graphs, 4 breaks", read each reason and fix.

Other useful env knobs:

TORCH_LOGS="dynamo"             # what dynamo is tracing
TORCH_LOGS="graph_breaks"       # only the breaks
TORCH_LOGS="recompiles"         # why a graph re-compiled at runtime
TORCH_LOGS="output_code"        # the generated kernels (see Inductor)

9.2 AOTAutograd

Source: torch/_functorch/aot_autograd.py and friends.

Once Dynamo hands an FX graph to the compiler backend, AOTAutograd does two things:

  1. Joint trace of forward + backward. It runs the forward FX graph through make_fx with grad enabled, then calls .backward() to also trace the backward. The output is one FX graph containing both, plus a partition that assigns nodes to "forward" or "backward" subgraphs (so they can run separately at runtime).
  2. Decomposition to core ATen. Higher-level ops (e.g., torch.nn.functional.layer_norm) are decomposed into their constituent core ops (mean, var, mul, add, ...). This shrinks the op surface Inductor must handle from thousands to ~250 core ops.

The decomposition table lives in torch/_decomp/. You can see what an op decomposes to:

from torch._decomp import core_aten_decompositions
table = core_aten_decompositions()
for k, v in list(table.items())[:5]:
    print(k)

After AOTAutograd you have two FX graphs in core ATen: forward_graph and backward_graph. They are pure functions of inputs (and saved-for-backward tensors). Now Inductor compiles each.

9.3 Inductor

Source: torch/_inductor/.

Inductor is a lowering compiler. It takes an FX graph in core ATen, builds an internal IR (Inductor IR) that represents loops over tensors, fuses adjacent pointwise ops into bigger loops, schedules reductions, and emits target code.

Two backends:

  • CUDA / ROCm: emits Triton kernels. Triton handles the GPU-specific tiling and memory hierarchy; Inductor decides what to fuse and what shapes to specialise on.
  • CPU: emits C++ with OpenMP pragmas, optionally with vector ISA intrinsics (AVX2/AVX-512). Compiles via the system compiler at runtime.

Pipeline inside Inductor:

core ATen FX graph
    -> lowering        (each op -> Inductor IR ops; eg. add -> Pointwise(...))
    -> scheduler       (group nodes that can fuse; assign to kernels)
    -> codegen         (emit Triton or C++)
    -> compile         (call Triton's autotuner / call cc -O3)
    -> wrapper         (Python wrapper that calls each kernel in order)

Fusion is the big win. A handwritten layer norm in eager is 5+ kernels (mean, var, sub, mul, add). Inductor often emits one fused Triton kernel. Same for activation+linear-bias, embedding+layernorm, etc.

Inspecting generated code

TORCH_LOGS="output_code" python my_script.py

You get the Triton (or C++) source dumped. It's worth reading at least once -- you'll see something like:

@triton.jit
def triton_poi_fused_add_mul_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr):
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    a = tl.load(in_ptr0 + xindex, xmask)
    b = tl.load(in_ptr1 + xindex, xmask)
    c = a * 2.0 + b
    tl.store(out_ptr0 + xindex, c, xmask)

That's a fused a*2 + b. Two separate eager ops collapsed into one launch.

You can also dump the FX graph after AOTAutograd:

TORCH_LOGS="aot_graphs"

9.4 Guards and recompilation

A compiled artifact is keyed by (op graph, guards). At each call, Dynamo evaluates the guards over the actual inputs. If all hold, run the cached compiled artifact. If any fail, recompile.

Common guards:

  • Type guards: input is a Tensor, dtype is float32, device is cuda:0.
  • Shape guards: rank is 3, sizes are (2, ?, 1024). The ? may be symbolic (a free variable) or static (a specific int) depending on the dynamic-shape mode.
  • requires_grad guards: input had requires_grad=True. (Recompile if you switch eval/train without telling it.)
  • Python guards: a constant Python int equals N, a list has length M, a particular nn.Module instance is the same object.

Dynamic vs static shapes:

  • torch.compile(model) (default in 2.4+): tries dynamic shapes when it sees the same shape vary; specialises when it doesn't.
  • torch.compile(model, dynamic=False): always specialise on shapes. Faster code, more recompiles.
  • torch.compile(model, dynamic=True): assume dynamic from the start. Fewer recompiles, sometimes slower per-iter.

If you see frequent recompiles (TORCH_LOGS="recompiles"), the usual culprits are:

  1. Variable batch size with dynamic=False.
  2. Variable sequence length without marking it dynamic.
  3. Calling with different requires_grad settings (eval vs train without .eval()/.train()).
  4. Using Python lists/tuples whose length varies.

9.5 Modes

torch.compile(model, mode="default")
torch.compile(model, mode="reduce-overhead")
torch.compile(model, mode="max-autotune")
  • default: balanced. Compile time low-ish, runtime good.
  • reduce-overhead: enables CUDA graphs around the compiled region. CUDA graphs eliminate per-op CUDA launch overhead by recording a sequence and replaying it as one submission. Big win for small batches and lots of small ops. Constraints: shapes must be static across replays, and tensors must live at the same addresses (Inductor handles this with persistent input buffers; you may need to .clone() inputs or warm up).
  • max-autotune: Triton autotunes block sizes per kernel, multiple template variants for matmul, longest compile time, often best runtime.

9.6 End-to-end example

import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(1024, 4096)
        self.l2 = nn.Linear(4096, 1024)
    def forward(self, x):
        return self.l2(torch.relu(self.l1(x)))

model = Net().cuda().to(torch.bfloat16)
model = torch.compile(model, mode="reduce-overhead")

x = torch.randn(32, 1024, device="cuda", dtype=torch.bfloat16)
# Warm-up: compiles + records cuda graph
for _ in range(3):
    y = model(x)

# Steady-state: every call is a single CUDA graph replay
for _ in range(1000):
    y = model(x)

What happened on first call:

  1. Dynamo hooks forward, traces it into FX (3 nodes: linear, relu, linear).
  2. AOTAutograd skips backward (no grad needed), decomposes linear -> matmul + add.
  3. Inductor lowers, fuses relu with the second matmul's bias-add prologue if possible, generates two Triton matmul kernels and a fused activation/bias kernel.
  4. CUDA graph captures the launch sequence on call 2.
  5. Calls 3+ replay the graph.

10. Custom Op Registration (Modern Path)

You want to register a new op so that:

  • It has an autograd rule.
  • It survives torch.compile (Dynamo and Inductor know what to do).
  • It works under FakeTensor / meta tracing.

The modern API is torch.library, available in 2.4+. Avoid the older torch.autograd.Function - only path when going throughtorch.compile`; Dynamo will graph-break on it.

10.1 Skeleton

import torch

# 1. Declare and implement the op for real backends
@torch.library.custom_op("mylib::myadd", mutates_args=())
def myadd(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x + y       # or call into a Triton kernel here

# 2. Tell the compiler/dynamo how shapes propagate (the "fake" / meta impl)
@myadd.register_fake
def _(x, y):
    # Must match real op's output shape/dtype/device, with no real compute
    return torch.empty_like(x)

# 3. Register an autograd rule
def myadd_setup_context(ctx, inputs, output):
    # Save what backward needs
    pass  # nothing for plain add

def myadd_backward(ctx, grad):
    return grad, grad   # dx, dy

myadd.register_autograd(myadd_backward, setup_context=myadd_setup_context)

Now mylib::myadd is a first-class op. You can call torch.ops.mylib.myadd(x, y) and it goes through the dispatcher like any built-in.

10.2 What each piece is for

  • custom_op: the user-facing implementation. Runs in eager mode.
  • register_fake: Dynamo / FakeTensor / torch.export use this to symbolically execute your op. It must allocate output tensors with correct shape/dtype/device but no real values. Without this, Dynamo will graph-break at your op.
  • register_autograd: the VJP. Mirrors torch.autograd.Function.backward. Setup context can save tensors via ctx.save_for_backward(...).
  • mutates_args: tuple of arg names that are mutated in place. The compiler needs to know this for correctness when reordering / re-using buffers.

10.3 A Triton kernel as a custom op (worked example)

import torch
import triton
import triton.language as tl

@triton.jit
def _add_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK: tl.constexpr):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < N
    a = tl.load(x_ptr + offs, mask=mask)
    b = tl.load(y_ptr + offs, mask=mask)
    tl.store(out_ptr + offs, a + b, mask=mask)

def _add_launch(x, y):
    out = torch.empty_like(x)
    N = x.numel()
    BLOCK = 1024
    grid = ((N + BLOCK - 1) // BLOCK,)
    _add_kernel[grid](x, y, out, N, BLOCK=BLOCK)
    return out

@torch.library.custom_op("mylib::triton_add", mutates_args=())
def triton_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    assert x.is_cuda and y.is_cuda and x.shape == y.shape
    return _add_launch(x.contiguous(), y.contiguous())

@triton_add.register_fake
def _(x, y):
    return torch.empty_like(x)

def _bwd(ctx, g):
    return g, g

triton_add.register_autograd(_bwd)

# Use it
x = torch.randn(4096, device="cuda", requires_grad=True)
y = torch.randn(4096, device="cuda", requires_grad=True)
z = torch.ops.mylib.triton_add(x, y)
z.sum().backward()
print(x.grad.shape, y.grad.shape)

# It also works under torch.compile thanks to register_fake
def f(a, b):
    return torch.ops.mylib.triton_add(a, b).relu()
g = torch.compile(f)
g(x, y)

The key win: Dynamo treats triton_add as an opaque op (it doesn't try to look inside the Triton kernel). It uses register_fake to know the shape, and register_autograd to know how grads flow. Inductor will not fuse with surrounding ops -- but it also won't graph-break.

If you do want Inductor to fuse with surrounding ops, write the kernel in core ATen (let Inductor codegen its own Triton). Custom Triton ops are for cases where you have a hand-tuned kernel that beats codegen.


11. C++ Extension Path

When you need raw C++/CUDA, the supported flow is torch.utils.cpp_extension. There are two flavours:

  • JIT (load, load_inline): compiles on first import. Great for development.
  • AOT (setup.py with CUDAExtension / CppExtension): produces a wheel.

11.1 Minimal setup.py skeleton

my_ext/
  setup.py
  src/
    binding.cpp
    kernel.cu

src/kernel.cu:

#include <torch/extension.h>

__global__ void scale_kernel(const float* in, float* out, float s, int N) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < N) out[i] = in[i] * s;
}

torch::Tensor scale_cuda(torch::Tensor x, double s) {
    TORCH_CHECK(x.is_cuda(), "x must be cuda");
    TORCH_CHECK(x.scalar_type() == torch::kFloat32, "x must be float32");
    auto y = torch::empty_like(x);
    int N = x.numel();
    int block = 256;
    int grid = (N + block - 1) / block;
    scale_kernel<<<grid, block>>>(x.data_ptr<float>(), y.data_ptr<float>(),
                                  static_cast<float>(s), N);
    return y;
}

src/binding.cpp:

#include <torch/extension.h>

torch::Tensor scale_cuda(torch::Tensor x, double s);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("scale", &scale_cuda, "scale(x, s) = x * s on CUDA");
}

setup.py:

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name="my_ext",
    ext_modules=[
        CUDAExtension(
            name="my_ext",
            sources=["src/binding.cpp", "src/kernel.cu"],
            extra_compile_args={"cxx": ["-O3"], "nvcc": ["-O3"]},
        ),
    ],
    cmdclass={"build_ext": BuildExtension},
)

Build and use:

pip install -e .
import torch, my_ext
x = torch.randn(1024, device="cuda")
y = my_ext.scale(x, 2.5)

11.2 ABI considerations

PyTorch is built with a specific C++ ABI (the C++11 GCC ABI on Linux). Your extension must be compiled against the same PyTorch headers and the same compiler ABI flag. Practical rules:

  • Always build the extension on the machine where it will run, against the installed PyTorch wheel, or publish per-PyTorch-version wheels.
  • Match the CUDA toolkit major version to PyTorch's CUDA major (e.g. CUDA 12.1 PyTorch -> CUDA 12.x toolkit).
  • Avoid passing C++ exceptions across the pybind boundary. Use TORCH_CHECK for user errors -- it raises Python RuntimeError.
  • Don't statically link C++ standard library; let the system one be used.
  • For wheels, use manylinux2014 or newer base images; build separate wheels per (PyTorch version, CUDA version, Python version) tuple.

11.3 Registering the C++ op into the dispatcher

If you want it to be a real dispatcher op (so autograd, autocast, etc. integrate), use TORCH_LIBRARY (Section 4.5) instead of (or in addition to) the pybind binding. That gives you torch.ops.myext.scale(x, s) and full dispatch behavior.


12. The CUDA Caching Allocator

12.1 Why caching

cudaMalloc and cudaFree are slow (often 100us+ each) and synchronous on the default stream. A naive implementation would call them per tensor. PyTorch instead routes all CUDA tensor allocations through a caching allocator (c10/cuda/CUDACachingAllocator.cpp).

12.2 Mechanism

Conceptually:

CachingAllocator state per device:
  large_blocks:   sorted-by-size list of free blocks >= 1MB
  small_blocks:   sorted-by-size list of free blocks <  1MB
  active_blocks:  {ptr -> Block}     # currently held by a Tensor

allocate(size, stream):
    round size up (small to nearest 512B; large to nearest 2MB)
    search the appropriate free list for a block of >= size on a compatible stream
    if found: split off the suffix as a free block, return prefix
    else:
        cudaMalloc a fresh segment (geometric growth: 2MB, 20MB, 200MB, ...)
        carve a block out of it, return it

free(ptr):
    mark Block free
    record the stream we last used it on; only reusable on that stream
        unless cuda events confirm cross-stream safety
    coalesce with neighbors in the same segment if both free

Two keys to internalize:

  1. free() does not call cudaFree. It returns the block to the pool. From the driver's perspective the memory is still allocated.
  2. Stream-aware reuse. Memory freed on stream A cannot be reused on stream B until events confirm the prior work has finished. This is why multi-stream code can OOM where single-stream code wouldn't: the allocator is conservatively keeping blocks pinned to the original stream.

12.3 empty_cache

torch.cuda.empty_cache()

Walks the free lists and actually cudaFrees segments that contain only free blocks. Returns memory to the driver -- visible to other processes (e.g., another container sharing the GPU). Does not shrink anything currently in use, and does not improve performance for your own process (the caching allocator was already going to reuse those blocks). Use it when you need to release memory across processes; do not sprinkle it through your training loop.

12.4 Reading memory_summary

print(torch.cuda.memory_summary(device=0, abbreviated=False))

You get a table like:

|---|------------|-----------|-----------|-----------|
|   |   Cur Usage|  Peak Usage| Tot Alloc | Tot Freed |
|---|------------|-----------|-----------|-----------|
|Allocated memory   |  ...      |  ...      |  ...      |  ...
|Active memory      |  ...      |  ...      |  ...      |  ...
|GPU reserved memory|  ...      |  ...      |  ...      |  ...
|Non-releasable mem |  ...      |  ...      |  ...      |  ...

Definitions:

  • Allocated: bytes currently held by Tensors.
  • Active: allocated + still-pending-on-stream-thus-uncoalescable.
  • Reserved: total cudaMalloc'd. Reserved - Allocated = sitting in the cache.
  • Non-releasable: free-but-cannot-be-given-back-to-driver because the segment still has at least one in-use block.

12.5 Fragmentation

The classic failure mode: you have 3GB free in the cache, but it's split into 300 blocks of ~10MB each, none big enough to satisfy a 100MB request. The allocator does coalesce neighbors, but only within the same segment. Mitigations:

  1. PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True (PyTorch 2.0+). The allocator uses CUDA virtual memory APIs (cuMemMap/cuMemAddressReserve) to grow a single backing segment instead of allocating many. Drastically reduces fragmentation for variable-shape workloads.
  2. PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:N: don't split blocks larger than N MB, reducing tiny suffix blocks scattered around.
  3. Avoid pathologic patterns: alternating very large and very small allocations on the same stream.

If you OOM at "tried to allocate 1GiB but only 500MiB free although 4GiB reserved", that is fragmentation. Check memory_summary; consider expandable segments.


13. Profiling Internals

torch.profiler.profile(...) (in torch/profiler/) records per-op entry/exit by registering callbacks at the dispatcher. Each time the dispatcher enters or exits an op, it calls every registered observer. The profiler is one such observer; so are autograd hooks and record_function.

13.1 Minimal use

import torch
from torch.profiler import profile, record_function, ProfilerActivity

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    record_shapes=True,
    with_stack=False,
) as prof:
    with record_function("forward"):
        y = model(x)
    with record_function("loss"):
        loss = (y - target).pow(2).mean()
    with record_function("backward"):
        loss.backward()

print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
prof.export_chrome_trace("trace.json")    # open with chrome://tracing or perfetto

Open the trace and you get a flame-graph-like view: CPU dispatcher events on top, CUDA kernel events on the bottom with launch arrows. The gaps between kernels are launch overhead.

13.2 What each column means

  • Self CPU: time the op itself spent in CPU (dispatcher + Python -> C++ + kernel launch).
  • CPU total: includes children. A linear op's total includes its matmul and add children.
  • Self CUDA / CUDA total: same on GPU, measured via CUDA events.
  • # of Calls: how many times this op key (with these shapes) was hit.

13.3 Reading patterns

  • "Self CPU >> Self CUDA, kernels short": you are launch-overhead bound. Try torch.compile(mode="reduce-overhead") for CUDA graphs, or batch up small ops.
  • "Self CUDA dominates, one kernel is 80% of it": profile that kernel; consider a different algorithm (FlashAttention, a fused MoE), or see if Inductor will generate a better one with max-autotune.
  • "CUDA gaps with no work": host is too slow producing input. DataLoader is the usual suspect; bump num_workers, prefetch, pin memory.
  • "memcpy dominating": you're moving data CPU<->GPU per step. Pin host memory, pre-load to GPU, or use non_blocking=True with pinned source.

13.4 Autograd hooks for finer questions

hooks = []
for name, p in model.named_parameters():
    h = p.register_hook(lambda g, n=name: print(n, g.norm()))
    hooks.append(h)
loss.backward()
for h in hooks: h.remove()

Hook fires from the engine's worker thread when the grad is computed. Useful for "which parameter's grad is NaN".


14. Practical Exercises (with answers)

Exercise 1: stride sleuthing

You have:

import torch
a = torch.arange(60).reshape(3, 4, 5)
b = a.permute(2, 0, 1)

Without running this, what are b.shape, b.stride(), b.storage_offset()? Is b.is_contiguous()? Why?

Answer. a.shape=(3,4,5), a.stride()=(20,5,1). Permute remaps dims by index: new dim 0 = old dim 2, new dim 1 = old dim 0, new dim 2 = old dim 1. So b.shape=(5,3,4), b.stride()=(1,20,5), b.storage_offset()=0. Not contiguous because canonical strides for (5,3,4) would be (12,4,1) and ours are (1,20,5).

Exercise 2: dispatch trace

You write:

with torch.autocast("cuda", dtype=torch.bfloat16):
    with torch.no_grad():
        y = a @ b

where a, b are CUDA fp32 tensors with requires_grad=True. List the keys in the dispatch set for @, the order of kernels invoked, and the dtype of y.

Answer. - Per-tensor key set: {AutogradCUDA, CUDA}. - Local include from autocast: adds AutocastCUDA. - Local exclude from no_grad: does not exclude Autograd keys (that's inference_mode's job). However, the autograd kernel checks GradMode::is_enabled() and, finding it false, just redispatches without recording.

So the order is: AutogradCUDA kernel (skips recording, redispatches) -> AutocastCUDA kernel (casts a,b to bf16, redispatches) -> CUDA kernel (runs bf16 matmul). y.dtype is bfloat16. y.requires_grad is False.

Exercise 3: detect a graph break

@torch.compile
def f(x, n):
    if n.item() > 0:
        return x.sin()
    else:
        return x.cos()

Why is this slow / breaky, and how do you fix?

Answer. n.item() materialises a tensor value to a Python int. Dynamo cannot symbolically execute that branch -- it triggers a graph break (or specialisation / recompile per value of n). The fix: use torch.where (data-dependent on tensor) or torch.cond (if you really need control flow):

@torch.compile
def f(x, n):
    return torch.where(n > 0, x.sin(), x.cos())

Exercise 4: a custom op that survives compile

Write a custom op clip_norm(x, max_norm) that scales x so its L2 norm is at most max_norm. Make sure it works under torch.compile.

Answer.

import torch

@torch.library.custom_op("mylib::clip_norm", mutates_args=())
def clip_norm(x: torch.Tensor, max_norm: float) -> torch.Tensor:
    n = x.norm()
    scale = (max_norm / (n + 1e-12)).clamp(max=1.0)
    return x * scale

@clip_norm.register_fake
def _(x, max_norm):
    return torch.empty_like(x)

# autograd: easiest is to leave it to the implementation
# since it uses only autograd-aware ops; but custom_op disables
# autograd-through-implementation by default. So:
def _bwd(ctx, g):
    x, scale = ctx.saved
    # d(x*scale)/dx = scale (treating scale as constant for simplicity)
    return g * scale, None

def _setup(ctx, inputs, output):
    x, max_norm = inputs
    n = x.norm()
    scale = (max_norm / (n + 1e-12)).clamp(max=1.0)
    ctx.saved = (x, scale)

clip_norm.register_autograd(_bwd, setup_context=_setup)

Note we approximate the gradient by treating scale as a constant -- that's the standard / desired behaviour for gradient clipping.

Exercise 5: why does this OOM?

Training fine yesterday, today OOMs at the same batch size. Memory summary shows Reserved=22GiB, Allocated=8GiB. Tried empty_cache, no change. What are two most likely causes and one mitigation?

Answer. Causes: 1. Fragmentation: the 14 GiB cache is split into too many small free blocks for a big request to find a contiguous free run. Mitigation: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True (PyTorch 2.0+). 2. Stream-pinned blocks: a multi-stream codepath freed memory on stream A; the allocator can't yet hand it to stream B. Mitigation: synchronize, or unify streams.

empty_cache does nothing here because all 14GiB of free cache is non-releasable (segments still have allocated blocks).

Exercise 6: inference_mode vs no_grad latency

You have a hot inference path that runs ~50 small ops per request (lots of LayerNorm, GELU, small matmul). You measure 8% latency drop switching no_grad -> inference_mode. Why? Where would the gain be much smaller?

Answer. inference_mode excludes Autograd dispatch keys entirely, so each op skips the autograd kernel layer (a function call, a check, and an output AutogradMeta allocation). Saving ~hundreds of nanoseconds per op times ~50 ops times batch is a measurable percentage when individual ops are short.

The gain shrinks toward zero as ops get bigger: a single 4096x4096 fp16 matmul takes milliseconds, dwarfing the dispatch cost. The win is in launch-overhead-bound regimes. For one big op per request, prefer profiling to see if it's worth bothering.


15. A Coherent Mental Model To Keep

If you remember nothing else, remember these seven sentences:

  1. A Tensor is a (storage, sizes, strides, storage_offset, dtype, device) tuple.
  2. Views share storage; contiguous() materialises.
  3. The dispatcher picks a kernel from a key set assembled from inputs and thread-local mode; layers redispatch by removing their own key.
  4. Autograd is a layer in the dispatcher that, when grad mode is on, builds a tape; loss.backward() traverses it.
  5. inference_mode is faster than no_grad because it removes the Autograd layer entirely.
  6. Autocast is just a dispatcher layer that casts inputs before the backend runs.
  7. torch.compile is Dynamo (capture Python -> FX) + AOTAutograd (joint forward/backward in core ATen) + Inductor (Triton/C++ codegen with fusion); guards govern when the compiled artifact is reused.

Everything else -- caching allocator, profiler, custom ops -- hangs off these. Once you can simulate the dispatcher in your head and reason about the compile pipeline, PyTorch internals stop feeling like a foreign country and become a place you live.


Appendix A: Source Tree Map

A high-confidence cheat sheet for navigating the repo:

Path Contents
c10/core/ TensorImpl, Storage, DispatchKey, Device, Layout. The smallest, most stable foundation.
c10/cuda/ CUDACachingAllocator, CUDA stream/event wrappers.
aten/src/ATen/core/ The dispatcher (Dispatcher.cpp), op registration (library.cpp).
aten/src/ATen/native/ Op kernels (CPU / generic). BinaryOps.cpp, ReduceOps.cpp, etc.
aten/src/ATen/native/cuda/ CUDA kernels.
aten/src/ATen/native/native_functions.yaml The op-schema source of truth.
aten/src/ATen/autocast_mode.cpp Autocast policies (which ops are FP16/BF16/FP32/promote).
tools/autograd/derivatives.yaml Op-by-op VJPs for codegen.
torch/csrc/autograd/ Engine, Function (Node), saved variables.
torch/_dynamo/ TorchDynamo (frame eval, symbolic execution, guards).
torch/_functorch/aot_autograd.py AOTAutograd.
torch/_decomp/ Decompositions to core ATen.
torch/_inductor/ Inductor lowering, scheduler, codegen (codegen/triton.py, codegen/cpp.py).
torch/library.py Modern custom-op API (custom_op, register_fake, register_autograd).
torch/utils/cpp_extension.py JIT and AOT C++/CUDA extension build helpers.
torch/profiler/ Profiler frontend; backend in torch/csrc/profiler/.

When in doubt, git grep for the op name in aten/src/ATen/native/ -- the kernel is almost always there.

Appendix B: Useful Environment Variables

TORCH_LOGS=dynamo,graph_breaks,recompiles,aot_graphs,output_code
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb:128
TORCH_SHOW_DISPATCH_TRACE=1            # prints kernel choice per op (verbose)
TORCH_USE_CUDA_DSA=1                   # device-side assertions for shape/index errors
CUDA_LAUNCH_BLOCKING=1                  # serialises kernel launches; better stack traces
TORCHINDUCTOR_CACHE_DIR=/tmp/inductor   # control compile cache location
TORCHINDUCTOR_MAX_AUTOTUNE=1            # equivalent to mode="max-autotune"

End of chapter.

Comments