从零实现 REINFORCE/GRPO —— 大模型推理强化微调实践

news/2025/11/17 23:00:51/文章来源:https://www.cnblogs.com/fangpin/p/19234732

一文吃透:不依赖成熟 RL 库,如何实现 REINFORCE、REINFORCE-baseline 与 GRPO;数理推理模型的强化学习微调实践;训练/参考/采样模型的多卡调度。

引言

你是否也遇到过:模型“会思考”,但少数题正确,格式还常常不合规?我在 Qwen/Qwen2.5-Math-1.5B 上亲历这一痛点——zero-shot 在 GSM8K 只有约 1%。本文分享我从零实现的llm-from-scratch 仓库中的 alignment 模块,从零实现 REINFORCE、带基线的 REINFORCE 与 GRPO,把准确率稳定拉升到 63.4%,并把训练策略模型、参考模型与采样模型拆到不同 GPU 上高效协同。

读完你将掌握:奖励设计与计算、方差降低的工程化做法、GRPO 的分组偏好更新与裁剪、以及一套可复现的多卡调度与评估 pipeline。

阅读本文前,强烈建议先阅读我的前一篇介绍 SFT 的文章:从 1.56% 到 62.9%:SFT 推理微调优化实战

如果你对强化学习的知识还不熟悉的话,可以参考我之前的文章:强化学习从入门到放弃 —— 跟着 OpenAI 学强化学习

我的其他系列文章:

  • 从 0 搭建 LLM 不再难!这个 PyTorch 项目帮你吃透大模型底层逻辑
  • 突破性能瓶颈:深入解析 Flash Attention 内核优化技术
  • 深入解析:使用 Triton 实现 Flash Attention2 - 让大模型训练飞起来
  • 手撸大模型的分布式训练:深刻理解大模型训练的“起飞”原理
  • 从0到1:揭秘LLM预训练前的海量数据清洗全流程
  • 从 1.56% 到 62.9%:SFT 推理微调实战

问题与目标

  • 问题:zero-shot 推理准确率极低(~1%),且格式不稳定,难以可靠评估与训练。
  • 目标:设计严格且高召回的奖励函数,配合从零实现的策略梯度与 GRPO,逐步把 Qwen2.5-Math-1.5B 在 GSM8K 的 zero-shot 准确率提升到 63.4%。

我只在必要处简述 SFT,上文已分析过;本文将把重点落在强化学习微调(RLFT)的训练循环与实现细节。


奖励与格式:答案正确 + 格式遵循的双指标

本仓库采用 R1 风格的格式约束与数学习题的严格判定。核心在 alignment/drgrpo_grader.pyr1_zero_reward_fn

  • 格式要求:必须出现 </think> <answer></answer>。不合格式直接奖励为 0。
  • 答案正确:借助多种规范化与符号等价检查(含 LaTeX 解析、Sympy 简化、数值近似),保障较高召回率。
  • 组合奖励:format_rewardanswer_reward 都满足则总奖励 reward=1,否则为 0。评估脚本会统计三者的平均值。

代码摘录(路径与作用标注):

  • 文件:llm-from-scratch/alignment/drgrpo_grader.py
  • 作用:计算 GSM8K 的格式与正确性奖励
  • 完整代码参考:llm-from-scratch
def r1_zero_reward_fn(response, ground_truth, fast=True):# We are strict about format to evaluate our models.if "</think> <answer>" in response and "</answer>" in response:model_answer = response.split("<answer>")[-1].replace("</answer>", "")if "\\boxed" in model_answer:model_answer = extract_answer(model_answer)if model_answer is None:return {"format_reward": 1.0, "answer_reward": 0.0, "reward": 0.0}# 严格的数学等价判断(字符串规范化、Sympy、可选 math_verify)is_correct = grade(model_answer, str(ground_truth), fast)if is_correct:return {"format_reward": 1.0, "answer_reward": 1.0, "reward": 1.0}else:# 格式正确但答案错:不给格式奖励以避免投机return {"format_reward": 1.0, "answer_reward": 0.0, "reward": 0.0}else:# 未按格式输出return {"format_reward": 0.0, "answer_reward": 0.0, "reward": 0.0}

