Skip to content

Operator Fusion

Prereqs: SM Architecture, Shared Memory. Fusion exists because HBM is slow and SMEM/registers are fast.

Almost every “we made the model 2× faster” win in production AI compilers is actually a fusion win. The math didn’t get faster; the memory traffic dropped. A pointwise op like add, mul, relu, sigmoid reads its inputs from and writes its outputs back. Doing this for 5 chained ops means 10 HBM trips and ~5× the bandwidth a single-pass implementation would need. generates one kernel that does the whole chain, keeping intermediates in registers — typically 2–10× speedup on bandwidth-bound chains.

This is the lesson where reading torch.compile’s output and seeing one big fused kernel where there used to be ten small ones starts making sense. Knowing what fuses and what doesn’t is what lets you reason about why your model is fast or slow without reading PTX.

TL;DR

  • A pointwise op like add, mul, relu, sigmoid reads its inputs from HBM and writes its outputs to HBM. Doing this for 5 chained ops means 10 HBM trips and ~5× the bandwidth a single-pass implementation would need.
  • Operator fusion = generate one kernel that does the whole chain, keeping intermediates in registers. 2–10× speedup on bandwidth-bound chains, often more on long ones.
  • Inductor (PyTorch’s torch.compile), XLA (JAX/TF), TVM, IREE all do this automatically. They differ in how aggressive they are and what they fuse.
  • Three fusion classes:
    1. Pointwise + pointwise — trivial; always done.
    2. Reduction + pointwise (epilogue fusion) — a softmax fused with the multiply that follows it; the canonical attention optimization.
    3. Matmul + pointwise (output-tile fusion) — a GEMM with bias-add or activation in the same kernel; what CUTLASS calls “epilogue.”
  • The fusion boundary is a memory-format change or a non-fusable op (random sample, sort, complex control flow). Compilers cluster fusable ops between these boundaries.

Mental model

Same math, 5× less memory traffic. On bandwidth-bound chains (i.e., almost all pointwise) the fused version runs ~5× faster.

Three classes of fusion

1. Pointwise + pointwise — the canonical case.

y = torch.relu(x * w + b)

Naive: three kernels (mul, add, relu). Each one reads its operands from HBM, computes one elementwise op, writes back. Total HBM traffic = 6 reads + 3 writes.

Fused: one kernel does y[i] = max(0, x[i]*w[i] + b[i]) in registers. Total HBM traffic = 3 reads + 1 write. ~2.5× less.

Every modern AI compiler does this trivially. The question is how far it can extend the fusion before it has to break.

2. Reduction + pointwise (epilogue fusion).

y = torch.exp(x) / torch.exp(x).sum(dim=-1, keepdim=True) # softmax

The reduction (sum) and the divide can be fused: one pass that computes the row sum, then a second tile-traversal that divides each element. With online softmax (the FlashAttention trick — running max + running sum in one pass), even tighter.

3. Matmul + pointwise (output-tile fusion).

y = torch.relu(x @ W + b)

The matmul writes its output tile to SMEM/registers. Before writing to HBM, we apply + b and relu in-place. This is what calls an epilogue — a transformation applied to the output tile before final store. Inductor and XLA do it automatically; CUTLASS lets you specify it via the CollectiveEpilogue template (see CuTe & CUTLASS 4).

What blocks fusion

Three categories of fusion-breakers:

  • Reductions across the wrong axis. A sum over the last axis fuses with downstream pointwise; a sum that crosses the tile boundary doesn’t, easily.
  • Nondeterministic ops. Random sampling, dropout (without a fused-dropout kernel), top-k. The op needs its full output in HBM before the next op can sample from it.
  • Memory-format / dtype changes. Casting from FP16 to FP32 mid-chain, transposing, reshaping with stride changes — these often break fusion because the next op needs a different view of memory.
  • Control flow. A torch.where(x > 0, f(x), g(x)) that calls different functions per element is hard to fuse cleanly; compilers may evaluate both and select, or break the fusion.

What torch.compile actually does

’s lowering pipeline (simplified):

  1. Receive an FX graph from Dynamo (Python tracing).
  2. Decompose composite ops into elementary ones.
  3. Cluster fusable ops together. The clustering is greedy: walk the graph, group ops that are pointwise-compatible, break clusters at boundaries.
  4. Lower each cluster to a Triton kernel template. Inductor’s templates are parameterized by the ops in the cluster — the kernel body is generated from the FX subgraph.
  5. Autotune the resulting kernel.

You can see this:

import torch torch._inductor.config.trace.enabled = True @torch.compile def f(x, w, b): return torch.relu(x @ w + b) f(x, w, b) # produces /tmp/torchinductor_*/output_code.py

