Pallas FlashAttention with logit softcap on TPU v6e
Why softcap attention on TPU needs a dedicated kernel surface: fuse the nonlinearity, keep masking contract-friendly, and avoid turning a stability trick into a second full pass over the score matrix.

Softcap attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns is a good example of when TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries kernel work becomes justified. "Softcap" here means bounding attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns logits before softmax so outlier scores do not dominate the whole row. The point is not to replace every stock kernel. The point is to avoid paying for an extra pass over the score matrix when a fused path can keep that stability trick inside the main attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns loop.
For first-touch readers, the useful split is this: trace_pallasQuick term guidetrace_pallasThe native PyTorch/XLA custom-kernel lane that traces a Pallas kernel into a payload the XLA side can keep without crossing into a generic JAX bridge call.GroundingAbout: Pallas on TPU Example: trace_pallas scalar-prefetch sample Example: XLA Pallas bridge receipt is the native TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries custom-kernel route, call_jaxQuick term guidecall_jaxThe Torch/XLA bridge lane that hands one narrowed TPU operation to JAX instead of moving the whole program into a JAX-owned frontend path.GroundingAbout: libtpu / PJRT / JAX ownership boundaries Example: XLA call_jax bridge Example: call_jax bridge runtime is the narrower bridge into a JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries Reference: Pallas on TPU helper, and Splash is the stable TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns family for standard causal or local-mask surfaces rather than a synonym for PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample. segment_idsQuick term guidesegment_idsThe fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample are the compact per-token boundary labels derived from packed-row doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample, and a runtime receipt is the checked record of which backend family actually executed. Pallas kernels on TPU, XLA backend index, TPU backend ownership, and Profiler and receipts are the checked-in map for that split.
For this article, segment_idsQuick term guidesegment_idsThe fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample are the compact per-token integers derived from packed-row doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample; they tell the kernel which tokens belong to the same document or span without materializing a dense (T, T) mask. Vocab and Tokenizer Plumbing on TPU is the upstream row-contract story, and Document-mask segment ID sample is the smallest checked-in receipt for the conversion.
Why softcap needs special treatment
Softcap changes attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns logits before softmax. On paper the math is simple. In practice, an unfused implementation can turn one useful nonlinearity into more memory traffic and another expensive walk over the score tiles.
That is why the public TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries claim should stay narrow: use a custom PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample path when softcap, masking, and local-window behavior together justify it. The same justification line is why block-sparse attention on TPU stays focused on stable sparse metadata instead of treating every mask variant as a kernel opportunity.
The checked-in local surfaces for that claim are Pallas softcap attention sample, trace-pallas scalar-prefetch sample, Pallas grid shrinking sample, XLA Pallas bridge receipt sample, and XLA backend dispatch sample. They keep the proof tied to explicit mask metadata, scalar-prefetch wiring, and backend receipts. The dispatch receipt also makes the cutoff explicit: plain causal requests stay on SplashQuick term guideSplashThe stable TPU attention family used for dense or local-mask lanes before MegaCpp drops to narrower planner-driven sparse contracts.GroundingAbout: Block-sparse attention on TPU Example: Splash mask cache sample Example: clustered sparse forward-cache sample-style helpers, while modified softcap or sliding-window variants route to xla_pallas.
What the kernel should keep explicit
A good TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries softcap kernel keeps four things explicit:
- softcap fused into the main score path
- local-window behavior expressed without dynamic per-step retracing
- masking carried as stable metadata rather than ad hoc dense materialization
- document-boundary handling kept compatible with packed-sequence segment identifiers
The point is not one exact implementation detail. The point is a stable compiled contract. On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries that contract is easier to audit when boundary metadata stays explicit, the same way packed rows as the real training contract keeps token and document ownership aligned through the data path.
On v6e that usually means collapsing many caller-side requests onto a small set of validated block or window classes instead of retracing one kernel per exact variant. The planner does the normalization, the kernel keeps the fused softcap-and-mask contract, and the compile cache stays small enough that the custom path still wins.
The numerical side matters too. Softcap only pays off if the logits stay in a higher-precision lane long enough for the bound to mean what it is supposed to mean. On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries that usually means keeping the score accumulation and online-softmax bookkeeping in fp32 scratch, then casting back down only after the bounded normalization work is done.
Why masking contract matters here too
Softcap alone is not the whole kernel story. The moment local windows, segment boundaries, or sparse grid decisions enter the picture, the kernel stops being a pure numerical tweak and becomes part of the model's masking contract, the same boundary described in document masking and the curriculum and packed rows as the real training contract. That is where a custom TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path becomes easier to justify.
For the packed-sequence side, document-mask segment ID sample is the smallest checked-in readback of how segment_idsQuick term guidesegment_idsThe fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample stay explicit instead of turning into one dense mask per batch.
For the sparse-attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns version of that decision, clustered sparse on TPU: the planner stages and block-sparse attention on TPU are the closest adjacent articles.
What should stay out
A TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries softcap path becomes harder to trust if it also tries to absorb every adjacent idea. In practice, it is safer to avoid:
- dynamic per-step window changes
- optional fallback trees with radically different semantics
- unrelated fused epilogues that do not clearly pay for themselves
The measurement boundary should stay narrow too. Otherwise the instrumentation starts to distort the kernel path, which is the same observability warning described in attention sinks and telemetry on TPU.
Backend selection and fallback are deliberately separate in XLA backend dispatch sample and XLA backend fallback sample, so "softcap requested" is not confused with "softcap path actually executed."
The proof standard is the same as in profiler and receipts: a custom kernel claim is weak until the backend receipt is explicit.
The useful public summary
The public statement is simple: MegaCpp uses a selective PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample softcap attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns path on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries where fused softcap and stable masking semantics buy a real runtime win. It does not present that path as the default answer to every TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns problem.
Frequently asked questions
Why not implement softcap as a separate TPU pass over the score matrix?+
What makes a TPU softcap kernel production-safe?+
What are segment_ids in this TPU context?+
Where do segment_ids come from before the softcap kernel sees them?+
doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries. materialized in the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels. data path. Vocab and Tokenizer Plumbing on TPU explains why TPU insists on rectangular doc_ids, and Document-mask segment ID sample shows the compact conversion.What is the easy correctness bug when segment masks reach the softcap kernel?+
Why shrink the sparse grid instead of passing one dense mask into the TPU kernel?+
data_next, block_mask, and grid_width. That keeps the forward and backward call contract shape-stable and avoids spending TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels. work on empty columns.Why link softcap to masking and segment boundaries so aggressively?+
Terms used in this article
Start here for quick definitions, then follow the linked posts for deeper context.
The fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape.
JAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.
The fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.
The native PyTorch/XLA custom-kernel lane that traces a Pallas kernel into a payload the XLA side can keep without crossing into a generic JAX bridge call.
The Torch/XLA bridge lane that hands one narrowed TPU operation to JAX instead of moving the whole program into a JAX-owned frontend path.
The stable TPU attention family used for dense or local-mask lanes before MegaCpp drops to narrower planner-driven sparse contracts.
The token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.
Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.
A separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.
The long-context failure mode where a few tokens, often the first token, absorb disproportionate attention mass and behave like a null-attention valve.
How the v6_enriched packed-rows pipeline feeds per-token structure IDs, chunk boundaries, and call edges into the XLA dataloader on TPU v6e without…