Mamba 3
CUDA
TileLang
Pallas
CuTe DSL
Kernels
H200
TPU

The Mamba 3 Kernel Journey: CUDA, Pallas, TileLang, and a Honest Look at CuTe DSL

How the Mamba 3 kernel stack actually shipped in our nanochat POC: TileLang on H200, Pallas on TPU v6e, a CuTe DSL port that never made it, and the verdicts that came out of each attempt.

10 min readDavid Gornshtein
The Mamba 3 Kernel Journey: CUDA, Pallas, TileLang, and a Honest Look at CuTe DSL

The Mamba 3 Kernel Journey: CUDA, Pallas, TileLang, and a Honest Look at CuTe DSL

Shipping a hybrid Mamba 3 plus Transformer backbone for a C++ codegen model forces the same conversation three times, once per backend: CUDA on H200, Pallas on TPU v6e, and the DSL layer on top of each. This post is that conversation written down - what we tried, what carried to production, what we abandoned.

Short version: TileLang is how the MIMO kernels ship on H200 today, Pallas is how the SSD scan runs on TPU v6e, a CuTe DSL port of the MIMO kernel exists on disk and has not proven out, and our one serious TileLang versus CuTe comparison killed the CuTe path on ROI.

What We Actually Ship on H200

The production kernel stack for the Mamba 3 half of the hybrid is upstream mamba_ssm on commit 31f3d7baba, with local working-tree patches on three TileLang files (mamba_ssm/modules/mamba3.py, mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_bwd.py, mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_bwd_varlen.py) plus one correctness patch on mamba_ssm/ops/triton/mamba3/mamba3_siso_combined.py that caches ctx.saved_tensors for gradient checkpointing compatibility.

That fork is small on purpose. Patches go through an idempotent in-place applier (apply_mamba3_mimo_p1_patches.py) that crashes loudly if the upstream decorator block moves, and we md5-reconcile the working tree across both H200 hosts before any long run. The mamba3_siso_combined.py tweak exists because accessing ctx.saved_tensors twice on a recomputed node raises under gradient checkpointing.

The CUDA path for Mamba 3 lives inside TileLang, not hand-written CUDA C++. The kernels that matter are mamba_mimo_fwd, mamba_mimo_bwd_fwd, and mamba_mimo_bwd_bwd from upstream's tilelang/mamba3/, plus varlen variants. Everything else - in_proj fusion, RMSNorm, residual - runs through Transformer Engine and regular PyTorch.

TileLang P1: What It Buys, What It Costs

P1 is our internal label for a targeted perf pass on the MIMO kernels: flip the disabled-by-default TMA and warp-specialization flags to enabled. Upstream ships them off:

tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,

Flipping them lets the TileLang compiler emit Hopper-class TMA descriptors for bulk gmem-to-smem async copies, and warp-group pipelining instead of plain mma.sync. We also add TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True on every kernel that did not already have it, so the downstream smem pressure does not push us past the dynamic cap on smaller GPUs.

This is eight TL_DISABLE_* flips and five new aggressive-merge insertions across four files. The patch is idempotent and env-gated behind CPPMEGA_MAMBA3_P1=1, default OFF. On GB10 (sm_121a, 99 KiB smem per SM) all three kernel groups compile cleanly with TMA and warp-spec enabled, and all eleven forward parametrized shapes pass the standard correctness tolerance. The combined backward passes at rel_err 0.004 - 0.012, well inside our 0.05 gate.

Then we tried to run it on H200 with backward enabled. Forward compiled fine. mamba_mimo_bwd_fwd and mamba_mimo_bwd_bwd blew up inside TileLang's TMA lowering pass:

tvm.error.InternalError: Check failed: (shared_layout->InputDim() == 2)
is false: Cannot detect TMA layout.

Root cause: the backward kernels use three rank-3 shared-memory descriptors (qk_dot_shared[c, r1, r2] and (B, S, R, G, N) reads of Q); TileLang's TMA lowering only handles 2D layouts. Fix is mechanical: flatten to 2D via zero-copy reshape (qk_dot_shared[c, r1, r2] -> [c, r1 * R + r2], Q[B, S, R, G, N] -> [B, S*R, G, N]). No arithmetic changes, just a view.

We shipped it on branch tma-layout-fix-3d-to-2d, with an applier and a unified diff. Correctness on GB10 survived: 14 gradient tensors at rel_err 0.0038 - 0.0116, bit-for-bit with the TMA-off baseline inside BF16 rounding.

