Skip to content

Multi-Head Attention

In PyTorch, the attention layer is one line: F.scaled_dot_product_attention(q, k, v, is_causal=True). Llama-3.1-8B calls it 32 times per token — once per layer — and the bulk of inference latency lives inside that one function. This lesson is what that single line expands to: three projections, a reshape, a softmaxed dot product, a concat, and an output projection. Five steps of tensor algebra. The rest of this track is twelve different ways to make those five steps faster.

The “multi-head” part is the trick that made attention work at scale. Instead of one giant attention sum mixing the whole d_model dimension, the model reshapes the projected tensors so it runs H parallel in narrower subspaces, then concatenates the results. There is no per-head loop in any production kernel. Heads are an axis you reshape into and out of — same matmuls, same memory, different way of partitioning the dot product. If you walk away from this lesson with one mental image, make it that.

Every optimization in the rest of this module — , , GQA, MLA, paged attention, chunked prefill — is a transformation of the math on this page.

TL;DR

  • An attention head is a triplet (Q, K, V) of projections from d_model → d_head. MHA runs H of them in parallel and concatenates the outputs.
  • The whole thing is one big matmul, one reshape, one scaled-dot-product, one concat, one matmul. There is no separate code path per head — heads are an axis, not a module.
  • The scale factor 1/√d_k exists so the softmax doesn’t saturate as d_head grows; without it, training a wide head diverges.
  • Per-token cost is O(N · d_model²) for the projections and O(N² · d_model) for attention. The N² term is what you spend the rest of this track learning to tame (GQA, KV cache, FlashAttention, paged, chunked prefill, MLA).

What attention actually is

Attention is the only learned subroutine in a transformer that mixes information across token positions. Every other op — feed-forward layers, RMSNorm, residual adds — is point-wise; it operates on each token independently. Without attention, a transformer is just an MLP run independently per token, with no way to reason about context.

The mechanism, in plain English: for each token, build a query vector q. Look at every other token’s key vector k and decide how much that token matters by computing q · k (a dot product is large when the two directions match). Pass those raw similarities through a softmax to turn them into weights, then take a weighted sum of every token’s value vector v. That weighted sum is the new representation for the current token.

Multi-head attention runs H of these in parallel in narrower subspaces, so the model can look for several different kinds of relationships — syntactic, semantic, positional — at the same time without one big head having to do everything.

Mental model

The whole layer is one tensor pipeline with a head-axis bolted on by a reshape. There is no per-head loop in any production implementation. If the diagram makes you think “heads are independent networks,” delete that picture — they share the same big matmul, they’re just sliced differently after it.

The shapes, with numbers

For a sequence of N tokens and a model width of d_model, MHA uses three weight matrices W_Q, W_K, W_V of shape (d_model, d_model) and an output projection W_O of shape (d_model, d_model). The “heads” appear when you take the resulting (B, N, d_model) tensor and reshape it to (B, N, H, d_head) where H · d_head = d_model, then transpose to (B, H, N, d_head) so the head axis is adjacent to the batch.

Modeld_modelHd_headParam count, MHA layer
GPT-2 small76812644 · 768² ≈ 2.36 M
GPT-3 175B12288961284 · 12288² ≈ 604 M
Llama-3.1 8B4096321284 · 4096² ≈ 67 M *
Llama-3.1 405B163841281284 · 16384² ≈ 1.07 B *
DeepSeek-V37168128128uses MLA, not MHA †

* Llama-3.1 actually uses GQA — the K and V projections are smaller. The number above is what plain MHA would cost at the same (d_model, H). The next lesson, GQA, MQA & MLA, is exactly the story of how to shrink that. † DeepSeek-V3 uses Multi-head Latent Attention. Different shapes, same idea.

The math, in one screen

Given input x of shape (B, N, d_model):

# 1. Project. Three (d_model, d_model) matmuls. Q = x @ W_Q # (B, N, d_model) K = x @ W_K # (B, N, d_model) V = x @ W_V # (B, N, d_model) # 2. Split into heads. This is purely a reshape — zero FLOPs. Q = Q.reshape(B, N, H, d_head).transpose(0, 2, 1, 3) # (B, H, N, d_head) K = K.reshape(B, N, H, d_head).transpose(0, 2, 1, 3) V = V.reshape(B, N, H, d_head).transpose(0, 2, 1, 3) # 3. Scaled dot-product attention, per head, in parallel. scores = Q @ K.transpose(-1, -2) / sqrt(d_head) # (B, H, N, N) if causal: scores = mask_upper_triangular(scores) # decoder weights = softmax(scores, axis=-1) # (B, H, N, N) ctx = weights @ V # (B, H, N, d_head) # 4. Concat heads back. Another reshape. ctx = ctx.transpose(0, 2, 1, 3).reshape(B, N, d_model) # 5. Output projection. y = ctx @ W_O # (B, N, d_model)

