FlashAttention-3
In PyTorch 2.5+, F.scaled_dot_product_attention(q, k, v, is_causal=True) on an H100 dispatches — without you doing anything — to the FlashAttention-3 kernel. That single function call is what gets executed every time a Llama-3-70B layer runs . The kernel under that one Python line is what most of “fast inference on Hopper” actually means.
The thing FlashAttention got right, and that the field couldn’t see for five years before it, is that is memory-bound, not compute-bound. A naive kernel writes the full N × N score matrix to , reads it back to softmax it, writes it again, reads it once more for the · V step. The math takes a microsecond; the round-trips take ten. FA-1 (Tri Dao, 2022) replaced that with a tiled pass that never materializes the N × N matrix — using a six-line “online softmax” trick to keep the running max and sum without ever needing the full row. Memory traffic dropped from O(N²) to O(N), throughput went up 2–4×, and long-context models became feasible.
FA-2 cleaned up the math. FA-3 (2024) rewrote everything for Hopper: async copies move data while the SMs compute, warp specialization splits each block into a producer warp (loads) and a consumer warp (matmuls), instructions push tensor-core throughput past the older mma.sync ceiling, and with block-scaled accumulation nearly doubles peak again. Result: ~75% of H100 BF16 peak, ~85% of FP8 peak. Same model, same numbers, faster attention — and the Blackwell-tuned version of the same kernel family pushes another 1.5–2× through 2026.
TL;DR
- Standard attention writes the matrix to HBM, then reads it back. That’s memory traffic and is the bottleneck — not the FLOPs.
- FlashAttention-1/2 tile the computation, never materialize the full matrix in HBM, and use the online softmax trick to fuse softmax with matmul.
- FlashAttention-3 (2024) rewrites for Hopper: TMA async copies, warp-specialized producer/consumer, FP8 with block-scaled accumulation. ~75% of H100 peak for fp16/bf16, and FP8 throughput nearly doubles that.
- Net effect on inference: same model, same numbers — 2–4× faster attention layer, longer contexts feasible, lower memory pressure.
Why memory traffic is the whole game
A naive attention kernel reads from and writes to HBM something like 20× more bytes than the math actually requires. bandwidth, not FLOPs, is what bottlenecks LLM inference. FlashAttention is the single biggest kernel-level win of the last five years — it’s why long-context models exist at scale.
If you’re going to read one ML systems paper, this family is it. Then re-read it after touching CUDA for a month.
Mental model — three regimes
The conceptual leap is FA-1: never materialize the attention matrix. Everything since is engineering on top of that.
Standard attention’s memory traffic
For sequence length , head dim :
Q, K, V → 3 × N × d bytes loaded
Q · Kᵀ → S → N × N bytes written, then read back
softmax(S) → N × N written + read again
P · V → O → N × d writtenTotal HBM traffic: bytes (in fp16/bf16). At that’s 268 MB per attention call — read and written across HBM. Per layer. Per token in decode. The math takes ~1 µs; the memory reads take 5–10 µs.
FlashAttention’s tiling
Instead, split Q into blocks of rows, K and V into blocks of columns. For each Q-block, stream over all K/V blocks, keep a running max and sum for the softmax, never write the intermediate .
The online softmax keeps — running max and sum of exponentials — and rescales the partial output as new K-blocks arrive:
After the last K-block, divide by . Mathematically identical to standard softmax; numerically stable; never stores the full attention matrix.
What FlashAttention-3 added on Hopper
| Mechanism | What it does | Why it helps |
|---|---|---|
| TMA (Tensor Memory Accelerator) | Async bulk copies HBM↔SRAM with one instruction | Frees the SMs to compute while data moves |
| Warp specialization | Some warps do TMA loads (producer), others do WGMMA matmul (consumer) | Compute and IO overlap perfectly within a CTA |
| WGMMA | Hopper’s warpgroup matmul instruction | Higher throughput than mma.sync for big tiles |
| FP8 with block scaling | E4M3 inputs, FP32 accumulation, per-block scale factors | ~2× throughput vs BF16 with negligible quality loss |
Result: ~75% of H100 BF16 peak, ~85% of FP8 peak in published numbers. On Blackwell (B200/B300) the same kernel family — adapted for 5th-gen tensor cores and FP4 — pushes attention throughput further still; expect another ~1.5–2× over H100 for matched configurations as the kernels mature through 2026.
Run it in your browser
A pure-Python toy that demonstrates the online softmax — exactly how FlashAttention’s inner loop accumulates without materializing the full attention matrix:
The ~7 lines of running-max-and-sum logic is the entire mathematical insight of FlashAttention. Everything else is GPU engineering.
Quick check
Key takeaways
- Attention is memory-bound, not compute-bound. This is the single fact that makes the FlashAttention family the right thing.
- Online softmax = running max + running sum + output rescale. Mathematically equal to standard softmax, but doesn’t need the full attention matrix in memory.
- FA-2 was great on A100; FA-3 is the Hopper-specific rewrite. TMA + warp specialization + WGMMA + FP8 stacked together. Don’t use FA-1/2 on H100 if you can avoid it.
torch.nn.functional.scaled_dot_product_attentionalready calls FA under the hood in modern PyTorch — you usually don’t write the kernel yourself.- The sweet spot is large . For short sequences (under ~512 tokens), naive attention is fine; the tiling overhead doesn’t pay back.
Go deeper
- PaperFlashAttention: Fast and Memory-Efficient Exact Attention with IO-AwarenessThe original. Required reading. The IO-awareness framing changed how the field thinks about kernels.
- PaperFlashAttention-2: Faster Attention with Better Parallelism and Work PartitioningThe cleaner re-derivation. Better warp partitioning, simpler online softmax statement.
- PaperFlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-PrecisionHopper-specific rewrite. TMA, warp specialization, FP8 with block scaling.
- VideoTri Dao — FlashAttention-3 talkAuthor walking through why FA-3 looks the way it does on Hopper.
- BlogFlashAttention-3 — author's blog postThe most accessible writeup of the FA-3 architecture. Read before the paper.
- RepoDao-AILab/flash-attentionReference implementation. Read `csrc/flash_attn/` for the actual CUDA.
- BlogFlashAttention-2 explainedBest non-author explainer I've found. Diagrams beat the paper for intuition.
- PaperThunderKittens: Simple, Fast, and Adorable AI KernelsA different DSL for writing FA-3-class kernels with less code. Worth reading after the FA-3 paper.
TL;DR
- Standard attention writes the matrix to HBM, then reads it back. That’s memory traffic and is the bottleneck — not the FLOPs.
- FlashAttention-1/2 tile the computation, never materialize the full matrix in HBM, and use the online softmax trick to fuse softmax with matmul.
- FlashAttention-3 (2024) rewrites for Hopper: TMA async copies, warp-specialized producer/consumer, FP8 with block-scaled accumulation. ~75% of H100 peak for fp16/bf16, and FP8 throughput nearly doubles that.
- Net effect on inference: same model, same numbers — 2–4× faster attention layer, longer contexts feasible, lower memory pressure.
Why this matters
A naive attention kernel reads from and writes to HBM (the GPU’s main memory) something like more bytes than the math actually requires. HBM bandwidth, not FLOPs, is what bottlenecks LLM inference. FlashAttention is the single biggest kernel-level win of the last 5 years — it’s why long-context models exist at scale.
If you’re going to read one ML systems paper, this family is it. Then re-read it after touching CUDA for a month.
Mental model — three regimes
The conceptual leap is FA-1: never materialize the attention matrix. Everything since is engineering on top of that.
Concrete walkthrough — why it’s actually faster
Standard attention’s memory traffic
For sequence length , head dim :
Q, K, V → 3 × N × d bytes loaded
Q · Kᵀ → S → N × N bytes written, then read back
softmax(S) → N × N written + read again
P · V → O → N × d writtenTotal HBM traffic: bytes (in fp16/bf16). At that’s 268 MB per attention call — read and written across HBM. Per layer. Per token in decode. The math takes ~1 µs; the memory reads take 5–10 µs.
FlashAttention’s tiling
Instead, split Q into blocks of rows, K and V into blocks of columns. For each Q-block, stream over all K/V blocks, keep a running max and sum for the softmax, never write the intermediate .
The online softmax keeps — running max and sum of exponentials — and rescales the partial output as new K-blocks arrive:
After the last K-block, divide by . Mathematically identical to standard softmax; numerically stable; never stores the full attention matrix.
What FlashAttention-3 added on Hopper
| Mechanism | What it does | Why it helps |
|---|---|---|
| TMA (Tensor Memory Accelerator) | Async bulk copies HBM↔SRAM with one instruction | Frees the SMs to compute while data moves |
| Warp specialization | Some warps do TMA loads (producer), others do WGMMA matmul (consumer) | Compute and IO overlap perfectly within a CTA |
| WGMMA | Hopper’s warpgroup matmul instruction | Higher throughput than mma.sync for big tiles |
| FP8 with block scaling | E4M3 inputs, FP32 accumulation, per-block scale factors | ~2× throughput vs BF16 with negligible quality loss |
Result: ~75% of H100 BF16 peak, ~85% of FP8 peak in published numbers. On Blackwell (B200/B300) the same kernel family — adapted for 5th-gen tensor cores and FP4 — pushes attention throughput further still; expect another ~1.5–2× over H100 for matched configurations as the kernels mature through 2026.
Run it in your browser
A pure-Python toy that demonstrates the online softmax — exactly how FlashAttention’s inner loop accumulates without materializing the full attention matrix:
The ~7 lines of running-max-and-sum logic is the entire mathematical insight of FlashAttention. Everything else is GPU engineering.
Quick check
Key takeaways
- Attention is memory-bound, not compute-bound. This is the single fact that makes the FlashAttention family the right thing.
- Online softmax = running max + running sum + output rescale. Mathematically equal to standard softmax, but doesn’t need the full attention matrix in memory.
- FA-2 was great on A100; FA-3 is the Hopper-specific rewrite. TMA + warp specialization + WGMMA + FP8 stacked together. Don’t use FA-1/2 on H100 if you can avoid it.
torch.nn.functional.scaled_dot_product_attentionalready calls FA under the hood in modern PyTorch — you usually don’t write the kernel yourself.- The sweet spot is large . For short sequences (under ~512 tokens), naive attention is fine; the tiling overhead doesn’t pay back.
Go deeper
- PaperFlashAttention: Fast and Memory-Efficient Exact Attention with IO-AwarenessThe original. Required reading. The IO-awareness framing changed how the field thinks about kernels.
- PaperFlashAttention-2: Faster Attention with Better Parallelism and Work PartitioningThe cleaner re-derivation. Better warp partitioning, simpler online softmax statement.
- PaperFlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-PrecisionHopper-specific rewrite. TMA, warp specialization, FP8 with block scaling.
- VideoTri Dao — FlashAttention-3 talkAuthor walking through why FA-3 looks the way it does on Hopper.
- BlogFlashAttention-3 — author's blog postThe most accessible writeup of the FA-3 architecture. Read before the paper.
- RepoDao-AILab/flash-attentionReference implementation. Read `csrc/flash_attn/` for the actual CUDA.
- BlogFlashAttention-2 explainedBest non-author explainer I've found. Diagrams beat the paper for intuition.
- PaperThunderKittens: Simple, Fast, and Adorable AI KernelsA different DSL for writing FA-3-class kernels with less code. Worth reading after the FA-3 paper.