MegaCpp EngineeringApplied C++ model systems
</>
Article
Grounded engineering note from the MegaCpp stack
Published 9 min readDavid Gornshtein
TPU
V6e
XLA
Transformer Engine
Layer Spec
FP8

Transformer Engine replacements on TPU: keeping one model definition across paths

Transformer Engine is an NVIDIA Hopper and Blackwell story. On TPU v6e it does not exist. This is the layer-spec abstraction and the XLA-friendly substitutes that let one model definition ship across both paths.

MegaCpp
Focused on applied C++ model engineering
Article Preview
Transformer Engine replacements on TPU: keeping one model definition across paths
Published 9 min readDavid Gornshtein

Transformer EngineQuick term guideTransformer EngineNVIDIA's Transformer Engine library path for accelerated Transformer modules and lower-precision training surfaces such as FP8, kept behind optional adapter seams in these posts.GroundingAbout: Transformer Engine on H200 and Blackwell-class GPUs: the bridge we use Reference: NVIDIA Transformer Engine documentation Reference: Transformer Engine FP8 and FP4 primer is the load-bearing fast path on the GPU path: fused norm+linear, FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper autocast for the QKV matmul, cuDNN flash attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns, fused MoEQuick term guideMoEToken Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble.GroundingThe MoE Routing We Actually Shipped Sequence, Context, and Expert Splits in the Hybrid Stack permute. None of it ports to TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries v6e. The TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path has a different stack (XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations, PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample, JAXQuick term guideJAXA separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.GroundingAbout: libtpu and JAX interaction Reference: libtpu / PJRT / JAX ownership boundaries Reference: Pallas on TPU kernels), a different precision story (no FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper in deployment today), and a different sharding model. The interesting engineering question is not "how do we get TE on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries"; it is "how do we keep one model definition that sees TE on the GPU path and a clean XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations-traceable substitute on the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path, without forking any block code." This post is about that abstraction and the substitutes we ended up with.

Why This Matters

Two paths ship the same specialists: CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200 hosts and TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries v6e slices. The training loop, the attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns modules, the MoEQuick term guideMoEToken Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble.GroundingThe MoE Routing We Actually Shipped Sequence, Context, and Expert Splits in the Hybrid Stack router, the long-context curriculum, the data pipeline, the tokenizer — all single-source. The only places allowed to diverge are kernel implementations and the precision plan. If the model definition forks, every feature lands twice and the ablations stop being comparable.

The hardest piece of that constraint is the Transformer EngineQuick term guideTransformer EngineNVIDIA's Transformer Engine library path for accelerated Transformer modules and lower-precision training surfaces such as FP8, kept behind optional adapter seams in these posts.GroundingAbout: Transformer Engine on H200 and Blackwell-class GPUs: the bridge we use Reference: NVIDIA Transformer Engine documentation Reference: Transformer Engine FP8 and FP4 primer surface. On the GPU path, TE owns the highest-throughput path for at least four primitives: pre-norm fused into the next matmul, the QKV/MLP column-parallel linears, FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns via fp8_autocast, and the fused MoEQuick term guideMoEToken Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble.GroundingThe MoE Routing We Actually Shipped Sequence, Context, and Expert Splits in the Hybrid Stack permute kernel. On the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path, none of those exist as TE modules. The substitute has to be (a) numerically equivalent at bf16, (b) traceable by XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations without dynamic shapes, (c) shardable under SPMD without surprise propagation, and (d) selectable at construction time so the same block code lives in both worlds.

The Shared Layer-Spec Approach

The pivot is the public TE layer-spec sample. It is a smaller adaptation of Megatron-CoreQuick term guideMegatron CoreThe NVIDIA framework surface MegaCpp ports into through narrow adapters, layer specs, and runtime ownership bridges.GroundingAbout: Porting to Megatron friction About: Nemotron-style recipe as pure Megatron CLI Example: Mamba3 TP mixer sample's ModuleSpec pattern: a plain dict[str, type | None] that maps seven block component names — norm, linear_qkv, linear_proj, linear_fc1, linear_fc2, attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns, layernorm_mlp — to TE classes when TE is importable and to None otherwise. The mapping mirrors Megatron's TE submodule selection: linear_qkv and linear_fc1 both resolve to LayerNormLinear (pre-norm folded into the next GEMM), linear_proj and linear_fc2 are Linear, attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns is DotProductAttention, and layernorm_mlp is the full-block LayerNormMLP fusion.

The contract that makes this useful on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries is import safety. The module-level _TE_AVAILABLE flag is set inside a try/except that swallows every failure: missing package, ABI mismatch, broken transitive dependency. When TE is unavailable, calling te_layer_spec(use_te=True) returns a dict of None values; the caller substitutes natively and never sees an exception. The same import line lives at the top of a block file regardless of platform.

