Mamba-3 fused trapezoidal scan on TPU v6e
How we took the Mamba-3 trapezoidal SSM update from a CUDA Triton kernel to a Pallas/XLA-friendly scan on TPU v6e, and what survived the deployment port.

MegaCpp uses Mamba-style blocks because they let a small model spend part of its budget on sequence mixing that is not just another attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns layer. On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries, that only works if the state-space update stays inside a compile-stable path. That compile-stability requirement is the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries-side counterpart of the runtime ownership story in libtpu, PJRT, JAX, and ownership boundaries.
This article is about that TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries constraint, not about claiming a universal "best" kernel.
For first-touch readers, four terms matter here. PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample is JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries Reference: Pallas on TPU's lower-level kernel language for custom TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries kernels, not a generic name for every XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations path. A trapezoidal scan here means a Mamba-style state update whose prologue, scan body, and epilogue are fused so the recurrent path does not spill avoidable temporary tensors between steps. doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample are the per-token labels that preserve which packed source document each token came from. segment_idsQuick term guidesegment_idsThe fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample are the kernel-facing boundary labels derived from those document IDs so a scan or attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns helper can reset state or mask cross-document pairs without rebuilding a dense mask every step. The smallest checked-in public-safe proof surfaces are document-mask segment ID sample, trace_pallas scalar-prefetch sample, Pallas softcap attention sample, call_jax bridge runtime, and Mamba-3 porting note.
The engineering problem
The Mamba-3 paper is public. So are PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations runtime docs and JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries Reference: Pallas on TPU's TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample docs. What those sources collectively support is a narrow claim:
- Mamba-3 is a real model family with a state-space core
- PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations rewards static-shape-friendly compiled execution
- PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample can target TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries, but it is documented as experimental
MegaCpp's TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries porting story sits inside that triangle. The real work is not proving that one exact kernel is universally optimal. The real work is keeping the scan, the boundary metadata, and the surrounding elementwise work in a path that does not recompile or materialize avoidable temporary tensors.
Why the update shape matters
Once a state-space block depends on adjacent timestep information, a naive implementation quickly becomes launch-heavy. The high-level problem is easy to state:
- the scan itself is not the only cost
- prologue and epilogue work can dominate if they are split into many tiny ops
- TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries performance suffers if the path becomes shape-dynamic or mask-heavy
MegaCpp's TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries discipline is therefore:
- keep the scan body static-shape-friendly
- materialize boundary metadata explicitly
- let XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations fuse surrounding elementwise work by default
- only introduce a custom TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries kernel when it removes a real hot-path cost
Why boundary metadata belongs inside the compiled path
Packed training data made this non-optional. If document boundaries are part of the training contract, the state-space path cannot treat them as an afterthought. MegaCpp therefore keeps sequence or segment identifiers explicit instead of relying on hidden implicit state.
That choice matters for two reasons:
- it makes the compile-time shape story easier to reason about
- it keeps document-boundary semantics aligned with the rest of the long-context pipeline
Publicly, the useful claim is not "our exact implementation uses variable X."
The useful claim is that sequence-boundary metadata must travel with the
compiled path if you want long-context training to stay correct. In practice
that usually means doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample upstream and segment_idsQuick term guidesegment_idsThe fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample at the kernel-facing
boundary, as in document-mask segment ID sample and Pallas softcap attention sample. That is the
same data boundary defended in dataloader throughput and stalls
and tokenized enriched pipeline on TPU.
Where XLA is enough and where Pallas is considered
The default MegaCpp TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries stance is conservative: plain XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations fusion around the compiled scan is the baseline. PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample is considered only when a custom tiled kernel clearly removes extra passes or makes boundary handling materially cleaner.
That is a narrower and more defensible policy than saying "we rewrote the whole thing in PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample." In practice, many TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries performance problems are better solved by keeping the path static and letting XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations do its job than by taking ownership of another kernel surface. The framework-side version of that same caution is the frontier discipline in The Torch 2.12 journey, and the data-contract side is the same one discussed in long context and attention sinks: if boundaries and masks are unstable upstream, another custom kernel is usually the wrong first fix. The checked-in proof split here is useful: trace_pallas scalar-prefetch sample shows the native custom-kernel lane, while call_jax bridge runtime shows the narrower bridge MegaCpp tries not to make its default for this class of update.
A useful way to read those checked-in TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries surfaces is by ownership boundary.
trace_pallasQuick term guidetrace_pallasThe native PyTorch/XLA custom-kernel lane that traces a Pallas kernel into a payload the XLA side can keep without crossing into a generic JAX bridge call.GroundingAbout: Pallas on TPU Example: trace_pallas scalar-prefetch sample Example: XLA Pallas bridge receipt is the native custom-kernel lane: forward and backward stay on
one traced TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path, with segment IDsQuick term guidesegment_idsThe fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample and other sparse metadata passed as
runtime tensors in a fixed argument order. call_jaxQuick term guidecall_jaxThe Torch/XLA bridge lane that hands one narrowed TPU operation to JAX instead of moving the whole program into a JAX-owned frontend path.GroundingAbout: libtpu / PJRT / JAX ownership boundaries Example: XLA call_jax bridge Example: call_jax bridge runtime is the narrower bridge
lane, kept isolated so the model can stay in PyTorch while a TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries-native kernel
is called only where the default lowering is not good enough. That split keeps
bridge scope explicit and makes fallback provenance legible when the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries-native
path is unavailable.
The rules that survived the port
The TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries-friendly rules are:
- keep chunking or scan-shape choices explicit and stable
- keep boundary identifiers materialized instead of implicit
- avoid dynamic per-step kernel configuration
- prefer one compiled path over a Python bridge plus extra runtime glue
- treat custom TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries kernels as opt-in, not as the default answer
These rules follow directly from the official TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries runtime and PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample documentation, even though the exact MegaCpp code path is project-specific.
The same discipline applies to compile keys. Scan shape, tiling choice, and
mask contract belong in the static compiled surface; per-batch boundary
payloads such as doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample, segment_idsQuick term guidesegment_idsThe fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample, or valid-token counts belong in
runtime tensors. Mixing those layers is how a useful boundary feature turns
into a retrace problem.
A checked-in local sample makes that split concrete. exact-mask contract cache sample keys the cached closure only by static window and mask semantics, then rebuilds per-batch payloads at runtime. That is the same graph-contract discipline described more broadly in graph recompilation hell: if batch-specific boundary tensors leak into the compiled key, TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries throughput degrades before the run necessarily fails outright.
What we avoid claiming
MegaCpp does not use this article to claim:
- that TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries is the canonical home of Mamba-3
- that PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample is the settled production path for every TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries kernel
- that one exact fused update is the only correct implementation
Those claims would go beyond the public sources. The safer statement is that MegaCpp adapted a Mamba-style scan to TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries by favoring compile stability, explicit boundary metadata, and narrow use of custom kernels.
Frequently asked questions
Why keep boundary metadata inside the compiled path?+
doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries. to segment_idsQuick term guidesegment_idsThe fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape..What are doc_ids and segment_ids in this article?+
doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries. preserve the original packed-document provenance token by token. segment_idsQuick term guidesegment_idsThe fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape. are the boundary labels a TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels. kernel or scan wrapper actually consumes after packingQuick term guidePacked rowsHow the v6_enriched packed-rows pipeline feeds per-token structure IDs, chunk boundaries, and call edges into the XLA dataloader on TPU v6e without…. document-mask segment ID sample is the smallest checked-in explanation of that conversion.When is Pallas justified instead of plain XLA fusion?+
Why not make call_jax the default TPU path?+
Does this article claim one universally best TPU kernel?+
Where is the public-safe proof surface if the live runtime path is broader?+
Which claims are external literature versus MegaCpp receipts?+
What should the first TPU receipt prove before reaching for Pallas?+
What belongs in the compile key versus runtime boundary tensors?+
doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries., segment_idsQuick term guidesegment_idsThe fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape., and valid-token counts should stay runtime tensors. exact-mask contract cache sample is the smallest checked-in example of that split, and graph recompilation hell is the broader TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels. failure mode when the graph contract drifts.Terms used in this article
Start here for quick definitions, then follow the linked posts for deeper context.
JAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.
The native PyTorch/XLA custom-kernel lane that traces a Pallas kernel into a payload the XLA side can keep without crossing into a generic JAX bridge call.
The Torch/XLA bridge lane that hands one narrowed TPU operation to JAX instead of moving the whole program into a JAX-owned frontend path.
The fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape.
The fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.
The token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.
The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.
Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.
The TPU backend library that pairs with PJRT/XLA and owns device-side execution underneath the frontend.
The TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.
A separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.
The long-context failure mode where a few tokens, often the first token, absorb disproportionate attention mass and behave like a null-attention valve.
NVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.
A grounded look at why MegaCpp combines Mamba-style state-space blocks with a smaller number of attention blocks for long-context C++ work, and…
How the v6_enriched packed-rows pipeline feeds per-token structure IDs, chunk boundaries, and call edges into the XLA dataloader on TPU v6e without…