SM Architecture
When you read about a GPU, the marketing number is “16,000 CUDA cores.” That number is technically true and almost completely useless for understanding how a kernel runs. The real architectural unit isn’t a single core — it’s a Streaming Multiprocessor (SM), a small self-contained processor with its own instruction schedulers, register file, and on-chip scratch memory. An H100 is 132 of these. A B200 is 148.
Each SM has a fixed budget of resources: 256 KB of registers, 228 KB of shared memory, 64 warp slots, 4 tensor cores. Every “kernel optimization” decision — tile size, occupancy, pipeline depth — is fundamentally a budget allocation across that fixed list. When CUTLASS picks a 128×128×64 tile with 3 pipeline stages and 8 warps per CTA, those numbers aren’t magic; they’re a careful negotiation between accumulator registers, shared-memory tile area, and how many CTAs fit on one SM at once.
This lesson is the hardware view: what’s actually inside one SM, what the warp is (32 threads in lockstep — the real unit of execution), why tensor cores are the reason the chip exists, and how to read a kernel’s tile shape as a resource allocation. Once this clicks, CUTLASS code stops looking like magic numbers.
TL;DR
- A modern GPU is a fleet of Streaming Multiprocessors (SMs). Each SM is its own little processor: 4 warp schedulers, 64–128 CUDA cores, 4 tensor cores, a register file, and shared memory. Everything that runs on a GPU runs on one SM at a time.
- The warp is the real unit of execution: 32 threads marching in lockstep through one instruction at a time. Threads in a warp can’t truly diverge — they take turns when they branch. A kernel that thinks “thread” is the unit is going to be slow.
- Tensor Cores do the heavy lifting in modern AI: 4×8×16 (or larger) FP16/BF16/FP8 matrix-multiply-accumulate per cycle per core. Hopper has 4 per SM; Blackwell has more and faster ones with FP4/FP6 support.
- Occupancy is the fraction of an SM’s potential warps that are actually resident. High occupancy hides memory latency; low occupancy with high arithmetic intensity is fine. Don’t chase occupancy as an end in itself.
- The same kernel on H100 (132 SMs) vs B200 (148 SMs) vs MI355X (288 CUs) sees different parallelism budgets. Tile size and CTA count are tuned per-chip.
Mental model
Memorize the resources of one SM. Every kernel-tuning decision (tile size, occupancy, register pressure, smem usage) is a budget allocation across this list.
What’s inside an H100 SM
| Resource | H100 SM | What it bounds |
|---|---|---|
| Warp schedulers | 4 | Up to 4 warps issue per cycle |
| CUDA cores (FP32) | 128 | Vector / scalar non-tensor work |
| Tensor cores | 4 (4th gen) | All the AI matmul throughput |
| Register file | 64 K × 32-bit (256 KB) | Per-thread registers; tile size |
| Shared memory + L1 | 228 KB + 28 KB driver reserved | On-chip scratch + L1 cache |
| Max threads resident | 2048 | Occupancy ceiling |
| Max warps resident | 64 | Same as 2048 / 32 |
| Max blocks resident | 32 | Number of CTAs co-resident |
| HBM bandwidth (chip) | 3.35 TB/s | Cross-SM memory traffic |
For comparison: a B200 SM bumps tensor cores to FP4-capable 5th-gen, raises shared memory to 256 KB, and adds the Tensor Memory Accelerator (TMA) as a first-class async-copy engine. MI355X uses CUs (Compute Units) instead of SMs — same idea, different name, slightly different layout.
The warp is the real unit
A warp is 32 threads. They share an instruction pointer. When they all take the same branch, they execute together at full throughput. When they diverge — half take the if, half take the else — the warp serializes through both branches: it executes the if-branch with the else-half masked off, then the else-branch with the if-half masked off. Throughput halves per divergence level.
// Bad: each thread does different work based on its lane id
if (threadIdx.x % 2 == 0) { do_heavy_work_A(); } else { do_heavy_work_B(); }
// Both branches execute serially within the warp.
// Good: all 32 threads in the warp do the same work, on different data.
out[i] = heavy_work(in[i]);This is why “embarrassingly parallel” maps to GPUs but anything with branching needs care. Sorting, sparse data, dynamic batching — all classic GPU pain points because they branch.
Tensor cores do (almost) all the work
The dominant Hopper instruction is wgmma.mma_async — a warp-group (128-thread) async matmul, e.g. 64×128×16 at FP16, that retires asynchronously while the warp does other work. Per SM, the four tensor cores collectively deliver ~1024 BF16 FMAs/cycle, i.e. ~2048 FLOPs/cycle/SM at BF16 (FMAs are 2 FLOPs each). Times ~1.8 GHz boost × 132 SMs ≈ 989 BF16 TFLOPS — NVIDIA’s published H100 SXM5 spec. FP8 doubles that to ~1979 TFLOPS because each FP8 mac packs 2× the throughput per cycle. (mma.sync, the per-warp 16×8×16 instruction in older code, is still legal on Hopper but wgmma is the path the compiler picks for serious kernels.)
The lesson: everything else on the SM is bookkeeping for the tensor cores. CUDA cores load and stage data; warp schedulers issue wgmma; shared memory holds tiles; registers hold accumulators. The CUDA cores themselves don’t even hit 100 TFLOPS — the tensor cores are the reason the chip exists.
On Blackwell (SM_100), this generalizes: 5th-gen tensor cores add FP4/FP6 (tcgen05.mma), and Tensor Memory (TMEM) replaces the register file as the accumulator destination for these new instructions — you can no longer fit a Blackwell-class accumulator in regs alone, so the architecture provides a dedicated pool the tensor cores write into directly.
Occupancy, briefly
Occupancy = (warps actually resident on SM) / (max warps the SM can hold). Each warp consumes:
- registers (per-thread × 32)
- shared memory (per-CTA, divided across warps in the CTA)
- a “warp slot” (max 64 warps per SM on H100)
If you use 128 registers/thread on H100, max warps is 64K / (128 × 32) = 16 — 25% occupancy. If your kernel has high arithmetic intensity (lots of math per memory load) that’s fine — you don’t need extra warps to hide memory latency. If your kernel is memory-bound, low occupancy means HBM stalls show through.
Don’t chase occupancy. Chase throughput. A FlashAttention kernel runs at ~25% occupancy by design — it spends most of its registers on accumulators because that’s where the win is.
Reading CUTLASS by SM resources
A CUTLASS kernel definition like gemm<tile=128x128x64, stages=3, warps_per_cta=8> translates to:
- Tile 128×128×64: each CTA computes a 128×128 output block, processing K in chunks of 64. The 16,384 FP32 accumulators distribute across 8 warps × 32 lanes = 256 threads → 64 regs/thread for the accumulator alone (well under H100’s 255-reg/thread cap).
- 3 stages: the async-copy pipeline keeps 3 K-chunks in flight (stage 0 reading from HBM, stage 1 in SMEM, stage 2 being matmul’d). Triple-buffered to hide HBM latency.
- 8 warps per CTA: 256 threads. Tensor cores run on warp-groups; at 4 schedulers per SM and 8 warps per CTA, each scheduler has 2 warps to round-robin. Hides instruction issue latency.
- Shared memory used: 3 stages × (128 × 64 + 64 × 128) × 2 bytes = 96 KB. Fits in 228 KB SMEM. Could fit 2 CTAs per SM → potentially higher occupancy, but the kernel typically chooses 1 to keep registers per warp high.
The binding constraint at this scale is registers per thread, not the SM’s aggregate 256 KB register file. A 256×256 tile would need 65,536 FP32 accumulators / 256 threads = 256 regs/thread — past the 255-reg cap, so the compiler spills. That’s the real reason kernels stop at 128×128 (or grow warps_per_cta to 16 / use a CGA cluster).
The numbers aren’t magic; they’re an SM-budget allocation. Once you can read tile shapes as resource allocations, kernel code becomes legible.
Run it in your browser — SM resource calculator
The 256×256 tile fails because 256 KB of FP32 accumulators exceeds the 256 KB register file (and CUTLASS would spill to local memory, killing throughput). This is exactly why CUTLASS / Triton autotune only emits a handful of tile shapes — the rest don’t fit.
Quick check
Key takeaways
- The SM is the unit you optimize against. Memorize its resources: register file, shared memory, warp slots, tensor cores.
- A warp is 32 threads in lockstep. Branch divergence serializes; embarrassingly-parallel patterns thrive.
- Tensor cores are why the GPU is fast. Everything else on the SM is bookkeeping to keep the tensor cores fed.
- Occupancy is a means, not an end. High arithmetic intensity → low occupancy is fine.
- Reading kernel tile shapes = reading SM resource allocations. Once this clicks, CUTLASS / Triton kernels stop being magic.
Go deeper
- DocsCUDA C++ Programming Guide — Hardware ImplementationAuthoritative. The "Compute Capability" tables list per-SM resources for every architecture from Kepler through Blackwell.
- DocsNVIDIA Hopper Tuning GuideHow to think about TMA, WGMMA, async memory pipelines, warp specialization. Required reading before writing serious H100 code.
- DocsNVIDIA Blackwell Tuning Guide5th-gen tensor cores, FP4, Tensor Memory. The 2025+ extension of the Hopper guide.
- BlogThunderKittens — Hazy ResearchBest modern explanation of warp-specialized programming in plain English. Read alongside the SM picture above.
- Blogsiboehm — How to Optimize a CUDA MatmulWalks from naive to near-cuBLAS in stages. Best ground-up CUDA tutorial on the modern open web.
- RepoNVIDIA/cutlassThe reference. Look at `cute/atom/mma_traits_sm90.hpp` to see how Hopper tensor-core operations are encoded.
Companion to: Thread Hierarchy (the programmer’s view) and Shared Memory (the per-SM scratch pad). This lesson is the hardware view: what’s actually inside the chip.
TL;DR
- A modern GPU is a fleet of Streaming Multiprocessors (SMs). Each SM is its own little processor: 4 warp schedulers, 64–128 CUDA cores, 4 tensor cores, a register file, and shared memory. Everything that runs on a GPU runs on one SM at a time.
- The warp is the real unit of execution: 32 threads marching in lockstep through one instruction at a time. Threads in a warp can’t truly diverge — they take turns when they branch. A kernel that thinks “thread” is the unit is going to be slow.
- Tensor Cores do the heavy lifting in modern AI: 4×8×16 (or larger) FP16/BF16/FP8 matrix-multiply-accumulate per cycle per core. Hopper has 4 per SM; Blackwell has more and faster ones with FP4/FP6 support.
- Occupancy is the fraction of an SM’s potential warps that are actually resident. High occupancy hides memory latency; low occupancy with high arithmetic intensity is fine. Don’t chase occupancy as an end in itself.
- The same kernel on H100 (132 SMs) vs B200 (148 SMs) vs MI355X (288 CUs) sees different parallelism budgets. Tile size and CTA count are tuned per-chip.
Why this matters
Every fast GPU kernel — Flash Attention, cuBLAS GEMM, PagedAttention, vLLM kernels — is written to the SM. The programmer’s “thread block” abstraction is real, but the unit you optimize against is one SM with a fixed budget of registers, shared memory, and warp slots. Knowing the SM’s resources by heart is what lets you read a CUTLASS source file and immediately see why it picked a 128×128 tile, four pipelined stages, and 8 warps per CTA. Without this picture, kernel code looks like magic numbers.
Mental model
Memorize the resources of one SM. Every kernel-tuning decision (tile size, occupancy, register pressure, smem usage) is a budget allocation across this list.
Concrete walkthrough
What’s inside an H100 SM
| Resource | H100 SM | What it bounds |
|---|---|---|
| Warp schedulers | 4 | Up to 4 warps issue per cycle |
| CUDA cores (FP32) | 128 | Vector / scalar non-tensor work |
| Tensor cores | 4 (4th gen) | All the AI matmul throughput |
| Register file | 64 K × 32-bit (256 KB) | Per-thread registers; tile size |
| Shared memory + L1 | 228 KB + 28 KB driver reserved | On-chip scratch + L1 cache |
| Max threads resident | 2048 | Occupancy ceiling |
| Max warps resident | 64 | Same as 2048 / 32 |
| Max blocks resident | 32 | Number of CTAs co-resident |
| HBM bandwidth (chip) | 3.35 TB/s | Cross-SM memory traffic |
For comparison: a B200 SM bumps tensor cores to FP4-capable 5th-gen, raises shared memory to 256 KB, and adds the Tensor Memory Accelerator (TMA) as a first-class async-copy engine. MI355X uses CUs (Compute Units) instead of SMs — same idea, different name, slightly different layout.
The warp is the real unit
A warp is 32 threads. They share an instruction pointer. When they all take the same branch, they execute together at full throughput. When they diverge — half take the if, half take the else — the warp serializes through both branches: it executes the if-branch with the else-half masked off, then the else-branch with the if-half masked off. Throughput halves per divergence level.
// Bad: each thread does different work based on its lane id
if (threadIdx.x % 2 == 0) { do_heavy_work_A(); } else { do_heavy_work_B(); }
// Both branches execute serially within the warp.
// Good: all 32 threads in the warp do the same work, on different data.
out[i] = heavy_work(in[i]);This is why “embarrassingly parallel” maps to GPUs but anything with branching needs care. Sorting, sparse data, dynamic batching — all classic GPU pain points because they branch.
Tensor cores do (almost) all the work
The dominant Hopper instruction is wgmma.mma_async — a warp-group (128-thread) async matmul, e.g. 64×128×16 at FP16, that retires asynchronously while the warp does other work. Per SM, the four tensor cores collectively deliver ~1024 BF16 FMAs/cycle, i.e. ~2048 FLOPs/cycle/SM at BF16 (FMAs are 2 FLOPs each). Times ~1.8 GHz boost × 132 SMs ≈ 989 BF16 TFLOPS — NVIDIA’s published H100 SXM5 spec. FP8 doubles that to ~1979 TFLOPS because each FP8 mac packs 2× the throughput per cycle. (mma.sync, the per-warp 16×8×16 instruction in older code, is still legal on Hopper but wgmma is the path the compiler picks for serious kernels.)
The lesson: everything else on the SM is bookkeeping for the tensor cores. CUDA cores load and stage data; warp schedulers issue wgmma; shared memory holds tiles; registers hold accumulators. The CUDA cores themselves don’t even hit 100 TFLOPS — the tensor cores are the reason the chip exists.
On Blackwell (SM_100), this generalizes: 5th-gen tensor cores add FP4/FP6 (tcgen05.mma), and Tensor Memory (TMEM) replaces the register file as the accumulator destination for these new instructions — you can no longer fit a Blackwell-class accumulator in regs alone, so the architecture provides a dedicated pool the tensor cores write into directly.
Occupancy, briefly
Occupancy = (warps actually resident on SM) / (max warps the SM can hold). Each warp consumes:
- registers (per-thread × 32)
- shared memory (per-CTA, divided across warps in the CTA)
- a “warp slot” (max 64 warps per SM on H100)
If you use 128 registers/thread on H100, max warps is 64K / (128 × 32) = 16 — 25% occupancy. If your kernel has high arithmetic intensity (lots of math per memory load) that’s fine — you don’t need extra warps to hide memory latency. If your kernel is memory-bound, low occupancy means HBM stalls show through.
Don’t chase occupancy. Chase throughput. A FlashAttention kernel runs at ~25% occupancy by design — it spends most of its registers on accumulators because that’s where the win is.
Reading CUTLASS by SM resources
A CUTLASS kernel definition like gemm<tile=128x128x64, stages=3, warps_per_cta=8> translates to:
- Tile 128×128×64: each CTA computes a 128×128 output block, processing K in chunks of 64. The 16,384 FP32 accumulators distribute across 8 warps × 32 lanes = 256 threads → 64 regs/thread for the accumulator alone (well under H100’s 255-reg/thread cap).
- 3 stages: the async-copy pipeline keeps 3 K-chunks in flight (stage 0 reading from HBM, stage 1 in SMEM, stage 2 being matmul’d). Triple-buffered to hide HBM latency.
- 8 warps per CTA: 256 threads. Tensor cores run on warp-groups; at 4 schedulers per SM and 8 warps per CTA, each scheduler has 2 warps to round-robin. Hides instruction issue latency.
- Shared memory used: 3 stages × (128 × 64 + 64 × 128) × 2 bytes = 96 KB. Fits in 228 KB SMEM. Could fit 2 CTAs per SM → potentially higher occupancy, but the kernel typically chooses 1 to keep registers per warp high.
The binding constraint at this scale is registers per thread, not the SM’s aggregate 256 KB register file. A 256×256 tile would need 65,536 FP32 accumulators / 256 threads = 256 regs/thread — past the 255-reg cap, so the compiler spills. That’s the real reason kernels stop at 128×128 (or grow warps_per_cta to 16 / use a CGA cluster).
The numbers aren’t magic; they’re an SM-budget allocation. Once you can read tile shapes as resource allocations, kernel code becomes legible.
Run it in your browser — SM resource calculator
The 256×256 tile fails because 256 KB of FP32 accumulators exceeds the 256 KB register file (and CUTLASS would spill to local memory, killing throughput). This is exactly why CUTLASS / Triton autotune only emits a handful of tile shapes — the rest don’t fit.
Quick check
Key takeaways
- The SM is the unit you optimize against. Memorize its resources: register file, shared memory, warp slots, tensor cores.
- A warp is 32 threads in lockstep. Branch divergence serializes; embarrassingly-parallel patterns thrive.
- Tensor cores are why the GPU is fast. Everything else on the SM is bookkeeping to keep the tensor cores fed.
- Occupancy is a means, not an end. High arithmetic intensity → low occupancy is fine.
- Reading kernel tile shapes = reading SM resource allocations. Once this clicks, CUTLASS / Triton kernels stop being magic.
Go deeper
- DocsCUDA C++ Programming Guide — Hardware ImplementationAuthoritative. The "Compute Capability" tables list per-SM resources for every architecture from Kepler through Blackwell.
- DocsNVIDIA Hopper Tuning GuideHow to think about TMA, WGMMA, async memory pipelines, warp specialization. Required reading before writing serious H100 code.
- DocsNVIDIA Blackwell Tuning Guide5th-gen tensor cores, FP4, Tensor Memory. The 2025+ extension of the Hopper guide.
- BlogThunderKittens — Hazy ResearchBest modern explanation of warp-specialized programming in plain English. Read alongside the SM picture above.
- Blogsiboehm — How to Optimize a CUDA MatmulWalks from naive to near-cuBLAS in stages. Best ground-up CUDA tutorial on the modern open web.
- RepoNVIDIA/cutlassThe reference. Look at `cute/atom/mma_traits_sm90.hpp` to see how Hopper tensor-core operations are encoded.