从零构建属于自己的GPT系列2:模型训练1(预训练中文模型加载、中文语言模型训练、逐行代码解读)

🚩🚩🚩Hugging Face 实战系列 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在PyCharm中进行
本篇文章配套的代码资源已经上传

从零构建属于自己的GPT系列1:数据预处理
从零构建属于自己的GPT系列2:模型训练1
从零构建属于自己的GPT系列3:模型训练2
从零构建属于自己的GPT系列4:模型训练3

0 运行参数

指定运行配置参数后运行 :

–epochs 5
–batch_size 8
–device 0
–train_path data/train_novel.pkl
–save_model_path ./model/novel

1 参数设置

def set_args():parser = argparse.ArgumentParser()parser.add_argument('--device', default='0,1', type=str, required=False, help='设置使用哪些显卡')parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行训练')parser.add_argument('--vocab_path', default='vocab/chinese_vocab.model', type=str, required=False, help='sp模型路径')parser.add_argument('--model_config', default='config/cpm-small.json', type=str, required=False, help='需要从头训练一个模型时,模型参数的配置文件')parser.add_argument('--train_path', default='data/train.pkl', type=str, required=False, help='经过预处理之后的数据存放路径')parser.add_argument('--max_len', default=200, type=int, required=False, help='训练时,输入数据的最大长度')parser.add_argument('--log_path', default='log/train.log', type=str, required=False, help='训练日志存放位置')parser.add_argument('--ignore_index', default=-100, type=int, required=False, help='对于ignore_index的label token不计算梯度')parser.add_argument('--epochs', default=100, type=int, required=False, help='训练的最大轮次')parser.add_argument('--batch_size', default=16, type=int, required=False, help='训练的batch size')parser.add_argument('--gpu0_bsz', default=6, type=int, required=False, help='0号卡的batch size')parser.add_argument('--lr', default=1.5e-4, type=float, required=False, help='学习率')parser.add_argument('--eps', default=1.0e-09, type=float, required=False, help='AdamW优化器的衰减率')parser.add_argument('--log_step', default=10, type=int, required=False, help='多少步汇报一次loss')parser.add_argument('--gradient_accumulation_steps', default=6, type=int, required=False, help='梯度积累的步数')parser.add_argument('--max_grad_norm', default=1.0, type=float, required=False)parser.add_argument('--save_model_path', default='model', type=str, required=False, help='模型输出路径')parser.add_argument('--pretrained_model', default='model/zuowen_epoch40', type=str, required=False, help='预训练的模型的路径')parser.add_argument('--seed', type=int, default=1234, help='设置随机种子')parser.add_argument('--num_workers', type=int, default=0, help="dataloader加载数据时使用的线程数量")parser.add_argument('--warmup_steps', type=int, default=4000, help='warm up步数')args = parser.parse_args()return args

由于这里很多地方,在help中已经解释过意思,我只解释部分内容

  1. ‘–device’,如果只有单卡设置成0,多卡设置成0,1…
  2. ‘–no_cuda’,不使用GPU进行训练
  3. ‘–vocab_path’,中文预训练分词模型路径,这个模型是用来分词的,不是用来
  4. ‘–vocab_path’,模型就是用的cpm现成的,完全没有改
  5. ‘–max_len’,文本中的一句话,可能是指逗号或者句号隔开是一句话,但是当前的NLP任务中,是换行符后才是一句话,所以可能等到换行符的时候已经有几十行了 ,这里的max_len就是不管一句话多长,都按照200个词进行分割,就和逗号句号没有关系了,到一句话结束时,如果不到50词就不要了,有50词就加上一句话再补上0
  6. ‘–ignore_index’,-100表示在任务中,有一些特殊字符和一些没用的东西是不想要的,要忽略的ID是多少
  7. ‘–seed’,设置随机种子,设置随机种子在机器学习和深度学习中是非常重要的,在训练模型时,如果不设置随机种子,每次运行代码得到的模型参数初始化、数据集划分等都可能不同,导致实验结果的差异
  8. ‘–gradient_accumulation_steps’,梯度累加步数,正常情况下是一次迭代更新,但是可以攒几次,在pytorch中每次迭代完成后都需要进行一次梯度清零,实际上就相当于间接增加了batch_size,
  9. warmup_steps,刚开始缓慢训练,然后逐步增加训练速度,再然后再平稳训练,最后再进行学习率的衰减。

