libtpu
jax
torch_xla
pjrt
tpu
v6e
nanochat

libtpu, JAX and torch_xla in One Container: Device Init Races and Env-Var Landmines

What it looks like when a single TPU host runs JAX and torch_xla side by side: libtpu initialization races, VFIO contention, PJRT platform resolution, and the env-var surface you only learn about when it breaks.

9 min readDavid Gornshtein
libtpu, JAX and torch_xla in One Container: Device Init Races and Env-Var Landmines

libtpu, JAX and torch_xla in One Container: Device Init Races and Env-Var Landmines

The nanochat POC that feeds the MegaCpp SLM ensemble ships one container per TPU host that contains both torch_xla (for training) and jax (for tokenizer utilities, numerical reference tests, a few SPMD helpers, and upstream Pallas kernels we wrap). On a GPU box that would be unremarkable. On a TPU v6e host it is a minefield, because torch_xla and jax both link the same libtpu and both try to own the physical device.

This post documents what actually breaks when those two runtimes share a host, the env-var surface that controls the fight, and the debug patterns that finally made it reproducible.

The shared object underneath

Both stacks route through PJRT and both ultimately call into libtpu. On our pinned v6e-8 hosts:

  • torch 2.9.0a0+git21fec65 or torch 2.11.0a0+git7afdbae
  • torch_xla 2.9.0+gitc04e61c or torch_xla 2.11.0+gitc04e61c
  • jax 0.9.0, jaxlib 0.9.0
  • libtpu 0.0.36 (production) or 0.0.37.dev20260224+nightly
  • PJRT API version: plugin 0.94, framework 0.91 (forward-compat OK)

libtpu is the single consumer of /dev/vfio/0. There is one, and it is not re-entrant across processes. The moment two processes try to claim it at the same time, one wins and the other receives an opaque initialization failure. That is the primitive under everything that follows.

The race we kept losing

The symptom was that our test suite would pass cleanly on GPU, pass most of the time on TPU, and fail intermittently on TPU when run against a host that was already training. The failing tests were TestAutogradGuard cases on clustered sparse attention. They used jax only as a numerical reference - pure CPU math, no TPU intent. They were also the first tests imported in their module, which is why they drew the short straw.

The actual failure was that jax imported, saw a TPU-capable libtpu on the system, and tried to grab /dev/vfio/0 for TPU initialization. torch_xla already owned it from the training process next door. jax failed to init, and the test harness treated the failure as a crash. Worse, on a quiet host the tests would pass because jax won the race, then torch_xla would later fail to init during the next training launch.

The fix is to tell jax explicitly that it is CPU-only for these tests:

def _import_jax_cpu():
    os.environ.setdefault("JAX_PLATFORMS", "cpu")
    import jax
    return jax

Every numerical-reference test in the suite now goes through _import_jax_cpu(). JAX_PLATFORMS=cpu has to be set before import jax, and it has to be set in every subprocess that imports jax - setdefault handles both cases without clobbering explicit overrides. On test runs we set JAX_PLATFORMS=cpu at the pytest invocation level too, as a belt-and-braces measure:

JAX_PLATFORMS=cpu python -m pytest tests/...

After that change the v6e-8 validate host ran 903 passed, 1 skipped, 0 failures on the cycle 5 suite. The previous run had the single JAX init failure we chased for most of a day.

Env-var sensitivity, documented

Once we understood that env-vars were the real public API of libtpu, we wrote them down. The ones that matter in practice:

  • PJRT_DEVICE=TPU: the launch signal for torch_xla's TPU path. Must be set before import torch_xla. Absence silently falls back to CPU.
  • JAX_PLATFORMS: cpu for numerical reference tests; unset (or tpu) for the rare JAX-on-TPU helpers. Must be set before import jax.
  • XLA_NO_SPECIAL_SCALARS=1: part of the TPU run contract. Under SPMD, not setting it produced subtle wrong-answer regressions on scalar broadcasts that only showed up as higher loss.
  • XLA_COMPILATION_CACHE_DIR: the authoritative early-init knob for the persistent XLA cache on scripts/base_train.py. Default is /data/.xla_cache if /data exists, else ~/.cache/xla_compilation. There is also a --xla_cache_dir CLI flag, but it does not override the first early cache init; if you care where the cache lives, set the env var.
  • NANOCHAT_NO_XLA_CACHE=1: disables the persistent cache entirely.
  • LIBTPU_INIT_ARGS: the catch-all for libtpu-level flags. We now prefer --xla_flags_extra for one-off overrides rather than copying a full LIBTPU_INIT_ARGS blob between run scripts, because the libtpu flag vocabulary changes between versions (see below).
  • TPU_LOG_DIR / TPU_STDERR_LOG_LEVEL: off by default; bumped during device-init debugging. The log noise is high enough that you do not want it on by default on a training host.

