Distributed Optimizer Stress: Drift, All-Gather vs Reduce-Scatter, and Muon Gotchas

Distributed Optimizer Stress: Drift, All-Gather vs Reduce-Scatter, and Muon Gotchas
Optimizer bugs in a distributed trainer are the worst class of bug to
own. They do not fail fast. They do not raise. They compound silently
over thousands of steps and then surface as "the loss curve looks
weird around step 3000" on a $50K run. For the MegaCpp nanochat POC
we run a hybrid DistAdamW + DistMuon optimizer across CUDA and
XLA backends, and we spent real time building a stress harness whose
entire job is to catch drift before it compounds. This post is about
what that harness does, what it found, and the Muon-specific
failure modes that have nothing to do with orthogonalization math.
Why a dedicated stress harness
The unit tests pass. The integration tests pass. The training runs also look fine for a while. So what does the stress harness do that those do not?
It exercises the one failure mode that regular training never
reliably produces: rank-asymmetric grad presence. In real
training, every rank computes a forward and backward for every
parameter on every step, so every p.grad is non-None on every rank
simultaneously. The distributed optimizer code path that handles
{rank 0: grad, rank 1: None} almost never runs in practice. But it
exists. It has to exist, because conditional compute (MoD, ReDo,
MTP toggles, disabled experts, inactive adapters) lets a rank
legitimately have grad=None for a parameter while another rank
has a real gradient.
The stress harness runs scripts/distributed_optimizer_stress.py
under a 1000-step schedule with --sample-every=100, explicitly
cycling through:
{0: 1.0, 1: None}: rank 0 has grad, rank 1 does not{0: None, 1: -0.5}: the opposite{0: 0.25, 1: 0.75}: both have grads{0: None, 1: None}: both missing
Across both scenarios — the pure DistAdamW + DistMuon scenario and
the megatron_conditional scenario that toggles MoD/ReDo/MTP with
rank-asymmetric patterns — the latest TPU v6e-8 CPU/gloo run reports
max_adam_abs_diff = 0.0, max_mu_a_abs_diff = 0.0,
max_mu_b_abs_diff = 0.0, and max_abs_param_diff = 0.0 at every
sampled checkpoint. That is not "within tolerance". That is
bit-identical parameter values on both ranks after 1000 steps.
Tolerances still exist: the harness thresholds are <= 2e-5 for
Adam and <= 3e-4 for Muon. We keep them because future numerics
changes (fp32 fallbacks, new precision policies) can legitimately
drop us into "within tolerance but not zero". The current state is
zero, and the day it stops being zero we want to find out on a
stress run, not on step 3000 of a real training.
The all-gather vs reduce-scatter decision
DistAdamW is a ZeRO-2-style sharded optimizer: each rank owns a
slice of the optimizer state (exp_avg, exp_avg_sq) for each
large parameter, reduces grads into its own slice, updates the
slice, and all-gathers the updated parameter so every rank sees the
full tensor. For a parameter p with first dim divisible by world
size, the pattern is:
- Gather grads onto the owner rank via
reduce_scatter_tensoron a flattened view, producing onegrad_sliceper rank. - Each rank runs Adam locally on its slice, updating its shard of
exp_avg/exp_avg_sqand the param shard itself. all_gather_into_tensorrebuilds the full param into a shared buffer, which is sliced back intop.
The alternative, all_reduce on the full gradient followed by a
local update on the replicated optimizer state, is what
DistAdamW falls back to for small params or params whose first
dim does not divide the world size. The shape check is explicit:
shape[0] % world_size != 0 forces the all_reduce path to avoid
a reduce_scatter crash on indivisible shapes.
Why keep both patterns? Because reduce_scatter + all_gather only
wins when the parameter is large enough for the bandwidth savings
to dominate kernel-launch and state-management overhead. On our
geometry the crossover is around 1024 elements. Below that, the
per-param kernel launches and the bookkeeping for a sharded
optimizer state cost more than the full-tensor all-reduce. The
is_small list in the optimizer step code is explicit about which
param took which path; we trace it on step 0 and confirm on
every regression that the split did not drift.
On DTensor-like params, both paths are short-circuited: the DTensor
runtime handles the collective internally, so DistAdamW skips
its own and consumes the already-reduced local shard. This was not
obvious when we first integrated DTensor TP — our early
implementation double-reduced gradients for DTensor params and
produced exactly the kind of "loss drifts at step 3000" symptom
the stress harness now catches. The current code gates on
_is_dtensor_like_param(p) and takes the local shard directly.
Grad-none symmetry: the quiet correctness property
The single hardest correctness property to preserve is that
collectives run in identical order on all ranks, even when local
grad presence differs. If rank 0 skips a reduce_scatter because
its local grad is None while rank 1 runs it, NCCL either hangs
or — worse — silently consumes the next unrelated collective and
produces garbage.
The fix in both DistAdamW and DistMuon is the same shape: build
a rank-symmetric "has grad" mask with a leading all_reduce(MAX)
at the top of the step, and then always run the same collectives
on every rank, substituting a zero-filled tensor for missing
local grads. This is slightly more work than strictly necessary on
the "everyone has a grad" case, but the cost is a single
all_reduce of an int32 tensor the size of the param list. We
measured it and stopped worrying about it.
The stress harness's reason for existing is to verify exactly this: with the mask in place, 1000 steps of rank-asymmetric grad presence produce zero parameter divergence. Without the mask, the same harness used to deadlock in ~50 steps.
Muon-specific gotchas
Muon is a fundamentally different optimizer from Adam. It runs SGD-momentum followed by a Newton-Schulz iteration to orthogonalise the update matrix before applying it. The math is well-behaved on paper. The distributed implementation has sharp edges that do not show up in any of the Adam literature.
1. Two-dimensional only
DistMuon asserts all(p.ndim == 2 for p in params) at construction
time. This is not a performance optimization — it is a correctness
requirement. Orthogonalising a rank-3 tensor as if it were a matrix
silently mixes subspaces that should stay independent. The most
common offender in our stack was fused QKV. A tensor of shape
[3*d, d] is 2D, but orthogonalising it as a single matrix mixes
the query, key, and value subspaces through Newton-Schulz in a way
that regresses depth-52 runs by a visible margin on the loss curve.
The fix is explicit qkv_split_sizes metadata on the param group.
Muon respects the split and orthogonalises each sub-matrix
independently. Losing that metadata during a refactor was the most
recent regression we caught; it showed up immediately on the
depth-52 Muon sanity run, not on the stress harness, because
grad-none symmetry has nothing to do with the orthogonalization
path. That is itself an argument for keeping both kinds of
regression guard in place.
2. The cautious-update gate
Muon's fused step is momentum -> polar_express -> variance_reduction -> cautious_update. The cautious update masks out components of
the orthogonalised update whose sign disagrees with the raw
gradient's sign. Earlier in the rollout we reintroduced a raw-grad
gate that looked equivalent but used the pre-momentum gradient for
the sign check. On depth-52 runs this regressed Muon back to
pre-cautious-update behaviour. The current code keeps the gate on
the post-variance-reduction update, not on the raw grad, and the
regression receipt is preserved as a test case.
3. The reduce-scatter + all-gather dance
DistMuon uses the same two-pass collective pattern as DistAdamW
but on stacked parameter groups. All params of the same shape
within a group are stacked along a new leading axis, and the
optimizer runs one fused kernel on the whole stack. The benefit
is a single Newton-Schulz launch per group instead of one per
param, which dominated step time on deep models.
The hazard is that the chunk math must be identical on every rank.
chunk_size = (len(group_params) + world_size - 1) // world_size
is computed at init time and burned in. Adding a parameter to a
group after init (for example, when an adapter gets enabled
mid-run) requires add_param_group to recompute chunk_size.
We had exactly one bug where the chunk size was stale and rank 1
operated on a different slice of the stacked tensor than rank 0.
It produced a single non-zero value in the stress harness at
step 100 and we caught it before it hit real training.
4. XLA runtime scalars
The prepare_xla_step_scalars method on both DistMuon and DistAdamW
exists because naive float Python scalars force XLA recompilation
every time the LR schedule advances. We cache materialised 0-D
tensors for lr, momentum, weight_decay, beta2, and
lr_multiplier on the XLA device, and rewrite them in place
between steps. This is not numerically interesting — the values
are the same either way — but it is the difference between a
compile cache hit and a full recompile on every LR-schedule
transition. On TPU v6e-8 that matters.
5. FSDP2-native Muon vs DistMuon
The POC actually ships three Muon variants: single-device Muon,
ZeRO-2-style DistMuon, and a FSDP2Muon that consumes
FSDP2-sharded DTensor parameters directly. The last one is
mathematically equivalent to DistMuon: each rank owns a chunk
of the stacked parameters and orthogonalises its slice. The
reason it exists as a separate class is that FSDP2 hands
parameters to the optimizer as sharded DTensors whose local
shards do not match the naive "split along leading dim" layout
DistMuon assumes. The adapter layer is
_match_grad_to_local_shard, which maps the DTensor grad onto
the local shard shape FSDP2Muon consumes.
Keeping three Muon implementations is tech debt, and we are
aware of it. Collapsing them into one would require either
making DistMuon DTensor-native (costly rewrite) or abandoning
the non-FSDP2 path (premature, because TPU XLA SPMD does not
use FSDP2). For now the stress harness runs both DistMuon and
FSDP2Muon through the same grad-none asymmetry scenarios and
confirms bit-identical behaviour across them.
What the TPU v6e-8 receipt actually covers
The most recent checked-in receipt is from the CPU/gloo leg on the
v6e-8 host, not the XLA leg. That is a deliberate scope
restriction: the stress harness currently supports gloo and
nccl backends, not XLA. The CPU/gloo run exercises the exact
same Python optimizer code (DistAdamW, DistMuon,
_group_muon_params, prepare_xla_step_scalars paths) on the
same host, just with collectives running over gloo instead of
over XLA.
This matters because the distributed logic and the grad-none mask
are backend-independent. If they pass on gloo, they pass on any
backend that correctly implements the MPI collectives. The XLA
distributed optimizer path is exercised indirectly through real
multi-chip training runs via the base trainer with
--tensor_parallel, where divergence would show up as a loss
curve regression.
Extending the stress harness to XLA is a tracked follow-up; the
main blocker is that the harness expects a torch.distributed
process group, and the XLA SPMD runtime does not expose one at
the same API level. We have not yet decided whether the right
fix is to wrap XLA's collective primitives or to run the stress
harness against a synthetic gloo surrogate that replays XLA
SPMD grad layouts.
The cheap lesson
The single cheapest thing we did on the distributed optimizer
rollout was adding the rank-symmetric all_reduce(MAX) mask at
the top of every step. It is five lines of code. It removed the
entire class of "deadlock on conditional compute" failure. The
single most expensive thing we did was chasing a depth-52
regression caused by fused QKV being treated as a flat matrix
rather than three sub-matrices. It took a week.
The pattern is consistent: the bugs that hurt are the ones where distributed correctness quietly degrades the optimizer math, not the ones where the collective hangs. Hangs are loud. Drift is silent. The harness above is the handful of silent-drift sources we have actually hit.
The MegaCpp nanochat POC, by David Gornshtein and Boris Tamarkin,
treats the stress harness as a gate. If max_adam_abs_diff or any
of the Muon diffs drifts above zero on the checked-in scenarios,
the build does not ship. That is the kind of contract that is
easy to write down, occasionally annoying to satisfy, and worth
every minute the first time it catches a real regression.
References
- dist_optimizer_stress_tpu_v6e_2026-03-22.md
- TENSOR_PARALLELISM.md
- BACKEND_STOPLIGHT_MATRIX.md
- CURRENT_STATE.md
- tp_sp_ep_fsdp_h200_bringup_2026-04-07.md
- fa4_fsdp2_scaling_2026-03-22.json
- 11-adaptive-sharding-auto-fit.md
- TRAINING_PLAN.md
- training_plan_en.md
- training_review.md
- CHANGELOG.md