大模型训练ZeRO内存优化原理详解
大模型训练ZeRO内存优化原理详解
0. 引言
Zero Redundancy Optimizer (ZeRO)的主要目标是减少内存使用并加速大规模模型的训练过程。它通过在多个GPU或节点之间分散模型的状态(如梯度和参数)来实现这一目标。这种分散减少了每个计算节点上存储的冗余数据量,从而降低了内存占用。
论文:《ZeRO: Memory Optimizations Toward Training Trillion Parameter Models》
1. GPU 内存分布
1.1 模型状态
模型状态包括:
- 优化器状态(Optimizer States),例如使用Adam优化器时的动量和梯度方差
- 梯度(Gradients)
- 参数(Parameters)
上面的模型状态通常占据了大部分的内存,在混合精度训练中,还需要额外的内存来存储fp32的参数和优化器状态。
比如GPT-2(具有1.5B参数)模型,模型状态的保存要求至少24GB的内存。
1.2 剩余内存
除了模型状态外,剩余的内存包含:
- 激活内存。用于正向传播以执行反向传播的存储,可以通过激活检查点(checkpointing)来减少,但会提升计算量;
- 临时缓冲区。用于存储中间结果,其大小随着模型大小的增加而增加
- 不可用的碎片化内存
以上统称为除保存模型状态之外的剩余内存。
2. ZeRO 优化
2.1 ZeRO-DP 优化
ZeRO-DP(ZeRO数据并行),优化三个阶段的内存消耗情况:
Ψ为模型大小(参数个数),K为优化器状态的内存乘数,Nd为数据并行度,可以理解为GPU卡数。
在本例中,假设基于Adam优化器的混合精度训练,模型大小为7.5B,Nd=64,K=12。
下面分别介绍ZeRO-DP优化的三个阶段的具体情况。
2.1.1 ZeRO-Stage1: 优化器状态划分
(1)Pos(Optimizer State Partitioning,优化器状态划分)
ZeRO通过将优化器状态划分为Nd个数据并行进程,每个进程仅存储、更新其对应分区的优化器状态,即整体优化器状态的1/Nd,从而减少了每个设备上所需的内存量。在每个训练步骤结束时,再收集每一个进程的结果,以获取整体更新后的状态参数。
(2)ZeRO-Stage1内存优化后的结果,主要针对优化器状态(请参考上图):
(2+2)Ψ+K*Ψ/Nd
可见,优化器状态内存在原始基础上有一个Nd的除数。
(3)举例
在7.5B的模型上,标准的情况下要求120GB的内存,但是使用Pos后,Nd=64的情况下,仅要求31.4GB的内存。
而当Nd非常大时,内存消耗:
(2+2)Ψ+K*Ψ/Nd≈4Ψ
与原始的比例:
4/(4+K)
当K=12时,为1/4,即内存是原始的1/4
2.1.2 ZeRO-Stage2: 优化器状态+梯度划分
(1)Pg(Gradient Partitioning,梯度划分)
每个数据并行进程只存储和更新其对应的参数分区所需的梯度,减少了存储全部梯度的内存需求
(2)Pos+g,优化器状态+梯度划分
即在ZeRO-Stage1的Pos基础上,增加了Pg,则是ZeRO-Stage2
(3)ZeRO-Stage2内存优化后的结果,主要针对优化器状态+梯度(请参考上图):
(2+2+K)*Ψ/Nd
(4)举例
在7.5B的模型上,标准的情况下要求120GB的内存,但是使用Pos+g后,Nd=64的情况下,仅要求16.6GB的内存。
而当Nd非常大时,内存消耗:
(2+2+K)*Ψ/Nd≈0
这意味着,理论情况下,当设备足够多时,可以训练任意大的
2.1.3 ZeRO-Stage3: 优化器状态+梯度+参数划分
(1)Pp(Parameter Partitioning,参数划分)
类似于优化器状态和梯度的划分,每个进程只存储其参数分区的参数,在需要时通过广播从其他进程接收非本分区的参数。
(2)Pos+g+p,优化器状态+梯度划分+参数划分
即在ZeRO-Stage2的Pos+g基础上,增加了Pp,则是ZeRO-Stage3
(3)ZeRO-Stage3内存优化后的结果,主要针对优化器状态+梯度+参数(请参考上图):
2Ψ+(2+K)*Ψ/Nd
(4)举例
在7.5B的模型上,标准的情况下要求120GB的内存,但是使用Pos+g+p后,Nd=64的情况下,仅要求1.9GB的内存。
而当Nd非常大时,内存消耗:
2Ψ+(2+K)*Ψ/Nd≈2Ψ
与原始的比例:
2/(4+K)
当K=12时,为1/8,即内存是原始的1/8
2.2 ZeRO-R 优化
2.2.1 减少激活内存
(1)Pa(Partitioned Activation Checkpointing,划分激活检查点)
ZeRO-R通过Pa操作来减少因模型并行化(MP)导致的激活内存冗余。在正向传播过程中,每一层的输入激活被分割并存储在所有模型并行进程中,仅存储分区的激活检查点,而不是复制副本。ZeRO-R使用all-gather操作在反向传播需要时重新生成激活的复制副本。
(2)Pa+cpu
对于非常大的模型,ZeRO-R可以将分割的激活检查点卸载到CPU内存中,几乎将激活内存开销降至零,但是要额外的通信成本。
(3)举例
例如,对于一个100B参数的模型,如果每个Transformer层仅检查点一个激活,那么仅存储激活检查点就需要一个GPU约33GB的内存。但是,使用ZeRO-R中的Pa优化,可以将其降低到每GPU约2GB。此外,这2GB可以卸载到CPU上,将激活的内存占用减少到几乎为零。
2.2.2 管理临时缓冲区
ZeRO-R通过使用固定大小的缓冲区来避免临时缓冲区随着模型大小增加而膨胀,同时确保缓冲区足够大以保持效率。
2.2.3 管理碎片化内存
内存碎片化是由于短期和长期存活内存对象的交错导致的。ZeRO-R执行即时内存碎片整理,通过将激活检查点和梯度移动到预先分配的连续内存缓冲区中,不仅增加了内存的可用性,还通过减少内存分配器寻找连续内存块的时间来提高效率。
3. ZeRO 通讯分析
3.1 ZeRO-DP通讯分析
3.1.1 Pos+g的通讯量
使用梯度分区,每个进程只存储更新其相应参数分区所需的梯度部分。
(1)ZeRO只需要在梯度上进行分散缩减操作,从而产生Ψ的通信量。
(2)在每个进程更新其负责的参数分区后,执行全收集以从所有数据并行进程中收集所有更新的参数。这也会产生Ψ的通信量。
(3)因此,每个训练步骤的总通信量为Ψ+Ψ=2Ψ与标准DP情况完全相同。
3.1.2 Pos+g+p的通讯量
加入Pp参数划分之后,ZeRO-DP的通信量最多增加到标准DP的1.5倍,即3Ψ。这是因为在前向传播和反向传播中,参数需要在进程间进行广播和收集。尽管如此,Pp阶段进一步将内存占用减少,且减少程度与数据并行度Nd成线性关系。
3.2 ZeRO-R通讯分析
Pa在ZeRO-R中的通信开销与传统的MP模型并行方法相比,增加量通常不到10%。
但是由于ZeRO-R还提供了将激活分区卸载到CPU内存的选项Pa+cpu,这可以在保持效率不降低太多的同时,进一步减少GPU上的内存需求。