问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

FlashAttention:具有IO感知,快速且内存高效的新型注意力算法

创作时间:
作者:
@小白创作中心

FlashAttention:具有IO感知,快速且内存高效的新型注意力算法

引用
1
来源
1.
https://www.high-flyer.cn/en/blog/flash_attn/

Transformer模型的核心是自注意力机制(self attention),其在序列长度上时间和存储的复杂度都在级别。随着大语言模型(LLMs)规模的不断扩大,为LLM配备更长的上下文背景,在工程实现上面临着非常大的挑战。

来自斯坦福大学计算机系与纽约州立大学布法罗分校的科研团队发表了一种新型的注意力算法,名叫FlashAttention,其不仅拥有比PyTorch标准注意力快24倍的运行速度,所需内存还减少了520倍。后续发布的FlashAttention2,Flash Decoding还拥有更夸张的性能加速表现。

幻方基础研究研发的大模型训练工具HAI-LLM全系采用FlashAttention,大幅提高显卡利用率,实现了非常优异的训练表现。本系列文章将为大家深入浅出聊聊FlashAttention背后的技术和我们的实践经验。

论文地址https://arxiv.org/abs/2205.14135

项目源码https://github.com/Dao-AILab/flash-attention

背景

传统的注意力算法其内存效率是的。过去一些优化注意力机制的方法是采用近似值,例如稀疏近似、低秩近似以及它们的组合。尽管这些方法可以将计算降低到线性或接近线性(),但它们过于关注降低每秒所执行的浮点运算次数(FLops),并且倾向于忽略来自内存访问(IO)的开销。

多年来GPU FLOPS的增长速度一直比内存吞吐量(TB/s)的增长更快。我们在A100或同级别显卡上优化模型训练的实践中发现,内存吞吐量才是影响训练进一步提效的重要瓶颈。FLOPS和内存吞吐量需要紧密结合,才能充分提高的训练效率。这就需要我们在软件层面上进行更加细致的设计。

如下图所示:

上图展示了CPU和GPU不同层级内存的吞吐量和容量。可以看到内存不是一个单一的部件,它在本质上是分层的,一般的规则是:内存越快,越昂贵,容量越小。

以A100为例:A100 GPU有40~80GB的高带宽内存(HBM),带宽为1.5-2.0 TB/s,而每108个流处理器有192KB的SRAM,带宽估计在19TB/s左右。可以看到虽然SRAM容量小了很多,但是速度却提升了10倍,所以如何高效的利用SRAM是提速注意力算法的关键。

标准注意力算法

我们首先看看标准注意力算法背后的计算逻辑:

可以看到标准注意力算法基本上将HBM加载/存储操作视为0成本(它并不能感知IO)。

下图展示了GPT-2模型中一个Attention算子的完整计算耗时统计:

可以看到,masking,softmax和dropout操作占用了大量时间,而主要利用FLOPS的矩阵乘法(Matmul)却只占用了一部分时间。因此,感知硬件IO进行优化的FlashAttention算法被提出,其可以大幅减少冗余的HBM IO并充分利用SRAM进行计算加速。

FlashAttention

FlashAttention思路是:既然标准注意力算法要将S写回HBM,而这个步骤只为了重新加载计算Softmax,那么我们可以将其保存在SRAM中,等执行完所有中间步骤后,再将最终结果写回HBM。如下图所示:

可以看到FlashAttention将多个操作融合在一起,其只从HBM加载一次,执行融合的算子操作,然后将结果写回HBM。融合操作主要采用了如下两种技术:

  • Tiling:矩阵分块计算,在不访问整个输入的情况下计算Softmax函数的缩减,在前向和后向传播时都使用;
  • Recomputation:时间换空间,不存储中间注意力矩阵而采用重计算的方式,仅在后向传播时使用。

完整的伪代码如下:

1. Tiling分块计算

对于有限的SRAM容量,的存储用量使得序列长度(N)限定在了一定范围,因此我们要进行矩阵分块计算。对于矩阵乘法与逐点操作(scale,masking,dropout)的分块计算是比较容易实现的,主要障碍是Softmax函数,因为其需要将所有的分数列耦合在一起。为此研究者使用了一个技巧:既然Softmax与注意力K\mathbf{K}K的列是耦合的,通过引入了两个额外的统计量来进行解耦,实现了分块计算。具体如下:

m(x):=max⁡i xi, f(x):=[...], l(x):=∑i f(x)i, softmax(x):=m(x):=\max_i x_i, ~~f(x):=[e^{x_i-m(x)}...e^{x_B-m(x)}], ~~l(x):=\sum_i f(x)_i, ~~softmax(x):=\frac{f(x)}{l(x)}maxi xi , f(x):=[...], l(x):=∑i f(x)i , softmax(x):=

对于两个向量x^{(1)},x^{(2)}\in R^B,解耦拼接向量x=[x(1),x^{(2)}]\in R^{2B}x=[x^{(1)},x^{(2)}]\in R^{2B}x= R2B的Softmax计算:

m(x)=m([x(1),x^{(2)}])=\max(x^{(1)},x^{(2)}), f(x)=[e^{m(x^{(1)})-m(x)}f(x^{(1)}) ~~ e^{m(x^{(2)})-m(x)}f(x^{(2)})]max(x(1),x(2)), f(x)=[f(x(1)) f(x(2))]l(x)=l([x(1),x^{(2)}])=l(x(1)) +l(x(2)), softmax(x)=l(x)=l([x^{(1)},x^{(2)}])=e^{m(x^{(1)})-m(x)}l(x^{(1)}) + e^{m(x^{(2)})-m(x)}l(x^{(2)}), ~~ softmax(x)=\frac{f(x)}{l(x)}l(x(1)) +l(x(2)), softmax(x)=

需要注意的是,可以利用GPU多线程同时并行计算多个块的Softmax。为了充分利用硬件性能,多个块的计算不是串行的,而是并行。

2. 重计算

为了避免产生冗余的HBM读写次数,FlashAttention没有为后向传递保存很大的中间结果矩阵。

在标准注意力实现中,后向传递计算Q,K,V的梯度时,需要用到NxN的中间矩阵S,P,但这两个矩阵并没有保存下来。研究用的技巧是重计算,保存了两个统计量,后向传递时在高速的SRAM上快速地重新计算,通过分块的方式重新计算出注意力矩阵S,P。这种方式比标准方法要快很多。

实验

相比于标准的注意力算法,FlashAttention虽然由于反向传播需要重新计算导致GFLOPs增加,但是FlashAttention有效减少了HBM的I/O,运行时间显著减少,如下图左所示:

同时从上图右也可以看到,随着Block Size增大,HBM的访问次数减少,运行时间也随之减少。当Block Size超过256时,尽管HBM访问次数在减少,但运行时间并没有减少。这时性能受到了其他因素的限制,例如,计算受限。另外需要注意的是,更大的Block Size可能导致执行一次融合算子操作需要的显存超出SRAM的大小。

在A100显卡上进行实验,FlashAttention的加速效果如下图所示:

内存的变化为:

可以看到,在不同的序列长度下组合dropout和masking,都有不同程度的加速效果;随着序列长度的增加,FlashAttention对于内存消耗有着不断优化的效果。

总结

多数大语言模型输入输出的最大序列长度只有2K或4K,本质原因是transformer的核心组件self-attention块的计算复杂度和空间复杂度是的。FlashAttention的成功启发我们,可以通过分块计算、算子融合和重计算技术,实现深度学习模型的优化与加速,这对于AI工业实践走向深水区有很大的借鉴意义。

本文原文来自高飞机器学习平台

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号