pytorch
cuda
rocm
torch.compile
abi
nanochat

Living on PyTorch 2.12 Nightly: The nanochat ABI and Wheel-Matrix Tax

What it actually cost to sit on PyTorch 2.12 nightlies across CUDA, ROCm and custom TPU builds during the nanochat POC - ABI breaks, API churn, and the backward-compat patches that kept the fleet training.

8 min readDavid Gornshtein
Living on PyTorch 2.12 Nightly: The nanochat ABI and Wheel-Matrix Tax

Living on PyTorch 2.12 Nightly: The nanochat ABI and Wheel-Matrix Tax

The nanochat research stream that feeds the MegaCpp SLM ensemble is a moving target on purpose: FlashAttention 4 CuTe, Mamba 3 hybrid blocks, MTP heads, ReDo routing, FSDP2, and a TPU lane that wants modern XLA. That combination pinned us to PyTorch 2.12 nightlies on GPU and to custom PyTorch 2.9 / 2.11 builds on TPU for most of Q1 2026. This post is a specific, unglamorous account of what that costs.

It is not a "PyTorch is great, here is a benchmark" post. It is the ABI, API, and wheel-matrix tax we paid.

The stack we ended up with

At the point the dust settled, the fleet looked roughly like this:

  • H200 (primary CUDA target): torch 2.12.0.dev20260304+cu130, Python 3.13, Triton bundled, flash_attn 2.8.3+cu130torch2.10 force-installed on top, mamba-ssm + causal-conv1d rebuilt against the nightly headers.
  • H100 / A100 staging: torch 2.10.0+cu128, Python 3.13. Kept behind by one minor so we could reproduce old receipts.
  • GB10 / DGX Spark: torch 2.10.0+cu130, local flash_attn build against sm_121a.
  • TPU v6e lane: custom torch 2.9.0a0+git21fec65 or torch 2.11.0a0+git7afdbae with torch_xla 2.9.0+gitc04e61c or 2.11.0+gitc04e61c, libtpu 0.0.36 (production) / 0.0.37.dev20260224 (nightly), jax 0.9.0.

Nothing in that matrix is "pip install torch". Every row is either a nightly channel or a locally built wheel, and every row has a distinct ABI surface against the things we link in: FlashAttention, Triton, Mamba SSM, Transformer Engine, _XLAC.cpython-313-x86_64-linux-gnu.so.

That is the real bill: not the version numbers, but that the wheel matrix has to be coherent across roughly a dozen hosts and three accelerator families before a single training step runs.

Why 2.12 at all

Two things pushed us onto 2.12 before it was released. First, torch.compile on 2.6 / 2.10 has a documented reduce_scatter_tensor bug that makes our Megatron conditional-grad tests fail; the same tests pass cleanly on 2.12. That is an upstream defect, not our code, and we needed the distributed optimizer tests green before we could stress anything else.

Second, the Dynamo story for Mamba SSM's mamba_chunk_scan_combined only became workable on 2.12. On 2.10 the accepted pattern is torch.compiler.disable(MBlock) - you keep Mamba blocks out of the graph entirely. On 2.12 you can call torch._dynamo.allow_in_graph( mamba_chunk_scan_combined) and Dynamo traces the surrounding linear projections while leaving the Triton kernel opaque. The version gate in scripts/base_train.py ended up being explicit:

if torch.__version__ >= "2.12":
    torch._dynamo.allow_in_graph(mamba_chunk_scan_combined)
else:
    MBlock.forward = torch.compiler.disable(MBlock.forward)

Get that gate wrong and you either crash with a TorchRuntimeError because Dynamo runs the Triton kernel under FakeTensors and the kernel reads a data pointer, or you silently insert 13 graph breaks into NAM52 and eat a 24% throughput regression on H200. We saw both in production runs before the gate was in place.

The Triton stream bug that ate a session

Between torch 2.10 and early torch 2.12 nightlies there was a Triton codegen regression where launcher() was handed stream both positionally and as a keyword:

TypeError: launcher() got multiple values for argument 'stream'
at triton_heuristics.py:1417

It only fires during autotune benchmarking of certain generated kernels, so it does not reproduce until your Inductor FX graph hash changes and forces re-autotune. We hit it hard during a Modal H200 optimization sprint: the best known run (n19l0, 37,598 tok/sec) had a warm Inductor cache from an earlier code revision, so it never re-autotuned. The moment we touched gpt.py, the cache invalidated, re-autotune ran, and every subsequent test crashed with the stream error.

The fix that kept us alive was TORCHINDUCTOR_AUTOTUNE_IN_SUBPROC=1. Autotune runs in a forked subprocess; if Triton crashes, Inductor falls back to the default kernel and training continues. It is not a fix, it is a firebreak, but it moved the failure from "training dies" to "first compile is a bit slower".

ABI mismatches we actually hit

Three separate ABI edges bit us.

CuTe softcap. Current FlashAttention 4 CuTe kernels expose a score_mod entry point whose softcap ABI did not line up with the version of flash_attn wheel that we could build against torch 2.12 cu130. Calling CuTe softcap directly caused either segfaults or silent wrong answers on H200. The fix was to route softcap through our repo-local tolerant score_mod rather than the CuTe fast path, documented in the backend stoplight matrix. The direct fa4_gather path had a related smaller issue: it returns a tuple, and you have to unwrap before dtype cast, or you get a cryptic type error from the next op.

