GQA/MQA注意力优化
一句话总结
GQA和MQA通过减少Key/Value头的数量来降低KV Cache内存占用和推理开销,在几乎不损失性能的前提下显著加速推理。
核心概念
标准多头注意力(MHA)中每个注意力头有独立的Q、K、V投影。MQA(Multi-Query Attention)让所有注意力头共享同一组K和V,仅Q保持多头,大幅减少KV Cache。GQA(Grouped-Query Attention)是MHA和MQA的折中方案:将注意力头分成若干组,同组内共享K、V。例如LLaMA-2 70B用8个KV头服务64个Q头(每8个Q头共享1组KV)。GQA在减少内存的同时保持了更好的模型表达能力。
为什么重要
大模型推理的主要瓶颈是KV Cache的内存和带宽。标准MHA的KV Cache随序列长度和batch size线性增长,在长序列场景下极其昂贵。GQA能将KV Cache减少到原来的1/8甚至更少,直接提升推理吞吐量和最大支持序列长度。
实践要点
GQA的组数选择需要权衡:组数越少内存节省越多但性能可能下降。常见配置是KV头数为Q头数的1/4或1/8。从MHA模型转换为GQA模型时,可以通过mean pooling合并KV头权重进行初始化。训练时GQA的计算量与MHA相同(只是参数少了),加速体现在推理阶段。
常见误区
误区一:认为GQA会显著降低模型质量。实验表明GQA-8在绝大多数任务上与MHA持平。误区二:混淆训练加速和推理加速,GQA主要加速推理而非训练。误区三:忽视GQA需要与高效推理框架(如vLLM)配合才能充分发挥优势。