PyTorch Internals: A Deep Dive Reference¶
A self-contained chapter for the AI Systems curriculum, Month 3. Target reader: a backend/SRE engineer who already writes PyTorch model code, wants to understand what happens beneath
tensor.add_()andmodel.compile(), and is willing to read C++-flavoured pseudocode. Goal: after this chapter you should be able to read the PyTorch source tree, debug a dispatch-related bug, write a custom op that survivestorch.compile, and reason about performance from first principles.
0. How To Read This Chapter¶
PyTorch is a stack of layers. Understanding it means understanding which layer owns which decision. The chapter walks down the stack on the way in (Python -> ATen -> dispatcher -> backend), then back up (autograd, AMP, compile) because higher layers are easier to follow once you know the substrate.
For each topic you will see four passes:
- Intuition -- what mental model is correct.
- Mechanism -- the actual data structures and control flow.
- Minimal code -- the smallest example that exercises the mechanism.
- Dispatch trace -- "if I were the dispatcher, what would I do step by step." This is the most underrated reasoning tool in PyTorch -- once you can simulate the dispatcher in your head, almost every weird bug becomes obvious.
All code examples target PyTorch 2.4+ on Linux/CUDA. Where source paths are referenced, they are relative to the pytorch/pytorch repo root.
1. The Layered Architecture¶
1.1 Intuition¶
When you write c = a + b in Python you are at the top of a five-layer cake. The layers exist because each one solves a different problem:
| Layer | Language | Job |
|---|---|---|
Python frontend (torch.*) |
Python | Ergonomics, autograd surfaces, nn.Module |
Pybind shim (torch._C) |
C++/pybind11 | Convert PyObject -> C++ args, hold the GIL boundary |
ATen (at::Tensor, ops) |
C++ | The op API. Defines what add means type-erased over backends |
Dispatcher (c10) |
C++ | Pick which kernel to run (Autograd? CUDA? Autocast?) |
| Backend kernels | C++/CUDA/Triton/MPS/etc. | Actually compute the bytes |
The dispatcher is the keystone. ATen does not call kernels directly. ATen says "here is add(Tensor, Tensor) -> Tensor, dispatcher, please find the right implementation given these tensor properties." This indirection is what lets autograd, autocast, vmap, FakeTensor, meta tensors, and quantization all hook in at the same place.
1.2 Mechanism¶
Top-to-bottom for c = a + b where a, b are CUDA float32, requires_grad=True:
Python: c = a + b
-> Tensor.__add__(self, other)
-> torch._C._TensorBase.add(self, other) # via pybind
C++ (ATen): at::add(self, other)
-> at::_ops::add_Tensor::call(self, other) # codegen'd
Dispatcher: Dispatcher::singleton().call(op_handle, stack)
-> picks kernel by computed DispatchKeySet
1. Autograd kernel (records grad_fn, then redispatches)
2. AMP/Autocast kernel (maybe casts, then redispatches)
3. CUDA kernel (the real one)
CUDA: vectorized elementwise add launches; returns Tensor
The order of layers is not arbitrary. Autograd wraps autocast wraps the backend, because:
- Autograd needs to see the original op so it can record the right grad_fn.
- Autocast needs to decide casts before the backend sees the dtypes.
- The backend just computes.
You will see the same pattern reappear: any new cross-cutting concern (functionalization, batching for vmap, fake-tensor tracing) becomes a new DispatchKey somewhere in the stack.
1.3 ASCII trace of a + b¶
+---------------------------+
Python --> | Tensor.__add__ |
+-------------+-------------+
|
v torch._C (pybind)
+---------------------------+
| at::add(Tensor, Tensor) | ATen
+-------------+-------------+
|
v
+---------------------------+
| OperatorHandle::call | Dispatcher
| computes DispatchKeySet |
+-------------+-------------+
|
redispatches through keys:
|
+--------+--------+
| Autograd kernel | records AddBackward; redispatch w/o Autograd key
+--------+--------+
|
+--------+--------+
| Autocast kernel | (if active) cast inputs; redispatch w/o Autocast key
+--------+--------+
|
+--------+--------+
| CUDA kernel | vectorized_elementwise_kernel<<<...>>>(...)
+-----------------+
Memorise that picture. Almost every "why is this slow / wrong / weird" question maps to a layer in it.
2. Tensor Representation¶
2.1 The four-level wrapping¶
A torch.Tensor you hold in Python is a thin handle. Underneath:
torch.Tensor # Python object, subclass of torch._C._TensorBase
-> at::Tensor # C++ value type, ~one pointer wide
-> c10::TensorImpl # the heap object: dtype, sizes, strides, key set
-> c10::Storage # owns the bytes (or shares them with views)
-> c10::DataPtr -> raw void* + Device + Allocator
at::Tensor is essentially intrusive_ptr<TensorImpl>. Copying a tensor in C++ bumps a refcount; it does not copy bytes. That is why Tensor a = b; is cheap -- they share the same TensorImpl (and hence the same Storage).
2.2 Fields you must know¶
TensorImpl (in c10/core/TensorImpl.h) carries roughly:
| Field | Type | Meaning |
|---|---|---|
storage_ |
c10::Storage |
The byte buffer. Shared between views. |
sizes_ |
SmallVector<int64_t> |
Shape. |
strides_ |
SmallVector<int64_t> |
How many elements (not bytes) to step per dim. |
storage_offset_ |
int64_t |
Where this tensor starts inside the storage, in elements. |
dtype_ |
caffe2::TypeMeta |
float32, bfloat16, int64, ... |
device_ |
c10::Device |
(DeviceType::CUDA, index=0), (CPU, -1), etc. |
layout_ |
c10::Layout |
Strided, Sparse, SparseCsr, Mkldnn. |
key_set_ |
DispatchKeySet |
Bitset of dispatch keys (Autograd, CUDA, ...). |
requires_grad_ |
bool | Lives via AutogradMeta, not directly here. |
autograd_meta_ |
unique_ptr<AutogradMetaInterface> |
grad, grad_fn, version counter. |
Storage (in c10/core/Storage.h) carries:
| Field | Type | Meaning |
|---|---|---|
data_ptr_ |
c10::DataPtr |
Owning pointer + device + allocator deleter. |
size_bytes_ |
size_t |
Capacity in bytes. |
allocator_ |
c10::Allocator* |
Where to ask for memory. On CUDA this is the caching allocator (Section 12). |
resizable_ |
bool | Can resize_ grow the buffer? |
Note: Storage does not know dtype or shape. It is just bytes. Two views of the same Storage can in principle even disagree on dtype (e.g., a.view(torch.int32) reinterprets bits).
2.3 The single most important invariant¶
For a strided tensor:
address_of_element(i0, i1, ..., in) =
storage.data_ptr
+ dtype_size * (storage_offset + sum_k(ik * stride_k))
That's it. Sizes, strides, storage_offset, dtype, base pointer. Five things determine where every element lives. Views are just other tensors that share storage but have different (sizes, strides, storage_offset). Contiguous is a property of the strides relative to the sizes, not of memory itself.
2.4 Why decouple view from storage¶
If shape lived inside storage you would copy bytes for transpose, narrow, unsqueeze. Decoupling lets these be O(1) metadata-only ops. The cost: kernels must respect arbitrary strides, or you must contiguous() first. PyTorch favours the first for "shape ops" and the second for "compute ops" -- compute kernels typically demand contiguous (or one of a few canonical memory formats) input.
3. Strides and Views¶
3.1 Stride arithmetic¶
For a contiguous (2, 3, 4) float32 tensor:
sizes = [2, 3, 4]
strides = [12, 4, 1] # elements (not bytes)
element (i, j, k) -> offset = 12*i + 4*j + 1*k
Strides for contiguous (row-major) tensors are: stride[i] = prod(sizes[i+1:]).
3.2 Three view ops with no copy¶
import torch
a = torch.arange(24, dtype=torch.float32).view(2, 3, 4) # contiguous
print(a.stride()) # (12, 4, 1)
b = a.transpose(0, 2) # swap dim 0 and dim 2
print(b.shape, b.stride()) # torch.Size([4, 3, 2]), (1, 4, 12)
print(b.is_contiguous()) # False
print(b.data_ptr() == a.data_ptr()) # True -- same storage
c = a.narrow(1, 1, 2) # along dim 1, start=1, length=2
print(c.shape, c.stride(), c.storage_offset())
# torch.Size([2, 2, 4]), (12, 4, 1), 4
d = a.unsqueeze(0) # add a leading dim of size 1, stride 0 (or any)
print(d.shape, d.stride()) # torch.Size([1, 2, 3, 4]), (24, 12, 4, 1) or similar
What changed in each:
| Op | sizes | strides | storage_offset | storage |
|---|---|---|---|---|
transpose(0,2) |
(4,3,2) |
(1,4,12) |
0 | shared |
narrow(1,1,2) |
(2,2,4) |
(12,4,1) |
1*4 = 4 |
shared |
unsqueeze(0) |
(1,2,3,4) |
(24,12,4,1) |
0 | shared |
The transposed tensor's storage looks the same byte-for-byte; we just re-described how to walk it. That is the entire trick.
3.3 is_contiguous and contiguous()¶
is_contiguous() returns true iff strides equal the canonical strides for the sizes:
Many fused/elementwise CUDA kernels assume contiguous input so they can do unit-stride vectorised loads. If you pass them a transposed tensor they would either: - error out with a stride check, or - silently fall back to a strided kernel that is 5-10x slower.
So shape-bending operations are often followed by .contiguous() before a heavy op:
y = x.transpose(1, 2).contiguous() # make a real copy with canonical strides
z = some_fused_kernel(y)
contiguous() allocates new storage and copies. Skipping it when needed wastes throughput; calling it when not needed wastes memory.
3.4 Memory format: channels_last¶
For 4D image tensors, two memory layouts are common:
torch.contiguous_format(NCHW): strides(C*H*W, H*W, W, 1).torch.channels_last(NHWC): strides(C*H*W, 1, W*C, C).
Both have shape (N, C, H, W) -- only strides differ. CUDNN and many fused conv kernels are faster on channels_last for FP16/BF16 because it matches tensor-core memory access patterns. You opt in with:
Now x.is_contiguous() is false but x.is_contiguous(memory_format=torch.channels_last) is true. Kernels that understand the format will not insert a copy; kernels that do not will materialize a contiguous tensor on the way in.
3.5 The dispatcher's view of strides¶
The dispatcher does not know or care about strides. Strides live in TensorImpl. The kernel underneath cares. This is why a stride bug is almost always a kernel bug, never a dispatcher bug.
4. The Dispatcher¶
4.1 Intuition¶
The dispatcher is a polymorphism mechanism more powerful than virtual functions. A virtual call dispatches on one type. The PyTorch dispatcher dispatches on a set of features that come from all tensor inputs combined: the union of their dispatch keys plus thread-local state (autocast active? grad enabled?).
Think of it as: every op is a function pointer table indexed by DispatchKey. Every call computes a key set, picks the highest priority key in the set, looks up the kernel, runs it. The kernel may "redispatch" -- remove its own key from the set and ask the dispatcher to do it again -- to chain effects.
4.2 The DispatchKey enum¶
Defined in c10/core/DispatchKey.h. The keys form a priority order. Roughly (highest first):
... functorch / vmap / FuncTorchBatched ...
PythonTLSSnapshot
Python # __torch_dispatch__
Functionalize # for export / aot
... per-backend autograd ...
AutogradOther / AutogradCPU / AutogradCUDA / AutogradXPU / AutogradMeta
... AMP / autocast keys ...
AutocastCPU
AutocastCUDA
AutocastXPU
... tracing / profiling ...
... backend keys (lowest, where real work happens) ...
CPU
CUDA
MPS
XPU
Meta # shape-only "fake" tensors
SparseCPU / SparseCUDA
QuantizedCPU / QuantizedCUDA
Each tensor has a DispatchKeySet -- a 64-bit bitset over these keys. A typical CUDA tensor with requires_grad=True has {AutogradCUDA, CUDA}. If you enter a with torch.autocast("cuda"): block, the thread-local "included" set adds AutocastCUDA, so calls inside that block effectively dispatch on {AutogradCUDA, AutocastCUDA, CUDA}.
4.3 Computing the key for a call¶
Per call:
ks = empty
for each tensor argument t:
ks |= t.key_set()
ks |= local_include_set() # e.g. AutocastCUDA inside autocast()
ks &= ~local_exclude_set() # e.g. inference_mode excludes Autograd
top = ks.highest_priority_key()
kernel = op_table[op_id][top]
kernel(args)
This is conceptually ~50 lines of C++ in aten/src/ATen/core/dispatch/Dispatcher.{h,cpp}. Real implementation has fast paths and caching but the model is exact.
4.4 Redispatch¶
A kernel can run, then ask the dispatcher to re-run the same op but with its own key removed. That is how chaining works. In pseudocode for the autograd kernel for add:
Tensor add_autograd(const Tensor& a, const Tensor& b) {
// 1. Make output by redispatching to lower keys (skip Autograd).
auto out = at::redispatch::add(
c10::DispatchKeySet(c10::DispatchKey::AutogradCUDA).remove_from(ks),
a, b);
// 2. If grad mode is on and any input requires grad, build the graph node.
if (compute_requires_grad(a, b)) {
auto node = std::make_shared<AddBackward0>();
node->set_next_edges(collect_next_edges(a, b));
set_history(out, node);
}
return out;
}
So autograd doesn't do the math. It records bookkeeping, then asks the layer below to do the math.
4.5 TORCH_LIBRARY and TORCH_LIBRARY_IMPL¶
Two macros in C++. The first declares ops in a namespace; the second registers a backend implementation for some keys.
#include <torch/library.h>
// Declare the op. Namespace "myops". Schema is C++-typed Python.
TORCH_LIBRARY(myops, m) {
m.def("triple(Tensor x) -> Tensor");
}
// Implement for CPU.
Tensor triple_cpu(const Tensor& x) {
return x * 3;
}
TORCH_LIBRARY_IMPL(myops, CPU, m) {
m.impl("triple", triple_cpu);
}
// Implement for CUDA.
Tensor triple_cuda(const Tensor& x) {
return x * 3;
}
TORCH_LIBRARY_IMPL(myops, CUDA, m) {
m.impl("triple", triple_cuda);
}
You can also write m.impl("aten::add.Tensor", &my_add) to override a built-in op for your library / dispatch key. This is how custom backends, quantized ops, and out-of-tree devices plug in without touching ATen.
4.6 Layered keys: why autograd > autocast > backend¶
Imagine the opposite order. If autocast were above autograd, then when autograd records grad_fn, the recorded op would already have the cast applied -- so the backward would see the cast and might run in low precision unintentionally. Putting autograd on top means it sees the user-level op and the user-level dtypes. Conversely autocast sits above the backend so the cast happens before the kernel chooses its codepath.
The general rule: cross-cutting concerns that transform the call (cast, batch, fake-tensor-ify) sit above the backend; concerns that observe and record (autograd) sit above the transformers so their recording is faithful to user intent.
4.7 __torch_dispatch__ and __torch_function__¶
Two extension points worth knowing:
__torch_function__: a Python-level override. If you define a subclass ofTensorwith this method,torch.add(my_tensor, ...)will route through your method before doing anything else. Used by libraries liketorch.compile's subclass tracing,torch.func, and pretty-printing-only wrappers.__torch_dispatch__: a post-dispatcher override. Called from inside the dispatcher at thePythonkey. You see the canonical op (aten.add.Tensor) with already-resolved overloads. Used for FakeTensor, FunctionalTensor,LoggingTensor, and is the right hook for "I want to intercept everything below the API level."
If you only ever debug models you may never write either, but you will see them mentioned in stack traces and dynamo logs.
5. ATen Op Registration: native_functions.yaml¶
ATen's op surface is defined declaratively in aten/src/ATen/native/native_functions.yaml. Each entry looks roughly like:
- func: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
variants: function, method
dispatch:
CPU: add_cpu
CUDA: add_cuda
SparseCPU, SparseCUDA: add_sparse
MkldnnCPU: mkldnn_add
autogen: add.out
tags: pointwise, canonical
This declares:
- Schema:
add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor. This is the canonical Python-typed signature. The.Tensorsuffix is the overload name (vsadd.Scalar). - Variants: generate
at::add(...)(function form) andTensor::add(...)(method form). - Dispatch table: which C++ function implements which key.
- Autogen: also generate the
add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)variant from theout=pattern.
A codegen tool (driven from torchgen/, output mostly under build/aten/src/ATen/) reads this YAML and emits:
- Op symbol headers:
at::_ops::add_Tensorcallables. - Function variants: free functions in
at::and methods onat::Tensor. - Default implementations: `add_(...) {return at::add(...).copy_(...)} - style helpers.
- Autograd derivative bindings (combined with
derivatives.yaml):AddBackward0::applyetc. - Python bindings:
THPVariable_addetc. viatools/autograd/.
When you read PyTorch source and cannot find at::add's definition: it's generated. Look at aten/src/ATen/native/BinaryOps.cpp for the kernels and at the YAML for the contract.
derivatives.yaml (in tools/autograd/) is the sibling file. Each entry binds an op to its VJP:
- name: add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
self: grad
other: maybe_multiply(grad, alpha)
That tiny snippet is the autograd of add. The codegen turns it into an AddBackward0 Function class that returns (grad, alpha*grad).
6. The Autograd Engine¶
6.1 Intuition¶
Autograd in PyTorch is a dynamic tape: the graph is built every forward pass, executed once in reverse, and discarded. There is no "compile autograd". This buys flexibility (control flow, dynamic shapes) at the cost of allocating one graph node per differentiable op per forward.
6.2 Mechanism¶
Each differentiable op produces an output Tensor whose grad_fn is a Node (subclass of torch::autograd::Node, formerly Function). The Node holds:
next_edges_: list of(Node*, input_nr)pairs pointing at the Nodes that produced this op's inputs.- saved tensors / scalars needed for the backward (e.g., for
mul, both inputs). apply(grads_out) -> grads_in: the VJP.
grad_fn is null on leaves; leaves with requires_grad=True instead have an AccumulateGrad Node which writes into tensor.grad.
When you call loss.backward():
1. Engine seeds the gradient for `loss` (default: 1.0 if scalar).
2. It performs a reverse topological traversal starting at loss.grad_fn.
3. For each Node in topo order, call Node.apply(grad_outputs).
Result: grad_inputs, one per next_edge.
4. Send each grad_input to the corresponding next_edge's Node, accumulating.
5. When a Node's incoming grad count is satisfied, schedule it.
The engine is multi-threaded across devices: there is one worker per device that owns Nodes for that device (torch/csrc/autograd/engine.cpp). CPU work runs on the calling thread.
6.3 Trace of a tiny graph¶
import torch
a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=True)
c = a * b
d = c + a
d.backward()
print(a.grad, b.grad) # tensor(4.) tensor(2.)
Forward graph (as the dispatcher's autograd kernels build it):
AccumulateGrad(a) AccumulateGrad(b)
^ ^
| |
+-----------+ +------------+
| |
MulBackward0 <- saves a, b
|
v
c
|
+-----------+
|
AddBackward0 <- needs alpha=1 only
|
v
d
Backward execution starting from d with seed 1:
AddBackward0(grad=1) -> grad_c = 1, grad_a_partial = 1
MulBackward0(grad=1) -> grad_a_partial2 = b = 3, grad_b = a = 2
AccumulateGrad(a): a.grad = 1 + 3 = 4
AccumulateGrad(b): b.grad = 2
Note how a had two paths to it (through c and direct) and the engine summed contributions at AccumulateGrad. That summing is what next_edges plus accumulation buys you for free.
6.4 VJP definition per op¶
For each forward op y = f(x1, ..., xn), the VJP is grads_in = J^T @ grad_y. PyTorch implements VJPs op-by-op so it never materialises a Jacobian. For mul(a, b):
These are themselves PyTorch ops that go through the dispatcher. Crucially, by default they go through the dispatcher with the Autograd key still on -- enabling double backward (computing gradients of gradients). To turn that off you would compute backward in no_grad.
6.5 Custom torch.autograd.Function¶
You write one when no built-in op covers your forward, or when you have a faster handwritten backward. The shape:
import torch
class FusedScaleClamp(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale, lo, hi):
ctx.save_for_backward(x)
ctx.scale = scale
ctx.lo, ctx.hi = lo, hi
y = torch.clamp(x * scale, min=lo, max=hi)
return y
@staticmethod
def backward(ctx, grad_y):
(x,) = ctx.saved_tensors
s, lo, hi = ctx.scale, ctx.lo, ctx.hi
# grad flows only where the clamp didn't saturate
scaled = x * s
mask = (scaled > lo) & (scaled < hi)
grad_x = grad_y * s * mask.to(grad_y.dtype)
# No grads w.r.t. python scalars
return grad_x, None, None, None
# Use it
x = torch.randn(8, requires_grad=True)
y = FusedScaleClamp.apply(x, 2.0, -1.0, 1.0)
y.sum().backward()
print(x.grad)
Three rules to keep in mind:
- The number of returned grads in
backwardmust equal the number of inputs toforward. UseNonefor non-differentiable inputs (Python scalars, ints). - Anything you
save_for_backwardmust be a tensor; non-tensor context goes onctx.<attr>. - If your forward calls only differentiable PyTorch ops, you usually don't need a custom Function -- just write the function. Custom Functions are for when you sidestep autograd (e.g., calling a Triton kernel, or wanting a fused/cheaper backward).
6.6 Versioning and inplace¶
Each Storage carries a version counter. Inplace ops bump it. Saved tensors record the version they were saved at. On backward, the engine checks: if the version is now higher, you mutated a tensor that was needed for grad and the engine raises RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation. This is the famous error. The fix is almost always: don't += into something whose value the backward needs.
7. requires_grad, no_grad, inference_mode¶
7.1 requires_grad¶
A per-tensor flag (lives in AutogradMeta). When true, ops involving the tensor produce outputs whose grad_fn is set, and the tensor itself appears in the graph (via AccumulateGrad if it is a leaf).
Default: false for plain tensors, true for nn.Parameter.
7.2 no_grad¶
A thread-local override. Inside with torch.no_grad(): (or @torch.no_grad()), GradMode::is_enabled() returns false. The autograd kernel for each op checks this flag and, if grad mode is off, skips recording -- it just redispatches to the layer below. The op still runs through the autograd dispatch key (because the inputs still have AutogradCUDA in their key set), it just doesn't build graph nodes.
Use case: evaluation. You still get inference correctness; you save the bookkeeping cost of building backward graphs.
7.3 inference_mode¶
A stronger thread-local mode introduced for inference workloads. Inside with torch.inference_mode()::
- The Autograd dispatch keys are excluded from the key set entirely. The dispatcher no longer enters the autograd kernel at all -- it goes straight to the backend kernel.
- Outputs are marked as inference tensors. Their version counter is disabled. They cannot later be used in autograd.
Why is it faster than no_grad?
| Cost | no_grad | inference_mode |
|---|---|---|
| Enter autograd dispatch kernel | yes | no |
Allocate AutogradMeta for outputs |
yes (cheap but nonzero) | no |
| Bump version counter on inplace | yes | no |
| Lookup kernel via dispatcher | once per op | once per op |
You skip a whole layer of dispatch and a pile of small allocations. Benchmarks show ~5-15% overhead reduction on small ops where dispatch cost dominates.
The price: outputs cannot be used in autograd later. If you accidentally pass an inference tensor into a training graph, you get RuntimeError: Inference tensors cannot be saved for backward.
7.4 When to use which¶
- Training loop, ever: nothing.
- Eval loop you might re-enter training from:
torch.no_grad(). - Pure inference server:
torch.inference_mode(). Wrap once at the top of the request handler.
8. AMP / Autocast¶
8.1 Intuition¶
Mixed precision exists because matmul on tensor cores is much faster in FP16/BF16 than FP32, but reductions (loss, softmax denominator, layer norm stats) want FP32 to avoid catastrophic cancellation. Autocast classifies ops:
- lower-precision allowed (matmul, conv, linear): cast inputs down before running.
- must stay in FP32 (loss functions, softmax, layer norm in some cases): leave inputs alone, or cast up.
- promote (add of mixed dtypes): cast all inputs to the highest dtype present.
8.2 Mechanism: it's just another dispatch key¶
When you enter with torch.autocast("cuda", dtype=torch.float16):, the thread-local include set adds AutocastCUDA. Now any op called on a CUDA tensor inside the block has AutocastCUDA in its key set. The autocast kernel for that op runs before the backend kernel. It looks up the op's autocast policy:
// pseudocode for an autocast-lower op like matmul
Tensor matmul_autocastCUDA(const Tensor& a, const Tensor& b) {
auto target = at::autocast::current_dtype(c10::DeviceType::CUDA); // e.g. half
auto a_cast = cached_cast(target, a);
auto b_cast = cached_cast(target, b);
return at::redispatch::matmul(/*remove AutocastCUDA*/, a_cast, b_cast);
}
For a "must stay FP32" op, the autocast kernel casts up instead. cached_cast keeps a small thread-local cache so if you matmul the same weight twice in a forward you don't re-cast it.
The lists live in aten/src/ATen/autocast_mode.cpp (and FP16/BF16 specific lists). They are explicit: you can read which ops are "lower", "fp32", "promote".
8.3 GradScaler vs autocast-only¶
The danger of FP16 is gradient underflow: tiny gradients become 0. The fix is loss scaling: multiply the loss by a big number S before backward, divide grads by S before the optimizer step. If any grad becomes Inf/NaN, skip the step and shrink S; otherwise gradually grow S.
scaler = torch.cuda.amp.GradScaler()
for x, y in loader:
opt.zero_grad(set_to_none=True)
with torch.autocast("cuda", dtype=torch.float16):
out = model(x)
loss = loss_fn(out, y)
scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()
BF16 has the same exponent range as FP32, so underflow is not an issue and you do not need GradScaler:
for x, y in loader:
opt.zero_grad(set_to_none=True)
with torch.autocast("cuda", dtype=torch.bfloat16):
out = model(x)
loss = loss_fn(out, y)
loss.backward()
opt.step()
Rule of thumb on Ampere or later: BF16 unless you have a specific reason. Older GPUs (Volta, Turing) lack good BF16, so FP16 + GradScaler.
9. torch.compile Pipeline¶
torch.compile(model) returns a wrapped callable. Under the hood it composes three pieces: TorchDynamo (capture), AOTAutograd (joint forward+backward capture), Inductor (codegen).
+------------------+ +----------------+ +------------+ +------------+
| Python module | ---> | TorchDynamo | ---> | AOTAutograd | ---> | Inductor |
+------------------+ +----------------+ +------------+ +------------+
captures FX graph joint fwd/bwd Triton/C++
+ guards traced into kernels
core ATen + scheduling
9.1 TorchDynamo¶
Source: torch/_dynamo/.
Dynamo is a Python-level tracer that hooks into CPython's frame evaluation API (PEP 523). It registers an alternative frame evaluator. When Python is about to execute a function decorated by torch.compile, Dynamo intercepts the bytecode, symbolically executes it, and produces:
- An FX graph of tensor ops (
torch.fx.GraphModule). - A set of guards -- runtime predicates over inputs that, if true, mean it's safe to reuse this graph (e.g., "input 0 is
torch.float32", "input 0's shape is(B, 1024)for some int B", "this Python int equals 7"). - A residual bytecode that calls the compiled graph and does anything Dynamo couldn't handle.
The key idea: instead of tracing a Tensor program, Dynamo traces Python bytecode, with FakeTensors standing in for real tensors. Every PyTorch op invocation gets recorded into the FX graph. Every Python-level operation that isn't a tensor op (a list append, an if x.shape[0] > 16:) is either:
- specialised into a guard ("we recorded this branch when the shape was 32; if at runtime it's not, recompile"), or
- becomes a graph break.
Graph breaks¶
A graph break happens when Dynamo encounters something it cannot symbolically execute. Examples:
print(x)-- has a side effect Dynamo doesn't model.- Calling into a third-party C extension that isn't a torch op.
- A
try/exceptwhose handling Dynamo isn't sure how to capture. - Data-dependent control flow on a tensor without using
torch.cond. - Mutating a global Python object.
When a break happens, Dynamo:
- Compiles the graph it has so far.
- Falls back to the Python interpreter for the offending statement.
- Resumes tracing from the next statement -- producing a second graph after the break.
Each graph compiles separately and you pay the call/launch overhead between them. One graph break = lost optimization opportunity. Many = torch.compile may even be slower than eager.
Inspecting¶
import torch
@torch.compile
def f(x):
y = x.sin()
print("hi") # graph break
return y.cos()
torch._dynamo.explain(f)(torch.randn(4))
# prints: number of graphs, number of breaks, reasons, locations
torch._dynamo.explain is your first debugging tool. If it says "1 graph, 0 breaks", you're golden. If it says "5 graphs, 4 breaks", read each reason and fix.
Other useful env knobs:
TORCH_LOGS="dynamo" # what dynamo is tracing
TORCH_LOGS="graph_breaks" # only the breaks
TORCH_LOGS="recompiles" # why a graph re-compiled at runtime
TORCH_LOGS="output_code" # the generated kernels (see Inductor)
9.2 AOTAutograd¶
Source: torch/_functorch/aot_autograd.py and friends.
Once Dynamo hands an FX graph to the compiler backend, AOTAutograd does two things:
- Joint trace of forward + backward. It runs the forward FX graph through
make_fxwith grad enabled, then calls.backward()to also trace the backward. The output is one FX graph containing both, plus a partition that assigns nodes to "forward" or "backward" subgraphs (so they can run separately at runtime). - Decomposition to core ATen. Higher-level ops (e.g.,
torch.nn.functional.layer_norm) are decomposed into their constituent core ops (mean,var,mul,add, ...). This shrinks the op surface Inductor must handle from thousands to ~250 core ops.
The decomposition table lives in torch/_decomp/. You can see what an op decomposes to:
from torch._decomp import core_aten_decompositions
table = core_aten_decompositions()
for k, v in list(table.items())[:5]:
print(k)
After AOTAutograd you have two FX graphs in core ATen: forward_graph and backward_graph. They are pure functions of inputs (and saved-for-backward tensors). Now Inductor compiles each.
9.3 Inductor¶
Source: torch/_inductor/.
Inductor is a lowering compiler. It takes an FX graph in core ATen, builds an internal IR (Inductor IR) that represents loops over tensors, fuses adjacent pointwise ops into bigger loops, schedules reductions, and emits target code.
Two backends:
- CUDA / ROCm: emits Triton kernels. Triton handles the GPU-specific tiling and memory hierarchy; Inductor decides what to fuse and what shapes to specialise on.
- CPU: emits C++ with OpenMP pragmas, optionally with vector ISA intrinsics (AVX2/AVX-512). Compiles via the system compiler at runtime.
Pipeline inside Inductor:
core ATen FX graph
-> lowering (each op -> Inductor IR ops; eg. add -> Pointwise(...))
-> scheduler (group nodes that can fuse; assign to kernels)
-> codegen (emit Triton or C++)
-> compile (call Triton's autotuner / call cc -O3)
-> wrapper (Python wrapper that calls each kernel in order)
Fusion is the big win. A handwritten layer norm in eager is 5+ kernels (mean, var, sub, mul, add). Inductor often emits one fused Triton kernel. Same for activation+linear-bias, embedding+layernorm, etc.
Inspecting generated code¶
You get the Triton (or C++) source dumped. It's worth reading at least once -- you'll see something like:
@triton.jit
def triton_poi_fused_add_mul_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK: tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
a = tl.load(in_ptr0 + xindex, xmask)
b = tl.load(in_ptr1 + xindex, xmask)
c = a * 2.0 + b
tl.store(out_ptr0 + xindex, c, xmask)
That's a fused a*2 + b. Two separate eager ops collapsed into one launch.
You can also dump the FX graph after AOTAutograd:
9.4 Guards and recompilation¶
A compiled artifact is keyed by (op graph, guards). At each call, Dynamo evaluates the guards over the actual inputs. If all hold, run the cached compiled artifact. If any fail, recompile.
Common guards:
- Type guards: input is a
Tensor, dtype isfloat32, device iscuda:0. - Shape guards: rank is 3, sizes are
(2, ?, 1024). The?may be symbolic (a free variable) or static (a specific int) depending on the dynamic-shape mode. requires_gradguards: input hadrequires_grad=True. (Recompile if you switch eval/train without telling it.)- Python guards: a constant Python int equals N, a list has length M, a particular
nn.Moduleinstance is the same object.
Dynamic vs static shapes:
torch.compile(model)(default in 2.4+): tries dynamic shapes when it sees the same shape vary; specialises when it doesn't.torch.compile(model, dynamic=False): always specialise on shapes. Faster code, more recompiles.torch.compile(model, dynamic=True): assume dynamic from the start. Fewer recompiles, sometimes slower per-iter.
If you see frequent recompiles (TORCH_LOGS="recompiles"), the usual culprits are:
- Variable batch size with
dynamic=False. - Variable sequence length without marking it dynamic.
- Calling with different
requires_gradsettings (eval vs train without.eval()/.train()). - Using Python lists/tuples whose length varies.
9.5 Modes¶
torch.compile(model, mode="default")
torch.compile(model, mode="reduce-overhead")
torch.compile(model, mode="max-autotune")
- default: balanced. Compile time low-ish, runtime good.
- reduce-overhead: enables CUDA graphs around the compiled region. CUDA graphs eliminate per-op CUDA launch overhead by recording a sequence and replaying it as one submission. Big win for small batches and lots of small ops. Constraints: shapes must be static across replays, and tensors must live at the same addresses (Inductor handles this with persistent input buffers; you may need to
.clone()inputs or warm up). - max-autotune: Triton autotunes block sizes per kernel, multiple template variants for matmul, longest compile time, often best runtime.
9.6 End-to-end example¶
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(1024, 4096)
self.l2 = nn.Linear(4096, 1024)
def forward(self, x):
return self.l2(torch.relu(self.l1(x)))
model = Net().cuda().to(torch.bfloat16)
model = torch.compile(model, mode="reduce-overhead")
x = torch.randn(32, 1024, device="cuda", dtype=torch.bfloat16)
# Warm-up: compiles + records cuda graph
for _ in range(3):
y = model(x)
# Steady-state: every call is a single CUDA graph replay
for _ in range(1000):
y = model(x)
What happened on first call:
- Dynamo hooks
forward, traces it into FX (3 nodes: linear, relu, linear). - AOTAutograd skips backward (no grad needed), decomposes linear -> matmul + add.
- Inductor lowers, fuses relu with the second matmul's bias-add prologue if possible, generates two Triton matmul kernels and a fused activation/bias kernel.
- CUDA graph captures the launch sequence on call 2.
- Calls 3+ replay the graph.
10. Custom Op Registration (Modern Path)¶
You want to register a new op so that:
- It has an autograd rule.
- It survives
torch.compile(Dynamo and Inductor know what to do). - It works under FakeTensor / meta tracing.
The modern API is torch.library, available in 2.4+. Avoid the older torch.autograd.Function - only path when going throughtorch.compile`; Dynamo will graph-break on it.
10.1 Skeleton¶
import torch
# 1. Declare and implement the op for real backends
@torch.library.custom_op("mylib::myadd", mutates_args=())
def myadd(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y # or call into a Triton kernel here
# 2. Tell the compiler/dynamo how shapes propagate (the "fake" / meta impl)
@myadd.register_fake
def _(x, y):
# Must match real op's output shape/dtype/device, with no real compute
return torch.empty_like(x)
# 3. Register an autograd rule
def myadd_setup_context(ctx, inputs, output):
# Save what backward needs
pass # nothing for plain add
def myadd_backward(ctx, grad):
return grad, grad # dx, dy
myadd.register_autograd(myadd_backward, setup_context=myadd_setup_context)
Now mylib::myadd is a first-class op. You can call torch.ops.mylib.myadd(x, y) and it goes through the dispatcher like any built-in.
10.2 What each piece is for¶
custom_op: the user-facing implementation. Runs in eager mode.register_fake: Dynamo / FakeTensor /torch.exportuse this to symbolically execute your op. It must allocate output tensors with correct shape/dtype/device but no real values. Without this, Dynamo will graph-break at your op.register_autograd: the VJP. Mirrorstorch.autograd.Function.backward. Setup context can save tensors viactx.save_for_backward(...).mutates_args: tuple of arg names that are mutated in place. The compiler needs to know this for correctness when reordering / re-using buffers.
10.3 A Triton kernel as a custom op (worked example)¶
import torch
import triton
import triton.language as tl
@triton.jit
def _add_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < N
a = tl.load(x_ptr + offs, mask=mask)
b = tl.load(y_ptr + offs, mask=mask)
tl.store(out_ptr + offs, a + b, mask=mask)
def _add_launch(x, y):
out = torch.empty_like(x)
N = x.numel()
BLOCK = 1024
grid = ((N + BLOCK - 1) // BLOCK,)
_add_kernel[grid](x, y, out, N, BLOCK=BLOCK)
return out
@torch.library.custom_op("mylib::triton_add", mutates_args=())
def triton_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
assert x.is_cuda and y.is_cuda and x.shape == y.shape
return _add_launch(x.contiguous(), y.contiguous())
@triton_add.register_fake
def _(x, y):
return torch.empty_like(x)
def _bwd(ctx, g):
return g, g
triton_add.register_autograd(_bwd)
# Use it
x = torch.randn(4096, device="cuda", requires_grad=True)
y = torch.randn(4096, device="cuda", requires_grad=True)
z = torch.ops.mylib.triton_add(x, y)
z.sum().backward()
print(x.grad.shape, y.grad.shape)
# It also works under torch.compile thanks to register_fake
def f(a, b):
return torch.ops.mylib.triton_add(a, b).relu()
g = torch.compile(f)
g(x, y)
The key win: Dynamo treats triton_add as an opaque op (it doesn't try to look inside the Triton kernel). It uses register_fake to know the shape, and register_autograd to know how grads flow. Inductor will not fuse with surrounding ops -- but it also won't graph-break.
If you do want Inductor to fuse with surrounding ops, write the kernel in core ATen (let Inductor codegen its own Triton). Custom Triton ops are for cases where you have a hand-tuned kernel that beats codegen.
11. C++ Extension Path¶
When you need raw C++/CUDA, the supported flow is torch.utils.cpp_extension. There are two flavours:
- JIT (
load,load_inline): compiles on first import. Great for development. - AOT (
setup.pywithCUDAExtension/CppExtension): produces a wheel.
11.1 Minimal setup.py skeleton¶
src/kernel.cu:
#include <torch/extension.h>
__global__ void scale_kernel(const float* in, float* out, float s, int N) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < N) out[i] = in[i] * s;
}
torch::Tensor scale_cuda(torch::Tensor x, double s) {
TORCH_CHECK(x.is_cuda(), "x must be cuda");
TORCH_CHECK(x.scalar_type() == torch::kFloat32, "x must be float32");
auto y = torch::empty_like(x);
int N = x.numel();
int block = 256;
int grid = (N + block - 1) / block;
scale_kernel<<<grid, block>>>(x.data_ptr<float>(), y.data_ptr<float>(),
static_cast<float>(s), N);
return y;
}
src/binding.cpp:
#include <torch/extension.h>
torch::Tensor scale_cuda(torch::Tensor x, double s);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("scale", &scale_cuda, "scale(x, s) = x * s on CUDA");
}
setup.py:
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name="my_ext",
ext_modules=[
CUDAExtension(
name="my_ext",
sources=["src/binding.cpp", "src/kernel.cu"],
extra_compile_args={"cxx": ["-O3"], "nvcc": ["-O3"]},
),
],
cmdclass={"build_ext": BuildExtension},
)
Build and use:
11.2 ABI considerations¶
PyTorch is built with a specific C++ ABI (the C++11 GCC ABI on Linux). Your extension must be compiled against the same PyTorch headers and the same compiler ABI flag. Practical rules:
- Always build the extension on the machine where it will run, against the installed PyTorch wheel, or publish per-PyTorch-version wheels.
- Match the CUDA toolkit major version to PyTorch's CUDA major (e.g. CUDA 12.1 PyTorch -> CUDA 12.x toolkit).
- Avoid passing C++ exceptions across the pybind boundary. Use
TORCH_CHECKfor user errors -- it raises PythonRuntimeError. - Don't statically link C++ standard library; let the system one be used.
- For wheels, use
manylinux2014or newer base images; build separate wheels per (PyTorch version, CUDA version, Python version) tuple.
11.3 Registering the C++ op into the dispatcher¶
If you want it to be a real dispatcher op (so autograd, autocast, etc. integrate), use TORCH_LIBRARY (Section 4.5) instead of (or in addition to) the pybind binding. That gives you torch.ops.myext.scale(x, s) and full dispatch behavior.
12. The CUDA Caching Allocator¶
12.1 Why caching¶
cudaMalloc and cudaFree are slow (often 100us+ each) and synchronous on the default stream. A naive implementation would call them per tensor. PyTorch instead routes all CUDA tensor allocations through a caching allocator (c10/cuda/CUDACachingAllocator.cpp).
12.2 Mechanism¶
Conceptually:
CachingAllocator state per device:
large_blocks: sorted-by-size list of free blocks >= 1MB
small_blocks: sorted-by-size list of free blocks < 1MB
active_blocks: {ptr -> Block} # currently held by a Tensor
allocate(size, stream):
round size up (small to nearest 512B; large to nearest 2MB)
search the appropriate free list for a block of >= size on a compatible stream
if found: split off the suffix as a free block, return prefix
else:
cudaMalloc a fresh segment (geometric growth: 2MB, 20MB, 200MB, ...)
carve a block out of it, return it
free(ptr):
mark Block free
record the stream we last used it on; only reusable on that stream
unless cuda events confirm cross-stream safety
coalesce with neighbors in the same segment if both free
Two keys to internalize:
free()does not callcudaFree. It returns the block to the pool. From the driver's perspective the memory is still allocated.- Stream-aware reuse. Memory freed on stream A cannot be reused on stream B until events confirm the prior work has finished. This is why multi-stream code can OOM where single-stream code wouldn't: the allocator is conservatively keeping blocks pinned to the original stream.
12.3 empty_cache¶
Walks the free lists and actually cudaFrees segments that contain only free blocks. Returns memory to the driver -- visible to other processes (e.g., another container sharing the GPU). Does not shrink anything currently in use, and does not improve performance for your own process (the caching allocator was already going to reuse those blocks). Use it when you need to release memory across processes; do not sprinkle it through your training loop.
12.4 Reading memory_summary¶
You get a table like:
|---|------------|-----------|-----------|-----------|
| | Cur Usage| Peak Usage| Tot Alloc | Tot Freed |
|---|------------|-----------|-----------|-----------|
|Allocated memory | ... | ... | ... | ...
|Active memory | ... | ... | ... | ...
|GPU reserved memory| ... | ... | ... | ...
|Non-releasable mem | ... | ... | ... | ...
Definitions:
- Allocated: bytes currently held by Tensors.
- Active: allocated + still-pending-on-stream-thus-uncoalescable.
- Reserved: total
cudaMalloc'd. Reserved - Allocated = sitting in the cache. - Non-releasable: free-but-cannot-be-given-back-to-driver because the segment still has at least one in-use block.
12.5 Fragmentation¶
The classic failure mode: you have 3GB free in the cache, but it's split into 300 blocks of ~10MB each, none big enough to satisfy a 100MB request. The allocator does coalesce neighbors, but only within the same segment. Mitigations:
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True(PyTorch 2.0+). The allocator uses CUDA virtual memory APIs (cuMemMap/cuMemAddressReserve) to grow a single backing segment instead of allocating many. Drastically reduces fragmentation for variable-shape workloads.PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:N: don't split blocks larger than N MB, reducing tiny suffix blocks scattered around.- Avoid pathologic patterns: alternating very large and very small allocations on the same stream.
If you OOM at "tried to allocate 1GiB but only 500MiB free although 4GiB reserved", that is fragmentation. Check memory_summary; consider expandable segments.
13. Profiling Internals¶
torch.profiler.profile(...) (in torch/profiler/) records per-op entry/exit by registering callbacks at the dispatcher. Each time the dispatcher enters or exits an op, it calls every registered observer. The profiler is one such observer; so are autograd hooks and record_function.
13.1 Minimal use¶
import torch
from torch.profiler import profile, record_function, ProfilerActivity
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
with_stack=False,
) as prof:
with record_function("forward"):
y = model(x)
with record_function("loss"):
loss = (y - target).pow(2).mean()
with record_function("backward"):
loss.backward()
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=20))
prof.export_chrome_trace("trace.json") # open with chrome://tracing or perfetto
Open the trace and you get a flame-graph-like view: CPU dispatcher events on top, CUDA kernel events on the bottom with launch arrows. The gaps between kernels are launch overhead.
13.2 What each column means¶
Self CPU: time the op itself spent in CPU (dispatcher + Python -> C++ + kernel launch).CPU total: includes children. Alinearop's total includes itsmatmulandaddchildren.Self CUDA/CUDA total: same on GPU, measured via CUDA events.# of Calls: how many times this op key (with these shapes) was hit.
13.3 Reading patterns¶
- "Self CPU >> Self CUDA, kernels short": you are launch-overhead bound. Try
torch.compile(mode="reduce-overhead")for CUDA graphs, or batch up small ops. - "Self CUDA dominates, one kernel is 80% of it": profile that kernel; consider a different algorithm (FlashAttention, a fused MoE), or see if Inductor will generate a better one with
max-autotune. - "CUDA gaps with no work": host is too slow producing input. DataLoader is the usual suspect; bump
num_workers, prefetch, pin memory. - "memcpy dominating": you're moving data CPU<->GPU per step. Pin host memory, pre-load to GPU, or use
non_blocking=Truewith pinned source.
13.4 Autograd hooks for finer questions¶
hooks = []
for name, p in model.named_parameters():
h = p.register_hook(lambda g, n=name: print(n, g.norm()))
hooks.append(h)
loss.backward()
for h in hooks: h.remove()
Hook fires from the engine's worker thread when the grad is computed. Useful for "which parameter's grad is NaN".
14. Practical Exercises (with answers)¶
Exercise 1: stride sleuthing¶
You have:
Without running this, what are b.shape, b.stride(), b.storage_offset()? Is b.is_contiguous()? Why?
Answer. a.shape=(3,4,5), a.stride()=(20,5,1). Permute remaps dims by index: new dim 0 = old dim 2, new dim 1 = old dim 0, new dim 2 = old dim 1. So b.shape=(5,3,4), b.stride()=(1,20,5), b.storage_offset()=0. Not contiguous because canonical strides for (5,3,4) would be (12,4,1) and ours are (1,20,5).
Exercise 2: dispatch trace¶
You write:
where a, b are CUDA fp32 tensors with requires_grad=True. List the keys in the dispatch set for @, the order of kernels invoked, and the dtype of y.
Answer.
- Per-tensor key set: {AutogradCUDA, CUDA}.
- Local include from autocast: adds AutocastCUDA.
- Local exclude from no_grad: does not exclude Autograd keys (that's inference_mode's job). However, the autograd kernel checks GradMode::is_enabled() and, finding it false, just redispatches without recording.
So the order is: AutogradCUDA kernel (skips recording, redispatches) -> AutocastCUDA kernel (casts a,b to bf16, redispatches) -> CUDA kernel (runs bf16 matmul). y.dtype is bfloat16. y.requires_grad is False.
Exercise 3: detect a graph break¶
Why is this slow / breaky, and how do you fix?
Answer. n.item() materialises a tensor value to a Python int. Dynamo cannot symbolically execute that branch -- it triggers a graph break (or specialisation / recompile per value of n). The fix: use torch.where (data-dependent on tensor) or torch.cond (if you really need control flow):
Exercise 4: a custom op that survives compile¶
Write a custom op clip_norm(x, max_norm) that scales x so its L2 norm is at most max_norm. Make sure it works under torch.compile.
Answer.
import torch
@torch.library.custom_op("mylib::clip_norm", mutates_args=())
def clip_norm(x: torch.Tensor, max_norm: float) -> torch.Tensor:
n = x.norm()
scale = (max_norm / (n + 1e-12)).clamp(max=1.0)
return x * scale
@clip_norm.register_fake
def _(x, max_norm):
return torch.empty_like(x)
# autograd: easiest is to leave it to the implementation
# since it uses only autograd-aware ops; but custom_op disables
# autograd-through-implementation by default. So:
def _bwd(ctx, g):
x, scale = ctx.saved
# d(x*scale)/dx = scale (treating scale as constant for simplicity)
return g * scale, None
def _setup(ctx, inputs, output):
x, max_norm = inputs
n = x.norm()
scale = (max_norm / (n + 1e-12)).clamp(max=1.0)
ctx.saved = (x, scale)
clip_norm.register_autograd(_bwd, setup_context=_setup)
Note we approximate the gradient by treating scale as a constant -- that's the standard / desired behaviour for gradient clipping.
Exercise 5: why does this OOM?¶
Training fine yesterday, today OOMs at the same batch size. Memory summary shows Reserved=22GiB, Allocated=8GiB. Tried empty_cache, no change. What are two most likely causes and one mitigation?
Answer. Causes:
1. Fragmentation: the 14 GiB cache is split into too many small free blocks for a big request to find a contiguous free run. Mitigation: PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True (PyTorch 2.0+).
2. Stream-pinned blocks: a multi-stream codepath freed memory on stream A; the allocator can't yet hand it to stream B. Mitigation: synchronize, or unify streams.
empty_cache does nothing here because all 14GiB of free cache is non-releasable (segments still have allocated blocks).
Exercise 6: inference_mode vs no_grad latency¶
You have a hot inference path that runs ~50 small ops per request (lots of LayerNorm, GELU, small matmul). You measure 8% latency drop switching no_grad -> inference_mode. Why? Where would the gain be much smaller?
Answer. inference_mode excludes Autograd dispatch keys entirely, so each op skips the autograd kernel layer (a function call, a check, and an output AutogradMeta allocation). Saving ~hundreds of nanoseconds per op times ~50 ops times batch is a measurable percentage when individual ops are short.
The gain shrinks toward zero as ops get bigger: a single 4096x4096 fp16 matmul takes milliseconds, dwarfing the dispatch cost. The win is in launch-overhead-bound regimes. For one big op per request, prefer profiling to see if it's worth bothering.
15. A Coherent Mental Model To Keep¶
If you remember nothing else, remember these seven sentences:
- A
Tensoris a(storage, sizes, strides, storage_offset, dtype, device)tuple. - Views share storage;
contiguous()materialises. - The dispatcher picks a kernel from a key set assembled from inputs and thread-local mode; layers redispatch by removing their own key.
- Autograd is a layer in the dispatcher that, when grad mode is on, builds a tape;
loss.backward()traverses it. inference_modeis faster thanno_gradbecause it removes the Autograd layer entirely.- Autocast is just a dispatcher layer that casts inputs before the backend runs.
torch.compileis Dynamo (capture Python -> FX) + AOTAutograd (joint forward/backward in core ATen) + Inductor (Triton/C++ codegen with fusion); guards govern when the compiled artifact is reused.
Everything else -- caching allocator, profiler, custom ops -- hangs off these. Once you can simulate the dispatcher in your head and reason about the compile pipeline, PyTorch internals stop feeling like a foreign country and become a place you live.
Appendix A: Source Tree Map¶
A high-confidence cheat sheet for navigating the repo:
| Path | Contents |
|---|---|
c10/core/ |
TensorImpl, Storage, DispatchKey, Device, Layout. The smallest, most stable foundation. |
c10/cuda/ |
CUDACachingAllocator, CUDA stream/event wrappers. |
aten/src/ATen/core/ |
The dispatcher (Dispatcher.cpp), op registration (library.cpp). |
aten/src/ATen/native/ |
Op kernels (CPU / generic). BinaryOps.cpp, ReduceOps.cpp, etc. |
aten/src/ATen/native/cuda/ |
CUDA kernels. |
aten/src/ATen/native/native_functions.yaml |
The op-schema source of truth. |
aten/src/ATen/autocast_mode.cpp |
Autocast policies (which ops are FP16/BF16/FP32/promote). |
tools/autograd/derivatives.yaml |
Op-by-op VJPs for codegen. |
torch/csrc/autograd/ |
Engine, Function (Node), saved variables. |
torch/_dynamo/ |
TorchDynamo (frame eval, symbolic execution, guards). |
torch/_functorch/aot_autograd.py |
AOTAutograd. |
torch/_decomp/ |
Decompositions to core ATen. |
torch/_inductor/ |
Inductor lowering, scheduler, codegen (codegen/triton.py, codegen/cpp.py). |
torch/library.py |
Modern custom-op API (custom_op, register_fake, register_autograd). |
torch/utils/cpp_extension.py |
JIT and AOT C++/CUDA extension build helpers. |
torch/profiler/ |
Profiler frontend; backend in torch/csrc/profiler/. |
When in doubt, git grep for the op name in aten/src/ATen/native/ -- the kernel is almost always there.
Appendix B: Useful Environment Variables¶
TORCH_LOGS=dynamo,graph_breaks,recompiles,aot_graphs,output_code
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb:128
TORCH_SHOW_DISPATCH_TRACE=1 # prints kernel choice per op (verbose)
TORCH_USE_CUDA_DSA=1 # device-side assertions for shape/index errors
CUDA_LAUNCH_BLOCKING=1 # serialises kernel launches; better stack traces
TORCHINDUCTOR_CACHE_DIR=/tmp/inductor # control compile cache location
TORCHINDUCTOR_MAX_AUTOTUNE=1 # equivalent to mode="max-autotune"
End of chapter.