Activation Checkpointing Policy: The Per-Block Pareto That Held Up
Selective versus full activation checkpointing across attention, MoE, Mamba-style, and recurrent blocks, and why the best policy depends on where each block actually spends memory and compute.

Activation Checkpointing Policy
Activation checkpointing is easy to enable and hard to tune. Turn it on everywhere and memory drops, but throughput often falls too much. Turn it off and the model may not fit at all. For hybrid architectures, neither extreme is right. The practical answer is a per-block policy.
What "checkpointing" actually means
Several different mechanisms are often grouped under the same label.
Manual block checkpointing wraps a forward callable with torch.utils.checkpoint.checkpoint(..., use_reentrant=False). In eager mode, this is the straightforward path.
The use_reentrant=False part is not a side detail. PyTorch recommends passing that flag explicitly and prefers the non-reentrant path because it keeps the autograd graph visible, supports keyword arguments and torch.autograd.grad, and can stop replay once the needed intermediates have been rebuilt. In practice that makes it the cleaner default for compiled or sharded trainingQuick term guideTrainingA grounded walkthrough of how the project approaches small-language-model training: explicit stack specs, memory-first patches, hybrid blocks, and…GroundingSLM training in MegaCpp: what the stack optimizes for and what stays explicit Training speed anatomy on H200 stacks.
Compiled rematerialization is different. When the compile pipeline is given an activation-memory budget below full retention, the compiler inserts recompute into the graph directly. That path should own the recompute decision inside compiled regions. Stacking manual checkpointing on top of compiler rematerialization tends to duplicate work.
CPU offload checkpointing replaces recompute with copies to pinned CPU memory and restores tensors during backward. It trades compute for host-link bandwidth and is only useful in narrow cases.
FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper-safe checkpointing is its own category. If recompute crosses an FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper-autocast boundary, the checkpoint path must preserve amax history correctly. Otherwise forward and recompute can quantize differently.
The safe public rule is to treat precision bookkeeping as checkpoint state, not as an incidental side effect. If a recompute path would update FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper scale history twice, use the FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper-aware checkpoint surface for that block instead of wrapping it with a generic callable.
Per-block policy
Attention blocks
Checkpoint by default. Attention is usually the largest activation consumer on the dense path, and its recompute cost is manageable. In eager mode, full-block checkpointing is acceptable. In compiled mode, framework-level selective recompute around core attention is often the better fit. Under FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper, use an FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper-safe checkpoint path.
MoE blocks
Use selective expert-GEMM recompute, not full-block checkpointing. Full-block MoEQuick term guideMoEToken Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble.GroundingThe MoE Routing We Actually Shipped Sequence, Context, and Expert Splits in the Hybrid Stack checkpointing reruns dispatch, permutation, communication, expert compute, and combine on backward, which is too expensive. Selective expert recompute captures most of the memory win while avoiding the expensive dispatch side of the block.
Mamba-style blocks
Do not checkpoint the whole block by default. Selective scan is expensive to rerun, so full-block checkpointing gives a poor throughput tradeoff. A narrow conv-plus-projection recompute path is much better and avoids FP8Quick term guideFP8Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.GroundingAbout: precision recipe: FP16, BF16, FP8, NVFP4 History: FP8 rollout notes Reference: Megatron FLCE on Hopper issues that can appear when recompute re-enters packed-token logic.
Recurrent blocks
Checkpoint the block, but keep a narrow recurrence recompute enabled as well. The recurrence chain can hold a surprising amount of memory, and rerunning it is much cheaper than rerunning the entire block payload.
Last layer
Do not checkpoint the last layer. Its activations are consumed immediately by backward, so the recompute adds cost without relieving peak memory.
Mechanism selection by runtime
On compiled CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200 paths with an activation-memory budget below full retention, let the compiler own rematerialization.
That ownership rule is easier to keep in a real launcher if compiled blocks stay visible as the same blocks. The checked-in regional-compile samples mark compiled inner blocks to skip manual checkpoint wrappers and prefer in-place block compilation when the runtime exposes it, precisely so compiler-driven rematerialization does not stack on top of a second hand-written recompute layer.
On eager CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200 paths, use manual block checkpointing plus the per-block rules above.
On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries-class systems, autotuned rematerialization does much of the work already. Manual checkpointing still helps on attention-heavy regions, but CPU offload is unavailable and MoEQuick term guideMoEToken Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble.GroundingThe MoE Routing We Actually Shipped Sequence, Context, and Expert Splits in the Hybrid Stack tradeoffs differ because dispatch buffers dominate differently.
CPU offload should stay narrow. It can help when recompute is expensive and host-link bandwidth is available, but it is not a general-purpose default.
The Pareto that mattered
The key lesson from the measurements was not that one mode won everywhere. It was that each block family had a different efficient frontier.
- Attention favored either full-block checkpointing or framework-level selective recompute, depending on whether the runtime was eager or compiled.
- MoEQuick term guideMoEToken Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble.GroundingThe MoE Routing We Actually Shipped Sequence, Context, and Expert Splits in the Hybrid Stack favored selective expert recompute.
- MambaQuick term guideMambaA grounded look at why MegaCpp combines Mamba-style state-space blocks with a smaller number of attention blocks for long-context C++ work, and…GroundingMamba 3 + Transformers: Why MegaCpp Uses a Hybrid Stack for C++ MegaCpp model glossary: patterns, blocks, and what names like NAM52 and NAM56R encode-style layers favored narrow recompute only.
- Recurrent blocks favored full-block checkpointing plus a small in-module recompute.
That mix moved the stack from out-of-memory regimes into usable trainingQuick term guideTrainingA grounded walkthrough of how the project approaches small-language-model training: explicit stack specs, memory-first patches, hybrid blocks, and…GroundingSLM training in MegaCpp: what the stack optimizes for and what stays explicit Training speed anatomy on H200 regimes without paying the full throughput cost of blanket checkpointing.
Failure modes worth remembering
Non-reentrant checkpointing can require relaxed determinism checks on token-routed MoEQuick term guideMoEToken Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble.GroundingThe MoE Routing We Actually Shipped Sequence, Context, and Expert Splits in the Hybrid Stack paths because tiny numerical differences in scatter-style updates can change routing metadata even when the backward signal is still acceptable.
Backend-specific runtime hooks matter. TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries and CUDAQuick term guideCUDANVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.GroundingAbout: XLA vs CUDA stack decisions History: GB10 tensor-path proof summary Reference: training on 8x H200 do not share the same rematerialization vocabulary, so policies that look similar at a high level still need backend-specific handling.
Automatic retry logic should preserve the intended recompute mode when it searches for a smaller batch geometry. If retries silently disable checkpointing, they can invalidate the very lane they are supposed to rescue.
The same rule applies to compile retries: if only batch geometry changes, keep the rematerialization owner fixed. A retry that switches from compiler-owned rematerialization to manual wrappers is a different experiment, not a safer version of the same one.
What we threw away
- One global
gradient_checkpointingstory as the source of truth. - Full-block checkpointing for MoEQuick term guideMoEToken Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble.GroundingThe MoE Routing We Actually Shipped Sequence, Context, and Expert Splits in the Hybrid Stack by default.
- Full-block checkpointing for MambaQuick term guideMambaA grounded look at why MegaCpp combines Mamba-style state-space blocks with a smaller number of attention blocks for long-context C++ work, and…GroundingMamba 3 + Transformers: Why MegaCpp Uses a Hybrid Stack for C++ MegaCpp model glossary: patterns, blocks, and what names like NAM52 and NAM56R encode-style layers.
- Manual checkpointing inside compiled regions that already use compiler-driven rematerialization.
Policy snapshot
| Block kind | Mechanism | Notes |
|---|---|---|
| Attention | manual checkpoint or framework-level selective |
FP8-safe checkpointing under FP8 |
| MoE | selective expert-GEMM recompute | avoid full-block replay of dispatch |
| Mamba-style | never full-block by default | narrow conv-plus-projection recompute only |
| Recurrent | full block plus narrow recurrence recompute | good memory win at modest cost |
| Last layer | never | little or no peak-memory benefit |
# Eager-mode sketch
def should_checkpoint(layer_idx, n_layers, block_kind, spacing):
if layer_idx == n_layers - 1:
return False
if block_kind == "MambaStyle":
return False
if spacing > 0 and layer_idx % spacing:
return False
return True
In practice this policy only stays honest when the launch planner, batch geometry, and offload policy agree. Gradient accumulation and microbatching under FSDP2 is the nearby batch-side companion, while CPU offload and startup memory calibration on H200 and GB10 covers the memory escape hatches.
Frequently asked questions
Why is selective expert-side recompute still the safer MoE default?+
moeQuick term guideMoEToken Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble. and moe_act instead of treating whole-layer replay as the first knob.What should I measure before changing this policy?+
Why avoid mixing manual wrappers with compiler rematerialization?+
Is TPU/XLA checkpointing just CUDA eager checkpointing with different hardware?+
Is CPU activation offload a blanket model-offload switch?+
What does an activation-memory budget actually control?+
Terms used in this article
Start here for quick definitions, then follow the linked posts for deeper context.
Token Choice vs Expert Choice, null-expert debugging, gating stability, and the production routing decisions behind the MegaCpp SLM Ensemble.
PyTorch's Fully Sharded Data Parallel v2 wrapper API. On CUDA it shards parameters, gradients, and optimizer state across the data-parallel group; in the TPU/XLA posts here it is usually a memory-goal analogy, not the actual eager wrapper mechanism.
A grounded look at why MegaCpp combines Mamba-style state-space blocks with a smaller number of attention blocks for long-context C++ work, and…
NVIDIA's GPU programming stack: compiler, runtime, driver, libraries, and kernel toolchain used by CUDA training and inference lanes.
Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.
A grounded walkthrough of how the project approaches small-language-model training: explicit stack specs, memory-first patches, hybrid blocks, and…
NVIDIA's Hopper H200 GPU platform, typically discussed here as an 8-GPU training node with large HBM capacity and NVLink-connected ranks.
Consumer Grace Blackwell GB10 / DGX Spark bring-up lane used to separate driver-visible gates, patched cubin signals, and real execution proof.
Eight-bit floating-point training and inference formats used to trade precision for throughput and memory on recent accelerator lanes.