In output_code.py you’ll see one Triton kernel doing matmul + bias + relu, instead of three separate calls. Open it and you can read the fusion in plain Python.

Inductor’s fusion heuristic, briefly

Inductor scores potential fusions by an estimate of HBM bytes saved. Two ops fuse if:

  • Their inputs/outputs share a memory layout (or can be made to).
  • One’s output is the other’s input (no extra users) — or they’re sibling reads of the same tensor.
  • The combined kernel’s register pressure is below a threshold.
  • The combined kernel doesn’t exceed the runtime’s max kernel size.

Most useful debug knob: TORCH_LOGS=fusion python script.py — prints the fusion decisions and why some pairs were rejected. When torch.compile “should have fused something but didn’t,” this is where you look.

When to override

Sometimes the compiler picks a worse fusion than you would. Three escape hatches:

  • Decomposition tweaks. Reshape your model to be more fusable: relu(x * (x > 0)) is harder to fuse than torch.relu(x) because the compiler must prove they’re equivalent.
  • @torch.compile(mode='max-autotune'). Tries more fusion configs.
  • Hand-write the kernel in Triton or CUTLASS. The tools you’ve already met. Write the fused version, drop it in via torch.library.custom_op.

In practice: trust Inductor for 90% of cases; reach for hand-fusion only when profiling shows the auto-fusion missed a clear opportunity.

Run it in your browser — fuse a tiny graph

Python — editableA toy fusion engine: scan a graph for pointwise ops, cluster them, count HBM trips before vs after.
Ctrl+Enter to run

The output mirrors how Inductor reports fusion: a graph of clusters with HBM saving estimates per cluster.

Quick check

Fill in the blank
The CUTLASS term for a fusion that combines a matmul with the pointwise op that follows it (bias add, activation):
Latin for 'after speech'.
Quick check
A model has 12 elementwise ops chained between two large matmuls. The student adds `@torch.compile`. The expected speedup at the chain level:

Key takeaways

  1. Fusion saves HBM, not FLOPs. The math is the same; the bandwidth traffic drops. That’s where the speedup lives.
  2. Three classes: pointwise+pointwise (trivial), reduction+pointwise (e.g., softmax), matmul+pointwise (epilogue / output-tile).
  3. Inductor / XLA / TVM all auto-fuse. They differ in aggressiveness; Inductor’s mode='max-autotune' tries harder.
  4. Fusion-breakers: memory-format changes, nondeterministic ops, complex control flow.
  5. Trust the compiler for 90% of cases; hand-fuse via Triton when profiling shows a clearly-missed opportunity.

Go deeper

Prereqs: SM Architecture, Shared Memory. Fusion exists because HBM is slow and SMEM/registers are fast.

TL;DR

  • A pointwise op like add, mul, relu, sigmoid reads its inputs from HBM and writes its outputs to HBM. Doing this for 5 chained ops means 10 HBM trips and ~5× the bandwidth a single-pass implementation would need.
  • Operator fusion = generate one kernel that does the whole chain, keeping intermediates in registers. 2–10× speedup on bandwidth-bound chains, often more on long ones.
  • Inductor (PyTorch’s torch.compile), XLA (JAX/TF), TVM, IREE all do this automatically. They differ in how aggressive they are and what they fuse.
  • Three fusion classes:
    1. Pointwise + pointwise — trivial; always done.
    2. Reduction + pointwise (epilogue fusion) — a softmax fused with the multiply that follows it; the canonical attention optimization.
    3. Matmul + pointwise (output-tile fusion) — a GEMM with bias-add or activation in the same kernel; what CUTLASS calls “epilogue.”
  • The fusion boundary is a memory-format change or a non-fusable op (random sample, sort, complex control flow). Compilers cluster fusable ops between these boundaries.

Why this matters

Almost every “we made the model 2× faster” win in production AI compilers is actually a fusion win. The math didn’t get faster; the memory traffic dropped. Reading torch.compile’s output and seeing one big fused kernel where there used to be ten small ones is the visible payoff. Knowing what fuses and what doesn’t is what lets you reason about why your model is fast or slow without reading PTX.

Mental model

Same math, 5× less memory traffic. On bandwidth-bound chains (i.e., almost all pointwise) the fused version runs ~5× faster.

Concrete walkthrough

Three classes of fusion

1. Pointwise + pointwise — the canonical case.

y = torch.relu(x * w + b)

Naive: three kernels (mul, add, relu). Each one reads its operands from HBM, computes one elementwise op, writes back. Total HBM traffic = 6 reads + 3 writes.

Fused: one kernel does y[i] = max(0, x[i]*w[i] + b[i]) in registers. Total HBM traffic = 3 reads + 1 write. ~2.5× less.

Every modern AI compiler does this trivially. The question is how far it can extend the fusion before it has to break.

