tpu
v6e
xla
hbm
oom
sharding
training

OOM Hunting on TPU v6e: HBM Fragmentation, the XLA Allocator, and What Actually Moved Memory

9 min readDavid Gornshtein
OOM Hunting on TPU v6e: HBM Fragmentation, the XLA Allocator, and What Actually Moved Memory

OOM Hunting on TPU v6e

TPU v6e gives you 32 GB of HBM per chip, of which roughly 31.25 GB is addressable from XLA after the runtime takes its cut. That number sets the frame for everything in this post. We spent several weeks of the nanochat POC trying to fit a 4.09 B-parameter NAM52 MoE model onto v6e-8 and v6e-16 pods, and the punchline is unflattering: on v6e the optimizer is the OOM, not the model.

This is what we learned along the way - what HBM fragmentation actually looks like under the XLA allocator, which activation-checkpointing reshapes moved memory and which only moved compile time, and the handful of knobs that gave us back enough headroom to train.

The shape of the problem

The model under load was NAM52: 52 blocks, fused QKV, GQA with 8 KV heads, a 64-routed-expert + 1-shared MoE bank, plus our usual auxiliary surfaces (Mamba, Engram, mHC, n-gram hash). Total parameter count: 4,042,836,277. In bf16 that is ~8 GB of weights. With Muon on the matrix params and AdamW on the embeddings, optimizer state lands somewhere between 16 and 22 GB depending on what you fold into Polar Express. On a v6e chip with 31.25 GB usable HBM, the math has no slack.

The first OOMs we saw were the obvious ones: bump batch size, watch HBM overflow. The interesting OOMs came later, after we had tuned everything the textbook tells you to tune and the allocator was still failing on graphs that, on paper, fit.

The XLA allocator is not your friend

On v6e, the XLA allocator does not behave like a CUDA caching allocator with a nice nvidia-smi-shaped picture of the world. There are two separate failure modes that both surface as "OOM" and need very different responses.

The first is CompileTimeHbmOom. This fires before a single tensor lands on device. XLA has finished tracing your graph, computed the program allocation, and decided that the worst-case live set across the schedule exceeds the per-chip limit. Typical message: compiled program 22.11G on 31.25G HBM. The interesting bit is that this is computed against the scheduled peak, not the steady-state working set, so a graph that would actually run fine if the scheduler made a different choice still gets rejected at compile time.

The second is RuntimeOom / RuntimeAllocationFailure. This is the allocator failing to satisfy a request after the graph is already running. On v6e this is almost always the optimizer graph allocating its temporaries on top of an already-resident parameter and momentum set, rather than the forward-backward graph asking for activations. The characteristic signature: forward+backward fits, the first optimizer step OOMs, free HBM at the moment of failure is around 19-22 GB, requested allocation is around 21-23 GB.

Both failures will sometimes also show up with cryptic MSA (Memory Space Assignment) errors in libtpu, which is what XLA reports when its scheduler-aware allocator gives up. We saw this most often when MoE all-to-all temporaries grew past about 6 GB per chip; libtpu's MSA pass would refuse to place them and the run died inside jellyfish_msa. The only reliable fix was to shrink the all-to-all working set, which usually meant smaller EP or smaller per-expert capacity, not a bigger box.

The optimizer is the OOM

The single most important insight from this work, and the one that took us longest to internalise: on v6e, the compiled optimizer graph size is approximately constant in batch and sequence. Forward and backward activations scale with dbs * seq_len. The optimizer step does not. Muon with Polar Express keeps a stacked momentum buffer for every matrix parameter; AdamW keeps m and v for every embedding parameter; both sets need a fresh temporary at step time for the Newton-Schulz iterations and the bf16/fp32 cast traffic.

On NAM52 that compiled optimizer graph lands around 20-22 GB on a v6e chip after TP=2 sharding. That leaves about 9-11 GB for the entire forward+backward live set, which means any plausible production batch pushes you over the limit.

The diagnostic that sold us on this was a sweep of OOM modes at the SyncTensorsGraph.2980 boundary, all of which failed inside the optimizer step:

dbs=4 seq=4096 TP=2 EP=4: CompileTimeHbmOom 35.45G > 31.25G
dbs=2 seq=4096 TP=2 EP=4: RuntimeOom 22.31G > 19.70G free
dbs=1 seq=8192 TP=2 EP=4: RuntimeOom 21.40G > 19.69G free

The compiled optimizer graph is the floor. Everything else is what little space is left above it.

There were three things that actually moved this number on v6e:

  1. Keeping Polar Express in bf16 instead of fp32. Doing PE in fp32 doubled the stacked all-reduce buffer (360 MB → 720 MB per param group) and pushed the compiled optimizer graph straight into CompileTimeHbmOom. bf16 PE was a hard requirement for any NAM52 fit on v6e-16.
  2. Sharding the optimizer state via FSDP. This is obvious in retrospect. It was not obvious at the time because the --fsdp flag had been sitting unused for months: enabling it moved a NAM52 v6e-8 run from "no fit at TP=4 dp=2 dbs=1" to "fit at TP=1 dp=8 dbs=2, 65 K tokens per step, 48.9 K tok/sec".
  3. EP placement. EP=4 TP=2 with 16 experts per chip fit; EP=2 TP=2 with 32 experts per chip OOMed at 23.92 G needed; EP=8 TP=2 OOMed at the compile stage because the all-to-all program graph itself blew past 22 G. EP is not a memory-free knob - it trades parameter sharding for collective working sets, and on v6e the collective working sets win the budget fight more often than you would expect.

Activation checkpointing: what reshaped, what didn't

We tried five distinct activation-checkpointing strategies on the forward+backward graph. Two paid for themselves on v6e, two were neutral, and one was a quiet regression once you account for compile cost.

