Attention
Attention is the heart of the Transformer. It’s also the most memory-hungry, latency-sensitive, and innovation-rich block — which is why most production tricks (FlashAttention, GQA, MLA, RoPE) target it specifically.
0 / 4 lessons ~58 min total
Module capstone — build it
Implement attention from scratch — including the FlashAttention tile The attention block as it actually runs in production, written, verified against Llama-3.2-1B.
Advanced · One focused weekend (~10 h) · Free Colab T4
A single-file PyTorch implementation: multi-head, GQA, RoPE, and a tiled FlashAttention-style forward pass. Outputs match Llama-3.2-1B's attention layer to numerical tolerance. Benchmark plot at seq_len 4096 showing tiled-vs-naive HBM traffic.
Build it — step by step
01 Set up the harness — Llama-3.2-1B on Colab 40 min
Load Llama-3.2-1B-Instruct via transformers. Pull out one attention layer's weights (Q, K, V, O projections + RoPE freqs). Save them to a .pt file so you can reload without HF round-trips.
checkpoint You can `torch.load("attn_layer_0.pt")` and get a dict with the right shapes for `(d_model=2048, H=32, H_kv=8, d_head=64)`.
watch out Llama-3.2-1B uses GQA — `num_key_value_heads=8` not `32`. Using full-MHA shapes will silently break shape checks two steps later.
02 Naive multi-head attention — match the reference 90 min
Write `naive_attention(x, weights)` in pure PyTorch: Q/K/V projections, head split, scaled dot-product, mask, softmax, output projection. No FlashAttention. Test on a 256-token input.
checkpoint Your output matches `model.layers[0].self_attn(x)` to atol=1e-3. RMSE printed and small.
watch out Forgetting that GQA needs to *repeat* the K, V heads `n_query_heads / n_kv_heads = 4` times before the matmul. The reference does this internally; your reimplementation must too.
03 Add RoPE — and compare against Llama's rotary_embedding 75 min
Implement rotary positional embedding from scratch: build cos/sin tables for the head_dim, apply to Q and K before the attention matmul. Match the Llama-3.1 partial-rope and YaRN-style scaling if your model has it.
checkpoint Output still matches reference to atol=1e-3 with positions > 0.
watch out Llama-3 uses partial RoPE — only the first `rotary_dim` of head_dim is rotated. Default config has `rotary_dim == head_dim`, but the partial case is the silent test that catches a wrong implementation.
04 Tile it — write a FlashAttention-style forward 120 min
Replace the inner attention with a block-tiled version: outer loop over Q blocks, inner loop over KV blocks, online softmax with running max + running sum. Pure PyTorch (no Triton yet) — block-tiled is enough to demonstrate the IO win.
checkpoint Output still matches naive within atol=1e-3. The tiled implementation never materializes the full (N×N) score matrix.
watch out The classic bug: re-using the running-max from the previous block without rescaling the previous accumulator. The math is in the FlashAttention paper, equation (6); copy it carefully.
05 Profile + benchmark — show the HBM traffic drop 60 min
Use `torch.profiler` with `record_shapes=True` to capture HBM reads/writes for the naive vs tiled forward at seq_len 1024, 4096, 16384. Plot total bytes transferred per call.
checkpoint Tiled version moves ~10× fewer bytes than naive at seq_len 16384. Plot saved to PNG.
watch out PyTorch profiler only catches HBM through CUDA events; running on CPU gives meaningless numbers. Do this on a Colab GPU.
06 Push to GitHub with a clean README 45 min
One-file repo: `attention_from_scratch.py` (your impl) + `bench.py` (the comparison) + `README.md` (one-paragraph description, the benchmark plot, the equation reference).
checkpoint A reader can clone and run `pip install torch transformers matplotlib && python bench.py` and reproduce your plot.
You walk away with
A working attention block — MHA + GQA + RoPE + tiled forward — that matches a real model Fluency reading FlashAttention-style code: tiles, online softmax, the running-max trick A benchmark methodology you can re-use for any kernel comparison work A repo whose README + plot stand on their own Tools you'll use PyTorch 2.x Llama-3.2-1B (reference) FlashAttention paper (Dao et al.) torch.profiler rotary_embedding reference impl