MegaCpp EngineeringApplied C++ model systems
</>
Article
Grounded engineering note from the MegaCpp stack
Published 5 min readDavid Gornshtein
Torch XLA
PJRT
XLA
TPU
Training
Evaluation

Torch XLA and PJRT reality: what actually matters

A grounded look at the current TPU stack: PJRT contracts, SPMD setup order, reduction semantics, and the failure modes that still shape training and evaluation.

MegaCpp
Focused on applied C++ model engineering
Article Preview
Torch XLA and PJRT reality: what actually matters
Published 5 min readDavid Gornshtein

The current TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingHistory: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries lane is not a generic "install XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: XLA SPMD sharding annotations and go" setup. It depends on the modern PJRTQuick term guidePJRTThe TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.GroundingAbout: libtpu / PJRT ownership boundaries Example: TPU backend ownership note runtime, early SPMD initialization, and careful reduction behavior in evaluation code. The practical rule is simple: use a frontend and runtime that agree on the same contract, set runtime policy before imports, enable SPMD before tensors exist, and never assume a TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingHistory: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries metric is globally reduced unless the code path proves it.

The substrate boundary matters. PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: XLA SPMD sharding annotations is the PyTorch path for TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingHistory: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries. JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries Reference: Pallas on TPU is a separate frontend with its own tracing and execution model. Pallas is a JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries Reference: Pallas on TPU kernel surface. None of those should be collapsed into an NVIDIA precision or CUDA-kernel story.

There is a lot of stale advice around Torch XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: XLA SPMD sharding annotations. Some of it was valid for older XRT-era setups, some of it assumes stock wheels are enough, and some of it ignores how training code mixes evaluation, optimizer state, and mesh construction. The decisive details are runtime ownership, import order, graph stability, and reduction semantics.

That is also why this post sits next to libtpu, PJRT, and JAX ownership boundaries and XLA SPMD sharding annotations: most failures here come from mixing frontend ownership, runtime selection, and sharding vocabulary that belong to different layers of the stack.

The first reality: the stack is a runtime contract

PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: XLA SPMD sharding annotations has moved from the older XRT runtime to PJRTQuick term guidePJRTThe TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.GroundingAbout: libtpu / PJRT ownership boundaries Example: TPU backend ownership note, and current TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingHistory: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries guidance assumes PJRTQuick term guidePJRTThe TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.GroundingAbout: libtpu / PJRT ownership boundaries Example: TPU backend ownership note by default when no older runtime is configured. That makes the version boundary between torch, torch_xla, and the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingHistory: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries runtime layer part of the execution contract rather than a background packaging detail.

That explains most TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingHistory: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries confusion. When people say "XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: XLA SPMD sharding annotations is unstable," they often mean they mixed a current TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingHistory: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries runtime with an older software contract. When they say "PJRTQuick term guidePJRTThe TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.GroundingAbout: libtpu / PJRT ownership boundaries Example: TPU backend ownership note changed everything," they are partly right, but the practical lesson is operational rather than philosophical: keep the frontend and runtime on the same contract boundary.

Public Torch XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: XLA SPMD sharding annotations guidance reinforces the same picture from another angle: frontend tracing, PJRTQuick term guidePJRTThe TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.GroundingAbout: libtpu / PJRT ownership boundaries Example: TPU backend ownership note runtime behavior, and device-runtime compatibility have to align before model-level debugging even starts.

Question Public answer Why it matters
Is PJRT the current runtime surface? Yes runtime expectations start there
Is TPU support bolt-on? No initialization order and mesh semantics are central
Can I trust a local TPU metric by default? No, only after the reduction path proves it silent fallback can misreport eval quality or throughput

If you take only one thing from this, make it this: TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingHistory: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries stability starts with a correct contract between the framework frontend, PJRTQuick term guidePJRTThe TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.GroundingAbout: libtpu / PJRT ownership boundaries Example: TPU backend ownership note runtime, and TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingHistory: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries runtime layer.

The second reality: import order and SPMD timing are part of correctness

The training startup path is explicit about setup order. Runtime flags must be applied before importing torch_xla, and xr.use_spmd() must be called before any Torch XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: XLA SPMD sharding annotations tensor exists. This is not cosmetic startup sequencing. It is part of the runtime contract.

That means TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingHistory: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries setup has two early gates.

  1. Set the environment and runtime policy before importing the runtime.
  2. Enable SPMD before constructing tensors or letting helper paths import XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: XLA SPMD sharding annotations indirectly.

A compilation cache belongs in that same startup contract. Caching is not just a convenience; it changes how repeated runs behave and helps separate cold compile cost from actual execution regressions.

On shared multi-host mounts, the safe pattern is either a writable cache per process or host, or one writer that warms a shared cache while the other workers stay read-only against that mount. When every worker writes into the same directory, the symptom often looks like "PJRTQuick term guidePJRTThe TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.GroundingAbout: libtpu / PJRT ownership boundaries Example: TPU backend ownership note startup is slow" even though the real bug is cache ownership and file-lock contention.

example TPU startup contract:
  set PJRT_DEVICE = TPU
  apply TPU runtime flags before importing the runtime
  enable SPMD before tensors exist
  start training with an explicit runtime profile

