Skip to content

KV Cache Basics

When you call model.generate(input_ids, max_new_tokens=512) in PyTorch, the framework silently does two very different things back-to-back. Prefill runs your prompt through every layer once, in one big parallel matmul, and stashes the keys and values from every token in a buffer called the . Decode then loops: for each new token, run one forward pass that re-uses every prior K and V from the cache and only computes a new row for the just-generated token. Without that cache, generating the 512th new token would re-run the prompt and the previous 511 tokens through every layer — quadratic work in context length, and 100× slower than the actual cached implementation.

The cache is the most important piece of state in an inference server. It’s a per-layer buffer of (K, V) pairs that grows by one row every time the model emits a token. Decode at step t reads all t cached rows and writes one new row. That single design decision is what turns LLM generation from “interesting research” into “ship it to a billion users.”

The catch: the cache lives in HBM and grows linearly with context. For Llama 3.1 70B at 32K tokens, cache is ~10 GB — often larger than weights when you batch a few users. Every modern serving optimization in the rest of this module — , prefix caching, MLA, KV quantization, disaggregated prefill — is fundamentally about managing this one buffer.

TL;DR

  • Without a KV cache, every token would re-attend over the entire prompt. With it, each new token does one extra row of attention math.
  • The cache is 2 × n_layers × n_kv_heads × head_dim × seq_len × dtype_bytes. For Llama 3.1 70B at 32K context that’s ~10 GB — often more than the weights.
  • The cache lives on the GPU and grows linearly with each generated token. Long contexts OOM on the cache long before they OOM on weights.
  • Every modern serving optimization (PagedAttention, prefix caching, MLA, KV quantization) is fundamentally about managing this one buffer.

Prefill vs decode — two very different jobs

LLM inference is two different jobs glued together. Prefill runs the prompt through every layer in one big parallel matmul — compute-bound, lots of FLOPs, runs at high tensor-core utilization. Decode generates one token at a time, attending back to everything before it — memory-bandwidth-bound, tiny FLOPs per token, dominated by HBM reads of the K and V cache. If decode recomputed K and V for the prompt every step, generation would be quadratic in context length and nothing past 1K tokens would be usable.

The KV cache is what bridges those two regimes: prefill fills it once, decode reads it forever after. Knowing its size and growth rate is the difference between explaining why a 70B model needs 4× H100s for 32K context and just guessing.

Mental model

Think of it as a per-layer ring buffer that grows by one row of (K, V) pairs every generated token. Attention at step t reads all t rows in parallel.

The size math

KV-cache memory in bytes, for a single sequence:

size = 2 × n_layers × n_kv_heads × head_dim × seq_len × bytes_per_element └── 2 because you store K *and* V

For a few canonical models in BF16 (2 bytes/element) at 8K context:

Modellayerskv_headshead_dimKV @ 8K
Llama 3.1 8B3281281.0 GB
Llama 3.1 70B8081282.5 GB
Llama 3.1 70B (32K)80812810 GB
Llama 3.1 405B12681283.9 GB

(GQA — n_kv_heads << n_query_heads — already saves ~8× over MHA. Without GQA, that 70B at 32K would be 80 GB of cache alone.)

For a batch of B concurrent users, multiply by B. This is why a serving system’s max concurrent users is almost entirely a function of how much VRAM is left after weights, not how fast the GPU is.

Run it in your browser

Compute the cache size for any model and context length right here. Edit the params and re-run.

Python — editableNo GPU needed — pure arithmetic.
Ctrl+Enter to run

A few things this should make obvious: (1) batch and context length both blow up the cache linearly; (2) at long contexts on a big model, the cache is the binding constraint; (3) the n_kv_heads parameter (GQA) is doing enormous work — DeepSeek’s MLA squeezes this further.

Quick check

Quick check
A 70B model with GQA (8 KV heads, 128 head_dim, 80 layers) is serving 16 concurrent users at 16K context in BF16. Roughly how much GPU memory is consumed by the KV cache alone?

