MegaCpp EngineeringApplied C++ model systems
</>
Article
Grounded engineering note from the MegaCpp stack
Published 2 min readDavid Gornshtein
TPU
XLA
SPMD
FSDP2
ZeRO-3

ZeRO-3-shaped sharding on the XLA backend: what transfers from FSDP2 and what does not

How to think about TPU XLA sharding honestly: keep the ZeRO-3 memory goal, drop the assumption that TPU uses the same eager FSDP2 wrapper model as CUDA.

MegaCpp
Focused on applied C++ model engineering
Article Preview
ZeRO-3-shaped sharding on the XLA backend: what transfers from FSDP2 and what does not
Published 2 min readDavid Gornshtein

Teams often use "FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingHistory: FSDP2 pain and payoff Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries" as shorthand for a memory goal rather than a literal implementation. That shorthand is easy to misuse. On CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200, FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingHistory: FSDP2 pain and payoff Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview is an eager wrapper and hook-based abstraction. On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations, the practical analogue is usually SPMD parameter sharding with ZeRO-3-like memory behavior, not the same wrapper mechanism.

The important distinction

On CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200, fully_shard rewrites module structure and installs runtime hooks for all-gather and reduce-scatter. On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations, sharding is generally expressed through SPMD annotations and compiler-owned collective placement. The memory objective may be similar, but the mechanism is different.

That is the right public framing:

Treating them as identical leads to bad debugging assumptions.

What transfers cleanly

Some ideas do transfer across backends:

  • classify which parameters should be sharded versus replicated
  • keep the sharding policy stable across steps
  • gate launches on whether the intended shard plan is actually valid
  • separate memory goals from wrapper-specific implementation details

These are policy ideas, not proof that the same API surface exists on both backends.

What does not transfer cleanly

Several familiar CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200 knobs do not map directly to TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations:

  • eager hook timing
  • reshard_after_forward
  • prefetch knobs tied to Python wrapper execution
  • assumptions about local wrapper state being visible at every block boundary

On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations, collective placement and resharding behavior are compiler-shaped. The relevant debugging surfaces are graph stability, meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.GroundingAbout: XLA SPMD sharding annotations Example: 3D parallelism sample Reference: TPU backend ownership note construction, annotation correctness, and recompilation risk.

Why this matters operationally

If a team says "FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingHistory: FSDP2 pain and payoff Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries" but is really using XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations SPMD shardingQuick term guideXLA SPMDThe explicit TPU sharding mode where one compiled program carries placement rules instead of rank-local imperative code.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: XLA SPMD sharding annotations Example: TPU backend ownership note, then launch, profiling, and failure interpretation should follow the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries model:

That keeps the operational story honest. It also avoids implying official parity where the underlying implementation model is different.

A safer naming convention

For public docs, a safer pattern is:

That keeps the memory intent visible without claiming identical runtime machinery.

In PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations, that contract is usually made concrete with mark_shardingQuick term guidemark_sharding(...)PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement.GroundingAbout: XLA SPMD sharding annotations Example: FSDP sharding sample Reference: TPU backend ownership note on a named meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.GroundingAbout: XLA SPMD sharding annotations Example: 3D parallelism sample Reference: TPU backend ownership note plus an explicit PartitionSpecQuick term guidePartitionSpecThe tuple-style sharding layout that says which tensor axis maps to which mesh axis and which axes stay replicated.GroundingAbout: XLA SPMD sharding annotations Example: FSDP sharding sample Reference: TPU backend ownership note-style layout. If you come from the JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries Reference: Pallas on TPU side, the closest sibling is NamedShardingQuick term guideNamedShardingJAX's frontend sharding object that pairs a mesh with a PartitionSpec; similar goal to PyTorch/XLA placement annotations, but not the same frontend API.GroundingAbout: XLA SPMD sharding annotations Reference: libtpu / PJRT / JAX ownership boundaries Reference: TPU backend ownership note: same question about owned placement, different frontend surface and different debugging vocabulary.

