基于Bert模型的增量微调3-使用csv文件训练

我们使用weibo评价数据,8分类的csv格式数据集。

一、创建数据集合

使用csv格式的数据作为数据集。

1、创建MydataCSV.py

from  torch.utils.data import Dataset
from datasets import load_datasetclass MyDataset(Dataset):#初始化数据集def __init__(self, split):# 加载csv数据self.dataset=load_dataset(path="csv",data_files=f"D:\Test\LLMTrain\day03\data\Weibo/{split}.csv", split= "train")# 返回数据集长度def __len__(self):return len(self.dataset)# 对每条数据单独进行数据处理def __getitem__(self, idx):text=self.dataset[idx]["text"]label=self.dataset[idx]["label"]return  text,labelif __name__== "__main__":train_dataset=MyDataset("test")for i in range(10):print(train_dataset[i])

二、处理模型

我们使用8分类任务

创建netCSV.py

import torch
from transformers import BertModel#定义设备信息
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)#加载预训练模型
path1=r"D:\Test\LLMTrain\day03\model\bert-base-chinese\models--google-bert--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f"
pretrained = BertModel.from_pretrained(path1).to(DEVICE)
print(pretrained)#定义下游任务(增量模型)
class Model(torch.nn.Module):def __init__(self):super().__init__()#设计全连接网络,实现8分类任务self.fc = torch.nn.Linear(768,8)#使用模型处理数据(执行前向计算)def forward(self,input_ids,attention_mask,token_type_ids):#冻结Bert模型的参数,让其不参与训练with torch.no_grad():out = pretrained(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)#增量模型参与训练out = self.fc(out.last_hidden_state[:,0])return out

8分类任务,所以 self.fc=torch.nn.Liner768,8) 。

我们是对大模型做增量微调训练,所以需要冻结Bert模型的参数,让其不参与训练。所以使用 

with torch.no_grad()。

我们定义一个下游任务增量模型Model类,继承 torch.nn.Module。

三、训练的代码

1、创建目录params

存放训练后的结果。

2、写代码

创建train_val_csv.py

#模型训练
import torch
from MyDataCSV import MyDataset
from torch.utils.data import DataLoader
from netCSV import Model
from transformers import BertTokenizer,AdamW#定义设备信息
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#定义训练的轮次(将整个数据集训练完一次为一轮)
EPOCH = 30000#加载字典和分词器
token = BertTokenizer.from_pretrained(r"D:\Test\LLMTrain\day03\model\bert-base-chinese\models--google-bert--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f")#将传入的字符串进行编码
def collate_fn(data):sents = [i[0]for i in data]label = [i[1] for i in data]#编码data = token.batch_encode_plus(batch_text_or_text_pairs=sents,# 当句子长度大于max_length(上限是model_max_length)时,截断truncation=True,max_length=512,# 一律补0到max_lengthpadding="max_length",# 可取值为tf,pt,np,默认为listreturn_tensors="pt",# 返回序列长度return_length=True)input_ids = data["input_ids"]attention_mask = data["attention_mask"]token_type_ids = data["token_type_ids"]label = torch.LongTensor(label)return input_ids,attention_mask,token_type_ids,label#创建数据集
train_dataset = MyDataset("train")
train_loader = DataLoader(dataset=train_dataset,#训练批次batch_size=50,#打乱数据集shuffle=True,#舍弃最后一个批次的数据,防止形状出错drop_last=True,#对加载的数据进行编码collate_fn=collate_fn
)
#创建验证数据集
val_dataset = MyDataset("validation")
val_loader = DataLoader(dataset=val_dataset,#训练批次batch_size=50,#打乱数据集shuffle=True,#舍弃最后一个批次的数据,防止形状出错drop_last=True,#对加载的数据进行编码collate_fn=collate_fn
)
if __name__ == '__main__':#开始训练print(DEVICE)model = Model().to(DEVICE)#定义优化器optimizer = AdamW(model.parameters())#定义损失函数loss_func = torch.nn.CrossEntropyLoss()#初始化验证最佳准确率best_val_acc = 0.0for epoch in range(EPOCH):for i,(input_ids,attention_mask,token_type_ids,label) in enumerate(train_loader):#将数据放到DVEVICE上面input_ids, attention_mask, token_type_ids, label = input_ids.to(DEVICE),attention_mask.to(DEVICE),token_type_ids.to(DEVICE),label.to(DEVICE)#前向计算(将数据输入模型得到输出)out = model(input_ids,attention_mask,token_type_ids)#根据输出计算损失loss = loss_func(out,label)#根据误差优化参数optimizer.zero_grad()loss.backward()optimizer.step()#每隔5个批次输出训练信息if i%5 ==0:out = out.argmax(dim=1)#计算训练精度acc = (out==label).sum().item()/len(label)print(f"epoch:{epoch},i:{i},loss:{loss.item()},acc:{acc}")#验证模型(判断模型是否过拟合)#设置为评估模型model.eval()#不需要模型参与训练with torch.no_grad():val_acc = 0.0val_loss = 0.0for i, (input_ids, attention_mask, token_type_ids, label) in enumerate(val_loader):# 将数据放到DVEVICE上面input_ids, attention_mask, token_type_ids, label = input_ids.to(DEVICE), attention_mask.to(DEVICE), token_type_ids.to(DEVICE), label.to(DEVICE)# 前向计算(将数据输入模型得到输出)out = model(input_ids, attention_mask, token_type_ids)# 根据输出计算损失val_loss += loss_func(out, label)#根据数据,计算验证精度out = out.argmax(dim=1)val_acc+=(out==label).sum().item()val_loss/=len(val_loader)val_acc/=len(val_loader)print(f"验证集:loss:{val_loss},acc:{val_acc}")# #每训练完一轮,保存一次参数# torch.save(model.state_dict(),f"params/{epoch}_bert.pth")# print(epoch,"参数保存成功!")#根据验证准确率保存最优参数if val_acc > best_val_acc:best_val_acc = val_acctorch.save(model.state_dict(),"params1/best_bert.pth")print(f"EPOCH:{epoch}:保存最优参数:acc{best_val_acc}")#保存最后一轮参数torch.save(model.state_dict(), "params1/last_bert.pth")print(f"EPOCH:{epoch}:最后一轮参数保存成功!")

