fsdp2
pytorch
tpu
h200
mixed-precision
checkpointing
training

FSDP2 Pain and Payoff: What Actually Cut Memory on GPU and TPU

8 min readDavid Gornshtein
FSDP2 Pain and Payoff: What Actually Cut Memory on GPU and TPU

FSDP2 Pain and Payoff: What Actually Cut Memory on GPU and TPU

FSDP2 is sold as a drop-in upgrade over legacy FSDP. In isolation that is roughly true. Inside a real training stack that also runs tensor parallelism, expert parallelism, pipeline parallelism, sequence parallelism, and a hybrid Muon + AdamW optimizer, it is not. This post is the unvarnished account of rolling fully_shard across both CUDA (H200) and XLA (TPU v6e) for the MegaCpp nanochat POC, with particular focus on the three knobs that actually mattered: reshard_after_forward, MixedPrecisionPolicy, and activation checkpointing.

The headline: FSDP2 delivered real memory headroom, but only after we stopped treating it as a "set it and forget it" wrapper and started treating it as a small graph of communication decisions you own per module group.

Two backends, one contract

The POC has two sharding surfaces in production:

  • CUDA FSDP2 via torch.distributed.fsdp.fully_shard, wrapping each transformer block, the embedding, and the LM head as separate groups on the DP mesh.
  • XLA FSDP as SPMD ZeRO-3 on TPU v6e, implemented in a standalone apply_fsdp_sharding() that composes with the TP/EP mesh instead of wrapping modules.

They share nothing at the API level. We kept the public wrapper (apply_cuda_fsdp, apply_fsdp_sharding) honest about that: no polymorphism, no "backend=auto". Code-wise it is two implementations with a narrow shared contract on what stays replicated (routing, Mamba-like internals, norm/QK-norm weights, LoRA adapters) and what gets sharded (attention Q/K/V/O, MLP up/gate/down, expert banks).

Everything below is about the CUDA side unless called out, because the XLA SPMD path is mostly automatic once the mesh is right.

reshard_after_forward: the knob that looks free and isn't

The first real decision per FSDP2 group is reshard_after_forward. The docs make it sound like a pure memory/speed slider:

  • True (default): free unsharded params after forward; re-all-gather them in backward.
  • False: keep unsharded params resident; skip one all-gather per block per step.
  • N (int): hybrid, reshard to a smaller mesh of size N.

On a small model this is a clean tradeoff. On a 4.7B-param MoE (NAM52-class) with TP=2, EP=2, FSDP2=4 on H200, it is a trap.

Our first serious run was a depth-52 MoE with NANOCHAT_FSDP2_NO_RESHARD_AFTER_FWD=1. The intuition was the standard one: we had the HBM, so trading memory for one fewer collective per block should be a win. It was not. The run OOM'd before steady state, and the forensic pass in the FA4 + FSDP2 scaling receipt made the mechanism obvious: with reshard_after_forward=False, all 52 blocks' unsharded parameters accumulate in HBM simultaneously. The param state for a single block is manageable. The param state for all 52 is not. Nemotron's reference config does not do this, and for good reason.

The fix was to revert the env toggle to default (True). The block groups now free their unsharded shards between forward and backward, and the MoE with EP + TP + FSDP2 + compile finally reached compile-stable steady state on the validated lane. That result overwrote a lot of intuition: "keep unsharded params resident" is defensible for a narrow window of small, dense, single-group models. It is actively harmful the moment you wrap every transformer block separately on a deep stack.

Where reshard_after_forward=False did still earn its keep was on the root-level groups that own a single small tensor: embedding, lm head, and root-level auxiliary modules. They are small enough that keeping them resident costs a rounding-error amount of HBM and genuinely removes a collective. We expose that as a narrow per-group decision inside apply_cuda_fsdp, not a global env flag anymore.

A second knob, NANOCHAT_FSDP2_PREFETCH_LIMIT, installs explicit forward/backward prefetch lists of length N on each wrapped block. With limit=1 the behaviour reproduces the default next/prev overlap but issues the all-gather earlier from the CPU side. Higher limits trade reserved memory for more overlap. On our lane limit=1 was the sweet spot; anything higher reintroduced the same accumulation class of failure, just more gradually.

MixedPrecisionPolicy: bf16 params, fp32 reduce, no exceptions

The second knob, and the one that cost us the most debugging time when we got it wrong, is MixedPrecisionPolicy. FSDP2 lets you pick independent dtypes for:

  • param_dtype: the dtype unsharded params cast to before forward
  • reduce_dtype: the dtype used for the gradient reduce-scatter

The combination we committed to, and would commit to again, is param_dtype=bf16, reduce_dtype=fp32.

Two findings pushed us there. First, on H200 with Muon driving the 2D matrix weights, bf16 param casts are fine in the forward/backward pass but unacceptable in the grad reduction. Muon's Newton-Schulz iteration is numerically touchy; it is derived to be stable on well-conditioned gradient matrices and shows visible drift when the reduce introduces systematic bf16 rounding. Lifting reduce_dtype to fp32 cost a small amount of communication bandwidth and removed an entire class of "why did depth-52 diverge at step 3000" bisects.

Second, we discovered a concrete FSDP2 + torch nightly interaction where to_accumulated_grad_if_needed() was being called before _unsharded_param was fully initialised on the first step. We landed a narrow workaround, skip that call on the first step under the affected nightly, and it shows up explicitly in the FA4/FSDP2 scaling log as CUDA FSDP2 workaround: skipping to_accumulated_grad_if_needed() before _unsharded_param init on this torch nightly. Keeping the workaround gated on the nightly detection rather than unconditional was worth the extra branch: the same code has since run clean on two later nightlies without the workaround firing.

