torch_xla
pjrt
tpu
v6e
spmd
lazy-tensor
nanochat

torch_xla and PJRT on TPU v6e: What Actually Worked

An honest engineering account of running torch_xla / PJRT on Google TPU v6e during the nanochat POC: version pinning, lazy-tensor tracing traps, persistent cache, SPMD, and the pybind we had to ship ourselves.

9 min readDavid Gornshtein
torch_xla and PJRT on TPU v6e: What Actually Worked

torch_xla and PJRT on TPU v6e: What Actually Worked

The nanochat research stream runs a serious fraction of its training on Google TPU v6e, specifically v6e-4 and v6e-8 slices. That path is torch-first (we reuse a single model definition across CUDA and TPU) which means the real runtime is torch_xla sitting on PJRT, sitting on libtpu. This post is a specific, unglamorous account of what held up, what broke, and where the lazy-tensor abstraction quietly bit us.

It is not "TPU vs GPU". It is "what you actually have to do to keep a PJRT / torch_xla stack honest on a custom training codebase".

The stack we ship

After enough pain, we stopped chasing PyPI. The repo-preferred TPU install is built from a pinned source tree:

  • Python 3.13
  • custom torch 2.9.0a0+git21fec65 (and on one validate host, torch 2.11.0a0+git7afdbae)
  • custom torch_xla 2.9.0+gitc04e61c / 2.11.0+gitc04e61c
  • libtpu 0.0.36 as the production floor; 0.0.37.dev20260224+nightly where we need Mode A compile on dense models
  • jax 0.9.0, jaxlib 0.9.0
  • PJRT_DEVICE=TPU as the launch signal

The stock wheels (torch_xla 2.9.0 + libtpu 0.0.23.1) are kept only as a historical reference. They are not the current path. The reason is not purity: it is that two of the patches below do not exist upstream, and without them our TPU runs were either silently broken or six to sixteen times slower than they should be.

The persistent compilation cache was a lie

The first time we noticed, we had a v6e-8 training job that took 47 minutes to reach step 0 on NAM52. The second time we restarted it, it took 47 minutes again. torch_xla happily wrote files into the cache directory after compile; on restart it read them back, failed deserialization, and recompiled from scratch. The error surfaced as:

UNIMPLEMENTED: Deserializing serialized executable not supported

Root cause: torch_xla's DeserializeComputation() was calling PjRtClient::DeserializeExecutable(). That is a base-class method which returns UNIMPLEMENTED on TPU. The method libtpu actually implements is PjRtClient::LoadSerializedExecutable(), reached through the PJRT C API entry point PJRT_Executable_DeserializeAndLoad. JAX has always called the right one. torch_xla had not.

The secondary problem was that SerializeComputation() was persisting an HLO proto, not the compiled XLAEXE executable, so even with the correct deserialization entry point there was nothing useful on disk to load.

Our patch (persistent_cache_fix.diff against torch_xla/csrc/runtime/pjrt_computation_client.cpp) switches serialization to XLAEXE (with HLO fallback) and replaces the deserialization call with client_->LoadSerializedExecutable(). On a small d=8 model, cold step 0 is 11.5s; warm step 0 after cache hit is 1.7s, a 6.7x improvement. On NAM52 the same fix is expected to take 47 minutes down to roughly 7 minutes. We have not yet observed a full warm NAM52 restart inside the window this post covers - the estimate is proportional, not measured - but every receipted cold-vs-warm pair on smaller shapes has held.

The companion gotcha is AdamW. Even with the cache fix, our first warm restart still saw a cache miss on step 1. The culprit is torch.optim.AdamW._get_value(step_t), which calls .item() on the step counter every step. On XLA that materialises a fresh Python float constant into the compiled graph, so the graph hash changes on every step and every step cache-misses. The fix in gpt.py is to pass capturable=True on XLA:

optim_cls = partial(torch.optim.AdamW, capturable=_on_xla)

capturable=True keeps the step counter on-device, the constant disappears from the graph, and the cache actually hits across steps.

SPMD and the memory-info pybind we had to write

We run torch_xla SPMD for tensor and data parallel. The model sees xla:0; under the hood PJRT reports a virtual device SPMD:0 and the real runtime devices are TPU:0 through TPU:7 on a v6e-8. This is usually invisible, but it breaks one specific thing we need: HBM telemetry.

