torch_xla
tpu
v6e
compile-cache
dynamo
sharding

Graph Recompilation Hell on torch_xla: Shape Polymorphism, Dynamic Fallbacks, and Stabilizing the Cache

10 min readDavid Gornshtein
Graph Recompilation Hell on torch_xla: Shape Polymorphism, Dynamic Fallbacks, and Stabilizing the Cache

Graph Recompilation Hell on torch_xla

The first time we ran a NAM52 MoE step on TPU v6e-8 it took eighteen minutes. The second step took 1.5 seconds. The third step took 1.5 seconds. The fortieth step took 73 seconds, then 1.5 again. Step 196 took another minute. Step 312 took two minutes. By step 1000 we had spent more wall time inside XlaCompile than inside the actual training loop.

This is the post about why that happened, what we did about it, and where the limits of torch_xla shape polymorphism actually are in 2026.

The model that triggered it

The numbers below come from the nanochat POC: a 4.04 B-parameter NAM52 stack with 52 blocks in an AEME pattern (Attention / Expert / Mamba / Engram), GQA with 8 KV heads, fused QKV, MoE with 64 routed experts plus 1 shared expert, plus optional Multi-Token Prediction (MTP) heads at depths 1 and 3, plus DSA (sparse attention) on a subset of layers. The stack on TPU was Python 3.13, custom torch 2.9.0a0+git21fec65, custom torch_xla 2.9.0+gitc04e61c, libtpu 0.0.36, jax 0.9.0. SPMD-on, PJRT_DEVICE=TPU, model torch.compile explicitly off (it is unsafe on XLA in this stack and we leave the JIT to XLA).

The reference numbers are concrete. An MTP attribution sweep on v6e-8 showed the mtp_depth=1 lane needed about 80 XLA compilations and 72 minutes of compile wall-clock to settle a 20-step run; the mtp_depth=3 lane needed 196 compilations and 137 minutes. Same model code, same shapes, only the MTP-depth-conditional branches changing. That is the cost we were paying for unstable graph keys.

What torch_xla actually keys on

The TL;DR of torch_xla graph caching: a compiled program is keyed by the traced graph plus the shapes and dtypes of every input tensor. It is not keyed by their values. So far, so good. The trap is what counts as "the traced graph" when you are using torch_xla.compile() around a Python-level fwd+bwd in 2026.

Three things end up baked into that key in ways that are easy to miss:

  1. Any Python int, float, or bool that flows into a tensor op as a constant becomes part of the HLO. Change the value, get a new graph. This is the same trap as Dynamo on CUDA, but torch_xla surfaces it more silently because there is no recompile log line by default.
  2. Any tensor.item() or .cpu() call inside the traced region collapses the trace at that point and forces a host-device sync. The next op restarts a new graph fragment with whatever shape the downstream tensors take. Two such syncs per step is two graph fragments that all need their own cache entries.
  3. Any non-static tensor shape - typically the result of a bool-mask gather, a top-k whose k is data-dependent, or a varlen attention layout - either forces dynamic-shape fallback (slow on v6e) or triggers a per-shape recompile.

On a model the size of NAM52, every one of these traps was hit by at least one block. Finding them was a multi-week archaeology project.

The recompilation patterns we hit

Python counters in compiled blocks

The most spectacular case was the MoE overflow accumulator. The MoE block was tracking how many tokens had been dropped due to capacity overflow, and it stored the running total as a Python int attribute on the module:

self._overflow_total += int(dropped.sum().item())

Every forward call mutated self._overflow_total. Dynamo / torch_xla treated each new value as a fresh constant in the trace. On 8-GPU DDP we saw up to 64 recompiles per training step before the cache stabilised, and on TPU it manifested as compile-storm hangs and eventually NaN.

The fix is the Megatron-Core pattern: hold counters in register_buffer tensors, and do the increment as an in-place tensor op rather than a Python increment. Do this for every counter that gets touched in a compiled region. We had to chase three of them across MoE, the GPT block, and our sparse-attention path before the storm stopped.

Grad-create vs grad-accumulate as two graphs

