大模型推理中的KV Cache原理详解
大模型推理中的KV Cache原理详解
KV Cache是大模型推理中的一个重要优化技术,通过缓存注意力机制中的键值对,可以减少重复计算,加速模型的推理过程。本文将从基本原理出发,通过矩阵运算的性质,详细解释KV Cache的工作机制,并通过具体的计算步骤和示例图解,展示其在自回归生成过程中的应用。
基本原理
两句话说明 KV Cache 的原理:
- 由于Attention中的mask机制(前面token的Q,看不到后面token的K和V),当前生成的token只依赖于 当前输入的token以及之前所有token的KV矩阵
- 因此 可以通过 缓存 注意力机制中的键值对,在自回归生成过程中减少重复计算,加速大模型的推理过程
矩阵运算
要想理解 KV Cache,需要先了解一下矩阵运算的一个基本性质:分块矩阵计算法则,也就是两个矩阵相乘可以拆分成 列向量与行向量运算的和,在张量并行中也有用到:
$$
\left[\begin{array}{ll}
X_1 & X_2
\end{array}\right] \times\left[\begin{array}{l}
A_1 \
A_2
\end{array}\right]=\left[X_1 A_1+X_2 A_2\right]=Y
$$
在这个性质的基础上,此时如果矩阵X是一个下三角矩阵,就有:
$$
\begin{aligned}
& \left[\begin{array}{cccc}
X_{1,1} & 0 & \cdots & 0 \
X_{2,1} & X_{2,2} & \cdots & 0 \
\vdots & \vdots & \ddots & \vdots \
X_{m, 1} & X_{m, 2} & \cdots & X_{m, n}
\end{array}\right] \cdot\left[\begin{array}{cccc}
Y_{1,1} & Y_{1,2} & \cdots & Y_{1, p} \
Y_{2,1} & Y_{2,2} & \cdots & Y_{2, p} \
\vdots & \vdots & \ddots & \vdots \
Y_{n, 1} & Y_{n, 2} & \cdots & Y_{n, p}
\end{array}\right] \
& =\left[\begin{array}{l}
X_{1,1} \vec{Y}1 \
X{2,1} \vec{Y}1+X{2,2} \vec{Y}2 \
\vdots \
X{m, 1} \vec{Y}1+X{m, 2} \vec{Y}2+\cdots+X{m, n} \vec{Y}_n
\end{array}\right]
\end{aligned}
$$
可以看到,结果矩阵的第k行只用到了矩阵X的 第k个行向量。所以X不需要进行全部的矩阵乘法,每一步只取第k个行向量即可,这就很大程度上减少了计算量,也就是 KV Cache 的数学原理。
如果把矩阵X认为是 Attention 计算得到的 QK score,矩阵Y认为是V,上述计算流程基本上就是 KV Cache 的框架了。这里要注意一点,因为 现在大模型基本都是Decoder-only的架构 ,自回归生成的过程中,当前token在做attention计算时是看不到后面的,后面的内容都被 mask 了。所以 KV Cache 只有 Decoder-only架构才有,Encoder-only 比如 BERT模型 是没有的
KV Cache原理
下面用一个具体的例子来说明,KV Cache假设当前文本为
"我要"
, 要 大模型 输出的内容为
"学习AI"
第一次计算
模型第一次计算,生成
"学"
时的过程如下:
为了方便演示,忽略scale项$\sqrt{d}$,最终Attention的计算公式如下,(softmaxed 表示已经按行进行了softmax):
$$
\begin{aligned}
Att_{\text {step } 1}(Q, K, V)&=\operatorname{softmax}\left(\left[\begin{array}{cc}
Q_1 K_1^T & -\infty \
Q_2 K_1^T & Q_2 K_2^T
\end{array}\right]\right)\left[\begin{array}{l}
\overrightarrow{V_1} \
\overrightarrow{V_2}
\end{array}\right] \
& =\left(\left[\begin{array}{cc}
\operatorname{softmaxed}\left(Q_1 K_1^T\right) & 0 \
\operatorname{softmaxed}\left(Q_2 K_1^T\right) & \operatorname{softmaxed}\left(Q_2 K_2^T\right)
\end{array}\right]\right)\left[\begin{array}{l}
\overrightarrow{V_1} \
\overrightarrow{V_2}
\end{array}\right] \
& =\left(\left[\begin{array}{c}
\operatorname{softmaxed}\left(Q_1 K_1^T\right) \times \overrightarrow{V_1} \
\operatorname{softmaxed}\left(Q_2 K_1^T\right) \times \overrightarrow{V_1}+\operatorname{softmaxed}\left(Q_2 K_2^T\right) \times \overrightarrow{V_2}
\end{array}\right]\right)
\end{aligned}
$$
假设$Att_1$表示Attention结果的第一行,$Att_2$表示Attention结果的第二行,那么有下面的表示:
$$
\begin{aligned}
& Att_1(Q, K, V)=\operatorname{softmaxed}\left(Q_1 K_1^T\right) \vec{V}_1 \
& Att_2(Q, K, V)=\operatorname{softmaxed}\left(Q_2 K_1^T\right) \vec{V}_1+\operatorname{softmaxed}\left(Q_2 K_2^T\right) \vec{V}_2
\end{aligned}
$$
可以看到:
- 在计算$Att_1$时,$Q_1K_2^T$这个score会被mask掉,也就是当前token计算Attention时看不到后面的token信息
- 在计算$Att_2$时,仅仅依赖于$Q_2$以及$K_1^T, K_2^T$和$\vec{V}_1, \vec{V}_2$,与$Q_1$无关
第二次计算
如果没有KV Cache,模型进行第二次前向推理计算,生成
"习"
时,计算过程如下:
此时,Attention的计算公式为:
$$
\begin{aligned}
& \operatorname{Att}_1(Q, K, V)=\operatorname{softmaxed}\left(Q_1 K_1^T\right) \vec{V}_1 \
& \operatorname{Att}_2(Q, K, V)=\operatorname{softmaxed}\left(Q_2 K_1^T\right) \vec{V}_1+\operatorname{softmaxed}\left(Q_2 K_2^T\right) \vec{V}_2 \
& \operatorname{Att}_3(Q, K, V)=\operatorname{softmaxed}\left(Q_3 K_1^T\right) \vec{V}_1+\operatorname{softmaxed}\left(Q_3 K_2^T\right) \vec{V}_2+\operatorname{softmaxed}\left(Q_3 K_3^T\right) \vec{V}_3
\end{aligned}
$$
同样的,$Att_3$的计算结果,之和$Q_3$有关,与$Q_1, Q_2$无关,因此,可以模型的推理可以简化为:
这就是KV Cache的计算过程
第三次计算
同理,生成最后
"AI"
时,使用KV Cache计算过程如下:
$Att_4$的计算公式为:
$$
\begin{aligned}
\operatorname{Att}_4(Q, K, V)&=\operatorname{softmaxed}\left(Q_4 K_1^T\right) \overrightarrow{V_1} +\operatorname{softmaxed}\left(Q_4 K_2^T\right) \overrightarrow{V_2} \
& +\operatorname{softmaxed}\left(Q_4 K_3^T\right) \overrightarrow{V_3} +\operatorname{softmaxed}\left(Q_4 K_4^T\right) \overrightarrow{V_4}
\end{aligned}
$$
因此,可以看出:
- 不使用KV Cache的方法,存在大量冗余的计算,也就是要生成$Att_k$时,还需要重复计算$Att_1, .., Att_{k-1}$
- 计算$Att_k$时,之和$Q_k$有关,与之前的$Q_1, ... , Q_{k-1}$都没关系
- 生成第$x_k$个token时,只需要输入上一轮生成的$x_{k-1}$即可
所以每一步其实只需要根据$Q_k$计算$Att_k$就可以,但是K和V是全程参与计算的。从优化推理速度角度来看,只需要把每一步的K,V缓存起来就可以, 所以叫 KV Cache。
最后需要注意当 sequence 比较长,或者 batch 特别大的时候,KV Cache 其实还是个Memory刺客,所以如何减少 KV 的内存变得尤为重要。
目前各种框架,针对 KV Cache 做了优化,比如 vLLM 的 Page Attention, Prefix Caching,Token 的稀疏化,KV 共享或者压缩(MQA、GQA 和 MLA),LayerSkip,Mooncake 等等,可以说KV Cache 目前是推理的基石,各种基于 KV Cache 的优化方法撑起了大模型推理加速的半壁江山。