3、执行代码

这个过程需等待很久,若是使用cuda环境,显存越大,速度越快。

train_loader的训练批次batch_size=50,这个数值是根据电脑的配置来的,数值越大越好,只要不超过显存或者内存的90%即可。

四、使用训练好的模型

我们写一个控制台程序,也可以使用FastAPI。创建run.py文件。

#模型使用接口(主观评估)
#模型训练
import torch
from net import Model
from transformers import BertTokenizer#定义设备信息
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")#加载字典和分词器
token = BertTokenizer.from_pretrained(r"D:\Test\LLMTrain\day03\model\bert-base-chinese\models--google-bert--bert-base-chinese\snapshots\c30a6ed22ab4564dc1e3b2ecbf6e766b0611a33f")
model = Model().to(DEVICE)
names = ["负向评价","正向评价"]#将传入的字符串进行编码
def collate_fn(data):sents = []sents.append(data)#编码data = token.batch_encode_plus(batch_text_or_text_pairs=sents,# 当句子长度大于max_length(上限是model_max_length)时,截断truncation=True,max_length=512,# 一律补0到max_lengthpadding="max_length",# 可取值为tf,pt,np,默认为listreturn_tensors="pt",# 返回序列长度return_length=True)input_ids = data["input_ids"]attention_mask = data["attention_mask"]token_type_ids = data["token_type_ids"]return input_ids,attention_mask,token_type_idsdef test():#加载模型训练参数model.load_state_dict(torch.load("params/best_bert.pth"))#开启测试模型model.eval()while True:data = input("请输入测试数据(输入‘q’退出):")if data=='q':print("测试结束")breakinput_ids,attention_mask,token_type_ids = collate_fn(data)input_ids, attention_mask, token_type_ids = input_ids.to(DEVICE),attention_mask.to(DEVICE),token_type_ids.to(DEVICE)#将数据输入到模型,得到输出with torch.no_grad():out = model(input_ids,attention_mask,token_type_ids)out = out.argmax(dim=1)print("模型判定:",names[out],"\n")if __name__ == '__main__':test()

运行程序 ,输入test测试集里的数据进行验证,或许输入其他的文本验证。

 正确率还是非常棒的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/73244.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

flowable新增或修改单个任务的历史变量

简介 场景:对历史任务进行关注,所以需要修改流程历史任务的本地变量 方法包含2个类 1)核心方法,flowable command类:HistoricTaskSingleVariableUpdateCmd 2)执行command类:BpmProcessCommandS…

Netty基础—4.NIO的使用简介一

大纲 1.Buffer缓冲区 2.Channel通道 3.BIO编程 4.伪异步IO编程 5.改造程序以支持长连接 6.NIO三大核心组件 7.NIO服务端的创建流程 8.NIO客户端的创建流程 9.NIO优点总结 10.NIO问题总结 1.Buffer缓冲区 (1)Buffer缓冲区的作用 (2)Buffer缓冲区的4个核心概念 (3)使…

python元组(被捆绑的列表)