torch_xla.compile() traces what it sees. On step 0 of training with gradient accumulation, the autograd state on rank 0 has no existing gradient tensors, so the first micro-step traces a grad-create graph (acc.grad = new_grad). On every subsequent micro-step the gradients already exist, so the trace produces a grad-accumulate graph (acc.grad.add_(new_grad)). These are two distinct HLO programs, both needed.

If you wrap each micro-step in its own torch_xla.compile(), the cache gets both variants the first time around and you pay two compile bills at startup. If you also do optimizer.zero_grad(set_to_none=True) between accumulation cycles, you flip the optimizer back into the grad-create shape on every cycle, which causes the optimizer step itself to alternate between two compiled programs forever. The fix is a single torch_xla.compile() around the whole accumulation chain (so create and accumulate land in one HLO), and to avoid set_to_none=True in the inner loop. We documented this in the training script as a hard requirement.

Enriched-batch metadata as accidental polymorphism

Our dataloader produces "enriched" batches that carry optional per-document metadata: doc-id arrays, FIM/IFIM boundary tensors, per-token validity counts, occasional auxiliary masks. Some batches have all of those keys; some are missing one or two; some have empty tensors.

The first version of the model entry path branched on key in batch for each optional field. That meant the Python control flow taken inside the traced region depended on which optional keys were present, which meant every distinct subset of present keys traced its own HLO. Across a real shuffle that produced a long tail of one-shot graphs.

The fix was to canonicalise the structure: a single helper (_canonicalize_structure_meta_for_xla) that pads or zeroes every optional field to a fixed shape and dtype, so the traced graph is structurally identical regardless of which keys the dataloader emitted. The downstream model then takes the canonical batch and uses tensor masks (not Python ifs) to decide what is active. That alone removed about a dozen recompiles per epoch on the data we were running.

.item() inside fast paths

Two tiny pieces of code were each costing about ten host-device syncs per step:

  • flash_attention.py's _doc_ids_are_uniform() was doing (doc_ids == doc_ids[..., :1]).all().item() to decide whether the GQA expansion path could be skipped. On XLA we replaced it with a branch-free (doc_ids - doc_ids[..., :1]).any() whose result stays on device, plus a static fast path keyed on a launch-time flag, and pushed the dynamic check behind a device.type == "xla" guard so it never executes on TPU.
  • mamba2.py's boundary check was calling .any().item() on a per-step boundary mask. Same pattern: replace with a tensor-only path on XLA.

Each of these was one line. Together they cut roughly twenty syncs and the matching mid-step graph fragments per step.

MoE capacity that wasn't quite static

MoE capacity is min(max_tokens_per_expert_init, BT) where BT is batch * tokens. We had a code path where BT was being recomputed from the batch's actual token count rather than the configured one, which was right within tolerance for a fixed dataloader but wrong enough to occasionally produce a different int value when the collator padded slightly differently. That int flowed into the trace as a constant, so each unique BT value got its own compiled MoE dispatch program. The fix was to clamp at the configured BT and never let the collator's padding leak into the compiled graph (test_xla_capacity_clamped_to_bt covers it).

Where shape polymorphism actually ends

torch_xla does support a degree of dynamic-shape compilation, but it is not the load-bearing surface most people assume. In the 2.9 stack we were on, shape-polymorphic compilation works for:

  • broadcastable elementwise ops with one polymorphic dimension
  • contractions whose contracted dimension is static
  • gather/scatter ops where index counts are static and only the input payload size is dynamic

It does not work, in the sense of producing a single cached HLO program, for:

  • topk with a non-constant k
  • mask-driven gathers whose output count varies per call
  • attention layouts where valid_token_counts differs per batch
  • MoE dispatch where capacity is allowed to vary

When you ask for any of those, torch_xla falls back to per-shape recompilation, and on a model with the structural variability of NAM52 that fallback path is cost-prohibitive. The pragmatic answer is to make everything static: pad to a fixed length, clamp k, fix the per-step capacity. Pay the constant-overhead cost. Skip the long tail of compiles.

The other useful surface is the bounded XLA flash-attention dispatcher (xla_flash_pallas / xla_splash_via_trace_pallas), which is itself a statically-shaped Pallas kernel. The provenance receipt we have on v6e-8 shows it dispatching cleanly with and without softcap, with no fallback path. If your attention layout fits inside that bounded contract you keep the cache stable for free; if it does not, you should expect varlen recompiles even from the "fast" backend.

