libtpu, JAX and torch_xla in One Container: Device Init Races and Env-Var Landmines
What it looks like when a single TPU host runs JAX and torch_xla side by side: libtpu initialization races, VFIO contention, PJRT platform resolution, and the env-var surface you only learn about when it breaks.

libtpu, JAX and torch_xla in One Container: Device Init Races and Env-Var Landmines
The nanochat POC that feeds the MegaCpp SLM ensemble ships one container per
TPU host that contains both torch_xla (for training) and jax (for
tokenizer utilities, numerical reference tests, a few SPMD helpers, and
upstream Pallas kernels we wrap). On a GPU box that would be unremarkable.
On a TPU v6e host it is a minefield, because torch_xla and jax both
link the same libtpu and both try to own the physical device.
This post documents what actually breaks when those two runtimes share a host, the env-var surface that controls the fight, and the debug patterns that finally made it reproducible.
The shared object underneath
Both stacks route through PJRT and both ultimately call into libtpu. On
our pinned v6e-8 hosts:
torch 2.9.0a0+git21fec65ortorch 2.11.0a0+git7afdbaetorch_xla 2.9.0+gitc04e61cortorch_xla 2.11.0+gitc04e61cjax 0.9.0,jaxlib 0.9.0libtpu 0.0.36(production) or0.0.37.dev20260224+nightly- PJRT API version: plugin 0.94, framework 0.91 (forward-compat OK)
libtpu is the single consumer of /dev/vfio/0. There is one, and it is
not re-entrant across processes. The moment two processes try to claim it
at the same time, one wins and the other receives an opaque initialization
failure. That is the primitive under everything that follows.
The race we kept losing
The symptom was that our test suite would pass cleanly on GPU, pass most
of the time on TPU, and fail intermittently on TPU when run against a
host that was already training. The failing tests were TestAutogradGuard
cases on clustered sparse attention. They used jax only as a numerical
reference - pure CPU math, no TPU intent. They were also the first tests
imported in their module, which is why they drew the short straw.
The actual failure was that jax imported, saw a TPU-capable libtpu on
the system, and tried to grab /dev/vfio/0 for TPU initialization.
torch_xla already owned it from the training process next door. jax
failed to init, and the test harness treated the failure as a crash. Worse,
on a quiet host the tests would pass because jax won the race, then
torch_xla would later fail to init during the next training launch.
The fix is to tell jax explicitly that it is CPU-only for these tests:
def _import_jax_cpu():
os.environ.setdefault("JAX_PLATFORMS", "cpu")
import jax
return jax
Every numerical-reference test in the suite now goes through
_import_jax_cpu(). JAX_PLATFORMS=cpu has to be set before
import jax, and it has to be set in every subprocess that imports
jax - setdefault handles both cases without clobbering explicit
overrides. On test runs we set JAX_PLATFORMS=cpu at the pytest
invocation level too, as a belt-and-braces measure:
JAX_PLATFORMS=cpu python -m pytest tests/...
After that change the v6e-8 validate host ran 903 passed, 1 skipped, 0 failures on the cycle 5 suite. The previous run had the single JAX init failure we chased for most of a day.
Env-var sensitivity, documented
Once we understood that env-vars were the real public API of libtpu,
we wrote them down. The ones that matter in practice:
PJRT_DEVICE=TPU: the launch signal fortorch_xla's TPU path. Must be set beforeimport torch_xla. Absence silently falls back to CPU.JAX_PLATFORMS:cpufor numerical reference tests; unset (ortpu) for the rare JAX-on-TPU helpers. Must be set beforeimport jax.XLA_NO_SPECIAL_SCALARS=1: part of the TPU run contract. Under SPMD, not setting it produced subtle wrong-answer regressions on scalar broadcasts that only showed up as higher loss.XLA_COMPILATION_CACHE_DIR: the authoritative early-init knob for the persistent XLA cache onscripts/base_train.py. Default is/data/.xla_cacheif/dataexists, else~/.cache/xla_compilation. There is also a--xla_cache_dirCLI flag, but it does not override the first early cache init; if you care where the cache lives, set the env var.NANOCHAT_NO_XLA_CACHE=1: disables the persistent cache entirely.LIBTPU_INIT_ARGS: the catch-all for libtpu-level flags. We now prefer--xla_flags_extrafor one-off overrides rather than copying a fullLIBTPU_INIT_ARGSblob between run scripts, because the libtpu flag vocabulary changes between versions (see below).TPU_LOG_DIR/TPU_STDERR_LOG_LEVEL: off by default; bumped during device-init debugging. The log noise is high enough that you do not want it on by default on a training host.
Order matters. PJRT_DEVICE, JAX_PLATFORMS, and
XLA_COMPILATION_CACHE_DIR must be set before the first import of the
corresponding framework. In scripts/base_train.py the XLA cache init
happens very early in main() precisely because a later init does not
override the first one.
The libtpu flag vocabulary drift
Between libtpu 0.0.36 and 0.0.37.dev20260224, a flag we relied on
quietly changed:
--xla_disable_hlo_passes=...- valid on0.0.36, unknown flag on0.0.37.- The replacement on
0.0.37is--xla_jf_vmem_memory_space_assignment=false.
A run script that carried the old flag fails closed on 0.0.37 with
"unknown flag" and the run does not start. That is, fortunately, loud -
the failure is explicit at launch rather than silently ignored. The
contract we moved to: log the libtpu version in every run manifest, and
let nanochat/xla_flags.py pick the flag name based on the detected
libtpu version rather than hard-coding either one.
The broader lesson is that the LIBTPU_INIT_ARGS flag surface is not
stable across nightlies. Treat it as internal to each libtpu release, not
as an API.
MSA and the NaN hunt
The most unpleasant libtpu interaction we hit was not a race; it was a
numerical regression. In early March, every NAM52 bare FSDP run on
torch 2.11 + libtpu 0.37 produced NaN at step 1. So did the same config
on torch 2.9 + libtpu 0.37. So did the same config with FSDP disabled.
So did the same config with the "disable MSA" flag removed.
The bisect matrix looked like this:
| Lane | torch | libtpu | Config | Result |
|---|---|---|---|---|
| T1 | 2.11 | 0.37 | NAM52 FSDP | NaN at step 1 |
| T2b | 2.9 | 0.37 | NAM52 FSDP | NaN at step 1 |
| T6 | 2.11 | 0.37 | NAM52 no-FSDP | NaN at step 1 |
| T7 | 2.11 | 0.37 | NAM52 FSDP, no MSA-disable | NaN at step 1 |
| T9 | 2.11 | 0.37 | d=12 FSDP + MSA-disable | libtpu 0.37 unknown flag |
| T11 | 2.9 | 0.36 | NAM52 FSDP | Compiling (no NaN) |
What T11 told us: the NaN was not exclusively a libtpu 0.37 regression,
but 0.36 was the only version where we could run the known-good MSA
workaround at all. The 0.37 flag rename
(--xla_disable_hlo_passes -> --xla_jf_vmem_memory_space_assignment=false)
meant a script mechanically ported from 0.36 could not even express the
workaround. We ended up with a rule: bare NAM52 FSDP trains on 0.36 until
0.37 has an equivalent receipt, and MoE configs can ride on 0.37
because they don't hit the same memory-space-assignment pass in the same
way.
That is not a satisfying answer. It is the honest one. libtpu nightlies
change HLO-pass behaviour between builds, and bisecting the combined
surface of torch / torch_xla / libtpu / model config is slow.
Mode A, MoE, and the libtpu SIGKILL
Two TPU compile modes, as documented elsewhere: Mode B (per-micro-step
compile around fwd+bwd) and Mode A (whole-step compile through
torch_xla.compile()). On dense models libtpu 0.0.37 enables Mode A,
and it is 30-56% faster than Mode B. On MoE models Mode A consistently
SIGKILLs during compilation:
- Dense d12 Mode A: 482K tok/sec (works)
- Dense d12 + Mamba AAM Mode A: 280-634K tok/sec (works)
- NAM12 MoE(64) Mode A: SIGKILL during compile
- NAM12 MoE(8) Mode A: SIGKILL during compile
The crash is internal to libtpu on the MoE routing sub-graph
(gather/scatter/top-k). torch_xla hands the HLO down; libtpu dies.
From the Python side there is no useful traceback; the process is
terminated by the kernel with SIGKILL.
The operational contract we wrote: Mode A for dense / Mamba-only, Mode B for anything with MoE. If you need Mode A speed on MoE, you need to wait for a libtpu revision that does not kill on those routing graphs; there is no Python-side workaround.
The XLA persistent cache is per-libtpu
One more interaction worth documenting. The patched torch_xla persistent
cache (covered in the companion PJRT post) stores XLAEXE executables on
disk. Those executables are produced by a specific libtpu build. Bumping
libtpu - even between 0.36 and 0.37 within the same torch_xla
build - should in principle invalidate the cache entry and force a
recompile.
In practice we have seen two cases worth knowing:
libtpu 0.37still returnedUNIMPLEMENTEDon cache deserialization against unpatchedtorch_xla. Our patch works against both libtpu versions, but the upstream path is broken across both.- A
libtpubump while an old cache is on disk is not guaranteed to reject every stale entry. We have not seen a wrong-answer outcome from this, but we also do not rely on libtpu to garbage-collect our cache. Our discipline: when we bumplibtpuon a host, we wipeXLA_COMPILATION_CACHE_DIRby hand and let the next run repopulate.
That is a small operational cost - a warm cache is a few GB and repopulates in one or two compile cycles - and it is much cheaper than debugging a stale-cache incident during a live training run.
Debug patterns that paid off
Three patterns actually moved debugging forward when the runtime was opaque.
PT_XLA_DEBUG=1 on the first step, then off. Leaving it on destroys
throughput and floods the log. Running it for the first 10 steps of an
ablation, captured into a dedicated log, was enough to see which ops
caused host-device syncs and which metadata columns were triggering new
trace variants. We then disabled it for the steady-state run.
Explicit _xla_runtime_memory_info("TPU:N") on each rank. Our pybind
patch (see the companion torch_xla post) lets a running training process
ask libtpu for HBM usage per physical chip. We wire this into /memory
and poll from the host. Without it, you get OOM crashes with no forensic
data about which chip ran out first or how close you were. With it,
auto-fit can narrow dbs proactively and the retry loop has signal to
act on.
Provenance JSON per run. Every training launch writes a provenance record
with hostname, python, torch, torch_xla, jax, and libtpu
versions plus the set of backend modes exercised (xla_flash, splash,
softcap on/off). A tiny example from a March validate run:
{
"provenance": {
"python": "3.13.12",
"torch": "2.9.0a0+git21fec65",
"torch_xla": "2.9.0+gitc04e61c",
"jax": "0.9.0"
},
"tests": [
{ "mode": "xla_flash", "backend_used": "xla_flash_pallas", "status": "success" },
{ "mode": "splash", "backend_used": "xla_splash_via_trace_pallas", "status": "success" }
]
}
Three months of those records is what lets us say "this regression is new" rather than "this regression is vibes". It is also what lets us correlate a failure on one v6e-8 host with the same stack on another, which turns out to matter when four machines in a four-machine fleet each run a slightly different torch build.
Fleet discipline, not per-host tweaks
The eventual operational rule is dull and non-negotiable.
One pinned Python, one pinned torch, one pinned torch_xla, one pinned
libtpu, one pinned jax, per host. Recorded in a provenance JSON. Logged
in every training manifest. JAX_PLATFORMS=cpu unless a test explicitly
needs TPU JAX. PJRT_DEVICE=TPU and XLA_NO_SPECIAL_SCALARS=1 set in
the training environment before the first framework import. Persistent
cache wiped on libtpu bump. Mode A only on dense; Mode B on MoE.
Every time we relaxed one of those rules under time pressure, the cost
was higher than doing it right would have been. The interactions between
jax, torch_xla, and libtpu on a single TPU host are not
well-documented because most users run exactly one of the three. If you
run all three, the contract is yours to enforce.
References
- TPU_SETUP.md
- CURRENT_STATE.md
- BACKEND_STOPLIGHT_MATRIX.md
- GCLOUD_AUTH_REFRESH.md
- TPU_BACKUPS.md
- CHANGELOG.md
- persistent_cache_fix.diff (README)
- xla_runtime_memory_info.diff (README)
- tpu_backend_provenance_v6e8_2026-03-16.json
- dist_optimizer_stress_tpu_v6e_2026-03-22.md
- review_gcp_tpu.md
- training_review.md