Skip to content

Backprop as a Graph

In a managed-language framework — Keras, fastai, anything where “backprop” is a checkbox — .backward() looks like calculus magic. You wrote a forward pass; gradients appear; the optimizer steps. You never touch the chain rule.

The thing the abstraction is hiding is that backprop is not really about derivatives. The derivatives are bookkeeping. Backprop is graph execution in reverse, with the activations from the forward pass as the binding state. records a tape of every op as you run forward; .backward() walks that tape from the loss back to the leaves, computing one per op and freeing each saved tensor as soon as its work is done.

Why this reframe matters: once you see backward as “replay the forward graph backward, consuming the activations stored at each step,” you can predict where your training run will OOM. Spoiler — it’s almost never the weights. It’s the activations, which can outweigh the model 5–10× at long context. Every memory-saving training trick you’ll meet later — , activation offloading, micro-batching, FSDP — is some variation on “manage activations more cleverly during backward.”

TL;DR

  • Backprop isn’t really about derivatives — those are bookkeeping. It’s about executing a computational graph backward, with the forward activations as the binding state.
  • Activations dominate training memory. For a 7B model at 8K context, activations can exceed weights by 5–10× without checkpointing.
  • PyTorch’s autograd records a tape of operations during forward; backward replays it in reverse, freeing tensors as soon as their gradient is computed.
  • VJP (Vector-Jacobian Product) is the right mental model — .backward() doesn’t compute Jacobians, it computes JvJ^\top v for some upstream vv, one op at a time.
  • The big systems consequence: gradient checkpointing trades compute for memory by recomputing activations during backward instead of storing them.

Mental model

The forward pass writes activations to memory; the backward pass reads them in reverse, computing one VJP per op, freeing each activation as soon as its corresponding op has computed its gradient.

The peak memory is the activation accumulation at the deepest point of the network — typically all activations for one forward pass live simultaneously by the time the loss is computed.

Why activations dominate

The clean formula (Korthikanti et al., Megatron 2022, assuming so the N×NN \times N attention-score matrix is never materialized in HBM):

activation bytes per layer    34sbh(bytes/elem)\text{activation bytes per layer} \;\approx\; 34 \cdot s \cdot b \cdot h \cdot \text{(bytes/elem)}

where ss = sequence length, bb = batch, hh = hidden dim. The 34 absorbs Q/K/V/O projections, the FFN intermediate (SwiGLU stores 2× inputs), layer-norm scratch, residual inputs — all the per-token tensors saved for the backward pass.

Without FlashAttention there’s also a 5as2b5 \cdot a \cdot s^2 \cdot b term for the score matrix; at long context that term dwarfs everything, which is exactly why Hopper-era training runs are FlashAttention-only.

For Llama-3.1 70B (h=8192,L=80h=8192, L=80) at s=8192,b=1s=8192, b=1, BF16:

activations    80348192181922    365 GB\text{activations} \;\approx\; 80 \cdot 34 \cdot 8192 \cdot 1 \cdot 8192 \cdot 2 \;\approx\; \mathbf{365\ GB}

Compare: weights are 140 GB; AdamW state in FP32 adds ~280 GB. Activations alone are 2–3× the weights at long context. The full forward+backward state is ~785 GB — fits across an 8×H100 node only with gradient checkpointing.

This is why every frontier training run uses gradient checkpointing — not an optimization, a feasibility requirement.

What .backward() actually does (autograd in 30 lines)

class Tensor: def __init__(self, data, parents=(), op=None, grad_fn=None): self.data = data self.grad = None self.parents = parents # tensors this depends on self.grad_fn = grad_fn # function: (upstream_grad) -> per-parent grads def backward(self, upstream=1.0): # Topological sort + reverse traversal. order, seen = [], set() def topo(t): if id(t) in seen: return seen.add(id(t)) for p in t.parents: topo(p) order.append(t) topo(self) # seed root grad self.grad = upstream # walk in reverse, accumulating grads at each tensor for t in reversed(order): if t.grad_fn is None: continue parent_grads = t.grad_fn(t.grad) for p, g in zip(t.parents, parent_grads): p.grad = g if p.grad is None else p.grad + g # Each op registers its parents and a grad_fn that's its VJP. def matmul(A, B): out = Tensor(A.data @ B.data, parents=(A, B), grad_fn=lambda gO: (gO @ B.data.T, A.data.T @ gO)) return out

This is essentially what PyTorch’s autograd does, scaled and optimized: a tape, a topological sort, and per-op VJP closures.

VJP, not Jacobian

For an op with mm inputs and nn outputs, the Jacobian is an n×mn \times m matrix — usually huge. We never materialize it. We compute JvJ^\top v for some upstream gradient vRnv \in \mathbb{R}^n, which is just a vector of size mm.

For matmul(A, B) = AB, the Jacobian w.r.t. AA is enormous, but the VJP is simple:

