tensor-parallel
sequence-parallel
expert-parallel
fsdp
tpu
v6e
h200

Tensor, Sequence, Expert, FSDP: The Hybrid Sharding Stack and the v6e-8 Bugs

9 min readDavid Gornshtein
Tensor, Sequence, Expert, FSDP: The Hybrid Sharding Stack and the v6e-8 Bugs

Tensor, Sequence, Expert, FSDP: The Hybrid Sharding Stack and the v6e-8 Bugs

When people say "we do tensor parallelism" they usually mean one of two things: the Megatron-style column/row split of MLP and attention, or the PyTorch DTensor variant of the same pattern. In the MegaCpp nanochat POC we run four sharding dimensions simultaneously and we run them on two unrelated backends. This post walks through the topology we actually ship, the contract on what stays replicated, and the specific class of bugs that only showed up once we left a single-host H200 box and ran the same model on TPU v6e-8.

What the stack actually is

The four dimensions in the hybrid stack:

  • TP (tensor parallel): column/row split of attention Q/K/V/O and MLP up/gate/down, on an intra-host axis with fast interconnect.
  • SP (sequence parallel): the activation/norm axis hanging off TP so that layer norms and residual adds operate on T/tp tokens per rank, not on full sequences.
  • EP (expert parallel): expert banks sharded across a separate mesh axis so MoE tokens dispatch via AlltoAll instead of per-expert all-reduce.
  • FSDP (ZeRO-3 style): data-parallel axis that also shards parameters, gradients, and optimizer states across DP ranks.

There is no single implementation of the above. On CUDA we combine DTensor TP (or Megatron TP when we want explicit parallel-linear wrappers), SP, EP via DeepEP dispatch, and FSDP2 through fully_shard. On XLA we do SPMD TP via _apply_tensor_parallel_sharding over a named mesh, combined with apply_fsdp_sharding and an expert axis when MoE is enabled. They are two entirely different pieces of code that happen to honour the same replicate/shard contract.

Mesh shapes

The XLA mesh topology is described explicitly in the base trainer:

  • ("data",) when TP=1 and EP=1
  • ("data", "model") when TP>1
  • ("data", "expert") when EP>1 and TP=1
  • ("data", "expert", "model") when EP>1 and TP>1

So TPU work is one-axis, two-axis, or three-axis sharding depending on the run. On CUDA the mesh is analogous but built via torch.distributed.device_mesh, with FSDP2 wrapping each block on the DP axis and TP sharding applied through parallelize_module.

The important invariant is that TP and EP axes are fast-local: NVLink/NVSwitch on H200, ICI on v6e. DP can be either intra-host or cross-host, but in the POC we deliberately kept it single-host for the first receipts to pin down the compute-and-comm contract before adding DCN variance.

What stays replicated

After a year of chasing "why is rank 3 drifting" bugs we have a short, blunt replicate list:

  • Mamba weights and state-space internals
  • Engram parameters
  • mHC parameters
  • ngram_hash projection tables
  • shared router / shared-expert / null-expert paths
  • layer norm and QK-norm weights (with explicit post-backward all-reduce hooks on the TP group to keep them bit-identical)
  • LoRA adapters (excluded via FSDP2 ignored_params or resynced via register_lora_grad_hooks)

Anything "sharp" or routing-sensitive is pinned replicated unless the runtime explicitly implements a safe parallel rule. This is not theoretical conservatism: most of the longest bisects we ran traced to a router or a norm that the TP pass had silently sharded because the tensor was a simple nn.Linear that looked indistinguishable from a dense FFN.

Divisibility contract

The canonical shape rules live in the base trainer, enforced at startup:

  • num_heads % tensor_parallel == 0
  • effective n_kv_head % tensor_parallel == 0
  • fused (n_head + 2 * n_kv_head) % tensor_parallel == 0
  • dsa_indexer_heads % tensor_parallel == 0 when DSA is active

These are not advisory. Violating any of them produces silent correctness issues that only surface several hundred steps in, or on specific pipeline/EP combinations. We treat them as part of the sharding contract, not as performance hints.

The hybrid with LoRA is a real path, not a footnote

