学习记录:初次学习使用transformers进行大模型微调

初次使用transformers进行大模型微调

环境:

电脑配置:
笔记本电脑:I5(6核12线程) + 16G + RTX3070(8G显存)
需要自行解决科学上网

Python环境:
python版本:3.8.8
大模型:microsoft/DialoGPT-medium(微软的对话大模型,模型小,笔记本也能学习微调)
数据集:daily_dialog (日常对话数据集)

其他:
模型及数据集:使用来源于抱抱脸

微调大模型

准备工作:

下载模型:

找到自己想要的模型:

  1. 打开抱抱脸官网——点击Model:
    在这里插入图片描述

  2. 输入要搜索的模型(这里以DialoGPT-medium为例):
    在这里插入图片描述

  3. 复制名称到代码中替换要下载的模型名称:

在这里插入图片描述
模型下载:

import os
from transformers import AutoModel, AutoTokenizer# 因为使用了科学上网,需要进行处理
os.environ["HTTP_PROXY"] = "http://127.0.0.1:xxxx"
os.environ["HTTPS_PROXY"] = "http://127.0.0.1:xxxx"if __name__ == '__main__':# model_name = 'google-t5/t5-small'  # 要下载的模型名称model_name = 'microsoft/DialoGPT-medium'  # 要下载的模型名称 需要到抱抱脸进行复制cache_dir = r'xxxx'  # 模型保存位置# 加载模型时指定下载路径model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
下载数据集:

找到自己想要的模型:

  1. 打开抱抱脸官网——点击Datasets:List item
  2. 输入要搜索的内容,点击对应数据集进入:
    在这里插入图片描述
  3. 找到适合用的模型后,点击复制
    在这里插入图片描述

开始微调训练

代码示例:

# 系统模块
import os# 第三方库
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from datasets import load_dataset# 设置代理(注意:可能需要根据实际网络环境调整或移除)
os.environ["HTTP_PROXY"] = "http://127.0.0.1:xxxx"  # HTTP代理设置
os.environ["HTTPS_PROXY"] = "http://127.0.0.1:xxxx"  # HTTPS代理设置if __name__ == '__main__':# 数据准备阶段 --------------------------------------------------------------# 加载完整数据集(daily_dialog包含日常对话数据集)full_dataset = load_dataset("daily_dialog", trust_remote_code=True)# 创建子数据集(仅使用训练集前500条样本,用于快速实验)dataset = {"train": full_dataset["train"].select(range(500))  # select保持数据集结构}# 模型加载阶段 --------------------------------------------------------------# 模型配置参数model_name = "microsoft/DialoGPT-medium"  # 使用微软的对话生成预训练模型cache_dir = r'xxx'  # 本地模型缓存路径# 加载分词器(重要:设置填充token与EOS token一致)tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)tokenizer.pad_token = tokenizer.eos_token  # 将填充token设置为与EOS相同# 加载预训练模型(使用因果语言模型结构)model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)# 数据预处理阶段 ------------------------------------------------------------def tokenize_function(examples):"""将对话数据转换为模型输入格式的预处理函数"""# 将多轮对话用EOS token连接,并在结尾添加EOSdialogues = [tokenizer.eos_token.join(dialog) + tokenizer.eos_tokenfor dialog in examples["dialog"]]# 对文本进行分词处理tokenized = tokenizer(dialogues,truncation=True,  # 启用截断max_length=512,  # 最大序列长度padding="max_length"  # 填充到最大长度(静态填充))# 创建标签(对于因果语言模型,标签与输入相同)tokenized["labels"] = tokenized["input_ids"].copy()return tokenized# 应用预处理(保留数据集结构)tokenized_dataset = {"train": dataset["train"].map(tokenize_function,batched=True,  # 批量处理提升效率batch_size=50,  # 每批处理50个样本remove_columns=["dialog", "act", "emotion"]  # 移除原始文本列)}# 数据验证(检查预处理结果)print("Sample keys:", tokenized_dataset["train"][0].keys())  # 应包含input_ids, attention_mask, labelsprint("Input IDs:", tokenized_dataset["train"][0]["input_ids"][:5])  # 检查前5个token# 训练配置阶段 --------------------------------------------------------------training_args = TrainingArguments(output_dir="./dialo_finetuned",  # 输出目录per_device_train_batch_size=2,  # 每个设备的批次大小(根据显存调整)gradient_accumulation_steps=8,  # 梯度累积步数(模拟更大batch size)learning_rate=1e-5,  # 初始学习率(可调超参数)num_train_epochs=3,  # 训练轮次(根据需求调整)fp16=True,  # 启用混合精度训练(需要GPU支持)logging_steps=10,  # 每10步记录日志# 可添加的优化参数:# evaluation_strategy="steps",    # 添加验证策略# save_strategy="epoch",          # 保存策略# warmup_steps=100,               # 学习率预热步数)# 创建训练器trainer = Trainer(model=model,args=training_args,train_dataset=tokenized_dataset["train"],  # 训练数据集# 可扩展功能:# eval_dataset=tokenized_dataset["validation"],  # 添加验证集# data_collator=...,             # 自定义数据整理器# compute_metrics=...,           # 添加评估指标)# 训练执行阶段 --------------------------------------------------------------trainer.train()  # 启动训练# 模型保存阶段 --------------------------------------------------------------model.save_pretrained("./dialo_finetuned")  # 保存模型权重tokenizer.save_pretrained("./dialo_finetuned")  # 保存分词器# 推荐使用以下方式统一保存:trainer.save_model("./dialo_finetuned")       # 官方推荐保存方式

