GPT-2 语言模型 - 模型训练

本节代码是一个完整的机器学习工作流程,用于训练一个基于GPT-2的语言模型。下面是对这段代码的详细解释:

文件目录如下

1. 初始化和数据准备

  • 设置随机种子

    random.seed(1002)

    确保结果的可重复性。

  • 定义参数

    test_rate = 0.2
    context_length = 128
    • test_rate:测试集占总数据集的比例。

    • context_length:模型处理的文本长度。

  • 获取数据文件

    all_files = glob(pathname=os.path.join("data","*"))

    使用 glob 获取 data 目录下的所有文件。

  • 划分数据集

    test_file_list = random.sample(all_files, int(len(all_files) * test_rate))
    train_file_list = [i for i in all_files if i not in test_file_list]

    将数据集随机划分为训练集和测试集。

  • 加载数据集

    raw_datasets = load_dataset("csv", data_files={"train": train_file_list, "vaild": test_file_list}, cache_dir="cache_data")

    使用 datasets 库加载 CSV 格式的数据集,并缓存到 cache_data 目录。

2. 数据预处理

  • 初始化分词器

    tokenizer = BertTokenizerFast.from_pretrained("D:/bert-base-chinese")
    tokenizer.add_special_tokens({"bos_token":"[begin]","eos_token":"[end]"})

    从本地路径加载预训练的 BERT 分词器,并添加自定义的开始和结束标记。

  • 数据预处理

    tokenize_datasets = raw_datasets.map(tokenize, batched=True, remove_columns=raw_datasets["train"].column_names)

    使用 map 方法对数据集进行预处理,将文本转换为模型可接受的格式。

    • tokenize 函数对文本进行分词和截断。

    • batched=True 表示批量处理数据。

    • remove_columns 删除原始数据集中的列。

3. 模型配置和初始化

  • 模型配置

    config = GPT2Config.from_pretrained("config",vocab_size=len(tokenizer),n_ctx=context_length,bos_token_id=tokenizer.bos_token_id,eos_token_id=tokenizer.eos_token_id,)

    加载预训练的 GPT-2 配置,并根据分词器的词汇表大小和上下文长度进行调整。

  • 初始化模型

    model = GPT2LMHeadModel(config)
    model_size = sum([t.numel() for t in model.parameters()])
    print(f"model_size: {model_size/1000/1000} M")

    根据配置初始化 GPT-2 语言模型,并计算模型参数的总数,打印模型大小(以兆字节为单位)。

4. 训练设置

  • 数据整理器

    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

    使用 DataCollatorForLanguageModeling 整理训练数据,设置 mlm=False 表示不使用掩码语言模型。

  • 训练参数

    args = TrainingArguments(learning_rate=1e-5,num_train_epochs=100,per_device_train_batch_size=10,per_device_eval_batch_size=10,eval_steps=2000,logging_steps=2000,gradient_accumulation_steps=5,weight_decay=0.1,warmup_steps=1000,lr_scheduler_type="cosine",save_steps=100,output_dir="model_output",fp16=True,
    )

    配置训练参数,包括学习率、训练轮数、批大小、评估间隔等。

  • 初始化训练器

    trianer = Trainer(model=model,args=args,tokenizer=tokenizer,data_collator=data_collator,train_dataset=tokenize_datasets["train"],eval_dataset=tokenize_datasets["vaild"]
    )
  • 启动训练

    trianer.train()

    使用 Trainer 类启动模型训练。

需复现完整代码

from glob import glob
import os
from torch.utils.data import Dataset
from datasets import load_dataset
import random
from transformers import BertTokenizerFast
from transformers import GPT2Config
from transformers import GPT2LMHeadModel
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer,TrainingArgumentsdef tokenize(element):outputs = tokenizer(element["content"],truncation=True,max_length=context_length,return_overflowing_tokens=True,return_length=True)input_batch = []for length,input_ids in zip(outputs["length"],outputs["input_ids"]):if length == context_length:input_batch.append(input_ids)return {"input_ids":input_batch}if __name__ == "__main__":random.seed(1002)test_rate = 0.2context_length = 128all_files = glob(pathname=os.path.join("data","*"))test_file_list = random.sample(all_files,int(len(all_files)*test_rate))train_file_list = [i for i in all_files if i not in test_file_list]raw_datasets = load_dataset("csv",data_files={"train":train_file_list,"vaild":test_file_list},cache_dir="cache_data")tokenizer = BertTokenizerFast.from_pretrained("D:/bert-base-chinese")tokenizer.add_special_tokens({"bos_token":"[begin]","eos_token":"[end]"})tokenize_datasets = raw_datasets.map(tokenize,batched=True,remove_columns=raw_datasets["train"].column_names)config = GPT2Config.from_pretrained("config",vocab_size=len(tokenizer),n_ctx=context_length,bos_token_id = tokenizer.bos_token_id,eos_token_id = tokenizer.eos_token_id,)model = GPT2LMHeadModel(config)model_size = sum([ t.numel() for t in model.parameters()])print(f"model_size: {model_size/1000/1000} M")data_collator = DataCollatorForLanguageModeling(tokenizer,mlm=False)args = TrainingArguments(learning_rate=1e-5,num_train_epochs=100,per_device_train_batch_size=10,per_device_eval_batch_size=10,eval_steps=2000,logging_steps=2000,gradient_accumulation_steps=5,weight_decay=0.1,warmup_steps=1000,lr_scheduler_type="cosine",save_steps=100,output_dir="model_output",fp16=True,)trianer = Trainer(model=model,args=args,tokenizer=tokenizer,data_collator=data_collator,train_dataset=tokenize_datasets["train"],eval_dataset=tokenize_datasets["vaild"])trianer.train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/75718.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

架构师面试(二十九):TCP Socket 编程

问题 今天考察网络编程的基础知识。 在基于 TCP 协议的网络 【socket 编程】中可能会遇到很多异常,在下面的相关描述中说法正确的有哪几项呢? A. 在建立连接被拒绝时,有可能是因为网络不通或地址错误或 server 端对应端口未被监听&#x…

HTTP实现心跳模块

HTTP实现心跳模块 使用轻量级的cHTTP库cpp-httplib重现实现HTTP心跳模块 头文件HttplibHeartbeat.h #ifndef HTTPLIB_HEARTBEAT_H #define HTTPLIB_HEARTBEAT_H#include <string> #include <thread> #include <atomic> #include <chrono> #include …

openharmony—release—4.1开发环境搭建(踩坑记录)

环境开发需要分别在window以及ubuntu下进行相应设置 一、window 1.安装DevEco Device Tool OpenAtom OpenHarmony 二、ubuntu 1.将Ubuntu Shell环境修改为bash ls -l /bin/sh 2.打开终端工具&#xff0c;执行如下命令&#xff0c;输入密码&#xff0c;然后选择No&#xff0…

Go学习系列文章声明

本次学习是基于B站的视频&#xff0c;【Udemy高分热门付费课程】Golang&#xff1a;完整开发者指南&#xff08;基础知识和高级特性&#xff09;中英文字幕_哔哩哔哩_bilibili 本人会尝试输出视频中的内容&#xff0c;如有错误欢迎指出 next page: Go installation process

error: RPC failed; HTTP 408 curl 22 The requested URL returned error: 408

在git push时报错&#xff1a;error: RPC failed; HTTP 408 curl 22 The requested URL returned error: 408 原因&#xff1a;可能是推送的文件太大&#xff0c;要么是缓存不够&#xff0c;要么是网络不行。 解决方法&#xff1a; 将本地 http.postBuffer 数值调整到500MB&…

Android.bp中添加条件判断编译方式

背景&#xff1a; 马哥学员朋友以前在vip群里&#xff0c;有问道如何在Android.bp中添加条件判断&#xff0c;在工作中经常需要一套代码兼容发货目标版本&#xff0c;即代码都是公共的一套&#xff0c;但是需要用这一套代码集成到各个产品设备上 但是这个产品设备可能面临比…

swift ui基础

一个朴实无华的目录 今日学习内容&#xff1a;1.三种布局&#xff08;可以相互包裹&#xff09;1.1 vstack&#xff08;竖直&#xff09;&#xff1a;先写的在上面1.1 hstack&#xff08;水平&#xff09;&#xff1a;先写的在左边1.1 zstack&#xff08;前后&#xff09;&…

第16届蓝桥杯单片机模拟试题Ⅲ

试题 代码 sys.h #ifndef __SYS_H__ #define __SYS_H__#include <STC15F2K60S2.H> //sys.c extern unsigned char UI; //界面标志(0湿度界面、1参数界面、2时间界面) extern unsigned char time; //时间间隔(1s~10S) extern bit ssflag; //启动/停止标志…

Node.js中URL模块详解

Node.js 中 URL 模块全部 API 详解 1. URL 类 const { URL } require(url);// 1. 创建 URL 对象 const url new URL(https://www.example.com:8080/path?queryvalue#hash);// 2. URL 属性 console.log(协议:, url.protocol); // https: console.log(主机名:, url.hos…

