大模型中 KV Cache 原理及显存占用分析
大模型中 KV Cache 原理及显存占用分析
在大模型推理过程中,KV Cache 是一个关键的技术细节,它不仅优化了计算效率,还显著影响了显存占用。本文将深入探讨 KV Cache 的工作原理及其对显存的影响,以帮助读者更好地理解大模型的推理机制。
Self-Attention 与 KV Cache
如图,当新生成的 token x 进到模型计算 Attention 时,先分别乘上参数矩阵W q W_qWq 、W k W_kWk 、W v W_vWv 得到向量 q,以及矩阵 K、V。然后根据下面公式计算当前 token 跟前面 tokens 的注意力权重(本文为了简化,不考虑多头 MHA)。
自回归生成过程中,K和V矩阵并没有太大变化,比如下图中 cold 这个词对应了 K 的某一列和 V 的某一行,算完就放那里不再变了。
轮到生成 chill 这个词时,其实只需要在原始 K 矩阵追加一列,原始 V 矩阵追加一行,而没必要每生成一个 token 都重新计算一遍 K、V 矩阵,这便是 KV Cache 的意义。
因此在推理的时候,不用每次传入前面全部 token 序列的 embedding,而只需传入 KV Cache 以及当前 token x 的 embedding。Transformer 在算完当前 token x 的 Attention 之后,会把新的 K’ 和 V’ 更新到 GPU 显存中。 左图中 Masked Multi Self Attention 这块也是唯一和前面序列有交互的模块,其他模块(比如 Layer Norm、FFN、位置编码等)都不涉及跟已生成 token 的交互。
KV Cache 显存占用分析
KV Cache 显存计算方式如下:
2 ∗ p r e c i s i o n ∗ n l a y e r ∗ d m o d e l ∗ s e q _ l e n ∗ b a t c h _ s i z e 2 * precision * n_{layer} * d_{model} * seq_len * batch_size2∗precision∗nlayer ∗dmodel ∗seq_len∗batch_size
- 2 22是指 K 跟 V 俩矩阵。
- p r e c i s i o n precisionprecision是模型每个参数的字节数,比如 fp32 精度下每个参数 4 字节。
- n l a y e r n_{layer}nlayer 和n m o d e l n_{model}nmodel 分别是模型 Decoder layer 层数和 embedding 维度大小。
- s e q _ l e n seq_lenseq_len、b a t c h _ s i z e batch_sizebatch_size顾名思义分别是最大序列长度和 global batch size。
比如以 OPT-30B 模型(bf16,48层,7168维,1024上下文,128 batch size)为例,KV Cache 占的显存是:
2 ∗ 2 ∗ 48 ∗ 7168 ∗ 1024 ∗ 128 = 180 , 388 , 626 , 432 b y t e s ≈ 180 G B 224871681024*128 \=180,388,626,432 bytes \≈ 180GB2∗2∗48∗7168∗1024∗128=180,388,626,432bytes≈180GB
模型本身仅占显存:2 ∗ 30 B = 60 B b y t e s = 60 G B 2*30B=60Bbytes=60GB2∗30B=60Bbytes=60GB
光 KV Cache 就顶模型本身占显存的3倍。(当然一般推理时 batch size是1,这时候KV Cache显存占用就砍到 1/128 了,不过 batch 模式能够最大化利用显存,所以这也是为啥各个大模型厂商 batch 模型都比较便宜了)
参考资料:油管《The KV Cache: Memory Usage in Transformers》