问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

FlashAttention全解:Transformer模型的高效注意力机制

创作时间:
作者:
@小白创作中心

FlashAttention全解:Transformer模型的高效注意力机制

引用
1
来源
1.
https://www.cnblogs.com/mudou/p/18321760

FlashAttention是一种在GPU上加速注意力机制的方法,通过减少内存读写次数来提高计算效率。它在Transformer模型的训练和推理中发挥着重要作用,使得大语言模型的上下文长度在过去两年中大幅增加。本文将详细介绍FlashAttention算法及其优化版本FlashAttention-2和FlashAttention-3的核心原理和实现细节。

一、FlashAttention

1.1 硬件基础

在讨论FlashAttention之前,我们先了解一下GPU的硬件基础。以A100 80G为例,80G指的是GPU中的HBM存储,其上还有更为快速的SRAM,大小约为20MB。在注意力计算过程中,每一步计算产生的中间结果都需要存储到HBM中,复杂度为O(N^2)。由于SRAM非常有限,无法存储所有数据,因此需要寻找优化方法。

1.2 FlashAttention 核心思想

Flash Attention的核心思想是将计算模块化,将QKV分为若干个模块进行计算,在计算过程中不存储N×N的矩阵,最终只有输出O1涉及存储到HBM中。

1.3 计算前提

在实现FlashAttention时,需要考虑数值稳定性问题。为了避免数值溢出,可以将每个元素都减去最大值。同时,由于softmax操作按行计算,需要维护每行的最大值用于分块计算softmax的纠正偏差。

1.4 FlashAttention 算法

简要描述:外层循环遍历K^T,内层循环遍历Q。通过分块计算和在线softmax技术,减少内存读写次数,提高计算效率。

二、FlashAttention-2

FlashAttention-2相比第一代实现了2倍的速度提升,比PyTorch上的标准注意力操作快5~9倍。它通过减少非矩阵乘法FLOPs来进一步优化性能,特别是在现代GPU上,矩阵乘法的计算速度远高于非矩阵乘法。

2.1 硬件特性

GPU存在大量的线程(被称为kernel)用于执行一个操作。线程被组织为线程块,线程块被调度在 streaming multiprocessors (SMs) 上运行。在每个线程块内部,线程被分组为 warps (包含32个线程的线程组)。warp 内的线程可以通过 fast shuffle instructions 进行通信或协同执行矩阵乘法。线程块内的warps 可以通过对共享内存读写进行通信。

2.2 标准的注意力实现

标准的注意力实现将矩阵S和P存储到HBM,这需要O(N^2)内存。通常N≫d(通常N在1k-8k左右,d在64-128左右)。标准注意实现需要大量的内存访问,导致慢的wall-clock time。

2.3 Flash Attention-1

FlashAttention 应用传统的 tiling 来减少内存IO。通过显著减少内存读取/写入量,FlashAttention 实现了优化基线注意实现的 2-4 倍的 wall-clock 加速。

2.4 FlashAttention-2

FlashAttention-2对FlashAttention的算法进行了调整,以减少 the number of non-matmul FLOPs。这是因为现代 GPUs 有专门的计算单元(例如,Nvidia图形处理器上的 Tensor Cores ),可以使乘法更快。

2.4.1 算法调整

重新回到online softmax trick,做出了两个小的调整以减少非乘法FLOPs。在反向传播过程中, 不再需要保存最大值m(j)和指数和l(j), 只需要存储 对数指数和L(j)=m(j)+log(l(j))。

2.4.2 反向传播

FlashAttention-2的反向传播过程与FlashAttention几乎相同, 做出了一个小的调整, 仅使用每行的对数指数和L代替每行的最大值和每行的指数和。

3. 并行性

4. Warps之间的工作划分

三、FlashAttention-3

FlashAttention-3相比第二代又实现了1.5~2倍的速度提升。FlashAttention-3 将H100的FLOP利用率再次拉到75%。这一版的FlashAttention专攻H100 GPU,只能在H100或H800上运行,不支持其他GPU型号。

四、动画演示

动画演示链接

1. 标准注意力算法--短序列

初始时,S、P存储中间计算结果,O存储输出结果,以及Q、K、V均存储在HBM中。

计算过程包括将Q、K移动到SRAM,计算QK^T,依次计算Q_0和K_0, K_1, K_2, K_3, K_4的乘积,依次向后填充,将计算好的中间结果从SRAM移动到HBM,对SRAM中的中间结果计算softmax,每次计算一行,softmax结果P计算完成后,从SRAM移动到HBM,将SRAM中存储的Q和K清空,并将V从HBM移动到SRAM,进行后续计算,计算O_i = P_iV,并将结果从SRAM移动到HBM,最终将SRAM置空。

2. 标准注意力算法--长序列

计算过程与短序列一致。

3.标准注意力算法--中等长度序列

计算过程与短序列一致。

4. Flash Attention算法-长序列

初始状态与标准注意力算法类似,但通过分块计算和在线softmax技术,减少了内存读写次数,提高了计算效率。

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号