DTensorQuick term guideDTensorPyTorch's mesh-backed distributed-tensor abstraction: one logical tensor with explicit shard or replica metadata across ranks.GroundingAbout: EP / PP / TP / CP / SP / DP overview Example: 3D parallelism sample Reference: DualPipe and 3D parallelism on NVIDIA is the closer PyTorch term for "one logical tensor plus explicit shard or replica metadata", but on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations the operational seam in this article is still the compiler-owned XLA SPMDQuick term guideXLA SPMDThe explicit TPU sharding mode where one compiled program carries placement rules instead of rank-local imperative code.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: XLA SPMD sharding annotations Example: TPU backend ownership note contract. That is why this lane stays grounded in meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.GroundingAbout: XLA SPMD sharding annotations Example: 3D parallelism sample Reference: TPU backend ownership note construction, mark_shardingQuick term guidemark_sharding(...)PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement.GroundingAbout: XLA SPMD sharding annotations Example: FSDP sharding sample Reference: TPU backend ownership note placement, and graph stability instead of assuming CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200 FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.GroundingHistory: FSDP2 pain and payoff Example: FSDP sharding sample Reference: EP / PP / TP / CP / SP / DP overview wrapper mechanics transferred unchanged.

FAQ

Frequently asked questions

What should I inspect before changing the shard policy?+
Look at the XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here. metrics and the meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense. contract before touching the model wrapper. If CompileTime or transfer counters keep rising after warmup, the problem is probably graph churn or host synchronization rather than an FSDP2Quick term guideFSDP2PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.-style reshard knob. If the counters stay quiet but placement is wrong, inspect the mark_shardingQuick term guidemark_sharding(...)PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement. mesh and partition specQuick term guidePartitionSpecThe tuple-style sharding layout that says which tensor axis maps to which mesh axis and which axes stay replicated. directly; PyTorch/XLA's SPMD contract is carried by those annotations, not by hidden CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.-style hook timing. The local-safe companion path is TPU runtime probe sample, Canonical XLA flag profile, and XLA SPMD sharding annotations.
What should I check before bootstrapping a large TPU shard run?+
Treat initialization as its own contract, not as an afterthought of the wrapper name. The portable preflight is: construct on meta or empty storage first, materialize only the shard-local tensors you intend to own, then confirm that the XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here. meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense. and mark_shardingQuick term guidemark_sharding(...)PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement. policy still match the placement plan before the first real training step. That keeps host memory, parameter residency, and compiler placement separate enough to debug. The local-safe companion path is Distributed memory notes, FSDP sharding sample, and XLA SPMD sharding annotations.
What CUDA habits should I remove before blaming XLA sharding?+
Start with habits that change the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels. graph rather than the shard policy: wrapper-level process-group assumptions, per-step scalar readbacks such as .item() or .nonzero(), and manual layout reshapes that collide with the meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense. dimension plan. Move those checks into a preflight or post-step receipt, then debug mark_shardingQuick term guidemark_sharding(...)PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement. and mesh placement with the stable graph intact. The local-safe companion path is Graph recompilation hell, XLA SPMD sharding annotations, TPU runtime probe sample, and Canonical XLA flag profile.
Glossary

Terms used in this article

Start here for quick definitions, then follow the linked posts for deeper context.

FSDP2

PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.

mark_sharding(...)

PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement.

XLA SPMD

The explicit TPU sharding mode where one compiled program carries placement rules instead of rank-local imperative code.

mesh

The named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.

PartitionSpec

The tuple-style sharding layout that says which tensor axis maps to which mesh axis and which axes stay replicated.

CUDA

NVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.

XLA

The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.

NamedSharding

JAX's frontend sharding object that pairs a mesh with a PartitionSpec; similar goal to PyTorch/XLA placement annotations, but not the same frontend API.

DTensor

PyTorch's mesh-backed distributed-tensor abstraction: one logical tensor with explicit shard or replica metadata across ranks.

TPU

Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.

JAX

A separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.

Topic hubs