DeepGEMM如何实现FP8通用矩阵乘法
DeepGEMM如何实现FP8通用矩阵乘法
DeepGEMM是一种细粒度混合精度训练框架,通过创新的FP8通用矩阵乘法(GEMM)实现方式,有效解决了低精度训练中的精度损失问题。本文将详细介绍DeepGEMM的核心技术原理及其在实际应用中的效果。
(还是从DeepSeek-V3技术报告那窜出来的一个小课题)在技术报告中作者提出一种细粒度混合精度的训练框架,来看看它是怎么高效实现FP8通用矩阵乘法的。链接:技术报告;代码库(ง •_•)ง
虽说低精度训练很有前景,但对激活值、权重、梯度中异常值很敏感。虽说在推理过程进行量化的研究已有不少成果,但将低精度技术成功应用到LLM预训练过程中的研究还是少。
为了解决这些挑战并对FP8数据格式的动态范围进行有效扩展,作者推出一个细粒度量化策略:分成
的小块(tile-wise grouping)或
的完整方块(block-wise grouping)。
作者通过引入increased-precision accumulation策略,用更高精度的数据格式来计算和累积模型参数的更新,缓解去量化的开销,这对于达到精确的FP8通用矩阵乘法(GEMM)至关重要。
为了进一步减少MoE训练过程中内存和通信开销,作者以FP8的数据格式缓存和分发激活值,以FP16存储低精度的优化器状态。
作者在DeepSeek-V2-Lite和DeepSeek-V2这两尺寸相近的模型上对这个框架进行了验证,训练了将近1万亿token。结果就是,相较于采用BF16的基线模型,这里的FP8训练模型的相对误差始终保持在0.25%以下,在训练随机性的可接受范围内。
混合精度框架
图1. NVIDIA Tensor Cores支持的所有数据格式(这张图是2025年3月问GPT得到的结果,不能展现最新发展)
用于FP8混合精度训练框架整体如下图,简化起见只展示线性操作,其中多数计算密集型操作采用FP8,然后战略性地将少数关键操作保持在原始数据格式,以平衡训练效率和数值稳定性。
图2. 混合精度训练框架
图中3个和线性操作相关的GEMM操作分别是Fprop(前向)、Dgrad(激活值反向)、Wgrad(权重反向),都接收FP8的张量作为输入,产生BF16或FP32的输出。这种设计理论上可以将原BF16的方法提速一倍,其中Wgrad GEMM(在反向传播过程中计算权重梯度的矩阵乘法操作)使得激活值能以FP8格式进行存储,然后在反向传播的过程中被直接使用。
尽管种种FP8格式带来的效率优势,一些对低精度敏感的重要操作还是得要更高的精度。另外一些低成本操作也可以以微乎其微的开销采用高精度。经过作者的仔细调研后,会为以下模块保留原始精度:嵌入层模块、输出头、MoE门控模块、归一化算子、注意力算子,以进行稳定的训练。另外某些master权重、权重梯度和优化器状态会采用更高精度来保证数值稳定性。这些高精度组件虽会带来一些内存开销,但在分布式训练系统中,可以通过数据并行进行有效分片来消解。
精度提升
作者提出几个提升低精度训练准确度的策略,这些策略专注于量化方法和矩阵乘过程。
1. 细粒度量化
在低精度训练框架中,FP8格式的指数位少,表示范围受限,容易出现溢出(overflow超过所能表示的最大值变inf)和下溢出(underflow小于所能表示的最小正值变0)的问题。于是常把输入分布对齐到FP8格式的可表示范围内(将输入张量中最大的绝对值缩放到FP8格式能表示的最大值),但当出现异常激活值(极端大或小)时,量化精度将受到极大损害。
PS:何为激活值?前向传播时每一层神经网络的输出,通常由线性变换和非线性激活函数计算得到,用作下一层的输入,影响模型的学习和决策能力,并在反向传播过程中用于计算梯度,调整模型权重。为了节约内存,可以在反向传播时重新计算而不存储(activation checkpointing),或用更低精度进行存储(mixed precision training)。
作者提出一种细粒度量化方法,在更细粒度的层次上应用放缩。对于激活值,将其分成
的小块(tile-wise grouping)(每个token的128个通道共享一个缩放因子,减少误差);对于权重,将其分成
的完整方块(block-wise grouping)(每128个输入通道x128个输出通道共享一个缩放因子,加速推理)。简化起见,下图中只展示了前向传播的Fprop过程,通过更细粒度的划分,对元素更少的组进行缩放,更能容纳异常值,量化误差更小。
图3. Fprop过程中对激活值和权重采用不同的划分方法
PS:虽说小块尺度(tile-wise)量化可以有效缓解异常特征值带来的误差,但在后文也将看到,对于激活值量化就会有不同的分组方式,在前向时
,而反向时
,对于激活值梯度也是同样的处理。一个直接的想法是,干嘛要这样重新划来划去?为啥不直接对每
个元素做大块尺度的量化呢?这样在反向传播时转置一下就行了。于是作者做了个实验,把所有和激活值反向传播的Dgrad过程相关的张量都做大块尺度量化,结果就是,这个计算激活值梯度、以链式法则向浅层反向传播的Dgrad操作,对精度巨敏感!甚至让一个总参数量约16B、训练在300B token上的MoE模型发散了。作者猜测这种敏感性产生的原因在于token间的激活值梯度极不平衡,存在和token强关联的异常值,用大块尺度量化可不能有效管理它们。
标准的FP8 GEMM是不支持沿着累加维度引入按组缩放因子的,但作者通过结合FP32累加策略就能有效实现了。作者还提到,这个细粒度量化策略和microscaling格式的思路高度一致,NVIDIA下一代GPU(Blackwell系列)称其Tensor Cores将以更小的粒度支持microscaling格式,作者希望这里的设计可以作为未来工作的参考,以跟上最新GPU架构的步伐。
2. 高精度累加
低精度的GEMM常面临下溢的问题,其准确性很大程度上依赖于高精度累加(通常是FP32),然而FP8 GEMM在NVIDIA H800 GPU上的累加精度仅限于保留14位左右,明显低于FP32的累加精度,当累加维度K较大时就会出问题,在大规模模型训练中的一个典型场景就是增加批量大小和模型宽度。随机拿两个K=4096的矩阵当例子,在初步测试中,Tensor Cores有限的累加精度导致的最大相对误差接近2%。尽管存在问题,有限的累加精度依旧是少数FP8框架的默认选项,这严重限制了训练精度。
详细了解一下Tensor Cores:
Tensor Cores是NVIDIA GPU体系结构中的专用硬件单元,旨在加速深度学习和高性能计算任务。首先在Volta架构(V100,代表GPU有Tesla V100,支持FP16计算)中引入,随后在Turing(T4、RTX20xx,代表GPU有RTX 2080 Ti、T4,支持INT8/INT4推理,提高AI性能)、Ampere(A100、RTX30xx,代表GPU有A100、RTX 3090,新增TF32和FP64 Tensor Cores HPC)和Hopper(H100,代表GPU H100,提供转为LLM设计的Transformer Engine,支持FP8训练)架构中不断优化,极大提升了矩阵乘法和卷积计算的速度。
特别说明一下,在TF32中,和FP32采用了相同的32位结构,可以兼容FP32代码,但尾数仅10位,提升速度,再除了1位符号、8位指数外,其余13位为扩展的尾数位,允许计算过程中做更多调整,提高吞吐量。
Tensor Cores是专门为矩阵计算设计的硬件单元,用于混合精度计算,能在一次运算中执行多个FMA(Fused Multiply-Add,融合乘加)操作。例如在FP16下,一个Tensor Cores可以在一个时钟周期内执行64次FMA运算(对应4x4x4维度的矩阵乘法),在INT8、TF32、BF16精度下计算吞吐量显著提高。针对矩阵-矩阵运算(如矩阵乘法、卷积),通过Warp-level并行计算提高效率。
Tensor Cores加速GEMM,执行操作D = A x B + C,其中A和B是输入矩阵,C是累积矩阵,用于存储部分计算结果,D是最终输出矩阵。Tensor Cores支持混合精度计算,输入可以使用低精度(如FP16/INT8)来减少存储带宽需求,计算内部使用更高精度(如FP32/TF32)来保持数值稳定性,输出也能是FP32以满足精度要求。
Tensor Cores通过矩阵乘法加速,提高神经网络训练效率,例如ResNet/EfficientNet训练速度提升3-6倍,GPT/BERT/ViT预训练速度提升5-10倍(相比于没有Tensor Cores的GPU)。其在推理过程中使用INT8、INT4精度,降低在ASR、NLP、CV等任务上的计算开销。其也用于科学计算(需要高精度,如FP64 tensor Cores在A100/H100中提升HPC性能),如流体力学模拟、天气预报、药物模拟。
NVIDIA提供cuBLAS(矩阵运算库)和cuDNN(深度学习优化库),自动利用Tensor Cores;TensorRT也可自动将AI模型转换为Tensor Core计算模式,加速推理;在CUDA c++代码中,可以使用WMMA(Warp Matrix Multiply-Accumulate) API调用Tensor Cores。
为了解决这个问题,作者选择调换到CUDA Cores以获取更高的精度。如下图所示,在Tensor Cores上执行MMA(矩阵乘-累加)时,先使用有限位进行低精度累加,每处理完
个元素时,这些部分结果将被复制到CUDA Cores的FP32寄存器中,进行全精度的FP32累加。然后如前文所说,细粒度量化会沿累加维度K应用按组缩放因子,这些缩放因子可以在CUDA Cores上执行高效的乘法运算,以极小的计算成本实现去量化过程。
图4. 每隔一段时间就将Tensor Cores上的中间结果复制到CUDA Cores上
再比对着Tensor Cores详细了解一下CUDA Cores:
和Tensor Cores的专用设计不同,CUDA Cores是NVIDIA的通用计算单元。作为GPU架构的基础,它们可以处理多种计算任务,适用于图形渲染、科学计算和机器学习等领域。CUDA Cores虽也支持并行计算,但在处理特定的矩阵计算任务时,其性能远无法与Tensor Cores相比,而是在执行较复杂的控制流和算法时表现更优,适合处理多样化的应用需求。Tensor Cores主要支持低精度计算,尤其是FP16和TF32格式,而CUDA Cores则会执行更高精度的数值计算。Tensor Cores会通过有效利用高效缓存和数据局部性来减少内存带宽消耗,提高数据访问速度,而CUDA Cores由于其通用性,需要频繁地进行数据传输,可能会增加延迟。Tensor Cores执行低精度矩阵乘法快速降低运算延迟,通过细粒度量化策略来减少计算中的量化误差,在特定情况下,数据可能被实时晋升至CUDA Cores进行高精度累加,以确保计算结果的稳定性和准确性。
值得注意的是,这样调整会降低一个warp组(在GPU中指令是以warp,通常是32个线程为基本单位执行的,多个warp组成一个warpgroup)中WGMMA指令的发送速率。但在NVIDIA H800架构上,通常会同时执行两个WGMMA指令:当一个warpgroup进行promotion(将数据加载到更快的缓存区域)操作时,另一个warpgroup执行MMA(矩阵乘加)操作,让两种操作重叠,高效利用Tensor Cores。
基于实验,设置
个元素(相当于4个WGMMA指令),便是最小的理想间隔,能在不引入大量开销的情况下显著提升精度。
3. 尾数先于指数
先前工作中的混合FP8格式是怎么样的呢?在Fprop操作中采用E4M3格式,也就是4位指数3位尾数,因为前向传播的过程主要是计算激活值,这些值通常不会有太大的范围变化,所以用更多的尾数位来提高精度;在反向传播数据梯度Dgrad和权重梯度Wgrad时采用E5M2格式,也就是5位指数2位尾数,因为梯度的动态范围大(梯度可能很大或很小),需要更多的指数位进行表示。
作者选择将所有的FP8张量统一采用E4M3格式,以追求更高精度,并将其可行性归功于前文说的细粒度量化策略(tile-wise scaling和block-wise scaling):通过在更小的元素组上实时缩放,能更有效地在组内元素间共享指数位,缓解有限动态范围的影响。
4. 在线量化
在对张量进行量化的框架中,延迟量化(delayed quantization)时有应用,在迭代过程中不断保留最大的绝对值,来推断当前值。为了确保缩放的准确性和框架的简单性,作者为每个
的激活值小块或
的权重大块在线计算最大绝对值,在此基础上产生缩放因子,然后实时将激活值或权重量化为FP8格式。
低精度存储和通信
有了上述的FP8训练框架后,作者以更低的精度对激活值和优化器状态进行压缩、缓存,进一步减少内存和通信开销。
低精度优化器状态
作者采用BF16数据格式代替FP32来追踪AdamW优化器的一阶和二阶矩,这样并不会引起明显的性能下降。而由优化器维护的核心模型参数(the master weights stored by the optimizer)和在多个小批量间进行累加的梯度(gradients used for batch size accumulation)还是存成FP32的格式,以确保训练过程中的数值稳定性。
低精度激活值
在图2中,Wgrad和Dgrad操作都采用FP8格式。但实际上为了低成本高精度训练,有几个操作得特殊考虑:
- 注意力操作之后线性层的输入。这些激活值也会用于注意力算子的反向传播,对精度十分敏感。作者为它们专门采用定制的E5M6数据格式,另外在反向传播时中会将这些激活值进行维度转换,从
的量化块变成
的小块(为了匹配计算模式、提高计算并行性、优化内存访问?)。为了避免引入额外的量化误差,所有缩放因子都采用整数幂进行缩放,例如2的幂次,可直接用移位运算实现,避免浮点数误差,提高计算精度。
- MoE架构中SwiGLU操作的输入。为了进步减少内存消耗,作者将SwiGLU操作(Swish-Gated Linear Unit,一种改进版的激活函数,比ReLU更平滑,比普通GLU更高效)的输入进行缓存,在反向传播的过程中重计算输出。这些激活值也会通过细粒度量化方法存储成FP8数据格式,以取得内存效率和计算精度之间的平衡。
低精度通信
通信带宽是MoE模型训练过程中的关键瓶颈,为了缓解这个问题,在MoE上投影前将激活值量化为FP8格式,然后应用调度组件,这些组件负责将量化后的激活值有效分派给相应专家,在设计上考虑了与MoE上投影中的Fprop操作的协调性。与注意力操作后的线性层的输入一样,这里激活值的缩放因子也是2的幂次。类似策略也会在MoE下投影前被应用到激活值的梯度上。对于前向传播和反向传播过程中需要合并的部分,可能是梯度累积、权重更新等关键计算步骤,会继续使用BF16数据格式,以确保训练精度不会受到影响。
写在最后:师兄派任务给我的时候,他比较关注隔段时间将张量从Tensor Cores搬到CUDA Cores上使用FP32全精度这个思路,就像二级累加,有点分级存储那味儿了,他问能不能做到三级实现FP4呢?我比较菜,感觉这要对GPU的核心计算单元做较底层的研究了,有点畏难,想先找个实现了FP8 GEMM的小模型研究一下,佬们可以给我些指点吗?🙇
