A matrix multiplication that doesn’t need all those bits
Consider a transformer’s self-attention computation. The query-key dot products, followed by softmax, followed by the value projection. In FP32, each of these operations uses 32 bits per element — full IEEE 754 floating-point precision, with roughly 7 decimal digits of accuracy.
But the softmax output is a probability distribution. Its values fall between 0 and 1, and the model’s downstream behavior depends on the relative ordering and rough magnitudes of these probabilities far more than on their exact values. Whether the softmax assigns 0.0312 or 0.0314 to a particular position almost never changes the model’s prediction. That difference — two parts in ten thousand — disappears in the noise of the subsequent matrix multiplication.
This observation is the foundation of mixed precision: not all operations require equal precision, and the operations that tolerate lower precision are often the most compute-intensive ones.
The numerical tolerance landscape
A neural network is not a uniform numerical system. Different components exhibit different sensitivities to precision changes:
Attention and feedforward layers — the bulk of computation in transformer models — are dominated by matrix multiplications where the outputs are subsequently normalized (by LayerNorm or similar operations). The normalization absorbs small numerical errors by rescaling the outputs. This makes these layers relatively tolerant of reduced-precision arithmetic. BF16 and FP16 work well here because the normalization step cleans up the rounding errors that accumulate during the lower-precision matrix multiply.
Loss computation and gradient accumulation — in training — require higher precision because they deal with small values that grow through summation over many elements. Accumulating thousands of small gradients in FP16 risks overflow or catastrophic cancellation. This is why training frameworks keep a master copy of weights in FP32 and perform gradient reduction in FP32, even when the forward and backward passes run in BF16.
Embedding lookups and final projection layers tend to be more sensitive in some architectures because they operate at the boundaries where small numerical changes can shift which token gets selected or which class gets predicted. These layers sometimes benefit from staying at higher precision even when the rest of the model has been reduced.
The practical consequence is that precision is not a global setting to be applied uniformly. It’s a resource to be allocated selectively — high precision where sensitivity demands it, low precision where tolerance allows it.
How frameworks implement it
Modern frameworks make mixed precision largely automatic for the common case. PyTorch’s torch.amp (Automatic Mixed Precision) wraps the forward pass in a context that casts operations to lower precision where it’s been determined to be safe, while keeping certain operations — cumulative sums, log operations, loss functions — in FP32.
Under the hood, the decision about which operations run at which precision is based on empirically validated allowlists. NVIDIA’s documentation classifies operations into categories: operations that are safe in FP16/BF16 (most matrix multiplies and convolutions), operations that should remain in FP32 (reductions, normalizations, log-domain math), and operations where either precision is acceptable depending on context.
This automated approach works well for standard architectures. It becomes less reliable with custom operations, unusual architectures, or numerical edge cases specific to certain datasets. When we’ve evaluated non-standard architectures, we’ve sometimes found that the default allowlists are too aggressive or too conservative — a custom attention variant that needs FP32 for stability, or a normalization layer that works fine in BF16 despite being categorized as FP32-required.
The automation is a useful starting point, not a guarantee. Validation on the target workload remains necessary, just as precision being a design parameter rather than a fixed quality level means the design must be verified for each deployment.
The performance case
The motivation for mixed precision is straightforward: lower precision means less memory, less bandwidth, and more throughput.
A BF16 matrix multiply uses half the memory bandwidth of FP32 and can execute up to 2× faster on hardware with dedicated BF16 tensor cores (Ampere, Hopper). FP8 on Hopper-generation hardware offers another 2× over BF16 for supported operations. The improvement comes from both reduced data movement (less bandwidth consumed per operation) and dedicated hardware units (tensor cores with native lower-precision arithmetic).
Memory savings compound the throughput benefit. A model stored in BF16 uses half the HBM of FP32. This means larger batch sizes fit in memory, which improves GPU utilization. Or it means larger models fit on a single GPU, avoiding the communication overhead of model parallelism.
For inference specifically, where the workload is often memory-bandwidth-bound (reading model weights from HBM for every token), reduced precision directly translates to higher tokens-per-second because each token generation reads half (BF16 vs FP32) or quarter (FP8 vs FP32) the data from memory.
Stability emerges from selective precision retention
The reason mixed precision works reliably in practice — despite reducing numerical precision for most computations — is that the precision reduction is selective, not uniform.
The operations that are most vulnerable to precision errors (gradient accumulation, loss computation, certain normalizations) retain full precision. The operations that generate the vast majority of compute load (matrix multiplications, convolutions) use reduced precision because the magnitude of their rounding errors is small relative to the signal they carry, and subsequent operations (normalization, activation functions) absorb or mask those errors.
This selective strategy means the model’s numerical behavior in mixed precision closely tracks its behavior in full precision. The errors introduced by lower-precision arithmetic are absorbed at every normalization boundary, and the accumulated effect on the final output is typically within the noise floor of other sources of inference variability.
When the strategy fails — when mixed precision produces materially different outputs than full precision — it’s almost always because a specific operation was incorrectly classified as precision-tolerant when it wasn’t. This is diagnosable (compare layer-by-layer outputs between mixed and full precision) and fixable (keep that layer at higher precision while leaving the rest at lower precision). The fix is surgical, not a retreat to full FP32.
Mixed precision isn’t a hack or a shortcut. It’s an engineering exploitation of a real property of neural networks: uneven numerical sensitivity. Understanding where the tolerance lives — and validating that the framework’s assumptions match your workload’s reality, as discussed in how hardware constraints shape precision choices — is what makes it work reliably.