ZeRO-3-shaped sharding on the XLA backend: what transfers from FSDP2 and what does not
How to think about TPU XLA sharding honestly: keep the ZeRO-3 memory goal, drop the assumption that TPU uses the same eager FSDP2 wrapper model as CUDA.

Teams often use "FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingHistory: FSDP2 pain and payoff Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries" as shorthand for a memory goal rather than a literal implementation. That shorthand is easy to misuse. On CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200, FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingHistory: FSDP2 pain and payoff Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview is an eager wrapper and hook-based abstraction. On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations, the practical analogue is usually SPMD parameter sharding with ZeRO-3-like memory behavior, not the same wrapper mechanism.
The important distinction
On CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200, fully_shard rewrites module structure and installs runtime hooks for all-gather and reduce-scatter. On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations, sharding is generally expressed through SPMD annotations and compiler-owned collective placement. The memory objective may be similar, but the mechanism is different.
That is the right public framing:
- CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200 path: eager FSDP or FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingHistory: FSDP2 pain and payoff Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview-style wrapper semantics
- TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path: XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations SPMD shardingQuick term guideXLA SPMDThe explicit TPU sharding mode where one compiled program carries placement rules instead of rank-local imperative code.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: XLA SPMD sharding annotations Example: TPU backend ownership note that aims for similar memory savings
Treating them as identical leads to bad debugging assumptions.
What transfers cleanly
Some ideas do transfer across backends:
- classify which parameters should be sharded versus replicated
- keep the sharding policy stable across steps
- gate launches on whether the intended shard plan is actually valid
- separate memory goals from wrapper-specific implementation details
These are policy ideas, not proof that the same API surface exists on both backends.
What does not transfer cleanly
Several familiar CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200 knobs do not map directly to TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations:
- eager hook timing
reshard_after_forward- prefetch knobs tied to Python wrapper execution
- assumptions about local wrapper state being visible at every block boundary
On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations, collective placement and resharding behavior are compiler-shaped. The relevant debugging surfaces are graph stability, meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.GroundingAbout: XLA SPMD sharding annotations Example: 3D parallelism sample Reference: TPU backend ownership note construction, annotation correctness, and recompilation risk.
Why this matters operationally
If a team says "FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingHistory: FSDP2 pain and payoff Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries" but is really using XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations SPMD shardingQuick term guideXLA SPMDThe explicit TPU sharding mode where one compiled program carries placement rules instead of rank-local imperative code.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: XLA SPMD sharding annotations Example: TPU backend ownership note, then launch, profiling, and failure interpretation should follow the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries model:
- confirm meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.GroundingAbout: XLA SPMD sharding annotations Example: 3D parallelism sample Reference: TPU backend ownership note and sharding annotations early
- keep the shard contract stable across steps
- treat recompilation and memory-space assignment as first-class risks
- avoid copying CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200-only tuning vocabulary into TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries launch documentation
That keeps the operational story honest. It also avoids implying official parity where the underlying implementation model is different.
A safer naming convention
For public docs, a safer pattern is:
- "FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingHistory: FSDP2 pain and payoff Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview on CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200"
- "ZeRO-3-shaped sharding on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations"
That keeps the memory intent visible without claiming identical runtime machinery.
In PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations, that contract is usually made concrete with mark_shardingQuick term guidemark_sharding(...)PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement.GroundingAbout: XLA SPMD sharding annotations Example: FSDP sharding sample Reference: TPU backend ownership note on a named meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.GroundingAbout: XLA SPMD sharding annotations Example: 3D parallelism sample Reference: TPU backend ownership note plus an explicit PartitionSpecQuick term guidePartitionSpecThe tuple-style sharding layout that says which tensor axis maps to which mesh axis and which axes stay replicated.GroundingAbout: XLA SPMD sharding annotations Example: FSDP sharding sample Reference: TPU backend ownership note-style layout. If you come from the JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries Reference: Pallas on TPU side, the closest sibling is NamedShardingQuick term guideNamedShardingJAX's frontend sharding object that pairs a mesh with a PartitionSpec; similar goal to PyTorch/XLA placement annotations, but not the same frontend API.GroundingAbout: XLA SPMD sharding annotations Reference: libtpu / PJRT / JAX ownership boundaries Reference: TPU backend ownership note: same question about owned placement, different frontend surface and different debugging vocabulary.
DTensorQuick term guideDTensorPyTorch's mesh-backed distributed-tensor abstraction: one logical tensor with explicit shard or replica metadata across ranks.GroundingAbout: EP / PP / TP / CP / SP / DP overview Example: 3D parallelism sample Reference: DualPipe and 3D parallelism on NVIDIA is the closer PyTorch term for "one logical tensor plus explicit shard or replica metadata", but on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations the operational seam in this article is still the compiler-owned XLA SPMDQuick term guideXLA SPMDThe explicit TPU sharding mode where one compiled program carries placement rules instead of rank-local imperative code.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: XLA SPMD sharding annotations Example: TPU backend ownership note contract. That is why this lane stays grounded in meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.GroundingAbout: XLA SPMD sharding annotations Example: 3D parallelism sample Reference: TPU backend ownership note construction, mark_shardingQuick term guidemark_sharding(...)PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement.GroundingAbout: XLA SPMD sharding annotations Example: FSDP sharding sample Reference: TPU backend ownership note placement, and graph stability instead of assuming CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200 FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingHistory: FSDP2 pain and payoff Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview wrapper mechanics transferred unchanged.
Frequently asked questions
What should I inspect before changing the shard policy?+
CompileTime or transfer counters keep rising after warmup, the problem is probably graph churn or host synchronization rather than an FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.-style reshard knob. If the counters stay quiet but placement is wrong, inspect the mark_shardingQuick term guidemark_sharding(...)PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement. mesh and partition specQuick term guidePartitionSpecThe tuple-style sharding layout that says which tensor axis maps to which mesh axis and which axes stay replicated. directly; PyTorch/XLA's SPMD contract is carried by those annotations, not by hidden CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.-style hook timing. The local-safe companion path is TPU runtime probe sample, Canonical XLA flag profile, and XLA SPMD sharding annotations.What should I check before bootstrapping a large TPU shard run?+
mark_shardingQuick term guidemark_sharding(...)PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement. policy still match the placement plan before the first real training step. That keeps host memory, parameter residency, and compiler placement separate enough to debug. The local-safe companion path is Distributed memory notes, FSDP sharding sample, and XLA SPMD sharding annotations.What CUDA habits should I remove before blaming XLA sharding?+
.item() or .nonzero(), and manual layout reshapes that collide with the meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense. dimension plan. Move those checks into a preflight or post-step receipt, then debug mark_shardingQuick term guidemark_sharding(...)PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement. and mesh placement with the stable graph intact. The local-safe companion path is Graph recompilation hell, XLA SPMD sharding annotations, TPU runtime probe sample, and Canonical XLA flag profile.Terms used in this article
Start here for quick definitions, then follow the linked posts for deeper context.
PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.
PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement.
The explicit TPU sharding mode where one compiled program carries placement rules instead of rank-local imperative code.
The named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.
The tuple-style sharding layout that says which tensor axis maps to which mesh axis and which axes stay replicated.
NVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.
The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.
JAX's frontend sharding object that pairs a mesh with a PartitionSpec; similar goal to PyTorch/XLA placement annotations, but not the same frontend API.
PyTorch's mesh-backed distributed-tensor abstraction: one logical tensor with explicit shard or replica metadata across ranks.
Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.
A separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.