混合精度训练(FP16/BF16)
一句话总结
混合精度训练使用半精度(FP16/BF16)进行前向和反向计算以节省显存并加速训练,同时保留FP32的主权重副本确保数值精度。
核心概念
FP32(32位浮点数)精度高但显存大、计算慢。FP16(半精度)显存减半、GPU Tensor Core加速2-3倍,但表示范围小(最大值65504),容易溢出。BF16(Brain Float 16)与FP32有相同的指数位(8bit)和表示范围,但精度低于FP16。混合精度训练策略:(1)维护FP32的主权重(master weights);(2)前向和反向传播用FP16/BF16;(3)FP16需配合Loss Scaling防止梯度下溢——将loss乘以一个大数(如1024)放大梯度,更新前再缩小;(4)BF16因范围足够大通常不需要Loss Scaling,已成为LLM训练的主流选择(需Ampere及以上GPU)。
为什么重要
没有混合精度训练,LLM训练的显存需求和时间成本将翻倍。以7B模型为例:FP32需要28GB存参数,FP16/BF16只需14GB。再加上激活值、优化器状态都可以用低精度,总体显存节省40-50%。A100/H100的BF16算力是FP32的2倍。
实践要点
BF16优先于FP16——数值稳定性更好,无需Loss Scaling;PyTorch的torch.cuda.amp提供了自动混合精度(AMP)接口;LayerNorm和Softmax等敏感操作通常保持FP32计算;训练中如果出现loss为NaN,首先检查是否是精度溢出问题。
常见误区
误区一:FP16和BF16精度相同——FP16有更高的尾数精度(10bit vs 7bit),BF16有更大的动态范围(指数8bit vs 5bit)。误区二:混合精度会显著损失模型质量——正确实现的混合精度训练与FP32训练结果几乎无差别。误区三:所有操作都可以用半精度——某些数值敏感操作(如归一化、softmax)需要保持FP32。