MegaCpp EngineeringApplied C++ model systems
</>
Article
Grounded engineering note from the MegaCpp stack
Published 8 min readMegaCpp Engineering
Optimizer
Muon
Adamw
Distributed
Numerical Stability
TPU
H200

Distributed Optimizer Stress: Drift, All-Gather vs Reduce-Scatter, and Muon Gotchas

MegaCpp
Focused on applied C++ model engineering
Article Preview
Distributed Optimizer Stress: Drift, All-Gather vs Reduce-Scatter, and Muon Gotchas
Published 8 min readMegaCpp Engineering

Distributed Optimizer Stress: Drift, All-Gather vs Reduce-Scatter, and Muon Gotchas

Optimizer bugs in a distributed trainer are the worst class of bug to own. They do not fail fast. They do not raise. They compound silently over thousands of steps and then surface as "the loss curve looks weird around step 3000" on a long run. For the MegaCpp trainer we run a hybrid DistAdamW + DistMuon optimizer across CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200 and XLA backends, and we spent real time building a stress harness whose entire job is to catch drift before it compounds.

The nearest adjacent reads are NCCL and collective hangs, MoE routing we actually shipped, and FSDP on CUDA and Megatron DDP, because the same collective-order mistakes can show up as hangs, routing drift, or optimizer divergence depending on where they surface.

Why a dedicated stress harness

The unit tests pass. The integration tests pass. The training runs also look fine for a while. So what does the stress harness do that those do not?

It exercises the one failure mode that regular training never reliably produces: rank-asymmetric grad presence. In real training, every rank computes a forward and backward for every parameter on every step, so every p.grad is non-None on every rank simultaneously. The distributed optimizer code path that handles {rank 0: grad, rank 1: None} almost never runs in practice. But it has to exist, because conditional compute can legitimately let one rank see grad=None while another rank sees a real gradient.

The stress harness runs a 1000-step schedule with --sample-every=100, explicitly cycling through:

  • {0: 1.0, 1: None}: rank 0 has grad, rank 1 does not
  • {0: None, 1: -0.5}: the opposite
  • {0: 0.25, 1: 0.75}: both have grads
  • {0: None, 1: None}: both missing

Across both scenarios — the pure DistAdamW + DistMuon case and the megatron_conditional case that toggles conditional features with rank-asymmetric patterns — the latest TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries v6e-8 CPU/gloo run reports max_adam_abs_diff = 0.0, max_mu_a_abs_diff = 0.0, max_mu_b_abs_diff = 0.0, and max_abs_param_diff = 0.0 at every sampled checkpoint. That is not "within tolerance". That is bit-identical parameter values on both ranks after 1000 steps.

Tolerances still exist: the harness thresholds are <= 2e-5 for Adam and <= 3e-4 for Muon. We keep them because future numerics changes can legitimately drop us into "within tolerance but not zero". The current state is zero, and the day it stops being zero we want to find out on a stress run, not on step 3000 of a real training.

The all-gather vs reduce-scatter decision

DistAdamW is a ZeRO-2-style sharded optimizer: each rank owns a slice of the optimizer state (exp_avg, exp_avg_sq) for each large parameter, reduces grads into its own slice, updates the slice, and all-gathers the updated parameter so every rank sees the full tensor.

For a parameter p with first dim divisible by world size, the pattern is:

  1. Gather grads onto the owner rank via reduce_scatter_tensor on a flattened view, producing one grad_slice per rank.
  2. Each rank runs Adam locally on its slice, updating its shard of exp_avg and exp_avg_sq and the parameter shard itself.
  3. all_gather_into_tensor rebuilds the full parameter into a shared buffer, which is sliced back into p.

The alternative, all_reduce on the full gradient followed by a local update on replicated optimizer state, is what DistAdamW falls back to for small params or params whose first dim does not divide the world size. The shape check is explicit: shape[0] % world_size != 0 forces the all_reduce path to avoid a reduce_scatter crash on indivisible shapes.

Why keep both patterns? Because reduce_scatter + all_gather only wins when the parameter is large enough for the bandwidth savings to dominate kernel-launch and state-management overhead. On our geometry the crossover is around 1024 elements. Below that, the per-param kernel launches and the bookkeeping for a sharded optimizer state cost more than the full-tensor all-reduce.

