MegaCpp EngineeringApplied C++ model systems
</>
Article
Grounded engineering note from the MegaCpp stack
Published 4 min readDavid Gornshtein
Mamba
Cross Entropy
Parity
Deep Dive

Mamba linear CE parity deep dive

Why output-layer swaps in Mamba-style stacks need explicit CE parity checks, not just shape compatibility checks.

MegaCpp
Focused on applied C++ model engineering
Article Preview
Mamba linear CE parity deep dive
Published 4 min readDavid Gornshtein

An output path can preserve tensor shapes while still drifting on the logits-to-loss contract. That is why a CE parity reproducer is worth publishing. It narrows the question to the only thing that matters: does the alternate path preserve the same logits-to-loss behavior?

The near-copy version keeps the output-layer contract visible: one path uses a fused linear-plus-cross-entropy module and the other keeps a plain column-parallel output layer until the loss-path contract is restored. That is closer to the real failure than a generic cross_entropy(hidden @ W.T, targets) toy. This makes the post the output-side sibling of Porting to Megatron friction: the useful work is in keeping the seam visible rather than pretending shape parity proves enough.

What the checked-in examples prove

If you want the local proof surfaces before the article prose, start with MegaCpp example index, Mamba linear CE parity sample, Mamba linear CE parity near-copy, and Megatron FLCE Hopper near-copy. The compact sample keeps the logits-to-CE interface tiny enough to inspect directly. The checked-in sample is more useful when you care about the real integration seam because it keeps the class-level boundary visible: one path owns a fused output-and-loss surface, the other exposes a plain column-parallel layer until the patch restores the expected contract.

Why shape parity is not enough

Shape checks are necessary and still too weak. Two paths can agree on output dimensions and still drift if one accumulates its running max and denominator in a different dtype, or if one loss path averages and reduces tokens at a different point than the other. That is why CE parity here is a behavioral property, not a shape property. The real question is whether the runtime still consumes the same logits-to-loss boundary it was designed for.

The distributed case is stricter still. Once the output layer is vocab-parallel or sequence-parallel, parity is no longer only a local logits question. A serious parity proof has to name where the max reduction happens, where the denominator sum reduction happens, and whether the final loss is normalized by local tokens or by the global token count. On sequence-parallel lanes it also helps to state the layout explicitly, usually [S, B, V/TP] or [S/SP, B, V/TP], because "parity passed" otherwise often means only that the local tensor shapes looked plausible.

That accumulator story deserves to be named plainly. In vocab-parallel or chunked CE, the running max and running denominator are part of the contract, not an invisible implementation detail. If one path upcasts those accumulators and another keeps them in the execution dtype, the loss can drift before the final token reduction ever happens. That is why a useful parity note says where accumulator ownership lives: inside a fused kernel, across TPQuick term guideTPTensor parallelism splits each linear's weights (QKV, O, MLP gate/up/down) across GPUs. On 8× H200 with TP=8 each GPU owns 1/8 of every matmul's columns or rows, so one big matmul becomes 8 smaller ones that all-reduce at the layer boundary. Cost: one all-reduce per attention and per MLP — heavy bandwidth, so TP is usually bound to a single NVLink/NVSwitch island (1 node of up to 8 GPUs). Embeddings, layernorms, and optimizer state stay replicated across the TP GPUs. Use TP when a single layer's weights don't fit on one GPU, not to scale past one node.GroundingAbout: parallelism map overview Example: TP partition-shape sample Reference: tensor parallel and sharding reductions, or only after a materialized-logits fallback. Megatron FLCE on Hopper is the closest sibling when the question becomes how that boundary is implemented on the fused side.

Sequence-parallel lanes add one more place to be explicit. Once the tensor is [S/SP, B, V/TP], the parity claim should also say whether the final mean divides by local visible tokens or by the global non-padding token count after cross-rank reductions. Without that statement, a reproducer can pass the obvious shape checks and still change gradient scale across ranks, which is exactly the kind of seam drift that later gets misread as a generic training-health problem rather than an output-layer problem. Liger FLCE reduction none is the neighboring post when the reduction contract itself is the thing under test.

The easiest concrete example is small enough to keep in your head. If one lane averages over 1024 visible local tokens while a TPQuick term guideTPTensor parallelism splits each linear's weights (QKV, O, MLP gate/up/down) across GPUs. On 8× H200 with TP=8 each GPU owns 1/8 of every matmul's columns or rows, so one big matmul becomes 8 smaller ones that all-reduce at the layer boundary. Cost: one all-reduce per attention and per MLP — heavy bandwidth, so TP is usually bound to a single NVLink/NVSwitch island (1 node of up to 8 GPUs). Embeddings, layernorms, and optimizer state stay replicated across the TP GPUs. Use TP when a single layer's weights don't fit on one GPU, not to scale past one node.GroundingAbout: parallelism map overview Example: TP partition-shape sample Reference: tensor parallel and sharding=8 fused lane divides only after cross-rank reductions over 8192 non-padding tokens, the per-token logits can still look fine while the returned gradient scale is different by 8x. That is why this post keeps reduction ownership next to Megatron FLCE on Hopper instead of treating it as an implementation footnote.

That is also why the near-copy reproducer matters more than a toy script. A toy hidden @ W.T check can validate arithmetic in isolation, but it cannot show whether a path swap quietly changed the owned output-and-loss seam. The checked-in near-copy files keep that seam visible enough to debug, which is the same operational reason this article belongs next to Megatron FLCE on Hopper, Liger FLCE reduction none, Author Mamba3 spec inside Megatron, and Loss curves and divergence playbook.

FAQ