Adapter-aware TP was the part that almost did not make it into the POC. Early wrappers treated LoRA as "inject after TP, hope for the best" and broke on DTensor TP + LoRA. The current contract:

  • DTensor TP and Megatron TP both have explicit LoRA-aware wrapping.
  • Both have checked-in test coverage.
  • FSDP2 wrapping treats LoRA as ignored_params when it exists at wrap time, and otherwise relies on register_lora_grad_hooks to all-reduce adapter gradients on the DP group.

This is boring plumbing. It is also the only reason finetuning survives the hybrid mesh.

The validated CUDA lane

For the record, the specific H200 geometry that passes the bring-up matrix is:

  • depth=16, n_embd=128, head_dim=32, n_kv_head=2, max_seq_len=128, device_batch_size=2, total_batch_size=256
  • tensor_parallel=2, sequence_parallel=on, fsdp_cuda=on
  • MoE with moe_token_choice, moe_n_routed_experts=8, moe_top_k=2, moe_expert_size=64, expert_parallel=2, deepep_dispatch=on
  • FSDP_USE_ORIG_PARAMS=true, TORCH_NCCL_AVOID_RECORD_STREAMS=1, TRITON_DEFAULT_NUM_STAGES=2

On this geometry, TP + SP + EP + FSDP2 eager is alive, multi-step is alive, and compile is compile-stable after a specific MoE fix described below. Depth expansion to 16 on the same 4-GPU box was straightforward once the eager lane was clean.

Bugs that only appeared on v6e-8

Single-host H200 is forgiving in one important way: every rank shares a clock domain, and collective ordering is mostly implicit. TPU v6e-8 is less forgiving. Three classes of bug that passed every CUDA smoke and then lit up on v6e-8:

1. EP-active predicate based on the wrong mesh

The CUDA-side fix is documented in the bring-up report, but the same root cause affected XLA differently. The symptom on v6e-8 was Block+MoE tokens going through the non-EP path when the user asked for --expert_parallel=2 --expert_tensor_parallel=0 ("follow TP"). The old predicate _expert_parallel_active = expert_tp_mesh is not None was wrong: EP can be active while expert_tp_mesh is None, because "follow TP" does not allocate a separate expert-TP mesh. The block then installed manual SP hooks on inner.mlp and expanded router_in from local shard to full sequence. On a single host this did not error; on v6e-8 it produced subtly inconsistent token layouts across chips that only manifested as a loss plateau around step 3000.

Fix: compute _expert_parallel_active from expert_parallel_degree > 1 and plumb expert_parallel_degree explicitly through the CUDA TP wrapper (and the XLA mesh factory). After the fix, tracing on the passing lane prints ep_active=True ep_local=True manual_sp=False ep_lane=True, and block_moe_mlp_after_clear pre=0 post=0 confirms there are no stale hooks.

2. PP group duplication hang

Not strictly a v6e-8 bug, but it behaved like one. On CUDA, `PP + TP

  • SPcould hang insidedist.new_group(_pp_group_ranks)because the existingdp_process_group` already had the same membership and creating a duplicate NCCL subgroup deadlocked. On TPU the equivalent manifested as the SPMD mesh rejecting a redundant submesh with no useful error. The fix on both sides is the same idea: reuse the existing DP group as the PP group when memberships are identical, instead of creating a new one. We also added all-rank PP group tracing around the new-group call so future hangs produce visible output instead of silence.

3. MoE + compile on the Megatron-permute padded path

The longest bisect of the rollout. On single-rank CUDA, standalone TokenChoiceMoELayer compiled fine. On 4x H200 with TP + SP + FSDP2 + compile + MoE, it failed deep in Inductor with both a fallback and a decomp for same op: aten.index_add.default. The inner-path ablations narrowed it all the way down: replacing routing_weights.scatter_add did not help; replacing the final padded_routed_out_flat.scatter_add with index_add did not help; zeroing aux/z-loss bookkeeping did not help; expert-identity did not help. What did help was forcing prefer_megatron_permute = False. That removed the failure entirely and unblocked the full-lane compile.

We committed the change as the default for compiled CUDA MoE and kept the override available via NANOCHAT_FORCE_MEGATRON_PERMUTE_COMPILE=1 for future upstream retesting. On v6e-8 the failure class was different (XLA does not use Inductor), but the same permute path was implicated in a distinct tokenization symmetry bug: the Megatron padded permute assumes a specific dispatch order that does not match what TPU AlltoAll produces under --expert_tensor_parallel=0. Killing the permute on both backends kept the shared code clean.

