问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

最透彻的大模型PPO原理和源码解读

创作时间:
作者:
@小白创作中心

最透彻的大模型PPO原理和源码解读

引用
1
来源
1.
https://www.wehelpwin.com/article/5152

本文深入解析了大模型PPO(Proximal Policy Optimization)在RLHF(Reinforcement Learning from Human Feedback)中的原理和源码实现。从强化学习的基础概念到NLP中的具体应用,再到详细的loss计算过程,通过直观的解释和代码实践,帮助读者全面理解这一复杂算法在实际中的应用。

一、强化学习概述

1.1 强化学习整体流程

强化学习主要涉及两个实体:智能体(Agent)与环境(Environment)。它们之间的交互包括:

  • 状态空间S(State):环境中所有可能状态的集合
  • 动作空间A(Action):智能体所有可能动作的集合
  • 奖励R(Reward):智能体在环境的某一状态下所获得的奖励

以图为例,智能体与环境的交互过程如下:

  1. 在时刻,环境的状态为,达到这一状态所获得的奖励为
  2. 智能体观测到与,采取相应动作
  3. 智能体采取后,环境状态变为,得到相应的奖励

智能体的目标是找到一个策略,根据当前观测到的环境状态和奖励反馈,选择最佳的动作。

1.2 价值函数

在强化学习中,奖励值表示环境进入状态下的即时奖励。为了考虑未来收益的影响,引入了总收益的概念:

[ V(s) = R(s) + \gamma V(s') ]

其中:

  • ( V(s) ):时刻的总收益
  • ( R(s) ):时刻的即时收益
  • ( \gamma ):折扣因子,决定了未来收益的权重

二、NLP中的强化学习

在NLP任务中,我们需要将强化学习框架映射到语言生成任务中。具体来说:

  • 智能体(Agent):语言模型本身
  • 环境(Environment):语言模型生成的语料
  • 状态(State):当前的上下文
  • 动作(Action):生成的token
  • 奖励(Reward):对生成token的即时反馈

三、RLHF中的四个重要角色

在RLHF中,主要涉及四个模型:

  • Actor Model(演员模型):目标语言模型,需要训练
  • Critic Model(评论家模型):预测总收益,需要训练
  • Reward Model(奖励模型):计算即时收益,参数冻结
  • Reference Model(参考模型):防止模型训歪,参数冻结

3.1 Actor Model (演员模型)

Actor模型初始化为SFT模型,用于生成符合人类喜好的response。

3.2 Reference Model(参考模型)

Ref模型也用SFT模型初始化,通过KL散度防止Actor模型训歪。

3.3 Critic Model(评论家模型)

Critic模型用于预测期望总收益,从RW阶段的Reward Model初始化。

3.4 Reward Model(奖励模型)

Reward模型用于计算生成token的即时收益,参数冻结。

四、RLHF中的loss计算

4.1 Actor loss

(1)直观设计

[ \text{actor_loss} = -\log(\pi(a|s)) \cdot V(s) ]

(2)引入优势(Advantage)

[ A(s, a) = Q(s, a) - V(s) ]

(3)重新设计

[ \text{reward} = \text{kl_ctl} \cdot (\text{ref_log_probs} - \text{log_probs}) + \text{reward_score} ]

(4)重新设计优势

[ A_t = \delta_t + \gamma \lambda A_{t+1} ]

(5)PPO-epoch: 引入新约束

[ \text{actor_loss} = -\min(r_t \cdot A_t, \text{clip}(r_t, 1-\epsilon, 1+\epsilon) \cdot A_t) ]

(6)Actor loss小结

最终的actor_loss综合考虑了即时收益、未来收益、KL散度约束和剪裁机制。

4.2 Critic loss

Critic loss通过MSE损失函数计算实际收益和预估收益的差异,并引入剪裁机制。

总结

本文详细介绍了RLHF-PPO的核心原理和具体实现,通过直观的解释和代码实践,帮助读者深入理解这一复杂算法在实际应用中的具体实现。建议读者在理解理论的基础上,结合源码进行实践,以加深理解。

源码地址:https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号