MegaCpp EngineeringApplied C++ model systems
</>
Article
Grounded engineering note from the MegaCpp stack
Published 6 min readDavid Gornshtein
Pallas
TPU
V6e
JAX

Pallas kernels on TPU v6e: what we ship and what we deleted

Where Pallas beats the XLA lowering on TPU v6e, where it loses, the debugging workflow that keeps us sane, and the kernel deltas we kept versus the ones we reverted.

MegaCpp
Focused on applied C++ model engineering
Article Preview
Pallas kernels on TPU v6e: what we ship and what we deleted
Published 6 min readDavid Gornshtein

PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries is useful on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries, but the public documentation is clear about one thing: it is still an experimental kernel-writing surface. The practical split here is trace_pallasQuick term guidetrace_pallasThe native PyTorch/XLA custom-kernel lane that traces a Pallas kernel into a payload the XLA side can keep without crossing into a generic JAX bridge call.GroundingExample: trace_pallas scalar-prefetch sample Example: XLA Pallas bridge receipt Reference: libtpu / PJRT / JAX ownership boundaries versus call_jaxQuick term guidecall_jaxThe Torch/XLA bridge lane that hands one narrowed TPU operation to JAX instead of moving the whole program into a JAX-owned frontend path.GroundingAbout: libtpu / PJRT / JAX ownership boundaries Example: XLA call_jax bridge Example: call_jax bridge runtime: a native PyTorch/XLA custom-kernel lane versus a narrower JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries bridge. MegaCpp therefore treats it as a narrow TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries-only tool, not as the default answer to every optimization problem and not as a substitute for NVIDIA or NVFP4Quick term guideNVFP4NVIDIA's four-bit floating-point inference/training format family used when the lane can tolerate more aggressive quantization than FP8.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 Reference: NVFP4 inference-specific kernel work. That ownership split is why this post links PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries to backend receipts and Torch XLA and PJRT reality instead of presenting it as a generic TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries speed knob.

For first-touch readers, TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries is the accelerator/runtime surface, XLA is the compiler/runtime layer that lowers tensor programs into TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries graphs, and PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries is the narrower JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries-side kernel language MegaCpp reaches for only when plain XLA lowering is no longer enough.

The rule that actually matters

The important question is not "can we write this in PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries?" The important question is "does writing this in PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries buy us something XLA lowering does not already give us?"

MegaCpp keeps a short decision rule:

The compact checked-in version of that rule is Pallas kernel selection note. In practice, trace_pallasQuick term guidetrace_pallasThe native PyTorch/XLA custom-kernel lane that traces a Pallas kernel into a payload the XLA side can keep without crossing into a generic JAX bridge call.GroundingExample: trace_pallas scalar-prefetch sample Example: XLA Pallas bridge receipt Reference: libtpu / PJRT / JAX ownership boundaries is the preferred native path, call_jaxQuick term guidecall_jaxThe Torch/XLA bridge lane that hands one narrowed TPU operation to JAX instead of moving the whole program into a JAX-owned frontend path.GroundingAbout: libtpu / PJRT / JAX ownership boundaries Example: XLA call_jax bridge Example: call_jax bridge runtime is the narrower bridge, and the code-level proof of that split belongs in XLA Pallas bridge receipt sample and XLA backend dispatch sample.

That rule is stricter than it sounds. A custom kernel is not just another function. It becomes part of the compile contract, the sharding story, and the upgrade surface. That is exactly the class of failure that turns into graph recompilation hell once shape or backend assumptions stop being explicit.

Where Pallas earns its keep

Publicly defensible use cases are the ones where the custom kernel surface is obviously doing something structural:

Those are the kinds of cases where a custom TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries kernel can earn its maintenance budget.

The clearest public examples are Pallas FA softcap on TPU and block-sparse attention on TPU, where the kernel exists because the structure is real, not because the default lowering was merely unfashionable.

The checked-in samples around Pallas softcap attention sample and Pallas grid shrinking sample show the two recurring wins: keeping softcap or window metadata inside the kernel, and shrinking sparse worklists before launch.

The useful bandwidth lesson is qualitative, not a portable benchmark ratio: PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries is interesting when it avoids materializing a large score or mask intermediate, keeps online-softmax or softcap bookkeeping inside the tiled attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns path, and leaves a receipt that proves the specialized backend ran. That is the same boundary drawn in Pallas FA softcap on TPU, where the custom path is justified by avoiding an extra pass over the score matrix rather than by treating one shape-specific measurement as a universal speedup.

Where XLA is the right answer

MegaCpp leaves many paths alone:

That is an important part of the public claim. A TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries stack becomes harder to trust when every path is rewritten just because a lower-level tool exists.

The safe hardware-facing rule is simpler: sparse work only pays when the skipped regions are large enough to survive the block and padding contract. JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries documents BlockSpec as the mapping from grid invocations to input or output blocks, and uneven block shapes are padded on input and discarded on output. That is why fixed-shape data paths and OOM on v6e-style memory budgets still matter more than the mere existence of some sparsity.

That static-shape rule is not special to PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries, but PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries tends to feel it sooner because block and grid choices are part of the kernel contract. If the sequence or batch surface drifts every step, the team is paying both the general XLA recompile cost and a worse tile fit for the custom path. That is why MegaCpp pairs PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries work with graph recompilation hell and the checked-in TPU compile/runtime control sample instead of treating dynamic batching as a free optimization.

The workflow that survived

The debugging workflow is simple:

  1. keep mask or layout logic reproducible on CPU first
  2. compare custom-kernel outputs against a trusted reference path
  3. record which backend actually executed
  4. only promote the kernel if it wins clearly enough to justify its maintenance cost

