Triton: A Deep Dive into the Block-Level GPU DSL¶
A self-contained reference. After working through this chapter you should be able to read, write, debug, autotune, and integrate Triton kernels into a PyTorch project without consulting any other source. We assume you know Python, basic linear algebra, what a GPU is, and roughly what CUDA does (threads, blocks, shared memory, warps). We do not assume you have ever written a CUDA kernel by hand.
Table of contents¶
- Why Triton exists
- The programming model
- Memory operations: load, store, mask, broadcast
- Math operations: elementwise,
tl.dot, reductions, transcendentals - Compile-time constants and specialization
- Autotuning
- Numerical stability patterns (online softmax derivation, Welford, log-sum-exp)
- Six fully-annotated real kernels
- The compilation pipeline (Python -> MLIR -> PTX)
- Integration with PyTorch (
torch.library,torch.compile, autograd) - Common pitfalls
- Triton vs CUTLASS vs hand-rolled CUDA
- Six exercises with worked answer sketches
- Closing notes
1. Why Triton exists¶
1.1 The CUDA per-thread model¶
CUDA exposes the GPU as a hierarchy of threads organised into warps (32 threads), then blocks (one or more warps), then a grid of blocks. The programmer writes code from the perspective of a single thread:
// CUDA: each thread handles a single output element of c = a + b
__global__ void vec_add(const float* a, const float* b, float* c, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) c[i] = a[i] + b[i];
}
This per-thread view is faithful to the hardware (an SM really does dispatch
warps, each a SIMD-32 unit), but for a kernel author it forces a constant
mental gear shift. When you think "tile of A times tile of B equals tile of
C", you are reasoning block-wise. CUDA makes you re-derive that block-level
operation as a sequence of per-thread loads, per-thread __shared__ writes,
__syncthreads() barriers, per-thread MMA inputs, accumulator registers,
mask handling at the trailing edge of the matrix, etc. Every one of those
steps is a place to make a memory-coalescing, bank-conflict, or
register-pressure mistake. Worse: changing the tile size invalidates almost
all of that hand-crafted indexing.
1.2 The Triton thesis (MAPL 2019)¶
Tillet, Kung, and Cox (MAPL 2019, "Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations") observe that almost every high-performance GPU kernel for ML is naturally written as a loop over blocks of the output, where each iteration computes one output tile from input tiles. They propose: let the programmer write that loop directly, with block-shaped tensors as first-class values, and let a compiler lower the block program to a per-thread CUDA-like form -- choosing vector widths, shared-memory layouts, double-buffered pipelining, swizzling, and tensor core instructions automatically.
Concretely, the human writes (paraphrased):
Program instance
pidis responsible for output rowspid*BM ... pid*BM+BM. Allocate a block-tensoraccof shape[BM, BN]. Loop over the K dimension in blocks ofBK: load an[BM, BK]tile of A, an[BK, BN]tile of B, doacc += dot(a, b). After the loop, storeaccto C.
The compiler turns that into hundreds of lines of well-vectorised PTX with
correct shared-memory layout, asynchronous copies, and mma instructions.
1.3 What Triton gives up, and why that is fine¶
Triton intentionally cannot express:
- arbitrary intra-block thread-divergent control flow,
- explicit
__syncthreads(), - user-controlled shared memory,
- warp-level shuffles as a primary interface (some primitives exist, but they are not the model).
For ML kernels this is rarely a loss: 95% of useful kernels are tiled elementwise / reduce / matmul / softmax compositions, and for those the compiler's choices are at least as good as a moderately experienced CUDA author and far less error-prone.
1.4 Where the productivity comes from¶
Three places:
- Indexing is automatic. You write
tl.arange(0, BLOCK)and broadcast; the compiler generates the per-thread index arithmetic. - Shared memory is automatic.
tl.loadof a 2D block lowers to coalesced global -> shared copies with the right swizzle for downstreamtl.dot. - Tile-size sweeps are cheap. Because block sizes are
constexpr, the compiler specialises and you can autotune over a config grid in one decorator.
The empirical claim of the original paper, repeated through every release: hand-written Triton matmul reaches ~80% of cuBLAS on common square shapes and frequently beats cuBLAS on awkward shapes (small K, non-multiple-of-16 dims) where the vendor library has no specialised kernel.
2. The programming model¶
2.1 @triton.jit and the program-instance abstraction¶
A Triton kernel is a Python function decorated with @triton.jit:
import triton
import triton.language as tl
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n_elements,
BLOCK: tl.constexpr):
pid = tl.program_id(axis=0) # 0..num_programs-1
offsets = pid * BLOCK + tl.arange(0, BLOCK) # vector of indices
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
tl.store(out_ptr + offsets, x + y, mask=mask)
The decorator does not run the function on the host. It captures the Python AST, lowers it to Triton IR (an MLIR dialect), then to PTX, on first call. The body of the function executes on the GPU.
Inside the body, you are one program instance. There is no threadIdx.
The total number of program instances is set when you launch:
n = x.numel()
grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),) # 1D grid
add_kernel[grid](x, y, out, n, BLOCK=1024)
grid is either a tuple (1D/2D/3D) or a callable that receives the
compile-time meta-parameters and returns a tuple. triton.cdiv(a, b) is
ceil-division.
A program instance is roughly equivalent to a CUDA block, not a thread.
Inside it, the compiler will allocate num_warps * 32 actual threads and
distribute the block-level work across them. You influence that by passing
num_warps= at launch (or via autotune).
2.2 Block-tensors as first-class values¶
tl.arange(0, N) and any expression built from it -- arithmetic,
broadcasting, tl.load, tl.dot, reductions -- produces a block tensor
whose shape is known at compile time. These tensors live in registers (with
the compiler spilling to shared memory or even global as a last resort).
offs_m = tl.arange(0, BM) # shape [BM]
offs_n = tl.arange(0, BN) # shape [BN]
offs_2d = offs_m[:, None] * BN + offs_n[None, :] # shape [BM, BN]
x[:, None] and x[None, :] are NumPy-style broadcasts; the compiler
generates the per-thread index arithmetic that makes a 2D logical tile out
of registers held by 32-wide warps.
2.3 Side-by-side: vector add¶
Same kernel in CUDA:
__global__ void add_kernel(const float* x, const float* y, float* out, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) out[idx] = x[idx] + y[idx];
}
// host:
add_kernel<<<(n + 255) / 256, 256>>>(x, y, out, n);
The Triton version is structurally similar but works at the vector level:
one program does 1024 elements at once; the compiler decides how to assign
those 1024 lanes across 32 threads (probably 4 elements per thread, with
128-bit loads). You did not have to choose. The CUDA author who wants the
same vectorisation has to use float4 casts and do four-way unrolling by
hand.
Exercise (warm-up). Modify add_kernel so it computes
out = a*x + y (axpy). What changes? Answer: add an a scalar argument
and replace the store value with a*x + y. Nothing else.
3. Memory operations¶
3.1 tl.load and tl.store¶
Signatures (simplified, current as of 2026 Triton; verify if you target an older release):
tl.load(pointer, mask=None, other=0.0, cache=None, eviction_policy=None,
volatile=False)
tl.store(pointer, value, mask=None, cache=None, eviction_policy=None)
Key semantics:
pointeris either a scalar pointer (rare) or a block tensor of pointers formed bybase_ptr + offsets. The compiler infers the block shape fromoffsets.mask, when provided, is a same-shape boolean tensor. Lanes wheremask=Falseskip the memory access entirely (no fault, no traffic).otheris the value substituted for masked-off lanes onload. Default0.0. Pick this carefully: if you mask in atl.dotreduction, the masked lanes still participate arithmetically, soother=0.0makes them contribute zero -- which is what you want.
3.2 Boundary masking: why it always matters¶
Take a vector add over n=1000 elements with BLOCK=128. You will launch
ceil(1000/128) = 8 programs covering 1024 elements. The last program
must mask off the trailing 24 lanes or it will read past the end of the
tensor and either segfault or corrupt data:
offsets = pid * BLOCK + tl.arange(0, BLOCK)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)
In CUDA the equivalent is the if (idx < n) guard around a single thread.
The Triton version is more efficient because the entire warp issues one
masked load instruction; in CUDA each thread independently decides to load
or not.
For 2D blocks, you need a 2D mask:
offs_m = pid_m * BM + tl.arange(0, BM) # [BM]
offs_n = pid_n * BN + tl.arange(0, BN) # [BN]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
ptrs = X + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
x = tl.load(ptrs, mask=mask, other=0.0)
This is the canonical pattern. Get the strides right and the mask right and you have a correct boundary-aware tile load.
3.3 Broadcasting¶
tl.arange returns a 1D tensor; expressions broadcast like NumPy:
A = tl.zeros([BM, BN], dtype=tl.float32) # [BM, BN] of zero
v = tl.arange(0, BN) # [BN]
A2 = A + v[None, :] # broadcast across rows
The compiler chooses the warp/lane layout so the broadcast costs nothing (it is a register-relabelling, not a memory op). This is one of the places Triton silently saves you a lot of CUDA boilerplate.
3.4 Cache and eviction hints¶
cache="ca" (cache all), cache="cg" (cache global, skip L1),
eviction_policy="evict_last" | "evict_first" map to PTX cache-modifier
suffixes. Useful for reused tiles (the K-tile of A re-read by every column
block) vs. one-shot tiles. As a first pass, leave them at default.
Exercise. Why does tl.load(ptr, mask=mask, other=0.0) in a tl.dot
reduction not introduce numerical error from the masked lanes? Answer:
because 0.0 is the additive identity, and tl.dot is a sum-of-products,
masked-off lanes contribute exactly zero to the accumulator.
4. Math operations¶
4.1 Elementwise¶
All standard arithmetic (+ - * /), comparison (< > == !=), and
tl.exp, tl.log, tl.sqrt, tl.rsqrt, tl.sin, tl.cos, tl.tanh, tl.sigmoid,
tl.where(cond, a, b) operate elementwise on block tensors and follow the
broadcasting rules of section 3.
tl.where is your branchless conditional. There is no per-lane if/else
in Triton; you use where for value selection and mask for memory access
gating.
4.2 tl.dot and tensor cores¶
acc = tl.dot(a, b, acc=acc, allow_tf32=True)
# a: [M, K], b: [K, N], acc: [M, N]; result shape [M, N]
tl.dot is the only operation in Triton that lowers to the GPU's
tensor-core MMA instructions on Volta+. It requires:
M, N, Kdimensions to be at least 16 (often 16/16/16 minimum; on Hopper larger MMAs are used). Sub-16 dims lower to FMA loops, which is fine but slow.- dtypes that the hardware supports: fp16, bf16, tf32 (the fp32 path goes
through tf32 on A100 if
allow_tf32=True, which is the default in PyTorch ML), fp8 on Hopper. accto be fp32 (recommended) for numerical safety. The accumulator stays in registers across iterations of the K loop.
Without tl.dot you cannot use tensor cores from Triton. Memorise this.
4.3 Reductions¶
m = tl.max(x, axis=1) # reduce along axis 1
s = tl.sum(x, axis=0)
mn = tl.min(x, axis=0)
am = tl.argmax(x, axis=1) # newer Triton versions
Reductions return a tensor with the reduced axis removed. They lower to warp-level shuffle reductions where possible and shared-memory reductions across warps. You do not write either by hand.
4.4 Transcendentals¶
tl.exp, tl.log, tl.sqrt, tl.rsqrt, tl.sin, tl.cos, tl.erf,
tl.tanh, tl.sigmoid exist. Two notes:
- On NVIDIA they typically lower to the fast PTX intrinsics (
exp.approx.f32etc.) for fp32, which are ~1 ulp accurate, fine for ML. tl.exp2is often slightly faster thantl.expbecause the hardware natively computes base-2 exponential. You can rewrite softmax usingexp2(x * log2_e)if you really need the cycles; usually not worth it.
Exercise. Implement a fused GELU: 0.5 * x * (1 + tanh(sqrt(2/pi) *
(x + 0.044715 * x^3))). Sketch: one elementwise kernel, mask on
boundaries; everything is a tl.tanh/*/+. About 12 lines.
5. Compile-time constants: tl.constexpr¶
5.1 Why constexpr exists¶
Block sizes (BM, BN, BK, BLOCK) influence the shape of every block
tensor in the kernel. Shape determines:
- register allocation,
- shared-memory tile size,
- MMA instruction shape selection,
- loop unrolling.
The compiler must know these at codegen time. So they are passed as
tl.constexpr parameters:
When you launch with kernel[grid](..., BM=128, BN=128, BK=32), Triton
hashes the constexpr values and the dtypes/strides/etc. into a cache key.
First call with a new key triggers compilation; subsequent calls reuse the
cached PTX.
5.2 What else can be constexpr¶
Anything that should be specialised: a bool selecting causal vs. non-causal
attention, an int choosing reduction axis, an enum - likedtype`. Use
constexpr aggressively for branches that you want the compiler to delete
in the specialised version.
@triton.jit
def attn(..., CAUSAL: tl.constexpr):
if CAUSAL:
# this branch is compiled away when CAUSAL=False
mask = offs_m[:, None] >= offs_n[None, :]
scores = tl.where(mask, scores, float("-inf"))
That if CAUSAL is not a runtime branch -- it is Python-level dead-code
elimination at JIT time.
5.3 The cost: recompilation¶
Each unique (constexpr-value, dtype, contiguity) tuple is a separate kernel.
Cache lives in ~/.triton/cache. If your constexpr space is unbounded
(e.g. you set BLOCK=n per call) you will recompile constantly. Pick a
finite menu and stick to it.
6. Autotuning¶
6.1 The decorator¶
@triton.autotune(
configs=[
triton.Config({'BM': 128, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
triton.Config({'BM': 128, 'BN': 64, 'BK': 32}, num_warps=4, num_stages=4),
triton.Config({'BM': 64, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=4),
triton.Config({'BM': 64, 'BN': 64, 'BK': 32}, num_warps=2, num_stages=5),
triton.Config({'BM': 256, 'BN': 64, 'BK': 32}, num_warps=8, num_stages=3),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(...): ...
key= lists the runtime arguments whose values define a tuning bucket.
On first call with a particular (M, N, K) triple, Triton runs every
config, measures wall time, picks the winner, caches it. Subsequent calls
with the same (M, N, K) skip the search.
6.2 Knobs¶
- Block sizes (
BM, BN, BKfor matmul;BLOCK_SIZEfor elementwise) control register pressure and shared-memory tile size. Bigger usually means more reuse but more spills. num_warpscontrols how many 32-thread warps form the program instance. More warps means more parallelism per SM but each warp gets fewer registers.num_stagescontrols software pipelining depth: how many K-iterations worth of tiles are simultaneously in flight (cp.asyncon Ampere+, TMA on Hopper). 2-5 is typical. Too many overflows shared memory; too few starves the MMA units.
6.3 Practical autotune¶
- Limit configs to a dozen or so. The first call already pays the worst-case cost (you run every config), so 30 configs with 200ms each is six seconds of warmup.
key=['M', 'N', 'K']is right for matmul; for elementwise,key=['n']or no key (single bucket) is fine.- Use
prune_configs_by={'early_config_prune': fn, 'perf_model': fn}if the search is huge. Usually unnecessary. - Set
TRITON_PRINT_AUTOTUNING=1to see the table of configs and their timings. This is invaluable for sanity-checking.
6.4 Caching¶
Compiled kernels are cached on disk under ~/.triton/cache/<hash>/. The
hash is over: function source, argument dtypes, constexpr values, GPU
arch. Move to a different GPU and you recompile. This is fine.
7. Numerical stability patterns¶
Three patterns dominate ML kernels: online softmax, Welford for running mean/variance, log-sum-exp.
7.1 The naive softmax problem¶
softmax(x)_i = exp(x_i) / sum_j exp(x_j). If any x_i > 88.7 (fp32) or
>11.1 (fp16), exp(x_i) overflows. The standard fix:
This requires two passes over x: one to find max(x), one to compute
the numerator/denominator. For attention, where x is the score row of
length seq_len, two passes mean two reads of a tensor that does not fit
in registers if seq_len is large. Bandwidth-bound.
7.2 Online softmax: the derivation¶
We want a one-pass algorithm. Suppose we have processed a prefix of length
k and we know:
m_k = max(x_1..x_k),l_k = sum_{j=1..k} exp(x_j - m_k).
The softmax over the prefix is exp(x_j - m_k) / l_k.
Now we see x_{k+1}. The new max is m_{k+1} = max(m_k, x_{k+1}). We need
l_{k+1} = sum_{j=1..k+1} exp(x_j - m_{k+1}).
Split the sum:
l_{k+1} = sum_{j=1..k} exp(x_j - m_{k+1}) + exp(x_{k+1} - m_{k+1})
= sum_{j=1..k} exp(x_j - m_k) * exp(m_k - m_{k+1}) + exp(x_{k+1} - m_{k+1})
= l_k * exp(m_k - m_{k+1}) + exp(x_{k+1} - m_{k+1})
That is the recurrence. The factor exp(m_k - m_{k+1}) is <= 1 and
rescales the running denominator whenever the max grows. The recurrence
generalises trivially to blocks of new elements rather than single
elements: you compute the local max and local sum of the new block, then
combine with the running (m, l) exactly as above. This is the heart of
flash-attention.
In code (block version):
m_block = tl.max(x_block, axis=0) # local max
p_block = tl.exp(x_block - m_block)
l_block = tl.sum(p_block, axis=0) # local sum
m_new = tl.maximum(m, m_block)
alpha = tl.exp(m - m_new)
beta = tl.exp(m_block - m_new)
l_new = alpha * l + beta * l_block
m, l = m_new, l_new
# ...and rescale any running output by alpha
7.3 Welford for variance¶
For RMSNorm forward you only need mean(x^2), which is a single pass. For
LayerNorm forward you need both mean(x) and var(x). The naive
"sum of squares minus square of sum" formula is catastrophic in fp32 for
large vectors. Welford is the one-pass numerically stable form. For block
combination:
n_combined = n_a + n_b
delta = mean_b - mean_a
mean_combined = mean_a + delta * n_b / n_combined
M2_combined = M2_a + M2_b + delta^2 * n_a * n_b / n_combined
var = M2_combined / n_combined
You will see this exact pattern in the LayerNorm tutorial kernels.
7.4 Log-sum-exp¶
logsumexp(x) = log(sum exp(x)) is m + log(sum exp(x - m)) for m =
max(x). Same trick as softmax. Useful inside cross-entropy loss kernels
where you want - x_target + logsumexp(x)` directly.
8. Six annotated kernels¶
We progress from the trivial to flash-attention. Every kernel below is a runnable example. Imports assumed:
8.1 Vector add¶
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n,
BLOCK: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < n
x = tl.load(x_ptr + offs, mask=mask)
y = tl.load(y_ptr + offs, mask=mask)
tl.store(out_ptr + offs, x + y, mask=mask)
def add(x, y):
out = torch.empty_like(x)
n = x.numel()
grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),)
add_kernel[grid](x, y, out, n, BLOCK=1024)
return out
This kernel is bandwidth-bound; it should approach peak HBM bandwidth on any reasonable GPU. If you measure 80% of peak you are doing fine; the remaining 20% is launch overhead and tail effects.
8.2 Naive matmul¶
A naive matmul has each program compute one output element via an inner loop over K. Useful as a teaching baseline; do not ship it.
@triton.jit
def matmul_naive(A, B, C, M, N, K,
sa_m, sa_k, sb_k, sb_n, sc_m, sc_n,
BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BM + tl.arange(0, BM)
offs_n = pid_n * BN + tl.arange(0, BN)
offs_k = tl.arange(0, BK)
a_ptrs = A + offs_m[:, None] * sa_m + offs_k[None, :] * sa_k
b_ptrs = B + offs_k[:, None] * sb_k + offs_n[None, :] * sb_n
acc = tl.zeros([BM, BN], dtype=tl.float32)
for k in range(0, K, BK):
a_mask = (offs_m[:, None] < M) & ((k + offs_k)[None, :] < K)
b_mask = ((k + offs_k)[:, None] < K) & (offs_n[None, :] < N)
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
acc += tl.dot(a, b)
a_ptrs += BK * sa_k
b_ptrs += BK * sb_k
c_ptrs = C + offs_m[:, None] * sc_m + offs_n[None, :] * sc_n
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, acc.to(C.dtype.element_ty), mask=c_mask)
Already this is much simpler than the equivalent CUDA, which would need
explicit shared-memory tiles, __syncthreads(), and per-thread MMA inputs.
8.3 Tiled matmul with autotune¶
The canonical Triton matmul. The two changes from 8.2 are:
- L2-cache-friendly program-id swizzle (group along M to reuse B tiles).
@triton.autotuneover a config grid.
def get_configs():
return [
triton.Config({'BM':128,'BN':256,'BK':32,'GROUP_M':8}, num_warps=8, num_stages=3),
triton.Config({'BM':128,'BN':128,'BK':32,'GROUP_M':8}, num_warps=4, num_stages=4),
triton.Config({'BM':128,'BN':64, 'BK':32,'GROUP_M':8}, num_warps=4, num_stages=4),
triton.Config({'BM':64, 'BN':128,'BK':32,'GROUP_M':8}, num_warps=4, num_stages=4),
triton.Config({'BM':64, 'BN':64, 'BK':32,'GROUP_M':8}, num_warps=2, num_stages=5),
triton.Config({'BM':128,'BN':128,'BK':64,'GROUP_M':8}, num_warps=8, num_stages=3),
]
@triton.autotune(configs=get_configs(), key=['M', 'N', 'K'])
@triton.jit
def matmul_tiled(A, B, C, M, N, K,
sa_m, sa_k, sb_k, sb_n, sc_m, sc_n,
BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr,
GROUP_M: tl.constexpr):
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BM)
num_pid_n = tl.cdiv(N, BN)
# L2-friendly swizzle: walk in (GROUP_M, num_pid_n) super-blocks
num_pid_in_group = GROUP_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BM + tl.arange(0, BM)) % M
offs_bn = (pid_n * BN + tl.arange(0, BN)) % N
offs_k = tl.arange(0, BK)
a_ptrs = A + offs_am[:, None] * sa_m + offs_k[None, :] * sa_k
b_ptrs = B + offs_k[:, None] * sb_k + offs_bn[None, :] * sb_n
acc = tl.zeros([BM, BN], dtype=tl.float32)
for k in range(0, tl.cdiv(K, BK)):
k_remaining = K - k * BK
a_mask = offs_k[None, :] < k_remaining
b_mask = offs_k[:, None] < k_remaining
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
acc = tl.dot(a, b, acc=acc)
a_ptrs += BK * sa_k
b_ptrs += BK * sb_k
offs_cm = pid_m * BM + tl.arange(0, BM)
offs_cn = pid_n * BN + tl.arange(0, BN)
c_ptrs = C + offs_cm[:, None] * sc_m + offs_cn[None, :] * sc_n
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, acc.to(C.dtype.element_ty), mask=c_mask)
def matmul(a, b):
assert a.shape[1] == b.shape[0]
M, K = a.shape
_, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
grid = lambda meta: (triton.cdiv(M, meta['BM']) * triton.cdiv(N, meta['BN']),)
matmul_tiled[grid](
a, b, c, M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
)
return c
Why the swizzle? Without it, neighbouring program ids cover neighbouring N columns of C with the same row of A. With it, neighbours share both A and B reuse, which keeps the L2 hit rate high on big problems. This trick alone is often worth 20-40% on large square matmul.
On A100/H100 this kernel reaches roughly 80% of cuBLAS for square matmul with M,N,K multiples of 128, and it frequently beats cuBLAS on irregular shapes (e.g. K=80, common in attention with head_dim=80). Do not take exact numbers on faith; benchmark on your hardware.
8.4 Fused softmax (rowwise, single pass-per-row)¶
For rows that fit entirely in one block (N <= BLOCK_N):
@triton.jit
def softmax_kernel(X, Y, sx, sy, N,
BLOCK: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(X + row * sx + cols, mask=mask, other=-float("inf"))
x = x - tl.max(x, axis=0)
num = tl.exp(x)
den = tl.sum(num, axis=0)
y = num / den
tl.store(Y + row * sy + cols, y, mask=mask)
def softmax(x):
M, N = x.shape
BLOCK = triton.next_power_of_2(N)
y = torch.empty_like(x)
softmax_kernel[(M,)](x, y, x.stride(0), y.stride(0), N, BLOCK=BLOCK,
num_warps=4 if BLOCK <= 1024 else 8)
return y
Notes:
other=-inffor masked-off lanes ensures they cannot become the max and do not contribute to the sum (exp(-inf) = 0).- This whole row lives in registers, so we have a true one-pass softmax; no online-softmax recurrence needed.
- For very wide rows (
N > 64K), you would use a tiled version with the online-softmax recurrence from section 7.
8.5 RMSNorm forward + backward¶
RMSNorm: y_i = x_i / sqrt(mean(x^2) + eps) * w_i.
Forward.
@triton.jit
def rmsnorm_fwd(X, W, Y, RSTD, sx, sy, N, eps,
BLOCK: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(X + row * sx + cols, mask=mask, other=0.0).to(tl.float32)
var = tl.sum(x * x, axis=0) / N
rstd = 1.0 / tl.sqrt(var + eps)
tl.store(RSTD + row, rstd) # save for backward
w = tl.load(W + cols, mask=mask, other=0.0)
y = x * rstd * w
tl.store(Y + row * sy + cols, y.to(Y.dtype.element_ty), mask=mask)
Backward. Given dy, compute dx, dw. The math:
y_i = x_i * w_i * r, where r = (mean(x^2) + eps)^{-1/2}
dr/dx_j = -(1/N) * x_j * r^3 (since d/dx_j var = 2 x_j / N)
dy_i/dx_j = w_i * r * delta_{ij} + x_i * w_i * dr/dx_j
= w_i * r * delta_{ij} - (1/N) * x_i * w_i * x_j * r^3
dx_j = sum_i dy_i * dy_i/dx_j
= w_j * r * dy_j - (r^3 / N) * x_j * sum_i (x_i * w_i * dy_i)
= r * (w_j * dy_j - x_j * (r^2 / N) * S), where S = sum_i x_i * w_i * dy_i
So we need a single reduction S per row, then a fused elementwise pass:
@triton.jit
def rmsnorm_bwd_dx(X, W, DY, RSTD, DX, sx, sdy, sdx, N,
BLOCK: tl.constexpr):
row = tl.program_id(0)
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(X + row * sx + cols, mask=mask, other=0.0).to(tl.float32)
w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32)
dy = tl.load(DY + row * sdy + cols, mask=mask, other=0.0).to(tl.float32)
r = tl.load(RSTD + row).to(tl.float32)
S = tl.sum(x * w * dy, axis=0)
dx = r * (w * dy - x * (r * r / N) * S)
tl.store(DX + row * sdx + cols, dx, mask=mask)
dW is reduced across rows; do it in a second small kernel (or fold it
into the same kernel with atomics if you have many rows; pure reduce-then-
store is usually faster).
8.6 Causal flash-attention forward (simplified)¶
We compute O = softmax(Q K^T / sqrt(d)) V, with a causal mask, in tiles.
The trick is to never materialise the seq_len x seq_len score matrix --
we keep an online (m, l) per output row block and stream over key/value
blocks. Backward is more complex; we omit it.
Shapes: Q, K, V are [B, H, S, D]; we run one program per (batch, head,
M-block of Q rows). For brevity we drop the batch/head dims; assume Q, K,
V are 2D [S, D] and that the launcher iterates over BH.
@triton.jit
def flash_attn_fwd(Q, K, V, O, L, M, # M, L for backward
sq_s, sq_d, sk_s, sk_d, sv_s, sv_d, so_s, so_d,
S, D: tl.constexpr,
BM: tl.constexpr, BN: tl.constexpr,
CAUSAL: tl.constexpr):
pid_m = tl.program_id(0)
offs_m = pid_m * BM + tl.arange(0, BM)
offs_d = tl.arange(0, D)
offs_n = tl.arange(0, BN)
# Load Q-tile once and keep it in registers
q_ptrs = Q + offs_m[:, None] * sq_s + offs_d[None, :] * sq_d
q_mask = offs_m[:, None] < S
q = tl.load(q_ptrs, mask=q_mask, other=0.0)
# Online softmax state
m_i = tl.full([BM], -float("inf"), dtype=tl.float32)
l_i = tl.zeros([BM], dtype=tl.float32)
acc = tl.zeros([BM, D], dtype=tl.float32)
scale = 1.0 / tl.sqrt(tl.full([], D, dtype=tl.float32))
# Causal: only iterate up to the last K-block that can contribute.
n_end = (pid_m + 1) * BM if CAUSAL else S
for start_n in range(0, n_end, BN):
cur_n = start_n + offs_n
k_ptrs = K + cur_n[None, :] * sk_s + offs_d[:, None] * sk_d
v_ptrs = V + cur_n[:, None] * sv_s + offs_d[None, :] * sv_d
k_mask = cur_n[None, :] < S
v_mask = cur_n[:, None] < S
k = tl.load(k_ptrs, mask=k_mask, other=0.0)
v = tl.load(v_ptrs, mask=v_mask, other=0.0)
# Scores [BM, BN]
s = tl.dot(q, k) * scale
if CAUSAL:
causal_mask = offs_m[:, None] >= cur_n[None, :]
s = tl.where(causal_mask, s, float("-inf"))
# also mask off out-of-range cur_n
s = tl.where(cur_n[None, :] < S, s, float("-inf"))
# Online softmax update
m_new = tl.maximum(m_i, tl.max(s, axis=1))
alpha = tl.exp(m_i - m_new)
p = tl.exp(s - m_new[:, None])
l_i = alpha * l_i + tl.sum(p, axis=1)
acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)
m_i = m_new
o = acc / l_i[:, None]
o_ptrs = O + offs_m[:, None] * so_s + offs_d[None, :] * so_d
tl.store(o_ptrs, o.to(O.dtype.element_ty), mask=q_mask)
# Save log-sum-exp = m_i + log(l_i) for backward
tl.store(L + offs_m, m_i + tl.log(l_i), mask=offs_m < S)
This is not the fastest flash-attention -- the production versions split
the K loop differently, fuse the dropout, use TMA on Hopper, and have a
dedicated backward kernel. But the structure -- Q-tile in registers,
streaming K/V tiles, online (m, l), accumulator rescaled by alpha --
is the same and is the form you should be able to derive from scratch.
Two subtle points the code makes explicit:
- The accumulator update is
acc = acc * alpha + dot(p, v). Thealphascaling corrects every previous K-block's contribution to the new denominator. Without it the answer is wrong. - The causal mask is applied before the softmax exp, by setting
out-of-bounds positions to - inf
. Afterexp, those become0, which is the additive identity for the running sum and for thedot(p, v)` accumulation.
9. The compilation pipeline¶
9.1 The path¶
Python @triton.jit function
| AST capture
v
Triton IR (an MLIR dialect)
| high-level optimisation:
| layout selection, broadcasting elimination,
| reduction lowering, masked-load elision
v
TritonGPU IR (an MLIR dialect, HW-aware)
| GPU-specific:
| shared-memory allocation, pipelining (cp.async / TMA),
| swizzling, MMA selection
v
LLVM IR
|
v
PTX (NVIDIA) or AMDGCN (AMD) or other backend
|
v
SASS (assembled by ptxas at load time)
You almost never need to look at the intermediate IRs. You frequently want to look at the PTX.
9.2 Inspecting compiled artefacts¶
kernel = matmul_tiled[grid](...) # first call, compiles
# Get the compiled handle:
fn = matmul_tiled.warmup(*args, grid=grid) # one approach
# Or, after a real call:
print(matmul_tiled.cache.values())
# For a specific compiled binary:
binary = next(iter(matmul_tiled.cache.values()))
print(binary.asm['ttir']) # Triton IR
print(binary.asm['ttgir']) # TritonGPU IR
print(binary.asm['llir']) # LLVM IR
print(binary.asm['ptx']) # PTX
print(binary.asm['cubin']) # binary cubin (bytes)
(API has shifted across releases; the exact attribute names may differ in
your version. Verify with dir(binary.asm).)
9.3 What to look for in PTX¶
ld.global.v4.f32/ld.global.v8.f16-- vectorised loads. If your kernel emits scalarld.global.f32, your indexing is not coalesced.mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32-- tensor-core MMA. No such instruction means you are not using tensor cores.cp.async.cg.shared.global(Ampere) orcp.async.bulk.tensor(Hopper) -- asynchronous global-to-shared copies, used by software pipelining.bar.sync-- barriers between pipeline stages. Too many usually meansnum_stagesis too low and you are stalling.
9.4 Useful environment variables¶
TRITON_PRINT_AUTOTUNING=1-- prints the timing table for autotune.TRITON_CACHE_DIR=/path-- override the default cache location.TRITON_DEBUG=1-- prints extra diagnostics.TRITON_INTERPRET=1-- runs the kernel on the CPU in pure Python emulation. Excruciatingly slow but the only way to single-step a Triton kernel underpdb. Worth the cost when you have a correctness bug you cannot localise.
10. Integration with PyTorch¶
10.1 Calling a Triton kernel from PyTorch¶
The simplest path: a Python wrapper around the kernel, callable from
ordinary user code. The wrappers in section 8 (add, matmul, softmax)
are exactly this. PyTorch tensors are passed directly; their .data_ptr()
becomes the kernel's pointer argument, dtype and strides are read from the
tensor metadata.
10.2 Autograd¶
Subclass torch.autograd.Function:
class RMSNormFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, w, eps):
M, N = x.shape
y = torch.empty_like(x)
rstd = torch.empty(M, device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
rmsnorm_fwd[(M,)](x, w, y, rstd, x.stride(0), y.stride(0), N, eps,
BLOCK=BLOCK, num_warps=4)
ctx.save_for_backward(x, w, rstd)
ctx.N, ctx.BLOCK = N, BLOCK
return y
@staticmethod
def backward(ctx, dy):
x, w, rstd = ctx.saved_tensors
dx = torch.empty_like(x)
M = x.shape[0]
rmsnorm_bwd_dx[(M,)](x, w, dy, rstd, dx,
x.stride(0), dy.stride(0), dx.stride(0),
ctx.N, BLOCK=ctx.BLOCK, num_warps=4)
# dW: reduce x*rstd*dy over the row dim
dw = (x * rstd[:, None] * dy).sum(dim=0) # let PyTorch handle it
return dx, dw, None
def rmsnorm(x, w, eps=1e-6):
return RMSNormFn.apply(x, w, eps)
The dW reduction is small enough that letting PyTorch handle it is fine;
fusing it into a Triton kernel is an optimisation, not a correctness step.
10.3 torch.library ops¶
For interaction with torch.compile, register your Triton kernel as a
custom op so the compiler treats it as opaque and does not try to trace
through @triton.jit:
import torch
from torch.library import custom_op, register_fake
@custom_op("mylib::rmsnorm", mutates_args=())
def rmsnorm_op(x: torch.Tensor, w: torch.Tensor, eps: float) -> torch.Tensor:
return RMSNormFn.apply(x, w, eps)
@register_fake("mylib::rmsnorm")
def _(x, w, eps):
return torch.empty_like(x) # shape/stride meta, no compute
The register_fake (formerly register_meta) provides shape inference for
torch.compile's symbolic shape tracing. Without it torch.compile will
fail to trace through your op.
10.4 torch.compile¶
torch.compile already emits Triton kernels for many fusions (it has a
backend called Inductor). Your handwritten Triton kernels coexist: Inductor
will compile around them and treat them as black boxes (provided you
registered them via torch.library). If you write a really good kernel
that beats Inductor on a hot path -- common for attention variants and
custom norms -- this is the way to ship it.
Two cautions:
torch.compileexpects shape-stable callees. If your custom op recompiles on every shape it slows the graph; pre-warm with the shapes you care about.mutates_args=()is critical for op safety. If you do mutate arguments (e.g. an in-place kernel) declare them or you will silently corrupt the graph cache.
11. Common pitfalls¶
Shape-dependent recompilation. Every unique constexpr value triggers a
new compile. Pin block sizes to a finite menu. Autotune key buckets your
shapes; that is fine. What is not fine is BLOCK = triton.next_power_of_2(N)
when N varies continuously -- you will have a compile per N. Solution:
quantise (BLOCK = max(64, triton.next_power_of_2(N)) -- ok if N has a
small set of values; not ok in general).
Mask correctness. Two failure modes:
- Forgetting to mask. Boundary out-of-bounds reads on Ampere often appear to "work" because they read zeros from a happy memory region; on Hopper with TMA they hard-fault. Always mask.
other=mismatch. For a softmax row,other=0.0makes masked lanes the max if all real values are negative -- wrong. Useother=-inffor max-reductions,other=0.0for sum-reductions and dot accumulators.
num_stages too high. Each pipeline stage holds a full
[BM, BK] and [BK, BN] tile in shared memory. Running out of shared
memory causes silent fallback to fewer stages or a hard compile error,
depending on version. If autotune timing for a config is much worse than
a config with smaller BM*BK + BK*BN, you are spilling.
Register-pressure spills. Same root cause: tiles too big for registers.
Symptoms: ptxas reports spills (ptxas info: ... 2048 bytes stack frame,
... bytes spill stores), kernel runs at fraction of expected speed.
Reduce BM, BN, or split the K loop.
Forgetting .to(dtype) on store. acc is fp32, your output buffer is
fp16. tl.store(C_ptr, acc, ...) will implicitly cast in modern Triton
but the cast may not be the rounding mode you expect (RTNE vs. RTZ). Be
explicit: acc.to(C.dtype.element_ty).
Confusing tl.program_id axes with grid axes. tl.program_id(0) is
the first grid dim, not the X coordinate of anything physical. With a
2D grid (grid_m, grid_n), tl.program_id(0) ranges 0..grid_m-1 and
tl.program_id(1) 0..grid_n-1. Be consistent.
Autotune triggering on every call. Happens if your key= does not
include all shape-dependent args (e.g. you forgot K). The cache key
collides across distinct shapes and the autotuner re-runs. Always include
every shape-relevant runtime arg in key=.
Stride bugs. Always pass strides explicitly; never assume contiguous.
A view of a transposed tensor will have non-trivial strides; using T
without checking is a frequent source of garbage outputs.
Thinking tl.dot is matrix multiplication. It is block matrix
multiplication -- one tile from the operand-A block and one tile from
operand-B block. The K-loop is yours.
12. Triton vs CUTLASS vs hand-rolled CUDA¶
Rough rules of thumb. None are universal; benchmark your case.
| Use case | Pick | Why |
|---|---|---|
| Standard square matmul, fp16/bf16, big shapes | cuBLAS | Vendor-tuned per arch; near-peak. |
| GEMM with epilogue fusion (bias + GeLU + dropout) | CUTLASS or Triton | cuBLAS does no fusion. |
| Custom fused norm / softmax / attention | Triton | Productivity is decisive; perf is good. |
| Bleeding-edge attention (FA-3 style) | Hand CUDA / CUTLASS | TMA, warp specialisation, async fence patterns -- still ahead of Triton in 2025/2026. |
| Quirky shapes (small K, irregular masking) | Triton | Beats vendor libs frequently. |
| Sparse / structured-sparsity kernels | CUTLASS | Has sparse MMA primitives. |
| Research prototype, weeks-not-months | Triton | Always. |
CUTLASS is a C++ template library from NVIDIA: more verbose than Triton, more flexible than cuBLAS, exposes the actual MMA shapes and async pipeline primitives. Hand-rolled CUDA gets you the last 5-15% on niche ops; the engineering cost is large.
A reasonable rule: write everything in Triton first, profile, replace the top-1 hot kernel with CUTLASS or hand CUDA only if numbers force you to.
13. Exercises¶
Six problems, ordered by difficulty. Try each before reading the sketch.
13.1 Fused bias + ReLU + scale¶
Write a kernel y = relu(x + b) * s where x is [M, N], b is [N]
(broadcast across rows), s is a scalar. One launch, one pass.
Sketch. 1D grid of M*ceil(N/BN) programs, or 2D grid of (M,
ceil(N/BN)). Load x tile [1, BN], load b [BN] once per program,
broadcast add, tl.where(z > 0, z, 0.0) * s, store. ~25 lines. Time
should be HBM-bandwidth-bound.
13.2 Rowwise top-k mask (k=1)¶
Given x [M, N], output y of same shape with y[i, j] = x[i, j] if
x[i, j] is the rowwise max, else 0. Single pass per row.
Sketch. Like softmax: load row into registers, m = tl.max(x, axis=0),
y = tl.where(x == m, x, 0.0), store. Care with ties (you will mark all
of them; usually fine). Width must fit in BLOCK.
13.3 LayerNorm forward with Welford¶
Write a LayerNorm forward that computes mean and variance with one pass
using Welford's online algorithm. Compare numerical accuracy to the naive
formula on a [1, 65536] input of magnitude 1e6.
Sketch. Tile the row into chunks of BLOCK, accumulate (n, mean, M2)
across chunks with the combine formulas in section 7.3. After the loop,
var = M2 / N, rstd = rsqrt(var + eps), second pass to compute and
store y. Or, since the row fits in registers when N <= 64K, use the
naive single-block reduction path -- in fp32 that is also stable for
any realistic magnitude. Welford pays off when you must tile.
13.4 Causal-attention forward, drop the causal mask¶
Modify the flash-attention kernel in 8.6 so it accepts a CAUSAL: tl.constexpr
parameter and, when False, iterates over the full K range. Verify: when
CAUSAL=False, output equals softmax(QK^T/sqrt(d)) V to within fp16
precision.
Sketch. The kernel as written already has CAUSAL. The change is the
n_end = (pid_m + 1) * BM if CAUSAL else S line and the tl.where causal
mask gated on if CAUSAL:. Confirm that the if is on CAUSAL (a
constexpr) so it is compiled away when False.
13.5 Matmul autotune key bug¶
A user complains: "every call recompiles, even at the same shape." Their
decorator is @triton.autotune(configs=[...], key=['M', 'N']) for an
M x K by K x N matmul. What is wrong?
Sketch. K is missing from the key. The autotuner caches only by
(M, N), so (M, N, K1) and (M, N, K2) collide and the autotuner
reruns benchmarks each time K changes. Add K to key.
13.6 Spotting tensor-core usage in PTX¶
Take the matmul kernel from 8.3, compile it for fp16 inputs, dump the PTX,
and find the MMA instructions. What is the M-N-K shape of each mma.sync?
What does it imply about the tile sizes Triton chose internally?
Sketch. Look for mma.sync.aligned.m16n8k16 (Ampere fp16) or
wgmma.mma_async (Hopper). On Ampere fp16, the per-warp MMA shape is
16x8x16, so a [128, 128] accumulator across 4 warps is built from
many of these MMAs in a tile schedule chosen by the compiler. The
takeaway: the fundamental hardware tile is small; Triton's BM, BN, BK
are block-level tiles assembled out of many MMAs.
14. Closing notes¶
You now have:
- the conceptual model (block-level program, per-instance abstraction),
- the API surface (
tl.load/store/dot/sum/max, masks, constexpr, autotune), - the numerical-stability primitives (online softmax, Welford, log-sum-exp),
- six worked kernels covering the 90% case (elementwise, matmul, softmax, norm, attention),
- the integration story for PyTorch (
autograd.Function,torch.library,torch.compile), - a debugging path (PTX inspection,
TRITON_INTERPRET, autotune logs), - a comparative map against CUTLASS and hand CUDA.
Two final discipline rules. Always benchmark with torch.cuda.synchronize
inside triton.testing.do_bench -- forgetting do_bench's warmup
discards the autotune-search time and gives you misleading numbers.
Always check correctness against a PyTorch reference first, with
torch.testing.assert_close(out_triton, out_torch, rtol=1e-2, atol=1e-2)
for fp16, before you start chasing performance. A faster wrong kernel is
not a kernel.
The single strongest piece of advice anyone has given about Triton: write the dumb version first, autotune it, look at the PTX once to make sure tensor cores fired, then move on. The framework is designed to reward that workflow. The hours you would have spent on per-thread indexing in CUDA become hours you spend on actual algorithms.