LA=LYB\frac{\partial \mathcal{L}}{\partial A} = \frac{\partial \mathcal{L}}{\partial Y} \cdot B^\top

— another matmul. Most ops have similarly clean VJPs that fit on one line.

Gradient checkpointing — the key memory-saving trick

Idea: don’t save all activations during forward. Save only at checkpoints. Recompute the rest during backward.

# Without checkpointing: 1 fwd, 1 bwd, peak memory = full activation stack # With checkpointing every K layers: 1 fwd + 1 partial recomputation per K layers, peak = K layers' worth

Trade-off: ~33% more compute (one extra forward through each non-checkpointed segment), but memory drops from O(L)O(L) to O(L)O(\sqrt{L}) if checkpoints are placed evenly. For deep models this is the difference between fits and doesn’t.

PyTorch’s torch.utils.checkpoint.checkpoint wraps a function so its activations aren’t saved; backward recomputes them.

from torch.utils.checkpoint import checkpoint # Selective checkpointing: only checkpoint the attention block (which has the N×N matrix) def forward(x): x = checkpoint(self.attn_block, x, use_reentrant=False) x = self.ffn_block(x) return x

Megatron’s “selective recompute” (Korthikanti et al., 2022) chooses which activations to keep based on cost — recomputing FlashAttention is cheap (the math is fast), so attention is always recomputed; recomputing the FFN is expensive, so it’s saved. This gets the memory of full checkpointing with ~5% extra compute instead of 33%.

Real numbers — Llama 7B at 8K context, single A100

SetupActivationsTotal mem
No checkpointing39 GBOOM
Full checkpointing4 GB28 GB
Selective (attn only)6 GB30 GB

Selective is the production sweet spot.

Run it in your browser — predict activation memory

Python — editableEstimate activation memory for a Transformer training run.
Ctrl+Enter to run

You’ll see Llama-70B at 8K with no checkpointing exceeds an H100; selective checkpointing brings it within budget.

Quick check

Quick check
You're training a 13B model on 8 A100s and OOMing during the backward pass. The forward pass succeeds. What's the most appropriate first fix?

Key takeaways

  1. Backward is graph replay. Forward saves activations; backward consumes them via per-op VJPs.
  2. Activations, not weights, often dominate training memory. Predict them or you’ll OOM.
  3. Use VJP intuitions, not Jacobian intuitions. No one materializes JJ. Always think JvJ^\top v.
  4. Selective gradient checkpointing is the production default. Pay ~5% compute for 3–5× lower activation memory.
  5. PyTorch autograd is ~600 lines of clever bookkeeping. Read torch/autograd/ once — the abstractions are simpler than they look.

Go deeper

TL;DR

  • Backprop isn’t really about derivatives — those are bookkeeping. It’s about executing a computational graph backward, with the forward activations as the binding state.
  • Activations dominate training memory. For a 7B model at 8K context, activations can exceed weights by 5–10× without checkpointing.
  • PyTorch’s autograd records a tape of operations during forward; backward replays it in reverse, freeing tensors as soon as their gradient is computed.
  • VJP (Vector-Jacobian Product) is the right mental model — .backward() doesn’t compute Jacobians, it computes JvJ^\top v for some upstream vv, one op at a time.
  • The big systems consequence: gradient checkpointing trades compute for memory by recomputing activations during backward instead of storing them.

Why this matters

If you can’t predict where activations live, you can’t predict what your training run will OOM on. Memory pressure isn’t about the model size — it’s about activations × batch × sequence length.

This is also the foundation of every memory-saving training trick: gradient checkpointing, activation offloading, micro-batching, FSDP. They’re all variations on “manage activations more cleverly during backward”.

Mental model

The forward pass writes activations to memory; the backward pass reads them in reverse, computing one VJP per op, freeing each activation as soon as its corresponding op has computed its gradient.

The peak memory is the activation accumulation at the deepest point of the network — typically all activations for one forward pass live simultaneously by the time the loss is computed.

Concrete walkthrough

Why activations dominate

The clean formula (Korthikanti et al., Megatron 2022, assuming FlashAttention so the N×NN \times N attention-score matrix is never materialized in HBM):

activation bytes per layer    34sbh(bytes/elem)\text{activation bytes per layer} \;\approx\; 34 \cdot s \cdot b \cdot h \cdot \text{(bytes/elem)}

where ss = sequence length, bb = batch, hh = hidden dim. The 34 absorbs Q/K/V/O projections, the FFN intermediate (SwiGLU stores 2× inputs), layer-norm scratch, residual inputs — all the per-token tensors saved for the backward pass.

Without FlashAttention there’s also a 5as2b5 \cdot a \cdot s^2 \cdot b term for the score matrix; at long context that term dwarfs everything, which is exactly why Hopper-era training runs are FlashAttention-only.

