理解 MHA、GQA、MQA 和 MLA:多头注意力的变种及其应用
理解 MHA、GQA、MQA 和 MLA:多头注意力的变种及其应用
多头注意力(Multi-Head Attention, MHA)是Transformer结构的核心组件,近年来产生了多个变种,如GQA(Group Query Attention)、MQA(Multi-Query Attention)和MLA(Multi-Layer Attention)。这些改进主要目的是提高计算效率和减少计算开销。本文将深入探讨这些注意力机制的工作原理、数学公式、优缺点及应用场景,帮助读者理解Transformer及其改进版本。
1. MHA(Multi-Head Attention,多头注意力)
1.1 MHA 的基本原理
多头注意力(MHA)是Transformer结构的核心组件之一,它的作用是:
- 让模型在不同的子空间(subspace)上学习不同的特征。
- 提高模型的表达能力,使其能够关注输入序列的不同部分。
- 并行计算,提高计算效率。
MHA的核心思想是将输入的Query(查询)、Key(键)和Value(值)分别投影到多个不同的头(head)上,每个头独立计算注意力,然后将多个头的结果拼接后投影回原始维度。
1.2 计算过程
给定输入矩阵X(形状为s×d),MHA计算如下:
线性变换:将输入X变换成Query(Q)、Key(K)、Value(V)。
其中WQi,WKi,WVi是不同头的权重矩阵。计算Scaled Dot-Product Attention(缩放点积注意力):
其中dk=d/h,是每个头的维度。拼接多个头的输出:
其中WO是最终的投影矩阵。
1.3 MHA 的优势和劣势
✅ 优势:
- 允许模型在不同的子空间上学习不同的注意力模式。
- 提高模型的表达能力,可以关注输入序列的不同部分。
- 并行计算,可以在GPU上高效执行。
❌ 劣势:
- 计算量较大:每个Query头都需要与多个Key计算注意力,导致计算开销较高。
- 内存占用大:MHA需要存储多个Query、Key、Value头,特别是在大模型中,占用大量显存。
2. GQA(Group Query Attention,分组查询注意力)
2.1 GQA 的核心思想
GQA(分组查询注意力)是为了降低计算成本而提出的一种优化方案。它的主要改动在于:
- 多个Query共享一个Key-Value组,减少计算复杂度。
- 在视觉Transformer(ViT)等任务中表现良好,适用于大规模数据处理。
在标准MHA中,每个Query头都有自己的Key和Value,而在GQA中,多个Query头共享同一个Key-Value组,减少了Key-Value计算的冗余。
2.2 GQA 计算过程
- 将Query分组,设总共有h个Query头,我们将它们分为g组,每组的Query共享同一个Key-Value组:G=h/g
- 每个组的Query共享Key-Value:
- 拼接多个组的结果:
2.3 GQA 的优势
✅ 计算量降低:比MHA少计算Key-Value的开销,提高计算效率。
✅ 适用于CV任务:减少视觉Transformer在图像数据上的计算复杂度。
❌ 可能降低表达能力:由于Query共享Key-Value,可能会损失一定的灵活性。
3. MQA(Multi-Query Attention,多查询注意力)
3.1 MQA 的核心思想
MQA(多查询注意力)是GQA的一种极端情况:
- 所有Query共享一个Key-Value,极大减少计算量。
- 适用于大规模推理任务,如ChatGPT的解码阶段。
3.2 MQA 计算过程
- 所有Query头共享Key-Value:Kshared,Vshared
- 计算注意力:
- 拼接结果:
3.3 MQA 的优势
✅ 极大减少计算成本:适用于推理阶段,减少Key-Value计算量。
✅ 内存占用降低:适合处理超长文本,如GPT-4等大模型。
❌ 可能损失部分表达能力:仅有一个Key-Value可能影响多样性。
4. MLA(Multi-Layer Attention,多层注意力)
4.1 MLA 的核心思想
MLA(多层注意力)关注的是在不同层之间如何融合注意力信息,而不是在单个注意力层内进行优化。主要有两种实现方式:
- 层级MHA(Hierarchical MHA):每一层的注意力结果影响下一层。
- 跨层注意力(Cross-Layer Attention):不同层的信息进行融合。
4.2 MLA 计算方式
- 引入跨层信息:
- 增强跨层表示:
4.3 MLA 的优势
✅ 跨层信息融合,减少每一层的冗余计算。
✅ 提高信息利用率,适合深层Transformer。
5. 总结
机制 | 计算量 | 适用场景 |
---|---|---|
MHA | O(s^2d) | 适用于通用Transformer |
GQA | O(s^2d/g) | 适用于长文本处理 |
MQA | O(sd) | 适用于推理优化 |
MLA | 适中 | 适用于跨层信息融合 |
不同Transformer版本中的计算顺序:
Transformer版本 | 计算顺序 | 存储优化 |
---|---|---|
标准Transformer | 计算Q, K, V → RoPE → 计算注意力 | KV存储完整矩阵,消耗大 |
MQA(多查询注意力) | 降维Key-Value(单KV组) → 计算时恢复 → RoPE | 极致降低KV存储 |
GQA(分组查询注意力) | Query分组,每组共享Key-Value → 计算时恢复 → RoPE | 适中存储消耗 |
MLA(多层注意力) | 低秩存储KV → 计算时恢复 → RoPE | 适用于超长上下文 |
GQA在MQA的基础上,将多个Query分成若干组,每组共享Key-Value,从而在计算量和表达能力之间找到更好的平衡点。由于所有Query共享同一Key-Value,模型的表达能力下降,不能很好地区分不同Query头的语义信息。解决方案:GQA让Query进行分组,不同组共享不同的Key-Value,从而在计算效率和表达能力之间找到平衡。
6. 论文
MHA(多头注意力)最早由Vaswani等人在2017年的论文“Attention Is All You Need”中提出:
Reference: Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., ... & Polosukhin, I. (2017). "Attention is all you need." Advances in neural information processing systems (NeurIPS).
论文链接:Attention Is All You Need
GQA由Ainslie等人在2023年的论文“GQA:Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints”提出:
Reference: Ainslie, Joshua, Santiago Ontañón, Chris Alberti, and Llion Jones. "Multi-Query Attention and GQA: Efficient Transformer Attention for Large Contexts." arXiv preprint arXiv:2305.13245 (2023).
论文链接:GQA:Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
MQA由Shazeer在2019年的论文“Fast Transformer Decoding”提出:
Reference: Shazeer, N. (2019). "Fast Transformer Decoding." arXiv preprint arXiv:1911.02150.
论文链接:Fast Transformer Decoding
MLA由Li等人在2024年的论文“Multi-Layer Attention for Efficient Transformer Models”提出:
Reference: Li, X., Zhou, X., Zhang, T., Wu, Y., Zhang, Y., & Fu, J. (2021). "Multi-Layer Attention for Efficient Transformer Models." arXiv preprint arXiv:2107.02192.
论文链接:Multi-Layer Attention for Efficient Transforer Models