We deliberately did not copy Megatron's full ModuleSpec machinery: the blocks (Block, ABlockQuick term guideablockThe attention-heavy block family in MegaCpp's A/M/E/R notation.GroundingAbout: SLM architecture Example: block taxonomy sample, MBlockQuick term guidemblockThe state-space or Mamba-family block in MegaCpp's A/M/E/R notation.GroundingAbout: SLM architecture Example: block taxonomy sample, EBlockQuick term guideeblockThe expert / MoE block family in MegaCpp's A/M/E/R notation.GroundingAbout: SLM architecture Example: block taxonomy sample in the main model runtime module) take module references at __init__, not spec objects with params/submodules fields. A dict is the minimal shim that matches the existing surface and composes with both paths.

On the GPU path, te_layer_spec(use_te=True) returns the seven TE classes and _TEAttentionBlock assembles them into a fused QKV -> DPA -> output-proj block: pre-norm folded into LayerNormLinear, GQA via num_gqa_groups=num_kv_heads, bshd layout, residual sum at the end.

On the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path, where the dict is all None, the substitutes are: F.rms_norm + F.linear for the fused-norm-linear keys (XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations HLO fuses the two into a single GEMM); the PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample softcap kernel for attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns exposing the same (q, k, v, doc_ids, softcap) surface as TE DPA; a SwiGLU MLP (two F.linear + F.silu(gate) * up) for layernorm_mlp, within single-digit percent of TE's fused kernel; nn.Linear + mark_shardingQuick term guidemark_sharding(...)PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement.GroundingAbout: XLA SPMD sharding annotations Example: FSDP sharding sample Reference: FSDP2 on XLA TPU for linear_proj/linear_fc2, letting XLA SPMDQuick term guideXLA SPMDThe explicit TPU sharding mode where one compiled program carries placement rules instead of rank-local imperative code.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: XLA SPMD sharding annotations Example: TPU backend ownership note pick the collective; and an equal-split bf16 MoEQuick term guideMoEToken Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble.GroundingThe MoE Routing We Actually Shipped Sequence, Context, and Expert Splits in the Hybrid Stack dispatch in the MoEQuick term guideMoEToken Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble.GroundingThe MoE Routing We Actually Shipped Sequence, Context, and Expert Splits in the Hybrid Stack dispatch runtime module (no .item(), no data-dependent shapes) in place of the fused moe_permute_with_probs/moe_unpermute from the public TE permutation sample.

the public Megatron block sample is the block-level adapter: a Megatron TransformerLayer wrapper that uses the TE layer spec on the GPU path and degrades silently on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries via MEGATRON_AVAILABLE. Same (config, layer_idx) constructor signature, same forward signature.

FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper-on-TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries is the explicit asterisk. We do not ship FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper on the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path today; TE DPA's fp8_autocast is the cleanest FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns path on H200Quick term guideH200NVIDIA's Hopper H200 GPU platform, typically discussed here as an 8-GPU training node with large HBM capacity and NVLink-connected ranks.GroundingAbout: training on 8x H200 Reference: H200 memory geometry Reference: training speed anatomy on H200, and the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries equivalent is waiting for libtpuQuick term guidelibtpuThe TPU backend library that pairs with PJRT/XLA and owns device-side execution underneath the frontend.GroundingAbout: libtpu / PJRT ownership boundaries Example: TPU backend ownership note Example: XLA runtime probe sample's per-tensor FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper to mature for our shapes. The TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path stays bf16 with the same clipping, the same Muon/AdamW split, the same loss target. FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper is a precision-plan object, not a structural choice — on at construction on the GPU path, no-op on the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path.

How it lands in deployment

The lift-as-is parts: the seven-key dict surface in the public TE layer-spec sample, the _TEAttentionBlock assembly, the import-safety pattern. They become the canonical "TE or native" selector and the model definition stops carrying ad-hoc branching anywhere else.

