MegaCpp EngineeringApplied C++ model systems
</>
Article
Grounded engineering note from the MegaCpp stack
Published 3 min readDavid Gornshtein
TPU
XLA
Sparse Attention
Pallas
Long Context

Block-sparse attention on TPU v6e: block masks, MXU-friendly tiles, and stable contracts

How to frame block-sparse attention on TPU honestly: explicit mask contracts, MXU-aligned tile choices, and a preference for stable sparse layouts over data-dependent retracing.

MegaCpp
Focused on applied C++ model engineering
Article Preview
Block-sparse attention on TPU v6e: block masks, MXU-friendly tiles, and stable contracts
Published 3 min readDavid Gornshtein

Block-sparse attention matters on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries because long context quickly makes dense score tensors too expensive. "Block-sparse" here means skipping whole score tiles, not single tokens, so the planner and kernel can still operate on regular shapes. The useful TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries story is not "rewrite everything." It is "keep the sparse contract explicit enough that the compiled path stays stable."

Why sparse attention matters on TPU

At long context, dense attention becomes a memory problem before it becomes anything else. Sparse layouts help only if they preserve a clean execution contract:

  • the block-selection logic must not trigger recompilation every step
  • the mask contract must be explicit enough to test
  • tile choices should match the hardware well enough to avoid falling back to a de facto dense path

For readers coming from GPU sparse-attention work, the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries distinction is mostly about where the decision gets made and in what order the surviving blocks run. A GPU article can stay centered on one kernel family; a TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries article usually has to explain the planner, the compact mask metadata, and the compile boundary together. Clustered sparse on TPU: the planner stages and Pallas grid shrinking sample are the shortest local-safe receipts for that difference.

On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries v6e, that hardware match usually starts with 256x256 score tiles because v6e MXUs are 256 x 256 systolic arrays, whereas earlier TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries generations used 128 x 128 arrays. That does not make 256 tokens a semantic masking rule; it means smaller logical blocks need a lowering reason, because the physical kernel may still pay for a larger tile.

What the sparse contract should contain

For TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries, the safest structure is a contract that separates selection from execution:

  1. choose candidate blocks
  2. classify which blocks are valid or need finer masking
  3. pass stable mask metadata into the compiled kernel path

That separation matters because data-dependent decisions inside the hot loop are often what turn a sparse idea into a compile problem. The checked-in sparse receipts keep the static mask meaning outside the batch payload and pass compact scalar-prefetch worklists such as data_next, block_mask, and grid_width into the compiled lane, with transposed backward variants for the reverse path, so one traced contract can be reused across calls instead of rebuilding legality inside each invocation. Splash mask cache sample, Clustered sparse forward-cache sample, and Trace-Pallas scalar-prefetch sample are the local-safe receipts for that handoff.

Why explicit mask semantics matter

A good block-sparse implementation makes the legality rules auditable. That usually means distinguishing at least:

  • blocks that are valid to attend to
  • blocks that are fully safe without a finer token mask
  • blocks that still need token-level cleanup

That is a better contract than relying on one implicit mask representation to carry all meaning.

TPU versus GPU mental model

The TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path and the GPU path do not need to look identical. GPU stacks often lean on Triton-heavy mask and kernel surfaces. TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries paths are more likely to succeed when they keep the sparse contract explicit and the compiled kernel interface stable.

That means a TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries article should stay focused on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries concerns:

  • static or quasi-static tile choices
  • stable sparse metadata
  • recompilation avoidance
  • correctness checks on mask construction

What should be tested

For sparse attention, contract tests matter more than slogans. A useful test surface checks:

  • block classification is correct for mixed causal and document boundaries
  • sparse metadata preserves the intended legal region
  • lower-context cases still match a trusted reference path

The public claim worth making

The safe public statement is simple: MegaCpp uses block-sparse TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries attention only where the sparse contract is explicit enough to test and stable enough to compile repeatedly. That is a narrower claim than "sparse attention is solved," and it is the more useful one.

FAQ

Frequently asked questions

When does the block-mask metadata itself become the problem?+
Usually much later than readers expect. At 256x256 blocks with a 10% active-block rate, the Block-COO coordinates are only about 13 KB at 32k context, 52 KB at 64k, and 210 KB at 128k for one head. That is why the practical TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels. failure seam is usually not "metadata is too big," but "the sparse contract stopped being static enough to reuse the compiled lane." Trace-Pallas scalar-prefetch sample and Pallas kernel selection notes are the local receipts for that distinction.
Can the active block count change without recompiling?+
Yes, but only when the shape of the sparse contract stays fixed. Treat the active layout as values inside precomputed worklists, not as a new traceable program: cache the static mask semantics, rebuild per-batch token details at call time, and feed the compiled lane bounded tables such as block_mask, data_next, and grid_width. If the maximum grid width or mask semantics change, that is a different contract and should be tested as one. Clustered sparse forward-cache sample and Pallas grid shrinking sample
Do 256x256 tiles force document boundaries to align to 256 tokens?+
No. The tile is an execution unit, not a semantic boundary. If a causal diagonal or packed-document edge cuts through a kept tile, the contract should skip empty tiles first and then carry focused partial-mask work for the surviving tiles. The grid-shrinking sample keeps partial_mask_blocks beside block_mask and data_next, which is the shape worth testing. Pallas grid shrinking sample
Glossary

Terms used in this article

Start here for quick definitions, then follow the linked posts for deeper context.

Pallas

JAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.

Splash

The stable TPU attention family used for dense or local-mask lanes before MegaCpp drops to narrower planner-driven sparse contracts.

doc_ids

The fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.

XLA

The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.

TPU

Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.

Topic hubs