4. Unbatched P2P init warning turning into real overhead

On the H200 side, Schedule1F1B in torch.distributed.pipelining.schedules degrades homogeneous batches (all isend or all irecv) into raw dist.isend / dist.irecv. That behaviour lights up a ProcessGroupNCCL.cpp unbatched P2P warning during pipeline eager init. We patched it in a narrow schedule.step(...) window behind NANOCHAT_FORCE_BATCHED_PP_INIT_P2P (default 1) to force dist.batch_isend_irecv(p2p_ops). After the patch, the warning disappeared and the clean PP + TP + SP + compile lane ran at ~1280 tok/sec at depth 4, ~1516 tok/sec steady-state on depth 16. On TPU v6e-8 the analogous effect was worse: each unbatched send/recv materialised as a separate XLA op and stretched compile time, because XLA does not hide P2P in the same way NCCL does.

What single-host never caught

Across all four bug classes, the pattern is the same. Single-host masks two things: (a) it serialises order implicitly via a shared scheduler and a single driver, and (b) it silently absorbs predicate/mesh mistakes because the ranks sit on the same device memory. v6e-8 removes both props. Any code that assumed "the mesh shape I asked for is the mesh shape I got" instead of "prove it at runtime" eventually failed.

Our concrete response:

  • Every mesh predicate (expert_parallel_active, sp_on_block, ep_lane) now prints a one-line trace on step 0 of every rank, tagged with the rank and the mesh id.
  • Any code path that creates a new process group first checks if an existing group has identical membership, and reuses it.
  • Compile flag pairs (MoE + permute, MoE + FSDP2, EP + compile) that we know are fragile have explicit env gates with conservative defaults, so upstream retests do not silently reactivate a known bad configuration.
  • The receipt format for every bring-up log records requested_backend, actual_backend, runtime_mode, and fallback_reason so "I requested FA4" versus "FA4 actually executed" is never ambiguous.

What we kept out of the sharding contract

Three things deliberately do not live in the sharding doc:

  • One-off benchmark numbers. Those belong in dated reports, not in a contract that changes with each model revision.
  • Old batch-math rules. The old num_devices / tp_degree formula is no longer the general truth. Effective DP size is world_size / (pp * tp * ep) and is owned by runtime helpers.
  • Historical OOM tables. Outdated hardware capacity claims caused more bugs than they prevented, because they encouraged people to skip actual trial runs.

Where we are

The validated compile-stable lane is TP + SP + EP + FSDP2 + compile on 4x H200, small geometry, five-step run, recompile_limit_count=0, capacity_error_count=0, no aten.index_add.default. Depth expansion to 16 is clean. PP + TP + SP + compile at depth 16 is clean. The v6e-8 equivalent is stable on the same sharding contract once the predicate and permute fixes land; the optimizer-stress receipt on that host reports zero divergence across 1000 steps on both the DistAdamW + DistMuon scenario and the MoD/ReDo/MTP conditional scenario.

None of this is finished. The tp_comm_overlap userbuffer path is still on the TODO list; moe_shared_expert_overlap is not wired up; pre-collective NaN checks still happen post-collective. What it is is reproducible, which for a multi-dimensional sharding stack on two backends is the only honest form of progress.

The MegaCpp nanochat POC is built and operated by David Gornshtein and Boris Tamarkin. If your own hybrid stack is in the "single-host green, multi-host yellow" state ours was three months ago, the shortest path through is: pin the replicate list, make every mesh predicate traceable on step 0, and assume any compile-stable single- host lane will need one more round of debugging the moment you hit real v6e-8 or multi-host CUDA.

References

  • TENSOR_PARALLELISM.md
  • BACKEND_STOPLIGHT_MATRIX.md
  • CURRENT_STATE.md
  • tp_sp_ep_fsdp_h200_bringup_2026-04-07.md
  • 11-adaptive-sharding-auto-fit.md
  • dist_optimizer_stress_tpu_v6e_2026-03-22.md
  • TRAINING_PLAN.md
  • training_plan_en.md
  • training_review.md
  • review_gcp_tpu.md
  • CHANGELOG.md
David Gornshtein • Datasunrise OÜMore posts →