XLA vs CUDA: The Decision Matrix For Our Two Training Stacks
Where we keep one model definition, where the kernels diverge, what determinism we can give on each, how comms differ between NCCL and XLA collectives, and the operator surface that has to stay portable.

We run the same model on two radically different stacks: TPUs through torch_xla / XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingTorch XLA / PJRT reality XLA SPMD sharding annotations SPMD, and NVIDIA GPUs through CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200 / NCCL / Transformer Engine / tensor-parallel training libraries. Keeping both alive is expensive and people reasonably ask why we do not pick one. The answer is that the paths give us different things, and the portability discipline is what makes the duplication sustainable. This post is the decision matrix we use day to day in a dual-stack training workflow: what we keep unified, what we let diverge, what determinism we can guarantee, and where the operator surface has to stay clean.
Why two stacks at all
Recent TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries generations can be a better cost fit for some training shapes, and the XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingTorch XLA / PJRT reality XLA SPMD sharding annotations compiler owns graph partitioning, collective placement, and memory tradeoffs differently from a Python-scheduled CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200 or NCCL stack. Recent NVIDIA accelerators are where FP8 and modern FlashAttention-class kernels arrive first, and where architecture iteration often moves fastest because low-level kernel libraries evolve quickly. Picking only one narrows the operating envelope and locks the stack into a single vendor story. We pay the portability tax on purpose. The interesting question is how much of that tax is necessary. Our answer: the model definition stays unified, the kernels live in two stacks, the communication pattern follows two worldviews, and the determinism story has to be written twice.
What stays unified, what diverges
| Layer | TPU path | CUDA path | Shared? |
|---|---|---|---|
| Model definition | shared model runtime modules | same modules | yes |
| Config | GPTConfig |
GPTConfig |
yes |
| Sharding | XLA SPMD sharding annotations | parallelize_module plus tensor-parallel wrappers |
no |
| Optimizer | XLA-safe AdamW variant with device-resident scalars | fused AdamW under torch.compile |
shared math, two impls |
| Collectives | XLA-inserted HLO ops | NCCL launched from explicit Python scheduling | no |
| Attention | XLA Pallas flash | FA3 / CuTe-backed FA4 | no |
| FP8 | not used | TE fp8_autocast per zone |
no |
| Compile | torch_xla.compile per micro-step |
regional torch.compile per block |
no |
The boundary sits below model.__call__: anything inside the call is shared, anything below is path-specific. When we introduce a new operator (a new attention variant, a new norm, a new routing scheme) it goes into the shared modules with a pure-PyTorch reference, then picks up a Triton/Pallas kernel for CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200 and an XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingTorch XLA / PJRT reality XLA SPMD sharding annotations-blessed implementation for TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries. If either path needs bespoke Python plumbing inside __call__ to make the op work, we refactor until it does not. We have broken this rule a few times and paid for it, which is why the call boundary is now a hard portability rule.
Determinism is two stories
On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries, we treat determinism as a compile-stability problem first. Under a fixed seed, a fixed mesh, and fixed runtime setup, XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingTorch XLA / PJRT reality XLA SPMD sharding annotations can keep replaying the same compiled graph shape, but that stability depends on fixed shapes, fixed control flow, and avoiding host round trips or per-step Python scalars inside compiled regions. The gotchas are the familiar PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingTorch XLA / PJRT reality XLA SPMD sharding annotations ones: warmup policies that change the graph, scalar leaks such as .item()-style host reads, and runtime setup that drifts after tensors are already initialized.
That is why the XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingTorch XLA / PJRT reality XLA SPMD sharding annotations lane uses the steady-state batch and accumulation schedule from step 0 instead of a batch-ramp warmup. A harmless-looking early-step ramp can still change the compiled graph before the real run has started.
On CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200 we get algorithmic determinism rather than bitwise determinism. Flash Attention is non-deterministic under FLASH_ATTN_DETERMINISTIC=0; NCCL reductions are order-sensitive at the least-significant bits; FP8 is a noisy format by design. For bit-exact runs we set torch.use_deterministic_algorithms(True) plus deterministic NCCL plus deterministic FA variants, and pay a throughput cost measured in tens of percent. We use this for golden regression tests, not for deployment training.
Communication is two worldviews
NCCL collectives and XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingTorch XLA / PJRT reality XLA SPMD sharding annotations collectives are different worldviews. NCCL is explicit, synchronous by default, scheduled from Python, and overlapped with compute by deliberate hook placement and bucket batching. In practice that means real gradient-bucket state machines to make overlap happen. XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingTorch XLA / PJRT reality XLA SPMD sharding annotations collectives are compiler-inserted and overlapped through graph-level scheduling rather than Python bucket logic. We have repeatedly seen the same workload bottleneck solved with completely different knobs: on CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200 by retuning bucket sizes and overlap policy, on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries by adjusting memory limits and collective behavior. Same problem, completely different tools.
The XLA-safe optimiser pattern (and why it generalises)
Our XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingTorch XLA / PJRT reality XLA SPMD sharding annotations-safe AdamW exists because torch.optim.AdamW eventually materializes Python scalars inside torch_xla.compile(), which forces a graph break and then recompiles every step with different float constants. The fix is to replace those scalars with 0-D device tensors filled in place under an XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingTorch XLA / PJRT reality XLA SPMD sharding annotations-safe scalar policy:
class XLAAdamW(torch.optim.Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0):
super().__init__(params, dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay))
self._bc2_sqrt_t = torch.tensor(1.0)
self._step_size_t = torch.tensor(0.0)
self._wd_scale_t = torch.tensor(1.0)
@torch.no_grad()
def step(self, closure=None):
# compute bc1/bc2/step_size on host, then fill_() device tensors
# and run the per-parameter update as pure tensor ops.
...
The pattern matters outside AdamW too: any scalar that varies per step has to become a 0-D tensor under XLA_NO_SPECIAL_SCALARS=1 if you want a stable compile cache. On one longer hybrid preset that was the difference between a stable cache and a recurring recompilation spiral.
On this path, XLA_NO_SPECIAL_SCALARS=1 is part of the startup contract, not a late tuning toggle after torch_xla has already initialized.
The same scalar policy applies to schedulers and gradient clipping. If a per-step learning-rate or clip value reappears as a host scalar inside the compiled TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries lane, the optimizer fix did not actually close the compile-stability seam. That is why XLA flag profile sample and XLA compile runtime controls sample belong in the startup-policy surface rather than in an after-the-fact tuning checklist.
Operator-surface rules and the audit
Our portability discipline, learned the hard way: no .item() anywhere that runs under torch_xla.compile(); no Python-scalar graph-constants that change step to step; tuple/list mutations stay outside compiled regions; every op called from model code exists as a pure-PyTorch reference that runs correctly on CPU, XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingTorch XLA / PJRT reality XLA SPMD sharding annotations and CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200; fused kernels sit behind feature flags and are tested for parity against the reference. We added a CI gate that walks the model on CPU under a fake XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingTorch XLA / PJRT reality XLA SPMD sharding annotations device wrapper and asserts no .item() calls and no varying Python-scalar graph constants. The gate has caught two regressions since January.
The other audit is the receipt cross-check. Every CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200 receipt and every TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries receipt records the same step-loss values for the first 100 steps on a small canary preset; if those diverge across stacks beyond the documented determinism budget, the receipt fails and the offending path is investigated.
On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries, the first metric to read after warmup is UncachedCompile. If that counter is still climbing, the supposedly shared lane is not actually compile-stable yet.
What we kept and threw away
We kept one model module tree, two sharding implementations, two optimisers with shared math, two collective worldviews (NCCL explicit, XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingTorch XLA / PJRT reality XLA SPMD sharding annotations compiler-inserted), the rule that any new op ships with a pure-PyTorch reference, the CI gate against .item() in compiled regions, the receipt cross-check between paths, and the rule that path-specific code never lives inside model.__call__.
We threw away the "one big graph" TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries compile story in favor of a per-micro-step boundary, XLA_USE_SPMD=1 as a user-facing toggle, torch.compile(model) whole-model on either path (per-block on CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200, per-micro-step on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries), bitwise determinism on the CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200 training path (we pay its cost only in regression gates), and any attempt to share collective code between NCCL and XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingTorch XLA / PJRT reality XLA SPMD sharding annotations. The throughline is short: pick the duplication carefully, keep the model itself unified, and let the paths be themselves below the call boundary.
How a new feature lands across both paths
The lifecycle of a new feature is sequenced to keep the two paths from drifting. The feature lands first as a pure-PyTorch reference in the shared modules, with unit tests that run on CPU. Once the reference is correct, the CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200-specific kernel lands behind a feature flag with parity tests against the reference. Once the CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200 path is stable, the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries-specific implementation lands with its own parity tests. Only after both paths have parity-tested implementations does the feature get wired into a training preset.
The order matters. Landing the kernel first encourages design choices that are easy on one path and hard on the other; landing the reference first forces the design to be portable from day one. We have caught at least three feature designs at the reference stage that would have required Python-side variable shapes on the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path, and rewriting them at that stage was an order of magnitude cheaper than discovering it after the CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200 kernel had landed.
What the receipt cross-check actually catches
The cross-check compares the first 100 step-loss values on a small canary preset between CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200 and TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries. The determinism budget is documented per feature, and the cross-check passes when the differences sit inside that budget. When it fails, the diagnosis path is short: which path changed, which feature flag flipped, and which receipt's stack line differs.
We keep the first pass of that gate in BF16 on both paths before re-enabling CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200 FP8 zones. If the BF16 receipts already disagree, the failure is in the shared math or the path-specific plumbing; if the BF16 pass is clean and only the FP8 pass moves, we treat that as quantization drift until a more specific receipt proves otherwise.
Before blaming a stack, we also align the dtype at shared attention ingress. A norm path that hands FP32 activations to a BF16 attention block can make CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200 insert visible casts while XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingTorch XLA / PJRT reality XLA SPMD sharding annotations folds the same casts into the compiled graph, so the BF16 canary first proves the two paths are comparing the same math surface.
What lives below the call boundary on each path
On the CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200 path, below model.__call__ we have FSDP2 wrappers, tensor-parallel module replacements, the Transformer Engine bridge, attention-kernel selectors, Triton call sites, FP8 helpers, NCCL bucket plumbing, and Inductor regional compile setup. On the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path, below the same boundary we have XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingTorch XLA / PJRT reality XLA SPMD sharding annotations SPMD sharding annotations, sharding audits, runtime-flag loading, the per-micro-step torch_xla.compile() boundary, persistent cache hookup, and chip-memory helpers. The boundaries are different shapes; the rule that they live below the call is the same.
On CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200, that regional-compile surface also has an ordering contract. Tensor- and sequence-parallel setup need to happen before regional compile so torch.compile still traces the distributed block shape, while outer wrappers such as FSDP2 still belong after the compiled leaves are chosen. That is why Regional compile without losing the plot belongs next to this decision matrix rather than under a generic compile-speed story.
The model itself is the same file. That is what makes the duplication tractable. Without that rule, two stacks become two models, two models become two test suites, and two test suites become two products. The rule keeps us in one product.
What we would change if we could only have one path
If we had to drop a path today we would keep CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingHistory: GB10 tensor-path proof summary Reference: training on 8x H200, because FA4 and FP8 are where the architectural iteration speed is highest. We would lose the cost-per-token advantage of TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries. We would also lose the discipline that the cross-check enforces: with only one path there is no second source of truth for the canary loss, and a regression that happens to be path-specific becomes much harder to detect.
The fact that we keep both paths is not free, but the cost is bounded by the rules above and the value is real. Two paths is the right number for our workload at this size; one path would be cheaper and weaker, three paths would be too much surface to maintain. We do not see the calculus changing soon.
Frequently asked questions
What should the portability audit read first when compile behavior looks wrong?+
UncachedCompile reading. On CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes., start with graph-break inspection and the smallest canary that still reproduces the fallback.Why does the 100-step cross-check start in BF16 before CUDA FP8 zones return?+
Terms used in this article
Start here for quick definitions, then follow the linked posts for deeper context.
The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.
NVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.
Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.
The TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.
The TPU backend library that pairs with PJRT/XLA and owns device-side execution underneath the frontend.
NVIDIA's Hopper H200 GPU platform, typically discussed here as an 8-GPU training node with large HBM capacity and NVLink-connected ranks.