2 main()函数

def main():args = set_args()os.environ["CUDA_VISIBLE_DEVICES"] = args.deviceargs.cuda = not args.no_cudalogger = set_logger(args.log_path)args.cuda = torch.cuda.is_available() and not args.no_cudadevice = 'cuda:0' if args.cuda else 'cpu'args.device = devicelogger.info('using device:{}'.format(device))set_random_seed(args.seed, args.cuda)tokenizer = CpmTokenizer(vocab_file="vocab/chinese_vocab.model")args.eod_id = tokenizer.convert_tokens_to_ids("<eod>")  # 文档结束符args.pad_id = tokenizer.pad_token_idif not os.path.exists(args.save_model_path):os.mkdir(args.save_model_path)if args.pretrained_model:  # 加载预训练模型model = GPT2LMHeadModel.from_pretrained(args.pretrained_model)else:  # 初始化模型model_config = GPT2Config.from_json_file(args.model_config)model = GPT2LMHeadModel(config=model_config)model = model.to(device)logger.info('model config:\n{}'.format(model.config.to_json_string()))assert model.config.vocab_size == tokenizer.vocab_sizeif args.cuda and torch.cuda.device_count() > 1:model = BalancedDataParallel(args.gpu0_bsz, model, dim=0).cuda()logger.info("use GPU {} to train".format(args.device))num_parameters = 0parameters = model.parameters()for parameter in parameters:num_parameters += parameter.numel()logger.info('number of model parameters: {}'.format(num_parameters))logger.info("args:{}".format(args))# 加载训练集和验证集# ========= Loading Dataset ========= #train_dataset = load_dataset(logger, args)train(model, logger, train_dataset, args)
  1. main函数
  2. 初始化参数(命令行中已经制定好了参数)
  3. 训练设备、显卡参数
  4. 训练设备、显卡参数
  5. 创建日志对象
  6. 训练设备、显卡参数
  7. 训练设备、显卡参数
  8. 训练设备、显卡参数
  9. 训练设备、显卡参数加入日志
  10. 设置随机种子
  11. 读进来CpmTokenizer
  12. end_id=7,索引为7代表一个句子的终止符
  13. 添加padding的id,padding索引是5
  14. 如果保持模型的文件路径不存在
  15. 新建一个路径
  16. 加载预训练模型
  17. 指定的是一个GPT2的模型
  18. 如果没有模型
  19. 从json文件中导入配置
  20. 加载gpt2模型(也就是你给了预训练模型,就直接加载模型,没有就需要下载模型)
  21. 模型放入训练设备中
  22. 内存开始占用
  23. 在命令行中可以看到日志信息了
  24. 多卡训练
  25. 多卡训练
  26. 多卡训练
  27. 计算模型参数的变量
  28. 导入计算参数的函数
  29. 用for循环变量层
  30. 累加参数量
  31. 记录参数日志信息
  32. 记录参数设置
  33. 通过加载数据函数加载数据
  34. 通过训练函数训练模型

从零构建属于自己的GPT系列1:数据预处理
从零构建属于自己的GPT系列2:模型训练1
从零构建属于自己的GPT系列3:模型训练2
从零构建属于自己的GPT系列4:模型训练3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/200460.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

leetcode203. 移除链表元素