Five steps. Three projections, one attention, one output. Production kernels (FlashAttention) fuse step 3 — they never materialize the (N, N) scores matrix in HBM — but the math is identical.

Why divide by √d_head

Q and K are independent random vectors of dimension d_head. Their dot product has variance d_head (sum of d_head independent unit-variance products). Without scaling, larger heads push softmax inputs into the saturating tail, gradients vanish, and training diverges — the classic failure mode. Dividing by √d_head normalizes the variance back to 1. This is not optional; every implementation does it.

Where the FLOPs go

Per layer, ignoring batch:

StepFLOPsNotes
Q,K,V projections3 · 2 · N · d_model²Three matmuls. Compute-heavy, not N².
QKᵀ2 · H · N² · d_headThis is the N² term. Same as 2 · N² · d_model.
softmax · V2 · H · N² · d_headAnother 2 · N² · d_model.
Output projection2 · N · d_model²One matmul.

Total: 8 · N · d_model² + 4 · N² · d_model. The crossover where attention dominates the projections is at N ≈ 2 · d_model. For Llama-3.1 8B (d_model = 4096) that’s around 8K tokens — exactly where most production workloads sit. Long context is N²-bound by definition.

Run it in your browser

Two cells, sharing the same Python namespace via group="mha". Run them in order; the second cell uses arrays the first cell defined.

Python · cell · group "mha"Cell 1 — set up the model dims and the input. Run me first.
Ctrl+Enter to run
Python · cell · group "mha"Cell 2 — full forward pass. Variables x, W_Q, W_K, W_V, W_O, B, N, H, d_head, d_model are inherited from cell 1.
Ctrl+Enter to run

If the second cell prints zeros for positions 0–9, your is wired correctly. If it doesn’t, you forgot the tril.

Quick check

Fill in the blank
The scaling factor inside the softmax that keeps gradients sane as d_head grows.
scores = Q @ K.transpose(-1, -2) /
It cancels the variance growth of a d_head-dimensional dot product.
Quick check
GPT-3 175B has d_model = 12288 and 96 heads. A practitioner doubles H to 192 (and halves d_head to 64) without changing anything else. What changes?

Key takeaways

  1. MHA is one matmul pipeline with a reshape in the middle, not H independent attention modules. Every production implementation treats heads as an axis.
  2. H · d_head = d_model always. Doubling H halves d_head and changes neither parameter count nor FLOPs.
  3. The 1/√d_head scale is non-negotiable. It exists for variance, not for “stability” or “intuition.”
  4. Per-layer cost is 8 · N · d_model² + 4 · N² · d_model. The N² term dominates above N ≈ 2 · d_model.
  5. Everything else in this module — GQA, MLA, KV cache, FlashAttention, paged attention, chunked prefill — is a transformation of these five steps. Hold this picture in your head before reading the next four lessons.

Go deeper

TL;DR

  • An attention head is a triplet (Q, K, V) of projections from d_model → d_head. MHA runs H of them in parallel and concatenates the outputs.
  • The whole thing is one big matmul, one reshape, one scaled-dot-product, one concat, one matmul. There is no separate code path per head — heads are an axis, not a module.
  • The scale factor 1/√d_k exists so the softmax doesn’t saturate as d_head grows; without it, training a wide head diverges.
  • Per-token cost is O(N · d_model²) for the projections and O(N² · d_model) for attention. The N² term is what you spend the rest of this track learning to tame (GQA, KV cache, FlashAttention, paged, chunked prefill, MLA).

Why this matters

Multi-head attention is the only learned subroutine in a transformer that mixes information across token positions — every other op (FFN, RMSNorm, residuals) is point-wise. Get this layer right and the rest is bookkeeping; get it wrong and your model can’t reason about context. Every optimization in production inference — flash attention, KV caching, GQA, MLA, paged attention, chunked prefill — is a transformation of exactly the math on this page. If you can’t write MHA from scratch in numpy, you’ll plateau at “I tweak model.config.num_attention_heads” and nothing else.

Mental model

The whole layer is one tensor pipeline with a head-axis bolted on by a reshape. There is no per-head loop in any production implementation. If the diagram makes you think “heads are independent networks,” delete that picture — they share the same big matmul, they’re just sliced differently after it.

Concrete walkthrough

The shapes, with numbers

For a sequence of N tokens and a model width of d_model, MHA uses three weight matrices W_Q, W_K, W_V of shape (d_model, d_model) and an output projection W_O of shape (d_model, d_model). The “heads” appear when you take the resulting (B, N, d_model) tensor and reshape it to (B, N, H, d_head) where H · d_head = d_model, then transpose to (B, H, N, d_head) so the head axis is adjacent to the batch.

