Saltar a contenido

Triton: A Deep Dive into the Block-Level GPU DSL

A self-contained reference. After working through this chapter you should be able to read, write, debug, autotune, and integrate Triton kernels into a PyTorch project without consulting any other source. We assume you know Python, basic linear algebra, what a GPU is, and roughly what CUDA does (threads, blocks, shared memory, warps). We do not assume you have ever written a CUDA kernel by hand.


Table of contents

  1. Why Triton exists
  2. The programming model
  3. Memory operations: load, store, mask, broadcast
  4. Math operations: elementwise, tl.dot, reductions, transcendentals
  5. Compile-time constants and specialization
  6. Autotuning
  7. Numerical stability patterns (online softmax derivation, Welford, log-sum-exp)
  8. Six fully-annotated real kernels
  9. The compilation pipeline (Python -> MLIR -> PTX)
  10. Integration with PyTorch (torch.library, torch.compile, autograd)
  11. Common pitfalls
  12. Triton vs CUTLASS vs hand-rolled CUDA
  13. Six exercises with worked answer sketches
  14. Closing notes

1. Why Triton exists

1.1 The CUDA per-thread model

CUDA exposes the GPU as a hierarchy of threads organised into warps (32 threads), then blocks (one or more warps), then a grid of blocks. The programmer writes code from the perspective of a single thread:

// CUDA: each thread handles a single output element of c = a + b
__global__ void vec_add(const float* a, const float* b, float* c, int n) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < n) c[i] = a[i] + b[i];
}

This per-thread view is faithful to the hardware (an SM really does dispatch warps, each a SIMD-32 unit), but for a kernel author it forces a constant mental gear shift. When you think "tile of A times tile of B equals tile of C", you are reasoning block-wise. CUDA makes you re-derive that block-level operation as a sequence of per-thread loads, per-thread __shared__ writes, __syncthreads() barriers, per-thread MMA inputs, accumulator registers, mask handling at the trailing edge of the matrix, etc. Every one of those steps is a place to make a memory-coalescing, bank-conflict, or register-pressure mistake. Worse: changing the tile size invalidates almost all of that hand-crafted indexing.

1.2 The Triton thesis (MAPL 2019)

Tillet, Kung, and Cox (MAPL 2019, "Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations") observe that almost every high-performance GPU kernel for ML is naturally written as a loop over blocks of the output, where each iteration computes one output tile from input tiles. They propose: let the programmer write that loop directly, with block-shaped tensors as first-class values, and let a compiler lower the block program to a per-thread CUDA-like form -- choosing vector widths, shared-memory layouts, double-buffered pipelining, swizzling, and tensor core instructions automatically.

Concretely, the human writes (paraphrased):

Program instance pid is responsible for output rows pid*BM ... pid*BM+BM. Allocate a block-tensor acc of shape [BM, BN]. Loop over the K dimension in blocks of BK: load an [BM, BK] tile of A, an [BK, BN] tile of B, do acc += dot(a, b). After the loop, store acc to C.

The compiler turns that into hundreds of lines of well-vectorised PTX with correct shared-memory layout, asynchronous copies, and mma instructions.

1.3 What Triton gives up, and why that is fine

Triton intentionally cannot express:

  • arbitrary intra-block thread-divergent control flow,
  • explicit __syncthreads(),
  • user-controlled shared memory,
  • warp-level shuffles as a primary interface (some primitives exist, but they are not the model).

For ML kernels this is rarely a loss: 95% of useful kernels are tiled elementwise / reduce / matmul / softmax compositions, and for those the compiler's choices are at least as good as a moderately experienced CUDA author and far less error-prone.

1.4 Where the productivity comes from

Three places:

  1. Indexing is automatic. You write tl.arange(0, BLOCK) and broadcast; the compiler generates the per-thread index arithmetic.
  2. Shared memory is automatic. tl.load of a 2D block lowers to coalesced global -> shared copies with the right swizzle for downstream tl.dot.
  3. Tile-size sweeps are cheap. Because block sizes are constexpr, the compiler specialises and you can autotune over a config grid in one decorator.

The empirical claim of the original paper, repeated through every release: hand-written Triton matmul reaches ~80% of cuBLAS on common square shapes and frequently beats cuBLAS on awkward shapes (small K, non-multiple-of-16 dims) where the vendor library has no specialised kernel.


2. The programming model

2.1 @triton.jit and the program-instance abstraction

A Triton kernel is a Python function decorated with @triton.jit:

import triton
import triton.language as tl

@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n_elements,
               BLOCK: tl.constexpr):
    pid = tl.program_id(axis=0)                    # 0..num_programs-1
    offsets = pid * BLOCK + tl.arange(0, BLOCK)    # vector of indices
    mask = offsets < n_elements
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    tl.store(out_ptr + offsets, x + y, mask=mask)

The decorator does not run the function on the host. It captures the Python AST, lowers it to Triton IR (an MLIR dialect), then to PTX, on first call. The body of the function executes on the GPU.

Inside the body, you are one program instance. There is no threadIdx. The total number of program instances is set when you launch:

n = x.numel()
grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),)   # 1D grid
add_kernel[grid](x, y, out, n, BLOCK=1024)

grid is either a tuple (1D/2D/3D) or a callable that receives the compile-time meta-parameters and returns a tuple. triton.cdiv(a, b) is ceil-division.

A program instance is roughly equivalent to a CUDA block, not a thread. Inside it, the compiler will allocate num_warps * 32 actual threads and distribute the block-level work across them. You influence that by passing num_warps= at launch (or via autotune).

2.2 Block-tensors as first-class values

tl.arange(0, N) and any expression built from it -- arithmetic, broadcasting, tl.load, tl.dot, reductions -- produces a block tensor whose shape is known at compile time. These tensors live in registers (with the compiler spilling to shared memory or even global as a last resort).

offs_m = tl.arange(0, BM)                 # shape [BM]
offs_n = tl.arange(0, BN)                 # shape [BN]
offs_2d = offs_m[:, None] * BN + offs_n[None, :]   # shape [BM, BN]

x[:, None] and x[None, :] are NumPy-style broadcasts; the compiler generates the per-thread index arithmetic that makes a 2D logical tile out of registers held by 32-wide warps.

2.3 Side-by-side: vector add

Same kernel in CUDA:

__global__ void add_kernel(const float* x, const float* y, float* out, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) out[idx] = x[idx] + y[idx];
}
// host:
add_kernel<<<(n + 255) / 256, 256>>>(x, y, out, n);

The Triton version is structurally similar but works at the vector level: one program does 1024 elements at once; the compiler decides how to assign those 1024 lanes across 32 threads (probably 4 elements per thread, with 128-bit loads). You did not have to choose. The CUDA author who wants the same vectorisation has to use float4 casts and do four-way unrolling by hand.

Exercise (warm-up). Modify add_kernel so it computes out = a*x + y (axpy). What changes? Answer: add an a scalar argument and replace the store value with a*x + y. Nothing else.


3. Memory operations

3.1 tl.load and tl.store

Signatures (simplified, current as of 2026 Triton; verify if you target an older release):

tl.load(pointer, mask=None, other=0.0, cache=None, eviction_policy=None,
        volatile=False)
tl.store(pointer, value, mask=None, cache=None, eviction_policy=None)

Key semantics:

  • pointer is either a scalar pointer (rare) or a block tensor of pointers formed by base_ptr + offsets. The compiler infers the block shape from offsets.
  • mask, when provided, is a same-shape boolean tensor. Lanes where mask=False skip the memory access entirely (no fault, no traffic).
  • other is the value substituted for masked-off lanes on load. Default 0.0. Pick this carefully: if you mask in a tl.dot reduction, the masked lanes still participate arithmetically, so other=0.0 makes them contribute zero -- which is what you want.

3.2 Boundary masking: why it always matters

Take a vector add over n=1000 elements with BLOCK=128. You will launch ceil(1000/128) = 8 programs covering 1024 elements. The last program must mask off the trailing 24 lanes or it will read past the end of the tensor and either segfault or corrupt data:

offsets = pid * BLOCK + tl.arange(0, BLOCK)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask, other=0.0)

In CUDA the equivalent is the if (idx < n) guard around a single thread. The Triton version is more efficient because the entire warp issues one masked load instruction; in CUDA each thread independently decides to load or not.

For 2D blocks, you need a 2D mask:

offs_m = pid_m * BM + tl.arange(0, BM)             # [BM]
offs_n = pid_n * BN + tl.arange(0, BN)             # [BN]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
ptrs = X + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
x = tl.load(ptrs, mask=mask, other=0.0)

This is the canonical pattern. Get the strides right and the mask right and you have a correct boundary-aware tile load.

3.3 Broadcasting

tl.arange returns a 1D tensor; expressions broadcast like NumPy:

A = tl.zeros([BM, BN], dtype=tl.float32)           # [BM, BN] of zero
v = tl.arange(0, BN)                               # [BN]
A2 = A + v[None, :]                                # broadcast across rows

The compiler chooses the warp/lane layout so the broadcast costs nothing (it is a register-relabelling, not a memory op). This is one of the places Triton silently saves you a lot of CUDA boilerplate.

3.4 Cache and eviction hints

cache="ca" (cache all), cache="cg" (cache global, skip L1), eviction_policy="evict_last" | "evict_first" map to PTX cache-modifier suffixes. Useful for reused tiles (the K-tile of A re-read by every column block) vs. one-shot tiles. As a first pass, leave them at default.

Exercise. Why does tl.load(ptr, mask=mask, other=0.0) in a tl.dot reduction not introduce numerical error from the masked lanes? Answer: because 0.0 is the additive identity, and tl.dot is a sum-of-products, masked-off lanes contribute exactly zero to the accumulator.


4. Math operations

4.1 Elementwise

All standard arithmetic (+ - * /), comparison (< > == !=), and tl.exp, tl.log, tl.sqrt, tl.rsqrt, tl.sin, tl.cos, tl.tanh, tl.sigmoid, tl.where(cond, a, b) operate elementwise on block tensors and follow the broadcasting rules of section 3.

tl.where is your branchless conditional. There is no per-lane if/else in Triton; you use where for value selection and mask for memory access gating.

4.2 tl.dot and tensor cores

acc = tl.dot(a, b, acc=acc, allow_tf32=True)
# a: [M, K], b: [K, N], acc: [M, N]; result shape [M, N]

tl.dot is the only operation in Triton that lowers to the GPU's tensor-core MMA instructions on Volta+. It requires:

  • M, N, K dimensions to be at least 16 (often 16/16/16 minimum; on Hopper larger MMAs are used). Sub-16 dims lower to FMA loops, which is fine but slow.
  • dtypes that the hardware supports: fp16, bf16, tf32 (the fp32 path goes through tf32 on A100 if allow_tf32=True, which is the default in PyTorch ML), fp8 on Hopper.
  • acc to be fp32 (recommended) for numerical safety. The accumulator stays in registers across iterations of the K loop.

Without tl.dot you cannot use tensor cores from Triton. Memorise this.

4.3 Reductions

m = tl.max(x, axis=1)        # reduce along axis 1
s = tl.sum(x, axis=0)
mn = tl.min(x, axis=0)
am = tl.argmax(x, axis=1)    # newer Triton versions

Reductions return a tensor with the reduced axis removed. They lower to warp-level shuffle reductions where possible and shared-memory reductions across warps. You do not write either by hand.

4.4 Transcendentals

tl.exp, tl.log, tl.sqrt, tl.rsqrt, tl.sin, tl.cos, tl.erf, tl.tanh, tl.sigmoid exist. Two notes:

  • On NVIDIA they typically lower to the fast PTX intrinsics (exp.approx.f32 etc.) for fp32, which are ~1 ulp accurate, fine for ML.
  • tl.exp2 is often slightly faster than tl.exp because the hardware natively computes base-2 exponential. You can rewrite softmax using exp2(x * log2_e) if you really need the cycles; usually not worth it.

Exercise. Implement a fused GELU: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))). Sketch: one elementwise kernel, mask on boundaries; everything is a tl.tanh/*/+. About 12 lines.


5. Compile-time constants: tl.constexpr

5.1 Why constexpr exists

Block sizes (BM, BN, BK, BLOCK) influence the shape of every block tensor in the kernel. Shape determines:

  • register allocation,
  • shared-memory tile size,
  • MMA instruction shape selection,
  • loop unrolling.

The compiler must know these at codegen time. So they are passed as tl.constexpr parameters:

@triton.jit
def kernel(..., BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr): ...

When you launch with kernel[grid](..., BM=128, BN=128, BK=32), Triton hashes the constexpr values and the dtypes/strides/etc. into a cache key. First call with a new key triggers compilation; subsequent calls reuse the cached PTX.

5.2 What else can be constexpr

Anything that should be specialised: a bool selecting causal vs. non-causal attention, an int choosing reduction axis, an enum - likedtype`. Use constexpr aggressively for branches that you want the compiler to delete in the specialised version.

@triton.jit
def attn(..., CAUSAL: tl.constexpr):
    if CAUSAL:
        # this branch is compiled away when CAUSAL=False
        mask = offs_m[:, None] >= offs_n[None, :]
        scores = tl.where(mask, scores, float("-inf"))

That if CAUSAL is not a runtime branch -- it is Python-level dead-code elimination at JIT time.

5.3 The cost: recompilation

Each unique (constexpr-value, dtype, contiguity) tuple is a separate kernel. Cache lives in ~/.triton/cache. If your constexpr space is unbounded (e.g. you set BLOCK=n per call) you will recompile constantly. Pick a finite menu and stick to it.


6. Autotuning

6.1 The decorator

@triton.autotune(
    configs=[
        triton.Config({'BM': 128, 'BN': 128, 'BK': 32}, num_warps=4, num_stages=3),
        triton.Config({'BM': 128, 'BN': 64,  'BK': 32}, num_warps=4, num_stages=4),
        triton.Config({'BM': 64,  'BN': 128, 'BK': 32}, num_warps=4, num_stages=4),
        triton.Config({'BM': 64,  'BN': 64,  'BK': 32}, num_warps=2, num_stages=5),
        triton.Config({'BM': 256, 'BN': 64,  'BK': 32}, num_warps=8, num_stages=3),
    ],
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(...): ...

key= lists the runtime arguments whose values define a tuning bucket. On first call with a particular (M, N, K) triple, Triton runs every config, measures wall time, picks the winner, caches it. Subsequent calls with the same (M, N, K) skip the search.

6.2 Knobs

  • Block sizes (BM, BN, BK for matmul; BLOCK_SIZE for elementwise) control register pressure and shared-memory tile size. Bigger usually means more reuse but more spills.
  • num_warps controls how many 32-thread warps form the program instance. More warps means more parallelism per SM but each warp gets fewer registers.
  • num_stages controls software pipelining depth: how many K-iterations worth of tiles are simultaneously in flight (cp.async on Ampere+, TMA on Hopper). 2-5 is typical. Too many overflows shared memory; too few starves the MMA units.

6.3 Practical autotune

  • Limit configs to a dozen or so. The first call already pays the worst-case cost (you run every config), so 30 configs with 200ms each is six seconds of warmup.
  • key=['M', 'N', 'K'] is right for matmul; for elementwise, key=['n'] or no key (single bucket) is fine.
  • Use prune_configs_by={'early_config_prune': fn, 'perf_model': fn} if the search is huge. Usually unnecessary.
  • Set TRITON_PRINT_AUTOTUNING=1 to see the table of configs and their timings. This is invaluable for sanity-checking.

6.4 Caching

Compiled kernels are cached on disk under ~/.triton/cache/<hash>/. The hash is over: function source, argument dtypes, constexpr values, GPU arch. Move to a different GPU and you recompile. This is fine.


7. Numerical stability patterns

Three patterns dominate ML kernels: online softmax, Welford for running mean/variance, log-sum-exp.

7.1 The naive softmax problem

softmax(x)_i = exp(x_i) / sum_j exp(x_j). If any x_i > 88.7 (fp32) or >11.1 (fp16), exp(x_i) overflows. The standard fix:

softmax(x)_i = exp(x_i - max(x)) / sum_j exp(x_j - max(x))

This requires two passes over x: one to find max(x), one to compute the numerator/denominator. For attention, where x is the score row of length seq_len, two passes mean two reads of a tensor that does not fit in registers if seq_len is large. Bandwidth-bound.

7.2 Online softmax: the derivation

We want a one-pass algorithm. Suppose we have processed a prefix of length k and we know:

  • m_k = max(x_1..x_k),
  • l_k = sum_{j=1..k} exp(x_j - m_k).

The softmax over the prefix is exp(x_j - m_k) / l_k.

Now we see x_{k+1}. The new max is m_{k+1} = max(m_k, x_{k+1}). We need l_{k+1} = sum_{j=1..k+1} exp(x_j - m_{k+1}).

Split the sum:

l_{k+1} = sum_{j=1..k} exp(x_j - m_{k+1})  +  exp(x_{k+1} - m_{k+1})
        = sum_{j=1..k} exp(x_j - m_k) * exp(m_k - m_{k+1})  +  exp(x_{k+1} - m_{k+1})
        = l_k * exp(m_k - m_{k+1})  +  exp(x_{k+1} - m_{k+1})

That is the recurrence. The factor exp(m_k - m_{k+1}) is <= 1 and rescales the running denominator whenever the max grows. The recurrence generalises trivially to blocks of new elements rather than single elements: you compute the local max and local sum of the new block, then combine with the running (m, l) exactly as above. This is the heart of flash-attention.

In code (block version):

m_block = tl.max(x_block, axis=0)            # local max
p_block = tl.exp(x_block - m_block)
l_block = tl.sum(p_block, axis=0)            # local sum
m_new = tl.maximum(m, m_block)
alpha = tl.exp(m - m_new)
beta  = tl.exp(m_block - m_new)
l_new = alpha * l + beta * l_block
m, l = m_new, l_new
# ...and rescale any running output by alpha

7.3 Welford for variance

For RMSNorm forward you only need mean(x^2), which is a single pass. For LayerNorm forward you need both mean(x) and var(x). The naive "sum of squares minus square of sum" formula is catastrophic in fp32 for large vectors. Welford is the one-pass numerically stable form. For block combination:

n_combined = n_a + n_b
delta      = mean_b - mean_a
mean_combined = mean_a + delta * n_b / n_combined
M2_combined   = M2_a + M2_b + delta^2 * n_a * n_b / n_combined
var = M2_combined / n_combined

You will see this exact pattern in the LayerNorm tutorial kernels.

7.4 Log-sum-exp

logsumexp(x) = log(sum exp(x)) is m + log(sum exp(x - m)) for m = max(x). Same trick as softmax. Useful inside cross-entropy loss kernels where you want - x_target + logsumexp(x)` directly.


8. Six annotated kernels

We progress from the trivial to flash-attention. Every kernel below is a runnable example. Imports assumed:

import torch
import triton
import triton.language as tl

8.1 Vector add

@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n,
               BLOCK: tl.constexpr):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < n
    x = tl.load(x_ptr + offs, mask=mask)
    y = tl.load(y_ptr + offs, mask=mask)
    tl.store(out_ptr + offs, x + y, mask=mask)

def add(x, y):
    out = torch.empty_like(x)
    n = x.numel()
    grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),)
    add_kernel[grid](x, y, out, n, BLOCK=1024)
    return out

This kernel is bandwidth-bound; it should approach peak HBM bandwidth on any reasonable GPU. If you measure 80% of peak you are doing fine; the remaining 20% is launch overhead and tail effects.

8.2 Naive matmul

A naive matmul has each program compute one output element via an inner loop over K. Useful as a teaching baseline; do not ship it.

@triton.jit
def matmul_naive(A, B, C, M, N, K,
                 sa_m, sa_k, sb_k, sb_n, sc_m, sc_n,
                 BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr):
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    offs_m = pid_m * BM + tl.arange(0, BM)
    offs_n = pid_n * BN + tl.arange(0, BN)
    offs_k = tl.arange(0, BK)
    a_ptrs = A + offs_m[:, None] * sa_m + offs_k[None, :] * sa_k
    b_ptrs = B + offs_k[:, None] * sb_k + offs_n[None, :] * sb_n
    acc = tl.zeros([BM, BN], dtype=tl.float32)
    for k in range(0, K, BK):
        a_mask = (offs_m[:, None] < M) & ((k + offs_k)[None, :] < K)
        b_mask = ((k + offs_k)[:, None] < K) & (offs_n[None, :] < N)
        a = tl.load(a_ptrs, mask=a_mask, other=0.0)
        b = tl.load(b_ptrs, mask=b_mask, other=0.0)
        acc += tl.dot(a, b)
        a_ptrs += BK * sa_k
        b_ptrs += BK * sb_k
    c_ptrs = C + offs_m[:, None] * sc_m + offs_n[None, :] * sc_n
    c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tl.store(c_ptrs, acc.to(C.dtype.element_ty), mask=c_mask)

Already this is much simpler than the equivalent CUDA, which would need explicit shared-memory tiles, __syncthreads(), and per-thread MMA inputs.

8.3 Tiled matmul with autotune

The canonical Triton matmul. The two changes from 8.2 are:

  1. L2-cache-friendly program-id swizzle (group along M to reuse B tiles).
  2. @triton.autotune over a config grid.
def get_configs():
    return [
        triton.Config({'BM':128,'BN':256,'BK':32,'GROUP_M':8}, num_warps=8, num_stages=3),
        triton.Config({'BM':128,'BN':128,'BK':32,'GROUP_M':8}, num_warps=4, num_stages=4),
        triton.Config({'BM':128,'BN':64, 'BK':32,'GROUP_M':8}, num_warps=4, num_stages=4),
        triton.Config({'BM':64, 'BN':128,'BK':32,'GROUP_M':8}, num_warps=4, num_stages=4),
        triton.Config({'BM':64, 'BN':64, 'BK':32,'GROUP_M':8}, num_warps=2, num_stages=5),
        triton.Config({'BM':128,'BN':128,'BK':64,'GROUP_M':8}, num_warps=8, num_stages=3),
    ]

@triton.autotune(configs=get_configs(), key=['M', 'N', 'K'])
@triton.jit
def matmul_tiled(A, B, C, M, N, K,
                 sa_m, sa_k, sb_k, sb_n, sc_m, sc_n,
                 BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr,
                 GROUP_M: tl.constexpr):
    pid = tl.program_id(0)
    num_pid_m = tl.cdiv(M, BM)
    num_pid_n = tl.cdiv(N, BN)
    # L2-friendly swizzle: walk in (GROUP_M, num_pid_n) super-blocks
    num_pid_in_group = GROUP_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = (pid_m * BM + tl.arange(0, BM)) % M
    offs_bn = (pid_n * BN + tl.arange(0, BN)) % N
    offs_k  = tl.arange(0, BK)
    a_ptrs = A + offs_am[:, None] * sa_m + offs_k[None, :] * sa_k
    b_ptrs = B + offs_k[:, None] * sb_k + offs_bn[None, :] * sb_n

    acc = tl.zeros([BM, BN], dtype=tl.float32)
    for k in range(0, tl.cdiv(K, BK)):
        k_remaining = K - k * BK
        a_mask = offs_k[None, :] < k_remaining
        b_mask = offs_k[:, None] < k_remaining
        a = tl.load(a_ptrs, mask=a_mask, other=0.0)
        b = tl.load(b_ptrs, mask=b_mask, other=0.0)
        acc = tl.dot(a, b, acc=acc)
        a_ptrs += BK * sa_k
        b_ptrs += BK * sb_k

    offs_cm = pid_m * BM + tl.arange(0, BM)
    offs_cn = pid_n * BN + tl.arange(0, BN)
    c_ptrs = C + offs_cm[:, None] * sc_m + offs_cn[None, :] * sc_n
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, acc.to(C.dtype.element_ty), mask=c_mask)

def matmul(a, b):
    assert a.shape[1] == b.shape[0]
    M, K = a.shape
    _, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=a.dtype)
    grid = lambda meta: (triton.cdiv(M, meta['BM']) * triton.cdiv(N, meta['BN']),)
    matmul_tiled[grid](
        a, b, c, M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    )
    return c

Why the swizzle? Without it, neighbouring program ids cover neighbouring N columns of C with the same row of A. With it, neighbours share both A and B reuse, which keeps the L2 hit rate high on big problems. This trick alone is often worth 20-40% on large square matmul.

On A100/H100 this kernel reaches roughly 80% of cuBLAS for square matmul with M,N,K multiples of 128, and it frequently beats cuBLAS on irregular shapes (e.g. K=80, common in attention with head_dim=80). Do not take exact numbers on faith; benchmark on your hardware.

8.4 Fused softmax (rowwise, single pass-per-row)

For rows that fit entirely in one block (N <= BLOCK_N):

@triton.jit
def softmax_kernel(X, Y, sx, sy, N,
                   BLOCK: tl.constexpr):
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK)
    mask = cols < N
    x = tl.load(X + row * sx + cols, mask=mask, other=-float("inf"))
    x = x - tl.max(x, axis=0)
    num = tl.exp(x)
    den = tl.sum(num, axis=0)
    y = num / den
    tl.store(Y + row * sy + cols, y, mask=mask)

def softmax(x):
    M, N = x.shape
    BLOCK = triton.next_power_of_2(N)
    y = torch.empty_like(x)
    softmax_kernel[(M,)](x, y, x.stride(0), y.stride(0), N, BLOCK=BLOCK,
                         num_warps=4 if BLOCK <= 1024 else 8)
    return y

Notes:

  • other=-inf for masked-off lanes ensures they cannot become the max and do not contribute to the sum (exp(-inf) = 0).
  • This whole row lives in registers, so we have a true one-pass softmax; no online-softmax recurrence needed.
  • For very wide rows (N > 64K), you would use a tiled version with the online-softmax recurrence from section 7.

8.5 RMSNorm forward + backward

RMSNorm: y_i = x_i / sqrt(mean(x^2) + eps) * w_i.

Forward.

@triton.jit
def rmsnorm_fwd(X, W, Y, RSTD, sx, sy, N, eps,
                BLOCK: tl.constexpr):
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK)
    mask = cols < N
    x = tl.load(X + row * sx + cols, mask=mask, other=0.0).to(tl.float32)
    var = tl.sum(x * x, axis=0) / N
    rstd = 1.0 / tl.sqrt(var + eps)
    tl.store(RSTD + row, rstd)             # save for backward
    w = tl.load(W + cols, mask=mask, other=0.0)
    y = x * rstd * w
    tl.store(Y + row * sy + cols, y.to(Y.dtype.element_ty), mask=mask)

Backward. Given dy, compute dx, dw. The math:

y_i = x_i * w_i * r,   where r = (mean(x^2) + eps)^{-1/2}
dr/dx_j = -(1/N) * x_j * r^3                          (since d/dx_j var = 2 x_j / N)
dy_i/dx_j = w_i * r * delta_{ij} + x_i * w_i * dr/dx_j
         = w_i * r * delta_{ij} - (1/N) * x_i * w_i * x_j * r^3

dx_j = sum_i dy_i * dy_i/dx_j
     = w_j * r * dy_j - (r^3 / N) * x_j * sum_i (x_i * w_i * dy_i)
     = r * (w_j * dy_j - x_j * (r^2 / N) * S),   where S = sum_i x_i * w_i * dy_i

So we need a single reduction S per row, then a fused elementwise pass:

@triton.jit
def rmsnorm_bwd_dx(X, W, DY, RSTD, DX, sx, sdy, sdx, N,
                   BLOCK: tl.constexpr):
    row = tl.program_id(0)
    cols = tl.arange(0, BLOCK)
    mask = cols < N
    x  = tl.load(X  + row * sx  + cols, mask=mask, other=0.0).to(tl.float32)
    w  = tl.load(W                  + cols, mask=mask, other=0.0).to(tl.float32)
    dy = tl.load(DY + row * sdy + cols, mask=mask, other=0.0).to(tl.float32)
    r  = tl.load(RSTD + row).to(tl.float32)
    S  = tl.sum(x * w * dy, axis=0)
    dx = r * (w * dy - x * (r * r / N) * S)
    tl.store(DX + row * sdx + cols, dx, mask=mask)

dW is reduced across rows; do it in a second small kernel (or fold it into the same kernel with atomics if you have many rows; pure reduce-then- store is usually faster).

8.6 Causal flash-attention forward (simplified)

We compute O = softmax(Q K^T / sqrt(d)) V, with a causal mask, in tiles. The trick is to never materialise the seq_len x seq_len score matrix -- we keep an online (m, l) per output row block and stream over key/value blocks. Backward is more complex; we omit it.

Shapes: Q, K, V are [B, H, S, D]; we run one program per (batch, head, M-block of Q rows). For brevity we drop the batch/head dims; assume Q, K, V are 2D [S, D] and that the launcher iterates over BH.

@triton.jit
def flash_attn_fwd(Q, K, V, O, L, M,            # M, L for backward
                   sq_s, sq_d, sk_s, sk_d, sv_s, sv_d, so_s, so_d,
                   S, D: tl.constexpr,
                   BM: tl.constexpr, BN: tl.constexpr,
                   CAUSAL: tl.constexpr):
    pid_m = tl.program_id(0)
    offs_m = pid_m * BM + tl.arange(0, BM)
    offs_d = tl.arange(0, D)
    offs_n = tl.arange(0, BN)

    # Load Q-tile once and keep it in registers
    q_ptrs = Q + offs_m[:, None] * sq_s + offs_d[None, :] * sq_d
    q_mask = offs_m[:, None] < S
    q = tl.load(q_ptrs, mask=q_mask, other=0.0)

    # Online softmax state
    m_i = tl.full([BM], -float("inf"), dtype=tl.float32)
    l_i = tl.zeros([BM], dtype=tl.float32)
    acc = tl.zeros([BM, D], dtype=tl.float32)

    scale = 1.0 / tl.sqrt(tl.full([], D, dtype=tl.float32))

    # Causal: only iterate up to the last K-block that can contribute.
    n_end = (pid_m + 1) * BM if CAUSAL else S
    for start_n in range(0, n_end, BN):
        cur_n = start_n + offs_n
        k_ptrs = K + cur_n[None, :] * sk_s + offs_d[:, None] * sk_d
        v_ptrs = V + cur_n[:, None] * sv_s + offs_d[None, :] * sv_d
        k_mask = cur_n[None, :] < S
        v_mask = cur_n[:, None] < S
        k = tl.load(k_ptrs, mask=k_mask, other=0.0)
        v = tl.load(v_ptrs, mask=v_mask, other=0.0)

        # Scores [BM, BN]
        s = tl.dot(q, k) * scale
        if CAUSAL:
            causal_mask = offs_m[:, None] >= cur_n[None, :]
            s = tl.where(causal_mask, s, float("-inf"))
        # also mask off out-of-range cur_n
        s = tl.where(cur_n[None, :] < S, s, float("-inf"))

        # Online softmax update
        m_new = tl.maximum(m_i, tl.max(s, axis=1))
        alpha = tl.exp(m_i - m_new)
        p = tl.exp(s - m_new[:, None])
        l_i = alpha * l_i + tl.sum(p, axis=1)
        acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)
        m_i = m_new

    o = acc / l_i[:, None]
    o_ptrs = O + offs_m[:, None] * so_s + offs_d[None, :] * so_d
    tl.store(o_ptrs, o.to(O.dtype.element_ty), mask=q_mask)
    # Save log-sum-exp = m_i + log(l_i) for backward
    tl.store(L + offs_m, m_i + tl.log(l_i), mask=offs_m < S)

This is not the fastest flash-attention -- the production versions split the K loop differently, fuse the dropout, use TMA on Hopper, and have a dedicated backward kernel. But the structure -- Q-tile in registers, streaming K/V tiles, online (m, l), accumulator rescaled by alpha -- is the same and is the form you should be able to derive from scratch.

Two subtle points the code makes explicit:

  1. The accumulator update is acc = acc * alpha + dot(p, v). The alpha scaling corrects every previous K-block's contribution to the new denominator. Without it the answer is wrong.
  2. The causal mask is applied before the softmax exp, by setting out-of-bounds positions to - inf. Afterexp, those become0, which is the additive identity for the running sum and for thedot(p, v)` accumulation.

9. The compilation pipeline

9.1 The path

Python @triton.jit function
   |  AST capture
   v
Triton IR  (an MLIR dialect)
   |  high-level optimisation:
   |    layout selection, broadcasting elimination,
   |    reduction lowering, masked-load elision
   v
TritonGPU IR  (an MLIR dialect, HW-aware)
   |  GPU-specific:
   |    shared-memory allocation, pipelining (cp.async / TMA),
   |    swizzling, MMA selection
   v
LLVM IR
   |
   v
PTX (NVIDIA) or AMDGCN (AMD) or other backend
   |
   v
SASS (assembled by ptxas at load time)

You almost never need to look at the intermediate IRs. You frequently want to look at the PTX.

9.2 Inspecting compiled artefacts

kernel = matmul_tiled[grid](...)        # first call, compiles
# Get the compiled handle:
fn = matmul_tiled.warmup(*args, grid=grid)   # one approach
# Or, after a real call:
print(matmul_tiled.cache.values())
# For a specific compiled binary:
binary = next(iter(matmul_tiled.cache.values()))
print(binary.asm['ttir'])    # Triton IR
print(binary.asm['ttgir'])   # TritonGPU IR
print(binary.asm['llir'])    # LLVM IR
print(binary.asm['ptx'])     # PTX
print(binary.asm['cubin'])   # binary cubin (bytes)

(API has shifted across releases; the exact attribute names may differ in your version. Verify with dir(binary.asm).)

9.3 What to look for in PTX

  • ld.global.v4.f32 / ld.global.v8.f16 -- vectorised loads. If your kernel emits scalar ld.global.f32, your indexing is not coalesced.
  • mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 -- tensor-core MMA. No such instruction means you are not using tensor cores.
  • cp.async.cg.shared.global (Ampere) or cp.async.bulk.tensor (Hopper) -- asynchronous global-to-shared copies, used by software pipelining.
  • bar.sync -- barriers between pipeline stages. Too many usually means num_stages is too low and you are stalling.

9.4 Useful environment variables

  • TRITON_PRINT_AUTOTUNING=1 -- prints the timing table for autotune.
  • TRITON_CACHE_DIR=/path -- override the default cache location.
  • TRITON_DEBUG=1 -- prints extra diagnostics.
  • TRITON_INTERPRET=1 -- runs the kernel on the CPU in pure Python emulation. Excruciatingly slow but the only way to single-step a Triton kernel under pdb. Worth the cost when you have a correctness bug you cannot localise.

10. Integration with PyTorch

10.1 Calling a Triton kernel from PyTorch

The simplest path: a Python wrapper around the kernel, callable from ordinary user code. The wrappers in section 8 (add, matmul, softmax) are exactly this. PyTorch tensors are passed directly; their .data_ptr() becomes the kernel's pointer argument, dtype and strides are read from the tensor metadata.

10.2 Autograd

Subclass torch.autograd.Function:

class RMSNormFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, w, eps):
        M, N = x.shape
        y = torch.empty_like(x)
        rstd = torch.empty(M, device=x.device, dtype=torch.float32)
        BLOCK = triton.next_power_of_2(N)
        rmsnorm_fwd[(M,)](x, w, y, rstd, x.stride(0), y.stride(0), N, eps,
                          BLOCK=BLOCK, num_warps=4)
        ctx.save_for_backward(x, w, rstd)
        ctx.N, ctx.BLOCK = N, BLOCK
        return y

    @staticmethod
    def backward(ctx, dy):
        x, w, rstd = ctx.saved_tensors
        dx = torch.empty_like(x)
        M = x.shape[0]
        rmsnorm_bwd_dx[(M,)](x, w, dy, rstd, dx,
                             x.stride(0), dy.stride(0), dx.stride(0),
                             ctx.N, BLOCK=ctx.BLOCK, num_warps=4)
        # dW: reduce x*rstd*dy over the row dim
        dw = (x * rstd[:, None] * dy).sum(dim=0)   # let PyTorch handle it
        return dx, dw, None

def rmsnorm(x, w, eps=1e-6):
    return RMSNormFn.apply(x, w, eps)

The dW reduction is small enough that letting PyTorch handle it is fine; fusing it into a Triton kernel is an optimisation, not a correctness step.

10.3 torch.library ops

For interaction with torch.compile, register your Triton kernel as a custom op so the compiler treats it as opaque and does not try to trace through @triton.jit:

import torch
from torch.library import custom_op, register_fake

@custom_op("mylib::rmsnorm", mutates_args=())
def rmsnorm_op(x: torch.Tensor, w: torch.Tensor, eps: float) -> torch.Tensor:
    return RMSNormFn.apply(x, w, eps)

@register_fake("mylib::rmsnorm")
def _(x, w, eps):
    return torch.empty_like(x)        # shape/stride meta, no compute

The register_fake (formerly register_meta) provides shape inference for torch.compile's symbolic shape tracing. Without it torch.compile will fail to trace through your op.

10.4 torch.compile

torch.compile already emits Triton kernels for many fusions (it has a backend called Inductor). Your handwritten Triton kernels coexist: Inductor will compile around them and treat them as black boxes (provided you registered them via torch.library). If you write a really good kernel that beats Inductor on a hot path -- common for attention variants and custom norms -- this is the way to ship it.

Two cautions:

  • torch.compile expects shape-stable callees. If your custom op recompiles on every shape it slows the graph; pre-warm with the shapes you care about.
  • mutates_args=() is critical for op safety. If you do mutate arguments (e.g. an in-place kernel) declare them or you will silently corrupt the graph cache.

11. Common pitfalls

Shape-dependent recompilation. Every unique constexpr value triggers a new compile. Pin block sizes to a finite menu. Autotune key buckets your shapes; that is fine. What is not fine is BLOCK = triton.next_power_of_2(N) when N varies continuously -- you will have a compile per N. Solution: quantise (BLOCK = max(64, triton.next_power_of_2(N)) -- ok if N has a small set of values; not ok in general).

Mask correctness. Two failure modes:

  • Forgetting to mask. Boundary out-of-bounds reads on Ampere often appear to "work" because they read zeros from a happy memory region; on Hopper with TMA they hard-fault. Always mask.
  • other= mismatch. For a softmax row, other=0.0 makes masked lanes the max if all real values are negative -- wrong. Use other=-inf for max-reductions, other=0.0 for sum-reductions and dot accumulators.

num_stages too high. Each pipeline stage holds a full [BM, BK] and [BK, BN] tile in shared memory. Running out of shared memory causes silent fallback to fewer stages or a hard compile error, depending on version. If autotune timing for a config is much worse than a config with smaller BM*BK + BK*BN, you are spilling.

Register-pressure spills. Same root cause: tiles too big for registers. Symptoms: ptxas reports spills (ptxas info: ... 2048 bytes stack frame, ... bytes spill stores), kernel runs at fraction of expected speed. Reduce BM, BN, or split the K loop.

Forgetting .to(dtype) on store. acc is fp32, your output buffer is fp16. tl.store(C_ptr, acc, ...) will implicitly cast in modern Triton but the cast may not be the rounding mode you expect (RTNE vs. RTZ). Be explicit: acc.to(C.dtype.element_ty).

Confusing tl.program_id axes with grid axes. tl.program_id(0) is the first grid dim, not the X coordinate of anything physical. With a 2D grid (grid_m, grid_n), tl.program_id(0) ranges 0..grid_m-1 and tl.program_id(1) 0..grid_n-1. Be consistent.

Autotune triggering on every call. Happens if your key= does not include all shape-dependent args (e.g. you forgot K). The cache key collides across distinct shapes and the autotuner re-runs. Always include every shape-relevant runtime arg in key=.

Stride bugs. Always pass strides explicitly; never assume contiguous. A view of a transposed tensor will have non-trivial strides; using T without checking is a frequent source of garbage outputs.

Thinking tl.dot is matrix multiplication. It is block matrix multiplication -- one tile from the operand-A block and one tile from operand-B block. The K-loop is yours.


12. Triton vs CUTLASS vs hand-rolled CUDA

Rough rules of thumb. None are universal; benchmark your case.

Use case Pick Why
Standard square matmul, fp16/bf16, big shapes cuBLAS Vendor-tuned per arch; near-peak.
GEMM with epilogue fusion (bias + GeLU + dropout) CUTLASS or Triton cuBLAS does no fusion.
Custom fused norm / softmax / attention Triton Productivity is decisive; perf is good.
Bleeding-edge attention (FA-3 style) Hand CUDA / CUTLASS TMA, warp specialisation, async fence patterns -- still ahead of Triton in 2025/2026.
Quirky shapes (small K, irregular masking) Triton Beats vendor libs frequently.
Sparse / structured-sparsity kernels CUTLASS Has sparse MMA primitives.
Research prototype, weeks-not-months Triton Always.

CUTLASS is a C++ template library from NVIDIA: more verbose than Triton, more flexible than cuBLAS, exposes the actual MMA shapes and async pipeline primitives. Hand-rolled CUDA gets you the last 5-15% on niche ops; the engineering cost is large.

A reasonable rule: write everything in Triton first, profile, replace the top-1 hot kernel with CUTLASS or hand CUDA only if numbers force you to.


13. Exercises

Six problems, ordered by difficulty. Try each before reading the sketch.

13.1 Fused bias + ReLU + scale

Write a kernel y = relu(x + b) * s where x is [M, N], b is [N] (broadcast across rows), s is a scalar. One launch, one pass.

Sketch. 1D grid of M*ceil(N/BN) programs, or 2D grid of (M, ceil(N/BN)). Load x tile [1, BN], load b [BN] once per program, broadcast add, tl.where(z > 0, z, 0.0) * s, store. ~25 lines. Time should be HBM-bandwidth-bound.

13.2 Rowwise top-k mask (k=1)

Given x [M, N], output y of same shape with y[i, j] = x[i, j] if x[i, j] is the rowwise max, else 0. Single pass per row.

Sketch. Like softmax: load row into registers, m = tl.max(x, axis=0), y = tl.where(x == m, x, 0.0), store. Care with ties (you will mark all of them; usually fine). Width must fit in BLOCK.

13.3 LayerNorm forward with Welford

Write a LayerNorm forward that computes mean and variance with one pass using Welford's online algorithm. Compare numerical accuracy to the naive formula on a [1, 65536] input of magnitude 1e6.

Sketch. Tile the row into chunks of BLOCK, accumulate (n, mean, M2) across chunks with the combine formulas in section 7.3. After the loop, var = M2 / N, rstd = rsqrt(var + eps), second pass to compute and store y. Or, since the row fits in registers when N <= 64K, use the naive single-block reduction path -- in fp32 that is also stable for any realistic magnitude. Welford pays off when you must tile.

13.4 Causal-attention forward, drop the causal mask

Modify the flash-attention kernel in 8.6 so it accepts a CAUSAL: tl.constexpr parameter and, when False, iterates over the full K range. Verify: when CAUSAL=False, output equals softmax(QK^T/sqrt(d)) V to within fp16 precision.

Sketch. The kernel as written already has CAUSAL. The change is the n_end = (pid_m + 1) * BM if CAUSAL else S line and the tl.where causal mask gated on if CAUSAL:. Confirm that the if is on CAUSAL (a constexpr) so it is compiled away when False.

13.5 Matmul autotune key bug

A user complains: "every call recompiles, even at the same shape." Their decorator is @triton.autotune(configs=[...], key=['M', 'N']) for an M x K by K x N matmul. What is wrong?

Sketch. K is missing from the key. The autotuner caches only by (M, N), so (M, N, K1) and (M, N, K2) collide and the autotuner reruns benchmarks each time K changes. Add K to key.

13.6 Spotting tensor-core usage in PTX

Take the matmul kernel from 8.3, compile it for fp16 inputs, dump the PTX, and find the MMA instructions. What is the M-N-K shape of each mma.sync? What does it imply about the tile sizes Triton chose internally?

Sketch. Look for mma.sync.aligned.m16n8k16 (Ampere fp16) or wgmma.mma_async (Hopper). On Ampere fp16, the per-warp MMA shape is 16x8x16, so a [128, 128] accumulator across 4 warps is built from many of these MMAs in a tile schedule chosen by the compiler. The takeaway: the fundamental hardware tile is small; Triton's BM, BN, BK are block-level tiles assembled out of many MMAs.


14. Closing notes

You now have:

  • the conceptual model (block-level program, per-instance abstraction),
  • the API surface (tl.load/store/dot/sum/max, masks, constexpr, autotune),
  • the numerical-stability primitives (online softmax, Welford, log-sum-exp),
  • six worked kernels covering the 90% case (elementwise, matmul, softmax, norm, attention),
  • the integration story for PyTorch (autograd.Function, torch.library, torch.compile),
  • a debugging path (PTX inspection, TRITON_INTERPRET, autotune logs),
  • a comparative map against CUTLASS and hand CUDA.

Two final discipline rules. Always benchmark with torch.cuda.synchronize inside triton.testing.do_bench -- forgetting do_bench's warmup discards the autotune-search time and gives you misleading numbers. Always check correctness against a PyTorch reference first, with torch.testing.assert_close(out_triton, out_torch, rtol=1e-2, atol=1e-2) for fp16, before you start chasing performance. A faster wrong kernel is not a kernel.

The single strongest piece of advice anyone has given about Triton: write the dumb version first, autotune it, look at the PTX once to make sure tensor cores fired, then move on. The framework is designed to reward that workflow. The hours you would have spent on per-thread indexing in CUDA become hours you spend on actual algorithms.

Comments