问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

大模型推理优化之 KV Cache

创作时间:
作者:
@小白创作中心

大模型推理优化之 KV Cache

引用
CSDN
1.
https://blog.csdn.net/javastart/article/details/138176497

在语言模型推理的过程中,性能优化一直是一个备受关注的话题。LLM(Large Language Models)的出现使得自然语言处理取得了显著的进展,但随之而来的是庞大的模型和复杂的计算过程,因此推理效率的提升变得至关重要。在这个背景下,KV Cache(键-值缓存)成为了一项被广泛应用的推理优化技术。

KV Cache 定义

KV Cache,即键-值缓存,是一种用于存储键值对数据的缓存机制。在语言模型的推理过程中,经常需要多次访问相同的数据,而KV Cache通过将这些数据缓存到内存中,提供了快速的数据访问速度,从而加速推理过程。该技术仅应用于解码阶段。如 decode only 模型(如 GPT3、Llama 等)、encode-decode 模型(如 T5)的 decode 阶段,像 Bert 等非生成式模型并不适用。

KV Cache 原理

推理过程:给定一个问题,模型会输出一个回答。生成回答的过程每次只生成一个 token,输出的 token会和问题拼接在一起,再次作为输入传给模型,这样不断重复直至生成终止符停止。

GPT-4推理过程图

下图是Scaled dot-product attention 有无 KV Cache 优化计算过程的比较。一般情况下,在每个生成步骤中,都会重新计算之前token的注意力,而实际上我们只想计算新 token 的注意力。而采用 KV Cache 方法后,会把之前 Token的 KV 值存下来,新 token 预测时只需要从缓存中读取结果就可以了。


Scaled dot-product attention 有无 KV Cache 比较,图片来源:https://medium.com/@joaolages/kv-caching-explained-276520203249

KV Cache 实现

huggingface的 transformer库已经实现了 KV cache,在推理时新增了past_key_values,设置 use_cache=True 或 config.use_cache=True 就可以了。

past_key_values(
Cache
or
tuple(tuple(torch.FloatTensor))
,optional) — Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the
past_key_values
returned by the model at a previous stage of decoding, when
use_cache=True
or
config.use_cache=True
.

import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_path = "Llama-2-7b-chat-hf"
device = "cuda:7" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map=device)

for use_cache in (True, False):
    times = []
    for _ in range(10):  # measuring 10 generations
        start = time.time()
        input = tokenizer("What is KV caching?", return_tensors="pt").to(device)
        outputs = model.generate(**input, use_cache=use_cache, max_new_tokens=1000, temperature=0.00001)
        times.append(time.time() - start)
    print(f"{'With' if use_cache else 'Without'} KV caching: {round(np.mean(times), 3)} +- {round(np.std(times), 3)} seconds")

执行结果如下所示:

With KV caching: 8.946 +- 0.011 seconds
Without KV caching: 58.68 +- 0.012 seconds

从结果可以看出使用 KV Cache 方法进行大模型推理,推理速度增加了6.56倍,差异巨大。

KV Cache 显存占用分析

假设输入的序列长度是 𝑚,输出序列长度是 𝑛 , 𝑏 为数据批次大小, 𝑙 为层数, ℎ 为隐向量维度,以 FP16(2bytes) 来保存,那么 KV Cache的峰值显存占用大小为 𝑏(𝑚+𝑛)ℎ∗𝑙∗2∗2=4𝑏𝑙ℎ(𝑚+𝑛) ,第一个 2 代表 K、V,第二个 2 代表 2bytes。可见随着批次大小和长度的增加,KV Cache 的显存占用也会快速增大。

KV Cache 优化方法

这里主要介绍下 Multi Query Attention 和 Grouped Query Attention。

Multi Query Attention

Multi-query attention is identical except that the different heads share a single set of keys and values.
MQA 和 MHA的区别是:每个头共享相同的 K、V 权重而不共享Q的权重。

Grouped Query Attention

Grouped-query attention divides query heads intoG groups, each of which shares a single key head and value head.
分组注意力将查询头分为 G 组,每组共享一个键头和值头。GQA-G 是指有 G 组的分组查询。GQA-1,有一个组,因此有一个键头和值头,等同于 MQA,而 GQA-H,组数等于头数,等同于 MHA。

MHA、MQA、GQA比较可参考下图。

Multi Head Attention、Multi Query Attention、Grouped Query Attention 比较

使用MHA、MQA、GQA进行KV Cache 显存占用情况比较

MHA: 𝑏(𝑚+𝑛)ℎ∗𝑙∗2∗2=4𝑏𝑙ℎ(𝑚+𝑛) ;
MQA: 4𝑏𝑙ℎ(𝑚+𝑛)/𝐻 , 𝐻 代表头数;
GQA: 4𝑏𝑙ℎ(𝑚+𝑛)∗𝐺/𝐻 , 𝐻 代表头数, 𝐺 代表分组数;

MQA、GQA Huggingface 库都有实现,具体见llm_tutorial_optimization。

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号