Graph Recompilation Hell on torch_xla: Shape Polymorphism, Dynamic Fallbacks, and Stabilizing the Cache

Graph Recompilation Hell on torch_xla
The first time we ran a NAM52 MoE step on TPU v6e-8 it took eighteen
minutes. The second step took 1.5 seconds. The third step took 1.5
seconds. The fortieth step took 73 seconds, then 1.5 again. Step 196 took
another minute. Step 312 took two minutes. By step 1000 we had spent more
wall time inside XlaCompile than inside the actual training loop.
This is the post about why that happened, what we did about it, and where
the limits of torch_xla shape polymorphism actually are in 2026.
The model that triggered it
The numbers below come from the nanochat POC: a 4.04 B-parameter NAM52
stack with 52 blocks in an AEME pattern (Attention / Expert / Mamba /
Engram), GQA with 8 KV heads, fused QKV, MoE with 64 routed experts plus
1 shared expert, plus optional Multi-Token Prediction (MTP) heads at
depths 1 and 3, plus DSA (sparse attention) on a subset of layers. The
stack on TPU was Python 3.13, custom torch 2.9.0a0+git21fec65, custom
torch_xla 2.9.0+gitc04e61c, libtpu 0.0.36, jax 0.9.0. SPMD-on,
PJRT_DEVICE=TPU, model torch.compile explicitly off (it is unsafe on
XLA in this stack and we leave the JIT to XLA).
The reference numbers are concrete. An MTP attribution sweep on v6e-8
showed the mtp_depth=1 lane needed about 80 XLA compilations and
72 minutes of compile wall-clock to settle a 20-step run; the
mtp_depth=3 lane needed 196 compilations and 137 minutes. Same model
code, same shapes, only the MTP-depth-conditional branches changing.
That is the cost we were paying for unstable graph keys.
What torch_xla actually keys on
The TL;DR of torch_xla graph caching: a compiled program is keyed by
the traced graph plus the shapes and dtypes of every input tensor.
It is not keyed by their values. So far, so good. The trap is what
counts as "the traced graph" when you are using torch_xla.compile()
around a Python-level fwd+bwd in 2026.
Three things end up baked into that key in ways that are easy to miss:
- Any Python
int,float, orboolthat flows into a tensor op as a constant becomes part of the HLO. Change the value, get a new graph. This is the same trap as Dynamo on CUDA, buttorch_xlasurfaces it more silently because there is no recompile log line by default. - Any
tensor.item()or.cpu()call inside the traced region collapses the trace at that point and forces a host-device sync. The next op restarts a new graph fragment with whatever shape the downstream tensors take. Two such syncs per step is two graph fragments that all need their own cache entries. - Any non-static tensor shape - typically the result of a
bool-mask gather, a top-k whosekis data-dependent, or a varlen attention layout - either forces dynamic-shape fallback (slow on v6e) or triggers a per-shape recompile.
On a model the size of NAM52, every one of these traps was hit by at least one block. Finding them was a multi-week archaeology project.
The recompilation patterns we hit
Python counters in compiled blocks
The most spectacular case was the MoE overflow accumulator. The MoE
block was tracking how many tokens had been dropped due to capacity
overflow, and it stored the running total as a Python int attribute on
the module:
self._overflow_total += int(dropped.sum().item())
Every forward call mutated self._overflow_total. Dynamo / torch_xla
treated each new value as a fresh constant in the trace. On 8-GPU DDP we
saw up to 64 recompiles per training step before the cache stabilised,
and on TPU it manifested as compile-storm hangs and eventually NaN.
The fix is the Megatron-Core pattern: hold counters in
register_buffer tensors, and do the increment as an in-place tensor op
rather than a Python increment. Do this for every counter that gets
touched in a compiled region. We had to chase three of them across MoE,
the GPT block, and our sparse-attention path before the storm stopped.
Grad-create vs grad-accumulate as two graphs
torch_xla.compile() traces what it sees. On step 0 of training with
gradient accumulation, the autograd state on rank 0 has no existing
gradient tensors, so the first micro-step traces a grad-create graph
(acc.grad = new_grad). On every subsequent micro-step the gradients
already exist, so the trace produces a grad-accumulate graph
(acc.grad.add_(new_grad)). These are two distinct HLO programs, both
needed.
If you wrap each micro-step in its own torch_xla.compile(), the cache
gets both variants the first time around and you pay two compile bills
at startup. If you also do optimizer.zero_grad(set_to_none=True)
between accumulation cycles, you flip the optimizer back into the
grad-create shape on every cycle, which causes the optimizer step
itself to alternate between two compiled programs forever. The fix is a
single torch_xla.compile() around the whole accumulation chain (so
create and accumulate land in one HLO), and to avoid set_to_none=True
in the inner loop. We documented this in the training script as a hard
requirement.
Enriched-batch metadata as accidental polymorphism
Our dataloader produces "enriched" batches that carry optional per-document metadata: doc-id arrays, FIM/IFIM boundary tensors, per-token validity counts, occasional auxiliary masks. Some batches have all of those keys; some are missing one or two; some have empty tensors.
The first version of the model entry path branched on key in batch for
each optional field. That meant the Python control flow taken inside
the traced region depended on which optional keys were present, which
meant every distinct subset of present keys traced its own HLO. Across a
real shuffle that produced a long tail of one-shot graphs.
The fix was to canonicalise the structure: a single helper
(_canonicalize_structure_meta_for_xla) that pads or zeroes every
optional field to a fixed shape and dtype, so the traced graph is
structurally identical regardless of which keys the dataloader emitted.
The downstream model then takes the canonical batch and uses tensor
masks (not Python ifs) to decide what is active. That alone removed
about a dozen recompiles per epoch on the data we were running.
.item() inside fast paths
Two tiny pieces of code were each costing about ten host-device syncs per step:
flash_attention.py's_doc_ids_are_uniform()was doing(doc_ids == doc_ids[..., :1]).all().item()to decide whether the GQA expansion path could be skipped. On XLA we replaced it with a branch-free(doc_ids - doc_ids[..., :1]).any()whose result stays on device, plus a static fast path keyed on a launch-time flag, and pushed the dynamic check behind adevice.type == "xla"guard so it never executes on TPU.mamba2.py's boundary check was calling.any().item()on a per-step boundary mask. Same pattern: replace with a tensor-only path on XLA.
Each of these was one line. Together they cut roughly twenty syncs and the matching mid-step graph fragments per step.
MoE capacity that wasn't quite static
MoE capacity is min(max_tokens_per_expert_init, BT) where BT is
batch * tokens. We had a code path where BT was being recomputed
from the batch's actual token count rather than the configured one,
which was right within tolerance for a fixed dataloader but wrong
enough to occasionally produce a different int value when the
collator padded slightly differently. That int flowed into the trace
as a constant, so each unique BT value got its own compiled MoE
dispatch program. The fix was to clamp at the configured BT and never
let the collator's padding leak into the compiled graph
(test_xla_capacity_clamped_to_bt covers it).
Where shape polymorphism actually ends
torch_xla does support a degree of dynamic-shape compilation, but it
is not the load-bearing surface most people assume. In the 2.9 stack we
were on, shape-polymorphic compilation works for:
- broadcastable elementwise ops with one polymorphic dimension
- contractions whose contracted dimension is static
- gather/scatter ops where index counts are static and only the input payload size is dynamic
It does not work, in the sense of producing a single cached HLO program, for:
topkwith a non-constantk- mask-driven gathers whose output count varies per call
- attention layouts where
valid_token_countsdiffers per batch - MoE dispatch where
capacityis allowed to vary
When you ask for any of those, torch_xla falls back to per-shape
recompilation, and on a model with the structural variability of NAM52
that fallback path is cost-prohibitive. The pragmatic answer is to make
everything static: pad to a fixed length, clamp k, fix the per-step
capacity. Pay the constant-overhead cost. Skip the long tail of
compiles.
The other useful surface is the bounded XLA flash-attention dispatcher
(xla_flash_pallas / xla_splash_via_trace_pallas), which is itself a
statically-shaped Pallas kernel. The provenance receipt we have on
v6e-8 shows it dispatching cleanly with and without softcap, with no
fallback path. If your attention layout fits inside that bounded
contract you keep the cache stable for free; if it does not, you should
expect varlen recompiles even from the "fast" backend.
What stabilised the cache
The combination of fixes that finally produced a clean cache after the first 5-10 steps:
- Persistent disk cache via
XLA_COMPILATION_CACHE_DIRpointing at a local SSD path (/data/.xla_cachewhen available). Repeated startups of the same configuration drop from ~18 minutes to ~90 seconds. Do not point this at GCS - the early-init path onbase_train.pydoes not safely handle remote URIs. - All counters and accumulators in
register_buffer, none in Python attributes that flow into compiled regions. - A single
torch_xla.compile()per micro-step cycle, not per micro-step. set_to_none=Falsefor grad zeroing inside accumulation loops.- Canonicalised batch structure: optional keys are always present, always the same shape and dtype, masking decides what is live.
- Static MoE capacity, clamped at config time, never touched by the collator.
--xla_barrier_every_n_layers=Nfor memory headroom only after the cache is clean - barriers split the trace into more programs, so they multiply the cost of any remaining cache instability. Do not use barriers as a debugging tool for recompilation.XLA_USE_BF16=0(we cast the model to bf16 ourselves; the env-driven cast is incompatible with current Pallas kernels and causes an extra set of fallback graphs).
After all of this, the MTP=1 sweep cache settled in roughly 5 steps with ~15 cached programs and stayed flat for the rest of the run. The MTP=3 sweep still needed ~30 cached programs - the per-depth heads each contribute their own attention and projection programs - but stayed flat after step ~10. Compile time becomes a one-time cost once the persistent cache is warm, and we re-use it across runs of the same config.
Things that pretended to help and didn't
torch.compile()on the model on TPU. Stock guidance over the years has gone back and forth on this; the currenttorch_xlastory is that XLA's own JIT does the work and addingtorch.compileon top produces hangs, NaNs, or duplicate compile work. Leave it off.- Aggressive
xla_flag_profile=autobundles inherited from MaxText. The async-collective-fusion flags and the data-parallel all-reduce optimisation flag both produced NaN from step 1 on every NAM12 and NAM52 config we tried. They also caused subtle compile-cache key drift when we toggled them. We default to--xla_flag_profile=nonenow and add overrides explicitly through--xla_flags_extrawhen we need them. - Speculative warmup of "common" batch shapes. Pre-tracing a small set of canonical batches at startup seeded the cache for the seeded shapes and helped nothing else, because dataloader sharding meant rank-local batch shapes drifted slightly from the seeds. Pin the shapes the dataloader produces; do not preheat the cache.
Practical checklist
If you want to stop seeing surprise compiles after step 0 on
torch_xla 2.9-class:
- Turn on
PT_XLA_DEBUG=1and count unique HLO programs across the loop. If the count keeps growing past step 20, you have an instability somewhere. - Remove every
.item(),.cpu(),int(...)-of-tensor, and Python-levelifon tensor values inside any region wrapped bytorch_xla.compile(). - Move every Python counter in a compiled module to
register_buffer. - Canonicalise dataloader output: optional fields always present, always the same shape, masked when inactive.
- Pin a persistent local cache. Do not point it at remote storage.
- Do not stack
torch.compile()on top of the XLA JIT.
The good news: this is one-time work. After the cache stabilised, the
same NAM52 configuration warmed from cold disk in under two minutes
and ran indefinitely without further compiles. The bad news: the
torch_xla runtime does not tell you which line caused a recompile,
so the fix is mechanical and the discovery curve is unforgiving.
References
- review_gcp_tpu.md
- docs/CURRENT_STATE.md
- docs/TPU_SETUP.md
- docs/TENSOR_PARALLELISM.md
- docs/BACKEND_STOPLIGHT_MATRIX.md
- reports/dist_optimizer_stress_tpu_v6e_2026-03-22.md
- reports/mtp_depth_attribution_tpu_v6e_2026-03-22.md
- reports/tpu_backend_provenance_v6e8_2026-03-16.json
- CHANGELOG.md
- training_review.md
- speed_rep_xx.md