从头开始复现GRPO【关键模块解析】
从头开始复现GRPO【关键模块解析】
近日,AI工程师和技术作家Andriy Burkov发布了一份「从头开始写GRPO代码」的教程,其中介绍了如何基于Qwen2.5-1.5B-Instruct模型构建一个使用GRPO的分布式强化学习流程。以下是项目对应的GitHub代码仓库:
theLMbook/GRPO_From_Scratch_Multi_GPU_DataParallel_Qwen_2_5_1_5B_Instruct.ipynb at main · aburkov/theLMbook · GitHub
这篇文章主要从数据格式角度讲解其中的关键模块作用,部分内容来自《丁师兄大模型》
爆肝GRPO算法,终于被我从头跑通了!
如有问题,欢迎评论区指正
背景知识:
本文将展示如何使用GRPO方法构建分布式强化学习(RL)流程,从而可以针对数学、逻辑和编程任务对语言模型进行微调,上述任务的特点在于存在一个唯一且正确的ground truth答案,可通过简单的字符串比较轻松加以验证,所以本教程的目标是将通用语言模型Qwen2.5-1.5B-Instruct转换为数学问题求解器。
GRPO相对于普通的PPO创新点在于省去了critic model的构建与训练,大大节省了内存开销,降低了训练难度。具体来说,它将token粒度上的预期收益(依托critic model预测得到)转变为nums_generations个输出对应的rewards的均值。
一般做PPO/DPO/GRPO等强化学习训练,关键点分为三部分:
- 输入输出数据与答案标签处理
- loss设置(PPO/GRPO关键点在于rewards model)
- 调参训练
本文也将按照这个结构来分析
一、输入输出数据与答案标签处理
项目定义了数据格式,以及模型如何从输出和数据集中提取答案段落。为了确保模型输出格式一致,项目还定义了一个系统提示。该提示指示模型生成包含
extract_answer_from_model_output:此函数获取模型的输出文本,并提取
extract_answer_from_dataset:此函数从GSM8K数据集中提取预期答案,该数据集使用“####”分隔符来分隔答案:
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""
def extract_answer_from_model_output(text):
"""
Extracts the value from the last <answer> tag in the text.
Args:
text (str): The model-generated text containing XML-style <answer> tags.
Returns:
str or None: The content inside the <answer> tags, or None if no valid answer is found.
Explanation:
1. Splits the text on the <answer> tag to isolate content after the tag.
2. Checks if at least one <answer> tag exists in the text.
3. For the last <answer> segment:
- Verifies it contains a closing </answer> tag.
- Extracts only the content between the tags.
4. Returns None if the answer is empty (just "...") or if tags are missing.
"""
# Split on <answer> and take everything after the last occurrence
parts = text.split("<answer>")
if len(parts) < 2: # No <answer> tag found
return None
last_part = parts[-1]
# Extract content up to </answer>
if "</answer>" not in last_part:
return None
answer = last_part.split("</answer>")[0].strip()
return None if answer == "..." else answer
def extract_answer_from_dataset(text):
"""
Extracts the answer from the GSM8K dataset examples.
Args:
text (str): The dataset example text containing a question and answer.
Returns:
str or None: The extracted answer part after the '####' delimiter, or None if not found.
Explanation:
1. Checks if the text contains the '####' delimiter that separates question from answer.
2. If found, splits the text at this delimiter and returns the second part (the answer).
3. The answer is stripped of leading/trailing whitespace.
4. Returns None if no delimiter is present.
"""
if "####" not in text:
return None
return text.split("####")[1].strip()
- 加载数据集GSM8K,格式化每个示例,包括系统提示和用户提示。
二、Loss设置
在PPO/GRPO中,loss设置的关键在于advantages,而advantages的关键在于rewards的计算,在GRPO中,rewards的计算与之前不同,不再需要用初始化sft后的模型,依靠构造的打分数据集进行fine-tuning,而是简单的通过答案准确性和格式准确性进行打分
1. 奖励模型构建
correctness_reward:这个函数根据生成的答案是否正确来分配奖励。
采用两种方式:精确的字符串匹配和数值等价检查,将模型输出的答案与预期答案进行比较。
完全匹配会获得更高的奖励(2.0),而基于数值等价的匹配会获得较小的奖励(1.5)
format_reward:这个函数鼓励模型遵循所需的类似XML的输出格式。
它为生成文本中存在
## part of correctness_reward
for r, a in zip(extracted, answer):
if r == a: # Exact match case
rewards.append(2.0)
else:
# Try numeric equivalence
r_num = extract_single_number(str(r))
a_num = extract_single_number(str(a))
if r_num is not None and a_num is not None and r_num == a_num:
rewards.append(1.5)
else:
rewards.append(0.0)
## part of format_reward
for response in responses:
score = 0.0
if "<reasoning>" in response: score += 0.2
if "</reasoning>" in response: score += 0.2
if "<answer>" in response: score += 0.2
if "</answer>" in response: score += 0.2
rewards.append(score)
format_scores.append(score)
## combined reward
combined_rewards = []
for c_score, f_score in zip(correctness_scores, format_scores):
# Correctness score range: 0.0 to 2.0
# Format score range: 0.0 to 0.8
# Total range: 0.0 to 2.8
combined_rewards.append(c_score + f_score)
2. 优势advantages计算
有了每个completion的rewards之后,advantages就很好计算了!
GRPO对应的advantages计算如下:
简单举个例子:
假设batch_size = 2, num_generations = 3,也就是GRPO模型每次产生三个输出
rewards = tensor([[1.0, 2.0, 3.0], # 样本 1 的 3 个奖励值
[4.0, 5.0, 6.0]]) # 样本 2 的 3 个奖励值
rewards.mean(dim=1) = tensor([2.0, 5.0]) # 样本 1 的平均奖励为 2.0,样本 2 的平均奖励为 5.0
mean_rewards = tensor([2.0, 2.0, 2.0, 5.0, 5.0, 5.0])
std_rewards = tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])
advantages = tensor([-1.0, 0.0, 1.0, -1.0, 0.0, 1.0])
最后advantages形状为(batch_size * num_generations,1)
3. 根据目标函数构建loss
DeepSeekMath的技术报告里给出了GRPO的目标函数(省略部分符号细节):
所以这个项目中的grpo_loss也是完全参照这个公式进行复现的,函数中的关键部分如下:
其中old_log_probs, ref_log_probs, token_log_probs都是通过compute_log_probs函数计算对数概率得来的,而这一函数的关键参数在于attention_mask(decoder模型生成需要的mask)以及logit_to_keep(指示了最终需要参与计算loss的token位置)
至于为何在softmax之后要接对数概率:
- 对数概率将概率值映射到对数空间,避免了数值下溢问题。
- 对数概率是单调递增的函数,因此比较对数概率的大小等价于比较原始概率的大小。
per_token_loss = surrogate_loss - beta * kl
这里的beta代表对于kl散度参与loss更新的惩罚系数
最后loss在(batch_size * num_generations,)维度上进行求和并平均。
三、调参开训
- num_iterations=1:从当前策略模型创建新参考模型的外部迭代次数。一次迭代是指在整个数据集上执行一次通过。
- num_steps=500:训练循环将执行最多500步,每个步骤处理一批样本。
- num_generations=4:对于训练数据中的每个提示词,训练器将生成4个不同的完成结果。如果你的GPU的VRAM较少,请减少此数字。
- max_completion_length=400:在生成完成结果(序列的response部分)时,生成上限为400个token。
- mu=3:对每个batch数据执行的策略更新次数。这里表示每个batch更新三次策略函数。
- epsilon=0.2:GRPO的PPO组件的clipping参数。这可以防止策略在单次更新中发生太大的变化。
训练后的一些指标如下:
本文原文来自CSDN