Skip to content

FlashAttention-3

In PyTorch 2.5+, F.scaled_dot_product_attention(q, k, v, is_causal=True) on an H100 dispatches — without you doing anything — to the FlashAttention-3 kernel. That single function call is what gets executed every time a Llama-3-70B layer runs . The kernel under that one Python line is what most of “fast inference on Hopper” actually means.

The thing FlashAttention got right, and that the field couldn’t see for five years before it, is that is memory-bound, not compute-bound. A naive kernel writes the full N × N score matrix to , reads it back to softmax it, writes it again, reads it once more for the · V step. The math takes a microsecond; the round-trips take ten. FA-1 (Tri Dao, 2022) replaced that with a tiled pass that never materializes the N × N matrix — using a six-line “online softmax” trick to keep the running max and sum without ever needing the full row. Memory traffic dropped from O(N²) to O(N), throughput went up 2–4×, and long-context models became feasible.

FA-2 cleaned up the math. FA-3 (2024) rewrote everything for Hopper: async copies move data while the SMs compute, warp specialization splits each block into a producer warp (loads) and a consumer warp (matmuls), instructions push tensor-core throughput past the older mma.sync ceiling, and with block-scaled accumulation nearly doubles peak again. Result: ~75% of H100 BF16 peak, ~85% of FP8 peak. Same model, same numbers, faster attention — and the Blackwell-tuned version of the same kernel family pushes another 1.5–2× through 2026.

TL;DR

  • Standard attention writes the N×NN \times N matrix to HBM, then reads it back. That’s O(N2)O(N^2) memory traffic and is the bottleneck — not the FLOPs.
  • FlashAttention-1/2 tile the computation, never materialize the full N×NN \times N matrix in HBM, and use the online softmax trick to fuse softmax with matmul.
  • FlashAttention-3 (2024) rewrites for Hopper: TMA async copies, warp-specialized producer/consumer, FP8 with block-scaled accumulation. ~75% of H100 peak for fp16/bf16, and FP8 throughput nearly doubles that.
  • Net effect on inference: same model, same numbers — 2–4× faster attention layer, longer contexts feasible, lower memory pressure.

Why memory traffic is the whole game

A naive attention kernel reads from and writes to HBM something like 20× more bytes than the math actually requires. bandwidth, not FLOPs, is what bottlenecks LLM inference. FlashAttention is the single biggest kernel-level win of the last five years — it’s why long-context models exist at scale.

If you’re going to read one ML systems paper, this family is it. Then re-read it after touching CUDA for a month.

Mental model — three regimes

The conceptual leap is FA-1: never materialize the N×NN \times N attention matrix. Everything since is engineering on top of that.

Standard attention’s memory traffic

For sequence length NN, head dim dd:

Q, K, V → 3 × N × d bytes loaded Q · Kᵀ → S → N × N bytes written, then read back softmax(S) → N × N written + read again P · V → O → N × d written

Total HBM traffic: 4N2+8Nd\approx 4N^2 + 8Nd bytes (in fp16/bf16). At N=8192,d=128N = 8192, d = 128 that’s 268 MB per attention call — read and written across HBM. Per layer. Per token in decode. The math takes ~1 µs; the memory reads take 5–10 µs.

FlashAttention’s tiling

Instead, split Q into blocks of BrB_r rows, K and V into blocks of BcB_c columns. For each Q-block, stream over all K/V blocks, keep a running max and sum for the softmax, never write the intermediate N×NN \times N.

The online softmax keeps (m,)(m, \ell) — running max and sum of exponentials — and rescales the partial output as new K-blocks arrive:

m(j)=max(m(j1),m~(j)),(j)=em(j1)m(j)(j1)+em~(j)m(j)~(j)m^{(j)} = \max(m^{(j-1)}, \tilde m^{(j)}), \quad \ell^{(j)} = e^{m^{(j-1)} - m^{(j)}} \ell^{(j-1)} + e^{\tilde m^{(j)} - m^{(j)}} \tilde \ell^{(j)}

After the last K-block, divide by \ell. Mathematically identical to standard softmax; numerically stable; never stores the full attention matrix.

What FlashAttention-3 added on Hopper

MechanismWhat it doesWhy it helps
TMA (Tensor Memory Accelerator)Async bulk copies HBM↔SRAM with one instructionFrees the SMs to compute while data moves
Warp specializationSome warps do TMA loads (producer), others do WGMMA matmul (consumer)Compute and IO overlap perfectly within a CTA
WGMMAHopper’s warpgroup matmul instructionHigher throughput than mma.sync for big tiles
FP8 with block scalingE4M3 inputs, FP32 accumulation, per-block scale factors~2× throughput vs BF16 with negligible quality loss

Result: ~75% of H100 BF16 peak, ~85% of FP8 peak in published numbers. On Blackwell (B200/B300) the same kernel family — adapted for 5th-gen tensor cores and FP4 — pushes attention throughput further still; expect another ~1.5–2× over H100 for matched configurations as the kernels mature through 2026.

Run it in your browser

A pure-Python toy that demonstrates the online softmax — exactly how FlashAttention’s inner loop accumulates without materializing the full attention matrix:

Python — editableOnline vs standard softmax — proves they're identical, but online uses O(d) memory.
Ctrl+Enter to run

The ~7 lines of running-max-and-sum logic is the entire mathematical insight of FlashAttention. Everything else is GPU engineering.

Quick check

Quick check
Standard attention's bottleneck is most accurately described as:

