gb10
blackwell
sm121a
nvfp4
tilelang
transformer-engine
training

Training the MegaCpp SLM Ensemble on GB10: A Grace Blackwell War Story

11 min readDavid Gornshtein
Training the MegaCpp SLM Ensemble on GB10: A Grace Blackwell War Story

Training the MegaCpp SLM Ensemble on GB10: A Grace Blackwell War Story

The GB10 / DGX Spark looks, from the marketing slide, like a small Blackwell. It is not. It is a different ISA wearing the Blackwell logo, with a desktop silicon die, a 273 GB/s LPDDR5X memory bus, and a software stack that assumes you are running on a B200 until you prove otherwise. This post is the unvarnished account of what it took to train the MegaCpp SLM Ensemble — our NAM56R MIMO hybrid stack — on this box: the silicon traps, the NaN hunts that turned out to be our own patches, and the software-stack recipe that finally held.

If you came here for a clean "10x speedup with one weird trick", close the tab. If you came here because your own GB10 is also misbehaving, read on.

What GB10 actually is

The first thing to internalise is that sm_121a is not a small sm_100a. The Blackwell umbrella covers two architecturally distinct chips: datacenter (sm_100a, B200) and consumer (sm_120a RTX 5090, sm_121a GB10). NVIDIA's own forum reps put it bluntly: GB10's tensor cores are "closer to the GeForce Ampere-style MMA model." RT cores and DLSS silicon took the die budget that would have gone to TMEM and tcgen05 on the datacenter parts.

