问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

通俗易懂的KVcache图解

创作时间:
2025-03-14 08:30:13
作者:
@小白创作中心

通俗易懂的KVcache图解

引用
CSDN
1.
https://m.blog.csdn.net/wlxsp/article/details/143575031

在分享之前先提出三个问题:

  1. 为什么KVCache不保存Q
  2. KVCache如何减少计算量
  3. 为什么模型回答的长度不会影响回答速度?

本文将带着这3个问题来详解KVcache

KVcache是什么

kv cache是指一种用于提升大模型推理性能的技术,通过缓存注意力机制中的键值(Key-Value)对来减少冗余计算,从而提高模型推理的速度。

不懂Self Attention的可以先去看这篇文章:

原因

首先要知道大模型进行推理任务时,是一个token一个token进行输出的。

例:给GPT一个任务 “对这个句子进行扩充:我爱“

GPT的输出为:

我爱
我爱中
我爱中国
我爱中国美
我爱中国美食
我爱中国美食,
我爱中国美食,因
我爱中国美食,因为
我爱中国美食,因为它
我爱中国美食,因为它好
我爱中国美食,因为它好吃
我爱中国美食,因为它好吃。

通过这个例子可以看出它生成句子是按token输出的(为了方便理解,假设一个字为一个token)。输出的token会与输入的tokens 拼接在一起,然后作为下一次推理的输入,这样不断反复直到遇到终止符后结束。自回归任务中,token只能和之前的文本做attention计算。

KVcache图解原理

将这个prompt通过embedding生成QKV三个向量。

“我”只能对自己做attention。得到

“爱”的Q向量对“我”和“爱”的K向量进行计算后再对V进行加权求和算得新向量后输出

输入到模型后得到新的token“中国”

重复上述过程

可以发现 在此过程中,新token只与之前token的KV有关系,和之前的Q没关系,因此可以将之前的KV进行保存,就不用再次计算。这就是KVcache。

问题回答

问题1:为什么不保存Q

因为每次运算只有当前token的Q向量,之前token的Q根本不需要计算,所以缓存Q没意义。

问题2:KVCache如何减少计算量

减少的就是不用重复计算之前token的KV向量,但是每个新词的Attention还得计算。

问题3:每次推理过程的输入tokens都变长了,为什么推理FLOPs不随之增大而是保持恒定呢?

因为使用了KVcache导致第i+1 轮输入数据只比第i轮输入数据新增了一个token,其他全部相同!因此第i+1轮推理时必然包含了第 i 轮的部分计算。

代码实现

这是自己实现的一个简单的多头注意力机制+KVcache

具体如何实现多头注意力机制可以看这篇文章:


import torch
import torch.nn as nn
import math
class MyMultiheadAttentionKV(nn.Module):
    def __init__(self, hidden_dim: int = 1024, heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.heads_num = heads
        self.dropout = nn.Dropout(dropout)
        self.head_dim = hidden_dim // self.heads_num
        self.Wq = nn.Linear(self.hidden_dim, self.hidden_dim)  # (hidden_dim, heads_num * head_dim)
        self.Wk = nn.Linear(self.hidden_dim, self.hidden_dim)  # (hidden_dim, heads_num * head_dim)
        self.Wv = nn.Linear(self.hidden_dim, self.hidden_dim)  # (hidden_dim, heads_num * head_dim)
        self.outputlayer = nn.Linear(self.hidden_dim, self.hidden_dim)
    def forward(self, x, mask=None, key_cache=None, value_cache=None):
        # x = (batch_size, seq_len, hidden_dim)
        query = self.Wq(x)
        key = self.Wk(x)
        value = self.Wv(x)
        bs, seq_len, _ = x.size()
        # Reshape to (batch_size, heads_num, seq_len, head_dim)
        query = query.view(bs, seq_len, self.heads_num, self.head_dim).transpose(1, 2)
        
        key = key.view(bs, seq_len, self.heads_num, self.head_dim).transpose(1, 2)
        value = value.view(bs, seq_len, self.heads_num, self.head_dim).transpose(1, 2)
        # Cache key and value if provided
        if key_cache is not None and value_cache is not None:
            key = torch.cat([key_cache, key], dim=2)  # Append along sequence dimension
            value = torch.cat([value_cache, value], dim=2)
        
        # Update caches
        key_cache = key
        value_cache = value
        # Calculate attention scores
        score = query @ key.transpose(-1, -2) / math.sqrt(self.head_dim)  # (batch_size, heads_num, seq_len, seq_len)
        if mask is not None:
            # Mask size should match the updated sequence length after cache concatenation
            mask = mask[:, :, :key.size(3), :key.size(3)]  # Crop the mask to the new size
        
        score = torch.softmax(score, dim=-1)
        score = self.dropout(score)
        output = score @ value  # (batch_size, heads_num, seq_len, head_dim)
        # Reshape back to (batch_size, seq_len, hidden_dim)
        output = output.transpose(1, 2).contiguous().view(bs, seq_len, -1)
        output = self.outputlayer(output)
        return output, key_cache, value_cache  # Return output and updated caches  

测试代码


def test_kvcache():
    
    torch.manual_seed(42)
   
    batch_size, seq_len, hidden_dim, heads_num = 3000, 100, 128, 8
    x = torch.rand(batch_size, seq_len, hidden_dim)  # Random input data
    attention_mask = torch.randint(0, 2, (batch_size, 1, seq_len, seq_len))  # Attention mask
    
    net = MyMultiheadAttentionKV(hidden_dim, heads_num)
    
    output, key_cache, value_cache = net(x, attention_mask)
    
    new_x = torch.rand(batch_size, seq_len, hidden_dim)  
    output, key_cache, value_cache = net(new_x, attention_mask, key_cache, value_cache)
    third_x = torch.rand(batch_size, seq_len, hidden_dim) 
    output, key_cache, value_cache = net(third_x, attention_mask, key_cache, value_cache)
    
    print(f"Output shape: {output.shape}")
    print(f"Key cache shape: {key_cache.shape}")
    print(f"Value cache shape: {value_cache.shape}")
    
# Run the test
if __name__ == "__main__":
    test_kvcache()  

使用KVcache后:

其实,KV Cache 配置开启后,推理过程可以分为2个阶段:

  1. 预填充阶段:发生在计算第一个输出token过程中,这时Cache是空的,计算时需要为每个 transformer layer 计算并保存key cache和value cache,在输出token时Cache完成填充;FLOPs同KV Cache关闭一致,存在大量gemm操作,推理速度慢。

  2. 使用KV Cache阶段:发生在计算第二个输出token至最后一个token过程中,这时Cache是有值的,每轮推理只需读取Cache,同时将当前轮计算出的新的Key、Value追加写入至Cache;FLOPs降低,gemm变为gemv操作,推理速度相对第一阶段变快,这时属于Memory-bound类型计算。

总结

KV Cache是Transformer推理性能优化的一项重要工程化技术,各大推理框架都已实现并将其进行了封装(例如 transformers库 generate 函数已经将其封装,用户不需要手动传入past_key_values)并默认开启(config.json文件中use_cache=True)。

参考:https://zhuanlan.zhihu.com/p/630832593

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号