xla_barrier_every_n_layers=N paid off. Inserting xm.mark_step() barriers inside the model splits fwd+bwd into smaller XLA programs whose per-program live sets are bounded by what crosses the barrier, not by what the global scheduler thinks could be live at peak. On NAM52 with barrier=2 we cut peak forward HBM by roughly 40 %, at the cost of disabling fused fwd+bwd compilation and accepting a per-barrier mark_step round-trip. This was the difference between fitting and not fitting at TP=4. It does not help the optimizer graph at all, which is what limits the eventual win.

Selective Mamba conv+BC recompute paid off. The Mamba state-space blocks were carrying surprisingly heavy intermediates (the per-head projection of the conv state plus the BC delta). Switching the 13 MBlocks to recompute their conv and BC tensors during backward saved roughly 6 GB per chip. Cheap to implement, no compile cost.

viewless_output + recomputed norms were neutral on v6e. They saved about 7 GB on H200 and almost nothing here, because XLA's scheduler already aliases the output of an RMSNorm with its input under the same buffer in most cases. We left it on for cross-platform consistency, but do not credit it on the v6e budget.

Per-block torch.utils.checkpoint on the AEME sequence was a regression. The XLA tracer split the recomputed forward into its own sub-program with its own activation working set, and the scheduler did not always free the original activations before the recompute landed - net effect was more peak HBM, not less, and a meaningful jump in compile time. This is one of the places where CUDA intuition leads you wrong on TPU.

MoE expert recompute was complicated. On the H200 lane this saved ~44 GB across 22 EBlocks. On v6e it saved roughly 3-4 GB and added a second compile of the expert dispatch program. We kept it on because the absolute saving was still positive, but the H200 numbers do not transfer.

Fragmentation, and what to do about it

The XLA allocator on v6e is in practice a buddy-style allocator with coarse granularity for program-allocated buffers and a finer arena for temporaries. We saw two practical fragmentation patterns.

The first is capture-time pinning: any tensor that survives across a mark_step boundary gets pinned into the program-allocated arena and stays there for the lifetime of the compiled program. If you build up small auxiliary tensors during early steps (capacity counters, MoE overflow accumulators, debug stats) those get pinned too. We had a case where a Python int counter inside a MoE block was being incremented every forward, which caused the dynamo trace to include a fresh constant every step, which fragmented the optimizer arena enough to OOM by step ~50. The fix was the Megatron-Core pattern: hold counters in register_buffer tensors, not Python ints.

The second is compile-cache contention. With XLA_COMPILATION_CACHE_DIR pointed at a local SSD, repeated startup of the same training script with slightly different shapes accumulates compile artifacts that the allocator sometimes pre-faults during program load. We did not chase this to ground; the workaround was to keep the cache on /data/.xla_cache and clear it between architecture changes.

The XLA flag bundles, briefly

We will not relitigate --xla_flag_profile here in detail because it deserves its own post, but two things matter for OOM specifically:

  • The MaxText-derived async-collective-fusion bundle (--xla_tpu_enable_async_collective_fusion=true and friends) was a NaN source on every NAM12 and NAM52 config we tried. It also tended to inflate compile-time peak HBM estimates by a few GB. We turned the whole bundle off and accepted the throughput cost.
  • The offload profile is the only profile that meaningfully reduces steady-state HBM, by spilling optimizer state to host memory across the PCIe bus. On v6e with optimizer state pinned at 22 GB this is the only knob that actually changes the budget shape, but the host-offload scheduler trades that GB for ~30 % of step time. Use it when you cannot fit; do not use it as a default.

What actually moved memory

If you are landing here because your own v6e job is OOMing and you want the short list:

  • Measure the optimizer graph in isolation. If it is more than ~22 GB on a v6e chip, no amount of activation-checkpointing will save you.
  • Force Polar Express to bf16 unless you have a documented numerical reason for fp32, and do not let --matrix_lr schedules silently flip it back.
  • Turn on --fsdp before you turn on anything else. Optimizer-state sharding is the only thing that meaningfully changes the v6e budget shape for models above ~2 B parameters.
  • Pick EP to fit the all-to-all program, not just the expert bank. EP=4 was the right answer for NAM52 on v6e-16; EP=8 compiled larger graphs than the ones it sharded.
  • Use --xla_barrier_every_n_layers to bound per-program activation peaks. Accept the per-barrier round-trip; it is worth the headroom.
  • Keep dynamic Python state out of compiled blocks. Counters belong in register_buffer. MoE overflow accumulators belong in tensors.
  • Trust the per-platform numbers. A reshape that saves 44 GB on H200 may save 3 GB on v6e, and a checkpointing strategy that helps on CUDA may hurt on XLA.

The headline win after all of this was a NAM52 4.09 B-parameter MoE training step running stably on v6e-16 at EP=4 TP=2 dp=2, dbs=1, seq=4096, with bf16 PE and barrier-split fwd+bwd. The HBM budget had about 1.5 GB of headroom. Most of the engineering work above was spent earning that 1.5 GB.

References

  • review_gcp_tpu.md
  • docs/CURRENT_STATE.md
  • docs/TPU_SETUP.md
  • docs/TENSOR_PARALLELISM.md
  • docs/BACKEND_STOPLIGHT_MATRIX.md
  • reports/dist_optimizer_stress_tpu_v6e_2026-03-22.md
  • reports/mtp_depth_attribution_tpu_v6e_2026-03-22.md
  • reports/tpu_backend_provenance_v6e8_2026-03-16.json
  • CHANGELOG.md
  • training_review.md
  • speed_rep_xx.md
David Gornshtein • Datasunrise OÜMore posts →