FSDP2 Pain and Payoff: What Actually Cut Memory on GPU and TPU

FSDP2 Pain and Payoff: What Actually Cut Memory on GPU and TPU
FSDP2 is sold as a drop-in upgrade over legacy FSDP. In isolation that is
roughly true. Inside a real training stack that also runs tensor
parallelism, expert parallelism, pipeline parallelism, sequence
parallelism, and a hybrid Muon + AdamW optimizer, it is not. This post is
the unvarnished account of rolling fully_shard across both CUDA (H200)
and XLA (TPU v6e) for the MegaCpp nanochat POC, with particular focus on
the three knobs that actually mattered: reshard_after_forward,
MixedPrecisionPolicy, and activation checkpointing.
The headline: FSDP2 delivered real memory headroom, but only after we stopped treating it as a "set it and forget it" wrapper and started treating it as a small graph of communication decisions you own per module group.
Two backends, one contract
The POC has two sharding surfaces in production:
- CUDA FSDP2 via
torch.distributed.fsdp.fully_shard, wrapping each transformer block, the embedding, and the LM head as separate groups on the DP mesh. - XLA FSDP as SPMD ZeRO-3 on TPU v6e, implemented in a standalone
apply_fsdp_sharding()that composes with the TP/EP mesh instead of wrapping modules.
They share nothing at the API level. We kept the public wrapper
(apply_cuda_fsdp, apply_fsdp_sharding) honest about that: no
polymorphism, no "backend=auto". Code-wise it is two implementations
with a narrow shared contract on what stays replicated (routing,
Mamba-like internals, norm/QK-norm weights, LoRA adapters) and what
gets sharded (attention Q/K/V/O, MLP up/gate/down, expert banks).
Everything below is about the CUDA side unless called out, because the XLA SPMD path is mostly automatic once the mesh is right.
reshard_after_forward: the knob that looks free and isn't
The first real decision per FSDP2 group is reshard_after_forward. The
docs make it sound like a pure memory/speed slider:
True(default): free unsharded params after forward; re-all-gather them in backward.False: keep unsharded params resident; skip one all-gather per block per step.N(int): hybrid, reshard to a smaller mesh of sizeN.
On a small model this is a clean tradeoff. On a 4.7B-param MoE
(NAM52-class) with TP=2, EP=2, FSDP2=4 on H200, it is a trap.
Our first serious run was a depth-52 MoE with
NANOCHAT_FSDP2_NO_RESHARD_AFTER_FWD=1. The intuition was the standard
one: we had the HBM, so trading memory for one fewer collective per
block should be a win. It was not. The run OOM'd before steady state,
and the forensic pass in the FA4 + FSDP2 scaling receipt made the
mechanism obvious: with reshard_after_forward=False, all 52 blocks'
unsharded parameters accumulate in HBM simultaneously. The param
state for a single block is manageable. The param state for all 52 is
not. Nemotron's reference config does not do this, and for good
reason.
The fix was to revert the env toggle to default (True). The block
groups now free their unsharded shards between forward and backward,
and the MoE with EP + TP + FSDP2 + compile finally reached
compile-stable steady state on the validated lane. That result
overwrote a lot of intuition: "keep unsharded params resident" is
defensible for a narrow window of small, dense, single-group models.
It is actively harmful the moment you wrap every transformer block
separately on a deep stack.
Where reshard_after_forward=False did still earn its keep was on
the root-level groups that own a single small tensor: embedding, lm
head, and root-level auxiliary modules. They are small enough that
keeping them resident costs a rounding-error amount of HBM and
genuinely removes a collective. We expose that as a narrow per-group
decision inside apply_cuda_fsdp, not a global env flag anymore.
A second knob, NANOCHAT_FSDP2_PREFETCH_LIMIT, installs explicit
forward/backward prefetch lists of length N on each wrapped block.
With limit=1 the behaviour reproduces the default next/prev overlap
but issues the all-gather earlier from the CPU side. Higher limits
trade reserved memory for more overlap. On our lane limit=1 was the
sweet spot; anything higher reintroduced the same accumulation class
of failure, just more gradually.
MixedPrecisionPolicy: bf16 params, fp32 reduce, no exceptions
The second knob, and the one that cost us the most debugging time when
we got it wrong, is MixedPrecisionPolicy. FSDP2 lets you pick
independent dtypes for:
param_dtype: the dtype unsharded params cast to before forwardreduce_dtype: the dtype used for the gradient reduce-scatter
The combination we committed to, and would commit to again, is
param_dtype=bf16, reduce_dtype=fp32.
Two findings pushed us there. First, on H200 with Muon driving the 2D
matrix weights, bf16 param casts are fine in the forward/backward pass
but unacceptable in the grad reduction. Muon's Newton-Schulz iteration
is numerically touchy; it is derived to be stable on well-conditioned
gradient matrices and shows visible drift when the reduce introduces
systematic bf16 rounding. Lifting reduce_dtype to fp32 cost a small
amount of communication bandwidth and removed an entire class of
"why did depth-52 diverge at step 3000" bisects.
Second, we discovered a concrete FSDP2 + torch nightly interaction
where to_accumulated_grad_if_needed() was being called before
_unsharded_param was fully initialised on the first step. We landed
a narrow workaround, skip that call on the first step under the
affected nightly, and it shows up explicitly in the FA4/FSDP2 scaling
log as CUDA FSDP2 workaround: skipping to_accumulated_grad_if_needed() before _unsharded_param init on this torch nightly. Keeping the workaround gated on the nightly detection
rather than unconditional was worth the extra branch: the same code
has since run clean on two later nightlies without the workaround
firing.
On the XLA side, the equivalent of MixedPrecisionPolicy is the XLA
autocast regime plus an explicit grad-reduction dtype. The v6e lane
uses bf16 activations with fp32 grad reduce as well, kept symmetric
with CUDA so the distributed optimizer stress harness can compare
cross-backend invariants instead of chasing backend-specific
numerical drift.
Activation checkpointing: FP8 as the default, not the optimization
For anyone used to AMP-era training, "activation checkpointing" still
evokes torch.utils.checkpoint wrappers bolted on after the fact.
On H200 that instinct is wrong. The cheap default is now FP8
activation checkpointing via Transformer Engine, not bf16
rematerialization.
Our rollout uses enable_fp8_activation_checkpointing(model) and the
matching get_fp8_activation_checkpoint_context() to stream
checkpointed activations in FP8 rather than bf16. For the NAM52
support-region training receipts, that change alone moved peak
activation memory from the blocker column to the non-blocker column on
the compiled TP + SP + EP + FSDP2 + compile lane. The dense
MixedPrecisionPolicy already gave us bf16 compute; FP8 checkpointing
compressed the stored activations without changing the forward math.
Two caveats held us up. The first is coexistence with torch.compile:
early runs threw torch._dynamo recompile-limit warnings around the
MoE scatter site. These were performance warnings, not correctness
blockers, but they masked real compile failures. We now treat a
non-zero recompile_limit_count as a gate: if it trips on the
validated lane, the FP8 checkpoint context has to be inspected before
any throughput claim is accepted.
The second is ordering with LoRA injection. If LoRA is injected after
fully_shard, FSDP2 does not manage the new adapter params. Without
register_lora_grad_hooks(model, dp_mesh) they silently stop
gradient-syncing across DP ranks. We hit this exactly once. It
manifested as adapters diverging between ranks after a few hundred
steps, with no training-time error. The fix is one line; the lesson
is that FSDP2's "only what existed at wrap time" contract is
load-bearing.
Checkpointing the state, not just the activations
Beyond activation checkpointing there is the older and more boring
topic of saving model state. FSDP2 here is less interesting than it
looks. We checkpoint on the DP-sharded parameters directly and
reconstruct full tensors on load, matching the pattern the CUDA
FSDP2 wrapper exposes via its ignored_params contract for LoRA and
its shared MixedPrecisionPolicy for non-sharded branches.
The part that required actual work was teaching the checkpoint
pipeline about the hybrid optimizer. Muon state (momentum buffers,
second-moment buffers) and AdamW state (exp_avg, exp_avg_sq) live
on different sharding axes because the two optimizers shard differently
(DistMuon is per-parameter-group chunking; DistAdamW is ZeRO-2
style row sharding of first-dim-divisible params). A naive "save
optimizer state dict" round-trips fine on a single rank but asserts on
resume if the world size changes. We pinned resumes to the same world
size for the POC, flagged world-size-invariant checkpointing as an
open item, and moved on.
What actually cut memory
Stepping back, the memory wins that held up in the receipts were boring, in the right way:
reshard_after_forward=Trueon every block group (the default). The single largest lever, and it is "do nothing". Our real contribution was removing the env flag that overrode it.MixedPrecisionPolicy(param_dtype=bf16, reduce_dtype=fp32)as the committed policy across both backends.- FP8 activation checkpointing via TE, not bf16 rematerialization.
- Explicit per-group wrapping so the root FSDP2 group degrades to a no-param state instead of owning TP boundary params.
- LoRA adapters held out via
ignored_paramswhen injected before wrap, or resynced viaregister_lora_grad_hookswhen injected after.
None of these is novel. What is easy to get wrong, and what cost us
real time, is the interaction: FSDP2 + MoE + EP + compile + FP8
checkpointing is not a linear sum of five independent features. Each
pair has at least one sharp edge. The only durable defense is the
receipts: every lane above is reproduced by a checked-in bring-up
log that records actual_backend, fa4_backend_confirmed,
recompile_limit_count, and capacity_error_count on the exact
validated geometry. If those fields drift on a new run, we do not
ship.
What we are not claiming
We are not claiming that FSDP2 dominates any particular pure-TP or Megatron-style pipeline-parallel baseline on our workload. The validated compile-stable lane on 4x H200 at depth 16 reached roughly 1500 tok/sec on a small dense test geometry; that number exists to prove the lane is alive, not to win a benchmark table. Scaling throughput is the next phase, and it does not belong in the memory story.
We are also not claiming parity with Nemotron's production config.
The latest audit still lists a handful of gaps (moe_router_dtype=fp32
vs our bf16 autocast, moe_shared_expert_overlap, tp_comm_overlap
userbuffers, pre-collective NaN checks, average_in_collective,
cross_entropy_fusion_impl). Each one is tracked; none are blocking
the FSDP2 rollout described above.
The honest verdict is that FSDP2 is worth the rollout cost, the payoff
is mostly in memory and mostly from defaults, and the debt is in the
combinatorial matrix with every other parallelism dimension you
already use. For the MegaCpp nanochat POC, by David Gornshtein and
Boris Tamarkin, that trade was worth it. If you are about to do the
same migration, start with reshard_after_forward=True, commit to a
single MixedPrecisionPolicy, and only then go looking for the
interesting knobs.
References
- TENSOR_PARALLELISM.md
- BACKEND_STOPLIGHT_MATRIX.md
- CURRENT_STATE.md
- tp_sp_ep_fsdp_h200_bringup_2026-04-07.md
- fa4_fsdp2_scaling_2026-03-22.json
- 11-adaptive-sharding-auto-fit.md
- TRAINING_PLAN.md
- training_plan_en.md
- CHANGELOG.md