XLA SPMD sharding annotations we actually rely on
Why explicit mark_sharding annotations matter on TPU XLA, what should be pinned explicitly, and why propagation is not a substitute for a stable sharding contract.

This post is about the annotation surface itself: where explicit mark_shardingQuick term guidemark_sharding(...)PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement.GroundingExample: FSDP sharding sample Reference: FSDP2 on XLA TPU Reference: TPU backend ownership note calls matter, where replication should be stated out loud, and why leaving parameter placement to inference is often the wrong tradeoff on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality. If you come from the JAX side, the closest sibling contract is NamedShardingQuick term guideNamedShardingJAX's frontend sharding object that pairs a mesh with a PartitionSpec; similar goal to PyTorch/XLA placement annotations, but not the same frontend API.Groundinglibtpu / PJRT / JAX ownership boundaries TPU backend ownership note on a named meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.GroundingExample: 3D parallelism sample Reference: FSDP2 on XLA TPU Reference: TPU backend ownership note with an explicit PartitionSpecQuick term guidePartitionSpecThe tuple-style sharding layout that says which tensor axis maps to which mesh axis and which axes stay replicated.GroundingExample: FSDP sharding sample Reference: FSDP2 on XLA TPU Reference: TPU backend ownership note, not a vague hope that propagation will recover the placement you meant.
Why propagation is not enough
SPMD propagation is useful, but it is not a replacement for an explicit parameter contract. If a parameter is left unannotated, XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality is free to infer a placement from surrounding graph structure. Sometimes that is fine. Sometimes it is wrong in a way that hurts correctness or compile stability rather than crashing immediately.
That is why the conservative rule is simple: explicitly annotate parameters you own, including the ones you intend to replicate.
The practical reason is often small tensors. When those get inferred as sharded, the result is usually not useful math parallelism but a stream of tiny latency-heavy collectives.
What should be explicit
Three classes benefit most from explicit annotations:
- parameters that should be sharded along a known meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.GroundingExample: 3D parallelism sample Reference: FSDP2 on XLA TPU Reference: TPU backend ownership note axis
- parameters that must remain replicated even if their shape happens to divide cleanly
- boundary activations whose layout is part of the intended compile contract
The important idea is not one exact project-specific whitelist. It is the operational discipline: if a tensor layout matters, pin it.
Why replication should also be stated explicitly
Small tensors are often the easiest place to make a bad assumption. A projection or auxiliary parameter may coincidentally match a meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.GroundingExample: 3D parallelism sample Reference: FSDP2 on XLA TPU Reference: TPU backend ownership note axis and look shardable even when the intended semantics are replicated. In that situation, explicit replication is safer than letting inference guess.
import torch_xla.distributed.spmd as xs
def replicated(ndim: int):
return tuple(None for _ in range(ndim))
for name, param in model.named_parameters():
if name in REPLICATED_NAMES:
xs.mark_sharding(param, mesh, replicated(param.ndim))
What a good audit looks like
If sharding is part of the startup contract, it should be auditable. A useful audit checks that:
- every intended parameter annotation was applied
- replication is explicit where it is supposed to be explicit
- meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.GroundingExample: 3D parallelism sample Reference: FSDP2 on XLA TPU Reference: TPU backend ownership note shape and axis names match the launch configuration
This matters because a sharding mistake can show up as compile instability or subtle training drift rather than as a loud startup failure.
The stronger version of that audit checks the compiled sharding surface too, not only the Python call sites. On PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality, the useful question is whether the final traced graph still shows replicated small tensors and the intended boundary placements, because a startup contract is only real once the compiled layout still matches the policy you thought you applied.
The official PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality readback tools for that check are visualize_sharding for the sharding string and visualize_tensor_sharding for a placement visualization. They are the right first pass when you need to confirm that the reported layout still matches the startup policy.
Bounded dynamic shapes do not replace that audit. They mean one declared upper-bounded family of shapes, not permission to keep reading real tensor sizes into Python on the hot path.
A safer public claim
The useful public claim is not that one exact sharding table is universal. The useful claim is narrower:
- on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality, parameter placement should be explicit where semantics matter
- replication is a real placement choice and should be stated explicitly
- propagation is helpful for intermediates, but risky as the sole source of truth for owned parameters
Frequently asked questions
Can bounded dynamic shapes replace explicit sharding?+
Should boundary activations be annotated too?+
Terms used in this article
Start here for quick definitions, then follow the linked posts for deeper context.
PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement.
The named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.
JAX's frontend sharding object that pairs a mesh with a PartitionSpec; similar goal to PyTorch/XLA placement annotations, but not the same frontend API.
The tuple-style sharding layout that says which tensor axis maps to which mesh axis and which axes stay replicated.
The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.
The TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.
The TPU backend library that pairs with PJRT/XLA and owns device-side execution underneath the frontend.
Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.
Continue with a curated reading path
TPU Sparse Attention and Pallas Kernels
A curated TPU sparse-attention reading path: block-sparse contracts, Pallas kernel choices, SPMD sharding, and the runtime surfaces that keep long-context TPU work stable.
TPU v6e and XLA Runtime Surfaces
A curated reading order for TPU work: bring-up, PJRT and Torch/XLA boundaries, SPMD sharding, and the kernel/runtime traps that made TPU performance non-obvious.