FSDP2 pain and payoff: what actually reduced memory
A practical look at selective wrapping, reshard timing, mixed precision, and the interaction between sharding, pipeline boundaries, and heterogeneous model blocks.

FSDP2 pain and payoff: what actually reduced memory
The easy story about FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingAbout: FSDP2 on XLA TPU Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview is simple: shard parameters, gather them for compute, then reduce-scatter on the way back. That story is directionally right, but it hides the operational question that really decides whether memory improves: what owns the live parameter state at each boundary of the forward and backward passes?
Why selective wrapping beats global wrapping
FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingAbout: FSDP2 on XLA TPU Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview tends to help most when the wrapped region matches a real execution boundary:
- a pipeline stage
- a large dense projection region
- another block family with predictable collective ownership
It tends to help less when wrapping crosses too many seams at once, especially optimizer boundaries, compileQuick term guideCompileWhy MegaCpp treats regional compile as a runtime-boundary decision rather than a blanket switch, and how compile ordering stays tied to distributed…GroundingRegional compile without losing the plot Dynamo and torch.compile Breakage on a Mamba-3 Hybrid-sensitive regions, or expert-routing surfaces.
| Surface | Stable posture | Why |
|---|---|---|
| pipeline-stage wrapper | strong candidate | aligns sharding with execution ownership |
| large dense projection blocks | often worth it | high parameter volume and predictable gather pattern |
| tiny helper modules | often leave replicated | low memory upside, higher complexity tax |
| mixed runtime seams | wrap conservatively | harder to prove live-state behavior |
The hidden issue is often timing, not sharding itself
Peak memory is not decided only by whether parameters are sharded. It is often decided by how long full views remain live. If several wrapped regions hold full parameter state longer than expected, the theoretical benefit of sharding gets eaten by overlapping materialization and transient buffers.
That leads to a better mental model: make full parameter views exist for the shortest trustworthy window.
Mixed precision and optimizer state matter
The memory story is never just about model weights. Optimizer state can erase a surprising amount of the gain if the sharded-parameter path is not treated explicitly. That is why FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingAbout: FSDP2 on XLA TPU Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview rollouts usually stabilize only after teams check:
- mixed-precision policy
- optimizer ownership assumptions
- whether shard-backed parameters are being handled explicitly enough
In practice, a broad sharding rollout often forces a second cleanup in optimizer behavior. If that cleanup does not happen, the win on model state becomes a partial loss somewhere else.
The practical check is whether the optimizer is stepping the local shard it actually owns, not a full-shaped tensor that only looks convenient in logs. The checked-in FSDP2 local-shard optimizer sample recovers local row bounds before the step and rescales QKV split metadata to the local row dimension, while the memory budget sample keeps parameters, optimizer state, activations, and overhead as separate capacity buckets.
Why heterogeneous models make the story narrower
Hybrid stacks make FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingAbout: FSDP2 on XLA TPU Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview more useful and less universal at the same time. Different block families put pressure on different seams. Dense blocks, expert blocks, and specialized layers do not all have the same ownership shape, so a single blanket wrapping rule is rarely the best one.
That is why the strongest FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingAbout: FSDP2 on XLA TPU Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview lessons are narrow:
- shard the large, repeatable parameter surfaces
- keep ownership boundaries explicit
- avoid treating every helper around dispatch or routing as a sharding target
Compile and overlap make bad assumptions visible
CompileQuick term guideCompileWhy MegaCpp treats regional compile as a runtime-boundary decision rather than a blanket switch, and how compile ordering stays tied to distributed…GroundingRegional compile without losing the plot Dynamo and torch.compile Breakage on a Mamba-3 Hybrid and pipeline overlap do not invalidate FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingAbout: FSDP2 on XLA TPU Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview. They make sloppy assumptions more obvious. The relevant question is no longer just "does FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingAbout: FSDP2 on XLA TPU Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview reduce memory?" It becomes "under this schedule and this compileQuick term guideCompileWhy MegaCpp treats regional compile as a runtime-boundary decision rather than a blanket switch, and how compile ordering stays tied to distributed…GroundingRegional compile without losing the plot Dynamo and torch.compile Breakage on a Mamba-3 Hybrid mode, what remains live at the same time?"
That is the question that separates a real memory win from a fragile benchmark result.
The useful summary
The payoff was real, but narrower than the marketing version. FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingAbout: FSDP2 on XLA TPU Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview helps when it is allowed to be specific: stage-aware, ownership-aware, and conservative about mixed seams. It becomes expensive when it is treated as a universal switch that can paper over optimizer, compileQuick term guideCompileWhy MegaCpp treats regional compile as a runtime-boundary decision rather than a blanket switch, and how compile ordering stays tied to distributed…GroundingRegional compile without losing the plot Dynamo and torch.compile Breakage on a Mamba-3 Hybrid, and routing complexity.
Frequently asked questions
Why can reshard_after_forward make memory look worse even when the schedule seems cleaner?+
fully_shard uses that setting to decide whether the unsharded parameter view is freed after forward or kept resident into backward. Leaving it unsharded can remove one backward-side all-gather, but it also widens the live window, so adjacent wrapped regions can overlap into a higher peak even while the communication schedule looks simpler.Why can root-only fully_shard(model) behave much worse than selective wrapping?+
fully_shard groups parameters by FSDP unit. If the only real unit is the root module, the gather boundary becomes too coarse, so too much parameter state is materialized together and the memory story collapses toward one big-group execution instead of stage-aware selective ownership. The safer FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism. pattern is still bottom-up wrapping of the repeated heavy blocks, then a root wrapper for the leftover embeddings and heads. Framework survey: FSDP vs Megatron vs DeepSpeed and FSDP, CUDA, and Megatron DDP are the adjacent reads for that boundary.Why can backward overlap still spike memory after selective wrapping?+
When does reshard_after_forward = N help instead of just adding another knob?+
N matches a real fast-island ownership boundary, usually an intra-node group that can carry the heavier parameter re-materialization traffic without pushing it onto the slower cross-node fabric. In that narrow case, the setting is not "middle ground for free"; it is a deliberate choice to keep the big gathers local while still re-sharding before the wider dataQuick term guideData pipelineAn honest walkthrough of how the MegaCpp training data pipeline was built — source selection, filtering, dedup, tokenization, document masking, and…-parallel boundary. FSDP, CUDA, and Megatron DDP is the closest local companion for that posture.Terms used in this article
Start here for quick definitions, then follow the linked posts for deeper context.
PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.
Why MegaCpp treats regional compile as a runtime-boundary decision rather than a blanket switch, and how compile ordering stays tied to distributed…
NVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.
An honest walkthrough of how the MegaCpp training data pipeline was built — source selection, filtering, dedup, tokenization, document masking, and…
Why lifting a hybrid attention/Mamba/MoE stack into Megatron-Core is a multi-adapter exercise: base config mapping, layer specs, mixer protocol, and…
A grounded walkthrough of how the project approaches small-language-model training: explicit stack specs, memory-first patches, hybrid blocks, and…