MegaCpp EngineeringApplied C++ model systems
</>
Article
Grounded engineering note from the MegaCpp stack
Published 3 min readDavid Gornshtein
XLA
SPMD
TPU
Sharding

XLA SPMD sharding annotations we actually rely on

Why explicit mark_sharding annotations matter on TPU XLA, what should be pinned explicitly, and why propagation is not a substitute for a stable sharding contract.

MegaCpp
Focused on applied C++ model engineering
Article Preview
XLA SPMD sharding annotations we actually rely on
Published 3 min readDavid Gornshtein

This post is about the annotation surface itself: where explicit mark_shardingQuick term guidemark_sharding(...)PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement.GroundingExample: FSDP sharding sample Reference: FSDP2 on XLA TPU Reference: TPU backend ownership note calls matter, where replication should be stated out loud, and why leaving parameter placement to inference is often the wrong tradeoff on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality. If you come from the JAX side, the closest sibling contract is NamedShardingQuick term guideNamedShardingJAX's frontend sharding object that pairs a mesh with a PartitionSpec; similar goal to PyTorch/XLA placement annotations, but not the same frontend API.Groundinglibtpu / PJRT / JAX ownership boundaries TPU backend ownership note on a named meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.GroundingExample: 3D parallelism sample Reference: FSDP2 on XLA TPU Reference: TPU backend ownership note with an explicit PartitionSpecQuick term guidePartitionSpecThe tuple-style sharding layout that says which tensor axis maps to which mesh axis and which axes stay replicated.GroundingExample: FSDP sharding sample Reference: FSDP2 on XLA TPU Reference: TPU backend ownership note, not a vague hope that propagation will recover the placement you meant.

Why propagation is not enough

SPMD propagation is useful, but it is not a replacement for an explicit parameter contract. If a parameter is left unannotated, XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality is free to infer a placement from surrounding graph structure. Sometimes that is fine. Sometimes it is wrong in a way that hurts correctness or compile stability rather than crashing immediately.

That is why the conservative rule is simple: explicitly annotate parameters you own, including the ones you intend to replicate.

The practical reason is often small tensors. When those get inferred as sharded, the result is usually not useful math parallelism but a stream of tiny latency-heavy collectives.

What should be explicit

Three classes benefit most from explicit annotations:

The important idea is not one exact project-specific whitelist. It is the operational discipline: if a tensor layout matters, pin it.

Why replication should also be stated explicitly

Small tensors are often the easiest place to make a bad assumption. A projection or auxiliary parameter may coincidentally match a meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.GroundingExample: 3D parallelism sample Reference: FSDP2 on XLA TPU Reference: TPU backend ownership note axis and look shardable even when the intended semantics are replicated. In that situation, explicit replication is safer than letting inference guess.

import torch_xla.distributed.spmd as xs

def replicated(ndim: int):
    return tuple(None for _ in range(ndim))

for name, param in model.named_parameters():
    if name in REPLICATED_NAMES:
        xs.mark_sharding(param, mesh, replicated(param.ndim))

What a good audit looks like

If sharding is part of the startup contract, it should be auditable. A useful audit checks that:

This matters because a sharding mistake can show up as compile instability or subtle training drift rather than as a loud startup failure.

The stronger version of that audit checks the compiled sharding surface too, not only the Python call sites. On PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality, the useful question is whether the final traced graph still shows replicated small tensors and the intended boundary placements, because a startup contract is only real once the compiled layout still matches the policy you thought you applied.

The official PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality readback tools for that check are visualize_sharding for the sharding string and visualize_tensor_sharding for a placement visualization. They are the right first pass when you need to confirm that the reported layout still matches the startup policy.

Bounded dynamic shapes do not replace that audit. They mean one declared upper-bounded family of shapes, not permission to keep reading real tensor sizes into Python on the hot path.

A safer public claim

The useful public claim is not that one exact sharding table is universal. The useful claim is narrower:

FAQ

Frequently asked questions

Can bounded dynamic shapes replace explicit sharding?+
No. They can reduce recompiles from some value-dependent operators, but they do not tell the compiler which owned parameters should be replicated or tiled. Treat them as a graph-stability tool, not as a placement policy.
Should boundary activations be annotated too?+
Yes, when the activation layout is part of the memory or compile contract. Parameters are the first surface to pin, but outputs crossing a sharded boundary should get the same treatment: choose the meshQuick term guidemeshThe named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense. and partition specQuick term guidePartitionSpecThe tuple-style sharding layout that says which tensor axis maps to which mesh axis and which axes stay replicated. intentionally, then confirm the compiled sharding string instead of assuming propagation preserved the intended placement. The memory-side version of that rule is covered in FSDP2 on XLA TPU.
Glossary

Terms used in this article

Start here for quick definitions, then follow the linked posts for deeper context.

mark_sharding(...)

PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement.

mesh

The named device grid that defines which logical axis maps to which TPU or distributed-device axis before sharding annotations make sense.

NamedSharding

JAX's frontend sharding object that pairs a mesh with a PartitionSpec; similar goal to PyTorch/XLA placement annotations, but not the same frontend API.

PartitionSpec

The tuple-style sharding layout that says which tensor axis maps to which mesh axis and which axes stay replicated.

XLA

The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.

PJRT

The TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.

libtpu

The TPU backend library that pairs with PJRT/XLA and owns device-side execution underneath the frontend.

TPU

Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.

Topic hubs