微调后使用

代码:

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from transformers import TextStreamer
from collections import deque
import torchdef optimized_generation(text, tokenizer, model):inputs = tokenizer(text, return_tensors="pt").to(model.device)outputs = model.generate(**inputs,max_new_tokens=150,temperature=0.9,  # 越高越有创意 (0-1)top_k=50,  # 限制候选词数量top_p=0.95,  # 核采样阈值repetition_penalty=1.2,  # 抑制重复num_beams=3,  # 束搜索宽度early_stopping=True,do_sample=True)return tokenizer.decode(outputs[0], skip_special_tokens=True)# 单轮对话
def simple_chat(model_path, text, max_length=100):"""单轮对话:param text::param max_length::return:"""# 加载模型和分词器tokenizer = AutoTokenizer.from_pretrained(model_path)model = AutoModelForCausalLM.from_pretrained(model_path)# 确保pad_token设置正确tokenizer.pad_token = tokenizer.eos_token# inputs = tokenizer(text + tokenizer.eos_token, return_tensors="pt")# outputs = model.generate(#     inputs.input_ids,#     max_length=max_length,#     pad_token_id=tokenizer.eos_token_id,#     temperature=0.7,#     do_sample=True# )# response = tokenizer.decode(outputs[0], skip_special_tokens=True)response = optimized_generation(text + tokenizer.eos_token, tokenizer, model)return response[len(text):]  # 去除输入文本# 多轮对话
class DialogueBot:def __init__(self, model_path, max_history=3):self.tokenizer = AutoTokenizer.from_pretrained(model_path)self.model = AutoModelForCausalLM.from_pretrained(model_path).to("cuda")self.max_history = max_historyself.history = deque(maxlen=max_history * 2)  # 每轮包含用户和机器人各一条# 确保pad_token设置if self.tokenizer.pad_token is None:self.tokenizer.pad_token = self.tokenizer.eos_tokendef generate_response(self, user_input):# 添加用户输入(带EOS)self.history.append(f"User: {user_input}{self.tokenizer.eos_token}")# 构建prompt并编码prompt = self._build_prompt()inputs = self.tokenizer(prompt,return_tensors="pt",max_length=512,truncation=True).to(self.model.device)# 流式输出# streamer = TextStreamer(self.tokenizer)# 生成回复outputs = self.model.generate(inputs.input_ids,attention_mask=inputs.attention_mask,max_new_tokens=150,temperature=0.85,top_p=0.95,eos_token_id=self.tokenizer.eos_token_id,pad_token_id=self.tokenizer.eos_token_id,do_sample=True,# streamer=streamer,early_stopping=True)# 解码并处理回复full_response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:],skip_special_tokens=True)# 清理无效内容(按第一个EOS截断)clean_response = full_response.split(self.tokenizer.eos_token)[0].strip()# 添加机器人回复到历史(带EOS)self.history.append(f"Bot: {clean_response}{self.tokenizer.eos_token}")return clean_responsedef _build_prompt(self):return "".join(self.history)if __name__ == '__main__':# 指定模型路径model_path = "./dialo_finetuned"# 测试单轮对话print(simple_chat(model_path, "Hello, how are you?"))# 使用示例 多轮对话# bot = DialogueBot(model_path)# while True:#     user_input = input("You: ")#     if user_input.lower() == "exit":#         break#     print("Bot:", bot.generate_response(user_input))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/70834.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Java学习】Object类与接口

面向对象系列五 一、引用 1.自调传自与this类型 2.类变量引用 3.重写时的发生 二、Object类 1.toString 2.equals 3.hashCode 4.clone 三、排序规则接口 1.Comparable 2.Comparator 一、引用 1.自调传自与this类型 似复刻变量调用里面的非静态方法时,都…

OpenEuler学习笔记(三十五):搭建代码托管服务器

以下是主流的代码托管软件分类及推荐,涵盖自托管和云端方案,您可根据团队规模、功能需求及资源情况选择: 一、自托管代码托管平台(可私有部署) 1. GitLab 简介: 功能全面的 DevOps 平台,支持代码托管、C…

Vscode无法加载文件,因为在此系统上禁止运行脚本

1.在 vscode 终端执行 get-ExecutionPolicy 如果返回是Restricted,说明是禁止状态。 2.在 vscode 终端执行set-ExecutionPolicy RemoteSigned 爆红说明没有设置成功 3.在 vscode 终端执行Set-ExecutionPolicy -Scope CurrentUser RemoteSigned 然后成功后你再在终…

Transformer 架构 理解

大家读完觉得有帮助记得关注和点赞!!! Transformer 架构:encoder/decoder 内部细节。 的介绍,说明 Transformer 架构相比当时主流的 RNN/CNN 架构的创新之处: 在 transformer 之前,最先进的架构…

事务的4个特性和4个隔离级别

事务的4个特性和4个隔离级别 1. 什么是事务2. 事务的ACID特性2.1 原子性2.2 一致性2.3 持久性2.4 隔离性 3. 事务的创建4. 事务并发时出现的问题4.1 DIRTY READ 脏读4.2 NON - REPEATABLR READ 不可重复读4.3 PHANTOM READ 幻读 5. 事务的隔离级别5.1 READ UNCOMMITTED 读未提交…

LeetCode热题100- 字符串解码【JavaScript讲解】

古语有云:“事以密成,语以泄败”! 关于字符串解码: 题目:题解:js代码:代码中遇到的方法:repeat方法:为什么这里不用this.strstack.push(result)? 题目&#x…

水利工程安全包括哪几个方面

水利工程安全培训的内容主要包括以下几个方面: 基础知识和技能培训 : 法律法规 :学习水利工程相关的安全生产法律法规,了解安全生产标准及规范。 事故案例 :通过分析事故案例,了解事故原因和教训&#x…

浅谈新能源汽车充电桩建设问题分析及解决方案

摘要: 在全球倡导低碳减排的大背景下,新能源成为热门行业在全球范围内得以开展。汽车尾气排放会在一定程度上加重温室效应,并且化石能源的日渐紧缺也迫切对新能源汽车发展提出新要求。现阶段的新能源汽车以电力汽车为主,与燃油汽…

05-1基于vs2022的c语言笔记——运算符

目录 前言 5.运算符和表达式 5-1-1 加减乘除运算符 1.把变量进行加减乘除运算 2.把常量进行加减乘除运算 3.对于比较大的数(往数轴正方向或者负方向),要注意占位符的选取 4.浮点数的加减乘除 5-1-2取余/取模运算符 1.基本规则 2.c语…

ubuntu:换源安装docker-ce和docker-compose

更新apt源 apt换源:ubuntu:更新阿里云apt源-CSDN博客 安装docker-ce 1、更新软件源 sudo apt update2、安装基本软件 sudo apt-get install apt-transport-https ca-certificates curl software-properties-common lrzsz -y3、指定使用阿里云镜像 su…

0—QT ui界面一览

2025.2.26,感谢gpt4 1.控件盒子 1. Layouts(布局) 布局控件用于组织界面上的控件,确保它们的位置和排列方式合理。 Vertical Layout(垂直布局) :将控件按垂直方向排列。 建议:适…

Apache Doris 索引的全面剖析与使用指南

搞大数据开发的都知道,想要在海量数据里快速查数据,就像在星图里找一颗特定的星星,贼费劲。不过别慌,数据库索引就是咱们的 “定位神器”,能让查询效率直接起飞!就拿 Apache Doris 这个超火的分析型数据库来…

docker file中ADD命令的介绍

在 Docker 的世界里,Dockerfile 是一个用于定义镜像内容和行为的脚本文件。其中,ADD 指令是 Dockerfile 中一个非常重要的命令,用于将文件或目录从主机文件系统复制到容器的文件系统中。本文将详细介绍 ADD 指令的作用、使用方式以及一些最佳…

从零到一:如何用阿里云百炼和火山引擎搭建专属 AI 助手(DeepSeek)?

本文首发:从零到一:如何用阿里云百炼和火山引擎搭建专属 AI 助手(DeepSeek)? 阿里云百炼和火山引擎都推出了免费的 DeepSeek 模型体验额度,今天我和大家一起搭建一个本地的专属 AI 助手。  阿里云百炼为 …

cpp中的继承

一、继承概念 在cpp中,封装、继承、多态是面向对象的三大特性。这里的继承就是允许已经存在的类(也就是基类)的基础上创建新类(派生类或者子类),从而实现代码的复用。 如上图所示,Person是基类&…

【QT】QLinearGradient 线性渐变类简单使用教程

目录 0.简介 1)qtDesigner中 2)实际执行 1.功能详述 3.举一反三的样式 0.简介 QLinearGradient 是 Qt 框架中的一个类,用于定义线性渐变效果(通过样式表设置)。它可以用来填充形状、背景或其他图形元素&#xff0…

