Skip to content

ZeRO & FSDP2

DDP works until the model stops fitting on a single GPU. A 70B model in BF16 is 140 GB of weights, plus ~280 GB of AdamW optimizer state, plus ~40 GB of gradients — 460 GB of state per rank if you replicate. No accelerator on the market in 2026 has 460 GB. DDP, the foundation of multi-GPU training, becomes physically impossible at frontier scale.

The fix is to shard what was being replicated. (Microsoft, 2020) progressively chops the replicated state across data-parallel ranks: stage 1 shards the optimizer state, stage 2 adds gradients, stage 3 adds the parameters themselves. By stage 3, no rank holds the full model — each holds 1/N of every tensor and reconstructs full layers on the fly using .

is PyTorch’s implementation of ZeRO-3 (2022). FSDP2 (2024) is the per-parameter rewrite on top of — instead of wrapping whole modules, each parameter knows its sharding independently. It composes cleanly with tensor parallel and pipeline parallel, wraps faster, and gives finer-grained control over the AllGather/ReduceScatter overlap that determines whether your run hits 70% MFU or 30%.

The Python is for layer in model.layers: fully_shard(layer). Behind that loop: PyTorch replaces every parameter with a sharded DTensor, the forward pass inserts AllGather pre-hooks (“fetch the full weight just before this layer runs”), the backward pass inserts ReduceScatter post-hooks (“scatter the gradient just after this layer’s backward finishes”). Your training loop doesn’t change. Your scaling does.

TL;DR

  • DDP replicates the full model on every rank. Out of memory at large scale.
  • ZeRO (Rajbhandari et al., 2020) progressively shards what’s replicated: stages 1 (optimizer state), 2 (+ gradients), 3 (+ parameters). Stage 3 = the whole model lives across ranks; only the active layer is materialized.
  • FSDP is PyTorch’s ZeRO-3 (2022). FSDP2 (2024) replaces module-level sharding with per-parameter sharding via DTensor, composes cleanly with TP/PP, much faster wraps, simpler code.
  • For a 70B BF16 model on 8 GPUs: DDP doesn’t fit at all. FSDP2 fits with room for a 32K-token batch.
  • Default in production through 2026: FSDP2 (HSDP) — hybrid sharding within a node + replication across nodes — composed with tensor parallel and pipeline parallel for very large runs.

Mental model

Each stage trades memory for communication. ZeRO-3 (= FSDP) shards everything; the price is that each forward and backward pass must AllGather the layer’s parameters before computing it.

The four ZeRO stages

StageWhat’s shardedMemory per rankComm pattern
Baseline (DDP)nothingfull × 1AllReduce gradients
ZeRO-1optimizer statefull params + grads + (1/N) optimAllReduce gradients
ZeRO-2+ gradientsfull params + (1/N)(grads + optim) grads
ZeRO-3 / FSDP+ parameters(1/N) of everythingAllGather params + ReduceScatter grads

For 70B BF16 on 8 ranks:

StageMem/rankFits on H100 80GB?
DDP460 GBNo
ZeRO-1215 GBNo
ZeRO-2175 GBNo
FSDP / ZeRO-357 GBYes (with selective recompute)

ZeRO-3 is the only setting that fits a 70B BF16 training step on a single 8×H100 node.

What FSDP does each step

Forward pass (per layer): 1. AllGather this layer's full weight from N shards across DP ranks 2. Compute forward 3. Free the full weight (keep activation) Backward pass (per layer, in reverse): 4. AllGather this layer's full weight again 5. Compute backward, produce full gradient 6. ReduceScatter the gradient: each rank ends up with its 1/N shard of the average grad 7. Free the full weight Optimizer step: 8. Each rank updates its 1/N shard of params using its 1/N shard of grad and optim state

The two communication ops — AllGather (forward + backward) and ReduceScatter (backward) — should overlap with compute. The whole game of FSDP performance tuning is making sure they do.

FSDP1 vs FSDP2

FSDP1 (2022) wrapped at the module level. You’d wrap(transformer_layer) and the entire layer’s params became one shard. Caveats:

  • Coarse: a transformer_layer of mixed-shape params got one big AllGather instead of fine-grained per-tensor ones
  • Couldn’t compose with tensor parallel cleanly (the wrap boundary fought the TP boundary)
  • Slow to wrap large models (~10s for a 70B at startup)

FSDP2 (2024) shards per-parameter via PyTorch’s DTensor:

# FSDP1 (deprecated) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP1 model = FSDP1(model, auto_wrap_policy=...) # FSDP2 (modern) from torch.distributed._composable.fsdp import fully_shard for layer in model.layers: fully_shard(layer) # shards each layer's parameters separately fully_shard(model) # shards the rest

FSDP2 advantages:

  • Composes with TP — a parameter can be both sharded (FSDP) on the data dim and sharded (TP) on the head dim. DTensor handles the bookkeeping.
  • Faster wraps — per-tensor sharding is essentially free at startup.
  • Cleaner gradient communication — ReduceScatter happens at parameter granularity, easier to overlap.
  • Future-proof — Megatron-Core, TorchTitan, and most 2024-2026 production stacks use FSDP2.

Hybrid sharding (HSDP)

For multi-node training, full FSDP across all ranks means cross-node AllGathers — expensive over IB / Ethernet. shards within a node (NVLink) and replicates across nodes (which use AllReduce instead of AllGather):

mesh = init_device_mesh(..., mesh_dim_names=("inter_node", "intra_node")) fully_shard(model, mesh=mesh["intra_node"]) # FSDP within node, DDP across

This is the production default for any run with more than one node.

Real numbers — Llama-2 70B training, 8×H100

SetupMemory/ranktokens/secNotes
DDPOOMDoesn’t fit
FSDP1 (full shard)62 GB18,500Works
FSDP2 (full shard)57 GB22,000Faster wraps, better overlap
FSDP2 + selective ckpt41 GB20,500Headroom for longer sequences
FSDP2 + TP=2 + PP=2 (4D)32 GB28,000Required for 405B

FSDP2 alone gets you to 70B; composed with TP/PP, you scale arbitrarily.

Run it in your browser — pick the right ZeRO stage

Python — editableMemory budget across DDP, ZeRO-1, ZeRO-2, FSDP for various model sizes.
Ctrl+Enter to run

You’ll see DDP becomes infeasible past ~10B; FSDP keeps memory linear-friendly down to many ranks.

Quick check

Quick check
You're training a 70B model with FSDP2 across 8×H100. tokens/sec is 60% of expected. The profiler shows AllGather operations dominating each step. What's the most likely fix?

Key takeaways

  1. FSDP2 is the production default in 2026 for single-node and HSDP for multi-node. FSDP1 is legacy.
  2. The whole game is communication-compute overlap. AllGather hides behind prior compute when configured right; surfaces as time when not.
  3. ZeRO-3 / FSDP-only scales to ~70B on a node. Past that, compose with TP and PP — the 4D mesh.
  4. Selective activation checkpointing pairs with FSDP to free memory for prefetching and longer sequences.
  5. Profile every runtorch.profiler with NVTX ranges. AllGather bars vs compute bars tell you instantly whether your FSDP config is right.

Go deeper

TL;DR

  • DDP replicates the full model on every rank. Out of memory at large scale.
  • ZeRO (Rajbhandari et al., 2020) progressively shards what’s replicated: stages 1 (optimizer state), 2 (+ gradients), 3 (+ parameters). Stage 3 = the whole model lives across ranks; only the active layer is materialized.
  • FSDP is PyTorch’s ZeRO-3 (2022). FSDP2 (2024) replaces module-level sharding with per-parameter sharding via DTensor, composes cleanly with TP/PP, much faster wraps, simpler code.
  • For a 70B BF16 model on 8 GPUs: DDP doesn’t fit at all. FSDP2 fits with room for a 32K-token batch.
  • Default in production through 2026: FSDP2 (HSDP) — hybrid sharding within a node + replication across nodes — composed with tensor parallel and pipeline parallel for very large runs.

Why this matters

A 70B model in BF16 is 140 GB of weights, plus ~280 GB of optimizer state (AdamW), plus ~40 GB of gradients = 460 GB of state. Doesn’t fit on any single accelerator. FSDP/ZeRO is what makes large-scale training possible at all for any team that doesn’t run on $100K-per-node TPU pods.

It’s also the most common reason a training run is slow when it shouldn’t be — communication-compute overlap is a delicate dance, and FSDP misconfiguration commonly costs 30–50% of throughput silently.

Mental model

Each stage trades memory for communication. ZeRO-3 (= FSDP) shards everything; the price is that each forward and backward pass must AllGather the layer’s parameters before computing it.

Concrete walkthrough — the four ZeRO stages

StageWhat’s shardedMemory per rankComm pattern
Baseline (DDP)nothingfull × 1AllReduce gradients
ZeRO-1optimizer statefull params + grads + (1/N) optimAllReduce gradients
ZeRO-2+ gradientsfull params + (1/N)(grads + optim)ReduceScatter grads
ZeRO-3 / FSDP+ parameters(1/N) of everythingAllGather params + ReduceScatter grads

