Training the MegaCpp SLM Ensemble on GB10: A Grace Blackwell War Story

Training the MegaCpp SLM Ensemble on GB10: A Grace Blackwell War Story
The GB10 / DGX Spark looks, from the marketing slide, like a small Blackwell. It is not. It is a different ISA wearing the Blackwell logo, with a desktop silicon die, a 273 GB/s LPDDR5X memory bus, and a software stack that assumes you are running on a B200 until you prove otherwise. This post is the unvarnished account of what it took to train the MegaCpp SLM Ensemble — our NAM56R MIMO hybrid stack — on this box: the silicon traps, the NaN hunts that turned out to be our own patches, and the software-stack recipe that finally held.
If you came here for a clean "10x speedup with one weird trick", close the tab. If you came here because your own GB10 is also misbehaving, read on.
What GB10 actually is
The first thing to internalise is that sm_121a is not a small sm_100a.
The Blackwell umbrella covers two architecturally distinct chips: datacenter
(sm_100a, B200) and consumer (sm_120a RTX 5090, sm_121a GB10). NVIDIA's
own forum reps put it bluntly: GB10's tensor cores are "closer to the GeForce
Ampere-style MMA model." RT cores and DLSS silicon took the die budget that
would have gone to TMEM and tcgen05 on the datacenter parts.
Concretely, on sm_121a:
- 48 SMs (vs 68 PyTorch's
is_big_gpuwants, vs 132 on H100) - 128 GB unified LPDDR5X at 273 GB/s (H100 has 3,350 GB/s; B200 has 8,000)
- ~128 KiB physical SMEM/SM, 99 KiB dynamic budget for CUTLASS
- Peak ~100 TFLOPS BF16, ~400 TFLOPS NVFP4 dense
- Compute capability 12.1, requires PTX ISA 8.8 and the
asuffix on the arch flag for block-scaled MMA
And what it does not have:
- No
tcgen05.*family (notcgen05.mma,.ld,.st,.alloc,.cp) - No Tensor Memory (TMEM) — the silicon is simply absent from the die
- No Hopper
wgmma.mma_async(deprecated across all Blackwell) - No 2-SM TMA multicast (the instruction exists, the cluster size is capped at 1, so it's effectively a no-op)
- No FlashAttention-4 — and there will never be one. FA4 wants TMEM and
tcgen05. The verifier rejects the cubin at the driver level.
The folk wisdom that "FP4 doesn't work on consumer Blackwell" is wrong but
understandable. FP4 does work on sm_121a via warp-level OMMA. What
doesn't work is the tcgen05-coupled UTCOMMA path, which is what 90% of
CUTLASS's NVFP4 examples hard-code. CUTLASS example 79 is the OMMA-based
reference that actually runs on GB10.
The toolchain dance
Bringing up nanochat-style pretraining on GB10 took five separate fixes
before the first iteration produced a finite loss. None of them were exotic;
all of them were undocumented in the obvious places.
ptxas. Triton ships its own ptxas (12.8) which does not know what
sm_121a is and bails with Value 'sm_121a' is not defined for option 'gpu-name'. The fix is to point Triton at the system CUDA 13.0+ ptxas:
if not os.environ.get("TRITON_PTXAS_PATH"):
for ptxas in ["/usr/local/cuda/bin/ptxas", shutil.which("ptxas")]:
if ptxas and os.path.exists(ptxas):
os.environ["TRITON_PTXAS_PATH"] = ptxas
break
is_big_gpu. PyTorch's inductor refuses max_autotune_gemm if your GPU
has fewer than 68 SMs. GB10 has 48. The patch is two lines:
os.environ["TORCHINDUCTOR_MAX_AUTOTUNE_GEMM"] = "1"
import torch
import torch._inductor.utils as inductor_utils
inductor_utils.is_big_gpu = lambda index=0: True
Some Triton configs then fail with shared-memory errors during autotune.
That is fine; autotune handles it. The 99 KiB SMEM budget on sm_121a is
lower than the SM100 default tile shapes assume, and every kernel that was
sized for B200 will overflow until you re-tile it.
MFU. The default MFU calculation in nanochat divides by H100's 989
TFLOPS BF16 peak, which on GB10 reports a depressing ~6%. With the correct
denominator (62 TFLOPS BF16, 500 TFLOPS NVFP4), MFU comes out around 11%,
which is roughly what the silicon can actually do given the bandwidth wall.
Liger graph break. LigerFusedLinearCrossEntropyFunction calls
target_mask.sum().item() internally, which forces a torch.compile graph
break and tanks Liger's throughput below the unfused baseline. One env knob
removes it:
torch._dynamo.config.capture_scalar_outputs = True
With that flag set, the Liger/Triton FLCE backend goes from 12,095 tok/s
(slower than baseline) to 16,950 tok/s — a real 1.17x end-to-end gain on
depth-20 with B=32, T=2048, grad_accum=8, plus 2.6 GB of memory back
because the B×T×V logits tensor never materialises.
NVFP4 RHT crash. TransformerEngine's NVFP4 recipe defaults to a Random
Hadamard Transform pre-scale, which on sm_120/sm_121 consistently dies
inside hadamard_transform_cast_fusion.cu with CUDA Error: invalid argument whenever M > 32. The fix lives in the recipe construction:
recipe = NVFP4BlockScaling(
disable_rht=True, # required for GB10
fp4_format=Format.E2M1,
override_linear_precision=(False, False, True), # WGrad in BF16
)
With RHT disabled, NVFP4 GEMMs measure 1.23x–1.36x over BF16 at common shapes (4096³ goes from 87 to 117 TFLOPS). That is far short of the 8x the spec sheet implies, because GB10 is memory-bandwidth-bound, not compute-bound. A 67%/33% backward/forward split in our profiler tells the same story: the SMs spend most of their time waiting for LPDDR5X.
The kernel layer: TileLang wins, cuTile is a dead end on this box
The MegaCpp ensemble's hot path is the Mamba-3 MIMO backward-of-backward
(bwd_bwd) kernel. We tried three independent paths to beat the TileLang
baseline of 167 µs on GB10. All three lost.
The cuTile Python rewrite was the most thorough attempt. Five algorithmic
variants — fused monolithic, nested @ct.function per phase, 3-kernel split,
hoisted loop invariants, full ct.static_iter unroll — all regressed against
the 2-kernel A/B split baseline of 624 µs. The full unroll was 5.2x slower.
The 3-kernel split that won by 33% on B200 (TMEM, 228 KiB SMEM) regressed
by 9% on GB10. The launch-overhead vs live-set trade-off flips on you the
moment you change SMEM budget. The lesson, which we now treat as a hard
rule: never assume a cuTile algorithmic variant that wins on one GPU will
win on another. Re-sweep on the target hardware.
The CuTe DSL hot-path port was the most fun and the most humbling. We got
cute.nvgpu.warp.MmaF16BF16Op + TMA + persistent scheduler running on
sm_121a out of the box (via the blackwell_geforce/dense_gemm.py pattern
— pass "sm_120" as the SmemAllocator capacity key, do not use
CUTE_DSL_ARCH=sm_120a overrides which the cubin loader rejects). The
resulting hand-written batched GEMM at L=256 ran in 10.28 µs. Then we
benchmarked torch.bmm on the same shape: 10.33 µs. cuBLAS on GB10
already matches a hand-written CuTe DSL kernel at 64×64×64 BF16. The
TileLang advantage is not GEMM efficiency; it is that TileLang fuses 10
GEMMs plus ~150 elementwise ops plus rotary plus reductions into one CUDA
kernel with 16 CTAs each running 16 chunks in on-chip state. cuTile Python
structurally cannot do that — it has to split into at least two kernels with
gmem temps. The 4x gap is the kernel-structure tax, not the instruction
tax.
The Triton M²RNN autotune sweep was the most anticlimactic. Twenty-five
configs across three shapes; the autotuner had already converged to
num_warps=8. num_stages moved performance by less than 0.03% because the
M²RNN R-layer is a sequential recurrence that holds state in registers and
has nothing to pipeline.
So we shipped TileLang, kept the cuTile 2-kernel split as a Python-readable
reference, and moved bwd_bwd optimisation effort to H200, where WGMMA, TMA,
swizzled SMEM, and warp specialisation via setmaxnreg actually unlock
different wins. GB10 mamba3 bwd_bwd is pareto-optimal at 167 µs. Every
further hour we spent on it was wasted; we say so plainly so you don't have
to repeat it.
One genuine deployment opportunity surfaced: cuTile mamba3_mimo_fwd is
17.7% faster than TileLang on B200 (0.054 ms vs 0.064 ms). A hybrid
cuTile fwd + TileLang bwd wrapper is a free throughput win on B200. On
GB10 the same hybrid changes nothing, because the forward isn't the
bottleneck.
NaN, NaN, NaN: a bisect that wasn't
The hardest two days of this project had nothing to do with kernels and
everything to do with grad norm: nan.
Symptom: bench3 H200 NAM56R training was producing finite gradients on
2026-04-13 ("golden 268 TFLOP/s"), and grad norm: nan on every iteration
by 2026-04-15. The obvious suspect was commit dd4da34, which had
rewritten the MTP and main-head Liger CE patches from reduction="none" to
reduction="mean" plus broadcast — explicitly to fix the silent grad
corruption from Liger issue #968. The hypothesis was clean: the "fix"
broke training.
Empirical reality: not the fix. We ran five mutations at HEAD —
MTP_DEPTHS=0, CPPMEGA_PREFER_NATIVE_HOPPER_CE=1, vanilla CE on logits,
revert mamba3 regional compile, drop selective recompute. All five produced
lm loss ~12 (finite), grad norm: nan on iter 1. The bug was upstream of
all four candidates.
So we set up a proper bisect on a clean worktree, with PYTHONPATH
precedence and a cppmega.__file__ import pre-check. Then we tested the
claimed-golden commit 0ce8a3a itself under the same env. NaN. Iter 1:
lm loss 11.88, grad norm nan. Same on 0038ad4 (Apr 13 19:00, the other
claimed "last known finite"). Same on HEAD. Three commits, including the
one that allegedly produced 269.4 TFLOP/s with finite gradients 24 hours
earlier, all NaN under the same env.
Conclusion forced by the data: the regression is not in cppmega source. The bench3 environment had drifted between the golden measurement and our bisect. Likely candidates, in order:
/mnt/data/cppmega-root/megatron-lmhad no.gitdirectory. The Megatron-LM version was unverifiable. Anyone could have rsynced over it.- The local
state-spaces-mambafork carried uncommitted patches (an fp32 upcast ofdd_dt + self.dt_biasbefore softplus, GQA branch changes in the bwd kernels). File mtimes showed 2026-04-13 23:10 — before the golden run — so the fork state itself wasn't the new variable, but nothing else was pinned either. - TE 2.13 FP8 tensorwise on this exact stack was already on our internal "fragile" list from the earlier europe regression notes.
The mistake, in hindsight, was institutional: bench3 had no SHA pinning for Megatron-LM, no environment lockfile for the venv, and no snapshot of the mamba_ssm fork at the moment of the golden measurement. Any of those would have made the bisect tractable. None of them existed. We aborted the bisect, wrote the post-mortem, and moved the test surface to a single-GPU GB10 reproduction we could control.
The other NaN: a mutation we made ourselves
While the bench3 NaN was being investigated, GB10 itself produced a
different and equally instructive failure: cudaErrorMisalignedAddress at
mamba_mimo_fwd_kernel, immediately on iter 0, during forward.
Root cause traced in an afternoon, once we stopped trusting the env gate.
Commit 4f115ea ("P1 — enable TMA + warp specialization in Mamba3 MIMO
kernels") had at some point been applied to the installed mamba_ssm
site-packages files, not to a copy. The apply_mamba3_mimo_p1_patches.py
helper had no restore path. The env gate CPPMEGA_MAMBA3_P1=0 correctly
skipped applying the patch on the next run — but the disk state was
already mutated. Every subsequent Python import picked up the patched
kernels regardless of the env var.
The diff was tiny:
- tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
- tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
+ tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False, # cppmega P1
+ tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False, # cppmega P1
+ tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True,
Plus @autotune(...) enabled. On TileLang f309d814, the TMA-lower path
produces bulk-copy descriptors that assume aligned multi-byte boundaries.
Combined with the tile shapes in mamba_mimo_fwd_kernel, that produces
unaligned addresses on sm_121a — the exact alignment-bug class that the
CUTLASS sm_120/sm_121 issues #2800/#3144 catalogue.
Two lessons. First: never patch installed site-packages in place. The
linear-CE patch already does this correctly via monkey-patch at import time;
the mamba3 P1 patch needs to be reworked to write to a mamba_ssm_p1/
shadow tree and patch at import. Second: env gates do not protect against
irreversible disk mutations. If your "off" path leaves the system in the
"on" state, your gate is a label, not a switch.
What we did finally validate on GB10
After the P1 disk state was reverted, we ran 13-layer NAM56R cuts (1 MLA
- 3 DSA + 4 MoE + 4 Mamba3/M2RNN + 1 MTP) end-to-end on a single GB10 and got finite gradients across every config we could reasonably build:
- BF16 with unfused attention, 5 iterations: finite, healthy loss decay
- FP8 tensorwise + MBS=1 and MBS=8: finite across 60 iterations
- TileLang SparseMLA BF16 + Liger MTP + Liger main-head + DSA indexer fused: finite, byte-identical iter-1 grad to the no-SparseMLA run
- full bench3 env (
CPPMEGA_NGRAM_HASH_ENABLED=1,CPPMEGA_STRUCTURE_ENABLED=1,MAMBA3_MIMO=1,MAMBA_NUM_GROUPS=8,MAMBA_RECOMPUTE=1): finite, grad 61.6 → 48.6 across 10 iters
- full bench3 env (
- true NAM56R per-layer dims (hidden=3584, ffn=18944, heads=28) at MBS=4: finite 10/10
- MBS=10, the bench3 golden batch size, with
CPPMEGA_INDEX_CACHE=1: finite 10/10, 87 GB peak memory, validation PPL 389/500
- MBS=10, the bench3 golden batch size, with
That last one is as close to the bench3 golden config as a single GB10 can
physically run. Every component that bench3 uses and that fits on sm_121a
— FP8 tensorwise, Liger MTP+main-head, TileLang SparseMLA BF16, DSA indexer
fused, IndexCache, ngram_hash, structure, selective recompute — produces
finite gradients on GB10. The NaN that haunted bench3 lives in the
intersection of multi-GPU EP=8 collective backward, megatron-lm SHA drift,
and TE 2.13 FP8 tensorwise behaviour — none of which a single GB10 can
exercise.
What this box is actually good for
After two weeks of shaking GB10, the honest assessment:
Use it for: kernel validation, single-node smoke tests, end-to-end architectural sanity checks, BF16 functional verification, NVFP4 dense GEMM development, and as a cheap surface for catching the class of bug that silently breaks on consumer Blackwell. The unified 128 GB memory makes it possible to run a 13-layer NAM56R cut at full bench3 dims on one device, which is genuinely useful.
Do not use it for: production training. The 273 GB/s bandwidth is the
ceiling on everything; the 10x peak-FLOP gap to H100 BF16 is not closeable;
and FA2/FA3 source-builds give literally zero speedup over PyTorch SDPA.
Production training goes to H200 / B200, where WGMMA, TMEM, tcgen05, and
real memory bandwidth do work the silicon can keep up with.
Default compile flag for any cppmega kernel you build for this box:
nvcc -arch=sm_120f. The family variant is 9x faster than sm_121a for
CUTLASS example 79 because of how SMEM carveout is calculated. Never use
the bare sm_120 / sm_121 — that strips block-scaled MMA entirely.
If we had a do-over: pin megatron-lm to a SHA in a real .git checkout,
snapshot the mamba_ssm fork state at every "golden" measurement, and never
let a patch helper write to installed site-packages. Those three rules
would have saved most of the time behind this writeup.