Checkpoint Sharding
A 175B model checkpoint in BF16 with Adam optimizer state is roughly 1.4 TB. A 405B checkpoint is over 3 TB. Saving these with naive torch.save from rank 0 takes minutes, blocks the training step, and depends on rank 0's connection to storage. Sharded checkpointing makes every rank write its own slice in parallel, so the wall-clock cost drops from minutes to seconds.
What gets sharded
A model checkpoint contains three things: the model parameters, the optimizer state (momentum, variance, master weights), and any auxiliary state (RNG state, scheduler state, training step counter). With FSDP / ZeRO-3, each rank already holds a sharded view of all three at training time. The sharded checkpoint just persists what the rank already has.
For a DP=64 cluster, each rank holds 1/64 of the parameters and optimizer state. A 1 TB checkpoint becomes 16 GB per rank. All 64 ranks write concurrently to a parallel filesystem, and the wall-clock time is the maximum per-rank write time, not the sum.
The naive write (rank 0 collects all weights, dumps to one file) takes roughly total_size / per-server_BW = 1 TB / 10 GB/s = 100 seconds, plus the all-gather to bring weights to rank 0. The sharded write takes per-rank_size / per-rank_BW = 16 GB / 2 GB/s = 8 seconds. The 12x speedup is real and matches what production training runs report.
Output formats
PyTorch's torch.distributed.checkpoint (DCP) is the standard sharded API. It writes one file per rank into a single checkpoint directory: __0_0.distcp, __1_0.distcp, etc. A small metadata file records the sharding layout. Loading reads each file in parallel, optionally re-sharding if the load topology differs from the save topology.
DeepSpeed's checkpoint format (zero_to_fp32.py) does the same shape with different file naming. Megatron-LM uses its own scheme keyed by TP rank, PP stage, and DP replica.
For interoperability, recent PyTorch versions (torch.distributed.checkpoint.format_utils) can convert sharded checkpoints to a single .pt file for downstream consumption (inference deployments, model card uploads). The conversion is one-time and runs offline.
Async checkpointing
Even an 8-second write blocks the training step if it is synchronous. Modern training pipelines do checkpoint writes asynchronously: the gradient step finishes, the parameters are duplicated to a CPU staging buffer (or a separate GPU scratch buffer), the next step starts, and the write completes in the background.
PyTorch DCP supports async with async_save=True. The catch is the staging buffer: 1 TB of CPU memory per checkpoint per cluster, transient. For a DP=64 cluster, that is 16 GB of host memory per node, which fits on most production hosts but is worth budgeting.
For frequency: most production training writes a checkpoint every 1-4 hours. With a 60-second blocking write per checkpoint, that is 60s/3600s = 1.7% wall-clock overhead. With a 5-second async write, it is well under 0.5%. Either is acceptable; the async path is preferred.
Resilience and rotation
Checkpoints are the recovery mechanism for drain-and-replace and any other failure mode that interrupts the training run. Production policies typically keep:
- Last 3-5 checkpoints on the active PFS.
- Every Nth checkpoint (every 10th, every 100th) archived to S3.
- The "best" checkpoint (lowest validation loss) flagged separately and never auto-deleted.
The rotation policy and the archive transfer are usually a separate cron job, not part of the training loop. Tools like NVIDIA NeMo, MosaicML Composer, and DeepSpeed all have built-in rotation; rolling your own is fine for smaller deployments.
For S3 archival, the transfer is a parallel S3 multipart upload, typically using s5cmd or AWS CLI's --multipart-chunksize. A 1 TB checkpoint uploads in 5-15 minutes depending on the available cross-region bandwidth.
Cross-topology resharding
A 64-DP checkpoint loaded into a 32-DP run requires resharding: each new rank picks up two old shards. Each new rank picks up half a shard if the new world is bigger than the old.
PyTorch DCP supports this natively. The metadata file records the sharding layout, and the loader figures out which bytes go where. DeepSpeed has a similar feature (zero_to_fp32.py followed by re-sharding).
This matters operationally because cluster sizes change during a long training run. A drained node, a hardware repair, an opportunistic burst-up: any of these change the world size between checkpoint save and load. Resharding makes the checkpoint format independent of the cluster topology that created it.
What this means in practice
- Use sharded checkpointing for any model where the full checkpoint is over a few hundred GB. PyTorch DCP, DeepSpeed, Megatron-LM all support it.
- Pair sharded write with GPUDirect Storage when the PFS supports it. The HBM-to-PFS path is faster than HBM-to-host-mem-to-PFS.
- Use async writes with a staging buffer. Synchronous checkpoint writes that block the training step are 5-10x more expensive in wall-clock than necessary.
- Rotation policy matters. Keep recent checkpoints on hot PFS, archive older ones to S3. The archive copy is the recovery path of last resort; the hot copy is the recovery path of normal operation.
- For resharding: test loading a sharded checkpoint into a different world size before you need to. Production failure recovery often runs at a smaller cluster size than the original, and a checkpoint that cannot be resharded is useless in that scenario.
- Verify checkpoint correctness by loading from disk and computing a few inference outputs immediately after writing. The most common production bug is a checkpoint that wrote but is somehow corrupted (truncated file, wrong dtype, missing tensor). Validation runs catch this before the next training session needs to depend on it.
Sharded checkpointing is one of the few things in storage that has a clear answer: parallel writes, async, with rotation. The cost of getting it wrong is hours of lost training every time something fails.
See also
Updated 2026-05-10