FlashAttention:通过优化自注意力机制减少GPU全局内存访问
FlashAttention:通过优化自注意力机制减少GPU全局内存访问
FlashAttention是一种优化自注意力机制的计算方法,通过减少对GPU全局内存(HBM)的访问来提高计算效率。本文将详细介绍NVIDIA GPU内存的整体架构,并深入解析FlashAttention的具体优化技术。
NVIDIA GPU 内存的整体架构
在讨论FlashAttention之前,我们先了解一下NVIDIA GPU内存的整体架构。GPU内存主要分为以下几种类型:
1. 全局内存(Global Memory)
- 位置:位于GPU板卡上的RAM存储芯片上(外部存储)。
- 容量:很大,例如NVIDIA H100有80GB全局内存。
- 访问权限:所有GPU线程都可以访问全局内存。
- 速度:虽然可以达到很高的带宽(例如3.35TB/s),但因为所有线程同时访问时,实际带宽可能会很低。
全局内存适合存放大量数据,但由于每个线程可能都在争抢带宽,访问速度相对较慢。
2. 本地内存(Local Memory)
- 位置:和全局内存一样,也位于GPU板卡的RAM存储芯片上(外部存储)。
- 容量:与全局内存类似,受限于显存大小。
- 访问权限:仅限于当前线程(即一个线程只能访问自己所用的本地内存)。
本地内存通常存储每个线程自己的临时数据,访问速度和全局内存相当。
3. 共享内存(Shared Memory)
- 位置:位于GPU核心内部(在流式多处理器,SM上),所以速度比全局内存快很多。
- 容量:容量很小,比如NVIDIA H100中每个线程块最多只能使用228KB。
- 访问权限:只能在同一个线程块内的线程之间共享使用(一个线程块内的多个线程可以一起访问)。
共享内存的速度非常快,适合线程块内部的并行计算、数据交换等操作,避免频繁访问全局内存。
4. 寄存器内存(Register Memory)
- 位置:位于GPU核心内部,每个流处理器都有自己的寄存器。
- 容量:非常小,每个线程只能拥有极少量的寄存器。
- 访问权限:只能被单个线程访问(寄存器是线程私有的)。
寄存器是GPU中最快的内存类型,存储线程的局部变量和中间计算结果。
5. 常量内存(Constant Memory)
- 位置:在GPU板卡上的RAM存储芯片上(外部存储)。
- 访问权限:所有线程都可以读,但不能写。
- 用途:存放不变的数据(常量),如常量表等。
6. 纹理内存(Texture Memory)
- 位置:在GPU板卡上的RAM存储芯片上(外部存储)。
- 访问权限:所有线程可以访问,但通过特殊的硬件单元(纹理单元)进行高效的读操作。
- 用途:用于存储纹理数据,常在图形处理和机器学习中使用。
整体结构
- 高带宽显存(HBM):全局内存和本地内存都位于HBM上,容量较大,但速度相对较慢。
- 芯片内存:共享内存和寄存器内存位于GPU核心内部,虽然容量较小,但速度极快。
总结
- 全局内存和本地内存:容量大,速度较慢,位于外部显存。
- 共享内存和寄存器内存:容量小,速度快,位于GPU内部。
- 共享内存:线程块内的线程可以共享访问,适合线程间的数据交换。
- 寄存器:只能被单个线程访问,速度最快,适合存放临时数据。
图1:NVIDIA GPU的整体内存结构图
FlashAttention的工作原理
那么FlashAttention如何通过优化自注意力机制的计算过程,来减少对GPU全局内存(HBM)的访问,从而提高计算效率?传统自注意力机制在计算时引入了两个中间矩阵S和P,这导致大量的全局内存访问,而FlashAttention通过分块计算和使用共享存储(SRAM)来减少这种瓶颈。
1. 传统Attention机制的问题
在自注意力机制中,计算时涉及以下步骤:
- S = Q × K:Q(查询矩阵)和K(键矩阵)的乘积得到注意力分数矩阵S。
- P = Softmax(S):对矩阵S做Softmax归一化,得到矩阵P。
- O = P × V:再将P和V(值矩阵)相乘,得到最终的输出O。
这些步骤中,传统的实现会将中间矩阵S和P写入到全局内存中。由于GPU的全局内存带宽有限,访问全局内存的速度比计算速度慢很多,这会成为计算效率的瓶颈。此外,矩阵S和P的大小与输入序列长度的平方成正比,当输入序列长度增加时,这些中间矩阵会占用大量的显存。
2. FlashAttention的改进
FlashAttention的目标是:
- 减少全局内存的访问:尽量避免频繁读写中间矩阵S和P。
- 利用共享存储:更多地利用GPU芯片内的共享存储(SRAM),以提高计算速度。
FlashAttention关键技术:
- 分块处理输入:将输入矩阵分成多个小块,避免一次性处理整个输入,减少对全局内存的依赖。
- 增量计算Softmax:FlashAttention通过在每个输入块上逐步传递,增量式地进行Softmax计算,而不是先计算出整个矩阵S后再做Softmax。这种方式避免了将S和P存储在全局内存中。
- 存储归一化因子:在前向传播中,不保存整个注意力矩阵P,而是只保存Softmax的归一化因子,减少全局内存的消耗。在反向传播时,通过这些归一化因子快速重算注意力分数,而不需要从全局内存中读取P。
3. 优势
- 减少全局内存访问量:通过不存储中间矩阵S和P,FlashAttention极大减少了全局内存的访问。即便重新计算注意力矩阵会增加一些计算量(FLOPs),但由于减少了慢速全局内存的访问,整体速度反而更快。
- 节省显存:由于不再存储中间的大型矩阵S和P,内存占用大大减少,尤其对于长序列的输入更加明显。
图2:FlashAttention计算流程图
4. PyTorch 2.0中的支持
在PyTorch 2.0中,可以通过以下方式开启或关闭FlashAttention的功能:
torch.backends.cuda.enable_flash_sdp()
这为用户提供了直接使用这一高效优化的注意力机制的接口。
总结
FlashAttention通过在计算过程中尽可能利用GPU内部的共享存储来减少全局内存的访问,避免存储大规模的中间矩阵S和P,从而加快计算速度并减少显存消耗。这种改进方式特别适合处理长序列输入的自注意力机制。