XLA-safe AdamW and TPU runtime flags on v6e
How to keep optimizer math graph-friendly on TPU, treat runtime flags as explicit launch policy, and recalibrate after stack changes.

On accelerator backends, the AdamW step often sits near the boundary between traced tensor math and Python control flow. On Cloud TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries v6e, that boundary is one of the first places where scalar extraction, host-device sync, or shape drift can trigger recompilation. The practical fix is simple: keep scalar handling graph-friendly, treat TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries runtime flags as explicit launch policy, and recalibrate after stack changes.
Why the optimizer step is the canary
The optimizer step sits at an awkward boundary between Python control flow and tensor math. On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries that boundary matters because scalar extraction and changing-shape control flow are both compile-sensitive in XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations.
The rule is straightforward: values that change every step should remain visible to the traced program as tensors, not escape to Python scalars inside the hot path.
# Stylized TPU-safe sketch.
def adamw_step_xla(p, g, exp_avg, exp_avg_sq, scalars):
p.mul_(scalars["one_minus_lr_wd"])
exp_avg.lerp_(g, scalars["one_minus_b1"])
exp_avg_sq.mul_(scalars["b2"]).addcmul_(g, g, value=scalars["one_minus_b2"])
The exact implementation can vary. The important point is that the graph sees tensors and stable shapes rather than a stream of fresh Python values.
If the lane also uses AMP, keep the step path graph-friendly too. TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries AMP uses bf16 and does not need gradient scaling, but PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations still calls out sync-free optimizers as the safer performance path because they avoid extra device-host sync in the optimizer step. In this repo the checked-in public-safe surface is XLA-safe AdamW example, which keeps that contract visible without leaning on private launcher details.
Runtime flags should be policy, not shell state
The public XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations flag sample is intentionally small, but it captures the main point: flag changes should be deliberate, reviewable, and grouped by purpose rather than scattered across shell wrappers.
| Group | Representative policy | Why it matters |
|---|---|---|
| SPMD enablement | explicit startup flagging | makes mesh assumptions visible |
| compile cache policy | explicit cache mode | separates cold-start from steady-state effects |
| shape guard policy | strict input contract | reduces accidental recompile drift |
| launch profile selection | named runtime profile | keeps runs comparable |
The takeaway is not that one magic flag profile solves TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries performance. It is that runtime policy should be explicit and narrow enough that when performance changes, you can tell whether the cause was the graph, the inputs, or the runtime profile.
That is also why the launcher should own the runtime profile before the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations runtime is initialized. Startup order is part of correctness here, not only style.
In practice that means settings such as PJRT_DEVICE, compile-cache policy, and similar TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations runtime choices belong on the pre-import side of the launch boundary. The checked-in XLA flag profile example keeps that contract visible: once the runtime is initialized, those settings are no longer normal hot-path tuning knobs.
Calibration matters after stack changes
A small startup calibration is cheaper than repeating a long failing launch. What matters is recording predicted versus observed memory behavior and feeding that back into the next run.
That loop is what makes runtime-policy changes survivable across stack upgrades. When the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries runtime changes behavior, the calibration record should catch the mismatch before a large run turns into an avoidable OOM or recompilation storm.
Takeaway
The TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries optimizer story is not really about one exotic optimizer trick. It is about respecting the graph contract, keeping runtime policy explicit, and recalibrating when the stack changes.
Frequently asked questions
Why mention sync-free optimizers on TPU if bf16 AMP does not need gradient scaling?+
Why treat TPU startup calibration as a launch signature instead of a one-off OOM workaround?+
Should XLA_IR_DEBUG and XLA_HLO_DEBUG stay enabled in the normal TPU profile?+
Terms used in this article
Start here for quick definitions, then follow the linked posts for deeper context.
The TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.
The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.
The named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.
Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.