Modeld_modelHd_headParam count, MHA layer
GPT-2 small76812644 · 768² ≈ 2.36 M
GPT-3 175B12288961284 · 12288² ≈ 604 M
Llama-3.1 8B4096321284 · 4096² ≈ 67 M *
Llama-3.1 405B163841281284 · 16384² ≈ 1.07 B *
DeepSeek-V37168128128uses MLA, not MHA †

* Llama-3.1 actually uses GQA — the K and V projections are smaller. The number above is what plain MHA would cost at the same (d_model, H). The next lesson, GQA, MQA & MLA, is exactly the story of how to shrink that. † DeepSeek-V3 uses Multi-head Latent Attention. Different shapes, same idea.

The math, in one screen

Given input x of shape (B, N, d_model):

# 1. Project. Three (d_model, d_model) matmuls. Q = x @ W_Q # (B, N, d_model) K = x @ W_K # (B, N, d_model) V = x @ W_V # (B, N, d_model) # 2. Split into heads. This is purely a reshape — zero FLOPs. Q = Q.reshape(B, N, H, d_head).transpose(0, 2, 1, 3) # (B, H, N, d_head) K = K.reshape(B, N, H, d_head).transpose(0, 2, 1, 3) V = V.reshape(B, N, H, d_head).transpose(0, 2, 1, 3) # 3. Scaled dot-product attention, per head, in parallel. scores = Q @ K.transpose(-1, -2) / sqrt(d_head) # (B, H, N, N) if causal: scores = mask_upper_triangular(scores) # decoder weights = softmax(scores, axis=-1) # (B, H, N, N) ctx = weights @ V # (B, H, N, d_head) # 4. Concat heads back. Another reshape. ctx = ctx.transpose(0, 2, 1, 3).reshape(B, N, d_model) # 5. Output projection. y = ctx @ W_O # (B, N, d_model)

Five steps. Three projections, one attention, one output. Production kernels (FlashAttention) fuse step 3 — they never materialize the (N, N) scores matrix in HBM — but the math is identical.

Why divide by √d_head

Q and K are independent random vectors of dimension d_head. Their dot product has variance d_head (sum of d_head independent unit-variance products). Without scaling, larger heads push softmax inputs into the saturating tail; gradients vanish and training diverges. Dividing by √d_head normalizes the variance back to 1. This is not optional — every implementation does it.

Where the FLOPs go

Per layer, ignoring batch:

StepFLOPsNotes
Q,K,V projections3 · 2 · N · d_model²Three matmuls. Compute-heavy, not N².
QKᵀ2 · H · N² · d_headThis is the N² term. Same as 2 · N² · d_model.
softmax · V2 · H · N² · d_headAnother 2 · N² · d_model.
Output projection2 · N · d_model²One matmul.

Total: 8 · N · d_model² + 4 · N² · d_model. The crossover where attention dominates the projections is at N ≈ 2 · d_model. For Llama-3.1 8B (d_model = 4096) that’s around 8K tokens — exactly where most production workloads sit. Long context is N²-bound by definition.

Run it in your browser

Two cells, sharing the same Python namespace via group="mha". Run them in order; the second cell uses arrays the first cell defined.

Python · cell · group "mha"Cell 1 — set up the model dims and the input. Run me first.
Ctrl+Enter to run
Python · cell · group "mha"Cell 2 — full forward pass. Variables x, W_Q, W_K, W_V, W_O, B, N, H, d_head, d_model are inherited from cell 1.
Ctrl+Enter to run

If the second cell prints zeros for positions 0–9, your causal mask is wired correctly. If it doesn’t, you forgot the tril.

Quick check

Fill in the blank
The scaling factor inside the softmax that keeps gradients sane as d_head grows.
scores = Q @ K.transpose(-1, -2) /
It cancels the variance growth of a d_head-dimensional dot product.
Quick check
GPT-3 175B has d_model = 12288 and 96 heads. A practitioner doubles H to 192 (and halves d_head to 64) without changing anything else. What changes?

Key takeaways

  1. MHA is one matmul pipeline with a reshape in the middle, not H independent attention modules. Every production implementation treats heads as an axis.
  2. H · d_head = d_model always. Doubling H halves d_head and changes neither parameter count nor FLOPs.
  3. The 1/√d_head scale is non-negotiable. It exists for variance, not for “stability” or “intuition.”
  4. Per-layer cost is 8 · N · d_model² + 4 · N² · d_model. The N² term dominates above N ≈ 2 · d_model.
  5. Everything else in this module — GQA, MLA, KV cache, FlashAttention, paged attention, chunked prefill — is a transformation of these five steps. Hold this picture in your head before reading the next four lessons.

Go deeper