DeepSeek中的多头潜在注意力(MLA)浅尝
DeepSeek中的多头潜在注意力(MLA)浅尝
多头潜在注意力(MLA)是DeepSeek团队提出的一种创新的注意力机制,它结合了多头注意力机制和潜在表示学习的优点,有效解决了传统多头注意力机制在计算效率和内存占用方面的局限性。本文将详细介绍MLA的原理和实现方法。
MHA(多头注意力)
在深入探讨MLA之前,我们先回顾一下多头注意力(MHA)的基本原理。MHA通过将输入向量分割成多个并行的注意力“头”,每个头独立地计算注意力权重并产生输出,然后将这些输出通过拼接和线性变换进行合并以生成最终的注意力表示。
具体来说,假设输入的查询(Q)、键(K)和值(V)的形状分别为:
- Q的形状:[seq, di]
- K的形状:[di, seq]
- V的形状:[seq, di]
则每个注意力头的计算过程可以表示为:
[ \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) ]
其中,(W_i^Q)、(W_i^K)和(W_i^V)是可学习的权重矩阵。最终的多头注意力输出通过以下方式合并:
[ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h)W^O ]
MHA能够理解输入不同部分之间的关系。然而,这种复杂性是有代价的——对内存带宽的需求很大,尤其是在解码器推理期间。主要问题的关键在于内存开销。在自回归模型中,每个解码步骤都需要加载解码器权重以及所有注意键和值。这个过程不仅计算量大,而且内存带宽也大。随着模型规模的扩大,这种开销也会增加,使得扩展变得越来越艰巨。
MLA(多头潜在注意力)
概念
- 多头注意力机制:Transformer的核心模块,能够通过多个注意力头并行捕捉输入序列中的多样化特征。
- 潜在表示学习:通过将高维输入映射到低维潜在空间,可以提取更抽象的语义特征,同时有效减少计算复杂度。
问题
- 效率问题:传统多头注意力的计算复杂度为(O(n^2d)),即随着序列长度的增长,键值(Key-Value,KV)缓存的大小也会线性增加,这给模型带来了显著的内存负担。
- 表达能力瓶颈:难以充分捕捉复杂全局依赖。
MLA的提出
MLA将多头注意力机制与潜在表示学习相结合,解决MHA在高计算成本和KV缓存方面的局限性。
MLA的具体做法(创新点)
采用低秩联合压缩键值技术,优化了键值(KV)矩阵,显著减少了内存消耗并提高了推理效率。
如上图,在MHA、GQA中大量存在于keys values中的KV缓存——带阴影表示,到了MLA中时,只有一小部分的被压缩Compressed的Latent KV了。
并且,在推理阶段,MHA需要缓存独立的键(Key)和值(Value)矩阵,这会增加内存和计算开销。而MLA通过低秩矩阵分解技术,显著减小了存储的KV(Key-Value)的维度,从而降低了内存占用。
MLA的核心步骤
- 输入映射->潜在空间
给定输入(X \in \mathbb{R}^{n \times d})(其中n是序列长度,d是特征维度),通过映射函数f将其投影到潜在空间:
[ Z = f(X) \in \mathbb{R}^{n \times k}, \quad k \ll d ]
f(·)可为全连接层、卷积层等映射模块,潜在维度k是显著降低计算复杂度的关键。
- 潜在空间中的多头注意力计算
在潜在空间Z上进行多头注意力计算。对于第i个注意力头,其计算公式为:
[ \text{Attention}_i = \text{Softmax}\left(\frac{Q_i \cdot K_i^T}{\sqrt{d_k}}\right)V_i ]
其中:
- (Q_i = ZW_i^Q, K_i = ZW_i^K, V_i = ZW_i^V)分别为查询、键和值;
- (W_i^Q, W_i^K, W_i^V \in R^{k \times d_k})是可学习的投影矩阵;
- (d_k = k / h)是每个注意力头的维度(h是头数)。
将所有注意力头的输出拼接后再通过线性变换:
[ \text{MultiHead}(Z) = \text{Concat}(\text{Attention}_1, \ldots, \text{Attention}_h)W^O ]
其中(W^O \in R^{hd_k \times k})是输出投影矩阵。
- 映射回原始空间
将多头注意力结果从潜在空间映射回原始空间:
[ Y = g(\text{MultiHead}(Z)) \in \mathbb{R}^{n \times d} ]
g(·)为非线性变换,如全连接层。