Operator Fusion
Prereqs: SM Architecture, Shared Memory. Fusion exists because HBM is slow and SMEM/registers are fast.
Almost every “we made the model 2× faster” win in production AI compilers is actually a fusion win. The math didn’t get faster; the memory traffic dropped. A pointwise op like add, mul, relu, sigmoid reads its inputs from and writes its outputs back. Doing this for 5 chained ops means 10 HBM trips and ~5× the bandwidth a single-pass implementation would need. generates one kernel that does the whole chain, keeping intermediates in registers — typically 2–10× speedup on bandwidth-bound chains.
This is the lesson where reading torch.compile’s output and seeing one big fused kernel where there used to be ten small ones starts making sense. Knowing what fuses and what doesn’t is what lets you reason about why your model is fast or slow without reading PTX.
TL;DR
- A pointwise op like
add,mul,relu,sigmoidreads its inputs from HBM and writes its outputs to HBM. Doing this for 5 chained ops means 10 HBM trips and ~5× the bandwidth a single-pass implementation would need. - Operator fusion = generate one kernel that does the whole chain, keeping intermediates in registers. 2–10× speedup on bandwidth-bound chains, often more on long ones.
- Inductor (PyTorch’s
torch.compile), XLA (JAX/TF), TVM, IREE all do this automatically. They differ in how aggressive they are and what they fuse. - Three fusion classes:
- Pointwise + pointwise — trivial; always done.
- Reduction + pointwise (epilogue fusion) — a softmax fused with the multiply that follows it; the canonical attention optimization.
- Matmul + pointwise (output-tile fusion) — a GEMM with bias-add or activation in the same kernel; what CUTLASS calls “epilogue.”
- The fusion boundary is a memory-format change or a non-fusable op (random sample, sort, complex control flow). Compilers cluster fusable ops between these boundaries.
Mental model
Same math, 5× less memory traffic. On bandwidth-bound chains (i.e., almost all pointwise) the fused version runs ~5× faster.
Three classes of fusion
1. Pointwise + pointwise — the canonical case.
y = torch.relu(x * w + b)Naive: three kernels (mul, add, relu). Each one reads its operands from HBM, computes one elementwise op, writes back. Total HBM traffic = 6 reads + 3 writes.
Fused: one kernel does y[i] = max(0, x[i]*w[i] + b[i]) in registers. Total HBM traffic = 3 reads + 1 write. ~2.5× less.
Every modern AI compiler does this trivially. The question is how far it can extend the fusion before it has to break.
2. Reduction + pointwise (epilogue fusion).
y = torch.exp(x) / torch.exp(x).sum(dim=-1, keepdim=True) # softmaxThe reduction (sum) and the divide can be fused: one pass that computes the row sum, then a second tile-traversal that divides each element. With online softmax (the FlashAttention trick — running max + running sum in one pass), even tighter.
3. Matmul + pointwise (output-tile fusion).
y = torch.relu(x @ W + b)The matmul writes its output tile to SMEM/registers. Before writing to HBM, we apply + b and relu in-place. This is what calls an epilogue — a transformation applied to the output tile before final store. Inductor and XLA do it automatically; CUTLASS lets you specify it via the CollectiveEpilogue template (see CuTe & CUTLASS 4).
What blocks fusion
Three categories of fusion-breakers:
- Reductions across the wrong axis. A sum over the last axis fuses with downstream pointwise; a sum that crosses the tile boundary doesn’t, easily.
- Nondeterministic ops. Random sampling, dropout (without a fused-dropout kernel), top-k. The op needs its full output in HBM before the next op can sample from it.
- Memory-format / dtype changes. Casting from FP16 to FP32 mid-chain, transposing, reshaping with stride changes — these often break fusion because the next op needs a different view of memory.
- Control flow. A
torch.where(x > 0, f(x), g(x))that calls different functions per element is hard to fuse cleanly; compilers may evaluate both and select, or break the fusion.
What torch.compile actually does
’s lowering pipeline (simplified):
- Receive an FX graph from Dynamo (Python tracing).
- Decompose composite ops into elementary ones.
- Cluster fusable ops together. The clustering is greedy: walk the graph, group ops that are pointwise-compatible, break clusters at boundaries.
- Lower each cluster to a Triton kernel template. Inductor’s templates are parameterized by the ops in the cluster — the kernel body is generated from the FX subgraph.
- Autotune the resulting kernel.
You can see this:
import torch
torch._inductor.config.trace.enabled = True
@torch.compile
def f(x, w, b): return torch.relu(x @ w + b)
f(x, w, b) # produces /tmp/torchinductor_*/output_code.pyIn output_code.py you’ll see one Triton kernel doing matmul + bias + relu, instead of three separate calls. Open it and you can read the fusion in plain Python.
Inductor’s fusion heuristic, briefly
Inductor scores potential fusions by an estimate of HBM bytes saved. Two ops fuse if:
- Their inputs/outputs share a memory layout (or can be made to).
- One’s output is the other’s input (no extra users) — or they’re sibling reads of the same tensor.
- The combined kernel’s register pressure is below a threshold.
- The combined kernel doesn’t exceed the runtime’s max kernel size.
Most useful debug knob: TORCH_LOGS=fusion python script.py — prints the fusion decisions and why some pairs were rejected. When torch.compile “should have fused something but didn’t,” this is where you look.
When to override
Sometimes the compiler picks a worse fusion than you would. Three escape hatches:
- Decomposition tweaks. Reshape your model to be more fusable:
relu(x * (x > 0))is harder to fuse thantorch.relu(x)because the compiler must prove they’re equivalent. @torch.compile(mode='max-autotune'). Tries more fusion configs.- Hand-write the kernel in Triton or CUTLASS. The tools you’ve already met. Write the fused version, drop it in via
torch.library.custom_op.
In practice: trust Inductor for 90% of cases; reach for hand-fusion only when profiling shows the auto-fusion missed a clear opportunity.
Run it in your browser — fuse a tiny graph
The output mirrors how Inductor reports fusion: a graph of clusters with HBM saving estimates per cluster.
Quick check
Key takeaways
- Fusion saves HBM, not FLOPs. The math is the same; the bandwidth traffic drops. That’s where the speedup lives.
- Three classes: pointwise+pointwise (trivial), reduction+pointwise (e.g., softmax), matmul+pointwise (epilogue / output-tile).
- Inductor / XLA / TVM all auto-fuse. They differ in aggressiveness; Inductor’s
mode='max-autotune'tries harder. - Fusion-breakers: memory-format changes, nondeterministic ops, complex control flow.
- Trust the compiler for 90% of cases; hand-fuse via Triton when profiling shows a clearly-missed opportunity.
Go deeper
- DocsPyTorch — torch.compilerAuthoritative. The "Inductor" section explains the fusion pass and the templates Triton kernels are generated from.
- BlogPyTorch Blog — torch.compile production casesReal numbers on what fusion gets you on real models. Section "Inductor performance" has the breakdowns.
- PaperTVM: An Automated End-to-End Optimizing Compiler for Deep LearningThe original paper that formalized AI-compiler fusion. Section 4 covers operator fusion in graph IR.
- PaperGlow: Graph Lowering Compiler Techniques for Neural NetworksThe Facebook side of the same story. Different fusion heuristic; instructive comparison.
- BlogMaking Deep Learning Go Brrrr — Horace HeHands-on view from PyTorch dev. The fusion section explains exactly what `torch.compile` is doing in production.
- Repopytorch/pytorch`torch/_inductor/codegen/triton.py` is the generator that emits fused Triton kernels. `torch/_inductor/scheduler.py` is the fusion heuristic.
Prereqs: SM Architecture, Shared Memory. Fusion exists because HBM is slow and SMEM/registers are fast.
TL;DR
- A pointwise op like
add,mul,relu,sigmoidreads its inputs from HBM and writes its outputs to HBM. Doing this for 5 chained ops means 10 HBM trips and ~5× the bandwidth a single-pass implementation would need. - Operator fusion = generate one kernel that does the whole chain, keeping intermediates in registers. 2–10× speedup on bandwidth-bound chains, often more on long ones.
- Inductor (PyTorch’s
torch.compile), XLA (JAX/TF), TVM, IREE all do this automatically. They differ in how aggressive they are and what they fuse. - Three fusion classes:
- Pointwise + pointwise — trivial; always done.
- Reduction + pointwise (epilogue fusion) — a softmax fused with the multiply that follows it; the canonical attention optimization.
- Matmul + pointwise (output-tile fusion) — a GEMM with bias-add or activation in the same kernel; what CUTLASS calls “epilogue.”
- The fusion boundary is a memory-format change or a non-fusable op (random sample, sort, complex control flow). Compilers cluster fusable ops between these boundaries.
Why this matters
Almost every “we made the model 2× faster” win in production AI compilers is actually a fusion win. The math didn’t get faster; the memory traffic dropped. Reading torch.compile’s output and seeing one big fused kernel where there used to be ten small ones is the visible payoff. Knowing what fuses and what doesn’t is what lets you reason about why your model is fast or slow without reading PTX.
Mental model
Same math, 5× less memory traffic. On bandwidth-bound chains (i.e., almost all pointwise) the fused version runs ~5× faster.
Concrete walkthrough
Three classes of fusion
1. Pointwise + pointwise — the canonical case.
y = torch.relu(x * w + b)Naive: three kernels (mul, add, relu). Each one reads its operands from HBM, computes one elementwise op, writes back. Total HBM traffic = 6 reads + 3 writes.
Fused: one kernel does y[i] = max(0, x[i]*w[i] + b[i]) in registers. Total HBM traffic = 3 reads + 1 write. ~2.5× less.
Every modern AI compiler does this trivially. The question is how far it can extend the fusion before it has to break.
2. Reduction + pointwise (epilogue fusion).
y = torch.exp(x) / torch.exp(x).sum(dim=-1, keepdim=True) # softmaxThe reduction (sum) and the divide can be fused: one pass that computes the row sum, then a second tile-traversal that divides each element. With online softmax (the FlashAttention trick — running max + running sum in one pass), even tighter.
3. Matmul + pointwise (output-tile fusion).
y = torch.relu(x @ W + b)The matmul writes its output tile to SMEM/registers. Before writing to HBM, we apply + b and relu in-place. This is what CUTLASS calls an epilogue — a transformation applied to the output tile before final store. Inductor and XLA do it automatically; CUTLASS lets you specify it via the CollectiveEpilogue template (see CuTe & CUTLASS 4).
What blocks fusion
Three categories of fusion-breakers:
- Reductions across the wrong axis. A sum over the last axis fuses with downstream pointwise; a sum that crosses the tile boundary doesn’t, easily.
- Nondeterministic ops. Random sampling, dropout (without a fused-dropout kernel), top-k. The op needs its full output in HBM before the next op can sample from it.
- Memory-format / dtype changes. Casting from FP16 to FP32 mid-chain, transposing, reshaping with stride changes — these often break fusion because the next op needs a different view of memory.
- Control flow. A
torch.where(x > 0, f(x), g(x))that calls different functions per element is hard to fuse cleanly; compilers may evaluate both and select, or break the fusion.
What torch.compile actually does
Inductor’s lowering pipeline (simplified):
- Receive an FX graph from Dynamo (Python tracing).
- Decompose composite ops into elementary ones.
- Cluster fusable ops together. The clustering is greedy: walk the graph, group ops that are pointwise-compatible, break clusters at boundaries.
- Lower each cluster to a Triton kernel template. Inductor’s templates are parameterized by the ops in the cluster — the kernel body is generated from the FX subgraph.
- Autotune the resulting kernel.
You can see this:
import torch
torch._inductor.config.trace.enabled = True
@torch.compile
def f(x, w, b): return torch.relu(x @ w + b)
f(x, w, b) # produces /tmp/torchinductor_*/output_code.pyIn output_code.py you’ll see one Triton kernel doing matmul + bias + relu, instead of three separate calls. Open it and you can read the fusion in plain Python.
Inductor’s fusion heuristic, briefly
Inductor scores potential fusions by an estimate of HBM bytes saved. Two ops fuse if:
- Their inputs/outputs share a memory layout (or can be made to).
- One’s output is the other’s input (no extra users) — or they’re sibling reads of the same tensor.
- The combined kernel’s register pressure is below a threshold.
- The combined kernel doesn’t exceed the runtime’s max kernel size.
Most useful debug knob: TORCH_LOGS=fusion python script.py — prints the fusion decisions and why some pairs were rejected. When torch.compile “should have fused something but didn’t,” this is where you look.
When to override
Sometimes the compiler picks a worse fusion than you would. Three escape hatches:
- Decomposition tweaks. Reshape your model to be more fusable:
relu(x * (x > 0))is harder to fuse thantorch.relu(x)because the compiler must prove they’re equivalent. @torch.compile(mode='max-autotune'). Tries more fusion configs.- Hand-write the kernel in Triton or CUTLASS. The tools you’ve already met. Write the fused version, drop it in via
torch.library.custom_op.
In practice: trust Inductor for 90% of cases; reach for hand-fusion only when profiling shows the auto-fusion missed a clear opportunity.
Run it in your browser — fuse a tiny graph
The output mirrors how Inductor reports fusion: a graph of clusters with HBM saving estimates per cluster.
Quick check
Key takeaways
- Fusion saves HBM, not FLOPs. The math is the same; the bandwidth traffic drops. That’s where the speedup lives.
- Three classes: pointwise+pointwise (trivial), reduction+pointwise (e.g., softmax), matmul+pointwise (epilogue / output-tile).
- Inductor / XLA / TVM all auto-fuse. They differ in aggressiveness; Inductor’s
mode='max-autotune'tries harder. - Fusion-breakers: memory-format changes, nondeterministic ops, complex control flow.
- Trust the compiler for 90% of cases; hand-fuse via Triton when profiling shows a clearly-missed opportunity.
Go deeper
- DocsPyTorch — torch.compilerAuthoritative. The "Inductor" section explains the fusion pass and the templates Triton kernels are generated from.
- BlogPyTorch Blog — torch.compile production casesReal numbers on what fusion gets you on real models. Section "Inductor performance" has the breakdowns.
- PaperTVM: An Automated End-to-End Optimizing Compiler for Deep LearningThe original paper that formalized AI-compiler fusion. Section 4 covers operator fusion in graph IR.
- PaperGlow: Graph Lowering Compiler Techniques for Neural NetworksThe Facebook side of the same story. Different fusion heuristic; instructive comparison.
- BlogMaking Deep Learning Go Brrrr — Horace HeHands-on view from PyTorch dev. The fusion section explains exactly what `torch.compile` is doing in production.
- Repopytorch/pytorch`torch/_inductor/codegen/triton.py` is the generator that emits fused Triton kernels. `torch/_inductor/scheduler.py` is the fusion heuristic.