Concretely, on sm_121a:

  • 48 SMs (vs 68 PyTorch's is_big_gpu wants, vs 132 on H100)
  • 128 GB unified LPDDR5X at 273 GB/s (H100 has 3,350 GB/s; B200 has 8,000)
  • ~128 KiB physical SMEM/SM, 99 KiB dynamic budget for CUTLASS
  • Peak ~100 TFLOPS BF16, ~400 TFLOPS NVFP4 dense
  • Compute capability 12.1, requires PTX ISA 8.8 and the a suffix on the arch flag for block-scaled MMA

And what it does not have:

  • No tcgen05.* family (no tcgen05.mma, .ld, .st, .alloc, .cp)
  • No Tensor Memory (TMEM) — the silicon is simply absent from the die
  • No Hopper wgmma.mma_async (deprecated across all Blackwell)
  • No 2-SM TMA multicast (the instruction exists, the cluster size is capped at 1, so it's effectively a no-op)
  • No FlashAttention-4 — and there will never be one. FA4 wants TMEM and tcgen05. The verifier rejects the cubin at the driver level.

The folk wisdom that "FP4 doesn't work on consumer Blackwell" is wrong but understandable. FP4 does work on sm_121a via warp-level OMMA. What doesn't work is the tcgen05-coupled UTCOMMA path, which is what 90% of CUTLASS's NVFP4 examples hard-code. CUTLASS example 79 is the OMMA-based reference that actually runs on GB10.

The toolchain dance

Bringing up nanochat-style pretraining on GB10 took five separate fixes before the first iteration produced a finite loss. None of them were exotic; all of them were undocumented in the obvious places.

ptxas. Triton ships its own ptxas (12.8) which does not know what sm_121a is and bails with Value 'sm_121a' is not defined for option 'gpu-name'. The fix is to point Triton at the system CUDA 13.0+ ptxas:

if not os.environ.get("TRITON_PTXAS_PATH"):
    for ptxas in ["/usr/local/cuda/bin/ptxas", shutil.which("ptxas")]:
        if ptxas and os.path.exists(ptxas):
            os.environ["TRITON_PTXAS_PATH"] = ptxas
            break

is_big_gpu. PyTorch's inductor refuses max_autotune_gemm if your GPU has fewer than 68 SMs. GB10 has 48. The patch is two lines:

os.environ["TORCHINDUCTOR_MAX_AUTOTUNE_GEMM"] = "1"
import torch
import torch._inductor.utils as inductor_utils
inductor_utils.is_big_gpu = lambda index=0: True

Some Triton configs then fail with shared-memory errors during autotune. That is fine; autotune handles it. The 99 KiB SMEM budget on sm_121a is lower than the SM100 default tile shapes assume, and every kernel that was sized for B200 will overflow until you re-tile it.

MFU. The default MFU calculation in nanochat divides by H100's 989 TFLOPS BF16 peak, which on GB10 reports a depressing ~6%. With the correct denominator (62 TFLOPS BF16, 500 TFLOPS NVFP4), MFU comes out around 11%, which is roughly what the silicon can actually do given the bandwidth wall.

Liger graph break. LigerFusedLinearCrossEntropyFunction calls target_mask.sum().item() internally, which forces a torch.compile graph break and tanks Liger's throughput below the unfused baseline. One env knob removes it:

torch._dynamo.config.capture_scalar_outputs = True

With that flag set, the Liger/Triton FLCE backend goes from 12,095 tok/s (slower than baseline) to 16,950 tok/s — a real 1.17x end-to-end gain on depth-20 with B=32, T=2048, grad_accum=8, plus 2.6 GB of memory back because the B×T×V logits tensor never materialises.

NVFP4 RHT crash. TransformerEngine's NVFP4 recipe defaults to a Random Hadamard Transform pre-scale, which on sm_120/sm_121 consistently dies inside hadamard_transform_cast_fusion.cu with CUDA Error: invalid argument whenever M > 32. The fix lives in the recipe construction:

recipe = NVFP4BlockScaling(
    disable_rht=True,                                # required for GB10
    fp4_format=Format.E2M1,
    override_linear_precision=(False, False, True),  # WGrad in BF16
)

With RHT disabled, NVFP4 GEMMs measure 1.23x–1.36x over BF16 at common shapes (4096³ goes from 87 to 117 TFLOPS). That is far short of the 8x the spec sheet implies, because GB10 is memory-bandwidth-bound, not compute-bound. A 67%/33% backward/forward split in our profiler tells the same story: the SMs spend most of their time waiting for LPDDR5X.

The kernel layer: TileLang wins, cuTile is a dead end on this box

The MegaCpp ensemble's hot path is the Mamba-3 MIMO backward-of-backward (bwd_bwd) kernel. We tried three independent paths to beat the TileLang baseline of 167 µs on GB10. All three lost.

The cuTile Python rewrite was the most thorough attempt. Five algorithmic variants — fused monolithic, nested @ct.function per phase, 3-kernel split, hoisted loop invariants, full ct.static_iter unroll — all regressed against the 2-kernel A/B split baseline of 624 µs. The full unroll was 5.2x slower. The 3-kernel split that won by 33% on B200 (TMEM, 228 KiB SMEM) regressed by 9% on GB10. The launch-overhead vs live-set trade-off flips on you the moment you change SMEM budget. The lesson, which we now treat as a hard rule: never assume a cuTile algorithmic variant that wins on one GPU will win on another. Re-sweep on the target hardware.

The CuTe DSL hot-path port was the most fun and the most humbling. We got cute.nvgpu.warp.MmaF16BF16Op + TMA + persistent scheduler running on sm_121a out of the box (via the blackwell_geforce/dense_gemm.py pattern — pass "sm_120" as the SmemAllocator capacity key, do not use CUTE_DSL_ARCH=sm_120a overrides which the cubin loader rejects). The resulting hand-written batched GEMM at L=256 ran in 10.28 µs. Then we benchmarked torch.bmm on the same shape: 10.33 µs. cuBLAS on GB10 already matches a hand-written CuTe DSL kernel at 64×64×64 BF16. The TileLang advantage is not GEMM efficiency; it is that TileLang fuses 10 GEMMs plus ~150 elementwise ops plus rotary plus reductions into one CUDA kernel with 16 CTAs each running 16 chunks in on-chip state. cuTile Python structurally cannot do that — it has to split into at least two kernels with gmem temps. The 4x gap is the kernel-structure tax, not the instruction tax.

The Triton M²RNN autotune sweep was the most anticlimactic. Twenty-five configs across three shapes; the autotuner had already converged to num_warps=8. num_stages moved performance by less than 0.03% because the M²RNN R-layer is a sequential recurrence that holds state in registers and has nothing to pipeline.

So we shipped TileLang, kept the cuTile 2-kernel split as a Python-readable reference, and moved bwd_bwd optimisation effort to H200, where WGMMA, TMA, swizzled SMEM, and warp specialisation via setmaxnreg actually unlock different wins. GB10 mamba3 bwd_bwd is pareto-optimal at 167 µs. Every further hour we spent on it was wasted; we say so plainly so you don't have to repeat it.

One genuine deployment opportunity surfaced: cuTile mamba3_mimo_fwd is 17.7% faster than TileLang on B200 (0.054 ms vs 0.064 ms). A hybrid cuTile fwd + TileLang bwd wrapper is a free throughput win on B200. On GB10 the same hybrid changes nothing, because the forward isn't the bottleneck.

NaN, NaN, NaN: a bisect that wasn't

The hardest two days of this project had nothing to do with kernels and everything to do with grad norm: nan.

Symptom: bench3 H200 NAM56R training was producing finite gradients on 2026-04-13 ("golden 268 TFLOP/s"), and grad norm: nan on every iteration by 2026-04-15. The obvious suspect was commit dd4da34, which had rewritten the MTP and main-head Liger CE patches from reduction="none" to reduction="mean" plus broadcast — explicitly to fix the silent grad corruption from Liger issue #968. The hypothesis was clean: the "fix" broke training.

Empirical reality: not the fix. We ran five mutations at HEAD — MTP_DEPTHS=0, CPPMEGA_PREFER_NATIVE_HOPPER_CE=1, vanilla CE on logits, revert mamba3 regional compile, drop selective recompute. All five produced lm loss ~12 (finite), grad norm: nan on iter 1. The bug was upstream of all four candidates.

So we set up a proper bisect on a clean worktree, with PYTHONPATH precedence and a cppmega.__file__ import pre-check. Then we tested the claimed-golden commit 0ce8a3a itself under the same env. NaN. Iter 1: lm loss 11.88, grad norm nan. Same on 0038ad4 (Apr 13 19:00, the other claimed "last known finite"). Same on HEAD. Three commits, including the one that allegedly produced 269.4 TFLOP/s with finite gradients 24 hours earlier, all NaN under the same env.

Conclusion forced by the data: the regression is not in cppmega source. The bench3 environment had drifted between the golden measurement and our bisect. Likely candidates, in order:

  1. /mnt/data/cppmega-root/megatron-lm had no .git directory. The Megatron-LM version was unverifiable. Anyone could have rsynced over it.
  2. The local state-spaces-mamba fork carried uncommitted patches (an fp32 upcast of dd_dt + self.dt_bias before softplus, GQA branch changes in the bwd kernels). File mtimes showed 2026-04-13 23:10 — before the golden run — so the fork state itself wasn't the new variable, but nothing else was pinned either.
  3. TE 2.13 FP8 tensorwise on this exact stack was already on our internal "fragile" list from the earlier europe regression notes.

The mistake, in hindsight, was institutional: bench3 had no SHA pinning for Megatron-LM, no environment lockfile for the venv, and no snapshot of the mamba_ssm fork at the moment of the golden measurement. Any of those would have made the bisect tractable. None of them existed. We aborted the bisect, wrote the post-mortem, and moved the test surface to a single-GPU GB10 reproduction we could control.

The other NaN: a mutation we made ourselves

While the bench3 NaN was being investigated, GB10 itself produced a different and equally instructive failure: cudaErrorMisalignedAddress at mamba_mimo_fwd_kernel, immediately on iter 0, during forward.

Root cause traced in an afternoon, once we stopped trusting the env gate. Commit 4f115ea ("P1 — enable TMA + warp specialization in Mamba3 MIMO kernels") had at some point been applied to the installed mamba_ssm site-packages files, not to a copy. The apply_mamba3_mimo_p1_patches.py helper had no restore path. The env gate CPPMEGA_MAMBA3_P1=0 correctly skipped applying the patch on the next run — but the disk state was already mutated. Every subsequent Python import picked up the patched kernels regardless of the env var.

The diff was tiny:

-        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
-        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
+        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: False,    # cppmega P1
+        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False,  # cppmega P1
+        tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True,

Plus @autotune(...) enabled. On TileLang f309d814, the TMA-lower path produces bulk-copy descriptors that assume aligned multi-byte boundaries. Combined with the tile shapes in mamba_mimo_fwd_kernel, that produces unaligned addresses on sm_121a — the exact alignment-bug class that the CUTLASS sm_120/sm_121 issues #2800/#3144 catalogue.

Two lessons. First: never patch installed site-packages in place. The linear-CE patch already does this correctly via monkey-patch at import time; the mamba3 P1 patch needs to be reworked to write to a mamba_ssm_p1/ shadow tree and patch at import. Second: env gates do not protect against irreversible disk mutations. If your "off" path leaves the system in the "on" state, your gate is a label, not a switch.

What we did finally validate on GB10

After the P1 disk state was reverted, we ran 13-layer NAM56R cuts (1 MLA

  • 3 DSA + 4 MoE + 4 Mamba3/M2RNN + 1 MTP) end-to-end on a single GB10 and got finite gradients across every config we could reasonably build:
  • BF16 with unfused attention, 5 iterations: finite, healthy loss decay
  • FP8 tensorwise + MBS=1 and MBS=8: finite across 60 iterations
    • TileLang SparseMLA BF16 + Liger MTP + Liger main-head + DSA indexer fused: finite, byte-identical iter-1 grad to the no-SparseMLA run
    • full bench3 env (CPPMEGA_NGRAM_HASH_ENABLED=1, CPPMEGA_STRUCTURE_ENABLED=1, MAMBA3_MIMO=1, MAMBA_NUM_GROUPS=8, MAMBA_RECOMPUTE=1): finite, grad 61.6 → 48.6 across 10 iters
    • true NAM56R per-layer dims (hidden=3584, ffn=18944, heads=28) at MBS=4: finite 10/10
    • MBS=10, the bench3 golden batch size, with CPPMEGA_INDEX_CACHE=1: finite 10/10, 87 GB peak memory, validation PPL 389/500

That last one is as close to the bench3 golden config as a single GB10 can physically run. Every component that bench3 uses and that fits on sm_121a — FP8 tensorwise, Liger MTP+main-head, TileLang SparseMLA BF16, DSA indexer fused, IndexCache, ngram_hash, structure, selective recompute — produces finite gradients on GB10. The NaN that haunted bench3 lives in the intersection of multi-GPU EP=8 collective backward, megatron-lm SHA drift, and TE 2.13 FP8 tensorwise behaviour — none of which a single GB10 can exercise.

What this box is actually good for

After two weeks of shaking GB10, the honest assessment:

Use it for: kernel validation, single-node smoke tests, end-to-end architectural sanity checks, BF16 functional verification, NVFP4 dense GEMM development, and as a cheap surface for catching the class of bug that silently breaks on consumer Blackwell. The unified 128 GB memory makes it possible to run a 13-layer NAM56R cut at full bench3 dims on one device, which is genuinely useful.

Do not use it for: production training. The 273 GB/s bandwidth is the ceiling on everything; the 10x peak-FLOP gap to H100 BF16 is not closeable; and FA2/FA3 source-builds give literally zero speedup over PyTorch SDPA. Production training goes to H200 / B200, where WGMMA, TMEM, tcgen05, and real memory bandwidth do work the silicon can keep up with.

Default compile flag for any cppmega kernel you build for this box: nvcc -arch=sm_120f. The family variant is 9x faster than sm_121a for CUTLASS example 79 because of how SMEM carveout is calculated. Never use the bare sm_120 / sm_121 — that strips block-scaled MMA entirely.

If we had a do-over: pin megatron-lm to a SHA in a real .git checkout, snapshot the mamba_ssm fork state at every "golden" measurement, and never let a patch helper write to installed site-packages. Those three rules would have saved most of the time behind this writeup.

David Gornshtein • Datasunrise OÜMore posts →