Block-sparse attention on TPU v6e: block masks, MXU-friendly tiles, and stable contracts
How to frame block-sparse attention on TPU honestly: explicit mask contracts, MXU-aligned tile choices, and a preference for stable sparse layouts over data-dependent retracing.

Block-sparse attention matters on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries because long context quickly makes dense score tensors too expensive. "Block-sparse" here means skipping whole score tiles, not single tokens, so the planner and kernel can still operate on regular shapes. The useful TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries story is not "rewrite everything." It is "keep the sparse contract explicit enough that the compiled path stays stable."
Why sparse attention matters on TPU
At long context, dense attention becomes a memory problem before it becomes anything else. Sparse layouts help only if they preserve a clean execution contract:
- the block-selection logic must not trigger recompilation every step
- the mask contract must be explicit enough to test
- tile choices should match the hardware well enough to avoid falling back to a de facto dense path
For readers coming from GPU sparse-attention work, the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries distinction is mostly about where the decision gets made and in what order the surviving blocks run. A GPU article can stay centered on one kernel family; a TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries article usually has to explain the planner, the compact mask metadata, and the compile boundary together. Clustered sparse on TPU: the planner stages and Pallas grid shrinking sample are the shortest local-safe receipts for that difference.
On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries v6e, that hardware match usually starts with 256x256 score tiles because v6e MXUs are 256 x 256 systolic arrays, whereas earlier TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries generations used 128 x 128 arrays. That does not make 256 tokens a semantic masking rule; it means smaller logical blocks need a lowering reason, because the physical kernel may still pay for a larger tile.
What the sparse contract should contain
For TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries, the safest structure is a contract that separates selection from execution:
- choose candidate blocks
- classify which blocks are valid or need finer masking
- pass stable mask metadata into the compiled kernel path
That separation matters because data-dependent decisions inside the hot loop are often what turn a sparse idea into a compile problem. The checked-in sparse receipts keep the static mask meaning outside the batch payload and pass compact scalar-prefetch worklists such as data_next, block_mask, and grid_width into the compiled lane, with transposed backward variants for the reverse path, so one traced contract can be reused across calls instead of rebuilding legality inside each invocation. Splash mask cache sample, Clustered sparse forward-cache sample, and Trace-Pallas scalar-prefetch sample are the local-safe receipts for that handoff.
Why explicit mask semantics matter
A good block-sparse implementation makes the legality rules auditable. That usually means distinguishing at least:
- blocks that are valid to attend to
- blocks that are fully safe without a finer token mask
- blocks that still need token-level cleanup
That is a better contract than relying on one implicit mask representation to carry all meaning.
TPU versus GPU mental model
The TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path and the GPU path do not need to look identical. GPU stacks often lean on Triton-heavy mask and kernel surfaces. TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries paths are more likely to succeed when they keep the sparse contract explicit and the compiled kernel interface stable.
That means a TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries article should stay focused on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries concerns:
- static or quasi-static tile choices
- stable sparse metadata
- recompilation avoidance
- correctness checks on mask construction
What should be tested
For sparse attention, contract tests matter more than slogans. A useful test surface checks:
- block classification is correct for mixed causal and document boundaries
- sparse metadata preserves the intended legal region
- lower-context cases still match a trusted reference path
The public claim worth making
The safe public statement is simple: MegaCpp uses block-sparse TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries attention only where the sparse contract is explicit enough to test and stable enough to compile repeatedly. That is a narrower claim than "sparse attention is solved," and it is the more useful one.
Frequently asked questions
When does the block-mask metadata itself become the problem?+
Can the active block count change without recompiling?+
block_mask, data_next, and grid_width. If the maximum grid width or mask semantics change, that is a different contract and should be tested as one. Clustered sparse forward-cache sample and Pallas grid shrinking sampleDo 256x256 tiles force document boundaries to align to 256 tokens?+
partial_mask_blocks beside block_mask and data_next, which is the shape worth testing. Pallas grid shrinking sampleTerms used in this article
Start here for quick definitions, then follow the linked posts for deeper context.
JAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.
The stable TPU attention family used for dense or local-mask lanes before MegaCpp drops to narrower planner-driven sparse contracts.
The fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.
The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.
Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.
Continue with a curated reading path
TPU Sparse Attention and Pallas Kernels
A curated TPU sparse-attention reading path: block-sparse contracts, Pallas kernel choices, SPMD sharding, and the runtime surfaces that keep long-context TPU work stable.
TPU v6e and XLA Runtime Surfaces
A curated reading order for TPU work: bring-up, PJRT and Torch/XLA boundaries, SPMD sharding, and the kernel/runtime traps that made TPU performance non-obvious.