The existing _xla_memory_info pybind goes through c10::Device(...), which rejects raw runtime device strings like TPU:18 and also rejects the virtual SPMD:0 device:

MemoryInfo not supported for SPMD virtual device.

That means the running training process cannot ask "how much HBM is allocated on physical chip 18 right now?" without going around the whole SPMD abstraction, which defeats the point of having it.

Our additive patch (xla_runtime_memory_info.diff against torch_xla/csrc/init_python_bindings.cpp) adds:

.def("_xla_runtime_memory_info",
     [](const std::string& device) { return GetRuntimeMemoryInfo(device); })

This bypasses GetDeviceOrCurrent() and forwards the raw runtime device string directly to the computation client. Nanochat calls _XLAC._xla_runtime_memory_info("TPU:18") from the already-running training loop and exposes it on the /memory endpoint. The original _xla_memory_info is untouched, so older callers are unaffected.

Two build artifacts made this bearable: the patched _XLAC.cpython-313-x86_64-linux-gnu.so (about 306 MB) is cached on GCS along with the source diff, so bringing a new VM up is a gcloud storage cp and a venv drop-in rather than a five-minute rebuild per host.

Lazy-tensor traps that actually cost us

Every torch_xla user learns the first rule - don't call .item() in hot code paths - and then cheerfully violates it. The nanochat POC spent weeks hunting down cases that only manifested on TPU, where host/device syncs fragmented the graph and destroyed throughput. A representative batch:

  • flash_attention.py: _doc_ids_are_uniform() was calling .item() to check whether all doc ids in a batch were equal. On TPU we now take an XLA fast path that stays on-device.
  • mamba2.py: has_boundaries = mask.any() branched on a host-side bool. Rewritten to keep the branch on-device.
  • mamba2.py: pad_val = doc_ids.max().item() + 1 synced the max back to the host. Replaced with a static sentinel pad_val = 2**30 - semantically equivalent for our doc-id range and no sync.
  • gpt.py: if aux_sum.abs() > 0: forced XLA to materialise a scalar to a Python bool every forward pass. The branch was always safe to take (zeros contribute nothing), so the branch was simply removed.

Each of these was a small fix. Collectively they removed roughly 20 host-device syncs per training step and were the difference between torch_xla.compile() producing a clean graph and producing a graph stitched together out of dozens of fragments.

A second class of trap was data-dependent trace variants. In StructureEmbedding.forward(), a dynamic component list with None checks could produce up to 32 distinct XLA trace variants as metadata columns came and went. The fix is to always compute all N components, let absent inputs contribute zero embeddings, and mask them at softmax time with -inf logits. One graph, always; no combinatorial recompilation.

A third class was O(N^2) structures expressed as Python loops or fancy indexing. build_chunk_relation_mask was a per-pair Python nest; we rewrote it as torch.eye, tensor scatter, and broadcast pairwise differences. TreeFFN's O(C*T) binary mask was replaced with searchsorted plus scatter_add. On one NAM12 profile run, enabling TreeFFN without that rewrite collapsed throughput from about 387K tok/sec to 5K tok/sec - a 70x regression from a single feature. After the rewrite, TreeFFN sat in the normal feature-overhead band.

Compile modes and when to use each

We settled on two torch_xla compile modes, with an explicit contract.

Mode B is per-micro-step torch_xla.compile() around fwd+bwd, with the optimizer step compiled separately. It is the safe default. With gradient accumulation it typically compiles two variants on step 0 (one for "grad create", one for "grad accumulate") before the cache settles. XLA batch warmup is intentionally disabled to keep graphs static.

Mode A is whole-graph compile via torch_xla.compile() across the full step. On dense models, Mode A is roughly 30-56% faster than Mode B on our v6e-8 runs (for NAM12 dense d12: 482K vs 309K tok/sec). It also eliminates LR-change recompilation, which in Mode B quietly happens at every schedule step-change. The catch: Mode A consistently crashes libtpu on MoE training graphs at compile time. MoE routing (gather/scatter/top-k) hits a compilation-time path that SIGKILLs the process. Mamba-only and dense-only Mode A work fine.

