torch_xla and PJRT on TPU v6e: What Actually Worked
An honest engineering account of running torch_xla / PJRT on Google TPU v6e during the nanochat POC: version pinning, lazy-tensor tracing traps, persistent cache, SPMD, and the pybind we had to ship ourselves.

torch_xla and PJRT on TPU v6e: What Actually Worked
The nanochat research stream runs a serious fraction of its training on
Google TPU v6e, specifically v6e-4 and v6e-8 slices. That path is torch-first
(we reuse a single model definition across CUDA and TPU) which means the real
runtime is torch_xla sitting on PJRT, sitting on libtpu. This post is a
specific, unglamorous account of what held up, what broke, and where the
lazy-tensor abstraction quietly bit us.
It is not "TPU vs GPU". It is "what you actually have to do to keep a PJRT / torch_xla stack honest on a custom training codebase".
The stack we ship
After enough pain, we stopped chasing PyPI. The repo-preferred TPU install is built from a pinned source tree:
- Python 3.13
- custom
torch 2.9.0a0+git21fec65(and on one validate host,torch 2.11.0a0+git7afdbae) - custom
torch_xla 2.9.0+gitc04e61c/2.11.0+gitc04e61c libtpu 0.0.36as the production floor;0.0.37.dev20260224+nightlywhere we need Mode A compile on dense modelsjax 0.9.0,jaxlib 0.9.0PJRT_DEVICE=TPUas the launch signal
The stock wheels (torch_xla 2.9.0 + libtpu 0.0.23.1) are kept only as a
historical reference. They are not the current path. The reason is not
purity: it is that two of the patches below do not exist upstream, and
without them our TPU runs were either silently broken or six to sixteen times
slower than they should be.
The persistent compilation cache was a lie
The first time we noticed, we had a v6e-8 training job that took 47 minutes
to reach step 0 on NAM52. The second time we restarted it, it took 47
minutes again. torch_xla happily wrote files into the cache directory
after compile; on restart it read them back, failed deserialization, and
recompiled from scratch. The error surfaced as:
UNIMPLEMENTED: Deserializing serialized executable not supported
Root cause: torch_xla's DeserializeComputation() was calling
PjRtClient::DeserializeExecutable(). That is a base-class method which
returns UNIMPLEMENTED on TPU. The method libtpu actually implements is
PjRtClient::LoadSerializedExecutable(), reached through the PJRT C API
entry point PJRT_Executable_DeserializeAndLoad. JAX has always called the
right one. torch_xla had not.
The secondary problem was that SerializeComputation() was persisting an
HLO proto, not the compiled XLAEXE executable, so even with the correct
deserialization entry point there was nothing useful on disk to load.
Our patch (persistent_cache_fix.diff against
torch_xla/csrc/runtime/pjrt_computation_client.cpp) switches serialization
to XLAEXE (with HLO fallback) and replaces the deserialization call with
client_->LoadSerializedExecutable(). On a small d=8 model, cold step 0 is
11.5s; warm step 0 after cache hit is 1.7s, a 6.7x improvement. On NAM52 the
same fix is expected to take 47 minutes down to roughly 7 minutes. We have
not yet observed a full warm NAM52 restart inside the window this post
covers - the estimate is proportional, not measured - but every receipted
cold-vs-warm pair on smaller shapes has held.
The companion gotcha is AdamW. Even with the cache fix, our first warm
restart still saw a cache miss on step 1. The culprit is
torch.optim.AdamW._get_value(step_t), which calls .item() on the step
counter every step. On XLA that materialises a fresh Python float constant
into the compiled graph, so the graph hash changes on every step and every
step cache-misses. The fix in gpt.py is to pass capturable=True on XLA:
optim_cls = partial(torch.optim.AdamW, capturable=_on_xla)
capturable=True keeps the step counter on-device, the constant disappears
from the graph, and the cache actually hits across steps.
SPMD and the memory-info pybind we had to write
We run torch_xla SPMD for tensor and data parallel. The model sees xla:0;
under the hood PJRT reports a virtual device SPMD:0 and the real runtime
devices are TPU:0 through TPU:7 on a v6e-8. This is usually invisible,
but it breaks one specific thing we need: HBM telemetry.
The existing _xla_memory_info pybind goes through c10::Device(...), which
rejects raw runtime device strings like TPU:18 and also rejects the
virtual SPMD:0 device:
MemoryInfo not supported for SPMD virtual device.
That means the running training process cannot ask "how much HBM is allocated on physical chip 18 right now?" without going around the whole SPMD abstraction, which defeats the point of having it.
Our additive patch (xla_runtime_memory_info.diff against
torch_xla/csrc/init_python_bindings.cpp) adds:
.def("_xla_runtime_memory_info",
[](const std::string& device) { return GetRuntimeMemoryInfo(device); })
This bypasses GetDeviceOrCurrent() and forwards the raw runtime device
string directly to the computation client. Nanochat calls
_XLAC._xla_runtime_memory_info("TPU:18") from the already-running training
loop and exposes it on the /memory endpoint. The original
_xla_memory_info is untouched, so older callers are unaffected.
Two build artifacts made this bearable: the patched
_XLAC.cpython-313-x86_64-linux-gnu.so (about 306 MB) is cached on GCS along
with the source diff, so bringing a new VM up is a gcloud storage cp and a
venv drop-in rather than a five-minute rebuild per host.
Lazy-tensor traps that actually cost us
Every torch_xla user learns the first rule - don't call .item() in hot
code paths - and then cheerfully violates it. The nanochat POC spent weeks
hunting down cases that only manifested on TPU, where host/device syncs
fragmented the graph and destroyed throughput. A representative batch:
flash_attention.py:_doc_ids_are_uniform()was calling.item()to check whether all doc ids in a batch were equal. On TPU we now take an XLA fast path that stays on-device.mamba2.py:has_boundaries = mask.any()branched on a host-side bool. Rewritten to keep the branch on-device.mamba2.py:pad_val = doc_ids.max().item() + 1synced the max back to the host. Replaced with a static sentinelpad_val = 2**30- semantically equivalent for our doc-id range and no sync.gpt.py:if aux_sum.abs() > 0:forced XLA to materialise a scalar to a Python bool every forward pass. The branch was always safe to take (zeros contribute nothing), so the branch was simply removed.
Each of these was a small fix. Collectively they removed roughly 20
host-device syncs per training step and were the difference between
torch_xla.compile() producing a clean graph and producing a graph stitched
together out of dozens of fragments.
A second class of trap was data-dependent trace variants. In
StructureEmbedding.forward(), a dynamic component list with None checks
could produce up to 32 distinct XLA trace variants as metadata columns came
and went. The fix is to always compute all N components, let absent inputs
contribute zero embeddings, and mask them at softmax time with -inf
logits. One graph, always; no combinatorial recompilation.
A third class was O(N^2) structures expressed as Python loops or fancy
indexing. build_chunk_relation_mask was a per-pair Python nest; we
rewrote it as torch.eye, tensor scatter, and broadcast pairwise
differences. TreeFFN's O(C*T) binary mask was replaced with searchsorted
plus scatter_add. On one NAM12 profile run, enabling TreeFFN without that
rewrite collapsed throughput from about 387K tok/sec to 5K tok/sec - a 70x
regression from a single feature. After the rewrite, TreeFFN sat in the
normal feature-overhead band.
Compile modes and when to use each
We settled on two torch_xla compile modes, with an explicit contract.
Mode B is per-micro-step torch_xla.compile() around fwd+bwd, with the
optimizer step compiled separately. It is the safe default. With gradient
accumulation it typically compiles two variants on step 0 (one for "grad
create", one for "grad accumulate") before the cache settles. XLA batch
warmup is intentionally disabled to keep graphs static.
Mode A is whole-graph compile via torch_xla.compile() across the full
step. On dense models, Mode A is roughly 30-56% faster than Mode B on our
v6e-8 runs (for NAM12 dense d12: 482K vs 309K tok/sec). It also eliminates
LR-change recompilation, which in Mode B quietly happens at every schedule
step-change. The catch: Mode A consistently crashes libtpu on MoE
training graphs at compile time. MoE routing (gather/scatter/top-k) hits a
compilation-time path that SIGKILLs the process. Mamba-only and dense-only
Mode A work fine.
The contract we ended up with: Mode A for dense / Mamba-only, Mode B for
any model with MoE. Model-level torch.compile(...) stays off on TPU
entirely - we rely on the XLA JIT. That decision is not new; it is simply
the one that stopped producing hangs.
We also abandoned the idea of a single giant graph for TPU. The older
"one-graph" narrative in our own docs is dead. The current truth is
per-step fwd+bwd compile, with an optional
--xla_barrier_every_n_layers=N switch that inserts xm.mark_step()
barriers inside the model when the full fwd+bwd is too large to compile in
one chunk. That flag is a safety valve, not a default.
SPMD mesh shapes and XLA_NO_SPECIAL_SCALARS
The mesh is not just 2D TP. Current shapes are:
("data",)for TP=1 EP=1("data", "model")for TP>1("data", "expert")for EP>1 TP=1("data", "expert", "model")for EP>1 TP>1
We pass XLA_NO_SPECIAL_SCALARS=1 unconditionally in the TPU environment.
It is part of the core run contract on this stack; omitting it produced
subtle wrong-answer regressions on scalar broadcast ops under SPMD that
only showed up as higher loss, not as explicit errors. If someone on the
team proposes removing it, the answer is: not without a fresh receipt.
libtpu version sensitivity
libtpu 0.0.36 and 0.0.37.dev20260224 are close enough to be confused in
CI logs, and different enough to destroy experiments.
On the 0.0.36 -> 0.0.37 bump:
--xla_disable_hlo_passeswas removed. Its replacement is--xla_jf_vmem_memory_space_assignment=false. A0.0.36run script that carried the old flag fails closed on0.0.37with "unknown flag".- Both
0.0.36and0.0.37produced NaN at step 1 on bare NAM52 with FSDP for a time. Swapping between the two versions was not sufficient to isolate the bug; the eventual resolution was on the MSA pass side, not the libtpu version. - Mode A becomes available on dense models on
0.0.37. On0.0.36we stayed on Mode B for dense too. - XLA cache deserialization remains broken upstream on both versions; the
patched
torch_xlafix above is what makes the cache work at all.
The operator contract is: pin libtpu explicitly per host, log the pinned
version in every run manifest, and do not treat nightly as equivalent to
stable. In our fleet the typical manifest row is
torch 2.9.0a0+git21fec65 | torch_xla 2.9.0+gitc04e61c | libtpu 0.0.37.dev20260224 | jax 0.9.0.
The failure mode you will not see in a benchmark
One more lesson worth recording. A portion of our early "TPU pipeline"
runs on tpu_full_pipeline.py were training on torch.randint instead of
real data. The loader, reward function, and eval metrics were all
placeholder. Throughput numbers looked fine; loss curves looked fine; the
receipts were meaningless. TPU time is cheap per chip-hour only if the
chips are doing real work. Our current discipline is to gate any
"TPU training" claim on a checked data-contract receipt (parquet loader,
BOS-aligned packing, packed docs, enriched metadata) before the run is
considered evidence of anything. Lazy-tensor tracing will happily compile
your random noise.
Takeaways
torch_xla / PJRT is viable for serious training on TPU v6e. It is not a
drop-in replacement for CUDA - it is a separate discipline, with its own
compile modes, its own cache story, and its own set of lazy-tensor traps
that only manifest on-device. The three investments that paid for
themselves were: upstream patches against the real pain points (persistent
cache, SPMD memory-info), a pinned and logged version matrix per host, and
an allergic reaction to .item() in the forward pass.
Everything else - SPMD mesh shapes, XLA flag profiles, compile-retry policies - is secondary. Get the patches and the version discipline right first.
References
- TPU_SETUP.md
- BACKEND_STOPLIGHT_MATRIX.md
- CURRENT_STATE.md
- persistent_cache_fix.diff (README)
- xla_runtime_memory_info.diff (README)
- CHANGELOG.md
- review_gcp_tpu.md
- training_plan_en.md
- training_review.md
- tpu_backend_provenance_v6e8_2026-03-16.json
- dist_optimizer_stress_tpu_v6e_2026-03-22.md