Transformer Engine
Transformer Engine (TE) is NVIDIA's open-source library for FP8 training. It does the boring, fragile work that turns the FP8 hardware capability on H100 and B200 into something a real model can train on without diverging. If you are running FP8 outside TE or an equivalent (PyTorch's native FP8 in newer versions, JAX's experimental FP8 stack), you are signing up to debug numerics issues that look like hardware faults.
What TE actually does
The hardware gives you two FP8 formats: E4M3 (4 exponent bits, 3 mantissa bits) and E5M2 (5 exponent bits, 2 mantissa bits). Both have only 256 representable values, so each tensor needs a per-tensor scale factor that maps the full FP32 dynamic range onto the FP8 grid. TE wraps each transformer layer (Linear, LayerNorm, Attention) with a Python module that:
- Tracks
amax(the absolute max value) of every input tensor in a small history buffer (default 16 steps). - Picks a scale factor that fits the next iteration's expected amax.
- Casts the tensor to FP8 (E4M3 forward, E5M2 backward).
- Runs the matmul on tensor cores in FP8.
- Casts the result back to FP32 for accumulation.
- Updates the amax history with this step's actual amax for next time.
The amax history matters because amax can spike (one step has unusually large activations) without warning. Smoothing across 16 steps means a single spike does not trigger a scale jump that overflows. The CUDA implementation lives in TE's transformer_engine_torch extension; the Python API is transformer_engine.pytorch.Linear, LayerNorm, MultiheadAttention.
Why two formats: E4M3 forward, E5M2 backward
Forward activations cluster around zero with limited dynamic range; E4M3 (range up to 448, granularity at small values) fits them well. Backward gradients have much wider dynamic range (some are tiny, some are large), so E5M2 (range up to 57344, coarser at small values) keeps them representable. The TE recipe Format.HYBRID configures this split automatically. Picking the wrong format on either path is the most common source of "FP8 diverged" reports: E4M3 gradients overflow on the first big-loss step, E5M2 activations underflow as soon as a layer normalizes.
Integration in a real training stack
Wrapping a model in TE looks like this:
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling, Format
recipe = DelayedScaling(
fp8_format=Format.HYBRID,
amax_history_len=16,
amax_compute_algo="max",
)
class TransformerBlock(torch.nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.attn = te.MultiheadAttention(d_model, n_heads)
self.mlp = te.Sequential(
te.Linear(d_model, 4 * d_model),
te.Linear(4 * d_model, d_model),
)
def forward(self, x):
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
x = x + self.attn(x)
x = x + self.mlp(x)
return xThree things to notice. First, the te.fp8_autocast context manager is what activates FP8; outside it, the layers run in their parent's dtype. Second, weights are stored in an FP16 or FP32 master copy and only cast to FP8 at matmul time. The "weights are FP8" mental model is wrong; only operand tensors entering matmul are FP8. Third, the recipe is per-region: most stacks wrap the entire forward / backward in one autocast region rather than per layer.
What goes wrong without TE
If you use raw FP8 tensor-core ops without amax tracking and per-tensor scaling, gradient underflow arrives within a few hundred steps. The loss curve looks normal until it does not, then diverges fast enough that the next checkpoint is poison. The failure mode is hard to distinguish from a flaky GPU or a corrupted data shard; teams that hit it often spend a week chasing hardware before identifying the numerics issue. See stragglers and blast radius for the operational angle on why FP8 numerical issues look like hardware faults.
The B200 ships even more aggressive formats (FP4, FP6, microscaling), and TE is being extended to handle them. The story is the same: hardware gives you a format, software has to do the per-tensor scaling, and the gap between "raw hardware FP8" and "trainable FP8" is exactly the library you are looking at.
Practical guidance
- Always use TE on H100 and B200 for transformer workloads in PyTorch. The integration cost is small; the throughput win is roughly 1.6x to 1.8x over BF16.
- Validate the loss curve against a BF16 baseline for the first 1000 steps before scaling out.
- Keep weights in BF16 master copy. FP8 is for activations and matmul throughput.
- If you migrate to JAX, use the equivalent stack in
jax.experimental.fp8; the principle is the same, the API is different.
The takeaway: TE is the unglamorous library that makes FP8 trainable. Without it, the FP8 number on the H100 spec sheet is theoretical. With it, you actually get to use the silicon you paid for.
See also
Updated 2026-05-10