That receipt-first habit is the same one described in profiler and receipts: if the team cannot show which backend ran, the optimization story is not ready to ship.

This is why MegaCpp prefers explicit backend receipts. A TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries configuration should never need log archaeology to answer the question "did the custom kernel actually run?" The same receipt should also make it obvious that this is a TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries and JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries/PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries lane, not an NVIDIA precision or CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200-kernel lane.

Routing and fallback are kept explicit in XLA backend dispatch sample and XLA backend fallback sample, so backend identity does not depend on reading opaque logs after the fact. The same examples make the ownership split concrete: trace_pallasQuick term guidetrace_pallasThe native PyTorch/XLA custom-kernel lane that traces a Pallas kernel into a payload the XLA side can keep without crossing into a generic JAX bridge call.GroundingExample: trace_pallas scalar-prefetch sample Example: XLA Pallas bridge receipt Reference: libtpu / PJRT / JAX ownership boundaries keeps the hotspot inside PyTorch/XLA, while call_jaxQuick term guidecall_jaxThe Torch/XLA bridge lane that hands one narrowed TPU operation to JAX instead of moving the whole program into a JAX-owned frontend path.GroundingAbout: libtpu / PJRT / JAX ownership boundaries Example: XLA call_jax bridge Example: call_jax bridge runtime widens the cache and runtime surface into a JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries bridge.

Why MegaCpp is conservative here

PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries can be excellent for narrow cases, but TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries execution already has enough moving parts: frontend behavior, PJRTQuick term guidePJRTThe TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.GroundingAbout: libtpu / PJRT ownership boundaries About: Torch XLA / PJRT reality Example: TPU backend ownership note, runtime versions, compile caches, and sharding behavior. Every unnecessary custom kernel makes that matrix harder to reason about.

MegaCpp therefore keeps a bias toward deletion:

The public claim

The useful public statement is:

That is consistent with the official PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingExample: Pallas kernel selection note Example: XLA Pallas bridge receipt sample Reference: libtpu / PJRT / JAX ownership boundaries docs and avoids presenting an experimental surface as if it were a settled platform guarantee.

The same discipline carries into libtpu, PJRT, JAX, and ownership boundaries, Torch XLA, PJRT, and runtime policy reality, and Determinism and bit-exact runs: TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries wins are only real when backend ownership, compile receipts, and correctness contracts all stay explicit.

FAQ

Frequently asked questions

What is the practical difference between trace_pallas and call_jax?+
trace_pallasQuick term guidetrace_pallasThe native PyTorch/XLA custom-kernel lane that traces a Pallas kernel into a payload the XLA side can keep without crossing into a generic JAX bridge call. keeps the hotspot inside the PyTorch/XLA-owned custom-kernel lane, while call_jaxQuick term guidecall_jaxThe Torch/XLA bridge lane that hands one narrowed TPU operation to JAX instead of moving the whole program into a JAX-owned frontend path. briefly crosses into a narrower JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes. bridge. Both can be useful, but the bridge has a wider cache and runtime surface, which is why MegaCpp asks for explicit backend receipts before calling it a stable optimization path.
Does Pallas make dynamic sequence lengths free?+
No. PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract. changes who owns the tile and memory contract; it does not remove JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes./XLA compile specialization. MegaCpp therefore treats dynamic sequence or batch choices as a bucketing problem first and a kernel problem second: if the shape surface keeps changing, the custom path can pay both the normal compilation-cache miss and a worse BlockSpec fit. The local guardrails are TPU compile/runtime control sample and graph recompilation hell.
Should norms or other small reduction-heavy ops move to Pallas too?+
Usually no. MegaCpp's public rule is to leave mostly elementwise or reduction-heavy work in XLA unless there is a demonstrated structural win, because otherwise the custom kernel mostly expands the compile and maintenance surface without changing the real bottleneck. The shortest local proof surfaces are Pallas kernel selection note and Pallas FA softcap on TPU.
What should a public-safe Pallas backend receipt actually prove?+
It should prove the chosen backend and the handoff contract, not force the reader to depend on private dump paths or internal log spelunking. In this article the checked-in proof surfaces are XLA backend dispatch sample, XLA Pallas bridge receipt sample, and trace-pallas scalar-prefetch sample: they keep the backend choice, validated cases, and traced custom-call payload visible in a public-safe form.
What makes a Pallas BlockSpec risky enough to review separately?+
A BlockSpec is not just bookkeeping: it maps each grid invocation to the input or output block the kernel sees, so changing the block shape or index map changes the compiled kernel contract. For TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels., JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes. documents additional block-shape restrictions and out-of-bounds padding behavior, which is why MegaCpp reviews grid shrinking, validity normalization, and fixed-bucket shape choices together instead of treating mask compaction as a harmless prelude. The local proof surfaces are Pallas grid shrinking sample, Pallas softcap attention sample, and TPU compile/runtime control sample.
Glossary

Terms used in this article

Start here for quick definitions, then follow the linked posts for deeper context.

Pallas

JAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.

trace_pallas

The native PyTorch/XLA custom-kernel lane that traces a Pallas kernel into a payload the XLA side can keep without crossing into a generic JAX bridge call.

call_jax

The Torch/XLA bridge lane that hands one narrowed TPU operation to JAX instead of moving the whole program into a JAX-owned frontend path.

Attention

The token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.

TPU

Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.

libtpu

The TPU backend library that pairs with PJRT/XLA and owns device-side execution underneath the frontend.

PJRT

The TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.

segment_ids

The fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape.

JAX

A separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.

NVFP4

NVIDIA's four-bit floating-point inference/training format family used when the lane can tolerate more aggressive quantization than FP8.

CUDA

NVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.

Topic hubs