On the XLA side, the equivalent of MixedPrecisionPolicy is the XLA autocast regime plus an explicit grad-reduction dtype. The v6e lane uses bf16 activations with fp32 grad reduce as well, kept symmetric with CUDA so the distributed optimizer stress harness can compare cross-backend invariants instead of chasing backend-specific numerical drift.

Activation checkpointing: FP8 as the default, not the optimization

For anyone used to AMP-era training, "activation checkpointing" still evokes torch.utils.checkpoint wrappers bolted on after the fact. On H200 that instinct is wrong. The cheap default is now FP8 activation checkpointing via Transformer Engine, not bf16 rematerialization.

Our rollout uses enable_fp8_activation_checkpointing(model) and the matching get_fp8_activation_checkpoint_context() to stream checkpointed activations in FP8 rather than bf16. For the NAM52 support-region training receipts, that change alone moved peak activation memory from the blocker column to the non-blocker column on the compiled TP + SP + EP + FSDP2 + compile lane. The dense MixedPrecisionPolicy already gave us bf16 compute; FP8 checkpointing compressed the stored activations without changing the forward math.

Two caveats held us up. The first is coexistence with torch.compile: early runs threw torch._dynamo recompile-limit warnings around the MoE scatter site. These were performance warnings, not correctness blockers, but they masked real compile failures. We now treat a non-zero recompile_limit_count as a gate: if it trips on the validated lane, the FP8 checkpoint context has to be inspected before any throughput claim is accepted.

The second is ordering with LoRA injection. If LoRA is injected after fully_shard, FSDP2 does not manage the new adapter params. Without register_lora_grad_hooks(model, dp_mesh) they silently stop gradient-syncing across DP ranks. We hit this exactly once. It manifested as adapters diverging between ranks after a few hundred steps, with no training-time error. The fix is one line; the lesson is that FSDP2's "only what existed at wrap time" contract is load-bearing.

Checkpointing the state, not just the activations

Beyond activation checkpointing there is the older and more boring topic of saving model state. FSDP2 here is less interesting than it looks. We checkpoint on the DP-sharded parameters directly and reconstruct full tensors on load, matching the pattern the CUDA FSDP2 wrapper exposes via its ignored_params contract for LoRA and its shared MixedPrecisionPolicy for non-sharded branches.

The part that required actual work was teaching the checkpoint pipeline about the hybrid optimizer. Muon state (momentum buffers, second-moment buffers) and AdamW state (exp_avg, exp_avg_sq) live on different sharding axes because the two optimizers shard differently (DistMuon is per-parameter-group chunking; DistAdamW is ZeRO-2 style row sharding of first-dim-divisible params). A naive "save optimizer state dict" round-trips fine on a single rank but asserts on resume if the world size changes. We pinned resumes to the same world size for the POC, flagged world-size-invariant checkpointing as an open item, and moved on.

What actually cut memory

Stepping back, the memory wins that held up in the receipts were boring, in the right way:

  • reshard_after_forward=True on every block group (the default). The single largest lever, and it is "do nothing". Our real contribution was removing the env flag that overrode it.
  • MixedPrecisionPolicy(param_dtype=bf16, reduce_dtype=fp32) as the committed policy across both backends.
  • FP8 activation checkpointing via TE, not bf16 rematerialization.
  • Explicit per-group wrapping so the root FSDP2 group degrades to a no-param state instead of owning TP boundary params.
  • LoRA adapters held out via ignored_params when injected before wrap, or resynced via register_lora_grad_hooks when injected after.

None of these is novel. What is easy to get wrong, and what cost us real time, is the interaction: FSDP2 + MoE + EP + compile + FP8 checkpointing is not a linear sum of five independent features. Each pair has at least one sharp edge. The only durable defense is the receipts: every lane above is reproduced by a checked-in bring-up log that records actual_backend, fa4_backend_confirmed, recompile_limit_count, and capacity_error_count on the exact validated geometry. If those fields drift on a new run, we do not ship.

What we are not claiming

We are not claiming that FSDP2 dominates any particular pure-TP or Megatron-style pipeline-parallel baseline on our workload. The validated compile-stable lane on 4x H200 at depth 16 reached roughly 1500 tok/sec on a small dense test geometry; that number exists to prove the lane is alive, not to win a benchmark table. Scaling throughput is the next phase, and it does not belong in the memory story.

We are also not claiming parity with Nemotron's production config. The latest audit still lists a handful of gaps (moe_router_dtype=fp32 vs our bf16 autocast, moe_shared_expert_overlap, tp_comm_overlap userbuffers, pre-collective NaN checks, average_in_collective, cross_entropy_fusion_impl). Each one is tracked; none are blocking the FSDP2 rollout described above.

The honest verdict is that FSDP2 is worth the rollout cost, the payoff is mostly in memory and mostly from defaults, and the debt is in the combinatorial matrix with every other parallelism dimension you already use. For the MegaCpp nanochat POC, by David Gornshtein and Boris Tamarkin, that trade was worth it. If you are about to do the same migration, start with reshard_after_forward=True, commit to a single MixedPrecisionPolicy, and only then go looking for the interesting knobs.

References

  • TENSOR_PARALLELISM.md
  • BACKEND_STOPLIGHT_MATRIX.md
  • CURRENT_STATE.md
  • tp_sp_ep_fsdp_h200_bringup_2026-04-07.md
  • fa4_fsdp2_scaling_2026-03-22.json
  • 11-adaptive-sharding-auto-fit.md
  • TRAINING_PLAN.md
  • training_plan_en.md
  • CHANGELOG.md
David Gornshtein • Datasunrise OÜMore posts →