Tokenized enriched packed rows on TPU: feeding structure to XLA without recompiles
How the v6_enriched packed-rows pipeline feeds per-token structure IDs, chunk boundaries, and call edges into the XLA dataloader on TPU v6e without triggering compile cache misses, and how that contract lifts into the main path.

The point of an enriched dataset is to teach a transformer that code has structure that is not in the bytes: function boundaries, dependency levels, call edges, type references. The point of a packed-row layout is to keep the long-context step at maximum tokens-per-second by stuffing many short documents and a few large ones into a fixed-shape window. Doing both on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries means letting an MXU-bound model see structure-aware data without the XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations compile cache exploding every batch. This post is how we used that pipeline on the v6_enriched parquet format, what the dataloader has to canonicalise to keep the compiled graph stable, and how the contract lifts into the deployment builder.
Why MegaCpp cares about this
Pretraining on raw C++ tokens leaves a lot on the table. The byte-level model has to relearn that class, };, and indented function bodies mean something; that #include hints at which symbols to expect; that the call graph beats lexical proximity. Our enriched parquet turns that into per-token data: token_structure_ids (9 categories), token_dep_levels (BFS depth in the dep graph), token_ast_depth, token_sibling_index, token_ast_node_type, plus chunk-level token_chunk_starts/ends/kinds/dep_levels, plus token_call_edges and token_type_edges between chunks. On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries v6e you cannot pay for this with host-side recompilation; every Python branch on column presence is a graph cache miss.
The win, when the pipeline holds, is twofold. First, switching from char-level enrichment to pretokenized columns removes avoidable runtime alignment work from the dataloader hot path. Second, the structure embedding and TreeFFN paths become materially easier to run on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries once the chunk mapping is precomputed at packingQuick term guidePacked rowsWhy packed rows are the real boundary between the data pipeline and the model, and why MegaCpp treats row packing as a schema contract rather than a…GroundingPacked rows as the real training contract time and threaded through as static-shape tensors.
Public pipeline contract
Four public codepaths carry the contract: the tokenized-enriched schema, the offline materializer, the packed-rows schema, and the structure-embedding consumer. The schema modules hold column names and coercion helpers; the materializer holds the offline char-to-token conversion; the structure module is the model-side consumer.
The public-safe example lane mirrors those seams: the token materialization example shows the offline conversion into token-aligned columns, the packed-row schema excerpt freezes required columns and dense fallback fills, the token chunk layout example remaps chunk and edge metadata into token offsets, and the structure graph consumer example shows the fixed-shape chunk graph interface. Treat those as interface receipts, not as permission to infer missing boundaries at runtime; the compiled path still depends on a fixed (B, T) recipe and explicit metadata defaults.
The tokenized-enriched schema is the source of truth for column names: TOKEN_IDS_COLUMN, TOKEN_STRUCTURE_IDS_COLUMN, TOKEN_DEP_LEVELS_COLUMN, TOKEN_AST_DEPTH_COLUMN, TOKEN_SIBLING_INDEX_COLUMN, TOKEN_AST_NODE_TYPE_COLUMN, TOKEN_SYMBOL_IDS_COLUMN, TOKEN_CALL_TARGETS_COLUMN, TOKEN_TYPE_REFS_COLUMN, TOKEN_DEF_USE_COLUMN, plus chunk-level TOKEN_CHUNK_STARTS/ENDS/KINDS/DEP_LEVELS_COLUMN and TOKEN_CALL_EDGES_COLUMN / TOKEN_TYPE_EDGES_COLUMN. TypedDicts and predicates (_is_token_value_sequence, _extract_span_bounds) keep producer and consumer agreeing on shapes.
The tokenized-enriched pipeline is the offline materializer. materialize_tokenized_enriched_batch(docs, tokenizer) encodes texts with encode_batch to recover per-token character spans, then walks each char-level metadata array assigning each token the value of its first character (_chars_to_tokens_structure_ids). Chunk boundaries become token offsets (_chunk_boundaries_to_token_offsets); per-token dep levels come from chunk-level dep levels and the token-to-chunk mapping (_compute_token_dep_levels); edges are remapped from char-level to token-level chunk indices (_remap_token_edges). The materializer refuses to proceed if the tokenizer does not produce per-token character spans; silently emitting unaligned metadata would be a correctness bug invisible until the model misbehaved.
The packed-rows schema defines the runtime contract. Packed rowsQuick term guidePacked rowsWhy packed rows are the real boundary between the data pipeline and the model, and why MegaCpp treats row packing as a schema contract rather than a…GroundingPacked rows as the real training contract have a fixed layout: pack_id, input_ids, target_ids, loss_maskQuick term guideloss_maskThe per-token training mask that decides which positions contribute to loss after packing, FIM rearrangement, or documentation-aware masking.GroundingAbout: document masking and curriculum Example: packed rows schema sample Example: FIM long-context metadata sample, doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample, valid_token_countQuick term guidevalid_token_countThe per-row prefix length of non-pad tokens; runtimes use it as the cheap validity receipt instead of rescanning variable-length packed payloads.GroundingAbout: packed rows as the real training contract Example: packed rows schema sample Example: packed row builder example, num_docs, plus optional provenance. PACKED_ROWS_TOKEN_ALIGNED_COLUMNS and PACKED_ROWS_CHUNK_METADATA_COLUMNS are explicit tuples so consumers never guess. PACKED_ROWS_DENSE_FALLBACK_FILL_VALUES splits fills deliberately: zero for category-style columns (structure, dep level, change masks), -1 for true sentinels (ast depth, sibling index, ast node type, hunk id). The dataloader uses these to keep batch shapes stable.
The structure-embeddings path is the model-side consumer. The TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries fast path in StructureGraphEnricher.forward exists because TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries is allergic to scatter, searchsorted, topk, and nonzero: each lowers to host-syncing or shape-fragile ops. The TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path replaces them with bmm-based pooling and cumsum-based top-K neighbour selection.
Feeding structure_ids, chunk_boundaries, call_edges through XLA without recompiles
Three sources of recompilation bite enriched data: variable presence of optional metadata, variable shapes per batch, and Python branches inside the model that read tensor values to choose a path. We addressed each.
The first is solved at the boundary, in _canonicalize_structure_meta_for_xla(x, structure_meta) in the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries training launcher. On the XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations branch this helper materialises every optional enriched tensor at a stable shape when the corresponding feature is enabled. With structure embeddings on, token_structure_ids, token_dep_levels, token_ast_depth, token_sibling_index, token_ast_node_type are present as (B, T) every batch; missing columns get the fallback fills. With TreeFFN on, token_chunk_ids, token_chunk_valid, semantic_block_starts/ends/valid, and the semantic_block_keep_mask / semantic_block_edge_weights matrices materialise at fixed max_semantic_blocks = max(128, T // 32). Optional tensors the model does not need are popped explicitly so the dict has the same key set every batch.
The second is solved at packingQuick term guidePacked rowsWhy packed rows are the real boundary between the data pipeline and the model, and why MegaCpp treats row packing as a schema contract rather than a…GroundingPacked rows as the real training contract time. Chunk boundaries vary per row, so the dataloader pads them to FIXED_MAX_CHUNKS = max(128, T // 32) (matching the canonicaliser) with zero starts/ends. Valid-chunk counts are recoverable from ends > starts. chunk_relation_mask (call/type edges as a per-chunk relation matrix) is padded to (B, R, FIXED_MAX_CHUNKS, FIXED_MAX_CHUNKS). Padding is the cheapest fix; the alternative is per-batch dynamic shapes the partitioner cannot cache.
A better packer does not need a looser compile contract. Better document placement can raise token density inside the same (B, T) window and the same FIXED_MAX_CHUNKS cap, but letting chunk caps or enabled columns drift by batch just trades padding savings for extra HLO variants, cache pressure, and warmup work. On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries, a little masked zero work is usually cheaper than more graph keys.
That is where best-fit packing in data prep belongs: it improves row utilization before the loader sees the batch. The emitted packed row still carries fixed input_ids, doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample, valid_token_countQuick term guidevalid_token_countThe per-row prefix length of non-pad tokens; runtimes use it as the cheap validity receipt instead of rescanning variable-length packed payloads.GroundingAbout: packed rows as the real training contract Example: packed rows schema sample Example: packed row builder example, and chunk tensors, while seed discipline keeps the changed row composition reproducible.
The third is solved inside the model. StructureGraphEnricher.forward keys on _use_bmm = token_chunk_ids is not None or x.device.type == "xla", so the XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations path always uses bmm pooling. When the dataloader supplies token_chunk_ids / token_chunk_valid (the common case after wiring) we skip the searchsorted reconstruction; otherwise we fall back to a sort-then-searchsorted path that is still XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations-clean (sort over a sentinel-padded array, gather to invert the permutation) but slower. The TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries bmm pooling reads F.one_hot(chunk_ids, C).permute(0, 2, 1) masked by valid, then bmm(membership, x) and normalises; it lands on the MXU.
Neighbour selection is similarly XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations-shaped. CUDA uses topk(adj, K, dim=-1); on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries topk lowers to a full sort plus slice on the VPU, expensive and shape-fragile. The XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations path uses cumsum over the boolean adjacency to compute a 1-based rank, masks cumsum <= K, then argsort(descending=True, stable=True)[..., :K]. Same result, MXU-friendly, deterministic. The neighbour scatter for incoming aggregation becomes bmm(neighbor_membership_t, msg_flat) instead of scatter_add.
Varlen handling
Packed rowsQuick term guidePacked rowsWhy packed rows are the real boundary between the data pipeline and the model, and why MegaCpp treats row packing as a schema contract rather than a…GroundingPacked rows as the real training contract mix many small documents with a few large ones. The model sees a single (B, T) tensor with doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample and valid_token_countQuick term guidevalid_token_countThe per-row prefix length of non-pad tokens; runtimes use it as the cheap validity receipt instead of rescanning variable-length packed payloads.GroundingAbout: packed rows as the real training contract Example: packed rows schema sample Example: packed row builder example; the varlen contract lives at three layers. At the parquet layer, the packer guarantees doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample is a monotonically increasing per-token int array within a row and valid_token_countQuick term guidevalid_token_countThe per-row prefix length of non-pad tokens; runtimes use it as the cheap validity receipt instead of rescanning variable-length packed payloads.GroundingAbout: packed rows as the real training contract Example: packed rows schema sample Example: packed row builder example is the prefix length. PACKED_ROWS_PACKER_REQUIRED_COLUMNS enforces it; a row missing any required column is rejected at load time, not zero-filled.
The same rule is why doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample and valid_token_countQuick term guidevalid_token_countThe per-row prefix length of non-pad tokens; runtimes use it as the cheap validity receipt instead of rescanning variable-length packed payloads.GroundingAbout: packed rows as the real training contract Example: packed rows schema sample Example: packed row builder example stay non-negotiable even when older shards degrade optional enriched payloads to deterministic empty defaults. Backward-readable rows are fine; row-level ambiguity about segment boundaries is not, because recovering those boundaries inside the compiled region is exactly the kind of batch drift that turns one recipe into many cached graphs.
At the attention layer, dense paths see doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample and build a per-token same-doc causal mask. The CUDA varlen FA path derives cu_seqlensQuick term guidecu_seqlensThe cumulative sequence-length offsets passed to varlen attention kernels so packed subsequences stay isolated without computing then masking cross-document attention.Groundingpacked rows as the real training contract from doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample (flash_attention.get_cu_seqlens_from_doc_ids) outside the compiled region; the cu_seqlensQuick term guidecu_seqlensThe cumulative sequence-length offsets passed to varlen attention kernels so packed subsequences stay isolated without computing then masking cross-document attention.Groundingpacked rows as the real training contract tensor itself is small. On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries the Pallas FA kernel takes segment_idsQuick term guidesegment_idsThe fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample derived from doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample once at the boundary.
At the structure layer, valid_token_countQuick term guidevalid_token_countThe per-row prefix length of non-pad tokens; runtimes use it as the cheap validity receipt instead of rescanning variable-length packed payloads.GroundingAbout: packed rows as the real training contract Example: packed rows schema sample Example: packed row builder example drives normalize_attention_validity to produce an AttentionValidityQuick term guideAttentionValidityThe validity carrier built from row-level counts or masks so sparse or structured attention paths know which token prefix is real without re-inferring it inside the compiled region.GroundingAbout: attention validity and structure Example: Pallas softcap attention sample carrier. The blockized sparse path consumes this via classify_selected_block_masks; the structure embedding consumes it implicitly via token_chunk_valid. Neither infers validity inside the compiled region. The kernels run on dense (B, T) tensors with masks from precomputed metadata; the partitioner caches one graph per recipe.
How it lands in deployment
The deployment builder ships the same contract with a Megatron-shaped consumer. The embedding layer ports the additive structure embedding as a dedicated deployment module; the config layer carries the fail-closed StructureConfig that translates training args into a frozen dataclass. The component allowlist is ("structure", "dep_level", "ast_depth", "sibling_index", "ast_node_type"); the deployment default is "core" (structure + dep_level).
The custom-embedding path subclasses Megatron's LanguageModelEmbedding and adds the structure embedding (and the n-gram hash embedding) as additive contributions before the Megatron forward. The sharded-state-dict walker is patched so custom submodules get distinct replica_id stamps when MTP replicates the embedding on a non-first pipeline stage; without that fix, the default walker stamps a duplicate "main replica" and the checkpoint becomes ambiguous.
The structure-batch bridge extracts structure inputs from a batch dict (structure_ids, dep_levels, ast_depth_ids, sibling_index_ids, node_type_ids) and threads them onto the model via set_structure_inputs. Conditional setting is fine here; the equivalent of _canonicalize_structure_meta_for_xla on the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries deployment path is a recipe-level guarantee that the loader emits the columns every batch.
Lifted as-is: schema column names, dense fallback fills, the varlen contract on doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample + valid_token_countQuick term guidevalid_token_countThe per-row prefix length of non-pad tokens; runtimes use it as the cheap validity receipt instead of rescanning variable-length packed payloads.GroundingAbout: packed rows as the real training contract Example: packed rows schema sample Example: packed row builder example, the AttentionValidityQuick term guideAttentionValidityThe validity carrier built from row-level counts or masks so sparse or structured attention paths know which token prefix is real without re-inferring it inside the compiled region.GroundingAbout: attention validity and structure Example: Pallas softcap attention sample normalisation, bmm-based pooling and cumsum-based neighbour selection. Rewritten: the offline materializer becomes a Megatron preprocessing job; packingQuick term guidePacked rowsWhy packed rows are the real boundary between the data pipeline and the model, and why MegaCpp treats row packing as a schema contract rather than a…GroundingPacked rows as the real training contract moves to the deployment data pipeline; the TreeFFN loop is recipe-flagged and defaults off pending more ablation at deployment scale. Dropped: the legacy char-level enriched_code_v3 format does not enter the main path; we migrated to pretokenized v6_enriched_* because the char-to-token conversion was the dataloader bottleneck. Feature-flagged: tree_ffn_enabled, relation_bias_enabled, platform_embed_enabled stay as recipe flags; the additive structure embedding is the default-on minimum.
Ablations and what we kept
The wins that survived the migration into the deployment loader:
| Change | Where | Effect |
|---|---|---|
| Pretokenized over char-level enriched | tokenized-enriched pipeline | removes per-batch char-to-token alignment from the hot path |
F.conv1d regardless of doc_ids |
mixer | ~22% recovered on CUDA |
| Bottleneck dim 64 | structure-embedding path | narrower structure features without changing the fixed-shape contract |
Precomputed token_chunk_ids/valid |
GPT.forward |
Skips searchsorted in TreeFFN |
_canonicalize_structure_meta_for_xla at boundary |
the TPU training launcher | Eliminates per-batch graph cache misses |
The decisions that matter here are architectural rather than release-note-shaped:
- Pretokenized over char-level. Pretokenized columns are authoritative because they move expensive alignment work into offline materialization and keep the runtime loader in fixed-shape tensor territory.
- Conv1d on CUDA regardless of
doc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample. The 4-iteration manual depthwise conv was triggering wheneverdoc_ids is not None(always true on enriched data). Switching back toF.conv1drecovered ~22%; cross-doc leakage at most 3 tokens for kernel=4 is negligible. - Bottleneck dim 64. Initial cost was around 23% throughput, but the absolute number stayed competitive once paired with the precomputed chunk mapping. Production keeps 64.
- Precomputed
token_chunk_idsandtoken_chunk_valid.GPT.forwardthreads them intoStructureGraphEnricher, so TreeFFN skips the searchsorted reconstruction. The fast path was wired through CUDA TP broadcast, CP sharding, and XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations/CUDA structure-meta canonicalisation in the same window. - Canonicalisation at the XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations boundary. Per-batch presence-conditional Python branches inside the model produced cache misses; pulling canonicalisation to the train script and materialising every optional tensor at fixed shape eliminated them.
- Bench harness. The enriched-loader benchmark emits a JSON report with B, T, iters, token rate; re-run on every dataloader change.
- Validator coverage. The packed-row schema, metadata invariants, and loader contract all need dedicated tests. The requirement is not a specific internal test quota; it is that every new column and every loader assumption is guarded by schema and consumer checks.
- Schema split. Producer-local provenance fields stay in the offline packer until producer and runtime contracts unify; the runtime schema only depends on what the dataloader actually consumes.
Production checklist
The boundary canonicaliser is the load-bearing line on the XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations path:
# the TPU training launcher — XLA branch only
def _canonicalize_structure_meta_for_xla(x, structure_meta):
# Materialise every optional enriched tensor at fixed shape.
# Missing columns get PACKED_ROWS_DENSE_FALLBACK_FILL_VALUES.
structure_meta = _ensure_token_aligned(structure_meta, B, T)
if tree_ffn_enabled:
structure_meta = _ensure_chunk_metadata(
structure_meta, max_chunks=max(128, T // 32),
)
return structure_meta # shape-stable; no Python branches downstream
- Lock the parquet schema per recipe; reject loads where
PACKED_ROWS_PACKER_REQUIRED_COLUMNSare missing. - Run
_canonicalize_structure_meta_for_xlaat the train-script boundary on the TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries path; no presence-conditional Python branches inside the model. - Pin
FIXED_MAX_CHUNKS = max(128, T // 32)between dataloader, canonicaliser, and model; bump in lockstep. - Default structure embedding to
"core"with bottleneck dim 64; gateast_depth,sibling_index,ast_node_type, TreeFFN behind explicit recipe flags. - Keep the precomputed
token_chunk_ids/token_chunk_validpath on; reject data drops that lose these columns. - Use bmm-based pooling and cumsum-based neighbour selection on TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries; no
topk,scatter_add,searchsorted, ornonzeroinside the compiled region. - Derive
cu_seqlensQuick term guidecu_seqlensThe cumulative sequence-length offsets passed to varlen attention kernels so packed subsequences stay isolated without computing then masking cross-document attention.Groundingpacked rows as the real training contract (CUDA) andsegment_idsQuick term guidesegment_idsThe fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample (TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries) fromdoc_idsQuick term guidedoc_idsThe fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.GroundingAbout: XLA SPMD tokenizer and vocab on TPU About: Block-sparse attention on TPU Example: document-mask segment ID sample outside the compiled region; cache per batch. - Run a fixed-shape enriched-loader benchmark on every data-side change and treat regressions as loader bugs until proven otherwise.
- Refuse the legacy char-level enriched format; force the pretokenized path.
- Validate per-token character spans before the offline materializer runs; do not silently emit unaligned metadata.
Frequently asked questions
Why not drop FIXED_MAX_CHUNKS once bounded dynamic shapes or newer TPU v6e tooling exist?+
(B, T) contract, materialize optional structure tensors at fixed bounds, and let newer compiler or hardware features reduce overhead inside that envelope instead of replacing the envelope.Terms used in this article
Start here for quick definitions, then follow the linked posts for deeper context.
Why packed rows are the real boundary between the data pipeline and the model, and why MegaCpp treats row packing as a schema contract rather than a…
The per-token training mask that decides which positions contribute to loss after packing, FIM rearrangement, or documentation-aware masking.
The per-row prefix length of non-pad tokens; runtimes use it as the cheap validity receipt instead of rescanning variable-length packed payloads.
The fixed-width per-token document identifiers that keep packed rows auditable and let TPU masking respect document boundaries.
The fixed-width segment labeling used to preserve document boundaries without changing the TPU kernel shape.
The cumulative sequence-length offsets passed to varlen attention kernels so packed subsequences stay isolated without computing then masking cross-document attention.
The validity carrier built from row-level counts or masks so sparse or structured attention paths know which token prefix is real without re-inferring it inside the compiled region.
JAX's kernel language for writing explicit TPU kernels when stock XLA lowering is not enough for the required tile, memory-layout, or masking contract.
The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.
Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.