Order matters. PJRT_DEVICE, JAX_PLATFORMS, and XLA_COMPILATION_CACHE_DIR must be set before the first import of the corresponding framework. In scripts/base_train.py the XLA cache init happens very early in main() precisely because a later init does not override the first one.

The libtpu flag vocabulary drift

Between libtpu 0.0.36 and 0.0.37.dev20260224, a flag we relied on quietly changed:

  • --xla_disable_hlo_passes=... - valid on 0.0.36, unknown flag on 0.0.37.
  • The replacement on 0.0.37 is --xla_jf_vmem_memory_space_assignment=false.

A run script that carried the old flag fails closed on 0.0.37 with "unknown flag" and the run does not start. That is, fortunately, loud - the failure is explicit at launch rather than silently ignored. The contract we moved to: log the libtpu version in every run manifest, and let nanochat/xla_flags.py pick the flag name based on the detected libtpu version rather than hard-coding either one.

The broader lesson is that the LIBTPU_INIT_ARGS flag surface is not stable across nightlies. Treat it as internal to each libtpu release, not as an API.

MSA and the NaN hunt

The most unpleasant libtpu interaction we hit was not a race; it was a numerical regression. In early March, every NAM52 bare FSDP run on torch 2.11 + libtpu 0.37 produced NaN at step 1. So did the same config on torch 2.9 + libtpu 0.37. So did the same config with FSDP disabled. So did the same config with the "disable MSA" flag removed.

The bisect matrix looked like this:

Lane torch libtpu Config Result
T1 2.11 0.37 NAM52 FSDP NaN at step 1
T2b 2.9 0.37 NAM52 FSDP NaN at step 1
T6 2.11 0.37 NAM52 no-FSDP NaN at step 1
T7 2.11 0.37 NAM52 FSDP, no MSA-disable NaN at step 1
T9 2.11 0.37 d=12 FSDP + MSA-disable libtpu 0.37 unknown flag
T11 2.9 0.36 NAM52 FSDP Compiling (no NaN)

What T11 told us: the NaN was not exclusively a libtpu 0.37 regression, but 0.36 was the only version where we could run the known-good MSA workaround at all. The 0.37 flag rename (--xla_disable_hlo_passes -> --xla_jf_vmem_memory_space_assignment=false) meant a script mechanically ported from 0.36 could not even express the workaround. We ended up with a rule: bare NAM52 FSDP trains on 0.36 until 0.37 has an equivalent receipt, and MoE configs can ride on 0.37 because they don't hit the same memory-space-assignment pass in the same way.

That is not a satisfying answer. It is the honest one. libtpu nightlies change HLO-pass behaviour between builds, and bisecting the combined surface of torch / torch_xla / libtpu / model config is slow.

Mode A, MoE, and the libtpu SIGKILL

Two TPU compile modes, as documented elsewhere: Mode B (per-micro-step compile around fwd+bwd) and Mode A (whole-step compile through torch_xla.compile()). On dense models libtpu 0.0.37 enables Mode A, and it is 30-56% faster than Mode B. On MoE models Mode A consistently SIGKILLs during compilation:

  • Dense d12 Mode A: 482K tok/sec (works)
  • Dense d12 + Mamba AAM Mode A: 280-634K tok/sec (works)
  • NAM12 MoE(64) Mode A: SIGKILL during compile
  • NAM12 MoE(8) Mode A: SIGKILL during compile

The crash is internal to libtpu on the MoE routing sub-graph (gather/scatter/top-k). torch_xla hands the HLO down; libtpu dies. From the Python side there is no useful traceback; the process is terminated by the kernel with SIGKILL.

