Pallas kernels on TPU v6e: what we ship and what we deleted
Where Pallas beats the XLA lowering on TPU v6e, where it loses, the debugging workflow that keeps us sane, and the kernel deltas we kept versus the ones we reverted.

PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries is useful on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries, but the public documentation is clear about one thing:
it is still an experimental kernel-writing surface. The practical split here is
trace_pallasQuick term guidetrace_pallasThe native PyTorch/XLA custom-kernel lane that traces a Pallas kernel into a payload the XLA side can keep without crossing into a generic JAX bridge call.GroundingExample: trace_pallas scalar-prefetch sample Example: XLA Pallas bridge receipt Reference: libtpu / PJRT / JAX ownership boundaries versus call_jaxQuick term guidecall_jaxThe Torch/XLA bridge lane that hands one narrowed TPU operation to JAX instead of moving the whole program into a JAX-owned frontend path.GroundingAbout: libtpu / PJRT / JAX ownership boundaries Example: XLA call_jax bridge Example: call_jax bridge runtime: a native PyTorch/XLA custom-kernel lane
versus a narrower JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries bridge. MegaCpp therefore treats it as a narrow TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries-only
tool, not as the default answer to every optimization problem and not as a
substitute for NVIDIA or NVFP4Quick term guideNVFP4NVIDIA's four-bit floating-point inference/training format family used when the lane can tolerate more aggressive quantization than FP8.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 Reference: NVFP4 inference-specific kernel work. That ownership split is
why this post links PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries to backend receipts and Torch XLA and PJRT
reality instead of presenting it as a generic TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries
speed knob.
For first-touch readers, TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries is the accelerator/runtime surface, XLA is the compiler/runtime layer that lowers tensor programs into TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries graphs, and PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries is the narrower JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries-side kernel language MegaCpp reaches for only when plain XLA lowering is no longer enough.
The rule that actually matters
The important question is not "can we write this in PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries?" The important question is "does writing this in PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries buy us something XLA lowering does not already give us?"
MegaCpp keeps a short decision rule:
- prefer XLA lowering when the default path is already good enough
- prefer PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries when tile control, local-window structure, or segment-aware masking removes a real hot-path cost
The compact checked-in version of that rule is Pallas kernel selection note. In practice, trace_pallasQuick term guidetrace_pallasThe native PyTorch/XLA custom-kernel lane that traces a Pallas kernel into a payload the XLA side can keep without crossing into a generic JAX bridge call.GroundingExample: trace_pallas scalar-prefetch sample Example: XLA Pallas bridge receipt Reference: libtpu / PJRT / JAX ownership boundaries is the preferred native path, call_jaxQuick term guidecall_jaxThe Torch/XLA bridge lane that hands one narrowed TPU operation to JAX instead of moving the whole program into a JAX-owned frontend path.GroundingAbout: libtpu / PJRT / JAX ownership boundaries Example: XLA call_jax bridge Example: call_jax bridge runtime is the narrower bridge, and the code-level proof of that split belongs in XLA Pallas bridge receipt sample and XLA backend dispatch sample.
That rule is stricter than it sounds. A custom kernel is not just another function. It becomes part of the compile contract, the sharding story, and the upgrade surface. That is exactly the class of failure that turns into graph recompilation hell once shape or backend assumptions stop being explicit.
Where Pallas earns its keep
Publicly defensible use cases are the ones where the custom kernel surface is obviously doing something structural:
- keeping softcap or local-window logic inside the hot loop
- using segment idsQuick term guidesegment_idsThe fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample instead of materializing a dense mask
- keeping block structure explicit rather than relying on a generic dense fallback
- holding tile sizes fixed so the step does not recompile
Those are the kinds of cases where a custom TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries kernel can earn its maintenance budget.
The clearest public examples are Pallas FA softcap on TPU and block-sparse attention on TPU, where the kernel exists because the structure is real, not because the default lowering was merely unfashionable.
The checked-in samples around Pallas softcap attention sample and Pallas grid shrinking sample show the two recurring wins: keeping softcap or window metadata inside the kernel, and shrinking sparse worklists before launch.
The useful bandwidth lesson is qualitative, not a portable benchmark ratio: PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries is interesting when it avoids materializing a large score or mask intermediate, keeps online-softmax or softcap bookkeeping inside the tiled attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns path, and leaves a receipt that proves the specialized backend ran. That is the same boundary drawn in Pallas FA softcap on TPU, where the custom path is justified by avoiding an extra pass over the score matrix rather than by treating one shape-specific measurement as a universal speedup.
Where XLA is the right answer
MegaCpp leaves many paths alone:
- plain dense attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns
- short-sequence cases where compile overhead dominates
- norms and similar reduction-heavy operations that XLA already fuses well
- dynamic-shape stories that would turn every step into a retrace or recompile
That is an important part of the public claim. A TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries stack becomes harder to trust when every path is rewritten just because a lower-level tool exists.
The safe hardware-facing rule is simpler: sparse work only pays when the skipped regions are large enough to survive the block and padding contract. JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries documents BlockSpec as the mapping from grid invocations to input or output blocks, and uneven block shapes are padded on input and discarded on output. That is why fixed-shape data paths and OOM on v6e-style memory budgets still matter more than the mere existence of some sparsity.
That static-shape rule is not special to PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries, but PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries tends to feel it sooner because block and grid choices are part of the kernel contract. If the sequence or batch surface drifts every step, the team is paying both the general XLA recompile cost and a worse tile fit for the custom path. That is why MegaCpp pairs PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries work with graph recompilation hell and the checked-in TPU compile/runtime control sample instead of treating dynamic batching as a free optimization.
The workflow that survived
The debugging workflow is simple:
- keep mask or layout logic reproducible on CPU first
- compare custom-kernel outputs against a trusted reference path
- record which backend actually executed
- only promote the kernel if it wins clearly enough to justify its maintenance cost
That receipt-first habit is the same one described in profiler and receipts: if the team cannot show which backend ran, the optimization story is not ready to ship.
This is why MegaCpp prefers explicit backend receipts. A TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries configuration should never need log archaeology to answer the question "did the custom kernel actually run?" The same receipt should also make it obvious that this is a TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries and JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries/PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries lane, not an NVIDIA precision or CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200-kernel lane.
Routing and fallback are kept explicit in XLA backend dispatch sample and XLA backend fallback sample, so backend identity does not depend on reading opaque logs after the fact. The same examples make the ownership split concrete: trace_pallasQuick term guidetrace_pallasThe native PyTorch/XLA custom-kernel lane that traces a Pallas kernel into a payload the XLA side can keep without crossing into a generic JAX bridge call.GroundingExample: trace_pallas scalar-prefetch sample Example: XLA Pallas bridge receipt Reference: libtpu / PJRT / JAX ownership boundaries keeps the hotspot inside PyTorch/XLA, while call_jaxQuick term guidecall_jaxThe Torch/XLA bridge lane that hands one narrowed TPU operation to JAX instead of moving the whole program into a JAX-owned frontend path.GroundingAbout: libtpu / PJRT / JAX ownership boundaries Example: XLA call_jax bridge Example: call_jax bridge runtime widens the cache and runtime surface into a JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries bridge.
Why MegaCpp is conservative here
PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries can be excellent for narrow cases, but TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries execution already has enough moving parts: frontend behavior, PJRTQuick term guidePJRTThe TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.GroundingAbout: libtpu / PJRT ownership boundaries About: Torch XLA / PJRT reality Example: TPU backend ownership note, runtime versions, compile caches, and sharding behavior. Every unnecessary custom kernel makes that matrix harder to reason about.
MegaCpp therefore keeps a bias toward deletion:
- if a PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries path is only tied with XLA, it should probably be removed
- if a PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries path helps only at one very narrow shape, it should stay experimental until the use case is stable
- if a PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries path requires dynamic per-step shape choices, it should probably stay out of the training hot loop
The public claim
The useful public statement is:
- MegaCpp uses PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries selectively on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries
- it keeps XLA lowering as the default for many paths
- it promotes a custom TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries kernel only when the structural win is clear
- it treats backend receipts and correctness checks as part of the deployment contract
That is consistent with the official PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries docs and avoids presenting an experimental surface as if it were a settled platform guarantee.
The same discipline carries into libtpu, PJRT, JAX, and ownership boundaries, Torch XLA, PJRT, and runtime policy reality, and Determinism and bit-exact runs: TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries wins are only real when backend ownership, compile receipts, and correctness contracts all stay explicit.
Frequently asked questions
What is the practical difference between trace_pallas and call_jax?+
trace_pallasQuick term guidetrace_pallasThe native PyTorch/XLA custom-kernel lane that traces a Pallas kernel into a payload the XLA side can keep without crossing into a generic JAX bridge call. keeps the hotspot inside the PyTorch/XLA-owned custom-kernel lane, while call_jaxQuick term guidecall_jaxThe Torch/XLA bridge lane that hands one narrowed TPU operation to JAX instead of moving the whole program into a JAX-owned frontend path. briefly crosses into a narrower JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes. bridge. Both can be useful, but the bridge has a wider cache and runtime surface, which is why MegaCpp asks for explicit backend receipts before calling it a stable optimization path.Does Pallas make dynamic sequence lengths free?+
BlockSpec fit. The local guardrails are TPU compile/runtime control sample and graph recompilation hell.Should norms or other small reduction-heavy ops move to Pallas too?+
What should a public-safe Pallas backend receipt actually prove?+
What makes a Pallas BlockSpec risky enough to review separately?+
BlockSpec is not just bookkeeping: it maps each grid invocation to the input or output block the kernel sees, so changing the block shape or index map changes the compiled kernel contract. For TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels., JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes. documents additional block-shape restrictions and out-of-bounds padding behavior, which is why MegaCpp reviews grid shrinking, validity normalization, and fixed-bucket shape choices together instead of treating mask compaction as a harmless prelude. The local proof surfaces are Pallas grid shrinking sample, Pallas softcap attention sample, and TPU compile/runtime control sample.Terms used in this article
Start here for quick definitions, then follow the linked posts for deeper context.
JAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.
The native PyTorch/XLA custom-kernel lane that traces a Pallas kernel into a payload the XLA side can keep without crossing into a generic JAX bridge call.
The Torch/XLA bridge lane that hands one narrowed TPU operation to JAX instead of moving the whole program into a JAX-owned frontend path.
The token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.
Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.
The TPU backend library that pairs with PJRT/XLA and owns device-side execution underneath the frontend.
The TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.
The fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape.
A separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.
NVIDIA's four-bit floating-point inference/training format family used when the lane can tolerate more aggressive quantization than FP8.
NVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.