The exact launcher can vary, but the principle does not. If a helper imports XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: XLA SPMD sharding annotations too early, the rest of the run becomes hard to reason about.

The third reality: evaluation can lie if reduction semantics are weak

One of the most useful surfaces for understanding TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingHistory: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries behavior is the evaluation path used for loss aggregation. Any metric that depends on a collective has to prove that the collective really happened. Otherwise a run can look healthy while reporting local rather than global totals.

This is not a niche bookkeeping error. In practice it affects how you interpret evaluation curves, cost-per-token calculations, and any claim about TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingHistory: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries scaling efficiency. If a global metric silently degraded into a local metric, then the dashboard is not merely noisy; it is wrong.

For operators, the rule should stay simple: if the metric depends on a collective, confirm the collective path explicitly.

In practice that usually means using xm.mesh_reduce(...) or xm.all_reduce(...) before any loss.item() or host-side conditional. Once a local host read sneaks in first, the run can keep looking healthy while the reported metric no longer describes the global TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingHistory: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries step.

Mesh construction is the real TPU mental model

The right way to think about PJRTQuick term guidePJRTThe TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.GroundingAbout: libtpu / PJRT ownership boundaries Example: TPU backend ownership note in this stack is not as a magical optimizer. It is the runtime contract underneath the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingHistory: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries execution model. The engineering task is to build the right mesh, shard the right tensors, and preserve those assumptions through the training and eval stack.

That matters especially in a project that mixes different training and evaluation paths. Two TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingHistory: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries lanes may both use XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: XLA SPMD sharding annotations and PJRTQuick term guidePJRTThe TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.GroundingAbout: libtpu / PJRT ownership boundaries Example: TPU backend ownership note, but they may not be testing the same execution surface. The runtime contract may be global, but the failure surfaces are still local.

The setup story is mature enough to be useful, but not simple enough to ignore

A good sign is that the public documentation is concrete here. The runtime and PJRTQuick term guidePJRTThe TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.GroundingAbout: libtpu / PJRT ownership boundaries Example: TPU backend ownership note notes define ownership and setup order clearly. The XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: XLA SPMD sharding annotations profile sample keeps policy visible instead of hiding it in shell history.

The mature posture is therefore neither optimism nor panic. It is to treat TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingHistory: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries startup, reduction semantics, and mesh ownership as first-class engineering surfaces rather than background details.

FAQ

Frequently asked questions

What does a safe shared compilation cache look like on multi-host TPU?+
Either segmented writable caches per process or host, or one writer that populates a shared cache while the other workers use it read-only. The dangerous pattern is many writable workers racing on one directory, because the resulting lock contention looks like runtime slowness instead of a cache-ownership bug.
Does PJRT_DEVICE=TPU prove the TPU lane is healthy?+
No. It states launch intent, not runtime proof. A healthy lane still needs a real runtime probe, which is why the safer local pair is TPU runtime probe sample and XLA compile/runtime controls sample. If that probe fails, stay at runtime bringup instead of debugging model code first.
Why can a TPU metric look plausible and still be wrong?+
Because the dangerous failure mode is often local rather than obviously broken. If a host-side read happens before the intended collective, each rank can materialize its own lazy value and the run still "looks healthy" while the dashboard is now summarizing per-rank numbers instead of the global TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels. step. That is why this lane keeps xm.mesh_reduce(...) and xm.all_reduce(...) in the same correctness story as startup order: Torch XLA and PJRT reality is not only about getting a runtime, it is also about proving that the reported metric still matches the distributed contract.
When does pin_layout matter for XLA collectives?+
Treat it as a collective-layout guardrail, not a tuning knob to flip casually. PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here. pins collective layouts by default to keep participants in the communication on compatible layouts, but the same constraint can become the thing that exposes a layout-mixing compile failure. If an xm.all_reduce(...) lane fails with a layout-constrained HLO error, the safer diagnosis is not "collectives are broken"; it is that the program shape, SPMD timing, and collective layout contract need to be checked together.
What changes if a Torch/XLA lane intentionally bridges into JAX or Pallas?+
The ownership rule gets stricter, not looser. Keep PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here. as the main frontend, keep the bridge narrow, and if the lane uses the PyTorch/XLA Pallas path, run the JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes. import guard before importing JAX so TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels. ownership does not deadlock at the client boundary. In MegaCpp, the public-safe pattern is a small checked-in bridge plus a checked-in receipt that keeps backend ownership visible instead of flattening everything into "the TPU path": libtpu, PJRT, and JAX ownership boundaries, Pallas kernels on TPU v6e, call_jax TPU bridge runtime sample, and XLA Pallas bridge receipt sample.
Glossary

Terms used in this article

Start here for quick definitions, then follow the linked posts for deeper context.

PJRT

The TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.

libtpu

The TPU backend library that pairs with PJRT/XLA and owns device-side execution underneath the frontend.

JAX

A separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.

XLA

The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.

TPU

Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.

XLA SPMD

The explicit TPU sharding mode where one compiled program carries placement rules instead of rank-local imperative code.

Topic hubs