Key takeaways

  1. Attention is memory-bound, not compute-bound. This is the single fact that makes the FlashAttention family the right thing.
  2. Online softmax = running max + running sum + output rescale. Mathematically equal to standard softmax, but doesn’t need the full attention matrix in memory.
  3. FA-2 was great on A100; FA-3 is the Hopper-specific rewrite. TMA + warp specialization + WGMMA + FP8 stacked together. Don’t use FA-1/2 on H100 if you can avoid it.
  4. torch.nn.functional.scaled_dot_product_attention already calls FA under the hood in modern PyTorch — you usually don’t write the kernel yourself.
  5. The sweet spot is large NN. For short sequences (under ~512 tokens), naive attention is fine; the tiling overhead doesn’t pay back.

Go deeper

TL;DR

  • Standard attention writes the N×NN \times N matrix to HBM, then reads it back. That’s O(N2)O(N^2) memory traffic and is the bottleneck — not the FLOPs.
  • FlashAttention-1/2 tile the computation, never materialize the full N×NN \times N matrix in HBM, and use the online softmax trick to fuse softmax with matmul.
  • FlashAttention-3 (2024) rewrites for Hopper: TMA async copies, warp-specialized producer/consumer, FP8 with block-scaled accumulation. ~75% of H100 peak for fp16/bf16, and FP8 throughput nearly doubles that.
  • Net effect on inference: same model, same numbers — 2–4× faster attention layer, longer contexts feasible, lower memory pressure.

Why this matters

A naive attention kernel reads from and writes to HBM (the GPU’s main memory) something like 20×20 \times more bytes than the math actually requires. HBM bandwidth, not FLOPs, is what bottlenecks LLM inference. FlashAttention is the single biggest kernel-level win of the last 5 years — it’s why long-context models exist at scale.

If you’re going to read one ML systems paper, this family is it. Then re-read it after touching CUDA for a month.

Mental model — three regimes

The conceptual leap is FA-1: never materialize the N×NN \times N attention matrix. Everything since is engineering on top of that.

Concrete walkthrough — why it’s actually faster

Standard attention’s memory traffic

For sequence length NN, head dim dd:

Q, K, V → 3 × N × d bytes loaded Q · Kᵀ → S → N × N bytes written, then read back softmax(S) → N × N written + read again P · V → O → N × d written

Total HBM traffic: 4N2+8Nd\approx 4N^2 + 8Nd bytes (in fp16/bf16). At N=8192,d=128N = 8192, d = 128 that’s 268 MB per attention call — read and written across HBM. Per layer. Per token in decode. The math takes ~1 µs; the memory reads take 5–10 µs.

FlashAttention’s tiling

Instead, split Q into blocks of BrB_r rows, K and V into blocks of BcB_c columns. For each Q-block, stream over all K/V blocks, keep a running max and sum for the softmax, never write the intermediate N×NN \times N.

The online softmax keeps (m,)(m, \ell) — running max and sum of exponentials — and rescales the partial output as new K-blocks arrive:

m(j)=max(m(j1),m~(j)),(j)=em(j1)m(j)(j1)+em~(j)m(j)~(j)m^{(j)} = \max(m^{(j-1)}, \tilde m^{(j)}), \quad \ell^{(j)} = e^{m^{(j-1)} - m^{(j)}} \ell^{(j-1)} + e^{\tilde m^{(j)} - m^{(j)}} \tilde \ell^{(j)}

After the last K-block, divide by \ell. Mathematically identical to standard softmax; numerically stable; never stores the full attention matrix.

What FlashAttention-3 added on Hopper

MechanismWhat it doesWhy it helps
TMA (Tensor Memory Accelerator)Async bulk copies HBM↔SRAM with one instructionFrees the SMs to compute while data moves
Warp specializationSome warps do TMA loads (producer), others do WGMMA matmul (consumer)Compute and IO overlap perfectly within a CTA
WGMMAHopper’s warpgroup matmul instructionHigher throughput than mma.sync for big tiles
FP8 with block scalingE4M3 inputs, FP32 accumulation, per-block scale factors~2× throughput vs BF16 with negligible quality loss

Result: ~75% of H100 BF16 peak, ~85% of FP8 peak in published numbers. On Blackwell (B200/B300) the same kernel family — adapted for 5th-gen tensor cores and FP4 — pushes attention throughput further still; expect another ~1.5–2× over H100 for matched configurations as the kernels mature through 2026.

Run it in your browser

A pure-Python toy that demonstrates the online softmax — exactly how FlashAttention’s inner loop accumulates without materializing the full attention matrix:

Python — editableOnline vs standard softmax — proves they're identical, but online uses O(d) memory.
Ctrl+Enter to run

The ~7 lines of running-max-and-sum logic is the entire mathematical insight of FlashAttention. Everything else is GPU engineering.

Quick check

Quick check
Standard attention's bottleneck is most accurately described as:

Key takeaways

  1. Attention is memory-bound, not compute-bound. This is the single fact that makes the FlashAttention family the right thing.
  2. Online softmax = running max + running sum + output rescale. Mathematically equal to standard softmax, but doesn’t need the full attention matrix in memory.
  3. FA-2 was great on A100; FA-3 is the Hopper-specific rewrite. TMA + warp specialization + WGMMA + FP8 stacked together. Don’t use FA-1/2 on H100 if you can avoid it.
  4. torch.nn.functional.scaled_dot_product_attention already calls FA under the hood in modern PyTorch — you usually don’t write the kernel yourself.
  5. The sweet spot is large NN. For short sequences (under ~512 tokens), naive attention is fine; the tiling overhead doesn’t pay back.

Go deeper