Frequently asked questions

Why is shape parity not enough?+
Because logits can still drift relative to the loss path even when tensor dimensions match. The smallest checked-in proof is Mamba linear CE parity sample, which keeps the logits-to-CE interface tiny enough to inspect directly.
Why can a distributed CE parity check fail even when the single-GPU version passes?+
Because the distributed path owns extra math the single-GPU check never sees: vocab-parallel max and sum reductions across ranks, sequence-parallel layout exchange, and sometimes a different global-versus-local loss scaling rule. A local parity script can validate the class seam and still miss the real multi-rank contract.
Why does vocab-partition target masking belong in the parity receipt?+
Because the fused path does not only accept a smaller logits tensor. Each rank owns a [S, B, V/TP] slice, while the target ids still describe the global vocabulary. The parity receipt therefore has to say where out-of-partition targets are masked or remapped before the predicted-logit gather and backward update happen. A full-logits PyTorch baseline can skip that local-vocab step, so it may hide wrong partition ranges or ignore-token handling that only fails when the all-gather is removed. Keep that target-mask contract next to the max/sum reductions; the neighboring fused-loss checks are Megatron FLCE on Hopper and Liger FLCE reduction none.
Why does local-token versus global-token averaging deserve its own check?+
Because it changes gradient scale even when the logits look fine. A plain PyTorch CrossEntropyLoss(reduction="mean") averages over the visible tokens in the local call site. A vocab-parallel fused path can instead normalize after cross-rank reductions over the global non-padding token count. If the parity receipt never states which count owns the final divide, two lanes can agree on shape and even on per-token logits while still backpropagating different update magnitudes.
Why can a parity probe OOM even when the fused training lane fits?+
Because the probe can silently change the memory class. The fused distributed lane keeps logits split as [S, B, V/TP] or, under sequence parallelQuick term guideSPSequence parallelism is a TP-region activation saver — not a separate mesh. Plain TP leaves layernorm / dropout / residual activations replicated on every TP GPU; SP keeps those intermediates sharded along the sequence axis so each TP GPU holds only 1/TP of them. Cost: same bandwidth as plain TP — the single all-reduce becomes an all-gather + reduce-scatter pair. Weights identical to plain TP; only the activation tensors shrink. Turn on whenever TP is on — near-free memory savings, which is what makes long contexts fit under TP., [S/SP, B, V/TP] and resolves the loss with cross-rank reductions. A plain PyTorch parity rewrite often all-gathers until it materializes [S, B, V], so the first failure becomes full-logit residency rather than CE math. That is why the checked-in near-copy lane is more useful than a full materialization baseline: it preserves the output-and-loss seam instead of replacing it with a different memory regime. The neighboring receipts are Megatron FLCE on Hopper and Porting to Megatron friction. The arithmetic is large enough to be a separate receipt check. With B=4, S=4096, V=151936, and BF16 logits, a materialized [B, S, V] tensor is about 4.98 GB before gradients and about 9.96 GB if an equally large gradient buffer is live. A chunked fused path with a 4096-token vocab chunk is about 134 MB for the same BF16 working logits. If a parity probe crosses from the second regime into the first, the OOM says more about the probe than about the fused CE contract.
Why use a near-copy reproducer instead of a toy CE script?+
Because the real bug lived at the output-and-loss boundary, not in a generic dense matrix multiply followed by cross entropy. The near-copy file Mamba linear CE parity near-copy shows the actual class-level seam and the fix that restores the expected runtime contract.
What if the compared loss path also adds label smoothing or z-loss?+
Then the parity note has to name where those penalties enter the graph. A fused loss can inject them inside the normalized-probability path, while plain PyTorch often applies them as separate autograd nodes afterward. Two lanes can therefore agree on shapes and still differ in rounding, scaling, and backward ownership unless the parity check keeps that penalty boundary explicit too.
Which neighboring posts help when the parity drift turns into training-health drift?+
Start with Megatron FLCE on Hopper for the fused-loss sibling, Liger FLCE reduction none for reduction-contract drift, Author Mamba3 spec inside Megatron for the input-side seam, and Loss curves and divergence playbook for the operational failure-reading side.
Glossary

Terms used in this article

Start here for quick definitions, then follow the linked posts for deeper context.

TP

Tensor parallelism splits each linear's weights (QKV, O, MLP gate/up/down) across GPUs. On 8× H200 with TP=8 each GPU owns 1/8 of every matmul's columns or rows, so one big matmul becomes 8 smaller ones that all-reduce at the layer boundary. Cost: one all-reduce per attention and per MLP — heavy bandwidth, so TP is usually bound to a single NVLink/NVSwitch island (1 node of up to 8 GPUs). Embeddings, layernorms, and optimizer state stay replicated across the TP GPUs. Use TP when a single layer's weights don't fit on one GPU, not to scale past one node.

SP

Sequence parallelism is a TP-region activation saver — not a separate mesh. Plain TP leaves layernorm / dropout / residual activations replicated on every TP GPU; SP keeps those intermediates sharded along the sequence axis so each TP GPU holds only 1/TP of them. Cost: same bandwidth as plain TP — the single all-reduce becomes an all-gather + reduce-scatter pair. Weights identical to plain TP; only the activation tensors shrink. Turn on whenever TP is on — near-free memory savings, which is what makes long contexts fit under TP.

Mamba

A grounded look at why MegaCpp combines Mamba-style state-space blocks with a smaller number of attention blocks for long-context C++ work, and…

Mamba3

A grounded look at why MegaCpp combines Mamba-style state-space blocks with a smaller number of attention blocks for long-context C++ work, and…