Distributed Optimizer Stress: Drift, All-Gather vs Reduce-Scatter, and Muon Gotchas

Distributed Optimizer Stress: Drift, All-Gather vs Reduce-Scatter, and Muon Gotchas
Optimizer bugs in a distributed trainer are the worst class of bug to own. They
do not fail fast. They do not raise. They compound silently over thousands of
steps and then surface as "the loss curve looks weird around step 3000" on a
long run. For the MegaCpp trainer we run a hybrid DistAdamW + DistMuon
optimizer across CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200 and XLA backends, and we spent real time building a
stress harness whose entire job is to catch drift before it compounds.
The nearest adjacent reads are NCCL and collective hangs, MoE routing we actually shipped, and FSDP on CUDA and Megatron DDP, because the same collective-order mistakes can show up as hangs, routing drift, or optimizer divergence depending on where they surface.
Why a dedicated stress harness
The unit tests pass. The integration tests pass. The training runs also look fine for a while. So what does the stress harness do that those do not?
It exercises the one failure mode that regular training never reliably
produces: rank-asymmetric grad presence. In real training, every rank
computes a forward and backward for every parameter on every step, so every
p.grad is non-None on every rank simultaneously. The distributed optimizer
code path that handles {rank 0: grad, rank 1: None} almost never runs in
practice. But it has to exist, because conditional compute can legitimately let
one rank see grad=None while another rank sees a real gradient.
The stress harness runs a 1000-step schedule with --sample-every=100,
explicitly cycling through:
{0: 1.0, 1: None}: rank 0 has grad, rank 1 does not{0: None, 1: -0.5}: the opposite{0: 0.25, 1: 0.75}: both have grads{0: None, 1: None}: both missing
Across both scenarios — the pure DistAdamW + DistMuon case and the
megatron_conditional case that toggles conditional features with
rank-asymmetric patterns — the latest TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries v6e-8 CPU/gloo run reports
max_adam_abs_diff = 0.0, max_mu_a_abs_diff = 0.0,
max_mu_b_abs_diff = 0.0, and max_abs_param_diff = 0.0 at every sampled
checkpoint. That is not "within tolerance". That is bit-identical parameter
values on both ranks after 1000 steps.
Tolerances still exist: the harness thresholds are <= 2e-5 for Adam and
<= 3e-4 for Muon. We keep them because future numerics changes can
legitimately drop us into "within tolerance but not zero". The current state is
zero, and the day it stops being zero we want to find out on a stress run, not
on step 3000 of a real training.
The all-gather vs reduce-scatter decision
DistAdamW is a ZeRO-2-style sharded optimizer: each rank owns a slice of the
optimizer state (exp_avg, exp_avg_sq) for each large parameter, reduces
grads into its own slice, updates the slice, and all-gathers the updated
parameter so every rank sees the full tensor.
For a parameter p with first dim divisible by world size, the pattern is:
- Gather grads onto the owner rank via
reduce_scatter_tensoron a flattened view, producing onegrad_sliceper rank. - Each rank runs Adam locally on its slice, updating its shard of
exp_avgandexp_avg_sqand the parameter shard itself. all_gather_into_tensorrebuilds the full parameter into a shared buffer, which is sliced back intop.
The alternative, all_reduce on the full gradient followed by a local update
on replicated optimizer state, is what DistAdamW falls back to for small
params or params whose first dim does not divide the world size. The shape
check is explicit: shape[0] % world_size != 0 forces the all_reduce path to
avoid a reduce_scatter crash on indivisible shapes.
Why keep both patterns? Because reduce_scatter + all_gather only wins when
the parameter is large enough for the bandwidth savings to dominate
kernel-launch and state-management overhead. On our geometry the crossover is
around 1024 elements. Below that, the per-param kernel launches and the
bookkeeping for a sharded optimizer state cost more than the full-tensor
all-reduce.
On DTensorQuick term guideDTensorPyTorch's mesh-backed distributed-tensor abstraction: one logical tensor with explicit shard or replica metadata across ranks.GroundingAbout: EP / PP / TP / CP / SP / DP overview Example: 3D parallelism sample Reference: FSDP2 on XLA TPU-like params, both paths are short-circuited: the runtime handles the
collective internally, so DistAdamW skips its own and consumes the already
reduced local shard. That was not obvious in our first DTensorQuick term guideDTensorPyTorch's mesh-backed distributed-tensor abstraction: one logical tensor with explicit shard or replica metadata across ranks.GroundingAbout: EP / PP / TP / CP / SP / DP overview Example: 3D parallelism sample Reference: FSDP2 on XLA TPU TPQuick term guideTPTensor parallelism splits each linear's weights (QKV, O, MLP gate/up/down) across GPUs. On 8× H200 with TP=8 each GPU owns 1/8 of every matmul's columns or rows, so one big matmul becomes 8 smaller ones that all-reduce at the layer boundary. Cost: one all-reduce per attention and per MLP — heavy bandwidth, so TP is usually bound to a single NVLink/NVSwitch island (1 node of up to 8 GPUs). Embeddings, layernorms, and optimizer state stay replicated across the TP GPUs. Use TP when a single layer's weights don't fit on one GPU, not to scale past one node.GroundingAbout: parallelism map overview Example: TP partition-shape sample Reference: tensor parallel and sharding integration;
the early version double-reduced gradients for DTensorQuick term guideDTensorPyTorch's mesh-backed distributed-tensor abstraction: one logical tensor with explicit shard or replica metadata across ranks.GroundingAbout: EP / PP / TP / CP / SP / DP overview Example: 3D parallelism sample Reference: FSDP2 on XLA TPU params and produced
exactly the kind of "loss drifts at step 3000" symptom the stress harness now
catches.
Treat that ~1024 elements number as a measured crossover, not as optimizer
folklore. Transport, launch overhead, and padding policy move it, which is why
we keep it as a harness receipt rather than as a universal rule.
Grad-none symmetry: the quiet correctness property
The single hardest correctness property to preserve is that collectives run in
identical order on all ranks, even when local grad presence differs. If rank 0
skips a reduce_scatter because its local grad is None while rank 1 runs it,
NCCLQuick term guideNCCLNVIDIA's collective-communication library for all-reduce, all-gather, reduce-scatter, and point-to-point transport on CUDA multi-GPU lanes.GroundingAbout: NCCL and collective hangs Example: pipeline parallel sample Reference: training on 8x H200 either hangs or silently consumes the next unrelated collective and
produces garbage.
The fix in both DistAdamW and DistMuon is the same shape: build a
rank-symmetric "has grad" mask with a leading all_reduce(MAX) at the top of
the step, and then always run the same collectives on every rank, substituting
a zero-filled tensor for missing local grads. This is slightly more work than
strictly necessary on the "everyone has a grad" case, but the cost is a single
all_reduce of an int32 tensor the size of the param list. We measured it
and stopped worrying about it.
The stress harness exists to verify exactly this: with the mask in place, 1000 steps of rank-asymmetric grad presence produce zero parameter divergence. Without the mask, the same harness used to deadlock in about fifty steps.
Muon-specific gotchas
Muon is a fundamentally different optimizer from Adam. It runs SGD momentum followed by a Newton-Schulz iteration to orthogonalize the update matrix before applying it. The math is well-behaved on paper. The distributed implementation has sharp edges that do not show up in any of the Adam literature.
1. Two-dimensional only
DistMuon asserts all(p.ndim == 2 for p in params) at construction time.
This is not a performance optimization. It is a correctness requirement.
Orthogonalizing a rank-3 tensor as if it were a matrix silently mixes subspaces
that should stay independent. The most common offender in our stack was fused
QKV. A tensor of shape [3*d, d] is 2D, but orthogonalizing it as a single
matrix mixes the query, key, and value subspaces through Newton-Schulz in a way
that regresses deep runs by a visible margin.
The fix is explicit qkv_split_sizes metadata on the param group. Muon
respects the split and orthogonalizes each sub-matrix independently.
2. The cautious-update gate
Muon's fused step is momentum -> polar_express -> variance_reduction -> cautious_update. The easy mistake is to import the generic
"cautious optimizer" story too literally and gate the final update against the
raw gradient sign. That sounds reasonable on paper and was unstable in practice
once the update had already been orthogonalized and variance-reduced.
The stable contract in our receipts is empirical: the conservative gate is parameter-sign agreement rather than raw-gradient agreement. Reintroducing the raw-grad gate pushed the deep Muon lane back toward immediate NaNs, so we keep the parameter-sign contract until a better large-scale receipt proves otherwise.
3. The reduce-scatter + all-gather dance
DistMuon uses the same two-pass collective pattern as DistAdamW but on
stacked parameter groups. All params of the same shape within a group are
stacked along a new leading axis, and the optimizer runs one fused kernel on
the whole stack. The benefit is a single Newton-Schulz launch per group instead
of one per param.
The hazard is that the chunk math must be identical on every rank.
chunk_size = (len(group_params) + world_size - 1) // world_size is computed
at init time and burned in. Adding a parameter to a group after init requires
add_param_group to recompute chunk_size. We hit exactly one bug where the
chunk size was stale and rank 1 operated on a different slice of the stacked
tensor than rank 0.
4. XLA runtime scalars
The prepare_xla_step_scalars method on both DistMuon and DistAdamW exists
because naive Python floats force XLA recompilation every time the learning-rate
schedule advances. We cache materialized 0-D tensors for lr, momentum,
weight_decay, beta2, and lr_multiplier on the XLA device, and rewrite
them in place between steps. This is not numerically interesting. It is the
difference between a compile-cache hit and a full recompile on every
LR-schedule transition.
The failure mode here is compile churn, not bad math. A host read like
loss.item() or an accidental Python float inside the step can force a sync
and invalidate the cached XLA program just to materialize a scalar that could
have waited until after the step.
5. FSDP2-native Muon vs DistMuon
MegaCpp actually ships three Muon variants: single-device Muon,
ZeRO-2-style DistMuon, and FSDP2Muon, which consumes FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingAbout: FSDP2 on XLA TPU History: FSDP2 pain and payoff Example: FSDP sharding sample-sharded
DTensorQuick term guideDTensorPyTorch's mesh-backed distributed-tensor abstraction: one logical tensor with explicit shard or replica metadata across ranks.GroundingAbout: EP / PP / TP / CP / SP / DP overview Example: 3D parallelism sample Reference: FSDP2 on XLA TPU parameters directly. FSDP2Muon is mathematically equivalent to
DistMuon; the reason it exists as a separate class is that FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingAbout: FSDP2 on XLA TPU History: FSDP2 pain and payoff Example: FSDP sharding sample hands the
optimizer sharded DTensors whose local shards do not match the naive
"split along leading dim" layout that DistMuon assumes.
Keeping three Muon implementations is tech debt, and we know it. For now the
stress harness runs both DistMuon and FSDP2Muon through the same
grad-none-asymmetry scenarios and confirms bit-identical behavior across them.
What the TPU v6e-8 receipt actually covers
The most recent checked-in receipt is from the CPU/gloo leg on the v6e-8 host,
not the XLA leg. That is a deliberate scope restriction: the stress harness
currently supports gloo and ncclQuick term guideNCCLNVIDIA's collective-communication library for all-reduce, all-gather, reduce-scatter, and point-to-point transport on CUDA multi-GPU lanes.GroundingAbout: NCCL and collective hangs Example: pipeline parallel sample Reference: training on 8x H200 backends, not XLA. The CPU/gloo run
exercises the exact same Python optimizer code on the same host, just with
collectives running over gloo instead of over XLA.
This matters because the distributed logic and the grad-none mask are backend-independent. If they pass on gloo, they pass on any backend that correctly implements the collectives.
Extending the stress harness to XLA is a tracked follow-up. The main blocker is
that the harness expects a torch.distributed process group and the XLA SPMDQuick term guideXLA SPMDThe explicit TPU sharding mode where one compiled program carries placement rules instead of rank-local imperative code.GroundingAbout: XLA SPMD sharding annotations About: XLA SPMD tokenizer and vocab on TPU Example: TPU backend ownership note
runtime does not expose one at the same API level.
The cheap lesson
The single cheapest thing we did on the distributed optimizer rollout was
adding the rank-symmetric all_reduce(MAX) mask at the top of every step. It
is a small patch. It removed an entire class of "deadlock on conditional
compute" failures. The single most expensive thing we did was chasing a deep
Muon regression caused by fused QKV being treated as one flat matrix rather
than three sub-matrices.
The pattern is consistent: the bugs that hurt are the ones where distributed correctness quietly degrades the optimizer math, not the ones where the collective hangs. Hangs are loud. Drift is silent.
Frequently asked questions
Why keep a gloo or CPU receipt on a TPU host?+
Why does fixed-capacity routing not replace the grad-none stress harness?+
What bug does the grad-none mask actually prevent?+
Why does Muon need explicit qkv_split_sizes metadata?+
Does GQA or MQA change the qkv_split_sizes rule?+
Why is Muon's cautious gate keyed to parameter sign instead of raw gradient sign?+
Why is FSDP2Muon still a separate class if the Muon math matches DistMuon?+
DTensorQuick term guideDTensorPyTorch's mesh-backed distributed-tensor abstraction: one logical tensor with explicit shard or replica metadata across ranks.-backed local shards, and those local views do not always line up with the naive leading-dimension split that DistMuon assumes. The math stays the same, but the shard adapter work changes. That seam gets sharper on fused-QKV shards. The checked-in FSDP2 Muon local shard sample first recovers the rank's Shard(0) row bounds and then rescales the full-row qkv_split_sizes metadata down to the local row count before the step runs. Muon on Hopper and Blackwell is the adjacent read for why that rescaling matters even when the parameter is still 2D.Why keep host scalar reads out of the XLA optimizer step?+
Where is the checked-in proof for the grad-none and local-shard claims?+
Terms used in this article
Start here for quick definitions, then follow the linked posts for deeper context.
PyTorch's mesh-backed distributed-tensor abstraction: one logical tensor with explicit shard or replica metadata across ranks.
PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.
Tensor parallelism splits each linear's weights (QKV, O, MLP gate/up/down) across GPUs. On 8× H200 with TP=8 each GPU owns 1/8 of every matmul's columns or rows, so one big matmul becomes 8 smaller ones that all-reduce at the layer boundary. Cost: one all-reduce per attention and per MLP — heavy bandwidth, so TP is usually bound to a single NVLink/NVSwitch island (1 node of up to 8 GPUs). Embeddings, layernorms, and optimizer state stay replicated across the TP GPUs. Use TP when a single layer's weights don't fit on one GPU, not to scale past one node.
The explicit TPU sharding mode where one compiled program carries placement rules instead of rank-local imperative code.
NVIDIA's collective-communication library for all-reduce, all-gather, reduce-scatter, and point-to-point transport on CUDA multi-GPU lanes.
NVIDIA's Hopper H200 GPU platform, typically discussed here as an 8-GPU training node with large HBM capacity and NVLink-connected ranks.
Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.
Token Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble.
NVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.