Java接口性能优化面试问题集锦:高频考点与深度解析

1. 如何定位接口性能瓶颈&#xff1f;常用哪些工具&#xff1f; 考察点&#xff1a;性能分析工具的使用与问题定位能力。 核心答案&#xff1a; 工具&#xff1a;Arthas&#xff08;在线诊断&#xff09;、JProfiler&#xff08;内存与CPU分析&#xff09;、VisualVM、Prometh…

WheatA小麦芽:农业气象大数据下载器

今天为大家介绍的软件是WheatA小麦芽&#xff1a;专业纯净的农业气象大数据系统。下面&#xff0c;我们将从软件的主要功能、支持的系统、软件官网等方面对其进行简单的介绍。 主要内容来源于软件官网&#xff1a;WheatA小麦芽的官方网站是http://www.wheata.cn/ &#xff0c;…

Python10天突击--Day 2: 实现观察者模式

以下是 Python 实现观察者模式的完整方案&#xff0c;包含同步/异步支持、类型注解、线程安全等特性&#xff1a; 1. 经典观察者模式实现 from abc import ABC, abstractmethod from typing import List, Anyclass Observer(ABC):"""观察者抽象基类""…

CST1019.基于Spring Boot+Vue智能洗车管理系统

计算机/JAVA毕业设计 【CST1019.基于Spring BootVue智能洗车管理系统】 【项目介绍】 智能洗车管理系统&#xff0c;基于 Spring Boot Vue 实现&#xff0c;功能丰富、界面精美 【业务模块】 系统共有三类用户&#xff0c;分别是&#xff1a;管理员用户、普通用户、工人用户&…

Windows上使用Qt搭建ARM开发环境

在 Windows 上使用 Qt 和 g++-arm-linux-gnueabihf 进行 ARM Linux 交叉编译(例如针对树莓派或嵌入式设备),需要配置 交叉编译工具链 和 Qt for ARM Linux。以下是详细步骤: 1. 安装工具链 方法 1:使用 MSYS2(推荐) MSYS2 提供 mingw-w64 的 ARM Linux 交叉编译工具链…

Python爬虫教程011:scrapy爬取当当网数据开启多条管道下载及下载多页数据

文章目录 3.6.4 开启多条管道下载3.6.5 下载多页数据3.6.6 完整项目下载3.6.4 开启多条管道下载 在pipelines.py中新建管道类(用来下载图书封面图片): # 多条管道开启 # 要在settings.py中开启管道 class DangdangDownloadPipeline:def process_item(self, item, spider):…

Mysql -- 基础

SQL SQL通用语法&#xff1a; SQL分类&#xff1a; DDL: 数据库操作 查询&#xff1a; SHOW DATABASES&#xff1b; 创建&#xff1a; CREATE DATABASE[IF NOT EXISTS] 数据库名 [DEFAULT CHARSET字符集] [COLLATE 排序规则]&#xff1b; 删除&#xff1a; DROP DATABA…

实操(环境变量)Linux

环境变量概念 我们用语言写的文件编好后变成了程序&#xff0c;./ 运行的时候他就会变成一个进程被操作系统调度并运行&#xff0c;运行完毕进程相关资源被释放&#xff0c;因为它是一个bash的子进程&#xff0c;所以它退出之后进入僵尸状态&#xff0c;bash回收他的退出结果&…

torch.cat和torch.stack的区别

torch.cat 和 torch.stack 是 PyTorch 中用于组合张量的两个常用函数&#xff0c;它们的核心区别在于输入张量的维度和输出张量的维度变化。以下是详细对比&#xff1a; 1. torch.cat (Concatenate) 作用&#xff1a;沿现有维度拼接多个张量&#xff0c;不创建新维度 输入要求…

深入解析@Validated注解:Spring 验证机制的核心工具

一、注解出处与核心定位 1. 注解来源 • 所属框架&#xff1a;Validated 是 Spring Framework 提供的注解&#xff08;org.springframework.validation.annotation 包下&#xff09;。 • 核心定位&#xff1a; 作为 Spring 对 JSR-380&#xff08;Bean Validation 2.0&#…

2025年认证杯数学建模竞赛A题完整分析论文(含模型、可运行代码)(共32页)

2025年认证杯数学建模竞赛A题完整分析论文 目录 摘要 一、问题分析 二、问题重述 三、模型假设 四、 模型建立与求解 4.1问题1 4.1.1问题1解析 4.1.2问题1模型建立 4.1.3问题1样例代码&#xff08;仅供参考&#xff09; 4.1.4问题1求解结果分析&#xff08…