LLM大模型训练Trick系列之拒绝采样
LLM大模型训练Trick系列之拒绝采样
拒绝采样是一种在LLM大模型训练中常用的技巧,它通过从一个复杂的分布中采样数据,帮助模型生成更高质量的输出。本文将详细介绍拒绝采样的原理、应用及其在LLM训练中的重要作用。
背景介绍
拒绝采样是一种蒙特卡洛算法,用于借助代理分布从一个复杂的(“难以采样的”)分布中采样数据。蒙特卡洛方法的核心思想是,如果不能直接从目标分布函数中采样,那么可以使用另一个分布函数(提议函数)来近似。
拒绝采样的基本步骤如下:
- 从提议分布q(x)中采样一个样本x。
- 从均匀分布U(0, Mq(x))中采样一个值y。
- 如果y < p(x),则接受x作为目标分布p(x)的一个样本;否则,重复上述过程。
这种方法通过均匀分布将Mq(x)提供的“封包”缩放到p(x)的概率密度函数,确保采样结果遵循目标分布。
在生成模型的背景下,拒绝采样通常是在一个微调过的模型基础上进行K个样本采样。然后使用一个拒绝或接受函数来过滤筛选出符合目标分布的样本,再进行模型微调。
相关研究
WebGPT: Browser-assisted question-answering with human feedback
WebGPT在推理阶段使用拒绝采样,但并未将其用于微调。研究发现,拒绝采样在某些情况下比强化学习(RL)表现更好,原因可能包括:
- 拒绝采样可以利用更多的推理计算资源。
- 在不可预测的环境中,模型可以尝试访问更多网站并评估信息。
- 奖励模型主要基于BC和拒绝采样策略训练的数据,可能使其对拒绝采样的过度优化更具鲁棒性。
Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback
在这个研究中,拒绝采样被用于生成样本,其中k值(采样数量)是一个参数,通常设置为16。研究发现,随着k值的增加,效果会更好,但在线RLHF模型的表现似乎优于拒绝采样。
Aligning Large Language Models through Synthetic Feedback
该研究使用拒绝采样得到的数据进行微调,通过ICL生成不同级别模型对prompt的response。使用合成的RM模型对生成的输出进行评分,选择最佳响应作为最终响应,即RM执行拒绝采样(最佳N采样)。
Llama 2: Open Foundation and Fine-Tuned Chat Models
Llama 2使用拒绝采样进行微调,与PPO(一种流行的on-policy RL算法)相比,主要区别在于:
- 广度:在拒绝采样中,模型为给定prompt探索K个样本,而PPO只生成一个样本。
- 深度:在PPO中,训练步骤t的样本是上一步梯度更新后模型策略的函数;而在拒绝采样微调中,所有输出都是基于初始策略采样的,然后进行类似SFT的微调。
Llama 2还使用RM模型进行拒绝采样生成的样本进行SFT训练,并更新策略模型的梯度。同时,他们还将拒绝采样生成的样本作为gold,在旧的checkpoint上重新训练RM模型,以加强RM模型的奖励。
SCALING RELATIONSHIP ON LEARNING MATHEMATICAL REASONING WITH LARGE LANGUAGE MODELS
该研究提出使用拒绝采样微调(RFT)来增强数据样本,通过监督模型生成和收集正确的推理路径作为增强微调数据集。实验表明,RFT可以显著提高数学推理性能,特别是在表现较差的LLM上。通过结合多个模型的拒绝样本,LLaMA-7B的准确率从35.9%提升到49.3%。
RAFT: Reward rAnked FineTuning for Generative Foundation Model Alignment
RAFT提出了一种新的框架,利用奖励模型和足够数量的样本,选择高质量样本并丢弃表现出不良行为的样本,从而构建一个流式数据集。这种方法在大语言模型和扩散模型中都表现出强大的性能。
总结与思考
拒绝采样通过拒绝/接受函数(可以是奖励模型或启发式规则)筛选SFT模型的输出结果分布,提高了最终返回的效果。在RLHF框架中,拒绝采样微调可以用于更新SFT模型的效果,对于PPO算法来说也很重要。同时,拒绝采样提供了更多的推理路径供模型学习,这对于模型的COT能力提升非常重要。