On DTensorQuick term guideDTensorPyTorch's mesh-backed distributed-tensor abstraction: one logical tensor with explicit shard or replica metadata across ranks.GroundingAbout: EP / PP / TP / CP / SP / DP overview Example: 3D parallelism sample Reference: FSDP2 on XLA TPU-like params, both paths are short-circuited: the runtime handles the collective internally, so DistAdamW skips its own and consumes the already reduced local shard. That was not obvious in our first DTensorQuick term guideDTensorPyTorch's mesh-backed distributed-tensor abstraction: one logical tensor with explicit shard or replica metadata across ranks.GroundingAbout: EP / PP / TP / CP / SP / DP overview Example: 3D parallelism sample Reference: FSDP2 on XLA TPU TPQuick term guideTPTensor parallelism splits each linear's weights (QKV, O, MLP gate/up/down) across GPUs. On 8× H200 with TP=8 each GPU owns 1/8 of every matmul's columns or rows, so one big matmul becomes 8 smaller ones that all-reduce at the layer boundary. Cost: one all-reduce per attention and per MLP — heavy bandwidth, so TP is usually bound to a single NVLink/NVSwitch island (1 node of up to 8 GPUs). Embeddings, layernorms, and optimizer state stay replicated across the TP GPUs. Use TP when a single layer's weights don't fit on one GPU, not to scale past one node.GroundingAbout: parallelism map overview Example: TP partition-shape sample Reference: tensor parallel and sharding integration; the early version double-reduced gradients for DTensorQuick term guideDTensorPyTorch's mesh-backed distributed-tensor abstraction: one logical tensor with explicit shard or replica metadata across ranks.GroundingAbout: EP / PP / TP / CP / SP / DP overview Example: 3D parallelism sample Reference: FSDP2 on XLA TPU params and produced exactly the kind of "loss drifts at step 3000" symptom the stress harness now catches.

Treat that ~1024 elements number as a measured crossover, not as optimizer folklore. Transport, launch overhead, and padding policy move it, which is why we keep it as a harness receipt rather than as a universal rule.

Grad-none symmetry: the quiet correctness property

The single hardest correctness property to preserve is that collectives run in identical order on all ranks, even when local grad presence differs. If rank 0 skips a reduce_scatter because its local grad is None while rank 1 runs it, NCCLQuick term guideNCCLNVIDIA's collective-communication library for all-reduce, all-gather, reduce-scatter, and point-to-point transport on CUDA multi-GPU lanes.GroundingAbout: NCCL and collective hangs Example: pipeline parallel sample Reference: training on 8x H200 either hangs or silently consumes the next unrelated collective and produces garbage.

The fix in both DistAdamW and DistMuon is the same shape: build a rank-symmetric "has grad" mask with a leading all_reduce(MAX) at the top of the step, and then always run the same collectives on every rank, substituting a zero-filled tensor for missing local grads. This is slightly more work than strictly necessary on the "everyone has a grad" case, but the cost is a single all_reduce of an int32 tensor the size of the param list. We measured it and stopped worrying about it.

The stress harness exists to verify exactly this: with the mask in place, 1000 steps of rank-asymmetric grad presence produce zero parameter divergence. Without the mask, the same harness used to deadlock in about fifty steps.

Muon-specific gotchas

Muon is a fundamentally different optimizer from Adam. It runs SGD momentum followed by a Newton-Schulz iteration to orthogonalize the update matrix before applying it. The math is well-behaved on paper. The distributed implementation has sharp edges that do not show up in any of the Adam literature.

1. Two-dimensional only

DistMuon asserts all(p.ndim == 2 for p in params) at construction time. This is not a performance optimization. It is a correctness requirement. Orthogonalizing a rank-3 tensor as if it were a matrix silently mixes subspaces that should stay independent. The most common offender in our stack was fused QKV. A tensor of shape [3*d, d] is 2D, but orthogonalizing it as a single matrix mixes the query, key, and value subspaces through Newton-Schulz in a way that regresses deep runs by a visible margin.

The fix is explicit qkv_split_sizes metadata on the param group. Muon respects the split and orthogonalizes each sub-matrix independently.

2. The cautious-update gate

Muon's fused step is momentum -> polar_express -> variance_reduction -> cautious_update. The easy mistake is to import the generic "cautious optimizer" story too literally and gate the final update against the raw gradient sign. That sounds reasonable on paper and was unstable in practice once the update had already been orthogonalized and variance-reduced.