What stabilised the cache

The combination of fixes that finally produced a clean cache after the first 5-10 steps:

  • Persistent disk cache via XLA_COMPILATION_CACHE_DIR pointing at a local SSD path (/data/.xla_cache when available). Repeated startups of the same configuration drop from ~18 minutes to ~90 seconds. Do not point this at GCS - the early-init path on base_train.py does not safely handle remote URIs.
  • All counters and accumulators in register_buffer, none in Python attributes that flow into compiled regions.
  • A single torch_xla.compile() per micro-step cycle, not per micro-step.
  • set_to_none=False for grad zeroing inside accumulation loops.
  • Canonicalised batch structure: optional keys are always present, always the same shape and dtype, masking decides what is live.
  • Static MoE capacity, clamped at config time, never touched by the collator.
  • --xla_barrier_every_n_layers=N for memory headroom only after the cache is clean - barriers split the trace into more programs, so they multiply the cost of any remaining cache instability. Do not use barriers as a debugging tool for recompilation.
  • XLA_USE_BF16=0 (we cast the model to bf16 ourselves; the env-driven cast is incompatible with current Pallas kernels and causes an extra set of fallback graphs).

After all of this, the MTP=1 sweep cache settled in roughly 5 steps with ~15 cached programs and stayed flat for the rest of the run. The MTP=3 sweep still needed ~30 cached programs - the per-depth heads each contribute their own attention and projection programs - but stayed flat after step ~10. Compile time becomes a one-time cost once the persistent cache is warm, and we re-use it across runs of the same config.

Things that pretended to help and didn't

  • torch.compile() on the model on TPU. Stock guidance over the years has gone back and forth on this; the current torch_xla story is that XLA's own JIT does the work and adding torch.compile on top produces hangs, NaNs, or duplicate compile work. Leave it off.
  • Aggressive xla_flag_profile=auto bundles inherited from MaxText. The async-collective-fusion flags and the data-parallel all-reduce optimisation flag both produced NaN from step 1 on every NAM12 and NAM52 config we tried. They also caused subtle compile-cache key drift when we toggled them. We default to --xla_flag_profile=none now and add overrides explicitly through --xla_flags_extra when we need them.
  • Speculative warmup of "common" batch shapes. Pre-tracing a small set of canonical batches at startup seeded the cache for the seeded shapes and helped nothing else, because dataloader sharding meant rank-local batch shapes drifted slightly from the seeds. Pin the shapes the dataloader produces; do not preheat the cache.

Practical checklist

If you want to stop seeing surprise compiles after step 0 on torch_xla 2.9-class:

  • Turn on PT_XLA_DEBUG=1 and count unique HLO programs across the loop. If the count keeps growing past step 20, you have an instability somewhere.
  • Remove every .item(), .cpu(), int(...)-of-tensor, and Python-level if on tensor values inside any region wrapped by torch_xla.compile().
  • Move every Python counter in a compiled module to register_buffer.
  • Canonicalise dataloader output: optional fields always present, always the same shape, masked when inactive.
  • Pin a persistent local cache. Do not point it at remote storage.
  • Do not stack torch.compile() on top of the XLA JIT.

The good news: this is one-time work. After the cache stabilised, the same NAM52 configuration warmed from cold disk in under two minutes and ran indefinitely without further compiles. The bad news: the torch_xla runtime does not tell you which line caused a recompile, so the fix is mechanical and the discovery curve is unforgiving.

References

  • review_gcp_tpu.md
  • docs/CURRENT_STATE.md
  • docs/TPU_SETUP.md
  • docs/TENSOR_PARALLELISM.md
  • docs/BACKEND_STOPLIGHT_MATRIX.md
  • reports/dist_optimizer_stress_tpu_v6e_2026-03-22.md
  • reports/mtp_depth_attribution_tpu_v6e_2026-03-22.md
  • reports/tpu_backend_provenance_v6e8_2026-03-16.json
  • CHANGELOG.md
  • training_review.md
  • speed_rep_xx.md
David Gornshtein • Datasunrise OÜMore posts →