Attention sinks and telemetry on TPU: measure without turning observability into the bug
Why TPU telemetry has to be gated carefully: scalar reads can become host-device syncs, so sink and outlier tracking must be designed as explicit low-cadence instrumentation.

Telemetry is useful only if it does not distort the run it is trying to measure. On TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations, that is a real risk because scalar extraction inside the hot path can trigger host-device synchronization.
A clean TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries telemetry lane keeps sink or outlier summaries as fixed-shape device tensors, reduces them on device, and hands the host one deferred readout on a chosen cadence. In practice that means step-closure or post-step metrics boundaries, not a live scalar read in the compiled path. The narrow checked-in proof surfaces here are Exact-token sparse telemetry sample, XLA flag profile sample, TPU runtime probe sample, and Profiler and receipts.
Why TPU telemetry needs stricter discipline
A naive attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns-sink tracker can look harmless in Python while quietly introducing step-time regressions. On XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.GroundingAbout: XLA vs CUDA stack decisions Reference: Torch XLA / PJRT reality Reference: XLA SPMD sharding annotations, calls such as .item() are often exactly the kind of scalar boundary that should not live inside a high-frequency compiled path.
The bad boundary is wider than .item(). print(tensor), .nonzero(), boolean indexing, or a Python branch that depends on a live tensor can all force the host to answer too early. That is the same graph-contract problem described in Graph recompilation hell.
That leads to a simple operational rule: high-detail telemetry should be gated, not left permanently on.
What the instrumentation should do
The useful split is:
- a sink-oriented stream for attentionQuick term guideAttentionThe token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.GroundingAbout: fused MLA on NVIDIA Reference: shared MLA adapter boundaries Reference: public-safe MLA integration patterns concentration
- an activation-oriented stream for outliers or spikes
The important part is not one exact implementation. It is the execution discipline:
- no-op on non-logging steps
- bounded work on logging steps
- explicit operator-visible cadence
On multi-replica TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries lanes, the clean path is device aggregate first and host summarize last. Reduce the fixed-shape summary once per cadence, then read it on the chosen host boundary.
What a good TPU telemetry surface looks like
For TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.GroundingAbout: Torch XLA / PJRT reality History: TPU v6e host bring-up Reference: libtpu / PJRT ownership boundaries training, the cleanest telemetry API is usually:
- attach once
- enable or disable per logging cadence
- expose structured summaries rather than ad hoc scalar prints
That keeps observability compatible with compiled execution.
What should be preserved
The core idea worth preserving is that telemetry is part of the runtime contract. If instrumentation can materially change step time, then cadence, gating, and summary shape should all be treated as first-class policy rather than debugging leftovers.
Frequently asked questions
What is the fastest official readout before I capture a full TPU profile?+
torch_xla.debug.metrics.short_metrics_report() or metrics_report(). If those counters point at repeated CompileTime, elevated TransferFromDeviceTime, or unexpected aten:: fallback counters, capture an XProf or Cloud TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels. profile next.Why is a host read before the intended collective a correctness bug, not just a slowdown?+
xm.all_reduce(...) or xm.mesh_reduce(...) ever happens. The run can keep looking healthy while the dashboard has silently degraded from a global TPUQuick term guideTPUGoogle's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels. metric into per-rank telemetry. That is why the safe order is device aggregate first, then one deferred host read on the chosen cadence, which is the same reduction-order rule described in Torch XLA and PJRT reality.Should the threshold check itself live in Python if I already log only every N steps?+
Is torch.where wasting work by keeping both telemetry outcomes tensor-shaped?+
torch.where keeps the result as a tensor selection with the same broadcasted shape, so the logging decision can remain inside the device program instead of becoming a Python branch on a live XLAQuick term guideXLAThe compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here. value. For this telemetry lane, bounded extra masking work is less dangerous than changing the compiled path or forcing the host to answer before the intended cadence.What should the attention-sink metric actually summarize?+
Does xm.add_step_closure() make scalar reads safe anywhere?+
.item() inside the compiled telemetry path. The safe order is still tensor-shaped sink summary, optional replica reduction, queued closure, then host formatting on the chosen cadence.Which local files show the telemetry contract versus the runtime receipt?+
Terms used in this article
Start here for quick definitions, then follow the linked posts for deeper context.
The long-context failure mode where a few tokens, often the first token, absorb disproportionate attention mass and behave like a null-attention valve.
The token-mixing path that turns Q/K/V style projections into context-aware activations. On MLA pages here it refers to the concrete attention module boundary, not the A/M/E/R block-family shorthand.
The TPU runtime interface between frontend code and the backend executor; it is the ownership seam between JAX/Torch-XLA frontends and libtpu.
The compiler/runtime layer that lowers frontend tensor programs into executable TPU or accelerator graphs, with shape stability and ownership boundaries as the main operational concerns here.
Google's Tensor Processing Unit accelerator/runtime surface, where the important boundary in these posts is usually XLA or PJRT ownership rather than handwritten GPU kernels.
Continue with a curated reading path
TPU Sparse Attention and Pallas Kernels
A curated TPU sparse-attention reading path: block-sparse contracts, Pallas kernel choices, SPMD sharding, and the runtime surfaces that keep long-context TPU work stable.
TPU v6e and XLA Runtime Surfaces
A curated reading order for TPU work: bring-up, PJRT and Torch/XLA boundaries, SPMD sharding, and the kernel/runtime traps that made TPU performance non-obvious.