基于 Python 的自然语言处理系列(85):PPO 原理与实践

📌 本文介绍如何在 RLHF(Reinforcement Learning with Human Feedback)中使用 PPO(Proximal Policy Optimization)算法对语言模型进行强化学习微调。

🔗 官方文档:trl PPOTrainer

一、引言:PPO 在 RLHF 中的角色

        PPO(Proximal Policy Optimization)是一种常用的强化学习优化算法,它在 RLHF 的第三阶段发挥核心作用:通过人类偏好训练出的奖励模型对语言模型行为进行优化。我们将在本篇中详细介绍如何基于 Hugging Face 的 trl 库,结合 IMDb 数据集、情感分析奖励模型,完成完整的 PPO 训练流程。

二、环境依赖

pip install peft trl accelerate datasets transformers

三、配置 PPOConfig

from trl import PPOConfigppo_config = PPOConfig(model_name="lvwerra/gpt2-imdb",query_dataset="imdb",reward_model="sentiment-analysis:lvwerra/distilbert-imdb",learning_rate=1.41e-5,log_with=None,mini_batch_size=128,batch_size=128,target_kl=6.0,kl_penalty="kl",seed=0,
)

四、构建数据集与 Tokenizer

from datasets import load_dataset
from transformers import AutoTokenizer
from trl.core import LengthSamplerdef build_dataset(config, query_dataset, input_min_text_length=2, input_max_text_length=8):tokenizer = AutoTokenizer.from_pretrained(config.model_name, use_fast=True)tokenizer.pad_token = tokenizer.eos_tokends = load_dataset(query_dataset, split="train")ds = ds.rename_columns({"text": "review"})ds = ds.filter(lambda x: len(x["review"]) > 200)input_size = LengthSampler(input_min_text_length, input_max_text_length)def tokenize(sample):sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]sample["query"] = tokenizer.decode(sample["input_ids"])return sampleds = ds.map(tokenize)ds.set_format(type="torch")return dsdataset = build_dataset(ppo_config, ppo_config.query_dataset)

五、加载模型与参考模型(Ref Model)

from trl import AutoModelForCausalLMWithValueHeadmodel_cls = AutoModelForCausalLMWithValueHead
model = model_cls.from_pretrained(ppo_config.model_name)
ref_model = model_cls.from_pretrained(ppo_config.model_name)tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id

六、构建 PPOTrainer 与奖励模型

from trl import PPOTrainer
from transformers import pipelinedef collator(data):return dict((key, [d[key] for d in data]) for key in data[0])ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)

构建情感奖励模型

task, model_name = ppo_config.reward_model.split(":")
sentiment_pipe = pipeline(task, model=model_name, device=1 if torch.cuda.is_available() else "cpu", return_all_scores=True, function_to_apply="none", batch_size=16
)# 确保 tokenizer 设置 pad_token_id
sentiment_pipe.tokenizer.pad_token_id = tokenizer.pad_token_id
sentiment_pipe.model.config.pad_token_id = tokenizer.pad_token_id

七、执行 PPO 训练循环

 
from tqdm.auto import tqdm
import torchgeneration_kwargs = {"min_length": -1,"top_k": 0.0,"top_p": 1.0,"do_sample": True,"pad_token_id": tokenizer.eos_token_id,"max_new_tokens": 32,
}for step, batch in enumerate(tqdm(ppo_trainer.dataloader)):query_tensors = batch["input_ids"]response_tensors, ref_response_tensors = ppo_trainer.generate(query_tensors, return_prompt=False, generate_ref_response=True, **generation_kwargs)batch["response"] = tokenizer.batch_decode(response_tensors)batch["ref_response"] = tokenizer.batch_decode(ref_response_tensors)texts = [q + r for q, r in zip(batch["query"], batch["response"])]rewards = [torch.tensor(output[1]["score"]) for output in sentiment_pipe(texts)]ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]ref_rewards = [torch.tensor(output[1]["score"]) for output in sentiment_pipe(ref_texts)]batch["ref_rewards"] = ref_rewardsstats = ppo_trainer.step(query_tensors, response_tensors, rewards)ppo_trainer.log_stats(stats, batch, rewards, columns_to_log=["query", "response", "ref_response", "ref_rewards"])

八、总结与展望

        在本篇文章中,我们实现了以下核心步骤:

阶段描述
数据构建利用 IMDb 构造简短语料用于语言生成
模型构建加载 GPT2 并构建 Value Head 以评估奖励
奖励模型使用 DistilBERT 进行情感打分作为奖励信号
PPO 训练利用 TRL 中的 PPOTrainer 实现语言强化优化

        PPO 是 RLHF 中至关重要的一环,在人类反馈基础上不断微调模型的输出质量,是当前 ChatGPT、Claude 等大模型背后的关键技术之一。

        📘 下一篇预告:《基于 Python 的自然语言处理系列(86):DPO(Direct Preference Optimization)原理与实战》
        相比传统 RLHF 流程,DPO 提供了一种更简洁、无需奖励模型与 PPO 的替代方案,敬请期待!

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/78308.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

珍爱网:从降本增效到绿色低碳,数字化新基建价值凸显

2024年12月24日,法大大联合企业绿色发展研究院发布《2024签约减碳与低碳办公白皮书》,深入剖析电子签在推动企业绿色低碳转型中的关键作用,为企业实现环境、社会和治理(ESG)目标提供新思路。近期,法大大将陆…

Java实现HTML转PDF(deepSeekAi->html->pdf)

Java实现HTML转PDF,主要为了解决将ai返回的html文本数据转为PDF文件方便用户下载查看。 一、deepSeek-AI提问词 基于以上个人数据。总结个人身体信息,分析个人身体指标信息。再按一个月为维度,详细列举一个月内训练计划,维度详细至每周每天…

Estimands与Intercurrent Events:临床试验与统计学核心框架

1. Estimands(估计目标)概述 1.1 定义与作用 1.1.1 定义 Estimand是临床试验中需明确提出的科学问题,即研究者希望通过数据估计的“目标量”,定义“治疗效应”具体含义,确保分析结果与临床问题一致。 例如,在研究某种新药对高血压患者降压效果时,Estimand可定义为“在…

Jsp技术入门指南【十】IDEA 开发环境下实现 MySQL 数据在 JSP 页面的可视化展示,实现前后端交互

Jsp技术入门指南【十】IDEA 开发环境下实现 MySQL 数据在 JSP 页面的可视化展示,实现前后端交互 前言一、JDBC 核心接口和类:数据库连接的“工具箱”1. 常用的 2 个“关键类”2. 必须掌握的 5 个“核心接口” 二、创建 JDBC 程序的步骤1. 第一步&#xf…

深入理解HotSpot JVM 基本原理

关于JAVA Java编程语言是一种通用的、并发的、面向对象的语言。它的语法类似于C和C++,但它省略了许多使C和C++复杂、混乱和不安全的特性。 Java 是几乎所有类型的网络应用程序的基础,也是开发和提供嵌入式和移动应用程序、游戏、基于 Web 的内容和企业软件的全球标准。. 从…

【HTTP/3:互联网通信的量子飞跃】

HTTP/3:互联网通信的量子飞跃 如果说HTTP/1.1是乡村公路,HTTP/2是现代高速公路系统,那么HTTP/3就像是一种革命性的"传送门"技术,它彻底重写了数据传输的底层规则,让信息几乎可以瞬间抵达目的地,…

Apipost免费版、企业版和私有化部署详解

Apipost是企业级的 API 研发协作一体化平台,为企业提供 API研发测试管理全链路解决方案,不止于API研发场景,增强企业API资产管理。 Apipost 基于同一份数据源,同时提供给后端开发、前端开发、测试人员使用的接口调试、Mock、自动化…

使用若依二次开发商城系统-1:搭建若依运行环境

前言 若依框架有很多版本,这里使用的是springboot3vue3这样的一个前后端分离的版本。 一.操作步骤 1 下载springboot3版本的后端代码 后端springboot3的代码路径,https://gitee.com/y_project/RuoYi-Vue 需要注意我们要的是springboot3分支。 先用g…

速成GO访问sql,个人笔记

更多个人笔记:(仅供参考,非盈利) gitee: https://gitee.com/harryhack/it_note github: https://github.com/ZHLOVEYY/IT_note 本文是基于原生的库 database/sql进行初步学习 基于ORM等更多操作可以关注我…

【C++指南】告别C字符串陷阱:如何实现封装string?