For Llama-3.1 70B (h=8192,L=80h=8192, L=80) at s=8192,b=1s=8192, b=1, BF16:

activations    80348192181922    365 GB\text{activations} \;\approx\; 80 \cdot 34 \cdot 8192 \cdot 1 \cdot 8192 \cdot 2 \;\approx\; \mathbf{365\ GB}

Compare: weights are 140 GB; AdamW state in FP32 adds ~280 GB. Activations alone are 2–3× the weights at long context. The full forward+backward state is ~785 GB — fits across an 8×H100 node only with gradient checkpointing.

This is why every frontier training run uses gradient checkpointing — not an optimization, a feasibility requirement.

What .backward() actually does (autograd in 30 lines)

class Tensor: def __init__(self, data, parents=(), op=None, grad_fn=None): self.data = data self.grad = None self.parents = parents # tensors this depends on self.grad_fn = grad_fn # function: (upstream_grad) -> per-parent grads def backward(self, upstream=1.0): # Topological sort + reverse traversal. order, seen = [], set() def topo(t): if id(t) in seen: return seen.add(id(t)) for p in t.parents: topo(p) order.append(t) topo(self) # seed root grad self.grad = upstream # walk in reverse, accumulating grads at each tensor for t in reversed(order): if t.grad_fn is None: continue parent_grads = t.grad_fn(t.grad) for p, g in zip(t.parents, parent_grads): p.grad = g if p.grad is None else p.grad + g # Each op registers its parents and a grad_fn that's its VJP. def matmul(A, B): out = Tensor(A.data @ B.data, parents=(A, B), grad_fn=lambda gO: (gO @ B.data.T, A.data.T @ gO)) return out

This is essentially what PyTorch’s autograd does, scaled and optimized: a tape, a topological sort, and per-op VJP closures.

VJP, not Jacobian

For an op with mm inputs and nn outputs, the Jacobian is an n×mn \times m matrix — usually huge. We never materialize it. We compute JvJ^\top v for some upstream gradient vRnv \in \mathbb{R}^n, which is just a vector of size mm.

For matmul(A, B) = AB, the Jacobian w.r.t. AA is enormous, but the VJP is simple:

LA=LYB\frac{\partial \mathcal{L}}{\partial A} = \frac{\partial \mathcal{L}}{\partial Y} \cdot B^\top

— another matmul. Most ops have similarly clean VJPs that fit on one line.

Gradient checkpointing — the key memory-saving trick

Idea: don’t save all activations during forward. Save only at checkpoints. Recompute the rest during backward.

# Without checkpointing: 1 fwd, 1 bwd, peak memory = full activation stack # With checkpointing every K layers: 1 fwd + 1 partial recomputation per K layers, peak = K layers' worth

Trade-off: ~33% more compute (one extra forward through each non-checkpointed segment), but memory drops from O(L)O(L) to O(L)O(\sqrt{L}) if checkpoints are placed evenly. For deep models this is the difference between fits and doesn’t.

PyTorch’s torch.utils.checkpoint.checkpoint wraps a function so its activations aren’t saved; backward recomputes them.

from torch.utils.checkpoint import checkpoint # Selective checkpointing: only checkpoint the attention block (which has the N×N matrix) def forward(x): x = checkpoint(self.attn_block, x, use_reentrant=False) x = self.ffn_block(x) return x

Megatron’s “selective recompute” (Korthikanti et al., 2022) chooses which activations to keep based on cost — recomputing FlashAttention is cheap (the math is fast), so attention is always recomputed; recomputing the FFN is expensive, so it’s saved. This gets the memory of full checkpointing with ~5% extra compute instead of 33%.

Real numbers — Llama 7B at 8K context, single A100

SetupActivationsTotal mem
No checkpointing39 GBOOM
Full checkpointing4 GB28 GB
Selective (attn only)6 GB30 GB

Selective is the production sweet spot.

Run it in your browser — predict activation memory

Python — editableEstimate activation memory for a Transformer training run.
Ctrl+Enter to run

You’ll see Llama-70B at 8K with no checkpointing exceeds an H100; selective checkpointing brings it within budget.

Quick check

Quick check
You're training a 13B model on 8 A100s and OOMing during the backward pass. The forward pass succeeds. What's the most appropriate first fix?

Key takeaways

  1. Backward is graph replay. Forward saves activations; backward consumes them via per-op VJPs.
  2. Activations, not weights, often dominate training memory. Predict them or you’ll OOM.
  3. Use VJP intuitions, not Jacobian intuitions. No one materializes JJ. Always think JvJ^\top v.
  4. Selective gradient checkpointing is the production default. Pay ~5% compute for 3–5× lower activation memory.
  5. PyTorch autograd is ~600 lines of clever bookkeeping. Read torch/autograd/ once — the abstractions are simpler than they look.

Go deeper