Rewritten on the way in:

  1. The native fallback is consolidated. MegaCpp has a dozen places in the main model runtime module where the block code reads spec[key] is None and assembles the substitute inline. Production centralizes that into a NativeLayerSpec factory with the same key setQuick term guidekey setThe selected sparse key positions that survive routing and stay visible to the later score or mask update path.GroundingAbout: DSA indexer memory fix deep dive History: DSA and CUDA graph safety Example: DSA CUDA graph safety sample. The block code calls one factory regardless of platform; the factory returns either TE classes or the XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations-friendly substitutes.
  2. The TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries-path substitutes get explicit SPMD partition specs at construction time. Relying on XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations sharding propagation to infer the spec from the surrounding graph has bitten this code repeatedly. Production pins every substitute parameter with mark_shardingQuick term guidemark_sharding(...)PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement.GroundingAbout: XLA SPMD sharding annotations Example: FSDP sharding sample Reference: FSDP2 on XLA TPU at the same call site that constructs it.
  3. The FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper plan becomes a precision-plan object rather than a constructor flag. On the GPU path it wraps the relevant linears with fp8_autocast; on the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path it is bf16 everywhere with a clear log line saying so.
  4. the public TE linear-replacement sample's post-hoc rewrite of every nn.Linear with te.Linear becomes the GPU-path initialization step that runs after model construction and before FSDP wrapping. The exclusion list (wte, wpe, lm_head, router, shared_expert_gate, engram, mhc, ngram_hash, structure_emb, platform_emb, temporal_, lora_) lifts as-is. On the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path the rewrite is a no-op.

Dropped: the per-experiment ad-hoc branching for "if TE installed do X else do Y" that grew up across half a dozen modules. It is replaced with the single dict surface.

Moved to a kernel/PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample path: the attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns substitute. The block code calls a single attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns adapter; on the GPU path that adapter is TE DPA, on the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path it is the PallasQuick term guidePallasJAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.GroundingAbout: Pallas on TPU Example: Pallas kernel selection note Example: XLA Pallas bridge receipt sample softcap kernel. Same call signature, two backends, one selector.

Becomes a feature flag: --use_te_block_layers, --use_te_all_linears, and the FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper precision plan. All three are no-ops on the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path and the entry point logs which subset is active so an ablation across paths is unambiguous.

Ablations and what we kept

The seven-key spec, summarised across paths:

Spec key GPU path (TE) TPU path substitute Notes
norm + linear_qkv LayerNormLinear F.rms_norm + F.linear XLA fuses to a single GEMM
norm + linear_fc1 LayerNormLinear F.rms_norm + F.linear Same fusion path
linear_proj, linear_fc2 te.Linear nn.Linear + mark_sharding XLA SPMD picks the collective
attention DotProductAttention Pallas FA softcap kernel One adapter, two backends
layernorm_mlp LayerNormMLP SwiGLU MLP (two F.linear) Off by default; harder to TP-shard
MoE permute the public TE permutation sample fused bf16 equal-split path Same (idx, gates) surface
FP8 plan FP8 attn + FP8 experts + bf16 rest bf16 everywhere Logged on startup

The ablation history on this path is mostly about what failed silently. A few items shaped the current form:

The TE in_proj fusion for the Mamba mixer (the public TE input projection sample work) was the cleanest "TE win on GPU, no-op on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries" win we have. Replacing the Mamba nn.Linear in_proj with TELayerNormColumnParallelLinear folds the LN into the column-parallel projection, drops one kernel launch, and matches the surrounding TE block precision plan. On the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path the same module is a plain F.rms_norm + F.linear; the XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations fuser does the right thing and the loss curves overlay. Sentinel values inside the wrapper let the block code stay agnostic.

tp_comm_overlap=True did not survive contact. The the public Megatron block sample config builder removed it after an audit of the TE extension layer: setting the flag requires a matching te.initialize_ub(...) call before model construction, which our GPU-path bring-up path does not do. Leaving the flag on would have produced a latent crash the moment someone ran --use_megatron_block --megatron_tp --tensor_parallel=2 --sequence_parallel. We kept the five real fusions (masked_softmax_fusion, persist_layer_norm, attention_softmax_in_fp32=False, apply_rope_fusion, gradient_accumulation_fusion) which do not depend on user-buffer overlap.

The LayerNormMLP full-block fuse is exposed but not the default. Megatron does not use it by default because it makes fc1/fc2 sharding harder for TPQuick term guideTPTensor parallelism splits each linear's weights (QKV, O, MLP gate/up/down) across GPUs. On 8× H200 with TP=8 each GPU owns 1/8 of every matmul's columns or rows, so one big matmul becomes 8 smaller ones that all-reduce at the layer boundary. Cost: one all-reduce per attention and per MLP — heavy bandwidth, so TP is usually bound to a single NVLink/NVSwitch island (1 node of up to 8 GPUs). Embeddings, layernorms, and optimizer state stay replicated across the TP GPUs. Use TP when a single layer's weights don't fit on one GPU, not to scale past one node.GroundingAbout: parallelism map overview Example: TP partition-shape sample Reference: tensor parallel and sharding; for single-GPU and FSDP-only runs it is the fastest path on the CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200 side. On the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path the substitute is the SwiGLU MLP we already had. The block code reads spec["layernorm_mlp"]; if non-None it uses the full-block fuse, otherwise it uses the substitute.

