Mamba 3 + Transformers: Why a Hybrid Stack Beats Pure Attention for C++
A deep dive into the Mamba 3 / Transformer hybrid MegaCpp trains on: layer interleaving, PSIV caching, the MIMO scan, and the register-split kernel that nearly shipped.

Mamba 3 + Transformers: Why a Hybrid Stack Beats Pure Attention for C++
Pure-attention Transformers are a bad fit for the thing MegaCpp actually does for a living: read long, deeply-nested C++ translation units, follow chains of #includes and template instantiations, and produce patches that are consistent with code the model saw 30k tokens ago. Quadratic attention over 32k–128k tokens of C++ is not a minor tax; it is the tax. At the same time, pure state-space models lose precision on the exact local lookups that matter for C++ - matching a } to its namespace, resolving an overload against an ADL candidate, or copying a symbol verbatim through a refactor.
The architecture we train is a hybrid: a Mamba 3 backbone interleaved with a minority of Transformer blocks. Most of the compute is an O(N) scan in state space; a handful of attention blocks, placed where they pay for themselves, handle sharp retrieval. This post walks through that choice from the kernel up: how the layers are interleaved, why the MIMO variant of Mamba 3 exists at all, what we cache between forward and backward (the "PsiV" cache), and why an ambitious kernel split we designed for the double-backward pass was ultimately rejected.
The post leans heavily on our internal engineering record. Specifically: the Mamba 3 MIMO P1 TMA + warp-specialization notes from [mamba3_mimo_p1_notes.md], the PsiV cache design in [mamba3_mimo_p2_psiv_cache_design.md], the register-split design in [mamba3_mimo_p3_register_split_design.md], the fork reconciliation record [mamba_fork_canonical_2026_04_14.md], and the author-pure integration seam in [cppmega/features/mamba3/config.py]. Context-graph sampling comes from the nanochat [v4_architecture.md].
Why Hybrid, Specifically for C++
A C++ token stream has two statistical regimes. Most of it is slowly varying context: the type environment, the current namespace, the project's macro vocabulary, the coding style of the file. That regime rewards compression into a running state, which is what a selective SSM does. A smaller fraction is high-precision retrieval: "what was the exact signature of Buffer::append that we declared 12k tokens ago?", "which of these seven operator<< overloads applies here?". That regime rewards content-addressable lookup, which is what attention does.
A pure Mamba stack loses the second regime; a pure Transformer pays the full quadratic price to do both. A hybrid lets each layer do what it is good at. Empirically, that is also what the frontier hybrid models (Nemotron Nano 3, Jamba, Zamba, Samba) converged on independently - attention is a minority, SSMs are the majority, and the exact ratio is tuned per target domain.
For MegaCpp, the target domain is C++ specifically, and the sampling regime makes the hybrid argument sharper. Our training corpus is built by the v4 context-graph packer described in [v4_architecture.md]: each training snippet is up to 64k tokens of Callers -> Target -> Callees, extracted with Tree-sitter and packed by a BFS priority queue until the budget is exhausted. Those snippets are long, semi-structured, and full of cross-references that look exactly like the two regimes above: slowly-varying structural context (the caller bodies, the includes) plus precise lookups (the callee signatures the Target modifies). A model trained on that corpus with pure attention burns almost all of its FLOPs revisiting structural context; a pure SSM blurs the callee signatures. The hybrid is not a stylistic preference, it is the cheapest model that actually represents the data.
The Layer Interleaving Pattern
Concretely, the backbone is a Mamba-majority stack with attention blocks sprinkled in at specific depths. The exact spec lives in [nam56r_full_spec.py], but the shape is: ~7 Mamba 3 layers per 1 attention layer, attention biased toward the middle and later third of the network rather than the first blocks. Early layers embed tokens and accumulate local state; attention is wasted there. By the middle of the network, representations are abstract enough that attention lookups hit meaningful keys, and the quadratic cost is paid against features that justify it.
The Mamba layers themselves are configured through the author-pure seam in [cppmega/features/mamba3/config.py]. The AuthorMamba3Config dataclass pins the contract: d_model, d_state, expand, headdim, ngroups, plus Mamba-3-specific fields (rope_fraction, dt_min, dt_max, dt_init_floor, A_floor, is_outproj_norm, is_mimo, mimo_rank, chunk_size). build_author_mamba3_config maps Megatron's config surface onto that contract and refuses any inconsistent override - for example, it rejects a custom mamba_num_heads that disagrees with hidden_size * expand // mamba_head_dim, because the author kernel assumes a specific head count. Failing loudly at config time is the right trade here; silent mismatches in SSM head geometry corrupt gradients in ways that only show up after hours of training.
The NAM56R shape we actually run is B=1, S=8192, H=16, G=1, N=64, P=64, R=4, chunk=16 (per [mamba3_mimo_p1_notes.md]). That R=4 is the MIMO rank, and it is the reason the Mamba-3 layer in our stack is not just "Mamba 2 with RoPE".
MIMO: Why the Rank Exists
Mamba 2 already generalizes Mamba 1 by reframing the selective scan as a structured state-space duality (SSD) with a matrix SSM. Mamba 3 MIMO goes one step further: instead of each head producing a single output channel per state update, each head produces R outputs from R input projections, sharing the same scan. In our config, R=4.
Mechanically, MIMO replaces what would be a rank-1 outer product in the state update with a rank-R one. Concretely, the "PsiV" tensor that dominates the kernel - computed as psi_v[cs, r, p] = v[b, chunk_start+cs, h, p] * psi[h, r, p] ([mamba3_mimo_p2_psiv_cache_design.md]) - has an explicit R axis. For NAM56R that means H=16, R=4, P=64 per head, so each state update carries four up-projections of V through the scan simultaneously. For code tokens, the practical effect is that the same head can track several lightweight "roles" in one pass - for instance, a type-identifier channel and a scope-nesting channel - without needing a wider head dimension or more heads. MIMO is how we get richer per-head state without blowing up the scan's arithmetic intensity.
The MIMO scan is the hot kernel. The P1 work in [mamba3_mimo_p1_notes.md] was about making that kernel use the hardware it was sitting on: flipping TL_DISABLE_TMA_LOWER and TL_DISABLE_WARP_SPECIALIZED to False on the six MIMO kernels, and adding TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE everywhere it was missing. That lets the TileLang compiler emit Hopper-class TMA descriptors and warp-group pipelining instead of falling back to plain mma.sync. The patch is env-gated behind CPPMEGA_MAMBA3_P1=1, idempotent, and crashes loudly if upstream moves the decorator block. Correctness is verified on GB10 (11/11 forward shapes pass at rel_err < 0.01; combined backward passes at stable_max_rel 0.004-0.012).
The catch: on H200, enabling TMA on the backward kernels exposed a real upstream bug. mamba_mimo_bwd_fwd and mamba_mimo_bwd_bwd used three rank-3 shared-memory descriptors, which TileLang's TMA lowering cannot handle ("Cannot detect TMA layout" - InputDim() == 2 assertion). The fix in branch tma-layout-fix-3d-to-2d @ 31dc695 flattens the 3D smem to 2D via zero-copy reshapes (qk_dot_shared[c, r1, r2] -> [c, r1 * R + r2], Q[B, S, R, G, N] -> [B, S*R, G, N]). Correctness survived: 14 gradient tensors at rel_err 0.0038-0.0116, bit-for-bit with the TMA-off baseline within BF16 rounding.
The honest outcome of P1 on H200 is in the addendum: selective-forward-only P1 was a wash (-0.006% throughput on bench3 at MBS=8 NAM56R), because mamba_mimo_fwd at 1192 ms/step is a small slice of the 5540 ms iteration. The full P1 win requires the TMA layout fix on both backward kernels, and that measurement is still pending an H200 slot.
The PsiV Cache
PsiV is the tensor that shows up everywhere in the MIMO kernels. mamba_mimo_fwd uses it three times (intra-chunk qk · PsiV MMA, diag qk · PsiV MMA, interchunk state accumulation). mamba_mimo_bwd_fwd recomputes it. mamba_mimo_bwd_bwd recomputes it again, plus holds it through the state-reverse scan. Three recomputations per chunk × head × batch, for a tensor that is a simple pointwise product of a parameter (psi) and an activation (v).
The P2 design in [mamba3_mimo_p2_psiv_cache_design.md] is about killing two of those three recomputations. The key dependency analysis is worth stating plainly:
PsiV depends on psi, which is a module parameter and does not change within a step, and on v, which is a per-step activation. That means PsiV cannot be cached across training steps, and it cannot be cached across CUDA-graph replays (the buffer would hold the previous replay's activation). But it is a perfectly well-defined intra-step tensor, and the same v flows through fwd -> bwd_fwd -> bwd_bwd inside a single forward+backward iteration. The cache is an activation checkpoint, not a hash table: save PsiV to gmem during fwd, pass it into bwd_fwd and bwd_bwd as an extra input, skip the recompute there.
The storage is straightforward - an extra output tensor on mamba_mimo_fwd, saved via ctx.save_for_backward on the autograd op. The shape is (B, S, H, R, P), dtype BF16, chunk-contiguous layout so the backward kernels' per-chunk access pattern stays coherent. For NAM56R MBS=8 that is about 5.6 GiB extra per rank, inside the 132 GiB H200 peak budget. The design explicitly refuses to pool buffers in v1, on the theory that a one-line torch.empty() is easier to reason about than a pool allocator interacting with CUDA graphs. Pool comes later if needed.
The expected win is 1.5-2.3% TFLOP/s, and - importantly - the P3 register-split design independently identified the same "Hoist-PsiV" pattern as its preferred escape hatch, because removing three fragment tiles from bwd_bwd's inner live-set is valuable on its own merits. As of the P2 doc, the work is at "skeleton committed, Phase A Python prototype pending": a drop-in subclass that computes psi_v = v * psi in Python before calling the kernel, purely to measure the ceiling. If Phase A does not move the needle, the whole pursuit is archived. The env gate CPPMEGA_MAMBA3_P2_PSIV_CACHE reads as "OFF, and raises NotImplementedError if you try to flip it" - we would rather fail loudly than silently pretend the cache is active.
Register Split: The Kernel That Did Not Ship
The double-backward kernel mamba_mimo_bwd_bwd is the tall pole of the Mamba backward pass at 2110 ms per step, versus 1192 ms for mamba_mimo_fwd and 1034 ms for mamba_mimo_bwd_fwd. It runs at 12.5% occupancy, 255 registers per thread, 228 KiB of shared memory - right at the compiler's register ceiling (H200's per-thread cap at threads=128 is 65536 / (2 × 128) = 256), with spilling on top.
The P3 design in [mamba3_mimo_p3_register_split_design.md] proposed splitting bwd_bwd into two kernels connected by a gmem tensor. Pass 1 would run the state-reverse scan end-to-end, producing a dstates_per_chunk buffer (~540 MiB at MBS=8), free of PsiV and qk_dot fragments. Pass 2 would consume that buffer plus re-derive PsiV and qk_dot for the chunk-local gradients. The pitch: both passes fit in ~130 registers, occupancy doubles from 12.5% to 25%, and the 1.3-1.8x throughput bump on a compute-bound kernel (AI=479 >> H200 ridge 206) turns into a 500-900 ms saving per step, ~1% total TFLOP/s.
The design did not survive a careful second read. The addendum is candid: three blockers.
First, the split point is not clean. The loop-carried dstates_frag is updated at the end of each reverse chunk via T.gemm(q_shared, dPhiO_scaled_frag, dstates_frag, clear_accum=False) and carried into the next reverse iteration. That means Pass 1 must still hold q_shared, dPhiO_shared, and dstates_frag live - the exact fragments the split was supposed to drop. Separating them would cost an extra [B, H, nchunks, chunk_size · R, P] buffer (3× bigger than dstates_per_chunk, ~200 MiB per sample) and re-derivation of rotated-and-trap-scaled Q/K inside Pass 1. Realistic post-split register count is 200-220, not the 140 the design claimed, and the occupancy bump is marginal.
Second, GB10 is not a viable correctness platform. At the shapes we would need to diff the split kernel against the baseline, even the baseline forward kernel fails to compile on GB10's 99 KiB smem - "TMA desc init error 716" on small shapes, "Auto-tuning failed: No configuration successfully compiled" on NAM56R small. Any correctness work has to run on H200, which is gated.
Third, bench3 SSH keys were broken on the day the audit ran. No H200 means no perf measurement, which means no ROI validation for a week of kernel work.
The decision: do not ship P3. Pursue Hoist-PsiV (the P2 cache) instead, because it removes ~3 fragment tiles from bwd_bwd's inner live-set for 2-3 days of implementation versus 8-12 days for the full split, with the same expected ~1-2% TFLOP/s envelope. This is worth stating plainly: the cleanest engineering win on the Mamba kernel path in 2026-04 came from rejecting a plausible-sounding optimization after reading the kernel line-by-line, not from implementing it.
Fork Discipline
Running a Mamba 3 hybrid in production means running a lightly-forked mamba_ssm in production. The reconciliation note [mamba_fork_canonical_2026_04_14.md] captures the current state: both training machines (bench3 and europe H200) are on upstream commit 31f3d7baba, with three identical working-tree patches on mamba3.py, mamba3_mimo_bwd.py, and mamba3_mimo_bwd_varlen.py, and one real divergence on mamba3_siso_combined.py where bench3 carries the PR #909 "cache ctx.saved_tensors for checkpoint compat" tweak. Accessing ctx.saved_tensors twice on a recomputed node raises under gradient checkpointing, so the one-line cache is a correctness patch, not a perf patch. The canonical working tree is bench3's superset; an earlier "bench3=31f3d7b vs europe=4f4857f" split in a brief turned out to be stale metadata.
The reason that tiny fork discipline matters is exactly the hybrid. We apply patches on top of mamba_ssm via apply_mamba3_mimo_p1_patches.py and (in the future) apply_mamba3_p2_psiv_patches.py, env-gated, idempotent, and lock-synchronized so rank-0 patches while other ranks wait on a sentinel. If the Mamba half of the stack drifts silently between machines, the only symptom is divergent loss curves three days into a run. The patch pipeline, the md5 reconciliation, and the author-pure AuthorMamba3Config seam are the machinery that keeps the hybrid reproducible.
What We Actually Believe
A hybrid Mamba 3 + Transformer backbone beats pure attention for C++ for compounding reasons. The O(N) scan costs less per token at 32k+ context, which is where most of our training snippets live. MIMO packs multiple up-projections into a single scan, so Mamba layers carry more per-head information per FLOP. A minority of attention blocks, placed at the right depths, handle the precise lookups SSMs blur. Kernel-level work - P1 TMA + warp specialization, the TMA layout 3D-to-2D fix, the planned PsiV cache - pushes the scan closer to its hardware ceiling without changing the math. The pieces we chose not to ship, like P3, are as much a part of the story: avoided kernel weeks are weeks spent on cheaper wins that actually move throughput. That is the hybrid MegaCpp trains - because the shape of the C++ data, the H200 we run on, and the kernels we are willing to maintain all point at the same design.