Tensor, Sequence, Expert, FSDP: The Hybrid Sharding Stack and the v6e-8 Bugs

Tensor, Sequence, Expert, FSDP: The Hybrid Sharding Stack and the v6e-8 Bugs
When people say "we do tensor parallelism" they usually mean one of two things: the Megatron-style column/row split of MLP and attention, or the PyTorch DTensor variant of the same pattern. In the MegaCpp nanochat POC we run four sharding dimensions simultaneously and we run them on two unrelated backends. This post walks through the topology we actually ship, the contract on what stays replicated, and the specific class of bugs that only showed up once we left a single-host H200 box and ran the same model on TPU v6e-8.
What the stack actually is
The four dimensions in the hybrid stack:
- TP (tensor parallel): column/row split of attention Q/K/V/O and MLP up/gate/down, on an intra-host axis with fast interconnect.
- SP (sequence parallel): the activation/norm axis hanging off TP
so that layer norms and residual adds operate on
T/tptokens per rank, not on full sequences. - EP (expert parallel): expert banks sharded across a separate mesh axis so MoE tokens dispatch via AlltoAll instead of per-expert all-reduce.
- FSDP (ZeRO-3 style): data-parallel axis that also shards parameters, gradients, and optimizer states across DP ranks.
There is no single implementation of the above. On CUDA we combine
DTensor TP (or Megatron TP when we want explicit parallel-linear
wrappers), SP, EP via DeepEP dispatch, and FSDP2 through
fully_shard. On XLA we do SPMD TP via _apply_tensor_parallel_sharding
over a named mesh, combined with apply_fsdp_sharding and an
expert axis when MoE is enabled. They are two entirely different
pieces of code that happen to honour the same replicate/shard
contract.
Mesh shapes
The XLA mesh topology is described explicitly in the base trainer:
("data",)whenTP=1andEP=1("data", "model")whenTP>1("data", "expert")whenEP>1andTP=1("data", "expert", "model")whenEP>1andTP>1
So TPU work is one-axis, two-axis, or three-axis sharding depending
on the run. On CUDA the mesh is analogous but built via
torch.distributed.device_mesh, with FSDP2 wrapping each block on
the DP axis and TP sharding applied through parallelize_module.
The important invariant is that TP and EP axes are fast-local: NVLink/NVSwitch on H200, ICI on v6e. DP can be either intra-host or cross-host, but in the POC we deliberately kept it single-host for the first receipts to pin down the compute-and-comm contract before adding DCN variance.
What stays replicated
After a year of chasing "why is rank 3 drifting" bugs we have a short, blunt replicate list:
- Mamba weights and state-space internals
- Engram parameters
- mHC parameters
ngram_hashprojection tables- shared router / shared-expert / null-expert paths
- layer norm and QK-norm weights (with explicit post-backward all-reduce hooks on the TP group to keep them bit-identical)
- LoRA adapters (excluded via FSDP2
ignored_paramsor resynced viaregister_lora_grad_hooks)
Anything "sharp" or routing-sensitive is pinned replicated unless
the runtime explicitly implements a safe parallel rule. This is not
theoretical conservatism: most of the longest bisects we ran
traced to a router or a norm that the TP pass had silently sharded
because the tensor was a simple nn.Linear that looked indistinguishable
from a dense FFN.
Divisibility contract
The canonical shape rules live in the base trainer, enforced at startup:
num_heads % tensor_parallel == 0- effective
n_kv_head % tensor_parallel == 0 - fused
(n_head + 2 * n_kv_head) % tensor_parallel == 0 dsa_indexer_heads % tensor_parallel == 0when DSA is active
These are not advisory. Violating any of them produces silent correctness issues that only surface several hundred steps in, or on specific pipeline/EP combinations. We treat them as part of the sharding contract, not as performance hints.
The hybrid with LoRA is a real path, not a footnote
Adapter-aware TP was the part that almost did not make it into the POC. Early wrappers treated LoRA as "inject after TP, hope for the best" and broke on DTensor TP + LoRA. The current contract:
- DTensor TP and Megatron TP both have explicit LoRA-aware wrapping.
- Both have checked-in test coverage.
- FSDP2 wrapping treats LoRA as
ignored_paramswhen it exists at wrap time, and otherwise relies onregister_lora_grad_hooksto all-reduce adapter gradients on the DP group.
This is boring plumbing. It is also the only reason finetuning survives the hybrid mesh.
The validated CUDA lane
For the record, the specific H200 geometry that passes the bring-up matrix is:
depth=16,n_embd=128,head_dim=32,n_kv_head=2,max_seq_len=128,device_batch_size=2,total_batch_size=256tensor_parallel=2,sequence_parallel=on,fsdp_cuda=on- MoE with
moe_token_choice,moe_n_routed_experts=8,moe_top_k=2,moe_expert_size=64,expert_parallel=2,deepep_dispatch=on FSDP_USE_ORIG_PARAMS=true,TORCH_NCCL_AVOID_RECORD_STREAMS=1,TRITON_DEFAULT_NUM_STAGES=2
On this geometry, TP + SP + EP + FSDP2 eager is alive, multi-step is alive, and compile is compile-stable after a specific MoE fix described below. Depth expansion to 16 on the same 4-GPU box was straightforward once the eager lane was clean.
Bugs that only appeared on v6e-8
Single-host H200 is forgiving in one important way: every rank shares a clock domain, and collective ordering is mostly implicit. TPU v6e-8 is less forgiving. Three classes of bug that passed every CUDA smoke and then lit up on v6e-8:
1. EP-active predicate based on the wrong mesh
The CUDA-side fix is documented in the bring-up report, but the same
root cause affected XLA differently. The symptom on v6e-8 was
Block+MoE tokens going through the non-EP path when the user
asked for --expert_parallel=2 --expert_tensor_parallel=0 ("follow
TP"). The old predicate _expert_parallel_active = expert_tp_mesh is not None was wrong: EP can be active while expert_tp_mesh is
None, because "follow TP" does not allocate a separate expert-TP
mesh. The block then installed manual SP hooks on inner.mlp and
expanded router_in from local shard to full sequence. On a single
host this did not error; on v6e-8 it produced subtly inconsistent
token layouts across chips that only manifested as a loss plateau
around step 3000.
Fix: compute _expert_parallel_active from expert_parallel_degree > 1
and plumb expert_parallel_degree explicitly through the CUDA TP
wrapper (and the XLA mesh factory). After the fix, tracing on the
passing lane prints ep_active=True ep_local=True manual_sp=False ep_lane=True, and block_moe_mlp_after_clear pre=0 post=0 confirms
there are no stale hooks.
2. PP group duplication hang
Not strictly a v6e-8 bug, but it behaved like one. On CUDA, `PP + TP
- SP
could hang insidedist.new_group(_pp_group_ranks)because the existingdp_process_group` already had the same membership and creating a duplicate NCCL subgroup deadlocked. On TPU the equivalent manifested as the SPMD mesh rejecting a redundant submesh with no useful error. The fix on both sides is the same idea: reuse the existing DP group as the PP group when memberships are identical, instead of creating a new one. We also added all-rank PP group tracing around the new-group call so future hangs produce visible output instead of silence.
3. MoE + compile on the Megatron-permute padded path
The longest bisect of the rollout. On single-rank CUDA, standalone
TokenChoiceMoELayer compiled fine. On 4x H200 with
TP + SP + FSDP2 + compile + MoE, it failed deep in Inductor with
both a fallback and a decomp for same op: aten.index_add.default.
The inner-path ablations narrowed it all the way down: replacing
routing_weights.scatter_add did not help; replacing the final
padded_routed_out_flat.scatter_add with index_add did not help;
zeroing aux/z-loss bookkeeping did not help; expert-identity did not
help. What did help was forcing prefer_megatron_permute = False.
That removed the failure entirely and unblocked the full-lane
compile.
We committed the change as the default for compiled CUDA MoE and
kept the override available via NANOCHAT_FORCE_MEGATRON_PERMUTE_COMPILE=1
for future upstream retesting. On v6e-8 the failure class was
different (XLA does not use Inductor), but the same permute path was
implicated in a distinct tokenization symmetry bug: the Megatron
padded permute assumes a specific dispatch order that does not match
what TPU AlltoAll produces under --expert_tensor_parallel=0.
Killing the permute on both backends kept the shared code clean.
4. Unbatched P2P init warning turning into real overhead
On the H200 side, Schedule1F1B in
torch.distributed.pipelining.schedules degrades homogeneous batches
(all isend or all irecv) into raw dist.isend / dist.irecv.
That behaviour lights up a ProcessGroupNCCL.cpp unbatched P2P
warning during pipeline eager init. We patched it in a narrow
schedule.step(...) window behind NANOCHAT_FORCE_BATCHED_PP_INIT_P2P
(default 1) to force dist.batch_isend_irecv(p2p_ops). After the
patch, the warning disappeared and the clean PP + TP + SP + compile
lane ran at ~1280 tok/sec at depth 4, ~1516 tok/sec steady-state on
depth 16. On TPU v6e-8 the analogous effect was worse: each
unbatched send/recv materialised as a separate XLA op and stretched
compile time, because XLA does not hide P2P in the same way NCCL
does.
What single-host never caught
Across all four bug classes, the pattern is the same. Single-host masks two things: (a) it serialises order implicitly via a shared scheduler and a single driver, and (b) it silently absorbs predicate/mesh mistakes because the ranks sit on the same device memory. v6e-8 removes both props. Any code that assumed "the mesh shape I asked for is the mesh shape I got" instead of "prove it at runtime" eventually failed.
Our concrete response:
- Every mesh predicate (
expert_parallel_active,sp_on_block,ep_lane) now prints a one-line trace on step 0 of every rank, tagged with the rank and the mesh id. - Any code path that creates a new process group first checks if an existing group has identical membership, and reuses it.
- Compile flag pairs (
MoE + permute,MoE + FSDP2,EP + compile) that we know are fragile have explicit env gates with conservative defaults, so upstream retests do not silently reactivate a known bad configuration. - The receipt format for every bring-up log records
requested_backend,actual_backend,runtime_mode, andfallback_reasonso "I requested FA4" versus "FA4 actually executed" is never ambiguous.
What we kept out of the sharding contract
Three things deliberately do not live in the sharding doc:
- One-off benchmark numbers. Those belong in dated reports, not in a contract that changes with each model revision.
- Old batch-math rules. The old
num_devices / tp_degreeformula is no longer the general truth. Effective DP size isworld_size / (pp * tp * ep)and is owned by runtime helpers. - Historical OOM tables. Outdated hardware capacity claims caused more bugs than they prevented, because they encouraged people to skip actual trial runs.
Where we are
The validated compile-stable lane is TP + SP + EP + FSDP2 + compile
on 4x H200, small geometry, five-step run, recompile_limit_count=0,
capacity_error_count=0, no aten.index_add.default. Depth
expansion to 16 is clean. PP + TP + SP + compile at depth 16 is
clean. The v6e-8 equivalent is stable on the same sharding contract
once the predicate and permute fixes land; the optimizer-stress
receipt on that host reports zero divergence across 1000 steps on
both the DistAdamW + DistMuon scenario and the MoD/ReDo/MTP
conditional scenario.
None of this is finished. The tp_comm_overlap userbuffer path is
still on the TODO list; moe_shared_expert_overlap is not wired up;
pre-collective NaN checks still happen post-collective. What it is
is reproducible, which for a multi-dimensional sharding stack on two
backends is the only honest form of progress.
The MegaCpp nanochat POC is built and operated by David Gornshtein and Boris Tamarkin. If your own hybrid stack is in the "single-host green, multi-host yellow" state ours was three months ago, the shortest path through is: pin the replicate list, make every mesh predicate traceable on step 0, and assume any compile-stable single- host lane will need one more round of debugging the moment you hit real v6e-8 or multi-host CUDA.
References
- TENSOR_PARALLELISM.md
- BACKEND_STOPLIGHT_MATRIX.md
- CURRENT_STATE.md
- tp_sp_ep_fsdp_h200_bringup_2026-04-07.md
- 11-adaptive-sharding-auto-fit.md
- dist_optimizer_stress_tpu_v6e_2026-03-22.md
- TRAINING_PLAN.md
- training_plan_en.md
- training_review.md
- review_gcp_tpu.md
- CHANGELOG.md