The contract we ended up with: Mode A for dense / Mamba-only, Mode B for any model with MoE. Model-level torch.compile(...) stays off on TPU entirely - we rely on the XLA JIT. That decision is not new; it is simply the one that stopped producing hangs.

We also abandoned the idea of a single giant graph for TPU. The older "one-graph" narrative in our own docs is dead. The current truth is per-step fwd+bwd compile, with an optional --xla_barrier_every_n_layers=N switch that inserts xm.mark_step() barriers inside the model when the full fwd+bwd is too large to compile in one chunk. That flag is a safety valve, not a default.

SPMD mesh shapes and XLA_NO_SPECIAL_SCALARS

The mesh is not just 2D TP. Current shapes are:

  • ("data",) for TP=1 EP=1
  • ("data", "model") for TP>1
  • ("data", "expert") for EP>1 TP=1
  • ("data", "expert", "model") for EP>1 TP>1

We pass XLA_NO_SPECIAL_SCALARS=1 unconditionally in the TPU environment. It is part of the core run contract on this stack; omitting it produced subtle wrong-answer regressions on scalar broadcast ops under SPMD that only showed up as higher loss, not as explicit errors. If someone on the team proposes removing it, the answer is: not without a fresh receipt.

libtpu version sensitivity

libtpu 0.0.36 and 0.0.37.dev20260224 are close enough to be confused in CI logs, and different enough to destroy experiments.

On the 0.0.36 -> 0.0.37 bump:

  • --xla_disable_hlo_passes was removed. Its replacement is --xla_jf_vmem_memory_space_assignment=false. A 0.0.36 run script that carried the old flag fails closed on 0.0.37 with "unknown flag".
  • Both 0.0.36 and 0.0.37 produced NaN at step 1 on bare NAM52 with FSDP for a time. Swapping between the two versions was not sufficient to isolate the bug; the eventual resolution was on the MSA pass side, not the libtpu version.
  • Mode A becomes available on dense models on 0.0.37. On 0.0.36 we stayed on Mode B for dense too.
  • XLA cache deserialization remains broken upstream on both versions; the patched torch_xla fix above is what makes the cache work at all.

The operator contract is: pin libtpu explicitly per host, log the pinned version in every run manifest, and do not treat nightly as equivalent to stable. In our fleet the typical manifest row is torch 2.9.0a0+git21fec65 | torch_xla 2.9.0+gitc04e61c | libtpu 0.0.37.dev20260224 | jax 0.9.0.

The failure mode you will not see in a benchmark

One more lesson worth recording. A portion of our early "TPU pipeline" runs on tpu_full_pipeline.py were training on torch.randint instead of real data. The loader, reward function, and eval metrics were all placeholder. Throughput numbers looked fine; loss curves looked fine; the receipts were meaningless. TPU time is cheap per chip-hour only if the chips are doing real work. Our current discipline is to gate any "TPU training" claim on a checked data-contract receipt (parquet loader, BOS-aligned packing, packed docs, enriched metadata) before the run is considered evidence of anything. Lazy-tensor tracing will happily compile your random noise.

Takeaways

torch_xla / PJRT is viable for serious training on TPU v6e. It is not a drop-in replacement for CUDA - it is a separate discipline, with its own compile modes, its own cache story, and its own set of lazy-tensor traps that only manifest on-device. The three investments that paid for themselves were: upstream patches against the real pain points (persistent cache, SPMD memory-info), a pinned and logged version matrix per host, and an allergic reaction to .item() in the forward pass.

Everything else - SPMD mesh shapes, XLA flag profiles, compile-retry policies - is secondary. Get the patches and the version discipline right first.

References

  • TPU_SETUP.md
  • BACKEND_STOPLIGHT_MATRIX.md
  • CURRENT_STATE.md
  • persistent_cache_fix.diff (README)
  • xla_runtime_memory_info.diff (README)
  • CHANGELOG.md
  • review_gcp_tpu.md
  • training_plan_en.md
  • training_review.md
  • tpu_backend_provenance_v6e8_2026-03-16.json
  • dist_optimizer_stress_tpu_v6e_2026-03-22.md
David Gornshtein • Datasunrise OÜMore posts →