Tensor Core SHAPE Constraints
In a managed-runtime language, the JIT will quietly pick a fast path or a slow path for your code, and the difference is rarely catastrophic. Hopper Tensor Cores are not so forgiving. They have a fixed instruction set that operates on fixed-shape tiles — and a kernel whose tile sizes don’t match those shapes silently falls back to CUDA cores at 8–16× lower throughput. There is no warning. The kernel runs, the output is correct, the wall-clock time is wrong by an order of magnitude.
This lesson is about the shapes themselves, why they exist, how to verify a kernel actually hit Tensor Cores, and the playbook for diagnosing a kernel that missed them. After this you should be able to look at a Triton autotuned BLOCK_M, BLOCK_N, BLOCK_K configuration and tell whether it’s about to make full use of an H100’s compute peak or quietly waste 90% of it.
TL;DR
- Hopper Tensor Cores execute the wgmma instruction family. The shape is m64 × N × k16 for fp16/bf16 (where N is any multiple of 8 from 8 up to 256) and m64 × N × k32 for fp8. Outside these shapes, you are not on Tensor Cores — you are on CUDA cores at ~1/16th the throughput.
- A Triton kernel only hits Tensor Cores when its per-warpgroup tile is a multiple of (64, 16, K) for fp16 or (64, 32, K) for fp8. Block sizes like
(BLOCK_M=32, BLOCK_K=16)silently miss. - The verification metric is
sm__pipe_tensor_op_hmma_cycles_active.avg.pct_of_peak_sustained_activein NCU. Anything below 30% means the kernel is mostly running on CUDA cores. Production kernels target 60–85%. - Blackwell adds fp4 (m64 × N × k64) and 2:4 sparse fp8 (m64 × N × k64). The shape philosophy carries forward; the K dimension expands as precision shrinks.
- The Capstone 1 fused RMSNorm+Matmul is correct by construction only if its tile shape hits wgmma. Verifying TC% with NCU before fusing is the discipline; the fused kernel inherits whatever the standalone matmul did wrong.
The concept, in plain English
A Tensor Core is not a general-purpose computer. It is a fixed-function unit that takes two small input tiles and a small accumulator tile, performs a single matrix-multiply-accumulate in one instruction, and writes the result back. Because the unit is fixed, the operands must arrive in a specific shape, in a specific layout, in a specific dtype. When all of these line up, the unit runs at peak throughput. When any of them is off, the kernel either rejects the instruction (compiler error) or, more commonly, the compiler silently lowers the operation to CUDA-core scalar code that imitates the same math at a fraction of the speed.
The fixed shapes are the floor. Every fast matmul on Hopper — cuBLAS, CUTLASS, hand-written CUDA, Triton autotuned, ThunderKittens, anything — eventually emits a stream of wgmma instructions in the canonical shape. Your job as a kernel author is to choose tile sizes such that the stream is contiguous and aligned, with no leftover edges that have to be handled scalar.
Mental model — the dispatch decision
Three landings, one instruction stream. Same tl.dot() call, same model output, three orders of magnitude in throughput. NCU is the only thing that tells you which one happened.
The wgmma shape table — the spec you memorize
Hopper SM90 introduced wgmma (warpgroup matrix-multiply-accumulate). A warpgroup is 4 contiguous warps (128 threads) and the instruction operates over them collectively. The shapes:
| Operand dtypes | Shape (m × n × k) | n options | Peak TFLOPs/s (SXM, dense) |
|---|---|---|---|
| fp16/bf16 → fp32 | 64 × N × 16 | 8, 16, 24, …, 256 | 989 |
| fp8 (e4m3, e5m2) → fp32 | 64 × N × 32 | 8, 16, 24, …, 256 | 1979 |
| int8 → int32 | 64 × N × 32 | 8, 16, 24, …, 256 | 1979 |
| tf32 → fp32 | 64 × N × 8 | 8, 16, 24, …, 256 | 495 |
| fp64 → fp64 | 64 × N × 16 | 8, 16, 32, …, 256 | 67 |
A few invariants: m is always 64. k is fixed per dtype (16 for fp16/bf16, 32 for fp8/int8, 8 for tf32). n is the only free axis and must be a multiple of 8 between 8 and 256. The output accumulator is fp32 except for fp64.
Blackwell SM100 extends the table downward: fp4 (NVFP4) at m64 × N × 64 doubles k again, 2:4 sparse fp8 at m64 × N × 64 uses structured sparsity for an extra 2× over dense fp8. The pattern is consistent: as precision shrinks, k grows because the chip wants to keep the data path utilization constant.
Why these specific shapes
The shapes encode three hardware constraints simultaneously:
1. Operand layout in shared memory. wgmma reads its inputs from SMEM in specific 16-byte-element-aligned patterns called swizzled layouts. A 64-wide m dimension matches the 32-thread warp × 2 column pattern; a 16-wide k dimension matches one cache line. Picking m=32 means SMEM loads waste half the bandwidth.
2. Accumulator register footprint. Each warpgroup gets 64 registers per thread. For fp32 accumulator with m=64 × n=128, that’s 8 fp32 values per thread × 4 warps = 1024 fp32 = exactly 64 registers / thread when laid out across the 128 threads. n=128 is a sweet spot; n=256 risks register spill.
3. Tensor Core array geometry. Hopper has 4 fp16 Tensor Cores per SM, each operating on 16×16×16 tiles internally. The warpgroup-level m64 × n × k16 instruction is a tiled batch of these — the wgmma instruction “knows” how to feed 4 inner tiles in parallel, but only if the outer shape lines up.
A kernel that uses tile (m=32, n=128, k=16) generates wgmma instructions with m=32. The hardware doesn’t have such an instruction. The compiler emits mma.sync (the Ampere-era 16×16×16 instruction) two times instead — and mma.sync is on a slower path, ~520 TFLOPs/s peak vs 989. Throughput is roughly halved, silently.
How a Triton kernel actually hits wgmma
Triton’s tl.dot() lowers to wgmma when the tile shape passed to it is a wgmma-compatible shape. The autotune surface controls this:
@triton.autotune(
configs=[
# GOOD — tile is (m=128, n=256, k=64). Each warpgroup sees (m=64, n=256, k=16).
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64,
'GROUP_M': 8}, num_warps=8, num_stages=4),
# GOOD — (m=128, n=128, k=64). Each warpgroup sees (m=64, n=128, k=16).
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64,
'GROUP_M': 8}, num_warps=4, num_stages=4),
# SILENT MISS — (m=32) is not 64-aligned. Falls back to mma.sync.
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 32,
'GROUP_M': 8}, num_warps=4, num_stages=2),
# SILENT MISS — k=8 is below the wgmma minimum (16 for fp16).
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 8,
'GROUP_M': 8}, num_warps=2, num_stages=2),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(...):
...
acc += tl.dot(a, b, out_dtype=tl.float32) # this becomes wgmma OR mma.sync OR scalarThe num_warps argument matters too. wgmma is a warpgroup instruction — it requires num_warps to be a multiple of 4. num_warps=2 cannot use wgmma at all. This is the most common silent miss in Triton autotuned configs that copy from older Ampere examples.
The autotuner will sometimes pick a bad config because correctness is the same and timing variance occasionally favors it. Always print the chosen config and verify with NCU.
Verifying with NCU — the metric that matters
Run the kernel under NCU and look at exactly one number first:
ncu --set full --kernel-name matmul_kernel python bench.pyThen in the NCU output, find the metric tree section “GPU Speed of Light” and read:
| Metric | Meaning | Healthy value |
|---|---|---|
sm__pipe_tensor_op_hmma_cycles_active.avg.pct_of_peak_sustained_active | fp16/bf16 TC pipe utilization | 60–85% for compute-bound kernels |
sm__pipe_tensor_op_imma_cycles_active.avg.pct_of_peak_sustained_active | fp8/int8 TC pipe utilization | 60–85% for fp8 kernels |
sm__inst_executed_pipe_fma.avg.pct_of_peak_sustained_active | fp32 CUDA-core utilization | should be low (< 10%) for a TC kernel |
If the TC metric is below 30%, the kernel is largely on CUDA cores. The fix is almost always a tile shape change. If it’s between 30–60%, the kernel is hitting some wgmma but losing throughput at edges (the M or N dimension isn’t a multiple of the tile, leftover scalar work is happening) — fix with shape padding or by autotuning over more configs.
NCU’s “Tensor Core” panel directly tells you which shape variant ran. Reading that panel before the source code is the senior workflow.
Concrete walkthrough — a silent miss in the wild
A common Triton matmul template you’ll see in tutorials:
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_warps=2, num_stages=2), # ← bad
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_warps=8, num_stages=3),
],
key=['M', 'N', 'K'],
)
@triton.jit
def naive_matmul(...): ...At small problem shapes (M, N around 1024), Triton’s autotuner often picks the second config because total work fits and warp launch overhead is lower. The kernel runs. Output is correct. NCU’s TC pipe utilization is 22%. The fp32 CUDA-core pipe is at 41%. You shipped a kernel that is silently 4× slower than it needs to be.
The fix is to remove the num_warps=2 config from the autotune list — wgmma needs ≥ 4 warps, period. The autotuner respecting the hardware instead of fighting it.
Run it in your browser — predict TC eligibility
You’ll see the four classic traps light up immediately: BLOCK_M=32 (per-warpgroup m=32 ≠ 64), num_warps=2 (no wgmma), BLOCK_K=8 (below fp16 minimum), BLOCK_K=16 for fp8 (below fp8 minimum). Each one is a kernel that “works” and is silently 4–10× slower than it should be.
Quick check
Key takeaways
- Hopper Tensor Cores have a fixed shape contract. wgmma fp16/bf16: m64 × N × k16. fp8: m64 × N × k32. Outside these shapes, you fall back to mma.sync or scalar code at a fraction of the throughput.
- Verify with NCU, always.
sm__pipe_tensor_op_hmma_cycles_active.avg.pct_of_peak_sustained_activeis the metric. Below 30% means the kernel missed wgmma. - Triton autotune can pick bad configs. num_warps must be a multiple of 4 for wgmma. Per-warpgroup m must equal 64. BLOCK_K must hit the dtype minimum (16 fp16, 32 fp8, 8 tf32).
- Blackwell extends the pattern, doesn’t replace it. fp4 at m64 × N × k64. Sparse fp8 at m64 × N × k64. Same m=64 invariant; k grows as precision shrinks.
- Capstone 1 inherits this discipline. The fused RMSNorm+Matmul kernel cannot be faster than the standalone matmul that ships inside it. Verify TC% on the standalone first; only then fuse.
Go deeper
- DocsPTX ISA — wgmma instruction referenceThe instruction set itself. Bookmark this page; you will reference it any time you write a Hopper kernel.
- DocsHopper Tuning GuideThe official guide to TMA + wgmma + warp specialization. Section 3 (Tensor Core programming) is required reading.
- PaperHopper Architecture WhitepaperSource for peak TFLOPs/s by dtype and the wgmma instruction shape table.
- DocsCuTe GEMM Tutorial — Hopper SM90 pathHow CUTLASS expresses wgmma shape constraints in CuTe layouts. The reference for production kernel structure.
- PaperOutperforming cuBLAS on H100 — FP16 Worked ExampleA real Hopper kernel that lands at 84% TC peak, with the wgmma scheduling explained.
- DocsTriton Matrix Multiplication TutorialThe autotune surface where wgmma-eligible configs live.
- PaperThunderKittens: Simple, Fast, and Adorable AI KernelsA tile-DSL alternative that bakes wgmma shape constraints into the type system.
- BlogCUTLASS Tutorial — A Deep Dive into wgmmaThe clearest single explanation of how wgmma is dispatched and what each operand means.
TL;DR
- Hopper Tensor Cores execute the wgmma instruction family. The shape is m64 × N × k16 for fp16/bf16 (where N is any multiple of 8 from 8 up to 256) and m64 × N × k32 for fp8. Outside these shapes, you are not on Tensor Cores — you are on CUDA cores at ~1/16th the throughput.
- A Triton kernel only hits Tensor Cores when its per-warpgroup tile is a multiple of (64, 16, K) for fp16 or (64, 32, K) for fp8. Block sizes like
(BLOCK_M=32, BLOCK_K=16)silently miss. - The verification metric is
sm__pipe_tensor_op_hmma_cycles_active.avg.pct_of_peak_sustained_activein NCU. Below 30% means the kernel is mostly running on CUDA cores. Production kernels target 60–85%. - Blackwell adds fp4 (m64 × N × k64) and 2:4 sparse fp8 (m64 × N × k64). The shape philosophy carries forward; the K dimension expands as precision shrinks.
- The Capstone 1 fused RMSNorm+Matmul is correct by construction only if its tile shape hits wgmma. Verifying TC% with NCU before fusing is the discipline; the fused kernel inherits whatever the standalone matmul did wrong.
Why this matters
Every fast matmul on Hopper terminates in a wgmma instruction stream. The hardware peak (989 TFLOPs/s fp16) is achievable only through wgmma; CUDA cores top out at 67 TFLOPs/s fp32 (a 14× delta) and mma.sync (Ampere-era) at ~520 TFLOPs/s. A kernel author who cannot identify whether their tile shape hits wgmma cannot reliably optimize a kernel — every speedup attempt is a coin flip.
This matters most when reading PRs in vLLM / SGLang / Triton: most “perf improvement” PRs that get cited by maintainers are tile-shape changes that flip a kernel from mma.sync to wgmma. Recognizing the shape contract is the prerequisite for proposing those PRs.
Mental model
The wgmma shape table
| Operand dtypes | Shape (m × n × k) | n options | Peak TFLOPs/s (SXM, dense) |
|---|---|---|---|
| fp16/bf16 → fp32 | 64 × N × 16 | 8, 16, 24, …, 256 | 989 |
| fp8 (e4m3, e5m2) → fp32 | 64 × N × 32 | 8, 16, 24, …, 256 | 1979 |
| int8 → int32 | 64 × N × 32 | 8, 16, 24, …, 256 | 1979 |
| tf32 → fp32 | 64 × N × 8 | 8, 16, 24, …, 256 | 495 |
| fp64 → fp64 | 64 × N × 16 | 8, 16, 32, …, 256 | 67 |
Invariants: m always 64, k fixed per dtype, n free as any multiple of 8 from 8 up to 256.
Blackwell SM100 additions: NVFP4 at m64 × N × 64; 2:4 sparse fp8 at m64 × N × 64. Pattern: as precision shrinks, k doubles to keep the data path utilized.
Why these shapes — three constraints
1. SMEM swizzled layouts. wgmma reads operands from SMEM in 16-byte-element-aligned swizzled patterns. m=64 matches the 32-thread warp × 2 column read pattern; k=16 matches one 128-byte cache line for fp16. Smaller m wastes load bandwidth.
2. Accumulator register footprint. Per-warpgroup, the fp32 accumulator for m64 × n128 fits in 64 registers/thread when distributed across 128 threads. n=256 risks spill into local memory; n=64 leaves units idle.
3. Tensor Core array geometry. Each SM has 4 fp16 Tensor Core units; the wgmma instruction feeds 4 inner 16×16×16 tiles in parallel. The outer m64 × N × k16 is a tiled batch over those inner tiles.
A tile shape that violates any constraint fails dispatch and falls back. The compiler does not warn.
Triton autotune surface — what to do and what to avoid
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64, 'GROUP_M': 8},
num_warps=8, num_stages=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8},
num_warps=4, num_stages=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'GROUP_M': 8},
num_warps=4, num_stages=4),
# AVOID: per-warpgroup m=32, no wgmma
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_warps=4, num_stages=2),
# AVOID: num_warps=2, no wgmma (needs >= 4)
# triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_warps=2, num_stages=2),
# AVOID: BLOCK_K=8, below fp16 minimum
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 8}, num_warps=4, num_stages=2),
],
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(...): ...Per-warpgroup tile = (BLOCK_M / (num_warps // 4), BLOCK_N, BLOCK_K). For wgmma fp16/bf16:
- per-warpgroup m must equal 64
- BLOCK_K must be a multiple of 16
- num_warps must be a multiple of 4
- BLOCK_N must be a multiple of 8 between 8 and 256
For fp8: same except BLOCK_K must be a multiple of 32.
NCU verification
ncu --set full --kernel-name matmul_kernel python bench.pyRead these metrics first:
| Metric | Meaning | Healthy value |
|---|---|---|
sm__pipe_tensor_op_hmma_cycles_active.avg.pct_of_peak_sustained_active | fp16/bf16 TC pipe | 60–85% (compute-bound) |
sm__pipe_tensor_op_imma_cycles_active.avg.pct_of_peak_sustained_active | fp8/int8 TC pipe | 60–85% |
sm__pipe_tensor_op_qmma_cycles_active.avg.pct_of_peak_sustained_active | fp4 TC pipe (Blackwell) | 60–85% |
sm__inst_executed_pipe_fma.avg.pct_of_peak_sustained_active | fp32 CUDA-core pipe | < 10% (TC kernels) |
sm__inst_executed_pipe_xu.avg.pct_of_peak_sustained_active | special-function unit | usually < 5% |
Diagnostic table:
| TC % | fp32 % | Diagnosis | Fix |
|---|---|---|---|
| < 5 | high | All on CUDA cores | Tile shape entirely wrong; see Triton config rules above |
| 5–30 | moderate | Mostly mma.sync or fallback | Per-warpgroup m ≠ 64, or num_warps not multiple of 4 |
| 30–60 | low | wgmma + scalar edges | Problem dims not multiple of tile; pad or reduce tile |
| 60–85 | very low | Healthy | None |
| > 90 | very low | Verify with second tool | Often a measurement artifact at small problem sizes |
Concrete walkthrough — diagnosing a silent miss
A Triton matmul ships with this autotune config and reports correct output:
@triton.autotune(configs=[
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_warps=8, num_stages=3),
])NCU says: sm__pipe_tensor_op_hmma_cycles_active.avg.pct_of_peak_sustained_active = 22%, sm__inst_executed_pipe_fma.avg.pct_of_peak_sustained_active = 41%.
Walking the diagnostic:
- TC at 22% → kernel is mostly NOT on wgmma.
- fp32 CUDA pipe at 41% → significant work on scalar units.
- Look at the chosen config: autotuner picked
BLOCK_M=32, num_warps=4. Per-warpgroup m = 32 / (4 // 4) = 32. Per-warpgroup m=32 ≠ 64. Fails wgmma dispatch. Falls back to mma.sync (Ampere-style) at ~520 TFLOPs/s peak; the rest spills to fp32 CUDA cores for accumulator handling. - Fix: remove the BLOCK_M=32 config from autotune list. Re-run; autotuner picks the BLOCK_M=128 config; TC% jumps to 71%, fp32 drops to 4%, kernel runs 3.2× faster on the same shape.
Real numbers — Hopper kernel landings
Where production kernels hit on the TC pipe:
| Kernel | Tile (M, N, K) | num_warps | TC % | Notes |
|---|---|---|---|---|
| cuBLAS fp16 GEMM | (auto) | (auto) | 78–84% | Best of breed default |
| CUTLASS hand-tuned (4096³) | 128 × 256 × 64 | 8 | 84% | The 2024 paper |
| Triton autotuned matmul | 128 × 128 × 64 | 4 | 71% | Default tutorial config |
| Triton (BLOCK_M=32 trap) | 32 × 128 × 32 | 4 | 22% | Silent miss |
| ThunderKittens GEMM | 192 × 256 × 64 | 8 | 87% | The newest tile DSL |
| Marlin INT4 GEMV | (decode) | 2 | < 5% | HBM-bound; TC isn’t the right metric |
Marlin’s < 5% TC is correct — INT4 decode is HBM-bound and the bytes/op is the right metric. Always reconcile the metric you’re reading against the regime the kernel is in.
Quick check
Key takeaways
- wgmma shapes: m64 × N × k16 (fp16/bf16), m64 × N × k32 (fp8/int8), m64 × N × k8 (tf32). Outside these → no Tensor Core.
- Triton: per-warpgroup m = 64 invariant. num_warps multiple of 4. BLOCK_K hits dtype minimum.
- NCU
sm__pipe_tensor_op_hmma_cycles_active.avg.pct_of_peak_sustained_activeis the metric. < 30% means broken. - Blackwell fp4 / sparse fp8 expand k to 64. The m=64 invariant carries forward.
- Capstone 1 fused RMSNorm+Matmul: verify TC% on the standalone matmul before fusing.