For 70B BF16 on 8 ranks:

StageMem/rankFits on H100 80GB?
DDP460 GBNo
ZeRO-1215 GBNo
ZeRO-2175 GBNo
FSDP / ZeRO-357 GBYes (with selective recompute)

ZeRO-3 is the only setting that fits a 70B BF16 training step on a single 8×H100 node.

What FSDP does each step

Forward pass (per layer): 1. AllGather this layer's full weight from N shards across DP ranks 2. Compute forward 3. Free the full weight (keep activation) Backward pass (per layer, in reverse): 4. AllGather this layer's full weight again 5. Compute backward, produce full gradient 6. ReduceScatter the gradient: each rank ends up with its 1/N shard of the average grad 7. Free the full weight Optimizer step: 8. Each rank updates its 1/N shard of params using its 1/N shard of grad and optim state

The two communication ops — AllGather (forward + backward) and ReduceScatter (backward) — should overlap with compute. The whole game of FSDP performance tuning is making sure they do.

FSDP1 vs FSDP2

FSDP1 (2022) wrapped at the module level. You’d wrap(transformer_layer) and the entire layer’s params became one shard. Caveats:

  • Coarse: a transformer_layer of mixed-shape params got one big AllGather instead of fine-grained per-tensor ones
  • Couldn’t compose with tensor parallel cleanly (the wrap boundary fought the TP boundary)
  • Slow to wrap large models (~10s for a 70B at startup)

FSDP2 (2024) shards per-parameter via PyTorch’s DTensor:

# FSDP1 (deprecated) from torch.distributed.fsdp import FullyShardedDataParallel as FSDP1 model = FSDP1(model, auto_wrap_policy=...) # FSDP2 (modern) from torch.distributed._composable.fsdp import fully_shard for layer in model.layers: fully_shard(layer) # shards each layer's parameters separately fully_shard(model) # shards the rest

FSDP2 advantages:

  • Composes with TP — a parameter can be both sharded (FSDP) on the data dim and sharded (TP) on the head dim. DTensor handles the bookkeeping.
  • Faster wraps — per-tensor sharding is essentially free at startup.
  • Cleaner gradient communication — ReduceScatter happens at parameter granularity, easier to overlap.
  • Future-proof — Megatron-Core, TorchTitan, and most 2024-2026 production stacks use FSDP2.

Hybrid sharding (HSDP)

For multi-node training, full FSDP across all ranks means cross-node AllGathers — expensive over IB / Ethernet. HSDP shards within a node (NVLink) and replicates across nodes (which use AllReduce instead of AllGather):

mesh = init_device_mesh(..., mesh_dim_names=("inter_node", "intra_node")) fully_shard(model, mesh=mesh["intra_node"]) # FSDP within node, DDP across

This is the production default for any run with more than one node.

Real numbers — Llama-2 70B training, 8×H100

SetupMemory/ranktokens/secNotes
DDPOOMDoesn’t fit
FSDP1 (full shard)62 GB18,500Works
FSDP2 (full shard)57 GB22,000Faster wraps, better overlap
FSDP2 + selective ckpt41 GB20,500Headroom for longer sequences
FSDP2 + TP=2 + PP=2 (4D)32 GB28,000Required for 405B

FSDP2 alone gets you to 70B; composed with TP/PP, you scale arbitrarily.

Run it in your browser — pick the right ZeRO stage

Python — editableMemory budget across DDP, ZeRO-1, ZeRO-2, FSDP for various model sizes.
Ctrl+Enter to run

You’ll see DDP becomes infeasible past ~10B; FSDP keeps memory linear-friendly down to many ranks.

Quick check

Quick check
You're training a 70B model with FSDP2 across 8×H100. tokens/sec is 60% of expected. The profiler shows AllGather operations dominating each step. What's the most likely fix?

Key takeaways

  1. FSDP2 is the production default in 2026 for single-node and HSDP for multi-node. FSDP1 is legacy.
  2. The whole game is communication-compute overlap. AllGather hides behind prior compute when configured right; surfaces as time when not.
  3. ZeRO-3 / FSDP-only scales to ~70B on a node. Past that, compose with TP and PP — the 4D mesh.
  4. Selective activation checkpointing pairs with FSDP to free memory for prefetching and longer sequences.
  5. Profile every runtorch.profiler with NVTX ranges. AllGather bars vs compute bars tell you instantly whether your FSDP config is right.

Go deeper