On H200 we then hedged and tried shipping only the forward flip, with backward kernels left unpatched. Over a 25-iter bench at MBS=8 NAM56R, the measured delta was -0.006 percent TFLOP/s. Inside noise. The honest read is that mamba_mimo_fwd at 1192 ms is a small fraction of the 5540 ms iteration, and a 20 - 30 percent speedup on forward moves the step by roughly one percent, which our variance swallows. The full P1 win needs both the TMA layout fix and the backward flips together, and that measurement is still pending an H200 slot.

PsiV and the Register Ceiling

While P1 was sitting on the H200 queue, we wrote up two follow-on designs. One shipped as scaffolding; one is shelved.

The PsiV cache (P2 in our internal sequence) is the straightforward one. PsiV appears five times in the MIMO kernel loop body across fwd, bwd_fwd, and bwd_bwd, recomputed from scratch each time even though its two inputs (psi as a module parameter, V as a per-step activation) are stable within a single forward-backward iteration. The plan is to save PsiV to gmem inside fwd, pass it to the backward kernels via ctx.save_for_backward, and skip two of the three recomputes. Shape (B, S, H, R, P), BF16, about 5.6 GiB extra per rank at NAM56R MBS=8, fine inside our 132 GiB peak. Modeled envelope is 1.5 - 2.3 percent total TFLOP/s.

The non-obvious part is the failure mode we have to check first. TileLang's scheduler may already be CSE-ing psi_v = v * psi across the kernel's internal stages - hoisting the load once, keeping the product in a register between back-to-back ct.mma calls. If it is, the cache buys nothing. That is why the plan starts with a Phase A Python prototype: materialize psi_v at the Python level before the kernel call, hand it in as if it were V, and measure. If the Python-level materialization does not move nsys, the whole pursuit is archived and a superseded note goes into the design doc. The env gate raises NotImplementedError by design until a perf number justifies turning it on.

The register-pressure split (P3) is the one we decided not to ship. It proposed cutting mamba_mimo_bwd_bwd in two - a state-reverse pass producing a dstates_per_chunk gmem tensor, and a chunk-local pass consuming it plus re-deriving PsiV and qk_dot - to drop both passes under 150 registers and double the 12.5 percent occupancy. Pitch: 500 - 900 ms per step, about one percent total TFLOP/s, on a compute-bound kernel at AI 479 with the H200 ridge at 206.

On a careful read of mamba3_mimo_bwd.py the split did not survive:

  • The loop-carried dstates_frag is updated at the end of each reverse chunk via T.gemm(q_shared, dPhiO_scaled_frag, dstates_frag, clear_accum=False) and carried into the next iteration. Pass 1 still needs q_shared, dPhiO_shared, and dstates_frag live - the exact fragments the split was supposed to drop.
  • Honestly separating the passes costs an extra [B, H, nchunks, chunk_size * R, P] buffer (3x bigger than dstates_per_chunk, ~200 MiB per sample) plus re-derivation of rotated-and-trap-scaled Q/K inside Pass 1. Realistic post-split register count is 200 - 220, not the 140 the design claimed.
  • GB10 is not a viable correctness platform for this work. At the shapes we would need for a baseline-versus-split diff, the upstream forward kernel itself fails to compile on GB10 at TMA desc init error 716 (small shape) and Auto-tuning failed: No configuration successfully compiled (NAM56R small). Any correctness validation has to run on H200, and H200 time is gated.
  • H200 SSH access was broken on the day we ran the audit. No H200 means no perf measurement, which meant no ROI validation for a week of kernel work.

Decision: do not ship P3. Pursue the PsiV cache instead, which removes three fragment tiles from bwd_bwd's inner live-set for 2 - 3 days of implementation rather than 8 - 12 days for the split, with the same expected ~1 - 2 percent envelope. The clean engineering win here was not implementing the optimization, which sounds like a joke but is genuinely how we saved a week of kernel time we did not have.

Pallas on TPU v6e

The TPU side of the same stack runs under XLA, on v6e-4 for 4k-context rapid ablations and v6e-8 for the 16k and 64k context phases. There is no TileLang path on TPU; the kernels are a mix of torch_xla.experimental.scan for the SSD recurrence and a chunked matmul-based reference for within-chunk mixing.

The chunked reference is where we caught most of our correctness bugs before the CUDA kernel did. The SSD dual decomposition must include both components: cross-chunk state accumulation (sequential over nchunks = seq / chunk_size) and within-chunk attention-like mixing. Drop the within-chunk term and tokens can only see information compressed into a chunk-boundary state; local context mixing is gone. Every reviewer proposed dropping it at some point, and the chunked reference test caught each attempt.