The "use TE everywhere via post-hoc replacement" path (the public TE linear-replacement sample) survived contact but only with the exclusion list. Replacing embeddings with te.Linear clobbers the tied lm_head weight; replacing LoRA adapters wraps an existing Linear and breaks the rank-decomposition; replacing the MoEQuick term guideMoEToken Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble.GroundingThe MoE Routing We Actually Shipped Sequence, Context, and Expert Splits in the Hybrid Stack router silently FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper-quantizes a 1024-wide projection that needs full bf16 precision to keep routing decisions stable. The exclusion list is a load-bearing piece of the contract.

What we tried and did not keep: a "single FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper plan that works everywhere" that conflated H200Quick term guideH200NVIDIA's Hopper H200 GPU platform, typically discussed here as an 8-GPU training node with large HBM capacity and NVLink-connected ranks.GroundingAbout: training on 8x H200 Reference: H200 memory geometry Reference: training speed anatomy on H200 FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper with a hypothetical TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper. The two are not the same. We pulled them back into separate precision plans, with the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries plan being explicitly bf16 today and the GPU plan being FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper-attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns plus FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper-experts plus bf16-rest. Each plan logs its full configuration on startup; reading "FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper plan: bf16 everywhere" on a TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries run is the right kind of unsurprising.

Production checklist

The block-side selector, in one place:

from te_layer_spec import te_layer_spec

spec = te_layer_spec(use_te=True)  # all-None on TPU; TE classes on GPU
norm_qkv = spec["linear_qkv"] or RmsNormLinear     # substitute on TPU
attn     = spec["attention"]   or pallas_fa_softcap_adapter
ln_mlp   = spec["layernorm_mlp"] or SwiGLUMLP
# FP8 is a separate precision-plan object, no-op on TPU.
FAQ

Frequently asked questions

What would replace nn.Linear if it stops preserving layout cleanly?+
The next candidate is not another TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.-only branch in the block. It is a native projection wrapper, likely shaped around an explicit einsum, that keeps the batch, sequence, and model axes visible to XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here. instead of flattening them behind a generic matmul. If that path lands, it should still be selected behind the same seven-key layer-spec surface and carry the same regression proof: matching signatures, explicit sharding, and a loss-curve overlay against the TE path.
Glossary

Terms used in this article

Start here for quick definitions, then follow the linked posts for deeper context.

Pallas

JAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.

Transformer Engine

NVIDIA's Transformer Engine library path for accelerated Transformer modules and lower-precision training surfaces such as FP8, kept behind optional adapter seams in these posts.

mark_sharding(...)

PyTorch/XLA's explicit tensor-placement annotation API: attach a mesh plus partition spec to a tensor so one TPU XLA program lowers with stable owned placement.

XLA

The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.

FP8

Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.

ablock

The attention-heavy block family in MegaCpp's A/M/E/R notation.

mblock

The state-space or Mamba-family block in MegaCpp's A/M/E/R notation.

eblock

The expert / MoE block family in MegaCpp's A/M/E/R notation.

Attention

The token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.

TPU

Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.

TP

Tensor parallelism splits each linear's weights (QKV, O, MLP gate/up/down) across GPUs. On 8× H200 with TP=8 each GPU owns 1/8 of every matmul's columns or rows, so one big matmul becomes 8 smaller ones that all-reduce at the layer boundary. Cost: one all-reduce per attention and per MLP — heavy bandwidth, so TP is usually bound to a single NVLink/NVSwitch island (1 node of up to 8 GPUs). Embeddings, layernorms, and optimizer state stay replicated across the TP GPUs. Use TP when a single layer's weights don't fit on one GPU, not to scale past one node.

libtpu

The TPU backend library that pairs with PJRT/XLA and owns device-side execution underneath the frontend.

JAX

A separate frontend above PJRT/libtpu. In these TPU posts it mainly matters as the owner of NamedSharding, PartitionSpec, and the optional call_jax or Pallas-adjacent bridge lanes.

Megatron Core

The NVIDIA framework surface MegaCpp ports into through narrow adapters, layer specs, and runtime ownership bridges.

doc_ids

The fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.

XLA SPMD

The explicit TPU sharding mode where one compiled program carries placement rules instead of rank-local imperative code.

MoE

Token Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble.

key set

The selected sparse key positions that survive routing and stay visible to the later score or mask update path.

CUDA

NVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.

H200

NVIDIA's Hopper H200 GPU platform, typically discussed here as an 8-GPU training node with large HBM capacity and NVLink-connected ranks.

Topic hubs