前端项目配置 Nginx 全攻略

在前端开发中,项目开发完成后,如何高效、稳定地将其部署到生产环境是至关重要的一步。Nginx 作为一款轻量级、高性能的 Web 服务器和反向代理服务器,凭借其出色的性能和丰富的功能,成为了前端项目部署的首选方案。本文将详细介绍在…

网络安全学习-常见web漏洞的渗xxx透以及防护方法

渗XX透测试 弱口令漏洞 漏洞描述 目标网站管理入口(或数据库等组件的外部连接)使用了容易被猜测的简单字符口令、或者是默认系统账号口令。 渗XX透测试 如果不存在验证码,则直接使用相对应的弱口令字典使用burpsuite 进行爆破如果存在验证…

网络安全 机器学习算法 计算机网络安全机制

(一)网络操作系统 安全 网络操作系统安全是整个网络系统安全的基础。操作系统安全机制主要包括访问控制和隔离控制。 访问控制系统一般包括主体、客体和安全访问政策 访问控制类型: 自主访问控制强制访问控制 访问控制措施: 入…

2025网络安全等级测评报告,信息安全风险评估报告(Word模板)

一、概述 1.1工作方法 1.2评估依据 1.3评估范围 1.4评估方法 1.5基本信息 二、资产分析 2.1 信息资产识别概述 2.2 信息资产识别 三、评估说明 3.1无线网络安全检查项目评估 3.2无线网络与系统安全评估 3.3 ip管理与补丁管理 3.4防火墙 四、威胁细类分析 4.1威胁…