The stable contract in our receipts is empirical: the conservative gate is parameter-sign agreement rather than raw-gradient agreement. Reintroducing the raw-grad gate pushed the deep Muon lane back toward immediate NaNs, so we keep the parameter-sign contract until a better large-scale receipt proves otherwise.

3. The reduce-scatter + all-gather dance

DistMuon uses the same two-pass collective pattern as DistAdamW but on stacked parameter groups. All params of the same shape within a group are stacked along a new leading axis, and the optimizer runs one fused kernel on the whole stack. The benefit is a single Newton-Schulz launch per group instead of one per param.

The hazard is that the chunk math must be identical on every rank. chunk_size = (len(group_params) + world_size - 1) // world_size is computed at init time and burned in. Adding a parameter to a group after init requires add_param_group to recompute chunk_size. We hit exactly one bug where the chunk size was stale and rank 1 operated on a different slice of the stacked tensor than rank 0.

4. XLA runtime scalars

The prepare_xla_step_scalars method on both DistMuon and DistAdamW exists because naive Python floats force XLA recompilation every time the learning-rate schedule advances. We cache materialized 0-D tensors for lr, momentum, weight_decay, beta2, and lr_multiplier on the XLA device, and rewrite them in place between steps. This is not numerically interesting. It is the difference between a compile-cache hit and a full recompile on every LR-schedule transition.

The failure mode here is compile churn, not bad math. A host read like loss.item() or an accidental Python float inside the step can force a sync and invalidate the cached XLA program just to materialize a scalar that could have waited until after the step.

5. FSDP2-native Muon vs DistMuon

MegaCpp actually ships three Muon variants: single-device Muon, ZeRO-2-style DistMuon, and FSDP2Muon, which consumes FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingAbout: FSDP2 on XLA TPU History: FSDP2 pain and payoff Example: FSDP sharding sample-sharded DTensorQuick term guideDTensorPyTorch's mesh-backed distributed-tensor abstraction: one logical tensor with explicit shard or replica metadata across ranks.GroundingAbout: EP / PP / TP / CP / SP / DP overview Example: 3D parallelism sample Reference: FSDP2 on XLA TPU parameters directly. FSDP2Muon is mathematically equivalent to DistMuon; the reason it exists as a separate class is that FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingAbout: FSDP2 on XLA TPU History: FSDP2 pain and payoff Example: FSDP sharding sample hands the optimizer sharded DTensors whose local shards do not match the naive "split along leading dim" layout that DistMuon assumes.

Keeping three Muon implementations is tech debt, and we know it. For now the stress harness runs both DistMuon and FSDP2Muon through the same grad-none-asymmetry scenarios and confirms bit-identical behavior across them.

What the TPU v6e-8 receipt actually covers

The most recent checked-in receipt is from the CPU/gloo leg on the v6e-8 host, not the XLA leg. That is a deliberate scope restriction: the stress harness currently supports gloo and ncclQuick term guideNCCLNVIDIA's collective-communication library for all-reduce, all-gather, reduce-scatter, and point-to-point transport on CUDA multi-GPU lanes.GroundingAbout: NCCL and collective hangs Example: pipeline parallel sample Reference: training on 8x H200 backends, not XLA. The CPU/gloo run exercises the exact same Python optimizer code on the same host, just with collectives running over gloo instead of over XLA.

This matters because the distributed logic and the grad-none mask are backend-independent. If they pass on gloo, they pass on any backend that correctly implements the collectives.

Extending the stress harness to XLA is a tracked follow-up. The main blocker is that the harness expects a torch.distributed process group and the XLA SPMDQuick term guideXLA SPMDThe explicit TPU sharding mode where one compiled program carries placement rules instead of rank-local imperative code.GroundingAbout: XLA SPMD sharding annotations About: XLA SPMD tokenizer and vocab on TPU Example: TPU backend ownership note runtime does not expose one at the same API level.

The cheap lesson

The single cheapest thing we did on the distributed optimizer rollout was adding the rank-symmetric all_reduce(MAX) mask at the top of every step. It is a small patch. It removed an entire class of "deadlock on conditional compute" failures. The single most expensive thing we did was chasing a deep Muon regression caused by fused QKV being treated as one flat matrix rather than three sub-matrices.

The pattern is consistent: the bugs that hurt are the ones where distributed correctness quietly degrades the optimizer math, not the ones where the collective hangs. Hangs are loud. Drift is silent.

FAQ

Frequently asked questions