The operational contract we wrote: Mode A for dense / Mamba-only, Mode B for anything with MoE. If you need Mode A speed on MoE, you need to wait for a libtpu revision that does not kill on those routing graphs; there is no Python-side workaround.

The XLA persistent cache is per-libtpu

One more interaction worth documenting. The patched torch_xla persistent cache (covered in the companion PJRT post) stores XLAEXE executables on disk. Those executables are produced by a specific libtpu build. Bumping libtpu - even between 0.36 and 0.37 within the same torch_xla build - should in principle invalidate the cache entry and force a recompile.

In practice we have seen two cases worth knowing:

  1. libtpu 0.37 still returned UNIMPLEMENTED on cache deserialization against unpatched torch_xla. Our patch works against both libtpu versions, but the upstream path is broken across both.
  2. A libtpu bump while an old cache is on disk is not guaranteed to reject every stale entry. We have not seen a wrong-answer outcome from this, but we also do not rely on libtpu to garbage-collect our cache. Our discipline: when we bump libtpu on a host, we wipe XLA_COMPILATION_CACHE_DIR by hand and let the next run repopulate.

That is a small operational cost - a warm cache is a few GB and repopulates in one or two compile cycles - and it is much cheaper than debugging a stale-cache incident during a live training run.

Debug patterns that paid off

Three patterns actually moved debugging forward when the runtime was opaque.

PT_XLA_DEBUG=1 on the first step, then off. Leaving it on destroys throughput and floods the log. Running it for the first 10 steps of an ablation, captured into a dedicated log, was enough to see which ops caused host-device syncs and which metadata columns were triggering new trace variants. We then disabled it for the steady-state run.

Explicit _xla_runtime_memory_info("TPU:N") on each rank. Our pybind patch (see the companion torch_xla post) lets a running training process ask libtpu for HBM usage per physical chip. We wire this into /memory and poll from the host. Without it, you get OOM crashes with no forensic data about which chip ran out first or how close you were. With it, auto-fit can narrow dbs proactively and the retry loop has signal to act on.

Provenance JSON per run. Every training launch writes a provenance record with hostname, python, torch, torch_xla, jax, and libtpu versions plus the set of backend modes exercised (xla_flash, splash, softcap on/off). A tiny example from a March validate run:

{
  "provenance": {
    "python": "3.13.12",
    "torch": "2.9.0a0+git21fec65",
    "torch_xla": "2.9.0+gitc04e61c",
    "jax": "0.9.0"
  },
  "tests": [
    { "mode": "xla_flash", "backend_used": "xla_flash_pallas", "status": "success" },
    { "mode": "splash",    "backend_used": "xla_splash_via_trace_pallas", "status": "success" }
  ]
}

Three months of those records is what lets us say "this regression is new" rather than "this regression is vibes". It is also what lets us correlate a failure on one v6e-8 host with the same stack on another, which turns out to matter when four machines in a four-machine fleet each run a slightly different torch build.

Fleet discipline, not per-host tweaks

The eventual operational rule is dull and non-negotiable.

One pinned Python, one pinned torch, one pinned torch_xla, one pinned libtpu, one pinned jax, per host. Recorded in a provenance JSON. Logged in every training manifest. JAX_PLATFORMS=cpu unless a test explicitly needs TPU JAX. PJRT_DEVICE=TPU and XLA_NO_SPECIAL_SCALARS=1 set in the training environment before the first framework import. Persistent cache wiped on libtpu bump. Mode A only on dense; Mode B on MoE.

Every time we relaxed one of those rules under time pressure, the cost was higher than doing it right would have been. The interactions between jax, torch_xla, and libtpu on a single TPU host are not well-documented because most users run exactly one of the three. If you run all three, the contract is yours to enforce.

References

  • TPU_SETUP.md
  • CURRENT_STATE.md
  • BACKEND_STOPLIGHT_MATRIX.md
  • GCLOUD_AUTH_REFRESH.md
  • TPU_BACKUPS.md
  • CHANGELOG.md
  • persistent_cache_fix.diff (README)
  • xla_runtime_memory_info.diff (README)
  • tpu_backend_provenance_v6e8_2026-03-16.json
  • dist_optimizer_stress_tpu_v6e_2026-03-22.md
  • review_gcp_tpu.md
  • training_review.md
David Gornshtein • Datasunrise OÜMore posts →