Interview: Gradient Checkpointing的时空权衡比例与checkpoint层选择策略


题目解析

Gradient Checkpointing(激活重计算)是大模型训练中最常用的显存优化技术。理解其精确的时空权衡比例和最优checkpoint位置选择,是训练效率调优的重要技能。

解答思路

对于L层的Transformer,不用checkpoint需要存储O(L)层激活,反向计算量为1x。使用sqrt(L)个均匀checkpoint点时,显存从O(L)降至O(sqrt(L)),额外计算量约为33%(即总训练时间增加约1/3)。这是因为每个checkpoint段内的激活需要在反向时重新前向计算一遍,平均每层多算一次前向,而前向约占总计算的1/3。

关键要点

  1. 理想情况下显存节省约sqrt(L)倍,计算开销增加约33%;2. checkpoint点应选在计算量小但激活量大的层边界;3. Transformer中通常在每个Block的Self-Attention之前做checkpoint;4. FlashAttention已减少了Attention层的激活存储需求,checkpoint策略需相应调整。

加分回答

Selective Checkpointing是更精细的策略:只对激活内存占比大的操作(如Softmax、LayerNorm)做checkpoint,保留计算密集但激活小的操作(如矩阵乘法)的激活。这能在更少的计算开销下获得接近的显存节省。

常见踩坑

一个常见错误是对所有层都做checkpoint,这样虽然显存最省但计算开销最大。实际上大部分场景下选择性checkpoint就足够了。另外不要忽略checkpoint本身也需要存储输入,不是完全零开销。