Torch 2.12 TPU/XLA breakage matrix: wheel pain, cache misses, and the workarounds that actually mattered
A repo-grounded account of where the TPU/XLA stack broke, which failures needed upstream-facing patches, and which ones were better handled as explicit MegaCpp runtime policy.

Torch 2.12 TPU/XLA breakage matrix: wheel pain, cache misses, and the workarounds that actually mattered
The practical TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations story around a Torch 2.12-class stack was not "upgrade the wheel and rerun." It was a breakage matrix. Some failures came from version skew across PyTorch, torch_xla, OpenXLA, PJRTQuick term guidePJRTThe TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.GroundingAbout: Torch XLA / PJRT reality About: libtpu / PJRT ownership boundaries Example: TPU backend ownership note, and libtpuQuick term guidelibtpuThe TPU backend library that pairs with PJRT/XLA and owns device-side execution underneath the frontend.GroundingAbout: libtpu / PJRT ownership boundaries Example: XLA runtime probe sample Example: TPU backend ownership note. Others came from backend behavior that looked acceptable at import time but failed on the real training path.
What mattered in MegaCpp was separating three categories:
- failures that needed a patch against the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations stack itself
- failures that were real but better handled as launch or model policy
- features that only existed in newer
torch_xlalines and therefore had to be gated by version
The matrix was broader than a version pin
The most useful build record was not a single package list but a compatibility bundle. One validated TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations build recorded PyTorch 2.9.0a0+git21fec65, torch_xla 2.9.0+gitc04e61c, OpenXLA a76a9a858, Python 3.13, and a set of local API patches needed just to keep that newer OpenXLA/PJRTQuick term guidePJRTThe TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.GroundingAbout: Torch XLA / PJRT reality About: libtpu / PJRT ownership boundaries Example: TPU backend ownership note layer buildable together. That is the right framing: on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations, a version bump is really an ABI and runtime-contract check, not only a packaging event.
The same repo history also documents a later custom stack used for cache testing that moved to PyTorch 2.11.0a0+git7afdbae with the same torch_xla commit family plus local fixes. In other words, even before talking about Torch 2.12 in public, the operational lesson was already clear: a TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations lane is defined by the whole bundle, not by torch alone.
That same framing is why the public-safe support surfaces in this repo are runtime-first rather than wheel-name-first. The TPU runtime probe sample separates "package installed" from "backend actually alive," the XLA compile/runtime controls sample keeps cache and startup policy explicit, and TPU backend ownership records which layer owns frontend, runtime, and fallback decisions.
Breakage matrix
| Surface | Symptom | What MegaCpp did |
|---|---|---|
| persistent compilation cache | cache files were written but restarts still recompiled from scratch | patched torch_xla to load serialized executables through the PJRT C API path |
| SPMD memory reporting | in-process HBM reporting failed on virtual SPMD:0 devices |
added a raw runtime-device memory binding and queried physical TPU:* devices |
| optimizer graph stability | cache behavior and graph hashes were polluted by scalar extraction in AdamW | forced capturable=True for XLA AdamW |
| checkpoint resume on XLA | loading with assign=True would replace XLA parameters with CPU tensors |
used assign=False on XLA resume paths |
scan_layers availability |
compile-time improvement feature depended on the torch_xla line in use |
gated the feature and warned when the local torch_xla build did not expose it |
torch.compile on TPU |
compile path could OOM during TPU compilation | kept TPU in eager mode and relied on XLA JIT instead |
That table is the real upgrade artifact. It says which failures were packaging or backend defects and which ones were simply the wrong runtime policy for TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries.
The most expensive bug was persistent cache that only wrote
The clearest upstream-facing defect was the persistent compilation cache. The patch write-up documents that cache files were successfully written but never read back, so every restart still paid the compile bill. The root cause in the patch write-up is precise: torch_xla used PjRtClient::DeserializeExecutable(), which returned UNIMPLEMENTED, while the working path was PjRtClient::LoadSerializedExecutable(), the PJRTQuick term guidePJRTThe TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.GroundingAbout: Torch XLA / PJRT reality About: libtpu / PJRT ownership boundaries Example: TPU backend ownership note C API route that reaches PJRT_Executable_DeserializeAndLoad in libtpuQuick term guidelibtpuThe TPU backend library that pairs with PJRT/XLA and owns device-side execution underneath the frontend.GroundingAbout: libtpu / PJRT ownership boundaries Example: XLA runtime probe sample Example: TPU backend ownership note.
The patch changed two things. It serialized executables first instead of only HLO, and on load it switched from DeserializeExecutable() to LoadSerializedExecutable(). The documented result was a restart-time improvement from 11.5 seconds to 1.7 seconds on a small validation model, with the larger-model expectation dropping from roughly 47 minutes to roughly 7 minutes.
That distinction matters because the repo also preserves the earlier dead end: HLO-cache writes existed, but they mostly saved tracing overhead rather than the real TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries compile cost. MegaCpp therefore treated executable-cache loading as the actual fix and plain HLO caching as insufficient.
For readers, the practical proof is warm-restart behavior on the same workload, not the presence of cache artifacts on disk. The XLA compile/runtime controls sample keeps that contract explicit, and Graph recompilation hell is the companion read when a cache miss is really a graph-stability problem.
The SPMD memory problem was an API mismatch, not a dashboard bug
Another failure looked small until it blocked observability. Under XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations SPMD, the Python-visible device could be SPMD:0, but memory queries needed the physical runtime devices such as TPU:18. The existing binding path rejected those physical strings before they reached the computation client.
The local patch added _xla_runtime_memory_info(device_str), a pybind that bypassed the usual device parsing and forwarded the raw runtime device string directly to the computation client. On the MegaCpp side, the training code was updated to prefer that runtime binding, enumerate physical runtime devices, and fail loudly with an actionable message if the build did not include the patch.
This is a good example of the right public lesson: sometimes a backend feature is not missing, but the Python-visible API is aimed at the wrong abstraction layer. In this case, the correct fix was not a profiler workaround disguised as observability. It was a small binding that exposed the physical-device query path the runtime already understood.
Some "breakages" were really policy mistakes
Not every TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries failure wanted a framework patch.
MegaCpp ended up treating several issues as launch-policy or model-policy corrections:
- XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations checkpoint resume used
assign=Falseso CPU-loaded checkpoint tensors would copy into existing XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations parameters instead of replacing them. - TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries runs stayed in eager mode because
torch.compilewith the OpenXLA backend could OOM during compilation, while XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations JIT already owned the optimization path. - XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations AdamW was configured with
capturable=True, because scalar extraction inside the optimizer step could perturb graph hashing and sabotage cache reuse.
These are important because they change how a breakage matrix should be written. A useful matrix does not label everything as "Torch 2.12 is broken." It distinguishes upstream defects from backend-appropriate policy.
The TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries-side torch.compile choice also belongs with the broader stack split, not as a generic compile-speed recommendation. XLA vs CUDA stack decisions explains why the CUDA and TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries lanes keep different compile boundaries, while OOM on v6e is the companion read for turning a TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries memory failure into a chip-level retry frontier instead of a random knob hunt.
scan_layers was a version and structure gate, not a generic knob
The training stack also documents a practical compile-time optimization via torch_xla.experimental.scan_layers, but it is explicitly version-gated and structurally gated. Older torch_xla lines did not expose scan_layers, so the launch path warns and ignores the flag unless the import succeeds. Even on supported versions, the optimization is only valid when the stacked blocks are structurally homogeneous.
The official docs make the current ceiling even narrower: the traced block has to stay AOTAutograd-traceable, and as of the documented 2.6 limitation scan / scan_layers still cannot trace functions that contain custom PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: XLA Pallas bridge receipt sample Example: Pallas kernel selection note kernels. A decoder whose hot path already depends on a PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: XLA Pallas bridge receipt sample Example: Pallas kernel selection note-backed attention lane therefore needs the plain loop, which is why the XLA backend index keeps PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: XLA Pallas bridge receipt sample Example: Pallas kernel selection note-style kernel paths and runtime-probe paths separate in the public-safe bundle.
Newer PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations scan docs also make the cache gate explicit: scan can become tracing-bound during backward AOTAutograd work, and the scan cache is only enabled when the caller marks the function or layer as pure. For MegaCpp that turns is_layer_pure=True into a receipt-backed claim, not a cosmetic performance flag: if layer state, dropout behavior, custom kernels, or shape-dependent branches differ across blocks, the safe answer is still the plain loop.
That is the right way to write wheel and version notes. "Feature exists" is too vague. The real statement is: this feature exists only on newer torch_xla builds, and it only helps if the model structure obeys the scan contract.
What Torch 2.12 adds to the story
The GPU-side MegaCpp docs repeatedly record a Torch 2.12 nightly stack as a moving compatibility surface rather than a settled platform. That same habit should be applied to TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations. The useful question is not "are we on 2.12?" The useful question is which TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations lane is validated on which full bundle, and which breakages still require local patches or explicit launch policy.
That is why a Torch 2.12 TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations note should end with a matrix and a stance:
- wheel names alone are not the compatibility story
- cache writes are not proof of cache reuse
- SPMD virtual devices are not the right abstraction for every runtime query
- some failures are backend bugs, but others are just the wrong default policy
What we would preserve in any future upgrade
If this stack moves again, the MegaCpp habits worth preserving are narrow:
- Record the full TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations bundle, not just the
torchwheel. - Separate upstream patch candidates from launch-policy fixes.
- Keep version-gated features honest about both API availability and structural preconditions.
- Treat restart-time evidence, not cache-directory existence, as the proof that caching works.
- Make missing TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries observability bindings fail loudly and specifically.
That is the difference between an install note and an operational compatibility document.
Frequently asked questions
How should I debug an SPMD OOM before assuming the backend needs a patch?+
torch_xla.real_devices() tells you which physical TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels. devices sit behind the virtual SPMD:0 view, and visualize_tensor_sharding lets you confirm whether the tensors are actually placed the way your meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense. suggests. If that path already explains the hotspot, treat it as a sharding or model-policy problem before assuming the TPU/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here. stack itself needs surgery.Does capturable=True make AdamW faster on TPU by itself?+
capturable=True around graph capture safety, and the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels./XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here. receipt is narrower: the optimizer step should avoid scalar-driven graph churn when it participates in the warm-restart and cache-reuse contract. The local cross-checks are XLA-safe AdamW and TPU runtime flags and the XLA-safe AdamW example; use the flag when the optimizer state is part of the XLA step boundary, then prove the decision with stable restart behavior rather than assuming a standalone throughput win.How do I choose between torch.compile, normal XLA lazy execution, and scan_layers?+
scan_layers only when the repeated blocks are homogeneous, AOTAutograd-traceable, and not already depending on a PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.-backed attention path; otherwise the plain loop is the safer choice. Use torch.compile(..., backend="openxla") as a separate validation lane rather than a piecemeal loss-wrapper fix, because the Dynamo path can split the training step into more compiled regions and change memory behavior. The local cross-checks are Torch/XLA 2.11 expectations vs TPU reality, Pallas kernels on TPU, and XLA vs CUDA stack decisions.Terms used in this article
Start here for quick definitions, then follow the linked posts for deeper context.
The TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.
The TPU backend library that pairs with PJRT/XLA and owns device-side execution underneath the frontend.
The named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.
JAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.
The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.
Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.