tpu
v6e
performance
mfu
sharding
spmd
fsdp
moe

TPU v6e Performance Deep Dive: Real MFU, Sharding Topology, and the Things That Pretended to Help

10 min readDavid Gornshtein
TPU v6e Performance Deep Dive: Real MFU, Sharding Topology, and the Things That Pretended to Help

TPU v6e Performance Deep Dive

TPU v6e is, on paper, a delight. Each chip is rated at roughly 918 TFLOPS bf16 peak. A v6e-16 pod offers about 14.7 PFLOPS of bf16 compute and just shy of 500 GB of HBM. The interconnect is fast, SPMD is stable in the current torch_xla stack, and the price-performance is attractive against a comparable H200 cluster.

And yet the first NAM52 4 B-parameter MoE training run we landed on v6e-8 ran at 0.5 % MFU. The second one, after a week of work, ran at 8.6 % MFU. The eventual production-shaped configuration on v6e-16 ran at 24,100 tok/sec - a real number, but corresponding to roughly 0.46 % MFU. This post is about why that gap exists, what closed parts of it, and what we learned about where v6e is and is not the right tool.

All numbers below are from the nanochat POC: NAM52 with 52 blocks, fused QKV, GQA with 8 KV heads, MoE with 64 routed experts plus 1 shared, optional MTP heads, plus our usual auxiliary surfaces. Stack: Python 3.13, custom torch 2.9.0a0+git21fec65, custom torch_xla 2.9.0+gitc04e61c, libtpu 0.0.36, jax 0.9.0. SPMD on, PJRT_DEVICE=TPU, model torch.compile off.

Headline numbers

The clean numbers we have, by topology:

Topology Config Tokens/step tok/sec MFU
v6e-8, NAM12 bare AEME TP=1 dp=8 dbs=8, 4K seq 262,144 638,000 ~50 %
v6e-8, NAM52 bare AEME TP=4 dp=2 dbs=1, 4K seq 8,192 27,800 8.6 %
v6e-8, NAM52 +features TP=4 dp=2 dbs=1, 4K seq 8,192 19,900 6.2 %
v6e-8, NAM52 +FSDP TP=1 dp=8 dbs=2, 4K seq 65,536 48,900 3.8 %
v6e-16, NAM52 +EP=4 EP=4 TP=2 dp=2, 4K seq 13,500 ~0.84 %
v6e-32, NAM52 +EP=4 EP=4 TP=2 dp=4, 4K seq 24,100 0.46 %

The gap between the NAM12 row and every NAM52 row is the entire story. NAM12 at depth 12 fits comfortably under TP=1 dp=8 dbs=8 with 262 K tokens per step and saturates the chips. NAM52 at depth 52 does not fit in the same shape, so it gets pushed into shapes that are memory-feasible but compute-starved. The MFU collapse is structural, not a kernel-quality problem.

Why MFU collapses on big NAM52

The arithmetic is sobering. A v6e chip is 918 TFLOPS bf16. NAM52 is about 5.6 B FLOPs per token (from estimate_flops()), so a single v6e-16 pod could in principle process about 2.6 M tokens per second just from compute. The number we actually achieved on v6e-16 EP=4 TP=2 dp=2 was about 13.5 K tok/sec. That is roughly 0.5 % of peak. Where does the other 99.5 % go?

Per-chip per-step the breakdown was:

  • Effective batch per chip: 1 sequence × 4096 tokens = 4096 tok
  • FLOPs per token: 5.6 B
  • Compute per chip per step: 2.87 TFLOP, 3.1 ms at peak
  • Actual step time: ~600 ms
  • Compute-bound fraction of step: ~0.5 %

The remaining 99.5 % is split, in roughly this order, between MoE all-to-all (EP=4 dispatch + combine), TP all-reduce on the model axis, gradient sync across the dp axis, the optimizer step (Muon Polar Express + AdamW on 4 B params), and HBM bandwidth shuffling parameters and activations through the chip. The single chip is idle for tensor cores almost the entire step.

The fundamental reason is that NAM52 cannot fit any larger per-chip batch. The optimizer state alone consumes about 22 GB of the 31.25 GB usable HBM at TP=2. That leaves ~9 GB for forward and backward live state, which is enough for exactly one 4 K sequence per chip. A bigger batch is the obvious lever for MFU; on v6e for a model this size, you do not get one.

Sharding topology - what actually worked

SPMD on v6e gives you several axes to play with: ("data",), ("data", "model"), ("data", "expert"), and ("data", "expert", "model"). The current scripts/base_train.py constructs whichever of these matches the requested TP and EP. The "right" topology depends entirely on whether your bottleneck is HBM, collective working set, or compute imbalance.

