Triton
Prereqs: SM Architecture, Thread Hierarchy, Shared Memory. Triton hides threads but not blocks; you need the SM picture in your head.
Before , fast GPU code meant CUDA: 800 lines of C++ with intrinsics, template metaprogramming, manual cp.async orchestration, and a multi-day debugging cycle for every shape change. Hand-written tensor-core kernels were a specialist sport. Then OpenAI shipped Triton — a Python-syntax DSL where you program the block (one CTA’s worth of work), operate on whole tiles (vectors and matrices), and let the compiler pick the threads-per-warp distribution and emit register-tiled, SMEM-staged, -using kernels.
The mental shift is the lesson: you write what looks like NumPy on tiles, and the compiler turns it into the C++ you’d have hand-written. A kernel engineer who could ship one new optimized op per quarter in CUDA can ship one per week in Triton. Almost every modern open-source kernel project — vLLM, FlashAttention(-3), liger-kernel, Unsloth, the entirety of torch.compile’s code generation — uses Triton as its primary authoring surface. Knowing Triton fluently is table stakes for kernel work in 2026.
TL;DR
- Triton is a Python-syntax kernel DSL that compiles via MLIR (TritonGPU dialect → llvm) to PTX. You write what looks like NumPy on tiles; the compiler emits register-tiled, SMEM-staged, tensor-core-using kernels.
- The mental shift: you program the block, not the thread. A Triton “program” is one CTA’s worth of work. Inside it you operate on whole tiles (vectors, matrices); the compiler picks the threads-per-warp distribution.
@triton.autotunepicks the best (BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages) per-shape on first run. This is the feature that makes Triton competitive with hand-tuned CUTLASS without the maintenance.- Triton is the daily-driver kernel language of OpenAI, the Triton-using parts of PyTorch (
torch.compilelowers to it), and most performance-critical OSS work in 2025–2026. Hand-written CUTLASS still wins by 5–10% on edge cases. - The 2024 frontier: Triton 3.x adds Hopper TMA and warp specialization. Triton on Blackwell (5th-gen tensor cores, FP4) lands incrementally through 2025–2026.
Mental model
Same compilation chain as the Foundation lessons taught — Triton just gives you an ergonomic frontend.
Hello, Triton
The single canonical example: vector add.
import triton
import triton.language as tl
import torch
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n_elements,
BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0) # this CTA's id along axis 0
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
tl.store(out_ptr + offsets, x + y, mask=mask)
def add(x, y):
out = torch.empty_like(x)
n = x.numel()
BLOCK = 1024
grid = ((n + BLOCK - 1) // BLOCK,)
add_kernel[grid](x, y, out, n, BLOCK_SIZE=BLOCK)
return outThings to notice:
pid = tl.program_id(0)— this is the CTA’s index along axis 0 of the launch grid. There’s nothreadIdx. The compiler picks the threads-per-warp layout.tl.arange(0, BLOCK_SIZE)— a vector of integers[0, 1, ..., BLOCK_SIZE-1]. This is your tile axis.mask— boundary tiles get partially-filled; the mask hides the off-end lanes during loads/stores. Skip the mask and you crash.tl.load,tl.store— vector ops over the tile. The compiler emits coalesced loads, often via cp.async on Hopper.BLOCK_SIZE: tl.constexpr— compile-time constant. DifferentBLOCK_SIZE→ different compiled kernel.
Matmul in 25 lines
Block-per-tile, identical to the CUDA version from Thread Hierarchy, much shorter:
@triton.autotune(
configs=[
triton.Config({'BLOCK_M':128,'BLOCK_N':128,'BLOCK_K':32}, num_warps=4, num_stages=3),
triton.Config({'BLOCK_M':128,'BLOCK_N':64, 'BLOCK_K':32}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_M':64, 'BLOCK_N':128,'BLOCK_K':32}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_M':128,'BLOCK_N':128,'BLOCK_K':64}, num_warps=8, num_stages=3),
],
key=['M', 'N', 'K'], # autotune cache keyed on these shape values
)
@triton.jit
def matmul_kernel(A, B, C, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for _ in range(0, K, BLOCK_K):
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
acc += tl.dot(a, b) # this is one tensor-core mma
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, acc.to(tl.float16))That tl.dot(a, b) line is the entire tensor-core call. Triton emits mma.sync instructions; on Hopper it emits WGMMA; on Blackwell, 5th-gen tensor cores. You don’t change a line of code. This is the productivity story.
Autotuning — what’s actually happening
compiles each config the first time the kernel is called with new shape values matching key=. It runs each config a few times, picks the fastest, caches the choice. Subsequent calls with matching shapes use the cached winner.
out = matmul_func(A, B) # first call: ~5–30 s as it tunes
out = matmul_func(A2, B2) # same shapes: instant, uses cached config
out = matmul_func(A3, B3) # different shapes: tunes againThe configs in your list aren’t magic — they’re the bands you’ve decided are reasonable. Standard recipe: include 4–8 configs spanning small/large block sizes and 4/8 warps. Triton picks. The compile time is real (especially with many configs); cache aggressively.
TritonGPU and the IR
Internally, Triton compiles your code to TritonGPU dialect in MLIR. You can see this:
TRITON_INTERPRET=0 TRITON_KERNEL_DUMP=/tmp/triton_dump python script.pyIn /tmp/triton_dump/<your_kernel>/ you’ll find *.ttir (high-level Triton IR), *.ttgir (TritonGPU IR with thread/tile layout), *.llir (LLVM IR), *.ptx, *.cubin. Reading these dumps is how kernel engineers debug “why is my kernel slow” — usually it’s a layout choice the autotuner made that’s wrong for your shapes.
When Triton wins, when it loses
Wins:
- Shapes the autotune covers well (typical AI workloads: matmul, attention, reductions).
- Anything with regular tiling.
- Anywhere developer time matters more than the last 5%.
- Cross-architecture (a Triton kernel runs on Ampere, Hopper, AMD MI300X, in principle Blackwell) without source change.
Loses to CUTLASS or hand-written CUDA:
- Edge shapes (very small N, weird strides, mixed-precision oddities).
- Kernels needing precise control over warp specialization, async pipelines, register banking.
- Anything pre-existing in CUTLASS that you can paste in.
In practice, ~95% of new production kernel work in 2025–2026 starts in Triton; is the fallback when the last 5–10% matters.
Run it in your browser — Triton-shaped tile algebra
Pyodide doesn’t have CUDA, but we can demonstrate the programming model — vectorized tile ops, masked loads, accumulator patterns — in numpy.
The shape — outer loop over tiles, inner accumulator, masked loads at boundaries — is exactly the structure you’ll write in Triton. The only difference on the GPU is tl.dot becomes a tensor-core call and the entire tile lives in registers.
Quick check
Key takeaways
- Triton = block-programming.
program_idinstead of threadIdx; tile-shaped tensors instead of per-thread scalars. tl.dotis the tensor-core call. One line, all the mma machinery.@triton.autotuneis what makes it ship. It picks (BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages) per shape. Cache aggressively.- TritonGPU dialect → llvm → PTX → SASS. When debugging, dump the IR at every level.
- Triton wins on developer time, CUTLASS wins on the last 5%. Most new production kernel work in 2026 starts here.
Go deeper
- PaperTriton: An Intermediate Language and Compiler for Tiled Neural Network ComputationsThe original paper. Section 4 has the tile-IR design that became TritonGPU.
- DocsTriton DocumentationUp-to-date for Triton 3.x. The "Tutorials" pages are the canonical entry point — vector add, matmul, fused softmax, dropout, attention.
- BlogOpenAI Blog — Introducing TritonThe original announcement. Useful for the "why" before diving into the tutorials.
- VideoPhil Tillet — Triton: Programming Models for Hardware AcceleratorsTalk by Triton's author. Best motivation for the design choices.
- BlogColfax — Triton Deep Dive SeriesModern (2024) walkthroughs of Hopper-specific Triton features (TMA, WGMMA).
- Repotriton-lang/tritonThe compiler. `python/triton/` is the frontend; `lib/Conversion/` is where the lowerings live.
- RepoDao-AILab/flash-attentionThe Triton FlashAttention reference — a real-world frontier kernel readable in an evening.
Prereqs: SM Architecture, Thread Hierarchy, Shared Memory. Triton hides threads but not blocks; you need the SM picture in your head.
TL;DR
- Triton is a Python-syntax kernel DSL that compiles via MLIR (TritonGPU dialect → llvm) to PTX. You write what looks like NumPy on tiles; the compiler emits register-tiled, SMEM-staged, tensor-core-using kernels.
- The mental shift: you program the block, not the thread. A Triton “program” is one CTA’s worth of work. Inside it you operate on whole tiles (vectors, matrices); the compiler picks the threads-per-warp distribution.
@triton.autotunepicks the best (BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages) per-shape on first run. This is the feature that makes Triton competitive with hand-tuned CUTLASS without the maintenance.- Triton is the daily-driver kernel language of OpenAI, the Triton-using parts of PyTorch (
torch.compilelowers to it), and most performance-critical OSS work in 2025–2026. Hand-written CUTLASS still wins by 5–10% on edge cases. - The 2024 frontier: Triton 3.x adds Hopper TMA and warp specialization. Triton on Blackwell (5th-gen tensor cores, FP4) lands incrementally through 2025–2026.
Why this matters
Before Triton, fast GPU code meant CUDA: 800 lines of C++ with intrinsics, template metaprogramming, and a multi-day debugging cycle for every shape change. After Triton, fast GPU code means 80 lines of Python that the compiler turns into the C++ you’d have hand-written. The productivity multiplier is real — a kernel engineer who could ship one new optimized op per quarter in CUDA can ship one per week in Triton. Almost every modern open-source kernel project — vLLM, FlashAttention(-3), liger-kernel, Unsloth, the entirety of torch.compile’s code generation — uses Triton as its primary authoring surface.
Knowing Triton fluently is the table stakes for kernel work in 2026.
Mental model
Same compilation chain as the Foundation lessons taught — Triton just gives you an ergonomic frontend.
Concrete walkthrough
Hello, Triton
The single canonical example: vector add.
import triton
import triton.language as tl
import torch
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n_elements,
BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0) # this CTA's id along axis 0
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
tl.store(out_ptr + offsets, x + y, mask=mask)
def add(x, y):
out = torch.empty_like(x)
n = x.numel()
BLOCK = 1024
grid = ((n + BLOCK - 1) // BLOCK,)
add_kernel[grid](x, y, out, n, BLOCK_SIZE=BLOCK)
return outThings to notice:
pid = tl.program_id(0)— this is the CTA’s index along axis 0 of the launch grid. There’s nothreadIdx. The compiler picks the threads-per-warp layout.tl.arange(0, BLOCK_SIZE)— a vector of integers[0, 1, ..., BLOCK_SIZE-1]. This is your tile axis.mask— boundary tiles get partially-filled; the mask hides the off-end lanes during loads/stores. Skip the mask and you crash.tl.load,tl.store— vector ops over the tile. The compiler emits coalesced loads, often via cp.async on Hopper.BLOCK_SIZE: tl.constexpr— compile-time constant. DifferentBLOCK_SIZE→ different compiled kernel.
Matmul in 25 lines
Block-per-tile, identical to the CUDA version from Thread Hierarchy, much shorter:
@triton.autotune(
configs=[
triton.Config({'BLOCK_M':128,'BLOCK_N':128,'BLOCK_K':32}, num_warps=4, num_stages=3),
triton.Config({'BLOCK_M':128,'BLOCK_N':64, 'BLOCK_K':32}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_M':64, 'BLOCK_N':128,'BLOCK_K':32}, num_warps=4, num_stages=4),
triton.Config({'BLOCK_M':128,'BLOCK_N':128,'BLOCK_K':64}, num_warps=8, num_stages=3),
],
key=['M', 'N', 'K'], # autotune cache keyed on these shape values
)
@triton.jit
def matmul_kernel(A, B, C, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for _ in range(0, K, BLOCK_K):
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
acc += tl.dot(a, b) # this is one tensor-core mma
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c_ptrs = C + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
tl.store(c_ptrs, acc.to(tl.float16))That tl.dot(a, b) line is the entire tensor-core call. Triton emits mma.sync instructions; on Hopper it emits WGMMA; on Blackwell, 5th-gen tensor cores. You don’t change a line of code. This is the productivity story.
Autotuning — what’s actually happening
@triton.autotune compiles each config the first time the kernel is called with new shape values matching key=. It runs each config a few times, picks the fastest, caches the choice. Subsequent calls with matching shapes use the cached winner.
out = matmul_func(A, B) # first call: ~5–30 s as it tunes
out = matmul_func(A2, B2) # same shapes: instant, uses cached config
out = matmul_func(A3, B3) # different shapes: tunes againThe configs in your list aren’t magic — they’re the bands you’ve decided are reasonable. Standard recipe: include 4–8 configs spanning small/large block sizes and 4/8 warps. Triton picks. The compile time is real (especially with many configs); cache aggressively.
TritonGPU and the IR
Internally, Triton compiles your code to TritonGPU dialect in MLIR. You can see this:
TRITON_INTERPRET=0 TRITON_KERNEL_DUMP=/tmp/triton_dump python script.pyIn /tmp/triton_dump/<your_kernel>/ you’ll find *.ttir (high-level Triton IR), *.ttgir (TritonGPU IR with thread/tile layout), *.llir (LLVM IR), *.ptx, *.cubin. Reading these dumps is how kernel engineers debug “why is my kernel slow” — usually it’s a layout choice the autotuner made that’s wrong for your shapes.
When Triton wins, when it loses
Wins:
- Shapes the autotune covers well (typical AI workloads: matmul, attention, reductions).
- Anything with regular tiling.
- Anywhere developer time matters more than the last 5%.
- Cross-architecture (a Triton kernel runs on Ampere, Hopper, AMD MI300X, in principle Blackwell) without source change.
Loses to CUTLASS or hand-written CUDA:
- Edge shapes (very small N, weird strides, mixed-precision oddities).
- Kernels needing precise control over warp specialization, async pipelines, register banking.
- Anything pre-existing in CUTLASS that you can paste in.
In practice, ~95% of new production kernel work in 2025–2026 starts in Triton; CUTLASS is the fallback when the last 5–10% matters.
Run it in your browser — Triton-shaped tile algebra
Pyodide doesn’t have CUDA, but we can demonstrate the programming model — vectorized tile ops, masked loads, accumulator patterns — in numpy.
The shape — outer loop over tiles, inner accumulator, masked loads at boundaries — is exactly the structure you’ll write in Triton. The only difference on the GPU is tl.dot becomes a tensor-core call and the entire tile lives in registers.
Quick check
Key takeaways
- Triton = block-programming.
program_idinstead of threadIdx; tile-shaped tensors instead of per-thread scalars. tl.dotis the tensor-core call. One line, all the mma machinery.@triton.autotuneis what makes it ship. It picks (BLOCK_M, BLOCK_N, BLOCK_K, num_warps, num_stages) per shape. Cache aggressively.- TritonGPU dialect → llvm → PTX → SASS. When debugging, dump the IR at every level.
- Triton wins on developer time, CUTLASS wins on the last 5%. Most new production kernel work in 2026 starts here.
Go deeper
- PaperTriton: An Intermediate Language and Compiler for Tiled Neural Network ComputationsThe original paper. Section 4 has the tile-IR design that became TritonGPU.
- DocsTriton DocumentationUp-to-date for Triton 3.x. The "Tutorials" pages are the canonical entry point — vector add, matmul, fused softmax, dropout, attention.
- BlogOpenAI Blog — Introducing TritonThe original announcement. Useful for the "why" before diving into the tutorials.
- VideoPhil Tillet — Triton: Programming Models for Hardware AcceleratorsTalk by Triton's author. Best motivation for the design choices.
- BlogColfax — Triton Deep Dive SeriesModern (2024) walkthroughs of Hopper-specific Triton features (TMA, WGMMA).
- Repotriton-lang/tritonThe compiler. `python/triton/` is the frontend; `lib/Conversion/` is where the lowerings live.
- RepoDao-AILab/flash-attentionThe Triton FlashAttention reference — a real-world frontier kernel readable in an evening.