MegaCpp EngineeringApplied C++ model systems
</>
Article
Grounded engineering note from the MegaCpp stack
Published 7 min readDavid Gornshtein
TPU
V6e
XLA
SPMD
Tokenizer
Vocab
Embeddings

Vocab and Tokenizer Plumbing on TPU: What XLA SPMD Makes You Decide Up Front

Vocab-size constraints under XLA, the padding choices that keep the compile cache stable, sharded embedding init under SPMD, and the per-specialist platform vocab story.

MegaCpp
Focused on applied C++ model engineering
Article Preview
Vocab and Tokenizer Plumbing on TPU: What XLA SPMD Makes You Decide Up Front
Published 7 min readDavid Gornshtein

The C++ specialist family in this codebase runs the same tokenizer and the same model definition on H200 and TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries v6e. The interesting part is not the tokenizer itself; it is a HuggingFace BPE artefact and a thin Python adapter. The interesting part is everything around it on the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path: how the vocab size becomes a static compile constant, how the embedding row count gets padded so the compile cache does not blow up, how the embedding parameter is sharded under SPMD without falling into XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations's propagation traps, and how the per-document platform vocabQuick term guidePlatform vocabThe compact per-document platform-ID vocabulary that travels beside token IDs and is embedded separately from the BPE rows.Groundingplatform embedding sample materialize tokenized enriched parquet gets pulled through the dataloader without forcing a synchronization point.

Why the TPU path has more rules

On the GPU path the tokenizer story is boring: build the BPE, pad the vocab to a multiple of 64 for tensor cores, register the embedding, train. On the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path the same sequence runs into three constraints the GPU path mostly ignores. First, the vocab size is part of the XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations HLO graph; a change in vocab_size (or in the padded vocab size) is a recompile. Padding decisions that are throwaway on H200 become load-bearing on v6e. Second, the embedding parameter is the largest single tensor in the model on small specialists, sitting in the propagation graph immediately under both wte and lm_head; whatever sharding XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations infers for one propagates into the other through tied or untied weights, and through the loss path. Third, the tokenizer is not the only vocabulary the dataloader pushes onto the device: a per-document platform vocabularyQuick term guidePlatform vocabThe compact per-document platform-ID vocabulary that travels beside token IDs and is embedded separately from the BPE rows.Groundingplatform embedding sample materialize tokenized enriched parquet (113 IDs covering OS, RTOS, GPU, arch, compiler, C++ standard) gets summed into a single embedding per document and added to every token in that document's row.

The two tokenizers and one padding shim

The tokenizer sample is the generic GPT-style HuggingFace BPE wrapper: a train_from_iterator(text_iterator, vocab_size) that builds a BPE with byte-fallback, a GPT-4-style pre-tokenizer regex, a ByteLevel decoder, and a fixed special-token list. The pre-tokenizer split was deliberately tuned to \p{N}{1,2} rather than \p{N}{1,3} because the wider form wastes vocabulary on multi-digit number tokens that small models never recover. The C++-aware variant adds a fixed C++ vocabulary, special-token aliases, and a decode path that preserves the expected surface for code rows.

The vocab size for the target specialist is 65536. Picking a power of two is not aesthetic; it is the largest XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations optimisation we get for free. 65536 is divisible by every TP degree we use (2, 4, 8), pads cleanly to the next tensor-core multiple, and leaves the per-shard vocab dimension at a round size when we shard rows of lm_head across the model axis. A vocab of 65000 would have forced a per-rank-uneven shard at TP=8 and a recompile every time the sharding changed.

The padding shim is in the main model runtime module. GPT.__init__ takes a pad_vocab_size_to=64 argument and rounds the configured vocab_size up to the next multiple before allocating wte and lm_head:

padded = ((cfg.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
self.wte    = nn.Embedding(padded, cfg.n_embd)
self.lm_head = nn.Linear(cfg.n_embd, padded, bias=False)
# Slice logits back to cfg.vocab_size at the loss boundary.

The forward slices logits back to vocab_size at the loss boundary so the extra rows never participate in cross-entropy. On the GPU path that pad keeps the matmul on the tensor-core fast path. On the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path it does the same job for VMEM tile sizes and, more importantly, makes the embedding shard-friendly: 64 divides every TP degree we use, so a row-sharded lm_head always sees an even per-rank slice.

Embedding sharding under SPMD

There are two annotation surfaces. On the CUDA-native TP path the embedding uses parallelize_module with RowwiseParallel(input_layouts=Replicate(), output_layouts=sp_layout, use_local_output=True) when sequence-parallel attention is on, the final RMSNorm runs as a SequenceParallel module, and the LM head accepts the seq-sharded input and produces vocab-sharded output via ColwiseParallel. Without SP, the embedding stays replicated and the LM head still gets row-sharded across TP ranks. A _vocab_parallel_enabled marker attribute is set on lm_head.weight so the fused loss path knows it is operating on per-rank vocab shards.

On the XLA SPMDQuick term guideXLA SPMDThe explicit TPU sharding mode where one compiled program carries placement rules instead of rank-local imperative code.GroundingAbout: XLA SPMD sharding annotations Example: TPU backend ownership note path under _apply_tensor_parallel_sharding, the embedding stays replicated across the model axis. The comment in the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries training launcher is explicit: the LM head row-shard is a CUDA-only optimisation, and the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries loss path operates on the full vocab tensor with the model-axis collective fused into the cross-entropy. The reason for the asymmetry is that XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations's sharding propagation around the loss is more reliable than the analogous DTensor path on CUDA, and the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries compile cache is much more sensitive to sharding-spec churn than the eager CUDA path is. We pin both decisions explicitly via xs.mark_sharding; an unannotated embedding under propagation is a future precision bug waiting to happen. That is the axis mismatch behind the policy: tokens arrive batch-sharded, while a row-sharded vocab table is partitioned over vocab rows, so the first lookup already asks XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations to reconcile two different ownership stories.

The platform vocab and how it crosses the data-loader boundary

The platform-vocabulary sample defines a flat 113-ID space (ID 0 is padding, IDs 1..112 cover the six categories). The categories share one ID space deliberately so the embedding is a single bounded-sum metadata lane rather than six separate tables. MAX_PLATFORM_IDS = 20 is the per-document buffer cap; any document with more than 20 active labels gets truncated, which has not happened on observed rows but is enforced for buffer-shape determinism.

The dataloader plumbing pulls platform IDs through three call sites in the public dataloader sample. On the parquet path, when the row group has a platform_ids column, IDs are read directly; when only the raw platform_info JSON is present, platform_info_to_ids(pi) walks the six categories and emits a sorted unique list. On the materialisation path, the per-batch buffer padded_platform = torch.zeros((B, MAX_PLATFORM_IDS), dtype=torch.long, device=device) is filled row by row with truncation to MAX_PLATFORM_IDS. The shape is fixed at construction time so the platform path sees the same (B, MAX_PLATFORM_IDS) rectangle on every step. The table being only 113 IDs does not relax that rule: on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries the dangerous part is shape drift in the live batch, not raw vocabulary size.

The text-prefix variant is the third call site. When the dataloader runs in metadata-prefix mode, platform_info_to_prefix(...) builds a single-line C++ comment and the tokenizer encodes it. That path is GPU-eager-friendly but on the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path we explicitly forbid it for input rows: strict_token_only_train raises if a row arrives without its platform prefix already encoded into token_ids. The reason is that runtime tokenisation is variable-length and would force a synchronization point per batch on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries; the offline materialisation path puts the encoded prefix into the parquet so the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries dataloader only ever sees fixed-shape token arrays.

What the TPU dataloader has to guarantee

Constraint What it forces
Vocab size in the HLO graph Pad once at model build, never change at runtime
Embedding annotation Pin via xs.mark_sharding; never let propagation infer
Per-row platform IDs Pre-padded to MAX_PLATFORM_IDS; no synchronization point per batch
Token rows Fixed (B, T) int32; no Python-side variable shapes
Doc IDs Materialised at (B, T) int32; no dynamic doc boundaries
Packing policy single_doc_block by default to keep rows rectangular
Shard list Fixed index-to-shard mapping up to the configured max

The TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries dataloader branch in the public dataloader sample and the public dataset sample is otherwise the same as the CUDA branch: parquet files from a fixed shard list, pyarrow row-group iteration, and a packing policy that the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries side runs in single_doc_block mode by default. The doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: Block-sparse attention on TPU Example: document-mask segment ID sample Example: packed row builder example tensor is materialised at (B, T) int32 so the attention masking path on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries has document boundaries without dynamic shapes.

The last shape-changing step has to finish before the live TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries batch loop. In the checked-in public-safe surfaces that means materializing token fields offline, then building fixed-width packed rows and doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: Block-sparse attention on TPU Example: document-mask segment ID sample Example: packed row builder example before handoff to the device-facing loader. Materialize tokenized enriched parquet, Packed row builder example, and Packed rows schema sample are the narrow receipts for that boundary.

Byte-fallback variability is part of why padding and materialization have to happen before Python ever sees a live TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries batch. The same rare fragment can expand to a different token count from one row to the next, so the rectangular row contract has to absorb that variability early instead of letting it mutate the runtime batch shape.

What we kept and threw away

We kept the 65536 vocab, the pad_vocab_size_to=64 shim, the _vocab_parallel_enabled marker, the asymmetric "row-shard lm_head on CUDA, replicated on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries" decision, the EmbeddingBag-style bounded metadata semantics for the 113-ID platform lane, the MAX_PLATFORM_IDS=20 cap, and the rule that runtime tokenisation is forbidden in the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries critical path.

We threw away \p{N}{1,3} in the pre-tokenizer regex (wasted vocab), the temptation to let XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations propagate the embedding sharding (precision bug surface), runtime metadata-prefix tokenisation on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries (synchronization point per batch), variable-length doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: Block-sparse attention on TPU Example: document-mask segment ID sample Example: packed row builder example (dynamic shapes break the cache), and any vocab size not divisible by every TP degree we use.

The throughline is small. On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries you decide tokenizer plumbing once, statically, and never let a Python-side variable shape into the critical path. Everything else, including the model, mostly takes care of itself.

FAQ

Frequently asked questions

Why does TPU keep the embedding replicated while CUDA row-shards lm_head?+
Because the ownership axes differ. On this lane the token batch already arrives batch-sharded, while a row-sharded vocab table is partitioned over vocab rows. Keeping the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels. embedding replicated avoids forcing that mismatch into the first lookup and leaves the model-axis collective at the loss boundary, where the compile contract is more stable.
Why keep platform labels as fixed side IDs on TPU instead of routing them through live token prefixes?+
Because the two paths carry different ownership. The platform lane is document metadata: a fixed (B, MAX_PLATFORM_IDS) side buffer that can be materialized before the live TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels. loop and then broadcast as one bounded signal, which is the contract shown by Platform embedding sample, Materialize tokenized enriched parquet, and Packed rows schema sample. A live text prefix rewrites the token row itself, so byte-fallback and prefix-length drift would push shape mutation back into the runtime path we are explicitly keeping out of the TPU batch step.
Does "EmbeddingBag-style" mean the TPU path uses live EmbeddingBag?+
No. In this article it is only a semantic label for "sum the document-level side IDs into one broadcast-ready metadata vector." The public-safe implementation surface is the fixed-width platform-ID buffer plus bounded sum/broadcast contract in Platform embedding sample, with token fields materialized before the live TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels. loop by Materialize tokenized enriched parquet. We do not rely on a live ragged bag operator to repair shape variance after the batch reaches XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here..
Can MpDeviceLoader fix runtime tokenization or ragged packed rows on TPU?+
No. PyTorch/XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here. documents MpDeviceLoader(..., input_sharding=...) as the sharding-aware host-to-device loader, but that sits after the tokenizer and packer boundary. In this stack the public-safe contract is to materialize token fields and build fixed-width packed rows before the live TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels. loop; the loader is allowed to shard and prefetch those rows, not to make runtime tokenization shape-stable after the fact. That is the same boundary carried by Tokenized enriched pipeline on TPU and Packed row builder example.
Glossary

Terms used in this article

Start here for quick definitions, then follow the linked posts for deeper context.

XLA SPMD

The explicit TPU sharding mode where one compiled program carries placement rules instead of rank-local imperative code.

Platform vocab

The compact per-document platform-ID vocabulary that travels beside token IDs and is embedded separately from the BPE rows.

doc_ids

The fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.

segment_ids

The fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape.

XLA

The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.

TPU

Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.

Topic hubs