🌟 各位看官好,我是egoist2023! 🌍 种一棵树最好是十年前,其次是现在! 💬 注意:本章节只详讲string中常用接口及实现,有其他需求查阅文档介绍。 🚀 今天通过了…

系统架构师2025年论文《论软件架构评估2》

论软件系统架构评估 v2.0 摘要: 某市医院预约挂号系统建设推广应用项目是我市卫生健康委员会 2019 年发起的一项医疗卫生行业便民惠民信息化项目,目的是实现辖区内患者在辖区各公立医疗机构就诊时,可以通过多种线上渠道进行预约挂号,提升就医体验。我作为系统架构师参与此…

BEVDet4D: Exploit Temporal Cues in Multi-camera 3D Object Detection

背景 对于现有的BEVDet方法,它对于速度的预测误差要高于基于点云的方法,对于像速度这种与时间有关的属性,仅靠单帧数据很难预测好。因此本文提出了BEVDet4D,旨在获取时间维度上的丰富信息。它是在BEVDet的基础上进行拓展,保留了之前帧的BEV特征,并将其进行空间对齐后与当…

el-upload 上传逻辑和ui解耦,上传七牛

解耦的作用在于如果后面要我改成从阿里云oss上传文件,我只需要实现上传逻辑从七牛改成阿里云即可,其他不用动。实现方式有2部分组成,一部分是上传逻辑,一部分是ui。 上传逻辑 大概逻辑就是先去服务端拿上传token和地址&#xff0…

酒水类目电商代运营公司-品融电商:全域策略驱动品牌长效增长

酒水类目电商代运营公司-品融电商:全域策略驱动品牌长效增长 在竞争日益激烈的酒水市场中,品牌如何快速突围并实现长效增长?品融电商凭借「效品合一 全域增长」方法论与全链路运营能力,成为酒水类目代运营的领跑者。从品牌定位、视…

机器学习特征工程中的数值分箱技术:原理、方法与实例解析

标题:机器学习特征工程中的数值分箱技术:原理、方法与实例解析 摘要: 分箱技术作为机器学习特征工程中的关键环节,通过将数值数据划分为离散区间,能够有效提升模型对非线性关系的捕捉能力,同时增强模型对异…

【MySQL专栏】MySQL数据库的复合查询语句

文章目录 1、首先练习MySQL基本语句的练习①查询工资高于500或岗位为MANAGER的雇员,同时还要满足他们的姓名首字母为大写的J②按照部门号升序而雇员的工资降序排序③使用年薪进行降序排序④显示工资最高的员工的名字和工作岗位⑤显示工资高于平均工资的员工信息⑥显…

Python爬虫(5)静态页面抓取实战:requests库请求头配置与反反爬策略详解

目录 一、背景与需求‌二、静态页面抓取的核心流程‌三、requests库基础与请求头配置‌3.1 安装与基本请求3.2 请求头核心参数解析‌3.3 自定义请求头实战 四、实战案例:抓取豆瓣读书Top250‌1. 目标‌2. 代码实现3. 技术要点‌ 五、高阶技巧与反反爬策略‌5.1 动态…

HTML给图片居中

在不同的布局场景下&#xff0c;让 <img> 元素居中的方法有所不同。下面为你介绍几种常见的居中方式 1. 块级元素下的水平居中 如果 <img> 元素是块级元素&#xff08;可以通过 display: block 设置&#xff09;&#xff0c;可以使用 margin: 0 auto 来实现水平居…

【高频考点精讲】前端构建工具对比:Webpack、Vite、Rollup和Parcel

前端构建工具大乱斗:Webpack、Vite、Rollup和Parcel谁是你的菜? 【初级】前端开发工程师面试100题(一) 【初级】前端开发工程师面试100题(二) 【初级】前端开发工程师的面试100题(速记版) 最近在后台收到不少同学提问:“老李啊,现在前端构建工具这么多,我该选哪个?…

赶紧收藏!教您如何用 GitHub 账号,获取永久免费的 Docker 容器!!快速搭建我们的网站/应用!

文章目录 📖 介绍 📖🏡 演示环境 🏡📒 永久免费的 Docker 容器 📒🚀 注册与登录➕ 创建 Docker 容器💻 部署你的网站🔑 注意事项💡 使用场景⚓️ 相关链接 ⚓️📖 介绍 📖 还在为搭建个人网站寻找免费方案而烦恼? 今天发现一个宝藏平台!只需一个 Git…