题目描述 给你一个链表的头节点 head 和一个整数 val &#xff0c;请你删除链表中所有满足 Node.val val 的节点&#xff0c;并返回 新的头节点 。 示例 1&#xff1a; 输入&#xff1a;head [1,2,6,3,4,5,6], val 6 输出&#xff1a;[1,2,3,4,5]示例 2&#xff1a; 输入…

NSS [NSSCTF 2022 Spring Recruit]babyphp

NSS [NSSCTF 2022 Spring Recruit]babyphp 考点&#xff1a;PHP特性 开局源码直接裸奔 <?php highlight_file(__FILE__); include_once(flag.php);if(isset($_POST[a])&&!preg_match(/[0-9]/,$_POST[a])&&intval($_POST[a])){if(isset($_POST[b1])&&…

Git篇如何在自己服务器搭建自己的git私有仓库

要在自己的服务器上搭建自己的Git私有仓库&#xff0c;可以按照以下步骤进行操作&#xff1a; 安装Git服务器软件&#xff1a;选择一款适合的Git服务器软件&#xff0c;如GitLab或GitHub&#xff0c;并按照官方文档进行安装和配置。创建数据库&#xff1a;如果使用GitLab&…

Doris 集成 ElasticSearch

Doris-On-ES将Doris的分布式查询规划能力和ES(Elasticsearch)的全文检索能力相结合,提供更完善的OLAP分析场景解决方案: (1)ES中的多index分布式Join查询 (2)Doris和ES中的表联合查询,更复杂的全文检索过滤 1 原理 (1)创建ES外表后,FE会请求建表指定的主机,获取所有…

MATLAB算法实战应用案例精讲-【图像处理】边缘检测(补充篇)(附MATLAB代码实现)

目录 前言 几个相关概念 知识储备 数字图像处理(Digital Image Processing)

Qt 软件调试——windbg初篇(一)

在上一篇《Qt 软件调试&#xff08;二&#xff09;使用dump捕获崩溃信息》中我们结尾处提示大家先准备好windbg&#xff0c;windbg是非常强大的调试工具&#xff0c;对于我们进行代码调试和分析异常有着非常重要的意义。 在Qt软件调试这个系列的首篇&#xff0c;我们介绍了《Qt…

RPG项目01_层级设置

基于“RPG项目01_UI面板Game”&#xff0c; 找到狼人 添加组件&#xff0c;让狼人一定区域自动跟随主角进行攻击 解释&#xff1a;【烘培蓝色】因为如果什么都不做就会被烘培成蓝色对应的功能就是 可修改区域功能 当将区域设置成不可行走状态&#xff0c;则不为蓝色 烘培&…

手机备忘录在哪里找出来?

谈及手机备忘录&#xff0c;每一个品牌的手机大家都能找到很多&#xff0c;现在各大手机品牌都开发的有自带的手机备忘录&#xff0c;所以说&#xff1a;手机备忘录在哪里找出来并不难&#xff0c;即便是手机自带的没有备忘录工具&#xff0c;大家也是可以通过手机应用市场搜索…

在AWS EC2中部署和使用Apache Superset的方案

大纲 1 Superset部署1.1 启动AWS EC21.2 下载Superset Docker文件1.3 修改Dockerfile1.4 配置管理员1.5 结果展示1.6 检查数据库驱动1.7 常见错误处理 2 Glue&#xff08;可选参考&#xff09;3 IAM与安全组3.1 使用AWS Athena3.2 使用AWS RedShift或AWS RDS3.2.1 查看AWS Reds…

【电子取证篇】汽车取证数据提取与汽车取证实例浅析(附标准下载)

【电子取证篇】汽车取证数据提取与汽车取证实例浅析&#xff08;附标准下载&#xff09; 关键词&#xff1a;汽车取证&#xff0c;车速鉴定、声像资料鉴定、汽车EDR提取分析 汽车EDR一般记录车辆碰撞前后的数秒&#xff08;5s左右&#xff09;相关数据&#xff0c;包括车辆速…

Redis击穿(热点key失效)

