问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

通过 PyTorch 分布式异步检查点将模型检查点时间缩短 10 倍以上

创作时间:
作者:
@小白创作中心

通过 PyTorch 分布式异步检查点将模型检查点时间缩短 10 倍以上

引用
1
来源
1.
https://pytorch.ac.cn/blog/reducing-checkpointing-times/

PyTorch分布式异步检查点功能的推出,为大模型训练带来了革命性的性能提升。通过将检查点过程从关键路径中分离,该功能能够将模型检查点时间缩短10-20倍,使得7B模型的检查点"停机时间"从平均148.8秒降至6.3秒。这一突破不仅显著减少了训练中断时间,还为更频繁的检查点提供了可能,从而提高了训练的鲁棒性和效率。

模型检查点是大模型训练的重要组成部分,但检查点是一个昂贵的过程,因为每个检查点过程都涉及阻止训练进度,以便保存最新的模型权重。然而,不进行检查点或降低检查点频率可能会导致训练进度的重大损失。例如,死锁、落后者和gpu错误等故障需要重新启动训练过程。为了从故障中重新启动,所有(训练)工作进程都必须停止其训练过程,并从上次保存的检查点重新启动。

因此,对故障的鲁棒性与训练进度之间的内在张力表现为一种权衡,但现在通过异步检查点,PyTorch Distributed能够显着缓解这种张力,并以对整体训练时间的最小影响实现频繁检查点。

作为背景,几乎就在一年前,我们展示了分布式检查点如何将检查点时间从最初的torch.save()功能大幅加快。正如IBM研究所指出的那样,torch.save最多可能需要30分钟才能检查点单个11B模型(PyTorch 1.13)。

随着分布式检查点的进步,对于高达30B的模型大小,检查点可以在4分钟内完成。

借助异步检查点,由于检查点而损失的训练时间现在缩短到30秒以下,通常甚至短至6秒。

需要明确的是,异步检查点并没有像之前的更新所展示的那样压缩实际的序列化检查点时间。相反,它将最终的检查点过程从关键路径(转移到cpu线程)移开,以便GPU训练可以在单独的线程下完成检查点时继续进行。

但是,对于用户而言,效果几乎相同,因为由于检查点造成的训练停机时间大大减少,在许多情况下减少了10倍甚至20倍。

正如上面的加速图表所示,异步检查点在一年前的巨大改进的基础上,又产生了10倍到23倍的进一步改进。

异步检查点如何工作?

异步检查点将检查点过程模块化为两个部分,而不是一个整体过程。第一阶段将每个gpu/rank的数据从GPU复制到CPU。这是用户可见的停机时间,对于7B-13B模型大小,可能需要6-14秒。第二阶段异步地将数据从CPU内存复制到磁盘以持久化检查点。

一旦数据在第一阶段复制到CPU,GPU就可以立即恢复训练。因此,借助异步检查点,检查点的停机时间仅仅是将最新模型状态复制到CPU所需的时间。

在训练恢复的同时,非阻塞CPU线程与内存中新到达的数据协同工作,以完成到磁盘的完整检查点/序列化过程(即持久化保存)。

请注意,PyTorch的分布式检查点器依赖于每个rank元数据的集体通信调用来优化保存,以及一个最终同步,它将检查点标记为完成并使操作具有原子性。如果检查点线程使用与训练相同的进程组,这可能会干扰分布式训练(因为分布式训练也依赖于类似的调用来同步多个GPU上的训练)。

具体而言,调用之间的竞争条件可能会导致训练和异步检查点保存线程同时等待集体调用,从而导致真正的集体挂起。

我们通过为异步检查点初始化单独的进程组来避免这种情况。这会将检查点集体通信分离到它们自己的逻辑进程组中,从而确保它不会干扰主训练线程中的集体调用。

如何在我的训练中使用异步检查点?

异步检查点的使用相对简单。使用最新版本的PyTorch nightly版本,您需要使用nccl和gloo初始化您的进程组。cpu线程部分需要Gloo。

从那里,创建一个异步检查点将使用的重复进程组。然后像往常一样进行训练,但在您想要检查点时,使用异步保存api,传入要保存的状态、检查点ID和检查点进程组。

异步检查点也已在torchtitan中完全实现。在这里,它被实现用于预训练您自己的Llama2或Llama3模型。使用它就像更新toml配置文件一样简单

未来工作

在过去一年中,检查点取得了巨大进步。从近半小时的检查点到分布式检查点的5分钟以下,再到现在的异步检查点的30秒以下。

最后的边界-零开销检查点,其中甚至<30秒也被消除,方法是在反向传播期间流式传输更新的权重,以便在异步检查点开始时检查点数据已在cpu上。

这将有效地将大型模型训练转移到检查点没有中断或停机时间的地方,从而实现更高的鲁棒性(因为可以更频繁地进行检查点)和更快的训练进度,因为检查点没有停机时间。

源代码链接:https://github.com/pytorch/pytorch/blob/main/torch/distributed/checkpoint/state_dict_saver.py

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号