We use F.rms_norm (parameterless) for B/C QK-norm to match nanochat's attention QK-norm rather than nn.RMSNorm - mixing the two produces orphan parameters that silently stop training. Complex RoPE uses per-dimension frequencies, not a single scalar angle per position, because single-frequency collapses the rotation to one dimension.

A Pallas kernel does live in the TPU tree, but for a different purpose. We maintain a content-dependent sparse attention prototype - importance scoring, query-tile union selection, and a Pallas sparse attention kernel with online softmax, parameters aligned to the v6e MXU (Bq=256, l'=256, H=128, Bk=1024). It is for the attention minority of the hybrid, not the SSM majority. Supports up to 128k sequence length with a top-n=8 selection. Prototype committed; hardware validation receipt still open.

Main Pallas trade-off: compile time. torch_xla.experimental.scan rewrites the SSM loop to avoid @while_loop overhead. The fused HLO is cheaper per step but much more expensive to compile the first time. We eat one long compile on rank-0 at process start; without it, the Python-level loop over chunks dominates at 4k context on v6e-8.

TileLang versus CuTe DSL

The one honest A-B we ran on the kernel stack itself was between TileLang and a CuTe DSL port of the MIMO kernel. Our CuTe DSL artifact lives at .tmp/mamba3_mimo_cutile/ - a "cuTile" port mirroring the upstream TileLang MIMO kernels. The intent was to have a non-TileLang code path for the MIMO kernels, with a CUTLASS-style template structure instead of a DSL compiler, so we could make targeted changes (for example adding PsiV as an extra output of fwd) without fighting a scheduler.

Verdict after a few weeks of on-and-off work: not worth shipping. Three reasons, none about correctness.

First, compile-time cost. The MIMO kernel is parameterized across (N, P, R, chunk_size, BB) and instantiating it through CUTLASS templates took minutes per cold build. TileLang's JIT takes seconds per shape because it caches at the TIR level. In a nightly sweep with ~20 configurations the tax is real.

Second, the targeted changes we actually wanted (PsiV cache, 3D-to-2D smem flatten, TMA flag flips) are all expressible inside TileLang's existing decorator block. The P1 patches are literally two lines per kernel. CuTe's advantage of "do whatever you want" mattered less than TileLang's advantage of "land a fix in an hour" - when the TMA layout bug hit, the TileLang fix landed the same day. A CuTe port would have required reflowing the smem allocation by hand.

Third, the CuTe artifact did not match TileLang on correctness once we pushed R > 4 or P=128. TileLang passes all eleven parametrized shapes; our CuTe port passed five. Chasing the last six was a kernel-author week the ROI math did not justify, given the low-hanging wins all live on the TileLang path anyway.

The port stays on disk as a reference implementation: the psi_v = v * psi identity for the P2 cache is trivially checkable side by side against it. Killing a candidate path while keeping its reference value is a reasonable end-state.

What We Learned About DSL Choice

DSL choice is an operational cost, not a performance cost. Both TileLang and CuTe land the same PTX for the shapes we run. The difference is iteration speed: flag flips, shape tuning, and small correctness fixes land cheaper in TileLang because the compiler's surface area is closer to what we want to change. For a training loop patched in place via an env-gated applier, TileLang pays for itself on maintenance alone.

Kernel perf work is measurement-bound, not design-bound. P1 selective-forward was a wash (-0.006 percent) because forward is a small fraction of the iteration. P3 looked like 30 - 50 percent on paper; a line-by-line read collapsed that estimate to 1 - 2 percent. Real kernel wins come from an nsys capture that names the actual bottleneck, not a design that names a plausible one.

Current state on H200: upstream TileLang MIMO kernels, three working-tree patches, env-gated P1 pipeline, P2 PsiV cache scaffolded and waiting for a Phase A nsys number, P3 rejected. On TPU v6e: torch_xla.experimental.scan for the SSM, Pallas sparse attention prototype for the attention minority, chunked matmul reference as the correctness ground truth. Nothing glamorous, which is exactly the sign the kernel stack is converging.

References

  • mamba3_mimo_p1_notes.md
  • mamba3_mimo_p2_psiv_cache_design.md
  • mamba3_mimo_p3_register_split_design.md
  • mamba_fork_canonical_2026_04_14.md
  • mamba_integration_log.md
  • mamba_review_followup_plan.md
  • v4_architecture.md
  • TENSOR_PARALLELISM.md
  • CURRENT_STATE.md
David Gornshtein • Datasunrise OÜMore posts →