基于 Python 的自然语言处理系列(82):Transformer Reinforcement Learning

🔗 本文所用工具:trltransformerspeftbitsandbytes
📘 官方文档参考:https://huggingface.co/docs/trl

一、引言:从有监督微调到 RLHF 全流程

        随着语言大模型的发展,如何在大规模预训练模型基础上更精细地对齐人类偏好,成为了研究与应用的热点。本文将介绍一套完整的 RLHF(Reinforcement Learning with Human Feedback)训练流程,基于 Hugging Face 推出的 trl 库,从 SFT(Supervised Fine-tuning)、RM(Reward Modeling)、到 PPO(Proximal Policy Optimization)三大阶段,逐步实现对 Transformer 模型的强化学习优化。

        本篇聚焦于 SFT 阶段的实现,并以 Hugging Face 提供的 instruction-dataset 为例,介绍如何使用 trl 和 PEFT(参数高效微调)技术训练一个高效对齐指令的语言模型。

二、安装与环境准备

        确保安装以下库(建议使用 PyTorch + CUDA 环境):

pip install trl transformers datasets peft bitsandbytes accelerate

三、加载并准备数据集

        本例使用 HuggingFaceH4 团队整理的 instruction-dataset

from datasets import load_datasetdataset = load_dataset("HuggingFaceH4/instruction-dataset")
dataset = dataset.remove_columns("meta")  # 移除无用字段
dataset

四、构建模型及量化配置(4-bit)

        使用 BitsAndBytesConfig 对模型进行 4-bit 量化,可大幅降低显存占用:

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import prepare_model_for_kbit_trainingmodel_name = "lmsys/fastchat-t5-3b-v1.0"bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,
)base_model = AutoModelForCausalLM.from_pretrained(model_name,torch_dtype=torch.bfloat16,quantization_config=bnb_config
)base_model.config.use_cache = False
base_model = prepare_model_for_kbit_training(base_model)

五、注入 LoRA 参数高效微调机制

        首先识别所有 4-bit 线性模块并定义 LoRA 参数配置:

import bitsandbytes as bnb
from peft import get_peft_model, LoraConfigdef find_all_linear_names(model):cls = bnb.nn.Linear4bitlora_module_names = set()for name, module in model.named_modules():if isinstance(module, cls):names = name.split(".")lora_module_names.add(names[0] if len(names) == 1 else names[-1])return list(lora_module_names)peft_config = LoraConfig(r=128,lora_alpha=16,target_modules=find_all_linear_names(base_model),lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",
)base_model = get_peft_model(base_model, peft_config)

        打印可训练参数占比:

def print_trainable_parameters(model):trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)total = sum(p.numel() for p in model.parameters())print(f"Trainable params: {trainable} / {total} ({trainable / total:.2%})")print_trainable_parameters(base_model)

六、定义 Prompt 格式化函数

        将数据集中的 promptcompletion 格式化为统一格式:

def formatting_prompts_func(example):return [f"### Input: ```{prompt}```\n ### Output: {completion}"for prompt, completion in zip(example["prompt"], example["completion"])]

七、训练参数设置与 SFTTrainer 训练器

        使用 SFTTrainer 执行指令微调训练,支持 gradient checkpointing、cosine 学习率调度等高级策略:

from transformers import TrainingArguments
from trl import SFTTraineroutput_dir = "./results"training_args = TrainingArguments(output_dir=output_dir,per_device_train_batch_size=4,gradient_accumulation_steps=4,gradient_checkpointing=True,max_grad_norm=0.3,num_train_epochs=15,learning_rate=2e-4,bf16=True,save_total_limit=3,logging_steps=10,optim="paged_adamw_32bit",lr_scheduler_type="cosine",warmup_ratio=0.05,
)tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"trainer = SFTTrainer(model=base_model,train_dataset=dataset,tokenizer=tokenizer,max_seq_length=2048,formatting_func=formatting_prompts_func,args=training_args
)

        执行训练:

trainer.train()
trainer.save_model(output_dir)

        保存最终模型权重与 tokenizer:

import os
final_output_dir = os.path.join(output_dir, "final_checkpoint")
trainer.model.save_pretrained(final_output_dir)
tokenizer.save_pretrained(final_output_dir)

八、小结与展望

        通过本文,我们使用 trl 工具链完成了 RLHF 的第一阶段:SFT 有监督微调。你可以根据项目实际需求,替换为自定义数据集或更大规模模型。后续步骤(RM 奖励建模 + PPO 策略优化)将在下一篇继续介绍。

📌 下一篇预告

        📘《基于 Python 的自然语言处理系列(83):InstructGPT》

        敬请期待!

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/80227.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JAVA猜数小游戏

import java.util.Random; import java.util.Scanner;public class HelloWorld {public static void main(String[] args) {Random rnew Random();int luck_number r.nextInt(100)1;while (true){System.out.println("输入猜数字");Scanner sc new Scanner(System…

GPU渲染阶段介绍+Shader基础结构实现

GPU是什么 (CPU)Center Processing Unit:逻辑编程 (GPU)Graphics Processing Unit:图形处理(矩阵运算,数据公式运算,光栅化) 渲染管线 渲染管线也称为渲染流水线&#x…

Spring Boot + MyBatis 动态字段更新方法

在Spring Boot和MyBatis中,实现动态更新不固定字段的步骤如下: 方法一:使用MyBatis动态SQL(适合字段允许为null的场景) 定义实体类 包含所有可能被更新的字段。 Mapper接口 定义更新方法,参数为实体对象&…

单例模式:确保唯一实例的设计模式

单例模式:确保唯一实例的设计模式 一、模式核心:保证类仅有一个实例并提供全局访问点 在软件开发中,有些类需要确保只有一个实例(如系统配置类、日志管理器),避免因多个实例导致状态混乱或资源浪费。 单…

UnoCSS原子CSS引擎-前端福音

UnoCSS是一款原子化的即时按需 CSS 引擎,其中没有核心实用程序,所有功能都是通过预设提供的。默认情况下UnoCSS应用通过预设来实现相关功能。 UnoCSS中文文档: https://www.unocss.com.cn 前有很多种原子化的框架,例如 Tailwind…

【Qwen2.5-VL 踩坑记录】本地 + 海外账号和国内账号的 API 调用区别(阿里云百炼平台)

API 调用 阿里云百炼平台的海内外 API 的区别: 海外版:需要进行 API 基础 URL 设置国内版:无需设置。 本人的服务器在香港,采用海外版的 API 时,需要进行如下API端点配置 / API基础URL设置 / API客户端配置&#xf…

C语言笔记(鹏哥)上课板书+课件汇总(结构体)-----数据结构常用

结构体 目录: 1、结构体类型声明 2、结构体变量的创建和初始化 3、结构体成员访问操作符 4、结构体内存对齐*****(重要指数五颗星) 5、结构体传参 6、结构体实现位段 一、结构体类型声明 其实在指针中我们已经讲解了一些结构体内容了&…

UV: Python包和项目管理器(从入门到不放弃教程)

目录 UV: Python包和项目管理器(从入门到不放弃教程)1. 为什么用uv,而不是conda或者pip2. 安装uv(Windows)2.1 powershell下载2.2 winget下载2.3 直接下载安装包 3. uv教程3.1 创建虚拟环境 (uv venv) 4. uvx5. 此pip非…

网络开发基础(游戏方向)之 概念名词

前言 1、一款网络游戏分为客户端和服务端两个部分,客户端程序运行在用户的电脑或手机上,服务端程序运行在游戏运营商的服务器上。 2、客户端和服务端之间,服务端和服务端之间一般都是使用TCP网络通信。客户端和客户端之间通过服务端的消息转…

java将pdf转换成word

1、jar包准备 在项目中新增lib目录&#xff0c;并将如下两个文件放入lib目录下 aspose-words-15.8.0-jdk16.jar aspose-pdf-22.9.jar 2、pom.xml配置 <dependency><groupId>com.aspose</groupId><artifactId>aspose-pdf</artifactId><versi…

【C/C++】插件机制:基于工厂函数的动态插件加载

本文介绍了如何通过 C 的 工厂函数、动态库&#xff08;.so 文件&#xff09;和 dlopen / dlsym 实现插件机制。这个机制允许程序在运行时动态加载和调用插件&#xff0c;而无需在编译时知道插件的具体类型。 一、 动态插件机制 在现代 C 中&#xff0c;插件机制广泛应用于需要…

【音视频】AAC-ADTS分析

AAC-ADTS 格式分析 AAC⾳频格式&#xff1a;Advanced Audio Coding(⾼级⾳频解码)&#xff0c;是⼀种由MPEG-4标准定义的有损⾳频压缩格式&#xff0c;由Fraunhofer发展&#xff0c;Dolby, Sony和AT&T是主 要的贡献者。 ADIF&#xff1a;Audio Data Interchange Format ⾳…

机器学习 Day12 集成学习简单介绍

1.集成学习概述 1.1. 什么是集成学习 集成学习是一种通过组合多个模型来提高预测性能的机器学习方法。它类似于&#xff1a; 超级个体 vs 弱者联盟 单个复杂模型(如9次多项式函数)可能能力过强但容易过拟合 组合多个简单模型(如一堆1次函数)可以增强能力而不易过拟合 集成…

通过爬虫方式实现头条号发布视频(2025年4月)

1、将真实的cookie贴到代码目录中toutiaohao_cookie.txt文件里,修改python代码里的user_agent和video_path, cover_path等变量的值,最后运行python脚本即可; 2、运行之前根据import提示安装一些常见依赖,比如requests等; 3、2025年4月份最新版; 代码如下: import js…

Linux ssh免密登陆设置

使用 ssh-copy-id 命令来设置 SSH 免密登录&#xff0c;并确保所有相关文件和目录权限正确设置&#xff0c;可以按照以下步骤进行&#xff1a; 步骤 1&#xff1a;在源服务器&#xff08;198.120.1.109&#xff09;生成 SSH 密钥对 如果还没有生成 SSH 密钥对&#xff0c;首先…

《让机器人读懂你的心:情感分析技术融合奥秘》

机器人早已不再局限于执行简单机械的任务&#xff0c;人们期望它们能像人类伙伴一样&#xff0c;理解我们的喜怒哀乐&#xff0c;实现更自然、温暖的互动。情感分析技术&#xff0c;正是赋予机器人这种“理解人类情绪”能力的关键钥匙&#xff0c;它的融入将彻底革新机器人与人…

Linux笔记---进程间通信:匿名管道

1. 管道通信 1.1 管道的概念与分类 管道&#xff08;Pipe&#xff09; 是进程间通信&#xff08;IPC&#xff09;的一种基础机制&#xff0c;主要用于在具有亲缘关系的进程&#xff08;如父子进程、兄弟进程&#xff09;之间传递数据&#xff0c;其核心特性是通过内核缓冲区实…

Ollama API 应用指南

1. 基础信息 默认地址: http://localhost:11434/api数据格式: application/json支持方法: POST&#xff08;主要&#xff09;、GET&#xff08;部分接口&#xff09; 2. 模型管理 API (1) 列出本地模型 端点: GET /api/tags功能: 获取已下载的模型列表。示例:curl http://lo…

【OSCP-vulnhub】Raven-2

目录 端口扫描 本地/etc/hosts文件解析 目录扫描&#xff1a; 第一个flag 利用msf下载exp flag2 flag3 Mysql登录 查看mysql的运行权限 MySql提权&#xff1a;UDF 查看数据库写入条件 查看插件目录 查看是否可以远程登录 gcc编译.o文件 创建so文件 创建临时监听…

Podman Desktop:现代轻量容器管理利器(Podman与Docker)

前言 什么是 Podman Desktop&#xff1f; Podman Desktop 是基于 Podman CLI 的图形化开源容器管理工具&#xff0c;运行在 Windows&#xff08;或 macOS&#xff09;上&#xff0c;默认集成 Fedora Linux&#xff08;WSL 2 环境&#xff09;。它提供与 Docker 类似的使用体验…