大语言模型的持续预训练:如何(重新)预热模型?
大语言模型的持续预训练:如何(重新)预热模型?
大语言模型(LLM)通常会在数十亿个tokens上进行预训练,一旦有新数据可用,就必须重新启动该过程。一个更便宜、更有效的解决方案是启用这些模型的持续预训练,即用新数据更新预训练模型,而不是从头开始重新训练它们。然而,新数据引起的分布变化通常会导致过去数据的性能下降。为了朝着高效的持续预训练迈出一步,本文研究不同预热策略的效果。
研究背景与实验设置
大语言模型(LLM)通常会在数十亿个tokens上进行预训练,一旦有新数据可用,就必须重新启动该过程。一个更便宜、更有效的解决方案是启用这些模型的持续预训练,即用新数据更新预训练模型,而不是从头开始重新训练它们。然而,新数据引起的分布变化通常会导致过去数据的性能下降。为了朝着高效的持续预训练迈出一步,本文研究不同预热策略的效果。
数据集与模型
- 上游(或预训练)数据集:Pile(Gao,2020)
- 下游(或微调)数据集:SlimPajama(Soboleva,2023)
- 模型架构:Pythia 410M
- 优化器:AdamW
- 学习率:{1.5 · 10−4, 3 · 10−4, 6 · 10−4}
- 梯度裁剪:1.0
- 训练精度:半精度(FP16)
持续预热实验
预热时长的影响
在文献中,通常最多用 1% 的数据进行预热 (Zhao,2023)。在这个实验中,研究是否对这个超参敏感。
结果:本实验结果如图所示。它们表明,用于预热学习率的数据大小不会显著影响下游任务(学习)或上游任务(遗忘)的困惑度。这些结果推翻了原来的假设,即用更多 tokens 进行预热可以平滑过渡,线性预热毫无用处。然而,在没有任何渐进式预热的情况下,训练的模型会经历一个初始的“混沌阶段”,导致损失在训练的前几次迭代中激增,这种现象也称为稳定性差距(Lange,2023;Caccia,2022)。
要点:
- 预热阶段的长度似乎对 Pile 和 SlimPajama 验证损失没有显著影响。
最大学习率的影响
重新预热学习率的一个目标是实现计算效率高的持续预训练。学习率太小可能会导致下游数据集的学习效率低下,而学习率太大可能会导致上游数据集的灾难性遗忘。重新预热学习率的一个重要方面是决定将其增加到多高。因此,在这个实验中,改变最大学习率来评估其对性能的影响。
结果:本实验的结果如下三个图所示。在训练结束时,较大的最大学习率会提高下游数据的性能,而会损害上游数据的性能。相反,较小的最大学习率会提高上游数据的性能,同时限制对下游数据的适应性——导致性能下降。改变最大学习率可能是一种权衡下游和上游性能的有效方法。此外的一个普遍趋势:在 SlimPajama 上进行微调会导致模型忘记在 Pile 上学到的内容,从而造成 Pile 验证困惑度增加。最后,从恒定学习率训练的模型采用早期停止(类似于传统的微调)是一种经济的方式,可以适应新的数据分布,同时保持上游数据集的强大性能。
要点:
- 重新预热然后降低学习率似乎是在下游任务上学习良好的必要条件。此外,虽然保持恒定的学习率最初在 Pile 上是有利的,但当在 SlimPajama 上训练足够长的时间时,这种优势就会消失。
- 仅在 SlimPajama 上学习的模型在 SlimPajama 上的表现,比在 Pile 上预训练的模型更差,尽管它仅针对下游任务进行了优化,突出了两个数据集之间的正向迁移。
与从头开始训练的模型比较
在这个实验中,比较微调模型和从头开始训练的模型。
结果:所有经过预热的微调模型都比从头开始训练的模型表现更好。这表明,即使下游数据集与上游数据集规模相同且与上游数据集重叠,微调而不是重新训练也可能提高性能。在 200B 个 tokens 之后,从头开始训练的模型比使用恒定学习率微调的模型表现更好。
在同样数据重新预热
在之前的实验中,对新数据进行微调会导致旧数据的损失快速增加,随后减少。最大学习率越大,增量越大。损失增加的一个假设是,上游和下游数据之间的分布变化会干扰训练过程。为了评估这一假设,在没有分布变化的环境中应用了预热策略。也就是说,通过对 Pile 进行微调来复制实验。
结果:如图表明,在继续对 Pile 进行预训练的同时重新预热学习率,在查看下游验证损失时,与对 SlimPajama 数据重新预热的效果类似。这表明,Pile 和 SlimPajama 之间的分布漂移并不是前面重新预热学习率负面影响的唯一原因,优化的动态性也在损失增加中发挥作用。
要点:
- 重新调整学习率似乎是导致之前开始学习下游任务时出现性能下降的一个重要原因,这一点可以通过在对同一数据集训练时重新预热然后降温学习率来证明。
- 在同一数据集上进行训练时,模型似乎无法从重新预热学习率导致的性能损失中恢复过来。
评估早期检查点
设置:从模型预训练中选择三个检查点来测试预热策略是否受益于从非收敛检查点开始。假设是,选择距收敛较远的检查点可能有利于适应下游任务,因为这些检查点可能位于损失区域中更有利的点。
为了选择明显不同的检查点,将最后一个预训练检查点(即 143,000 次迭代后的 Pythia 410M)与一个更早的检查点进行比较,该检查点的 Pile 验证损失接近之前所有模型达到的最大 Pile 验证损失(∼ 2.5),以及两个其他检查点之间的第三个检查点。
结果:如图提供了 SlimPajama 上验证损失的演变。可以看到,在设置中,选择较早的检查点进行后期微调不会导致下游性能的改善。因此,选择最新的检查点是最佳选择。结论是,预训练不会导致模型失去可塑性,那将使模型难以重新预热。
局部结论:即使下游数据与上游数据的来源相似,在新数据上重新预热预训练模型也是一项艰巨的任务。结果表明,用于预热的tokens数不会显著改变性能,增加最大学习率会提高最终模型的下游性能,而降低最大学习率会提高上游性能,选择较早的检查点会降低上游和下游数据的性能。
要点:
- 在 Pile 上进行预训练时使用较早的检查点并不会导致在 SlimPajama 上学习得更快。