Trajectory-straightness loss: span sampling, layer choices, and XLA-safe limits
How the STP-style trajectory-straightness auxiliary loss is implemented in the public sample, why it samples ordered triples instead of predicting future latents, and what the runtime should preserve.

The trajectory-straightness auxiliary objective in the public sample is small enough to be underestimated. It is not a big subsystem, not a new decoder head, and not a speculative latent-prediction tower. It is a narrow regularizer in the STP sample that asks a more modest question: if hidden states form a trajectory through representation space, do short local segments stay roughly straight? That choice matters because it keeps the feature cheap, backend-friendly, and easy to gate from the training loop.
Overview: The implemented STP design samples ordered triples (s, r, t) from existing hidden states and penalizes curvature with 1 - cos(h[t] - h[r], h[r] - h[s]). The design deliberately avoids predictor heads, data-dependent control flow, and shape polymorphism. The good part is that it stays easy to wire into the runtime contract already used by the public sample. The still-open part is policy: how many spans to sample, which layers to supervise, and when to turn the loss on during training.
What the actual objective does
The STP loss sample defines the trajectory-straightness loss directly. The module documentation describes the hypothesis in plain terms: hidden-state trajectories are assumed to be locally linear, so the loss should punish curvature rather than predict the next state explicitly. The implementation exposes one entry point, compute_stp_loss, which accepts either a single hidden-state tensor (B, T, D) or a list of such tensors for the multi-layer variant.
The scalar objective is simple:
L_STP = 1 - cos(h[t] - h[r], h[r] - h[s])
That formula tells you almost everything about the intended behavior.
| Design choice | What the code does | Why it matters |
|---|---|---|
| Geometry target | compares consecutive direction vectors | regularizes local straightness instead of predicting a future latent |
| Sampling unit | ordered triples (s, r, t) |
ensures there is an intermediate point and therefore a notion of curvature |
| Core math | gather, subtract, cosine similarity |
keeps the kernel narrow and easy to transport across backends |
| Layer surface | single tensor or explicit list of tensors | allows final-layer-only or selected-layer supervision |
There is no learned predictor head in the implementation. There is no teacher model. There is no extra projector whose hidden cost later pollutes memory accounting. That absence is not an omission. It is the point. The design tries to get a useful geometry prior while staying close to the exact hidden-state surfaces the training loop already owns.
That absence also keeps the cost surface honest. A predictor-style auxiliary head would bring extra parameters, optimizer state, and retained activations of its own. The current span-sampling lane does not. It stays on gathers, vector differences, and cosine math over tensors the model already had to materialize for the main step.
The three documented variants are also explicit in the module documentation: Variant A is one triple on the last layer, Variant B is N triples on the last layer, and Variant C averages the same loss across a selected list of layers. That structure is narrow enough to reason about and broad enough to support real ablations.
Why span sampling looks the way it does
The most important code in _stp_loss_single is not the cosine itself. It is the sampling logic. For each span, the function samples a (B, 3) tensor of base indices with torch.randint, sorts along the last axis, adds offsets [0, 1, 2], and clamps to the valid sequence range. That sequence is what gives the feature its implementation character.
First, it guarantees ordered positions without host-side conditional logic. The function does not branch on special cases per sample. It builds a uniform tensor-shaped path that works the same way for every batch element. That is why the module can credibly describe its operations as XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations-safe.
Second, it keeps the estimator local. The loss is not asking whether the first token and the last token lie on a globally straight manifold path. It checks whether short consecutive segments align. In practice that makes the objective easier to interpret: it is a local curvature penalty, not a whole-sequence reconstruction signal.
Third, the cost scales with n_spans, not with vocabulary size, sequence-wide pairwise comparisons, or a separate prediction head. The docstring even calls the operation budget out as roughly three gathers, two subtractions, and one cosine per triple. The exact FLOP wording is informal, but the engineering intent is precise: this feature is supposed to be cheap enough to survive contact with real training.
There is also an XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations-specific reason to keep n_spans modest. A handful of
triples buys a cleaner curvature estimate without changing the feature's
character, but aggressively large n_spans values turn a cheap auxiliary
regularizer into a larger unrolled compiler problem. The practical first
question is not "can we sample dozens of spans?" but "how few spans reduce
variance enough to stay worth the step-time budget?"
The warning is concrete on XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations: the current Python loop behaves like an
unrolled small-N kernel, not like a symbolic
scan. Single-digit n_spans values still look like the same cheap
regularizer; once the count climbs into the low teens, the compile story
starts competing with the variance win. That is why the disciplined
first sweep is still 1 -> 3 -> 5, not "keep adding spans until the
curve looks smoother."
The tests in the STP loss sample and its accompanying note reinforce that design. They verify scalar output, the [0, 2] loss range implied by 1 - cos, near-zero loss on a synthetic straight-line trajectory, positive loss on random trajectories, correct behavior for short sequences, and gradient flow. That is a solid unit-level contract for a regularizer. It means the code is not just mathematically plausible; its edge conditions are intentionally covered.
Layer selection is the real policy surface
The math kernel is simple. The real design question is where and when to apply it.
The STP implementation itself supports either a single hidden-state tensor or a list of tensors. In the multi-layer case, compute_stp_loss computes one scalar per layer, stacks them, sums them, and divides by the number of layers. That averaging rule matters because it rejects an implicit “all layers by default” story. The feature expects explicit layer selection.
That aligns with the broader training contract shown in the public sample. The TPU bringup note shows that the STP coefficient is logged explicitly, warns that pipeline-parallel training drops auxiliary losses including STP, and gates activation with a delayed start policy. So the system is already split into two pieces:
- the STP loss sample answers how trajectory-straightness curvature is measured.
- the training runtime answers when STP participates and how strongly it is weighted.
That separation is the right one to preserve. Auxiliary objectives become hard to maintain when training policy leaks into the math kernel. Here the current arrangement is healthier: compute_stp_loss stays reusable, while step gating and feature enablement remain runtime concerns.
The delayed start policy is not decorative. Early in training the base next-token objective is still busy forming the representation manifold itself, so forcing straightness from step zero can regularize noise rather than geometry. The safer interpretation is "turn STP on after the base lane stabilizes," not "always-on auxiliary loss from the first batch," which is exactly the operational lesson in STP after ten thousand steps.
One practical implication is that any future default should stay explicit. If this feature gets promoted into a broader preset, that preset should state whether STP is applied to the last layer only, to a curated list of intermediate layers, or to some architecture-specific slice. It should not silently guess.
The research-side layer-selection argument points in the same direction but with
a useful extra caution: the earliest layers are still untangling token and
syntax structure, so "regularize every layer" is usually the wrong first
default. A safer starting point is a few later semantic layers, then an
architecture-aware expansion into selected A, M, or E blocks only if the
later-layer receipt is good. If you want the shortest local decoder for that
block vocabulary, use the MegaCpp model glossary.
Why the XLA-safe claim is not just marketing
The module documentation says “All operations are XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations-safe: static shapes, no data-dependent branching.” That line is easy to wave away unless you read it next to the TPU docs and the wider runtime code.
The TPU bringup note is very explicit that the TPU lane values static compiled graphs, per-micro-step compilation boundaries, and predictable shape behavior. The current TPU contract is narrower than a generic "just use torch on TPU" story. The runtime disables model torch.compile(...) on TPU, uses torch_xla.compile() around forward and backward by micro-step, and treats changing shapes or host-driven scalar behavior as regressions. In that environment, a regularizer that introduces dynamic control flow or variable output structure would be expensive even if its math looked elegant.
STP avoids that trap.
| Runtime concern | STP posture |
|---|---|
| changing tensor ranks | none |
| predictor-head materialization | none |
| host-driven conditionals in the kernel | none |
| auxiliary outputs with irregular structure | none |
That is the part worth carrying forward. A trajectory-straightness objective is only operationally useful if it respects the same graph-stability rules as the rest of the training stack. The current implementation does.
There is one caveat: “XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations-safe” does not mean “free.” If the training loop has to collect multiple layer activations solely for STP, that collection cost is real. But that cost is visible and policy-controlled. It is not hidden inside a second prediction tower or a backend-hostile control path.
The practical version of that rule is simple: do not rebuild variable-length sub-span tensors on the fly. If a backend needs a padded gather buffer plus a mask to keep shapes static, that is still faithful to the intended STP contract. Recompiling because every batch sampled a differently shaped Python list of states is not.
How this maps into the public sample
The current sample already has the right habits for introducing narrowly scoped runtime features. You can see that style in multiple places: the hybrid pattern sample makes layer-stack composition explicit, and the runtime surfaces in this article set expose auxiliary-loss weights such as the STP coefficient rather than hiding them in opaque presets. STP should land the same way.
The likely stable shape is:
stp_enabled: true
stp_weight: 0.02
start_step: 1000
stp_n_spans: 4
stp_layers: "4,8,12"
That sort of config expresses the policy cleanly. It says when STP begins, how much estimator stability is purchased with extra spans, and which representation surfaces receive the geometric bias.
The main things not to do are just as important.
| Bad landing choice | Why it is wrong |
|---|---|
| hiding STP behind an architecture-specific heuristic | makes comparisons impossible across runs |
| automatically supervising every hidden layer | adds opaque cost and muddies interpretation |
| folding step gating into the loss kernel | mixes policy with math |
| adding a predictor head “for research completeness” | breaks the cheap, narrow contract that makes STP practical |
The stack also has to stay honest about architecture differences. In hybrid layouts with attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns, Mamba, expert, and recurrent blocks, not every layer family necessarily wants the same geometry prior. The local notation used elsewhere in the stack is helpful here: A means attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns, M means Mamba, E means expert or MoE, and R means recurrent. A pattern like NAM56RQuick term guideNAM56RA concrete MegaCpp hybrid family name whose meaning lives in the launch pattern, feature placement, and runtime constraints rather than in one marketing label.GroundingAbout: NAM56R Megatron translation About: MegaCpp model glossary Example: NAM56R Megatron plan sample or AEMEAEMEAEMRQuick term guideAEMEAEMEAEMRA concrete NAM56R-style hybrid pattern string that encodes the ordered A/M/E/R block mix.GroundingAbout: MegaCpp model glossary Example: NAM56R pattern composition sample Example: NAM56R Megatron plan sample is not just branding. It is a reminder that layer selection is architecture-aware policy.
That does not mean STP must become block-type-specific on day one. It means the config should leave room for that discussion instead of pretending one global default is always correct.
What should be ablated before calling it “done”
The current code is ready for disciplined ablations because the knobs are already separated.
| Ablation | Files that support it | What it answers |
|---|---|---|
n_spans=1 vs n_spans>1 |
STP loss sample and its accompanying note | how noisy the estimator is at fixed layer choice |
| final layer vs explicit list | STP loss sample, runtime config surfaces | whether late semantic states or intermediate states benefit more |
| early start vs delayed start | the delayed-start runtime notes in this article set | whether STP helps only after a base representation stabilizes |
| dense-only vs hybrid patterns | runtime presets and model pattern naming | whether the objective behaves differently across A/M/E/R mixtures |
Two engineering facts should guide those ablations.
First, pipeline-parallel runtime notes already say auxiliary losses such as STP are dropped in that mode. So any headline about STP effectiveness must name the runtime context; otherwise the comparison can be false even if the config looks the same.
Pipeline parallelismQuick term guidePPPipeline parallelism cuts the model by depth — each GPU gets a contiguous range of layers. 32 transformer blocks on 8× H200 with PP=8 puts 4 layers on each GPU. Weights and optimizer state live only on the GPU owning that stage; activations flow GPU0→GPU1→... forward and back on the reverse pass. Cost: a pipeline bubble of roughly 1/microbatches — you need many microbatches per step to amortize. Use PP to scale past a single NVLink island across nodes, because what crosses the wire is tiny stage-boundary activations, not full tensors.GroundingAbout: parallelism map overview Example: pipeline parallel sample Example: pipeline activation sample also changes the failure shape. If an intermediate stage computes STP locally but its auxiliary gradient never rejoins the main backward stream, later stages try to straighten a trajectory that earlier stages are still free to twist. That is why PPQuick term guidePPPipeline parallelism cuts the model by depth — each GPU gets a contiguous range of layers. 32 transformer blocks on 8× H200 with PP=8 puts 4 layers on each GPU. Weights and optimizer state live only on the GPU owning that stage; activations flow GPU0→GPU1→... forward and back on the reverse pass. Cost: a pipeline bubble of roughly 1/microbatches — you need many microbatches per step to amortize. Use PP to scale past a single NVLink island across nodes, because what crosses the wire is tiny stage-boundary activations, not full tensors.GroundingAbout: parallelism map overview Example: pipeline parallel sample Example: pipeline activation sample results need to say whether the auxiliary loss survives stage boundaries at all, not just whether STP was enabled in config.
If a pipeline-parallel lane ever claims STP survived those stage boundaries, the receipt should also say how. The minimum credible story is a forward-neutral backward injection seam: compute the local auxiliary loss on the stage, attach it to the crossing hidden states, and scale it with grad accumulation and any pipeline chunking so the auxiliary gradient rejoins the main backward stream at the right magnitude. Anything weaker is better described as local monitoring than as full STP training.
Second, the first ablation should usually be variance control, not layer maximalism. A small increase from one span to a few spans is the cleaner way to stabilize the curvature estimate while preserving the same narrow objective. Expanding immediately to many layers or aggressively large span counts mixes two questions at once: whether the geometry signal helps, and whether the runtime can still carry the extra estimator cost honestly.
Finally, the current tests are unit tests, not training-value proofs. They show that the objective is stable, differentiable, and shape-safe. They do not yet prove that a particular layer subset or warmup schedule improves downstream quality. That distinction should be preserved in the article and in any future preset docs.
The right conclusion
The strongest thing about the current trajectory-straightness-loss design is not novelty. It is restraint. The sample does not try to solve trajectory learning with a large auxiliary subsystem. It chooses a local curvature penalty that is mathematically legible, cheap to compute, and compatible with the backend constraints already enforced elsewhere in the stack.
That is exactly why it is a plausible feature for a production training stack. The landing should keep the kernel narrow, keep the layer list explicit, keep start-step policy in the runtime, and evaluate the feature as a configurable regularizer rather than a grand theory of representation geometry. If a future preset promotes it, the case should be made with architecture-specific receipts, not with generic claims.
Frequently asked questions
Why sample ordered triples instead of predicting future latents?+
Why is delayed STP activation the safer default?+
Which layers should carry STP first?+
A, M, or E slices only if the later-layer receipt is good.Why not materialize full dynamic sub-spans on XLA and let the compiler sort it out?+
What makes a pipeline-parallel STP result credible?+
stp_enabled flag in config. It should show that the auxiliary loss rejoins the main backward stream at the stage boundary, usually through a forward-neutral gradient-injection seam whose scale matches grad accumulation and any pipeline chunking. If STP is only computed locally and its gradient is dropped before the next stage, that is better described as monitoring than as full STP training.Terms used in this article
Start here for quick definitions, then follow the linked posts for deeper context.
MegaCpp shorthand for the four main block families: attention, Mamba/state-space, expert/MoE, and recurrent tail layers.
The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.
A concrete NAM56R-style hybrid pattern string that encodes the ordered A/M/E/R block mix.
A concrete MegaCpp hybrid family name whose meaning lives in the launch pattern, feature placement, and runtime constraints rather than in one marketing label.
Pipeline parallelism cuts the model by depth — each GPU gets a contiguous range of layers. 32 transformer blocks on 8× H200 with PP=8 puts 4 layers on each GPU. Weights and optimizer state live only on the GPU owning that stage; activations flow GPU0→GPU1→... forward and back on the reverse pass. Cost: a pipeline bubble of roughly 1/microbatches — you need many microbatches per step to amortize. Use PP to scale past a single NVLink island across nodes, because what crosses the wire is tiny stage-boundary activations, not full tensors.
The token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.