flash_attn wheel labels lie. The wheel marked flash_attn-2.8.3+cu130torch2.10 installs and imports happily on torch 2.12.0.dev...+cu130. It does not crash at import. It fails later, inside a kernel, under specific shapes, because the C++ extension was compiled against torch 2.10 C++ ABI structs. We now treat a fresh .venv313 with a new torch nightly as insufficient by default and reuse a known-good nanochat-exact/venv313 image where flash_attn, mamba-ssm, and causal-conv1d have been rebuilt against the exact nightly.

ROCm wheels are a separate universe. We looked at ROCm for an MI300 lane. The PyTorch nightly cadence for ROCm does not line up with CUDA 13.0, and Triton's ROCm backend in the 2.12 nightly window could not host the Mamba SSM kernels we needed. We formally parked ROCm for this POC rather than pretend it was ready. Honest engineering: not every wheel in the matrix was worth finishing.

Backward-compat patches we had to carry

Even with a coherent wheel set, we still kept a small patch queue in-tree against upstream behaviour that was wrong for us.

allow_in_graph version gate. Described above; lives in base_train.py and reads torch.__version__ explicitly.

NCCL heartbeat during compile. torch.compile on NAM52 with H200 takes 15-20 minutes for Triton JIT. NCCL's default TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=600 kills ranks that are not running collectives during that window, so every multi-GPU compile died at the 10-minute mark. base_train.py now auto-sets TORCH_NCCL_ENABLE_MONITORING=0, TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=7200, and TORCH_DISTRIBUTED_DEFAULT_TIMEOUT=7200 whenever it sees LOCAL_RANK. Combined with TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=0 to avoid the Triton autotune workspace OOM, this is what actually keeps an 8x H200 compile alive.

Lazy NCCL init for retry re-execs. Startup retries (for example when auto-fit shrinks dbs 8 -> 4) spawn a new process. On torch 2.12 H200 hosts, the retry child keeps the cached ProcessGroup Gloo debug wrapper, and then dies with Gloo connectFullMesh ... Connection refused before step 00000. We now downgrade TORCH_DISTRIBUTED_DEBUG=DETAIL to INFO in retry children, force the live debug level down in-process, force lazy NCCL init on any CUDA retry re-exec (not only expert-parallel), and skip the immediate bootstrap barrier for expert-parallel CUDA lanes. Every one of those four things was a separate regression wave before it was a patch.

Per-block compile, not whole-model compile. torch.compile(model) on NAM52 produces 13 MBlock graph breaks because Mamba is torch.compiler.disabled. Each break is a CPU/GPU sync. The fix is to compile each ABlock / EBlock individually and leave MBlock uncompiled - zero graph breaks, self-contained graphs. This is not a PyTorch bug, but it is a pattern that only becomes economical once your graph-break story is stable enough to predict.

MoE weighted multiply in float32. On torch 2.12 with 64 experts and top-6 routing, a bf16 weighted multiply in the padded MoE path accumulated enough rounding error across scatter_adds to explode loss (15 -> 73 -> 3187 -> 5014 in four steps). Reverted to fp32 multiply plus accumulate. This reads like a numerical bug, but it only became visible once we turned torch.compile on and the compiled graph exposed the rounding path; on the previous unfused kernels the error was below the noise floor.

The boring errors that cost the most time

Two of the costliest regressions were not ABI issues at all; they were ordinary Python mistakes that the 2.12 migration made expensive.

A commit added os.environ.get("PJRT_DEVICE") to gpt.py without importing os. Every Modal H200 run in that window crashed with NameError: name 'os' is not defined at import time, and because modal_train.py was not yet capturing subprocess stdout, the crash was invisible in the orchestrator. A two-line fix (use the already-imported _is_tpu_requested() helper) cost roughly a day of mystified debugging.

Separately, a surface-binding helper used Path.resolve() on canonical paths but not on observed paths. On H200 hosts where the venv python symlinks into a UV-managed CPython, the comparison always failed and every test that checked binding identity broke. Both sides now get normalized. Neither of these is a PyTorch defect - they are the sort of shear that happens when the interpreter, venv layout, and symlink story shift under you while you are also changing the torch version.

What we would do differently

If we were starting the POC today, with the hindsight of the last two quarters, three things would change.

We would freeze the wheel matrix earlier. Not "standardize on torch 2.12", but "freeze exactly one venv image per accelerator family, rebuild the linked C++ extensions against it once, and treat every host as immutable from that point until we explicitly bump the image". Our worst regressions came from fleets where three hosts were on torch 2.9 and one was on torch 2.11 with the same code. Cross-machine comparisons stopped being meaningful.

We would write the version-gate contract for torch.compile before enabling compile anywhere. The allow_in_graph vs compiler.disable decision is upstream-version-sensitive, and we learned that the expensive way.

We would keep a dedicated "rollback wheel" for each host. When a nightly breaks an ABI, the recovery path is not "pip install something" - it is "copy a pre-built .so from cold storage". Having that as a first-class artifact, with a checksum and a known-good kernel, is cheaper than any debugging session we ran.

None of this makes PyTorch 2.12 a bad bet. It made everything downstream possible - FA4, Mamba 3 Dynamo tracing, FSDP2, clean torch.compile over attention blocks. But "runs on the latest nightly" is a claim with a cost, and the cost is paid in ABI audits and patch queues, not in elegant diffs.

References

  • CHANGELOG.md
  • review_gcp_tpu.md
  • TRAINING_PLAN.md
  • training_review.md
  • BACKEND_STOPLIGHT_MATRIX.md
  • CURRENT_STATE.md
  • TPU_SETUP.md
  • persistent_cache_fix.diff (README)
David Gornshtein • Datasunrise OÜMore posts →