Key takeaways

  1. The cache is 2 × L × H_kv × D × T bytes per sequence — memorize this formula. It is the most useful single equation in LLM serving.
  2. Cache often exceeds weights at long contexts. A 70B model in BF16 weighs ~140 GB; its cache at 32K × bs=8 is ~80 GB.
  3. GQA was the first major breakthrough in cache compression (8× over MHA). MLA (DeepSeek-V2/V3) compresses another ~10×.
  4. Decode is memory-bandwidth bound, not compute bound, because each token re-reads the whole cache. This is why FP8 / INT4 KV quantization matters.

Go deeper

TL;DR

  • Without a KV cache, every token would re-attend over the entire prompt. With it, each new token does one extra row of attention math.
  • The cache is 2 × n_layers × n_kv_heads × head_dim × seq_len × dtype_bytes. For Llama 3.1 70B at 32K context that’s ~10 GB — often more than the weights.
  • The cache lives on the GPU and grows linearly with each generated token. Long contexts OOM on the cache long before they OOM on weights.
  • Every modern serving optimization (PagedAttention, prefix caching, MLA, KV quantization) is fundamentally about managing this one buffer.

Why this matters

LLM inference is two different jobs glued together. Prefill runs the prompt through every layer in one big parallel matmul. Decode generates one token at a time, attending back to everything before it. If decode recomputed K and V for the prompt every step, generation would be quadratic in context length and nothing past 1K tokens would be usable.

The KV cache is the single most important state in an inference server. Knowing its size and growth rate is the difference between explaining why a 70B model needs 4× H100s for 32K context and just guessing.

Mental model

Think of it as a per-layer ring buffer that grows by one row of (K, V) pairs every generated token. Attention at step t reads all t rows in parallel.

Concrete walkthrough — the size math

KV-cache memory in bytes, for a single sequence:

size = 2 × n_layers × n_kv_heads × head_dim × seq_len × bytes_per_element └── 2 because you store K *and* V

For a few canonical models in BF16 (2 bytes/element) at 8K context:

Modellayerskv_headshead_dimKV @ 8K
Llama 3.1 8B3281281.0 GB
Llama 3.1 70B8081282.5 GB
Llama 3.1 70B (32K)80812810 GB
Llama 3.1 405B12681283.9 GB

(GQA — n_kv_heads << n_query_heads — already saves ~8× over MHA. Without GQA, that 70B at 32K would be 80 GB of cache alone.)

For a batch of B concurrent users, multiply by B. This is why a serving system’s max concurrent users is almost entirely a function of how much VRAM is left after weights, not how fast the GPU is.

Run it in your browser

Compute the cache size for any model and context length right here. Edit the params and re-run.

Python — editableNo GPU needed — pure arithmetic.
Ctrl+Enter to run

A few things this should make obvious: (1) batch and context length both blow up the cache linearly; (2) at long contexts on a big model, the cache is the binding constraint; (3) the n_kv_heads parameter (GQA) is doing enormous work — DeepSeek’s MLA squeezes this further.

Quick check

Quick check
A 70B model with GQA (8 KV heads, 128 head_dim, 80 layers) is serving 16 concurrent users at 16K context in BF16. Roughly how much GPU memory is consumed by the KV cache alone?

Key takeaways

  1. The cache is 2 × L × H_kv × D × T bytes per sequence — memorize this formula. It is the most useful single equation in LLM serving.
  2. Cache often exceeds weights at long contexts. A 70B model in BF16 weighs ~140 GB; its cache at 32K × bs=8 is ~80 GB.
  3. GQA was the first major breakthrough in cache compression (8× over MHA). MLA (DeepSeek-V2/V3) compresses another ~10×.
  4. Decode is memory-bandwidth bound, not compute bound, because each token re-reads the whole cache. This is why FP8 / INT4 KV quantization matters.

Go deeper