Why keep a gloo or CPU receipt on a TPU host?+
Because the stress harness is validating the Python-side distributed optimizer contract before the expensive TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels. compile and multi-chip run. If that contract is wrong, backend choice will not save it, and the cheap host receipt catches the drift earlier.
Why does fixed-capacity routing not replace the grad-none stress harness?+
Because it only removes one source of dynamic shape drift. Conditional compute, stage-local features, and optional modules can still make local gradient presence diverge across ranks, so the optimizer still has to preserve collective order explicitly.
What bug does the grad-none mask actually prevent?+
It keeps collective order identical on every rank even when local gradient presence differs. Without it, one rank can skip a collective that another rank executes, which leads to either the loud failure (a hang) or the quiet one (rank-local optimizer drift).
Why does Muon need explicit qkv_split_sizes metadata?+
Because fused QKV is one 2D tensor that still contains three different semantic subspaces. Orthogonalizing it as one matrix mixes query, key, and value updates, so the optimizer can look numerically stable while model quality regresses.
Does GQA or MQA change the qkv_split_sizes rule?+
Yes. Splitting fused QKV stops Muon from mixing semantic subspaces, but GQA and MQA add a second geometry problem: query rows and key/value rows often have different aspect ratios. If the packed weight is still treated as one matrix, the smaller KV side inherits an update geometry tuned to the larger Q slice. The public-safe rule is to keep Q, K, and V as separate slices through orthogonalization and preserve per-slice geometry even when the packed weight is still 2D. Muon on Hopper and Blackwell is the adjacent read if you want the optimizer-facing version of that shape story.
Why is Muon's cautious gate keyed to parameter sign instead of raw gradient sign?+
Because the update that reaches the gate is no longer the raw gradient; it has already gone through momentum, orthogonalization, and variance reduction. Gating that final update against the pre-orthogonalization raw-grad sign was the unstable contract in our receipts and pushed deep Muon runs back toward immediate NaNs. The narrower public-safe rule is to keep the gate aligned with parameter-sign agreement until a better large-scale receipt proves otherwise.
Why is FSDP2Muon still a separate class if the Muon math matches DistMuon?+
Because FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism. hands the optimizer DTensorQuick term guideDTensorPyTorch's mesh-backed distributed-tensor abstraction: one logical tensor with explicit shard or replica metadata across ranks.-backed local shards, and those local views do not always line up with the naive leading-dimension split that DistMuon assumes. The math stays the same, but the shard adapter work changes. That seam gets sharper on fused-QKV shards. The checked-in FSDP2 Muon local shard sample first recovers the rank's Shard(0) row bounds and then rescales the full-row qkv_split_sizes metadata down to the local row count before the step runs. Muon on Hopper and Blackwell is the adjacent read for why that rescaling matters even when the parameter is still 2D.
Why keep host scalar reads out of the XLA optimizer step?+
Because they change the execution contract without changing the math. Pulling a scalar to host inside the step introduces a sync point and can trigger a new compiled program where a cached one should have been reused.
Where is the checked-in proof for the grad-none and local-shard claims?+
The local shape legend is FSDP2 Muon local shard sample, and the surrounding receipt discipline is runtime optimization receipts plus measured optimization receipts.
Glossary

Terms used in this article

Start here for quick definitions, then follow the linked posts for deeper context.

DTensor

PyTorch's mesh-backed distributed-tensor abstraction: one logical tensor with explicit shard or replica metadata across ranks.

FSDP2

PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.

TP

Tensor parallelism splits each linear's weights (QKV, O, MLP gate/up/down) across GPUs. On 8× H200 with TP=8 each GPU owns 1/8 of every matmul's columns or rows, so one big matmul becomes 8 smaller ones that all-reduce at the layer boundary. Cost: one all-reduce per attention and per MLP — heavy bandwidth, so TP is usually bound to a single NVLink/NVSwitch island (1 node of up to 8 GPUs). Embeddings, layernorms, and optimizer state stay replicated across the TP GPUs. Use TP when a single layer's weights don't fit on one GPU, not to scale past one node.

XLA SPMD

The explicit TPU sharding mode where one compiled program carries placement rules instead of rank-local imperative code.

NCCL

NVIDIA's collective-communication library for all-reduce, all-gather, reduce-scatter, and point-to-point transport on CUDA multi-GPU lanes.

H200

NVIDIA's Hopper H200 GPU platform, typically discussed here as an 8-GPU training node with large HBM capacity and NVLink-connected ranks.

TPU

Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.

MoE

Token Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble.

CUDA

NVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.