Transformer KV Cache原理深入浅出
Transformer KV Cache原理深入浅出
Transformer KV Cache(键值缓存)是一种应用于仅解码器Transformer中的工程优化技术,通过增加内存利用率来减少计算量。本文将深入浅出地介绍KV Cache的原理,帮助读者理解这一重要优化技术。
1. 前言
Transformer KV Cache(键值缓存)是一种应用于仅解码器Transformer中的工程优化技术,通过增加内存利用率来减少计算量。在讨论KV Cache之前,我想起之前与行业内的大佬交流时提到,大模型时代的到来要求我们放下既有的思维模式,才能更好地接受并理解大模型的理念和实践。对此我部分赞同。从宏观角度看,大模型的表现确实是颠覆性的,难以用传统的认知框架来解释。我曾在《压缩泛化-对大语言模型智能涌现的理解》一文中探讨过,大模型正在做的“压缩即智能”这件事,是相当伟大的创新。然而,从技术方案和技术细节的角度来看,大模型依然在沿用大量已有的方法,并不是凭空产生的。因此,仍然可以从原有的知识体系出发,迁移学习大模型的相关知识。这两种观点并不矛盾,只是视角和立场不同,从整体上看是统一的,因此不必对大模型产生畏难情绪或认知偏差。
其实,从事过隐私计算相关算法实践的人会理解,隐私计算同样需要用新的认知方式去审视,不能完全套用传统的数据安全保护方法,否则难以抓住其核心。然而,当我们从算法实现的细节来看待隐私计算时,会发现它其实将现有的机器学习、分布式计算、密码学等知识在隐私计算的层面实现了统一。
回到本文讨论的KV Cache,实话实说,其技术实现和原理都相对简单,有一定大模型工程优化经验的人一般都能想到这种方法。因此,在学习和理解大模型时,保持平和的心态至关重要。
2. KV Cache算法原理介绍
2.1 知识背景
首先回顾GPT中Transformer的结构,可以回看《GPT系列预训练模型原理讲解》:
- Transformer的输入是一个序列的tokens(或这些tokens的批处理)。
- 在GPT Transformer中,每一块都有一个注意力层和一个前馈层。
- Transformer中的几乎每一层/操作都是基于每个token的,除了注意力层(仅其中的一部分)。
- 注意力层有多个头,通常情况下,d_model = d_head * n_heads。
在生成推理阶段:每次前向传递都会生成一个token,将其附加到输入中,然后将输入(现在的序列长度+1)再次传递回模型进行下一次前向传递。然而,这种朴素的方法显然非常低效,因为:
- 它重新计算了先前的键、值和注意力行。
- 网络中基于每个token的其他部分也会再次被使用,浪费了计算资源。
2.2 自注意力层计算量的二次增长问题分析
首先分析在Transformer模型中的多头注意力(MHA)层的处理过程,假设只处理一个长度为t的序列(即batchsize大小为1):
- 在整个过程中,输入序列中的每个token都由一个稠密向量表示。
- 注意力层的输入是一系列稠密向量,每个输入token对应一个,由前一解码器块生成。
- 对于每个输入向量,注意力层生成一个相同维度的输出稠密向量。
考虑单个注意力头的情况:
- 首先,使用三个不同的投影为每个输入向量生成三个低维稠密向量:查询向量(query)、键向量(key)和值向量(value)。因此会有t个查询向量、t个键向量和t个值向量。
- 对于每个查询向量,生成的输出向量等于所有值向量的线性组合,线性组合的系数为注意力得分。对于每个查询,对应的输出向量是这些值向量的注意力加权平均。注意力得分是通过查询向量与每个键向量的点积计算得到的。通过这种方式,为序列中的每个token生成了一个包含其他tokens信息的上下文表示。
- 在自回归解码的情况下,不能使用所有可能的值向量来为特定查询生成输出表示。实际上,在计算与特定token相关的查询输出时,不能使用序列中后续token的值向量。这种限制通过一种Mask掩码的方式来实现,该技术本质上将后续token的注意力得分设置为零。
- 最后,每个注意力头的输出被连接起来,并通过最后的线性变换得到最终输出。
注意力计算的二次增长
计算注意力得分所需的浮点运算数(FLOPs)。对于给定的注意力头,对于一个批量大小为batch且总长度为t的序列(包括提示和生成的完成序列),注意力得分矩阵通过一个形状为(t, d_head)的查询张量与一个转置的形状为(d_head, t)的键张量相乘而生成。
单次矩阵乘法需要多少FLOPs?形状为(m,n)的矩阵与另一个形状为(n,p)的矩阵乘法大约涉及2mnp次操作,这里加法次数需要注意一下。
在本例子中,单头单序列的注意力得分计算大约需要
次FLOPs。总体而言,注意力得分计算所需的FLOPs为
。显然,计算量随t的二次增长显而易见。
2.3 KV缓存是什么
接下来我们引出本文的焦点:KV Cache。
先来看一下,当一个新token被添加到输入x中时:
- 在q、k、v中各增加一行【4】。
- attention矩阵(att)的尺寸从(T, T)变为(T+1, T+1),即新增了一行和一列。
- 新增的这一行表示新token与所有之前token之间的注意力(attention)。
- 由于att矩阵被掩码为下三角矩阵,新增的这一列除了最后一行的值之外,其余位置都是零。
- att矩阵中的新增行将导致在v中产生一个新的输出,这个输出是att矩阵的最后一行与v的所有列相乘得到的结果,即
att[-1, :] @ v
。
大家是否发现了规律,当新添加一个token后,查询矩阵、键矩阵和值矩阵只是多了一行,而之前的行不受影响。注意力矩阵也只是多了一行,因此,输出也只有一行额外的行。由于自注意力掩码的作用,所有先前的行都不受影响。
【2】中给出了可视化的例子:
以下是将最初的两个token传递给单个注意力头后,查询(Query)、键(Key)、值(Value)、注意力矩阵(Attention Matrix)和输出值(Output Values)的样子:(这里的头维度 d_head = 4)。
现在再添加一个token后:
即使再添加一个新的token,attention weights的前面几行也都不受影响。
因此:
现在可以理解 KV 缓存是如何工作的:对于生成的每个新token,不需要传入整个序列,因此可以避免重新计算整个注意力矩阵。只需要以下面的方式对新token进行操作:
- 仅为新token计算新的 q、k、v 行。
- 新的 q 行将立即被使用。(这也解释了为什么没有查询缓存的原因)
- 将新的键、值条目附加到现有的 K、V 缓存中。
- 通过新的 q 行和 k_cache 的转置进行矩阵向量乘法来计算新的注意力行。
- 通过新的注意力行和 v_cache 的转置进行矩阵向量乘法来计算新的 v 行。
- 输出(仅针对最新标记)被传递到下一层。
这是一种通过增加内存使用来节省重复计算的权衡。
2.4 KV Cache的内存占用分析
既然KV Cache时通过增加内存来降低重复计算的量,那么有必要分析一下内存占用大小,其实后续针对KV Cache又提出了很多节省内存占用的方法,比如【5,6,7】,有兴趣可以看下优化缓存的分析。本文先主要关注内存占用的分析。
假设模型的参数配置信息如下:
- Transformer 中有
n_layers
个层块。 - 每个层块中有一个多头注意力层。
- 每个多头注意力层有
n_heads
个注意力头,每个头的
k
和
v
的尺寸为
d_head
。 - 需要为
K
和
V
都缓存一份。 - 最大上下文长度为
n_context
。 - 精度为
n_bytes
,例如对于 FP32 是 4。 - 推理时的批量大小为
batch_size
。
那么总的内存大小:
kv_cache_size = n_layers * n_heads * 2 * n_context * d_head * n_bytes * batch_size
简化后:
kv_cache_size = 2 * n_bytes * n_layers * d_model * n_context * batch_size
例如,针对 OPT-30B 模型的内存量级计算【4】:
- n_bytes = 2
(FP16) - n_layers = 48
- d_model = 7168
- n_context = 1024
- batch = 128
计算结果为 180,388,626,432 字节,约为 180 GB。
2.5 KV缓存实现线性注意力增长
那么注意力模块的计算量随着缓存的引入发生了什么变化?
转置后的键张量仍然是形状(t, d_head),但查询张量现在的形状是(d_head, 1)。因此,单头单序列的注意力得分计算现在需要
次FLOPs,而总体的注意力计算需要
次FLOPs。注意力计算现在随序列总长度线性增长。
2.6 结论
注意力得分的计算量随着序列总长度呈二次增长。由于注意力计算中的掩码机制,在每次生成步骤中,实际上可以避免为过去的token重新计算键和值向量。每次计算新的键和值向量时,可以将它们缓存到GPU内存中,以便在后续的迭代中重复使用,从而避免重新计算。引入这种优化策略后,注意力机制的FLOPs随总序列长度实现了线性增长。
3. 参考材料
【1】LLM Inference Series: 3. KV caching explained
【2】What is the Transformer KV Cache?
【3】Transformers KV Caching Explained
【4】The KV Cache: Memory Usage in Transformers
【5】LLM profiling guides KV cache optimization
【6】vLLM: Easy, Fast, and Cheap LLM Serving with PagedAttention
【7】PagedAttention