在评估侧,alignment/evaluate.pyevaluate_vllm 会将 avg_format_rewardsavg_answer_rewardsavg_all_rewards 打印出来,其中 avg_all_rewards 近似于最终的准确率。


算法与实现:REINFORCE、REINFORCE-baseline、GRPO

1) 分组优势(baseline)与方差降低

  • 分组思想:针对每道题的一个 prompt,我们采样 group_size 个响应;每组共享同一 ground truth。
  • 优势计算:先把每组的 reward 减去组内均值(可选再除以组内标准差),得到“相对优势”。这就是 REINFORCE-baseline 在本代码中的实现思想。

代码摘录:

  • 文件:llm-from-scratch/alignment/grpo.py
  • 作用:把原始奖励转换为分组归一化优势(baseline)
  • 完整代码参考:llm-from-scratch
def compute_group_normalized_rewards(reward_fn, rollout_responses, repeated_ground_truths,group_size, advantage_eps, normalize_by_std=True,
):# 逐个样本计算原始 rewardrewards = [reward_fn(resp, gt) for resp, gt in zip(rollout_responses, repeated_ground_truths)]raw_rewards = torch.tensor([r["reward"] for r in rewards], dtype=torch.float32)# 折叠成 [n_prompts, group_size]advantages = raw_rewards.view(-1, group_size)# 基线:减去组均值(可选再除以组内 std)mean_advantages = einx.mean("n_prompts group_size -> n_prompts 1", advantages)advantages = advantages - mean_advantagesif normalize_by_std:std_advantages = torch.std(advantages, dim=1, unbiased=True, keepdim=True)advantages = advantages / (std_advantages + advantage_eps)return advantages.view(-1), raw_rewards, {"mean_advantages": mean_advantages}

这里的“分组减均值”就是减少方差的经典做法:在一个小团队中,我们只关注“比团队平均更好/更差”的相对表现,从而让梯度更稳定。

一个生动类比:

  • 想象一队球员在同一场馆、同一光照下投篮,每人投 10 球。当天的“场馆状态”可能会让所有人整体发挥偏高或偏低。如果我们用“每个球员的命中率减去团队平均命中率”来评价个人表现,这样就抵消了当天环境的整体波动。这就是 baseline 的直觉来源。

2) 三种损失的并行实现选择

在本仓库里,三种损失类型通过统一的入口 compute_policy_gradient_loss 分发:

  • no_baseline: 纯 REINFORCE,用原始 reward 直接乘 log_prob
  • reinforce_with_baseline: 带基线的 REINFORCE,用 advantageslog_prob
  • grpo_clip: GRPO 风格裁剪,计算 policy_log_probs - old_log_probs 的比率,并按 cliprange 做截断。

代码摘录:

  • 文件:llm-from-scratch/alignment/grpo.py
  • 作用:三种损失的核心分发逻辑(对应三种算法)
  • 完整代码参考:llm-from-scratch
def compute_policy_gradient_loss(policy_log_probs, loss_type,raw_rewards=None, advantages=None,old_log_probs=None, cliprange=None,
):if loss_type == "no_baseline":# 纯 REINFORCEper_token_loss = compute_naive_policy_gradient_loss(raw_rewards, policy_log_probs)meta = {}elif loss_type == "reinforce_with_baseline":# REINFORCE + baseline(方差更低)per_token_loss = compute_naive_policy_gradient_loss(advantages, policy_log_probs)meta = {}else:  # grpo_clip# GRPO 的裁剪型偏好优化(近似 PPO 风格)assert advantages is not None and old_log_probs is not None and cliprange is not Noneper_token_loss, meta = compute_grpo_clip_loss(advantages, policy_log_probs, old_log_probs, cliprange)return per_token_loss, meta

注意:本实现对 REINFORCE 与带基线的 REINFORCE 以“逐 token 的 log_prob 乘以标量优势/奖励”的统一形式实现;GRPO 则引入参考策略的 old_log_probs 与裁剪,避免策略更新过激。

3) 训练微批与掩码:只在响应段回传梯度

