ZeRO & FSDP2
DDP works until the model stops fitting on a single GPU. A 70B model in BF16 is 140 GB of weights, plus ~280 GB of AdamW optimizer state, plus ~40 GB of gradients — 460 GB of state per rank if you replicate. No accelerator on the market in 2026 has 460 GB. DDP, the foundation of multi-GPU training, becomes physically impossible at frontier scale.
The fix is to shard what was being replicated. (Microsoft, 2020) progressively chops the replicated state across data-parallel ranks: stage 1 shards the optimizer state, stage 2 adds gradients, stage 3 adds the parameters themselves. By stage 3, no rank holds the full model — each holds 1/N of every tensor and reconstructs full layers on the fly using .
is PyTorch’s implementation of ZeRO-3 (2022). FSDP2 (2024) is the per-parameter rewrite on top of — instead of wrapping whole modules, each parameter knows its sharding independently. It composes cleanly with tensor parallel and pipeline parallel, wraps faster, and gives finer-grained control over the AllGather/ReduceScatter overlap that determines whether your run hits 70% MFU or 30%.
The Python is for layer in model.layers: fully_shard(layer). Behind that loop: PyTorch replaces every parameter with a sharded DTensor, the forward pass inserts AllGather pre-hooks (“fetch the full weight just before this layer runs”), the backward pass inserts ReduceScatter post-hooks (“scatter the gradient just after this layer’s backward finishes”). Your training loop doesn’t change. Your scaling does.
TL;DR
- DDP replicates the full model on every rank. Out of memory at large scale.
- ZeRO (Rajbhandari et al., 2020) progressively shards what’s replicated: stages 1 (optimizer state), 2 (+ gradients), 3 (+ parameters). Stage 3 = the whole model lives across ranks; only the active layer is materialized.
- FSDP is PyTorch’s ZeRO-3 (2022). FSDP2 (2024) replaces module-level sharding with per-parameter sharding via DTensor, composes cleanly with TP/PP, much faster wraps, simpler code.
- For a 70B BF16 model on 8 GPUs: DDP doesn’t fit at all. FSDP2 fits with room for a 32K-token batch.
- Default in production through 2026: FSDP2 (HSDP) — hybrid sharding within a node + replication across nodes — composed with tensor parallel and pipeline parallel for very large runs.
Mental model
Each stage trades memory for communication. ZeRO-3 (= FSDP) shards everything; the price is that each forward and backward pass must AllGather the layer’s parameters before computing it.
The four ZeRO stages
| Stage | What’s sharded | Memory per rank | Comm pattern |
|---|---|---|---|
| Baseline (DDP) | nothing | full × 1 | AllReduce gradients |
| ZeRO-1 | optimizer state | full params + grads + (1/N) optim | AllReduce gradients |
| ZeRO-2 | + gradients | full params + (1/N)(grads + optim) | grads |
| ZeRO-3 / FSDP | + parameters | (1/N) of everything | AllGather params + ReduceScatter grads |
For 70B BF16 on 8 ranks:
| Stage | Mem/rank | Fits on H100 80GB? |
|---|---|---|
| DDP | 460 GB | No |
| ZeRO-1 | 215 GB | No |
| ZeRO-2 | 175 GB | No |
| FSDP / ZeRO-3 | 57 GB | Yes (with selective recompute) |
ZeRO-3 is the only setting that fits a 70B BF16 training step on a single 8×H100 node.
What FSDP does each step
Forward pass (per layer):
1. AllGather this layer's full weight from N shards across DP ranks
2. Compute forward
3. Free the full weight (keep activation)
Backward pass (per layer, in reverse):
4. AllGather this layer's full weight again
5. Compute backward, produce full gradient
6. ReduceScatter the gradient: each rank ends up with its 1/N shard of the average grad
7. Free the full weight
Optimizer step:
8. Each rank updates its 1/N shard of params using its 1/N shard of grad and optim stateThe two communication ops — AllGather (forward + backward) and ReduceScatter (backward) — should overlap with compute. The whole game of FSDP performance tuning is making sure they do.
FSDP1 vs FSDP2
FSDP1 (2022) wrapped at the module level. You’d wrap(transformer_layer) and the entire layer’s params became one shard. Caveats:
- Coarse: a
transformer_layerof mixed-shape params got one big AllGather instead of fine-grained per-tensor ones - Couldn’t compose with tensor parallel cleanly (the wrap boundary fought the TP boundary)
- Slow to wrap large models (~10s for a 70B at startup)
FSDP2 (2024) shards per-parameter via PyTorch’s DTensor:
# FSDP1 (deprecated)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP1
model = FSDP1(model, auto_wrap_policy=...)
# FSDP2 (modern)
from torch.distributed._composable.fsdp import fully_shard
for layer in model.layers:
fully_shard(layer) # shards each layer's parameters separately
fully_shard(model) # shards the restFSDP2 advantages:
- Composes with TP — a parameter can be both sharded (FSDP) on the data dim and sharded (TP) on the head dim. DTensor handles the bookkeeping.
- Faster wraps — per-tensor sharding is essentially free at startup.
- Cleaner gradient communication — ReduceScatter happens at parameter granularity, easier to overlap.
- Future-proof — Megatron-Core, TorchTitan, and most 2024-2026 production stacks use FSDP2.
Hybrid sharding (HSDP)
For multi-node training, full FSDP across all ranks means cross-node AllGathers — expensive over IB / Ethernet. shards within a node (NVLink) and replicates across nodes (which use AllReduce instead of AllGather):
mesh = init_device_mesh(..., mesh_dim_names=("inter_node", "intra_node"))
fully_shard(model, mesh=mesh["intra_node"]) # FSDP within node, DDP acrossThis is the production default for any run with more than one node.
Real numbers — Llama-2 70B training, 8×H100
| Setup | Memory/rank | tokens/sec | Notes |
|---|---|---|---|
| DDP | OOM | — | Doesn’t fit |
| FSDP1 (full shard) | 62 GB | 18,500 | Works |
| FSDP2 (full shard) | 57 GB | 22,000 | Faster wraps, better overlap |
| FSDP2 + selective ckpt | 41 GB | 20,500 | Headroom for longer sequences |
| FSDP2 + TP=2 + PP=2 (4D) | 32 GB | 28,000 | Required for 405B |
FSDP2 alone gets you to 70B; composed with TP/PP, you scale arbitrarily.
Run it in your browser — pick the right ZeRO stage
You’ll see DDP becomes infeasible past ~10B; FSDP keeps memory linear-friendly down to many ranks.
Quick check
Key takeaways
- FSDP2 is the production default in 2026 for single-node and HSDP for multi-node. FSDP1 is legacy.
- The whole game is communication-compute overlap. AllGather hides behind prior compute when configured right; surfaces as time when not.
- ZeRO-3 / FSDP-only scales to ~70B on a node. Past that, compose with TP and PP — the 4D mesh.
- Selective activation checkpointing pairs with FSDP to free memory for prefetching and longer sequences.
- Profile every run —
torch.profilerwith NVTX ranges. AllGather bars vs compute bars tell you instantly whether your FSDP config is right.
Go deeper
- PaperZeRO: Memory Optimizations Toward Training Trillion Parameter ModelsThe original ZeRO paper. Stages 1, 2, 3 defined here.
- PaperPyTorch FSDP: Experiences on Scaling Fully Sharded Data ParallelThe FSDP1 retrospective paper. Required if you want to operate FSDP at scale.
- DocsPyTorch — fully_shard (FSDP2) docsThe current API reference. Replaces the FSDP1 wrapper.
- BlogPyTorch Blog — FSDP2 + Float8 (2024)How FSDP2 composes with FP8 training. The current 2025 frontier configuration.
- Repopytorch/torchtitanReference implementations of FSDP2 + TP + PP for Llama-class models. Read `torchtitan/parallelisms/`.
- PaperMegatron-Core: Distributed Training FrameworksNVIDIA's alternative stack. Worth reading for the production-ML-systems perspective.
- VideoLess Wright — Mastering FSDP2Walkthrough of the API and tuning knobs by a PyTorch engineer.
- BlogDeepSpeed — ZeRO tutorialsThe other big stack. Conceptually identical to FSDP, different ergonomics.
TL;DR
- DDP replicates the full model on every rank. Out of memory at large scale.
- ZeRO (Rajbhandari et al., 2020) progressively shards what’s replicated: stages 1 (optimizer state), 2 (+ gradients), 3 (+ parameters). Stage 3 = the whole model lives across ranks; only the active layer is materialized.
- FSDP is PyTorch’s ZeRO-3 (2022). FSDP2 (2024) replaces module-level sharding with per-parameter sharding via DTensor, composes cleanly with TP/PP, much faster wraps, simpler code.
- For a 70B BF16 model on 8 GPUs: DDP doesn’t fit at all. FSDP2 fits with room for a 32K-token batch.
- Default in production through 2026: FSDP2 (HSDP) — hybrid sharding within a node + replication across nodes — composed with tensor parallel and pipeline parallel for very large runs.
Why this matters
A 70B model in BF16 is 140 GB of weights, plus ~280 GB of optimizer state (AdamW), plus ~40 GB of gradients = 460 GB of state. Doesn’t fit on any single accelerator. FSDP/ZeRO is what makes large-scale training possible at all for any team that doesn’t run on $100K-per-node TPU pods.
It’s also the most common reason a training run is slow when it shouldn’t be — communication-compute overlap is a delicate dance, and FSDP misconfiguration commonly costs 30–50% of throughput silently.
Mental model
Each stage trades memory for communication. ZeRO-3 (= FSDP) shards everything; the price is that each forward and backward pass must AllGather the layer’s parameters before computing it.
Concrete walkthrough — the four ZeRO stages
| Stage | What’s sharded | Memory per rank | Comm pattern |
|---|---|---|---|
| Baseline (DDP) | nothing | full × 1 | AllReduce gradients |
| ZeRO-1 | optimizer state | full params + grads + (1/N) optim | AllReduce gradients |
| ZeRO-2 | + gradients | full params + (1/N)(grads + optim) | ReduceScatter grads |
| ZeRO-3 / FSDP | + parameters | (1/N) of everything | AllGather params + ReduceScatter grads |
For 70B BF16 on 8 ranks:
| Stage | Mem/rank | Fits on H100 80GB? |
|---|---|---|
| DDP | 460 GB | No |
| ZeRO-1 | 215 GB | No |
| ZeRO-2 | 175 GB | No |
| FSDP / ZeRO-3 | 57 GB | Yes (with selective recompute) |
ZeRO-3 is the only setting that fits a 70B BF16 training step on a single 8×H100 node.
What FSDP does each step
Forward pass (per layer):
1. AllGather this layer's full weight from N shards across DP ranks
2. Compute forward
3. Free the full weight (keep activation)
Backward pass (per layer, in reverse):
4. AllGather this layer's full weight again
5. Compute backward, produce full gradient
6. ReduceScatter the gradient: each rank ends up with its 1/N shard of the average grad
7. Free the full weight
Optimizer step:
8. Each rank updates its 1/N shard of params using its 1/N shard of grad and optim stateThe two communication ops — AllGather (forward + backward) and ReduceScatter (backward) — should overlap with compute. The whole game of FSDP performance tuning is making sure they do.
FSDP1 vs FSDP2
FSDP1 (2022) wrapped at the module level. You’d wrap(transformer_layer) and the entire layer’s params became one shard. Caveats:
- Coarse: a
transformer_layerof mixed-shape params got one big AllGather instead of fine-grained per-tensor ones - Couldn’t compose with tensor parallel cleanly (the wrap boundary fought the TP boundary)
- Slow to wrap large models (~10s for a 70B at startup)
FSDP2 (2024) shards per-parameter via PyTorch’s DTensor:
# FSDP1 (deprecated)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP1
model = FSDP1(model, auto_wrap_policy=...)
# FSDP2 (modern)
from torch.distributed._composable.fsdp import fully_shard
for layer in model.layers:
fully_shard(layer) # shards each layer's parameters separately
fully_shard(model) # shards the restFSDP2 advantages:
- Composes with TP — a parameter can be both sharded (FSDP) on the data dim and sharded (TP) on the head dim. DTensor handles the bookkeeping.
- Faster wraps — per-tensor sharding is essentially free at startup.
- Cleaner gradient communication — ReduceScatter happens at parameter granularity, easier to overlap.
- Future-proof — Megatron-Core, TorchTitan, and most 2024-2026 production stacks use FSDP2.
Hybrid sharding (HSDP)
For multi-node training, full FSDP across all ranks means cross-node AllGathers — expensive over IB / Ethernet. HSDP shards within a node (NVLink) and replicates across nodes (which use AllReduce instead of AllGather):
mesh = init_device_mesh(..., mesh_dim_names=("inter_node", "intra_node"))
fully_shard(model, mesh=mesh["intra_node"]) # FSDP within node, DDP acrossThis is the production default for any run with more than one node.
Real numbers — Llama-2 70B training, 8×H100
| Setup | Memory/rank | tokens/sec | Notes |
|---|---|---|---|
| DDP | OOM | — | Doesn’t fit |
| FSDP1 (full shard) | 62 GB | 18,500 | Works |
| FSDP2 (full shard) | 57 GB | 22,000 | Faster wraps, better overlap |
| FSDP2 + selective ckpt | 41 GB | 20,500 | Headroom for longer sequences |
| FSDP2 + TP=2 + PP=2 (4D) | 32 GB | 28,000 | Required for 405B |
FSDP2 alone gets you to 70B; composed with TP/PP, you scale arbitrarily.
Run it in your browser — pick the right ZeRO stage
You’ll see DDP becomes infeasible past ~10B; FSDP keeps memory linear-friendly down to many ranks.
Quick check
Key takeaways
- FSDP2 is the production default in 2026 for single-node and HSDP for multi-node. FSDP1 is legacy.
- The whole game is communication-compute overlap. AllGather hides behind prior compute when configured right; surfaces as time when not.
- ZeRO-3 / FSDP-only scales to ~70B on a node. Past that, compose with TP and PP — the 4D mesh.
- Selective activation checkpointing pairs with FSDP to free memory for prefetching and longer sequences.
- Profile every run —
torch.profilerwith NVTX ranges. AllGather bars vs compute bars tell you instantly whether your FSDP config is right.
Go deeper
- PaperZeRO: Memory Optimizations Toward Training Trillion Parameter ModelsThe original ZeRO paper. Stages 1, 2, 3 defined here.
- PaperPyTorch FSDP: Experiences on Scaling Fully Sharded Data ParallelThe FSDP1 retrospective paper. Required if you want to operate FSDP at scale.
- DocsPyTorch — fully_shard (FSDP2) docsThe current API reference. Replaces the FSDP1 wrapper.
- BlogPyTorch Blog — FSDP2 + Float8 (2024)How FSDP2 composes with FP8 training. The current 2025 frontier configuration.
- Repopytorch/torchtitanReference implementations of FSDP2 + TP + PP for Llama-class models. Read `torchtitan/parallelisms/`.
- PaperMegatron-Core: Distributed Training FrameworksNVIDIA's alternative stack. Worth reading for the production-ML-systems perspective.
- VideoLess Wright — Mastering FSDP2Walkthrough of the API and tuning knobs by a PyTorch engineer.
- BlogDeepSpeed — ZeRO tutorialsThe other big stack. Conceptually identical to FSDP, different ergonomics.