Multi-Head Cross fused on Blackwell: from reference einsum to Triton
How the MegaCpp Multi-Head Cross branch mixer went from a readable PyTorch reference to a fused Triton path on Hopper and Blackwell, and how it lands in deployment through a narrow feature contract.

Multi-Head Cross, which we call mHC, is the part of the MegaCpp hybrid recipe
that mixes multiple residual streams between blocks. The algebra is four
einsums and a Sinkhorn normalisation; the pain is that those einsums run
per block for every layer, on every token, across Hopper and Blackwell alike,
and the reference PyTorch path is too launch-heavy to ignore at our depth. This
post is the engineering story of collapsing the mHC reference into a fused
Triton path, keeping a safe fallback, and shipping the result as a narrow
feature contract in the MegaCpp deployment stack. For where mHC sits in the
architecture rather than the kernel lane, the closest companion is Hybrid
layer interleaving, with the broader
keep-or-drop rule in Kernels that pay for
themselves.
If you want the shortest checked-in reading path before the narrative, start with mHC branch mixer sample, then mHC fused static sample, then mHC stream residual sample, and finally Megatron args sample. That sequence keeps the math, the fused fast path, and the deployment seam visible in that order.
One boundary matters up front: this is not a general hardware-feature article. The fused mHC path here is ordinary CUDA/Triton tensor math around a static 4-stream mixing contract. If the vocabulary itself is the blocker, open MegaCpp model glossary first. If your real question is low-level memory movement or descriptor ownership, use the dedicated Blackwell kernel companions instead of trying to infer that story from this small residual-mixing lane.
Why MegaCpp cares about this
The dense baseline uses 4-stream hyper-connections in the cross-layer
HyperConnections sense, with mHC doing the branch-mixing. Per-block the math is
cheap but launch-dominated: a pre-mix (bn,btnd->btd), the block body, a
residual mix-back (bnm,btmd->btnd), and an add. At a depth-52 dense preset
that is roughly 200 tiny einsum launches per forward, and an early benchmark
on H200Quick term guideH200NVIDIA's Hopper H200 GPU platform, typically discussed here as an 8-GPU training node with large HBM capacity and NVLink-connected ranks.GroundingAbout: training on 8x H200 Reference: H200 memory geometry Reference: training speed anatomy on H200 showed mHC overhead as the dominant non-attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns cost once Mamba-3 and
MoEQuick term guideMoEToken Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble.GroundingThe MoE Routing We Actually Shipped Sequence, Context, and Expert Splits in the Hybrid Stack were compiled out of the way. The public reference implementations we
looked at made different trade-offs: batch-dim packing, a different
(B, S, K, D) layout, and a small CUDA extension. We picked a
fused-but-narrow Triton path
that keeps the dynamic dispatch simple.
The useful framing constraint is that the win is a launch-and-bandwidth story, not a new algorithm. That is why the article keeps returning to "narrow hot path" language instead of pretending mHC turned into a generic new mixer family.
What we built in the MegaCpp training stack
The checked-in mHC branch mixer sample
is the reference surface. It projects each pooled branch representation through
a small hidden layer, scores it, builds a dense (N, N) affinity matrix via
bmm on low-dim keys, and runs Sinkhorn normalisation in fp32 to project onto
the Birkhoff polytope. For N=2 it detects Sinkhorn's degenerate
doubly-stochastic case and routes through a direct softmax instead.
Reader-safe version: once there are only two branches left, the iterative
balancing step stops buying extra structure, so the reference path uses the
plain two-way softmax directly.
blend_alpha blends the learned weights with uniform so slot-order bias does
not dominate when the early training signal is weak. That reference path is
still the default.
The checked-in mHC fused static sample
mirrors the fast path. The scope is intentionally narrow: only the static
4-stream cross-layer HC path is wired through the Triton contract. Everything
else, including variable N, stays on the native PyTorch path. The reason is
simple: the hot surface is only a few shape-stable mixing steps, while the
other variants would each need a separate kernel family.
The forward math in its Torch form is four einsums:
- branch input:
branch_input = einsum('bn,btnd->btd', pre_mix_weights, hidden) - stream mix: the pre-mix plus
residual_mix = einsum('bnm,btmd->btnd', residual_weights, hidden) - residual add:
new_hidden = residual_mix + post_mix_weights[:, None, :, None] * branch_output[:, :, None, :] - fused mix-add: the residual mix-back fused with the add, which is what actually avoids holding the residual tensor across the block body on the critical path
Each reference step keeps its backward as an explicit einsum chain. The
Triton fast path is forward-only; backward falls back to the PyTorch formulas.
That matches the pattern we use elsewhere for fused CUDA primitives where the
custom backward is not yet worth the complexity.
The fused-mix-add primitive, inlined for reference:
# reference fused mix-add
residual_mix = torch.einsum('bnm,btmd->btnd', residual_weights, hidden)
new_hidden = residual_mix + post_mix_weights[:, None, :, None] * branch_output[:, :, None, :]
Backend resolution at a glance:
| Check | Fast path? |
|---|---|
is_cuda and N == 4 |
Triton |
| N != 4 or token-wise dynamic | torch (reference) |
| non-CUDA / shape drift | torch + one-time warn |
| fp8 autocast inside group | group enters once |
| Sinkhorn (all shapes) | fp32 with eps clamp |
The checked-in samples keep the proof surface narrow. mHC fused static sample
is the kernel-shape receipt: it expects [B, T, S, D] and rejects any
S != 4. mHC branch mixer sample
is the reference-math receipt: it keeps Sinkhorn normalisation, the N=2
short-circuit, and blend_alpha visible. mHC stream residual sample,
hybrid pattern sample, and
Megatron args sample are the
deployment receipts: they keep stream count, dynamic mode, residual
interaction, mhc_fused_ops, and the "mHC remains custom" seam visible without
claiming they all map onto one fused kernel.
Numerical guards are the part the checked-in samples show most clearly. The
branch-mixer sample keeps Sinkhorn in fp32, clamps row and column sums with
epsilon, renormalizes weights onto the simplex, and falls back to a direct
softmax when N=2. That is the checked-in version of the same packet lesson:
keep the routing math numerically narrow and do not pretend every branch-mixing
case belongs on one fused kernel.
The boundary that keeps this lane honest
The right mental model is not "Blackwell gets a special mHC algorithm." It is "the checked-in fast path only claims a narrow fixed-shape fused lane, and everything else stays explicit."
That boundary has three parts:
- The static 4-stream fused path is the real deployment lane. The win comes from collapsing a shape-stable hot surface that is worth tuning once and keeping.
- Dynamic routing remains a valid research surface, but it widens the optimizer and recompute story. The shipped article keeps that lane visible without pretending it is the same contract as the static fused path.
- The Megatron side still treats mHC as custom. The checked-in args sample says that directly, which is exactly the seam this article needs to keep honest.
How it lands in MegaCpp
The MegaCpp deployment stack is built on Megatron-CoreQuick term guideMegatron CoreThe NVIDIA framework surface MegaCpp ports into through narrow adapters, layer specs, and runtime ownership bridges.GroundingAbout: Porting to Megatron friction About: Nemotron-style recipe as pure Megatron CLI Example: Mamba3 TP mixer sample. The relevant checked-in public surfaces are mHC stream residual sample, hybrid pattern sample, and Megatron args sample. Their role is fail-closed: the hybrid samples keep stream count, dynamic mode, fused-ops, and residual interaction visible, while the Megatron-args sample keeps the separate note that mHC still remains custom rather than Megatron-native.
The fused_ops toggle is the deployment flag. When it is on, the mixer
resolves the Triton backend and uses the fused forward surface. When it is off,
or when the tensor shape falls outside the four-stream contract, it falls back
to the Torch reference.
One useful research-backed addition belongs here: the grouped mHC path is also why the FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper autocast scope is treated as a group-level concern rather than a per-layer one. The article keeps that statement narrow on purpose. The real claim is only that the precision boundary belongs to the grouped runtime contract, not that every mixed-precision variant has already been fused and proved.
Ablations and what we kept
The ablation story is useful here only where it changes the public contract. Two parts survived:
- Dynamic mHC stayed on the PyTorch path intentionally. The checked-in branch
mixer already handles the general
Ncase with theN=2short-circuit and theblend_alpharamp. The fused lane stays shape-stable on purpose. - The Muon interaction turned out to be an optimizer boundary, not proof that the four-stream kernel itself was wrong. The useful public-safe receipt is the same one described in Muon on Hopper and Blackwell: split-QKV fixed the unstable update geometry without changing the narrow fused mHC contract.
Frequently asked questions
Is the Triton fast path generic N-stream infrastructure?+
N stays on the reference PyTorch path.Why keep backward in explicit PyTorch?+
Does GB10 depend on the fused path the same way H200 does?+
Why treat mHC as a narrow feature contract instead of a generic fused-mixer layer?+
Where do the real config and residual-interaction knobs live?+
Which checked-in files should I compare first for reference versus fused behavior?+
Why is the fast path hard-coded to four streams?+
N-stream fused mixer already exists.Why does the reference path switch to softmax when only two branches remain?+
n_branches == 2 it goes straight to a temperature-scaled softmax, then applies the same normalization and optional blend step used elsewhere. That keeps the smallest case readable and avoids pretending the full Sinkhorn path is required for every branch count.Is this the same thing as the newer mHC-lite permutation route?+
Why did raw Muon fail once mHC moved into the deeper hybrid runs?+
Terms used in this article
Start here for quick definitions, then follow the linked posts for deeper context.
The NVIDIA framework surface MegaCpp ports into through narrow adapters, layer specs, and runtime ownership bridges.
Consumer Grace Blackwell GB10 / DGX Spark bring-up lane used to separate driver-visible gates, patched cubin signals, and real execution proof.
NVIDIA's Hopper H200 GPU platform, typically discussed here as an 8-GPU training node with large HBM capacity and NVLink-connected ranks.
The token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.
Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.
Token Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble.