Mamba 3 Parallel Performance: What Beat Attention in the nanochat POC, and Where It Lost
MIMO scaling, block sizes, the PsiV cache trade-off, and an honest tally of where a Mamba 3 hybrid outran pure attention on H200 and where it did not.

Mamba 3 Parallel Performance: What Beat Attention in the nanochat POC, and Where It Lost
The question that gates everything else for a C++ specialist model is blunt: does a Mamba 3 block actually go faster than an attention block, at the sequence lengths we care about, on the hardware we have? The nanochat POC gave us a first answer. This post is the performance side of that answer - MIMO scaling, chunk-size behavior, the PsiV cache trade-off, and a frank tally of where the hybrid pulled ahead of pure attention and where it did not.
All numbers below are from two H200 training hosts, an internal GB10 correctness box, and the v6e TPU lane we use for XLA ablations. Configuration labels (NAM52, NAM56R) are internal names; shapes are real.
The Shapes That Matter
Before percentages, the geometry. The MIMO scan we run in the hybrid is parameterized by (H, G, N, P, R, chunk_size, B, S):
H=16heads per Mamba layerG=1B/C group (we keepngroups=1for the author-pure contract)N=64state dimensionP=64head widthR=4MIMO rank - four up-projections sharing one scanchunk_size=16for the MIMO kernel (not the 256 we used in the Mamba 2 reference)B=1, S=8192per rank at the NAM56R reference shape, MBS=8 in practice
Those numbers are locked by AuthorMamba3Config. The config layer refuses overrides that do not satisfy H = hidden_size * expand / head_dim, because the author kernel assumes a specific head count; silent mismatches on SSM head geometry corrupt gradients in ways that only show up after hours of training.
MIMO Is Where Mamba 3 Earns Its FLOPs
Mamba 2 already reframed the selective scan as a structured state-space duality and made it a matrix operation. Mamba 3 MIMO adds a rank-R outer product to the state update. Mechanically, the "PsiV" tensor that dominates the kernel is a per-chunk pointwise product:
psi_v[cs, r, p] = v[b, chunk_start+cs, h, p] * psi[h, r, p]
where psi is the learned MIMO_V parameter of shape (H, R, P). At R=4, each head carries four up-projected channels of V through the scan at once. Arithmetic intensity goes up without widening the head or adding heads.
In practice, MIMO is how we get attention-like representational width out of an O(N) kernel. For C++ tokens - where one head needs to track both "what scope am I in" and "what type does this identifier bind to" - a single scan with four channels behaves closer to four narrow scans than to one wider scan, and the perf profile stays linear.
The measured price: the MIMO scan is register-heavy. On NAM56R at MBS=8, nsys captures on the H200 reference show:
| Kernel | Time | Regs | Smem | Occupancy |
|---|---|---|---|---|
mamba_mimo_fwd |
1192 ms | 239 | 196 KiB | 6.2 % |
mamba_mimo_bwd_fwd |
1034 ms | 255 | 196 KiB | 6.2 % |
mamba_mimo_bwd_bwd |
2110 ms | 255 | 228 KiB | 12.5 % |
Three things jump out. First, the double-backward kernel is the tall pole; it runs at 12.5 percent occupancy at 255 regs per thread, which is the H200 compiler ceiling (65536 / (2 * 128) = 256). Second, forward and first-backward are both at 6.2 percent occupancy, meaning the scan kernels are not memory-bound - they are register-bound on a compute-bound workload (arithmetic intensity on bwd_bwd is ~479 against H200's ~206 ridge). Third, the backward is larger than the forward by factor > 2x, which changes the math on every optimization: wins on the forward kernel alone rarely move total step time by more than one percent.
Chunk Size Behavior
We ran the MIMO forward kernel across eleven parameter shapes to sanity-check correctness and pick a chunk size. The target tolerance was rel_err < 0.1; in practice every shape came in below 0.01:
| shape (N, P, R, chunk, BB) | stable_max_rel | max_abs |
|---|---|---|
| 16, 64, 4, 8, 128 | 0.006 | 0.28 |
| 32, 64, 4, 16, 256 | 0.007 | 0.90 |
| 64, 64, 4, 16, 256 | 0.008 | 0.54 |
| 128, 64, 4, 16, 256 | 0.008 | 1.04 |
| 256, 64, 4, 8, 256 | 0.009 | 1.34 |
| 64, 128, 4, 16, 256 | 0.005 | 0.58 |
| 128, 32, 4, 16, 256 | 0.007 | 0.96 |
| 128, 128, 4, 8, 256 | 0.008 | 0.85 |
| 128, 64, 8, 8, 256 | 0.006 | 0.32 |
| 128, 64, 2, 32, 256 | 0.005 | 2.56 |
| 128, 64, 1, 64, 256 | 0.009 | 6.41 |
Two observations drove the chunk-size choice. Larger chunks (chunk=64, R=1) pushed max_abs up roughly 10x without materially helping throughput, because the register window grew with R * P. Smaller chunks (chunk=8) were fine on correctness but spent more time on launch overhead and inter-chunk state plumbing. The sweet spot for NAM56R is chunk=16, which lets us keep R=4 without asking the compiler for more smem than the H200 SM has (228 KiB dynamic, right at the cap).
On the fourteen-gradient backward test at the smallest shape, stable_max_rel landed between 0.004 and 0.012 with bad_frac < 0.05 everywhere. That is the tolerance we carry forward as our correctness gate when we patch the kernel.
The PsiV Recomputation Tax
PsiV shows up five times inside the kernel loop body across the three MIMO kernels:
| Kernel | PsiV touches |
|---|---|
mamba_mimo_fwd |
intra-chunk qk * PsiV MMA, diag qk * PsiV MMA, interchunk state accumulation |
mamba_mimo_bwd_fwd |
recomputes psi_v = v * psi |
mamba_mimo_bwd_bwd |
recomputes psi_v_3d_bf = v_bc_r * psi_bc_cs, plus direct psi_bf in dv_pre/dPsi_pre |
For every chunk x head x batch in a training step, V and psi are loaded from gmem, broadcast, multiplied, and thrown away - three times. The product is a point-wise op; the reason it is recomputed is that it was never saved.
The dependency analysis is the whole story. psi is a module parameter, V is a per-step activation. PsiV cannot be cached across training steps (the activation changes every forward), it cannot be cached across CUDA graph replays (the buffer would hold the previous replay's activation), but it is a perfectly well-defined intra-step tensor, and the same V flows through fwd -> bwd_fwd -> bwd_bwd inside one iteration.
So the cache is an activation checkpoint, not a hash table. Save PsiV to gmem during fwd, pass it into the backward kernels as an extra input, skip the recompute. Shape (B, S, H, R, P), BF16, chunk-contiguous layout. For NAM56R at MBS=8 that is about 5.6 GiB of extra per-rank memory - fine inside the 132 GiB H200 peak we run at. The expected envelope is 1.5 - 2.3 percent total TFLOP/s.
There is a real failure mode this design is ready to accept: if the TileLang compiler is already CSE-ing psi_v = v * psi across its scheduling stages (hoisting the load and keeping the product in a register across back-to-back ct.mma calls), the runtime cost is already near-zero and we get nothing. That is why the first step is a Python-level materialization to measure the ceiling. If the Python hack does not move nsys numbers, the whole pursuit is archived. The env gate reads NotImplementedError until a perf number justifies flipping it; we would rather fail loudly than silently pretend the cache is on.
What Beat Pure Attention
Three places, measurable.
First, context length. At 32k - 64k tokens of C++, the quadratic attention cost becomes the dominant line item in an otherwise well-tuned stack. A Mamba 3 layer at the same shape runs O(N) per token with a constant that is not small - 1192 ms per step for MIMO forward at NAM56R is nothing to brag about - but it does not grow with sequence length. On our v4 context-graph snippets (up to 64k tokens of Callers -> Target -> Callees), the hybrid spends most of its FLOPs on the scan and reserves attention for the handful of layers that actually need content-addressable lookup.
Second, per-head information density. MIMO at R=4 means each head carries four channels through the same scan. Equivalent attention-based width would require either more heads (more KV cache) or a wider head dimension (more per-op compute). On H200, the MIMO path is strictly cheaper at the same representational capacity.
Third, training-side memory stability on long sequences. With a minority of attention blocks, KV-cache growth at long context is small. For inference with a 64k prompt, the peak live set stays well under what an equivalent all-attention stack would demand, which lets us keep a larger micro-batch at eval time.
Where It Lost
Also three places, and we do not pretend otherwise.
First, the fwd-kernel-only P1 optimization was a wash. We flipped TL_DISABLE_TMA_LOWER and TL_DISABLE_WARP_SPECIALIZED from True to False on the MIMO forward kernel and added TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE, expecting 20 - 30 percent forward speedup. We measured over 19 samples at MBS=8 on an 8x H200 host:
| Metric | Baseline | Selective P1 | Delta |
|---|---|---|---|
| Throughput (TFLOP/s) | 183.016 | 183.005 | -0.006 % |
| Iter 1 lm loss | 11.8775 | 11.8775 | identical |
| Iter 25 lm loss | 5.3296 | 5.1818 | -0.15 |
| Val test iter 25 | 5.2686 | 5.1094 | -0.16 |
| Peak reserved (GiB) | 131.924 | 132.686 | +0.76 |
Throughput delta is inside measurement noise. The reason is straightforward: forward is a small fraction of the iteration (1192 ms of 5540 ms total), so a 25 percent speedup on forward-only moves the whole step by about one percent, which is exactly what the noise envelope swallows. The loss delta is BF16 FMA-ordering noise from a different kernel schedule; iter-1 loss is bit-identical, which is the sanity check we actually trust.
Second, the backward kernels hit a real TileLang bug when TMA was enabled. mamba_mimo_bwd_fwd and mamba_mimo_bwd_bwd used three rank-3 shared-memory descriptors, which TileLang's TMA lowering cannot handle (InputDim() == 2 assertion, "Cannot detect TMA layout"). We fixed it by flattening 3D smem to 2D via zero-copy reshapes (qk_dot_shared[c, r1, r2] -> [c, r1 * R + r2], Q[B, S, R, G, N] -> [B, S*R, G, N]). Correctness survived - 14 gradient tensors at rel_err 0.0038 - 0.0116, bit-for-bit with TMA-off within BF16 rounding - but the measurement on H200 for the combined P1 + layout-fix stack is still pending a slot. Until we have that number, the env gate stays OFF.
Third, pure attention has a floor at small context that MIMO does not beat. On our 4k-context ablation sweep on TPU v6e-x4, the dense Transformer baseline (nam52_h200_dense_no_mtp_v1) landed at loss 5.43 after 100 steps at 508 tok/sec, while the AdamW hybrid (nam52_hybrid_md_h200_dense_v1, AEMEAEDE pattern) stopped at 7.06 at 512 tok/sec over the same budget. That gap closes at longer context and with the MIMO path enabled, but at 4k tokens the Mamba blocks are spending their O(N) cheapness on sequence lengths that do not exercise it. The hybrid becomes clearly dominant only once we run the same model at 16k or 64k context, which is where our v4 data actually lives.
What Comes After
The concrete list of perf work still on the table, sized by realistic gain:
- Full P1: TMA + warp specialization on fwd and both backward kernels, gated by the 3D-to-2D TMA layout fix. Modeled gain is 5 - 10 percent; measurement pending an H200 slot.
- PsiV cache (P2): intra-step activation checkpoint removes two of three recompute passes, modeled at 1.5 - 2.3 percent. Phase A Python prototype first, abandon if it does not move nsys.
- MBS=10 at fp8 param-gather: orthogonal win, paired with the Liger main-head backward fix. +1 - 2 percent if the micro-batch headroom actually opens.
- We do not ship the P3 register-split of
bwd_bwd. The analysis that killed it is in the companion kernel-journey post; the short version is that the design's claimed 30 - 50 percent kernel speedup did not survive a careful line-by-line read of the reverse-scan live set.
Overall the POC answered the gating question with a qualified "yes": Mamba 3 MIMO is the cheaper kernel at the context lengths we train at, it represents per-head information at a better density than attention does, and the costs are concentrated in bwd_bwd where we already have optimization paths ready. The wins are real; we just have to keep reporting them honestly, including the ones that came in at noise.
References
mamba3_mimo_p1_notes.mdmamba3_mimo_p2_psiv_cache_design.mdmamba3_mimo_p3_register_split_design.mdmamba_fork_canonical_2026_04_14.mdmamba_integration_log.mdmamba3_adoption_report_2026-03-18.mdv4_architecture.mdCHANGELOG.md