The sharding decisions that materially moved throughput on v6e:

FSDP over TP at depth 52. The single biggest jump on v6e-8 was enabling FSDP for the optimizer state. Without it, NAM52 only fit at TP=4 dp=2 dbs=1 (8,192 tokens per step, 27.8 K tok/sec, 8.6 % MFU). With it, NAM52 fit at TP=1 dp=8 dbs=2 (65,536 tokens per step, 48.9 K tok/sec). That is 8× the tokens per step, 1.76× the throughput, and a configuration where TP no longer steals the model dimension. The MFU number itself dropped to 3.8 % - because we are now amortising compile and collective overhead over a much larger batch - but wall-clock training time per token dropped by almost half. MFU is the wrong metric to optimise here; tokens per wall-clock second is the right one.

EP=4 over EP=8 on v6e-16. Expert parallelism is appealing because it shards the MoE bank across chips, but the cost is an all-to-all collective on every forward and backward. EP=8 on NAM52 produced a 22.11 G compiled program graph that did not fit in 31.25 G HBM, CompileTimeHbmOom before the first step. EP=4 with 16 experts per chip fit, with the 3D HybridMesh (2, 4, 2) for (data, expert, model). EP=2 with 32 experts per chip OOMed at runtime allocation (23.92 G needed). The sweet spot was narrow.

TP=2 stayed worth it on v6e-16. Going to TP=1 with EP=4 dp=4 on v6e-16 would have given more dp width but pushed every per-layer activation through dp instead of the cheaper TP all-reduce. We did not see a clean win on either side; TP=2 EP=4 dp=2 gave us the most stable baseline.

Scaling to v6e-32. Doubling chips from v6e-16 to v6e-32 took throughput from 13,500 tok/sec to 24,100 tok/sec - a 1.78× scaling factor for 2× the chips, or 89 % scaling efficiency. For a MoE with EP=4 spanning 32 chips, that is a respectable number. The dominant overhead at v6e-32 was the all-to-all working set rather than dp all-reduce; further scaling would likely require revisiting EP rather than just adding chips.

The bottleneck stack, in order

Ranked list for NAM52-class models on v6e:

  1. Optimizer-state memory. Ceiling on per-chip batch. FSDP, bf16 Polar Express. Nothing else moves the budget shape.
  2. MoE all-to-all working set. Pick EP for the all-to-all program size, not just for expert sharding. EP=4 was almost always right.
  3. Compile-cache stability. A single recompile per N steps is a throughput tax invisible in averages but ugly in tail latency.
  4. TP collective volume. TP=2 vs TP=4 changes the all-reduce volume per layer. The cheaper option that still fits wins.
  5. Attention backend. The bounded XLA Pallas flash-attention path (xla_flash_pallas + softcap variant) is fast and stable; we never beat it on v6e for layouts inside its support region. The Splash variant via Pallas trace ran at 4-5 ms per forward at 512 seq, 4 layers, 1024 hidden in the smoke harness.
  6. Per-step Python overhead. Small once the cache is warm; before then, every .item() and Python if on tensor values shows up.

Bottleneck analysis

v6e does not give you nsys, but the libtpu profiler plus torch_xla.debug.metrics shows the shape of the problem. The pattern that recurred across every NAM52 v6e run: tensor-core utilisation in the low single digits, collectives 35-55 % depending on EP/TP split, optimizer step 25-40 %, fwd+bwd arithmetic ~10-15 %, the rest split between host, dataloader, and sync barriers.

The H200 lane on the same model showed the same shape with different absolute numbers - H200 nsys captures had 88 % of GPU time on elementwise ops and only 2.7 % on matmul. NAM52's 52 blocks generate millions of small elementwise ops (residual adds, gating, normalisation) that are the wrong workload for either H200 tensor cores or v6e MXU. The chip family is incidental; the model shape is the bottleneck.

MTP and the per-depth tax

MTP (Multi-Token Prediction) was higher cost on v6e than on H200. The depth-attribution sweep on v6e-8 showed MTP=1 at ~80 compilations (~72 min compile wall-clock), MTP=3 at ~196 compilations (~137 min).

The K=3 sweep produced the per-depth contribution we expected: depth 1 at 51.1 %, depth 2 at 30.5 %, depth 3 at 18.4 %, matching the H200 reference and the exponential decay weights (0.51, 0.31, 0.18). The quality story transferred cleanly: depth 1 captures roughly half the MTP benefit at well under half the compile cost.

The compile-cost story did not transfer. On H200 with a warm cache, MTP=3 is about 4 % slower than MTP=1 at steady state. On v6e the 2.4× compile blowup is the dominant cost difference, because each MTP head adds new attention and projection sub-graphs that XLA compiles separately. With a persistent disk cache this is a one-time cost that amortises; without one, it is paid every restart. Practical recommendation: MTP=1 in production on v6e, MTP=3 only in research configurations with a warm persistent cache.

Things that pretended to help and didn't

  • MaxText async-collective-fusion XLA flag bundles (--xla_tpu_enable_async_collective_fusion=true, --xla_tpu_enable_data_parallel_all_reduce_opt=true, --xla_tpu_use_minor_sharding_for_major_trivial_input=true) caused NaN from step 1 on every NAM12 and NAM52 configuration across five v6e-8 hosts, and drove subtle compile-cache key drift. Upstream throughput claims did not reproduce. We default to --xla_flag_profile=none and add narrow overrides only with documented reason.

  • Activation sharding (mark_sharding) on MLP intermediates and QKV post-projection. No measurable throughput change at our shapes, and a fresh surface for sharding bugs (we hit one with Muon producing NaN on zero-initialised RowwiseParallel c_proj weights through this path). Rolled back.

  • Ring-folding mesh layout for the physical (2, 4) v6e chip topology. Difference within noise, complicated mesh reasoning, not worth keeping.

  • reuse_batch micro-optimisation. Small win on H200, no-op on v6e, where the dataloader is not on the critical path.

  • barrier_every_n_layers=2 for throughput. Barriers are a memory tool, not a throughput tool: more compiled programs, more per-step round-trips, slightly less throughput. Use them for HBM headroom or not at all.

  • XLA_USE_BF16=1. Incompatible with current Pallas kernels; produces an extra set of fallback graphs and bloats the cache. Cast to bf16 explicitly in the model and leave the env var unset.

What v6e is and is not good for

v6e earns its keep on dense small-and-medium models where per-chip batch fits comfortably. NAM12 at TP=1 dp=8 dbs=8 hit 638 K tok/sec on v6e-8 at ~50 % MFU - a healthy number that competes well with comparably-priced GPU options. SPMD is stable, the persistent compile cache is fast once warm, and the bounded Pallas attention kernels are good.

v6e struggles on deep MoE models that cannot fit a meaningful per-chip batch. NAM52 4 B at depth 52 with our optimizer state landed in a topology where MFU is structurally bounded below 1 %. The chip is not the problem and the kernel is not the problem - the memory budget is. The clean answer is more chips (v6e-32 was a 1.78× scaling step from v6e-16) or a smaller model. The dirty answer is sharding tricks until you claw back HBM for a second sequence per chip.

v6e is not a drop-in replacement for an H200 cluster of similar nominal compute. H200's 141 GB of HBM per device changes the game for 4-10 B models. Do not train a 4 B+ MoE on v6e unless you are price- sensitive enough to accept the MFU gap, or you can shrink to NAM12-class shapes.

What we would do differently

  • Decide per-chip HBM budget before model architecture. At 31.25 GB usable, anything that pushes optimizer state past ~22 GB forces compute-starved topologies.
  • Enable FSDP from day 1, not from week 4.
  • Default to the bounded Pallas attention path; only reach for custom kernels when its support region does not cover the layout.
  • Pin a persistent compile cache on fast local disk before any benchmarking.
  • Keep --xla_flag_profile=none as the production default; treat any flag bundle as a per-config hypothesis.
  • Measure tokens per wall-clock second, not MFU, and only chase MFU after the topology is locked.

The 24,100 tok/sec NAM52 v6e-32 number is a real, reproducible production training rate, and a fraction of what the same chips could do for a model that actually fits the v6e budget shape. Match the model to the silicon before you fight the silicon to fit the model.

References

  • review_gcp_tpu.md
  • docs/CURRENT_STATE.md
  • docs/TPU_SETUP.md
  • docs/TENSOR_PARALLELISM.md
  • docs/BACKEND_STOPLIGHT_MATRIX.md
  • reports/dist_optimizer_stress_tpu_v6e_2026-03-22.md
  • reports/mtp_depth_attribution_tpu_v6e_2026-03-22.md
  • reports/tpu_backend_provenance_v6e8_2026-03-16.json
  • reports/fa4_fsdp2_scaling_2026-03-22.json
  • CHANGELOG.md
  • training_review.md
  • speed_rep_xx.md
David Gornshtein • Datasunrise OÜMore posts →