元组(tuple) 1.元组一旦形成就不可更改,元组所指向的内存单元中内容不变 定义:定义元组使用小括号,并且使用逗号进行隔开,数据可以是不同的数据类型 定义元组自变量(元素,元素,元素…

输入:0.5元/百万tokens(缓存命中)或2元(未命中) 输出:8元/百万tokens

这句话描述了一种 定价模型,通常用于云计算、API 服务或数据处理服务中,根据资源使用情况(如缓存命中与否)来收费。以下是对这句话的详细解释: 1. 关键术语解释 Tokens:在自然语言处理(NLP&…

计算机视觉算法实战——驾驶员玩手机检测(主页有源码)

✨个人主页欢迎您的访问 ✨期待您的三连 ✨ ✨个人主页欢迎您的访问 ✨期待您的三连 ✨ ✨个人主页欢迎您的访问 ✨期待您的三连✨ ​ ​​​ 1. 领域简介:玩手机检测的重要性与技术挑战 驾驶员玩手机检测是智能交通安全领域的核心课题。根据NHTSA数据&#xff0…

Java糊涂包(Hutool)的安装教程并进行网络爬虫

Hutool的使用教程 1:在官网下载jar模块文件 Central Repository: cn/hutool/hutool-all/5.8.26https://repo1.maven.org/maven2/cn/hutool/hutool-all/5.8.26/ 下载后缀只用jar的文件 2:复制并到idea当中,右键这个模块点击增加到库 3&…

深度学习项目--基于DenseNet网络的“乳腺癌图像识别”,准确率090%+,pytorch复现

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 前言 如果说最经典的神经网络,ResNet肯定是一个,从ResNet发布后,很多人做了修改,denseNet网络无疑是最成功的…

优化用户体验:关键 Web 性能指标的获取、分析、优化方法

前言 在当今互联网高速发展的时代用户对于网页的加载速度和响应时间越来越敏感。一个性能表现不佳的网页不仅会影响用户体验,还可能导致用户流失。 因此,了解和优化网页性能指标是每个开发者的必修课。今天我们就来聊聊常见的网页性能指标以及如何获取这…

vs code配置 c/C++

1、下载VSCode Visual Studio Code - Code Editing. Redefined 安装目录可改 勾选创建桌面快捷方式 安装即可 2、汉化VSCode 点击确定 下载MinGW 由于vsCode 只是一个编辑器,他没有自带编译器,所以需要下载一个编译器"MinGW". https://…

Kotlin关键字`when`的详细用法

Kotlin关键字when的详细用法 在Kotlin中,when是一个强大的控制流语句,相当于其他语言中的switch语句,但更加强大且灵活。本文将详细讲解when的用法及其常见场景,并与Java的switch语句进行对比。 一、基本语法 基本的when语法如…

MFCday01、模式对话框

对话框类和应用程序类。 MFC中 Combo Box List Box List Control三种列表控件,日期控件Date Time Picker

接口测试笔记

4、接口测试自动化 接口自动化概述 HttpClient HttpClient开发过程 创建Java工程 新建libs库目录 HttpClient 工具下载及引入 https://hc.apache.org/index.html工程中引入jar包 Get请求 HttpGet方法---发起Get请求 创建HttpClient对象 CloseableHttpClient httpclient …

查找sql中涉及的表名称

import pandas as pd import datetime todaystr(datetime.date.today())filepath/Users/kangyongqing/Documents/kangyq/202303/分析模版/sql表引用提取/ file101试听课明细.txt newfilefile1.title().split(.)[0]with open(filepathfile1,r) as file:contentfile.read().lower…

如何在Ubuntu上构建编译LLVM和ISPC,以及Ubuntu上ISPC的使用方法

之前一直在 Mac 上使用 ISPC,奈何核心/线程太少了。最近想在 Ubuntu 上搞搞,但是 snap 安装的 ISPC不知道为什么只能单核,很奇怪,就想着编译一下,需要 Clang 和 LLVM。但是 Ubuntu 很搞,他的很多软件版本是…

【Spring IOC/AOP】

IOC 参考: Spring基础 - Spring核心之控制反转(IOC) | Java 全栈知识体系 (pdai.tech) 概述: Ioc 即 Inverse of Control (控制反转),是一种设计思想,就是将原本在程序中手动创建对象的控制权&#xff…

电感与电容的具体应用

文章目录 一、电感应用1.​电源滤波:2. 储能——平滑“电流波浪”​ ​3. 调谐——校准“频率乐器”​4. 限流——防止“洪水灾害”​二、电容应用1.核心特性理解2.应用场景 三.电容电感对比 一、电感应用 1.​电源滤波: ​场景:工业设备中…

前端面试:axios 请求的底层依赖是什么?

在前端开发中,Axios 是一个流行的 JavaScript 库,用于发送 HTTP 请求。它简化了与 RESTful APIs 的交互,并提供了许多便利的方法与配置选项。要理解 Axios 的底层依赖,需要从以下几个方面进行分析: 1. Axios 基于 XML…

springboot 3 集成Redisson

maven 依赖 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>3.2.12</version></parent><dependencies><dependency><groupId>org.red…

C#中继承的核心定义‌

1. 继承的核心定义‌ ‌继承‌ 是面向对象编程&#xff08;OOP&#xff09;的核心特性之一&#xff0c;允许一个类&#xff08;称为‌子类/派生类‌&#xff09;基于另一个类&#xff08;称为‌父类/基类‌&#xff09;构建&#xff0c;自动获得父类的成员&#xff08;字段、属…

Deep research深度研究:ChatGPT/ Gemini/ Perplexity/ Grok哪家最强?(实测对比分析)

目前推出深度研究和深度检索的AI大模型有四家&#xff1a; OpenAI和Gemini 的deep research&#xff0c;以及Perplexity 和Grok的deep search&#xff0c;都能生成带参考文献引用的主题报告。 致力于“几分钟之内生成一份完整的主题调研报告&#xff0c;解决人力几小时甚至几天…