Redis击穿是指在高并发情况下&#xff0c;一个键在缓存中过期失效时&#xff0c;同时有大量请求访问该键&#xff0c;导致所有请求都落到数据库上&#xff0c;对数据库造成压力。这种情况下&#xff0c;数据库可能无法及时处理这些请求&#xff0c;导致性能下降甚至崩溃。 为了…

熟悉tomcat的哪些配置?

Tomcat是一种常用的Java Web服务器&#xff0c;它提供了许多配置选项来控制其行为和性能。以下是一些常见的Tomcat配置&#xff1a; 端口配置&#xff1a;你可以配置Tomcat监听的端口号&#xff0c;通常用于指定HTTP和HTTPS服务的端口。连接池配置&#xff1a;Tomcat的连接池可…

基于openEuler20.03安装openGauss5.0.0及安装DBMind

基于openEuler20.03安装openGauss5.0.0及安装DBMind 一、环境说明二、安装部署三、问题及解决 一、环境说明 虚拟机&#xff1a;VirtualBox操作系统&#xff1a;openEuler20.3LTS &#xff08;x86&#xff09;数据库&#xff1a;openGauss5.0.0 (x86)DBMind&#xff1a;dbmind…

Pytest自动化测试数据驱动yaml/excel/csv/json

数据驱动 数据的改变从而驱动自动化测试用例的执行&#xff0c;最终引起测试结果的改变。简单说就是参数化的应用。 测试驱动在自动化测试中的应用场景&#xff1a; 测试步骤的数据驱动&#xff1b;测试数据的数据驱动&#xff1b;配置的数据驱动&#xff1b; 1、pytest结合…

Linux gtest单元测试

1 安装git sudo apt-get install git2 下载googletest git clone https://github.com/google/googletest.git3 安装googletest 注意1: 如果在 make 过程中报错,可在 CMakeLists.txt 中增加如下行,再执行下面的命令: SET(CMAKE_CXX_FLAGS “-std=c++11”) 注意2: CMakeLists…

Django回顾6

目录 一.Session 1.什么是Session 2.Django中Session相关方法 3.Django中的Session配置 二.中间件 1.什么是中间件 中间件的定义 2.中间件有什么用 3.自定义中间件 process_request和process_reponse &#xff08;1&#xff09;导入 &#xff08;2&#xff09;自定义…

5G常用简称

名称缩写全称缓冲区状态报告BSRBuffer Status Report&#xff08;主小区组MCGMaster Cell groupMCG的节点MNMasternode主小区PCellPrimary Cell&#xff0c;功率余量PHRPower Headroom Report主辅小区PSCellPrimary Secondary CellSCG的节点SNSecondarynode辅小区SCellSecondar…

centos安装node 、npm 、nvm

你好&#xff0c;这是Bing。我可以帮你用nodejs写一个http服务器。&#x1f60a; 根据我的搜索结果&#xff0c;你需要使用 require 指令来加载和引入 http 模块&#xff0c;然后使用 http.createServer 方法来创建一个服务器实例&#xff0c;最后使用 listen 方法来监听一个端…

优化您的Mac体验——System Dashboard Pro for Mac(系统仪表板)

作为Mac用户&#xff0c;我们都希望能够拥有一个高效、流畅的电脑体验。然而&#xff0c;在长时间使用后&#xff0c;我们的Mac可能会变得越来越慢&#xff0c;导致我们的工作效率下降。这时候&#xff0c;System Dashboard Pro for Mac(系统仪表板)就可以派上用场了。它是一款…

JAVA常见问题解答:解决Java 11新特性兼容性问题的六个步骤

引言&#xff1a; 随着技术的不断发展&#xff0c;Java作为一种被广泛使用的编程语言&#xff0c;也在不断更新和改进。Java 11作为Java的最新版本&#xff0c;带来了许多新的特性和改进。然而&#xff0c;对于一些老旧的Java应用程序来说&#xff0c;升级到Java 11可能会带来一…