问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

大模型推理中的KV Cache原理与显存占用分析

创作时间:
作者:
@小白创作中心

大模型推理中的KV Cache原理与显存占用分析

引用
CSDN
1.
https://blog.csdn.net/muyao987/article/details/140364179

在大模型推理过程中,KV Cache是一个关键的技术优化手段,它通过缓存K和V矩阵来减少重复计算,从而降低显存占用和提高推理效率。本文将详细解析KV Cache的工作原理及其显存占用情况。

Self-Attention 与 KV Cache

在Transformer模型中,Self-Attention机制是核心组成部分。当新生成的token x进入模型计算Attention时,会分别乘上参数矩阵$W_q$、$W_k$、$W_v$得到向量q,以及矩阵K、V。然后根据下面的公式计算当前token与前面tokens的注意力权重(本文为了简化,不考虑多头MHA):

$$
Attention(q, K, V) = softmax(\frac{qK^T}{\sqrt{d_k}})V
$$

在自回归生成过程中,K和V矩阵并没有太大变化。例如,当计算单词"chill"时,实际上只需要在原始K矩阵追加一列,原始V矩阵追加一行,而不需要每生成一个token都重新计算一遍K、V矩阵。这就是KV Cache的意义。

因此,在推理时,我们不需要每次传入前面全部token序列的embedding,而只需传入KV Cache以及当前token x的embedding。Transformer在计算完当前token x的Attention之后,会把新的K'和V'更新到GPU显存中。如下图所示,Masked Multi Self Attention这块是唯一和前面序列有交互的模块,其他模块(比如Layer Norm、FFN、位置编码等)都不涉及与已生成token的交互。

KV Cache 显存占用分析

KV Cache的显存占用计算方式如下:

$$
2 * precision * n_{layer} * d_{model} * seq_len * batch_size
$$

  • 2是指K跟V俩矩阵。
  • $precision$是模型每个参数的字节数,比如fp32精度下每个参数4字节。
  • $n_{layer}$和$d_{model}$分别是模型Decoder layer层数和embedding维度大小。
  • $seq_len$和$batch_size$分别是最大序列长度和global batch size。

以OPT-30B模型为例(bf16,48层,7168维,1024上下文,128 batch size),KV Cache占用的显存是:

$$
2 * 2 * 48 * 7168 * 1024 * 128 = 180,388,626,432 bytes ≈ 180GB
$$

而模型本身仅占显存:

$$
2 * 30B = 60B bytes = 60GB
$$

可以看出,光KV Cache就占了模型本身显存的3倍。当然,一般推理时batch size是1,这时候KV Cache显存占用会大幅减少(约1/128)。但是,batch模式能够最大化利用显存,这也是为什么各个大模型厂商的batch模式推理通常更便宜的原因。

参考资料:油管《The KV Cache: Memory Usage in Transformers》

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号