MegaCpp EngineeringApplied C++ model systems
</>
Article
Grounded engineering note from the MegaCpp stack
Published 4 min readDavid Gornshtein
TPU
Pallas
Sparse Attention
Kernels
XLA

Clustered sparse on TPU: the planner stages

How MegaCpp decomposes clustered sparse TPU attention into planner stages, legality checks, and backend dispatch rather than treating sparse attention as one giant kernel.

MegaCpp
Focused on applied C++ model engineering
Article Preview
Clustered sparse on TPU: the planner stages
Published 4 min readDavid Gornshtein

The most misleading way to describe sparse TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries attention is as one kernel. MegaCpp's public examples show something more useful: a planner pipeline. Coarse routing, union selection, legality checks, causal windowing, and block expansion all happen before the final sparse kernel is even the interesting part.

That is the right public story because it explains why sparse TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries work is hard. The challenge is not only writing a kernel. It is keeping the planner stages shape-safe and backend-safe enough that the sparse lane cannot silently drift into an invalid or misleading execution path.

One useful external grounding is JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries Reference: Pallas on TPU PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample's scalar-prefetch model: a small control payload can be loaded before the pipeline starts so block index maps can do data-dependent routing without changing compiled grid shape. That lines up with MegaCpp keeping union selection, legality, and block expansion as planner outputs instead of hiding them inside the final kernel.

The planner is where the sparse contract becomes real

The kernel examples in this repo already expose the main stages:

That is enough to support a strong architectural claim. Sparse TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries attention is not one feature flag. It is a planner plus execution stack.

The checked-in examples make that stack concrete enough to test: Clustered sparse three-phase sample, Union selection query mask sample, Exact mask contract cache sample, and XLA backend dispatch sample.

Why MegaCpp separates the stages publicly

The separation solves two problems.

First, it makes correctness legible. If legality and windowing are buried inside one opaque sparse kernel, it becomes hard to prove which boundary failed. Breaking the pipeline into planner stages makes it possible to reason about each contract individually.

The three-phase sample also keeps routing stop-gradient while the final sparse attention stage remains differentiable. That is another reason not to collapse planner logic into the kernel surface.

Second, it makes backend ownership honest. A request for clustered sparse attention does not mean one exact backend will always execute it. The backend dispatch and fallback examples keep that visible. Sometimes the sparse request stays on the intended PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample path. Sometimes a different backend or fallback is more honest for the current shape and mask contract.

Two TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries ownership labels are useful here. trace_pallasQuick term guidetrace_pallasThe native PyTorch/XLA custom-kernel lane that traces a Pallas kernel into a payload the XLA side can keep without crossing into a generic JAX bridge call.GroundingAbout: Pallas on TPU Example: trace_pallas scalar-prefetch sample Example: XLA Pallas bridge receipt is the native PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample payload lane that keeps the sparse kernel on the PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations side, while call_jaxQuick term guidecall_jaxThe Torch/XLA bridge lane that hands one narrowed TPU operation to JAX instead of moving the whole program into a JAX-owned frontend path.GroundingAbout: libtpu / PJRT / JAX ownership boundaries Example: XLA call_jax bridge Example: call_jax bridge runtime is the narrower bridge path that hands one operation to JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries Reference: Pallas on TPU. If the planner crosses from the first to the second, that backend decision should stay visible instead of being flattened into one generic "TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries backend."

The planner stages that matter most

The local examples suggest a practical decomposition.

Union selection is where top-k block choices are converted into a compact work set the later sparse kernel can actually consume. Causal windowing is where future-illegal tiles are removed before they pollute the sparse plan. Hierarchical block expansion is where a coarse routing choice is refined into a more precise sparse workset. Exact mask-contract helpers ensure that cache keys and legality decisions remain tied to the actual runtime mask semantics.

Causal windowing and hierarchical expansion are planner work for the same reason: they should happen once in the planner, not reappear as per-tile branching inside the final phase.

This is a better engineering story than saying "we implemented clustered sparse attention on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries." It shows where the planner can fail and why the sparse lane needs more than one test surface.

Why Pallas and Splash both appear in the story

The point is not that SplashQuick term guideSplashThe stable TPU attention family used for dense or local-mask lanes before MegaCpp drops to narrower planner-driven sparse contracts.GroundingAbout: Block-sparse attention on TPU Example: Splash mask cache sample Example: clustered sparse forward-cache sample and PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample are interchangeable. The point is that they own different parts of the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries attention story. SplashQuick term guideSplashThe stable TPU attention family used for dense or local-mask lanes before MegaCpp drops to narrower planner-driven sparse contracts.GroundingAbout: Block-sparse attention on TPU Example: Splash mask cache sample Example: clustered sparse forward-cache sample is the stable path for more standard dense or local attention surfaces. PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample matters when the mask, sparse plan, or softcap behavior needs lower-level control; the Pallas FA softcap note is the adjacent example for that narrower handoff. MegaCpp's public examples keep that distinction visible instead of flattening everything into "the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries backend."

That is also why clustered sparse examples belong beside backend-dispatch and bridge examples. The planner stages are only meaningful if the runtime can make a defensible backend choice after the planner has spoken.

Prior art and context

The broad ideas are not unique. PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries docs and sparse TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries docs explain the kernel-language side. SplashQuick term guideSplashThe stable TPU attention family used for dense or local-mask lanes before MegaCpp drops to narrower planner-driven sparse contracts.GroundingAbout: Block-sparse attention on TPU Example: Splash mask cache sample Example: clustered sparse forward-cache sample kernel sources show the stable TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries attention lane. MoBA and related sparse-attention work provide the broader block-sparse routing context. MegaCpp's public contribution is the narrower planner view: examples that show how legality, routing, block expansion, and backend dispatch are kept as explicit stages around clustered sparse TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries attention.

FAQ

Frequently asked questions

Where should I look first if a clustered sparse TPU run is wrong but the final kernel still looks fine?+
Start with Clustered sparse three-phase sample, Union selection query mask sample, Exact mask contract cache sample, and XLA backend dispatch sample. They expose legality, unioning, and backend choice before the final kernel becomes a black box.
What exactly does union selection hand to the final sparse stage?+
It hands over a compact workset, not another dense mask. The union-selection sample reduces per-query block picks into a smaller payload the final phase can audit and consume.
Why is exact-mask caching part of the planner story?+
Because the cache key should describe the static mask contract, not the current batch. The exact-mask sample keeps semantic knobs such as window shape and local-window mode in the key, then rebuilds runtime fields such as doc_ids and valid_token_counts only when the batch payload arrives.
Glossary

Terms used in this article

Start here for quick definitions, then follow the linked posts for deeper context.

trace_pallas

The native PyTorch/XLA custom-kernel lane that traces a Pallas kernel into a payload the XLA side can keep without crossing into a generic JAX bridge call.

call_jax

The Torch/XLA bridge lane that hands one narrowed TPU operation to JAX instead of moving the whole program into a JAX-owned frontend path.

Pallas

JAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.

XLA

The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.

TPU

Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.

Splash

The stable TPU attention family used for dense or local-mask lanes before MegaCpp drops to narrower planner-driven sparse contracts.

JAX

A separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.

Topic hubs