Backprop as a Graph
In a managed-language framework — Keras, fastai, anything where “backprop” is a checkbox — .backward() looks like calculus magic. You wrote a forward pass; gradients appear; the optimizer steps. You never touch the chain rule.
The thing the abstraction is hiding is that backprop is not really about derivatives. The derivatives are bookkeeping. Backprop is graph execution in reverse, with the activations from the forward pass as the binding state. records a tape of every op as you run forward; .backward() walks that tape from the loss back to the leaves, computing one per op and freeing each saved tensor as soon as its work is done.
Why this reframe matters: once you see backward as “replay the forward graph backward, consuming the activations stored at each step,” you can predict where your training run will OOM. Spoiler — it’s almost never the weights. It’s the activations, which can outweigh the model 5–10× at long context. Every memory-saving training trick you’ll meet later — , activation offloading, micro-batching, FSDP — is some variation on “manage activations more cleverly during backward.”
TL;DR
- Backprop isn’t really about derivatives — those are bookkeeping. It’s about executing a computational graph backward, with the forward activations as the binding state.
- Activations dominate training memory. For a 7B model at 8K context, activations can exceed weights by 5–10× without checkpointing.
- PyTorch’s autograd records a tape of operations during forward; backward replays it in reverse, freeing tensors as soon as their gradient is computed.
- VJP (Vector-Jacobian Product) is the right mental model —
.backward()doesn’t compute Jacobians, it computes for some upstream , one op at a time. - The big systems consequence: gradient checkpointing trades compute for memory by recomputing activations during backward instead of storing them.
Mental model
The forward pass writes activations to memory; the backward pass reads them in reverse, computing one VJP per op, freeing each activation as soon as its corresponding op has computed its gradient.
The peak memory is the activation accumulation at the deepest point of the network — typically all activations for one forward pass live simultaneously by the time the loss is computed.
Why activations dominate
The clean formula (Korthikanti et al., Megatron 2022, assuming so the attention-score matrix is never materialized in HBM):
where = sequence length, = batch, = hidden dim. The 34 absorbs Q/K/V/O projections, the FFN intermediate (SwiGLU stores 2× inputs), layer-norm scratch, residual inputs — all the per-token tensors saved for the backward pass.
Without FlashAttention there’s also a term for the score matrix; at long context that term dwarfs everything, which is exactly why Hopper-era training runs are FlashAttention-only.
For Llama-3.1 70B () at , BF16:
Compare: weights are 140 GB; AdamW state in FP32 adds ~280 GB. Activations alone are 2–3× the weights at long context. The full forward+backward state is ~785 GB — fits across an 8×H100 node only with gradient checkpointing.
This is why every frontier training run uses gradient checkpointing — not an optimization, a feasibility requirement.
What .backward() actually does (autograd in 30 lines)
class Tensor:
def __init__(self, data, parents=(), op=None, grad_fn=None):
self.data = data
self.grad = None
self.parents = parents # tensors this depends on
self.grad_fn = grad_fn # function: (upstream_grad) -> per-parent grads
def backward(self, upstream=1.0):
# Topological sort + reverse traversal.
order, seen = [], set()
def topo(t):
if id(t) in seen: return
seen.add(id(t))
for p in t.parents: topo(p)
order.append(t)
topo(self)
# seed root grad
self.grad = upstream
# walk in reverse, accumulating grads at each tensor
for t in reversed(order):
if t.grad_fn is None: continue
parent_grads = t.grad_fn(t.grad)
for p, g in zip(t.parents, parent_grads):
p.grad = g if p.grad is None else p.grad + g
# Each op registers its parents and a grad_fn that's its VJP.
def matmul(A, B):
out = Tensor(A.data @ B.data, parents=(A, B),
grad_fn=lambda gO: (gO @ B.data.T, A.data.T @ gO))
return outThis is essentially what PyTorch’s autograd does, scaled and optimized: a tape, a topological sort, and per-op VJP closures.
VJP, not Jacobian
For an op with inputs and outputs, the Jacobian is an matrix — usually huge. We never materialize it. We compute for some upstream gradient , which is just a vector of size .
For matmul(A, B) = AB, the Jacobian w.r.t. is enormous, but the VJP is simple:
— another matmul. Most ops have similarly clean VJPs that fit on one line.
Gradient checkpointing — the key memory-saving trick
Idea: don’t save all activations during forward. Save only at checkpoints. Recompute the rest during backward.
# Without checkpointing: 1 fwd, 1 bwd, peak memory = full activation stack
# With checkpointing every K layers: 1 fwd + 1 partial recomputation per K layers, peak = K layers' worthTrade-off: ~33% more compute (one extra forward through each non-checkpointed segment), but memory drops from to if checkpoints are placed evenly. For deep models this is the difference between fits and doesn’t.
PyTorch’s torch.utils.checkpoint.checkpoint wraps a function so its activations aren’t saved; backward recomputes them.
from torch.utils.checkpoint import checkpoint
# Selective checkpointing: only checkpoint the attention block (which has the N×N matrix)
def forward(x):
x = checkpoint(self.attn_block, x, use_reentrant=False)
x = self.ffn_block(x)
return xMegatron’s “selective recompute” (Korthikanti et al., 2022) chooses which activations to keep based on cost — recomputing FlashAttention is cheap (the math is fast), so attention is always recomputed; recomputing the FFN is expensive, so it’s saved. This gets the memory of full checkpointing with ~5% extra compute instead of 33%.
Real numbers — Llama 7B at 8K context, single A100
| Setup | Activations | Total mem |
|---|---|---|
| No checkpointing | 39 GB | OOM |
| Full checkpointing | 4 GB | 28 GB |
| Selective (attn only) | 6 GB | 30 GB |
Selective is the production sweet spot.
Run it in your browser — predict activation memory
You’ll see Llama-70B at 8K with no checkpointing exceeds an H100; selective checkpointing brings it within budget.
Quick check
Key takeaways
- Backward is graph replay. Forward saves activations; backward consumes them via per-op VJPs.
- Activations, not weights, often dominate training memory. Predict them or you’ll OOM.
- Use VJP intuitions, not Jacobian intuitions. No one materializes . Always think .
- Selective gradient checkpointing is the production default. Pay ~5% compute for 3–5× lower activation memory.
- PyTorch autograd is ~600 lines of clever bookkeeping. Read
torch/autograd/once — the abstractions are simpler than they look.
Go deeper
- PaperAutomatic Differentiation in Machine Learning: A SurveyThe reference for *what* AD actually is. Section 3 on reverse mode is enough.
- PaperTraining Deep Nets with Sublinear Memory CostThe original gradient checkpointing paper. Surprisingly accessible.
- PaperReducing Activation Recomputation in Large Transformer ModelsSelective recompute. The recipe behind every modern training run.
- VideoAndrej Karpathy — micrograd: a tiny autograd engineBuilds autograd in 100 lines from scratch. The clearest explanation that exists.
- Repokarpathy/microgradThe reference 100-line autograd. Read it before reading PyTorch's.
- DocsPyTorch — autograd notesThe official explanation of how the tape, hooks, and `retain_graph` actually work. Required for serious training work.
- BlogHorace He — what shapes do matmuls like?Activation memory framing applied to real PyTorch profiling.
- DocsPyTorch — torch.utils.checkpointHow to enable selective + activation checkpointing in practice.
TL;DR
- Backprop isn’t really about derivatives — those are bookkeeping. It’s about executing a computational graph backward, with the forward activations as the binding state.
- Activations dominate training memory. For a 7B model at 8K context, activations can exceed weights by 5–10× without checkpointing.
- PyTorch’s autograd records a tape of operations during forward; backward replays it in reverse, freeing tensors as soon as their gradient is computed.
- VJP (Vector-Jacobian Product) is the right mental model —
.backward()doesn’t compute Jacobians, it computes for some upstream , one op at a time. - The big systems consequence: gradient checkpointing trades compute for memory by recomputing activations during backward instead of storing them.
Why this matters
If you can’t predict where activations live, you can’t predict what your training run will OOM on. Memory pressure isn’t about the model size — it’s about activations × batch × sequence length.
This is also the foundation of every memory-saving training trick: gradient checkpointing, activation offloading, micro-batching, FSDP. They’re all variations on “manage activations more cleverly during backward”.
Mental model
The forward pass writes activations to memory; the backward pass reads them in reverse, computing one VJP per op, freeing each activation as soon as its corresponding op has computed its gradient.
The peak memory is the activation accumulation at the deepest point of the network — typically all activations for one forward pass live simultaneously by the time the loss is computed.
Concrete walkthrough
Why activations dominate
The clean formula (Korthikanti et al., Megatron 2022, assuming FlashAttention so the attention-score matrix is never materialized in HBM):
where = sequence length, = batch, = hidden dim. The 34 absorbs Q/K/V/O projections, the FFN intermediate (SwiGLU stores 2× inputs), layer-norm scratch, residual inputs — all the per-token tensors saved for the backward pass.
Without FlashAttention there’s also a term for the score matrix; at long context that term dwarfs everything, which is exactly why Hopper-era training runs are FlashAttention-only.
For Llama-3.1 70B () at , BF16:
Compare: weights are 140 GB; AdamW state in FP32 adds ~280 GB. Activations alone are 2–3× the weights at long context. The full forward+backward state is ~785 GB — fits across an 8×H100 node only with gradient checkpointing.
This is why every frontier training run uses gradient checkpointing — not an optimization, a feasibility requirement.
What .backward() actually does (autograd in 30 lines)
class Tensor:
def __init__(self, data, parents=(), op=None, grad_fn=None):
self.data = data
self.grad = None
self.parents = parents # tensors this depends on
self.grad_fn = grad_fn # function: (upstream_grad) -> per-parent grads
def backward(self, upstream=1.0):
# Topological sort + reverse traversal.
order, seen = [], set()
def topo(t):
if id(t) in seen: return
seen.add(id(t))
for p in t.parents: topo(p)
order.append(t)
topo(self)
# seed root grad
self.grad = upstream
# walk in reverse, accumulating grads at each tensor
for t in reversed(order):
if t.grad_fn is None: continue
parent_grads = t.grad_fn(t.grad)
for p, g in zip(t.parents, parent_grads):
p.grad = g if p.grad is None else p.grad + g
# Each op registers its parents and a grad_fn that's its VJP.
def matmul(A, B):
out = Tensor(A.data @ B.data, parents=(A, B),
grad_fn=lambda gO: (gO @ B.data.T, A.data.T @ gO))
return outThis is essentially what PyTorch’s autograd does, scaled and optimized: a tape, a topological sort, and per-op VJP closures.
VJP, not Jacobian
For an op with inputs and outputs, the Jacobian is an matrix — usually huge. We never materialize it. We compute for some upstream gradient , which is just a vector of size .
For matmul(A, B) = AB, the Jacobian w.r.t. is enormous, but the VJP is simple:
— another matmul. Most ops have similarly clean VJPs that fit on one line.
Gradient checkpointing — the key memory-saving trick
Idea: don’t save all activations during forward. Save only at checkpoints. Recompute the rest during backward.
# Without checkpointing: 1 fwd, 1 bwd, peak memory = full activation stack
# With checkpointing every K layers: 1 fwd + 1 partial recomputation per K layers, peak = K layers' worthTrade-off: ~33% more compute (one extra forward through each non-checkpointed segment), but memory drops from to if checkpoints are placed evenly. For deep models this is the difference between fits and doesn’t.
PyTorch’s torch.utils.checkpoint.checkpoint wraps a function so its activations aren’t saved; backward recomputes them.
from torch.utils.checkpoint import checkpoint
# Selective checkpointing: only checkpoint the attention block (which has the N×N matrix)
def forward(x):
x = checkpoint(self.attn_block, x, use_reentrant=False)
x = self.ffn_block(x)
return xMegatron’s “selective recompute” (Korthikanti et al., 2022) chooses which activations to keep based on cost — recomputing FlashAttention is cheap (the math is fast), so attention is always recomputed; recomputing the FFN is expensive, so it’s saved. This gets the memory of full checkpointing with ~5% extra compute instead of 33%.
Real numbers — Llama 7B at 8K context, single A100
| Setup | Activations | Total mem |
|---|---|---|
| No checkpointing | 39 GB | OOM |
| Full checkpointing | 4 GB | 28 GB |
| Selective (attn only) | 6 GB | 30 GB |
Selective is the production sweet spot.
Run it in your browser — predict activation memory
You’ll see Llama-70B at 8K with no checkpointing exceeds an H100; selective checkpointing brings it within budget.
Quick check
Key takeaways
- Backward is graph replay. Forward saves activations; backward consumes them via per-op VJPs.
- Activations, not weights, often dominate training memory. Predict them or you’ll OOM.
- Use VJP intuitions, not Jacobian intuitions. No one materializes . Always think .
- Selective gradient checkpointing is the production default. Pay ~5% compute for 3–5× lower activation memory.
- PyTorch autograd is ~600 lines of clever bookkeeping. Read
torch/autograd/once — the abstractions are simpler than they look.
Go deeper
- PaperAutomatic Differentiation in Machine Learning: A SurveyThe reference for *what* AD actually is. Section 3 on reverse mode is enough.
- PaperTraining Deep Nets with Sublinear Memory CostThe original gradient checkpointing paper. Surprisingly accessible.
- PaperReducing Activation Recomputation in Large Transformer ModelsSelective recompute. The recipe behind every modern training run.
- VideoAndrej Karpathy — micrograd: a tiny autograd engineBuilds autograd in 100 lines from scratch. The clearest explanation that exists.
- Repokarpathy/microgradThe reference 100-line autograd. Read it before reading PyTorch's.
- DocsPyTorch — autograd notesThe official explanation of how the tape, hooks, and `retain_graph` actually work. Required for serious training work.
- BlogHorace He — what shapes do matmuls like?Activation memory framing applied to real PyTorch profiling.
- DocsPyTorch — torch.utils.checkpointHow to enable selective + activation checkpointing in practice.