MegaCpp EngineeringApplied C++ model systems
</>
Article
Grounded engineering note from the MegaCpp stack
Published 3 min readDavid Gornshtein
TPU
V6e
XLA
Adamw
PJRT
Calibration

XLA-safe AdamW and TPU runtime flags on v6e

How to keep optimizer math graph-friendly on TPU, treat runtime flags as explicit launch policy, and recalibrate after stack changes.

MegaCpp
Focused on applied C++ model engineering
Article Preview
XLA-safe AdamW and TPU runtime flags on v6e
Published 3 min readDavid Gornshtein

On accelerator backends, the AdamW step often sits near the boundary between traced tensor math and Python control flow. On Cloud TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries v6e, that boundary is one of the first places where scalar extraction, host-device sync, or shape drift can trigger recompilation. The practical fix is simple: keep scalar handling graph-friendly, treat TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries runtime flags as explicit launch policy, and recalibrate after stack changes.

Why the optimizer step is the canary

The optimizer step sits at an awkward boundary between Python control flow and tensor math. On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries that boundary matters because scalar extraction and changing-shape control flow are both compile-sensitive in XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations.

The rule is straightforward: values that change every step should remain visible to the traced program as tensors, not escape to Python scalars inside the hot path.

# Stylized TPU-safe sketch.
def adamw_step_xla(p, g, exp_avg, exp_avg_sq, scalars):
    p.mul_(scalars["one_minus_lr_wd"])
    exp_avg.lerp_(g, scalars["one_minus_b1"])
    exp_avg_sq.mul_(scalars["b2"]).addcmul_(g, g, value=scalars["one_minus_b2"])

The exact implementation can vary. The important point is that the graph sees tensors and stable shapes rather than a stream of fresh Python values.

If the lane also uses AMP, keep the step path graph-friendly too. TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries AMP uses bf16 and does not need gradient scaling, but PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations still calls out sync-free optimizers as the safer performance path because they avoid extra device-host sync in the optimizer step. In this repo the checked-in public-safe surface is XLA-safe AdamW example, which keeps that contract visible without leaning on private launcher details.

Runtime flags should be policy, not shell state

The public XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations flag sample is intentionally small, but it captures the main point: flag changes should be deliberate, reviewable, and grouped by purpose rather than scattered across shell wrappers.

Group Representative policy Why it matters
SPMD enablement explicit startup flagging makes mesh assumptions visible
compile cache policy explicit cache mode separates cold-start from steady-state effects
shape guard policy strict input contract reduces accidental recompile drift
launch profile selection named runtime profile keeps runs comparable

The takeaway is not that one magic flag profile solves TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries performance. It is that runtime policy should be explicit and narrow enough that when performance changes, you can tell whether the cause was the graph, the inputs, or the runtime profile.

That is also why the launcher should own the runtime profile before the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations runtime is initialized. Startup order is part of correctness here, not only style.

In practice that means settings such as PJRT_DEVICE, compile-cache policy, and similar TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations runtime choices belong on the pre-import side of the launch boundary. The checked-in XLA flag profile example keeps that contract visible: once the runtime is initialized, those settings are no longer normal hot-path tuning knobs.

Calibration matters after stack changes

A small startup calibration is cheaper than repeating a long failing launch. What matters is recording predicted versus observed memory behavior and feeding that back into the next run.

That loop is what makes runtime-policy changes survivable across stack upgrades. When the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries runtime changes behavior, the calibration record should catch the mismatch before a large run turns into an avoidable OOM or recompilation storm.

Takeaway

The TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries optimizer story is not really about one exotic optimizer trick. It is about respecting the graph contract, keeping runtime policy explicit, and recalibrating when the stack changes.

FAQ

Frequently asked questions

Why mention sync-free optimizers on TPU if bf16 AMP does not need gradient scaling?+
Because the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels. benefit is the step boundary, not loss scaling. The real question is whether the optimizer can stay inside the traced step without extra host-side reads or sync seams, which is the same graph-contract story described in Graph recompilation hell and Torch XLA and PJRT reality.
Why treat TPU startup calibration as a launch signature instead of a one-off OOM workaround?+
Because the startup memory frontier moves with more than batch size alone. The checked-in startup calibration catalog example models the launch signature as code state, hardware, model shape, parallelism, and feature toggles, because any of those can move whether the first compile window fits. That is why calibration belongs to the explicit launch policy described here and in OOM on v6e, not only to ad hoc debugging. The matching startup calibration record sample also keeps predicted versus observed startup memory and pushes known-bad signatures behind unseen or known-good ones on the next retry. The point is not to predict final model quality; it is to stop paying for the same avoidable startup failure after a stack or runtime-profile change.
Should XLA_IR_DEBUG and XLA_HLO_DEBUG stay enabled in the normal TPU profile?+
No. Treat them as diagnostic launch profiles for recompilation triage, not as steady-state policy. PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here. documents these variables as debug controls that capture Python stack frames into IR and HLO metadata, and the same troubleshooting surface warns that debugging variables can degrade performance. Use them when a run is proving where shape drift or host-side materialization entered the graph, then return to the smaller launch profile once the seam is fixed.
Glossary

Terms used in this article

Start here for quick definitions, then follow the linked posts for deeper context.

PJRT

The TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.

XLA

The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.

mesh

The named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.

TPU

Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.

Topic hubs