All-Gather vs Reduce-Scatter
Ring all-reduce is not one collective in NCCL's implementation; it is two phases glued together. Pulling those phases apart is the move that makes FSDP and ZeRO stage 3 possible. The bytes on the wire stay the same; the placement in the training step changes, and the memory footprint changes with it.
The decomposition
The identity is exact: all-reduce(x) ≡ all-gather(reduce-scatter(x)). Reduce-scatter takes a tensor of size N replicated on every GPU, sums it across GPUs, and leaves each GPU with exactly one shard of size N/P (the sum's contribution that GPU is responsible for). All-gather takes one shard of size N/P per GPU and reconstructs the full N-byte tensor on every GPU. Compose them and every GPU ends with the same fully-reduced tensor that all-reduce would have produced.
The bytes match too. Each half moves 2(P-1)/P × N bytes per GPU, the same lower bound that ring all-reduce achieves end-to-end. Reduce-scatter is the first half of the ring schedule (the reduce-up phase); all-gather is the second half (the broadcast-down phase). Splitting them costs nothing in total bytes-on-the-wire. It is not free in the step plan: you now run two collectives where you ran one, and each has its own launch and synchronization overhead. But on a per-byte basis the decomposition is exact.
Why FSDP wants the halves separated
DDP holds the full parameter tensor and the full gradient tensor on every GPU at every moment. Forward reads the full params, backward writes the full grads, and one all-reduce per step (or per gradient bucket) keeps the grads in sync. The model is replicated; only data is sharded.
FSDP holds only a shard of size params/P on each GPU at rest. The full parameter tensor exists nowhere in the cluster except briefly, layer by layer, when a forward or backward pass needs it. Forward issues an all-gather of layer N's parameters just before computing layer N, runs the layer, and discards the gathered copy. Backward does the mirror move: it all-gathers parameters again to compute layer N's backward, then runs reduce-scatter on the layer's gradients to leave each GPU with only its shard of the gradient. The gradient is never fully materialized on any single GPU.
The memory savings are direct. DDP's resident state per GPU is roughly params + grads + optimizer (12x params for fp16 mixed-precision Adam, 16x for fp32). FSDP cuts each of those to params/P + grads/P + optimizer/P, saving (P-1)/P × params per term. On a 70B model across 64 GPUs, that is the difference between fitting and not fitting on H100 memory. The price is that every layer's parameters now have to be all-gathered on the fly, and every layer's gradients have to be reduce-scattered, instead of being co-resident with compute. See FSDP vs ZeRO for the fuller comparison.
ZeRO stage 3 = same shape
DeepSpeed's ZeRO is the same idea, staged. Stage 1 shards optimizer states only (the Adam moments and master weights), keeping params and grads replicated. The collective pattern is unchanged from DDP: still one all-reduce per gradient bucket per step. Stage 2 also shards gradients, replacing the all-reduce with a reduce-scatter of grads at the end of backward; params stay replicated, so forward and most of backward run without parameter collectives. Stage 3 finally shards parameters too, and at that point the runtime is doing exactly what FSDP does: all-gather of params before each layer, reduce-scatter of grads after. The naming differs; the collective shape does not.
The cost: 2 collectives per layer per step
DDP runs roughly one all-reduce per gradient bucket per step, on the order of tens to a few hundred collectives total for a real training run (PyTorch's default 25 MB bucket coalesces aggressively). FSDP runs two collectives per layer per step: one all-gather on the way down, one reduce-scatter on the way up. A 70-layer transformer goes from ~70 all-reduces per step to ~140 collectives, one all-gather and one reduce-scatter at every layer.
Each call is half the data, so total bytes-on-the-wire are comparable to the DDP baseline. Total launch overhead is roughly 2x: every collective pays NCCL dispatch cost, kernel launch latency, and CUDA stream synchronization, and FSDP issues twice as many of them. That is fine if every collective overlaps cleanly with compute. It is brutal if any of them stall the step. This is why FSDP needs compute-comm overlap more than DDP does. DDP can tolerate a slow all-reduce at the end of the step because compute is mostly done by then. FSDP cannot tolerate a slow all-gather on layer N+1, because forward of layer N is racing it on the wall clock.
What this means in practice
- If FSDP's all-gather of layer N+1's parameters does not complete before forward of layer N finishes, the step time spikes by the gap. The fix lives in NCCL channel count and framework prefetch depth: more channels lower the per-call latency floor, and prefetching layer N+2 while computing layer N+1 hides one extra round-trip of latency.
- Reduce-scatter messages are half the size of the equivalent all-reduce. On smaller models or shallower layers, that drag can pull the call back below the ring crossover and flip NCCL into tree algorithm without warning. If a DDP-to-FSDP migration suddenly produces strange communication time, check
NCCL_DEBUG=INFOfor the algorithm pick before tuning anything else. - The 2x collective count interacts with gradient bucketing differently than DDP. DDP fuses small gradients into 25 MB buckets; FSDP issues one reduce-scatter per layer regardless of layer size, so very thin layers can land below the bandwidth-bound regime and pay full per-call latency for tiny payloads. Layer fusion (or wider FSDP units) is the lever that helps.
See also
Updated 2026-05-10