强化学习微调中,我们只希望对模型的“响应段”进行优化,而不是把 prompt 也算进损失。grpo_microbatch_train_step 通过 response_mask 做掩码平均,并在梯度累积场景下自动缩放 loss。

代码摘录:

  • 文件:llm-from-scratch/alignment/train_rl.py
  • 作用:构建只在响应部分为 True 的掩码
  • 完整代码参考:llm-from-scratch
# --- Response Mask for Microbatch ---
response_mask = torch.zeros_like(mb_input_ids, dtype=torch.bool)
for j in range(len(response_mask)):start = mb_prompt_lengths[j].item()# Use attention mask sum for the end to handle padding correctlyend_pos = mb_attention_mask[j].sum().item()response_mask[j, start:end_pos] = Trueresponse_mask &= mb_input_ids != tokenizer.pad_token_id

多卡调度:采样/参考/评估/训练的设备分工与数据流

alignment/train_rl.py 中,我们将四种角色拆到不同设备:

  • vLLM 采样模型:负责 rollout,用 args.sample_device(默认 cuda:7
  • 参考模型(旧策略):冻结、只做 old_log_probs,用 args.reference_model_device(默认 cuda:6
  • 评估模型:周期性评估,args.eval_device(默认 cuda:5
  • 训练政策模型:主力,手动构建 device_map 把层均衡分布到剩余 GPU 上

代码摘录(设备映射的关键片段):

  • 文件:llm-from-scratch/alignment/train_rl.py
  • 作用:为策略模型手工构造跨多 GPU 的平衡 device_map
  • 完整代码参考:llm-from-scratch
def partition_model_across_devices(args) -> dict[str, int]:total_gpu_count = torch.cuda.device_count()# 预留采样/参考/评估三张卡sample_device_idx = int(args.sample_device.split(":")[-1])ref_device_idx    = int(args.reference_model_device.split(":")[-1])eval_device_idx   = int(args.eval_device.split(":")[-1])reserved_indices  = {sample_device_idx, ref_device_idx, eval_device_idx}policy_gpu_indices = sorted(list(set(range(total_gpu_count)) - reserved_indices))# 主 GPU 放嵌入、lm_head、最终 norm;其余均匀分层main_gpu = policy_gpu_indices[0]layer_gpus = policy_gpu_indices[1:] or [main_gpu]num_layers = AutoConfig.from_pretrained(args.model, trust_remote_code=True).num_hidden_layerslayers_per_gpu = math.ceil(num_layers / len(layer_gpus))device_map = {"model.embed_tokens": main_gpu,"lm_head": main_gpu,"model.norm": main_gpu,}gpu_idx_for_layers = 0for i in range(num_layers):if i > 0 and i % layers_per_gpu == 0:gpu_idx_for_layers += 1device_map[f"model.layers.{i}"] = layer_gpus[gpu_idx_for_layers]return device_map

动手复现:数据获取与 uv 入口命令

请在仓库根目录准备数据集(GSM8K 原始 JSONL):

cd dataset
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl
wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl

三步命令跑通评估、SFT 与 RL 微调:

uv run -m alignment.evaluate
uv run -m alignment.sft
uv run -m alignment.train_rl
  • alignment.evaluate 会打印 avg_format_rewards / avg_answer_rewards / avg_all_rewards,其中 avg_all_rewards 即准确率估计。复现实验中,我们从 ~1% 提升到 63.4%。
  • alignment.sft 产出 checkpoints/math_sft,RL 微调默认以此为底座(见 train_rl.py 第 94–125 行)。
  • alignment.train_rl 支持三种算法,可通过 --loss_typeno_baseline | reinforce_with_baseline | grpo_clip 间切换;并可用 --group_size 控制每题采样个数(默认 4),--cliprange 控制 GRPO 裁剪强度(默认 0.2)。

算法对比速查表(本代码库默认与常用超参)

算法 损失入口 是否用 baseline 是否用参考策略 关键超参 适用场景与特点
REINFORCE loss_type=no_baseline 否(直接用 raw reward) rollout_batch_sizetrain_batch_size 实现最简单,但方差较大,稳定性受限
REINFORCE-baseline loss_type=reinforce_with_baseline 是(组内减均值,选配除以 std) group_sizeadvantage_epsuse_std_normalization 方差明显更低,收敛更稳;本库默认只减均值(use_std_normalization=False
GRPO-clip loss_type=grpo_clip 是(同上) 是(old_log_probs cliprange(默认 0.2)、group_size 偏好优化 + 裁剪,抑制过激更新,经验上在数学推理上更稳健

补充常用训练参数(见 alignment/args.py):

  • 采样温度/Top-p:--sampling_temperature=1.0--sampling_top_p=0.9
  • 微批与累积:--train_mini_batch_size=8,按总样本数自动计算 grad_acc_steps
  • 设备分工:--sample_device=cuda:7--reference_model_device=cuda:6--eval_device=cuda:5

结论与行动

  • 关键心得:
    • 奖励要“既严格又宽容”:格式必须满足,正确性用多路等价判断提高召回;
    • REINFORCE-baseline 的“分组减均值”能显著降低方差;GRPO 的裁剪进一步稳定更新;
    • 多卡把采样/参考/评估拆离训练主力卡,避免竞争与干扰;策略模型分层均衡放置,解决大模型训练的内存压力。
  • 结果:在本仓库代码下,我把 Qwen2.5-Math-1.5B 在 GSM8K 的 zero-shot 从约 1% 提升到 63.4%。

现在就试试吧:准备数据、跑三条 uv run 命令,观察 avg_all_rewards 的提升曲线。如果你有更激进或更细致的奖励设计、调度策略,欢迎在 issues 或评论区交流。完整代码参考:llm-from-scratch

一个开放问题:在更大模型或更复杂推理数据集上,GRPO 的分组大小与裁剪范围如何动态自适应?你的经验是什么?我很期待你的分享。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/968385.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java for linux 下载

在 Linux 系统上下载 Java,通常有以下几种常见方式:? 方法一:使用 yum(Debian/Ubuntu)安装 Java sudo apt update sudo apt install openjdk-17-jdkopenjdk-17-jdk 是 Java 17 的 JDK(Java Development Kit)。…

13 个 pytest 宝藏插件推荐!(存存存)

pytest-html 核心:一键生成直观的 HTML 测试报告,便于结果查看与分享。allure-pytest 核心:支持步骤描述、附件、分类等,比pytest-html更高级,让测试结果一目了然。pytest-xdist 核心:支持并行测试,通过将测试用…

java 1.8 linux

如果您是在使用 Java 1.8 在 Linux 系统上进行开发或运行,以下是一些常见的命令和脚本示例,帮助您管理 Java 1.8 环境:一、安装 Java 1.8 1. 使用 yum(适用于基于 RHEL/CentOS/Scientific Linux 的系统) sudo yum…

iOS开发Linux

你提到的“iOS开发Linux”可能是一个误解或混淆。iOS 是苹果公司开发的移动操作系统,主要用于苹果设备(如 iPhone、iPad、Apple Watch 等),而 Linux 是一个开源的操作系统,主要用于服务器、嵌入式系统、桌面计算机…

手撸大模型的分布式训练:深刻理解大模型训练的“起飞”原理

单卡不够?内存爆炸?训练太慢? 在大型语言模型(LLM)的训练过程中,单设备算力和内存往往成为性能瓶颈。如何高效地利用多GPU甚至多节点资源进行分布式训练,是每个LLM研究者和工程师必须面对的挑战。本文将深入剖析…

XHORSE XZBT42EN 2-Button HON.D PCBs for Honda Fit XR-V Jazz City 2018-2022 (5pcs/lot)

Solving Honda Remote Key PCB Issues: The XHORSE XZBT42EN Advantage Is your Honda Fit, XR-V, Jazz, or City struggling with unresponsive remote controls? For European and American automotive repair prof…

事件循环其实很简单!

一、概念 JavaScript 是单线程执行(基于执行栈 / 调用栈 call stack),事件循环负责不断地从各种任务队列里取任务执行——以保证异步任务的函数回调按规则有序运行,浏览器环境和 Node.js 环境都使用事件循环,尽管…

从0到1:揭秘LLM预训练前的海量数据清洗全流程

读完这篇文章,你将用监督微调(SFT)把一个 1.5B 规模的数学模型在 GSM8K 上的零样本推理正确率从 1.56% → 62.9%,同时把输出格式遵循率从 18.9% → 100%。我们将完整走通数据集下载、Prompt 架构、训练配置和评估方…

Upgrade Your Key Programming: New Style CG A22-3+1 Flip-4BTN Wire Remote for CGDI K2 (5pcs/lot)

The Frustration of Unreliable Key Remotes: A Problem for Mechanics and Car Owners Alike In the bustling world of automotive repair, few issues frustrate European and American mechanics more than unreli…

深入解析:使用 Triton 实现 Flash Attention2 - 让大模型训练飞起来

引言 你是否曾经在训练大型语言模型时,眼睁睁地看着 GPU 内存不断飙升,最终因为 OOM(Out of Memory)错误而前功尽弃?或者在处理长序列时,发现注意力机制的计算时间呈平方级增长,让人望而却步? 如果你有过这样的…

AI技术落地实践

好的,这是一个极具前瞻性的问题,充分体现了您对技术趋势的敏锐度。下面我将详细阐述我们在AI技术落地,特别是前端与AI结合方面的完整思考与实践。8. AI技术落地实践 第一部分:SQL编辑器集成LLM的完整实践 1. 技术选…

Day22flex布局

1.felx的组成<!DOCTYPE html> <html lang="en"> <head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-scale=1…

CF2169A题解

贪心传送门:https://codeforces.com/problemset/problem/2169/A 将数组排序,如下情况:\(11\ 12\ 13\ 14\ 14\ 15\),假设 \(a=14\),我们发现我们无论如何选择只能选取 \(a\) 左边或右边的数,又因为平局不算分,贪…

re.compile为什么能提高速度?

re.compile(pattern, flags=0) 的核心作用是 “编译正则表达式模式,生成可重复使用的 Pattern 对象”——本质是把正则字符串“编译”成正则引擎可直接执行的“字节码”,核心价值是 提升重复使用时的效率 + 简化代码…

从 0 搭建 LLM 不再难!这个 PyTorch 项目帮你吃透大模型底层逻辑

如果你曾想深入理解大语言模型(LLM)的 “五脏六腑”,却被框架封装的黑盒接口、复杂的源码结构劝退;如果你希望亲手实现 Transformer 的每一个组件,而非单纯调用transformers库 —— 那么今天推荐的这个开源项目,…

题解:P8819 [CSP-S 2022] 星战

CSP-S 2022 T3 和哈希 trick你说的对,但是, “不可以,总司令!” 这是一个神秘 trick,它的模板题是 P3560,可以先把这个题写了或者先把星战写了再写模板。 题意简述 题目链接 给出 \(n\) 个点 \(m\) 条边的有向图…

instr在mysql索引中作用是什么

在MySQL中,instr函数并不是直接用于创建或管理索引的。然而,instr函数可以用于查询字符串中的子串位置,这在某些情况下可能与索引的使用相关。instr函数用于返回子字符串在字符串中第一次出现的位置。如果子字符串不…

initrans参数在oracle高并发环境下的作用

initrans 参数在 Oracle 数据库中用于设置数据库实例启动时的事务处理并发控制器的初始数量。这个参数对于高并发环境下的数据库性能至关重要,因为它直接影响到数据库能够同时处理的事务数量。在高并发环境下,多个用…

Java集合之【CopyOnWrite和Collections.synchronizedList()的区别】

CopyOnWriteArrayList 介绍 什么是 CopyOnWriteArrayList 适合读多写少的场景 是一个线程安全的List实现,特点是写时复制 当CopyOnWriteArrayList进行修改操作(如add,set,remove)的时候,会复制原数组的值到创建的新…

20232324 2024-2025-1 《网络与系统攻防技术》实验六实验报告

20232324 2024-2025-1 《网络与系统攻防技术》实验六实验报告1.实验内容 1.1靶机探测:主机、端口及漏洞扫描 通过Metasploit的Aux模块中arp_sweep工具完成主机发现;端口扫描可选用nmap工具,或Metasploit的Aux模块中…