Mamba linear CE parity deep dive
Why output-layer swaps in Mamba-style stacks need explicit CE parity checks, not just shape compatibility checks.

An output path can preserve tensor shapes while still drifting on the logits-to-loss contract. That is why a CE parity reproducer is worth publishing. It narrows the question to the only thing that matters: does the alternate path preserve the same logits-to-loss behavior?
The near-copy version keeps the output-layer contract visible: one path uses a fused linear-plus-cross-entropy module and the other keeps a plain column-parallel output layer until the loss-path contract is restored. That is closer to the real failure than a generic cross_entropy(hidden @ W.T, targets) toy. This makes the post the output-side sibling of Porting to Megatron friction: the useful work is in keeping the seam visible rather than pretending shape parity proves enough.
What the checked-in examples prove
If you want the local proof surfaces before the article prose, start with MegaCpp example index, Mamba linear CE parity sample, Mamba linear CE parity near-copy, and Megatron FLCE Hopper near-copy. The compact sample keeps the logits-to-CE interface tiny enough to inspect directly. The checked-in sample is more useful when you care about the real integration seam because it keeps the class-level boundary visible: one path owns a fused output-and-loss surface, the other exposes a plain column-parallel layer until the patch restores the expected contract.
Why shape parity is not enough
Shape checks are necessary and still too weak. Two paths can agree on output dimensions and still drift if one accumulates its running max and denominator in a different dtype, or if one loss path averages and reduces tokens at a different point than the other. That is why CE parity here is a behavioral property, not a shape property. The real question is whether the runtime still consumes the same logits-to-loss boundary it was designed for.
The distributed case is stricter still. Once the output layer is vocab-parallel or sequence-parallel, parity is no longer only a local logits question. A serious parity proof has to name where the max reduction happens, where the denominator sum reduction happens, and whether the final loss is normalized by local tokens or by the global token count. On sequence-parallel lanes it also helps to state the layout explicitly, usually [S, B, V/TP] or [S/SP, B, V/TP], because "parity passed" otherwise often means only that the local tensor shapes looked plausible.
That accumulator story deserves to be named plainly. In vocab-parallel or chunked CE, the running max and running denominator are part of the contract, not an invisible implementation detail. If one path upcasts those accumulators and another keeps them in the execution dtype, the loss can drift before the final token reduction ever happens. That is why a useful parity note says where accumulator ownership lives: inside a fused kernel, across TPQuick term guideTPTensor parallelism splits each linear's weights (QKV, O, MLP gate/up/down) across GPUs. On 8× H200 with TP=8 each GPU owns 1/8 of every matmul's columns or rows, so one big matmul becomes 8 smaller ones that all-reduce at the layer boundary. Cost: one all-reduce per attention and per MLP — heavy bandwidth, so TP is usually bound to a single NVLink/NVSwitch island (1 node of up to 8 GPUs). Embeddings, layernorms, and optimizer state stay replicated across the TP GPUs. Use TP when a single layer's weights don't fit on one GPU, not to scale past one node.GroundingAbout: parallelism map overview Example: TP partition-shape sample Reference: tensor parallel and sharding reductions, or only after a materialized-logits fallback. Megatron FLCE on Hopper is the closest sibling when the question becomes how that boundary is implemented on the fused side.
Sequence-parallel lanes add one more place to be explicit. Once the tensor is
[S/SP, B, V/TP], the parity claim should also say whether the final mean
divides by local visible tokens or by the global non-padding token count after
cross-rank reductions. Without that statement, a reproducer can pass the
obvious shape checks and still change gradient scale across ranks, which is
exactly the kind of seam drift that later gets misread as a generic
training-health problem rather than an output-layer problem. Liger FLCE reduction none
is the neighboring post when the reduction contract itself is the thing under
test.
The easiest concrete example is small enough to keep in your head. If one lane averages over 1024 visible local tokens while a TPQuick term guideTPTensor parallelism splits each linear's weights (QKV, O, MLP gate/up/down) across GPUs. On 8× H200 with TP=8 each GPU owns 1/8 of every matmul's columns or rows, so one big matmul becomes 8 smaller ones that all-reduce at the layer boundary. Cost: one all-reduce per attention and per MLP — heavy bandwidth, so TP is usually bound to a single NVLink/NVSwitch island (1 node of up to 8 GPUs). Embeddings, layernorms, and optimizer state stay replicated across the TP GPUs. Use TP when a single layer's weights don't fit on one GPU, not to scale past one node.GroundingAbout: parallelism map overview Example: TP partition-shape sample Reference: tensor parallel and sharding=8 fused lane divides only after cross-rank reductions over 8192 non-padding tokens, the per-token logits can still look fine while the returned gradient scale is different by 8x. That is why this post keeps reduction ownership next to Megatron FLCE on Hopper instead of treating it as an implementation footnote.
That is also why the near-copy reproducer matters more than a toy script. A toy hidden @ W.T check can validate arithmetic in isolation, but it cannot show whether a path swap quietly changed the owned output-and-loss seam. The checked-in near-copy files keep that seam visible enough to debug, which is the same operational reason this article belongs next to Megatron FLCE on Hopper, Liger FLCE reduction none, Author Mamba3 spec inside Megatron, and Loss curves and divergence playbook.
Frequently asked questions
Why is shape parity not enough?+
Why can a distributed CE parity check fail even when the single-GPU version passes?+
Why does vocab-partition target masking belong in the parity receipt?+
[S, B, V/TP] slice, while the target ids still describe the global vocabulary. The parity receipt therefore has to say where out-of-partition targets are masked or remapped before the predicted-logit gather and backward update happen. A full-logits PyTorch baseline can skip that local-vocab step, so it may hide wrong partition ranges or ignore-token handling that only fails when the all-gather is removed. Keep that target-mask contract next to the max/sum reductions; the neighboring fused-loss checks are Megatron FLCE on Hopper and Liger FLCE reduction none.Why does local-token versus global-token averaging deserve its own check?+
CrossEntropyLoss(reduction="mean") averages over the visible tokens in the local call site. A vocab-parallel fused path can instead normalize after cross-rank reductions over the global non-padding token count. If the parity receipt never states which count owns the final divide, two lanes can agree on shape and even on per-token logits while still backpropagating different update magnitudes.Why can a parity probe OOM even when the fused training lane fits?+
[S, B, V/TP] or, under sequence parallelQuick term guideSPSequence parallelism is a TP-region activation saver — not a separate mesh. Plain TP leaves layernorm / dropout / residual activations replicated on every TP GPU; SP keeps those intermediates sharded along the sequence axis so each TP GPU holds only 1/TP of them. Cost: same bandwidth as plain TP — the single all-reduce becomes an all-gather + reduce-scatter pair. Weights identical to plain TP; only the activation tensors shrink. Turn on whenever TP is on — near-free memory savings, which is what makes long contexts fit under TP., [S/SP, B, V/TP] and resolves the loss with cross-rank reductions. A plain PyTorch parity rewrite often all-gathers until it materializes [S, B, V], so the first failure becomes full-logit residency rather than CE math. That is why the checked-in near-copy lane is more useful than a full materialization baseline: it preserves the output-and-loss seam instead of replacing it with a different memory regime. The neighboring receipts are Megatron FLCE on Hopper and Porting to Megatron friction. The arithmetic is large enough to be a separate receipt check. With B=4, S=4096, V=151936, and BF16 logits, a materialized [B, S, V] tensor is about 4.98 GB before gradients and about 9.96 GB if an equally large gradient buffer is live. A chunked fused path with a 4096-token vocab chunk is about 134 MB for the same BF16 working logits. If a parity probe crosses from the second regime into the first, the OOM says more about the probe than about the fused CE contract.Why use a near-copy reproducer instead of a toy CE script?+
What if the compared loss path also adds label smoothing or z-loss?+
Which neighboring posts help when the parity drift turns into training-health drift?+
Terms used in this article
Start here for quick definitions, then follow the linked posts for deeper context.
Tensor parallelism splits each linear's weights (QKV, O, MLP gate/up/down) across GPUs. On 8× H200 with TP=8 each GPU owns 1/8 of every matmul's columns or rows, so one big matmul becomes 8 smaller ones that all-reduce at the layer boundary. Cost: one all-reduce per attention and per MLP — heavy bandwidth, so TP is usually bound to a single NVLink/NVSwitch island (1 node of up to 8 GPUs). Embeddings, layernorms, and optimizer state stay replicated across the TP GPUs. Use TP when a single layer's weights don't fit on one GPU, not to scale past one node.
Sequence parallelism is a TP-region activation saver — not a separate mesh. Plain TP leaves layernorm / dropout / residual activations replicated on every TP GPU; SP keeps those intermediates sharded along the sequence axis so each TP GPU holds only 1/TP of them. Cost: same bandwidth as plain TP — the single all-reduce becomes an all-gather + reduce-scatter pair. Weights identical to plain TP; only the activation tensors shrink. Turn on whenever TP is on — near-free memory savings, which is what makes long contexts fit under TP.
A grounded look at why MegaCpp combines Mamba-style state-space blocks with a smaller number of attention blocks for long-context C++ work, and…
A grounded look at why MegaCpp combines Mamba-style state-space blocks with a smaller number of attention blocks for long-context C++ work, and…