2. Reduction + pointwise (epilogue fusion).

y = torch.exp(x) / torch.exp(x).sum(dim=-1, keepdim=True) # softmax

The reduction (sum) and the divide can be fused: one pass that computes the row sum, then a second tile-traversal that divides each element. With online softmax (the FlashAttention trick — running max + running sum in one pass), even tighter.

3. Matmul + pointwise (output-tile fusion).

y = torch.relu(x @ W + b)

The matmul writes its output tile to SMEM/registers. Before writing to HBM, we apply + b and relu in-place. This is what CUTLASS calls an epilogue — a transformation applied to the output tile before final store. Inductor and XLA do it automatically; CUTLASS lets you specify it via the CollectiveEpilogue template (see CuTe & CUTLASS 4).

What blocks fusion

Three categories of fusion-breakers:

  • Reductions across the wrong axis. A sum over the last axis fuses with downstream pointwise; a sum that crosses the tile boundary doesn’t, easily.
  • Nondeterministic ops. Random sampling, dropout (without a fused-dropout kernel), top-k. The op needs its full output in HBM before the next op can sample from it.
  • Memory-format / dtype changes. Casting from FP16 to FP32 mid-chain, transposing, reshaping with stride changes — these often break fusion because the next op needs a different view of memory.
  • Control flow. A torch.where(x > 0, f(x), g(x)) that calls different functions per element is hard to fuse cleanly; compilers may evaluate both and select, or break the fusion.

What torch.compile actually does

Inductor’s lowering pipeline (simplified):

  1. Receive an FX graph from Dynamo (Python tracing).
  2. Decompose composite ops into elementary ones.
  3. Cluster fusable ops together. The clustering is greedy: walk the graph, group ops that are pointwise-compatible, break clusters at boundaries.
  4. Lower each cluster to a Triton kernel template. Inductor’s templates are parameterized by the ops in the cluster — the kernel body is generated from the FX subgraph.
  5. Autotune the resulting kernel.

You can see this:

import torch torch._inductor.config.trace.enabled = True @torch.compile def f(x, w, b): return torch.relu(x @ w + b) f(x, w, b) # produces /tmp/torchinductor_*/output_code.py

In output_code.py you’ll see one Triton kernel doing matmul + bias + relu, instead of three separate calls. Open it and you can read the fusion in plain Python.

Inductor’s fusion heuristic, briefly

Inductor scores potential fusions by an estimate of HBM bytes saved. Two ops fuse if:

  • Their inputs/outputs share a memory layout (or can be made to).
  • One’s output is the other’s input (no extra users) — or they’re sibling reads of the same tensor.
  • The combined kernel’s register pressure is below a threshold.
  • The combined kernel doesn’t exceed the runtime’s max kernel size.

Most useful debug knob: TORCH_LOGS=fusion python script.py — prints the fusion decisions and why some pairs were rejected. When torch.compile “should have fused something but didn’t,” this is where you look.

When to override

Sometimes the compiler picks a worse fusion than you would. Three escape hatches:

  • Decomposition tweaks. Reshape your model to be more fusable: relu(x * (x > 0)) is harder to fuse than torch.relu(x) because the compiler must prove they’re equivalent.
  • @torch.compile(mode='max-autotune'). Tries more fusion configs.
  • Hand-write the kernel in Triton or CUTLASS. The tools you’ve already met. Write the fused version, drop it in via torch.library.custom_op.

In practice: trust Inductor for 90% of cases; reach for hand-fusion only when profiling shows the auto-fusion missed a clear opportunity.

Run it in your browser — fuse a tiny graph

Python — editableA toy fusion engine: scan a graph for pointwise ops, cluster them, count HBM trips before vs after.
Ctrl+Enter to run

The output mirrors how Inductor reports fusion: a graph of clusters with HBM saving estimates per cluster.

Quick check

Fill in the blank
The CUTLASS term for a fusion that combines a matmul with the pointwise op that follows it (bias add, activation):
Latin for 'after speech'.
Quick check
A model has 12 elementwise ops chained between two large matmuls. The student adds `@torch.compile`. The expected speedup at the chain level:

Key takeaways

  1. Fusion saves HBM, not FLOPs. The math is the same; the bandwidth traffic drops. That’s where the speedup lives.
  2. Three classes: pointwise+pointwise (trivial), reduction+pointwise (e.g., softmax), matmul+pointwise (epilogue / output-tile).
  3. Inductor / XLA / TVM all auto-fuse. They differ in aggressiveness; Inductor’s mode='max-autotune' tries harder.
  4. Fusion-breakers: memory-format changes, nondeterministic ops, complex control flow.
  5. Trust the compiler for 90% of cases; hand-fuse via Triton when profiling shows a clearly-missed opportunity.

Go deeper