一起学Hugging Face Transformers(14)- “自定义训练循环”问题解答

文章目录

  • 前言
  • 问题一:可以详解下面这两段代码吗?
    • 1. 训练步骤的计算
    • 2. 学习率调度器的定义
    • 3. 作用总结
    • 4. 示例详细解释
  • 问题二:学习率是什么
      • 学习率的重要性
      • 例子
      • 学习率调度器
      • 学习率调度策略
      • 示例代码
  • 问题三:什么是 num_warmup_steps 预热步数呢
      • 预热步数的作用
      • 如何设置预热步数
  • 问题四:总训练步数num_training_steps 应该怎么确定呢
    • 1. 数据集大小和批次大小
    • 2. 训练轮数(Epochs)
    • 3. 学习率调度器的策略
    • 示例计算
    • 4. 小结
  • 问题六:训练轮数是怎么确定的呢
    • 1. 训练收敛性
    • 2. 训练时间和计算资源
    • 3. 训练效果监控
    • 4. 经验和实验
    • 5. 示例
  • 问题七:训练轮数在哪里配置


前言

前一篇文章 自定义训练循环 收到不少提问,在这里统一解答一下。

问题一:可以详解下面这两段代码吗?

num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

1. 训练步骤的计算

num_training_steps = num_epochs * len(train_dataloader)
  • num_epochs:训练的轮数(epoch),表示整个训练数据集将被迭代多少次。
  • len(train_dataloader):训练数据加载器中的批次数(batch),表示一个epoch中有多少个批次。

这段代码的目的是计算训练的总步骤数(total training steps),即训练过程中将执行的前向和反向传播步骤的总次数。这是通过将每个epoch中的批次数与epoch数相乘得到的。这个值在设置学习率调度器时很重要。

2. 学习率调度器的定义

lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)
  • get_scheduler:这是一个用于获取学习率调度器的函数。学习率调度器(Learning Rate Scheduler)在训练过程中调整学习率,以便更好地控制模型的优化过程。
  • name=“linear”:指定调度器的类型为线性调度器(linear scheduler),表示学习率将在训练期间线性地从初始值下降到最终值。
  • optimizer=optimizer:指定优化器,调度器将与这个优化器一起工作。
  • num_warmup_steps=0:指定学习率预热步骤(warmup steps)的数量。在预热阶段,学习率从0逐渐增加到初始值。这里设置为0,表示没有预热阶段。
  • num_training_steps=num_training_steps:指定训练的总步骤数,即之前计算的num_training_steps。这告诉调度器整个训练过程中有多少步,这样它就能在训练过程中正确地调整学习率。

3. 作用总结

  • 计算总训练步骤数:这是为了让学习率调度器知道整个训练过程的总步数,以便合理地调整学习率。
  • 定义学习率调度器:学习率调度器根据总训练步骤数和指定的调度策略(如线性下降)来调整优化器的学习率,从而改善模型的训练效果。

4. 示例详细解释

假设我们有一个数据集,通过train_dataloader可以得到每个epoch中的批次数是500,我们打算训练3个epoch。则总的训练步骤数为:

num_epochs = 3
len(train_dataloader) = 500num_training_steps = num_epochs * len(train_dataloader)
num_training_steps = 3 * 500 = 1500

然后我们定义一个线性学习率调度器,这个调度器将在训练的1500步内从初始学习率线性下降到最终学习率。

lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=1500
)

这样,在训练过程中,调度器将根据当前训练步数调整学习率,从而可能提高训练稳定性和模型的最终性能。

问题二:学习率是什么

学习率(Learning Rate)是深度学习模型训练中的一个超参数,它决定了在每次迭代更新模型参数时步长的大小。具体来说,学习率控制了每次参数更新的幅度,从而影响模型的收敛速度和训练稳定性。

学习率的重要性

  • 步长大小

    • 较高的学习率:如果学习率过高,模型参数更新的步长过大,可能会导致训练过程不稳定,甚至无法收敛,模型的损失函数会在高值和低值之间大幅波动。
    • 较低的学习率:如果学习率过低,模型参数更新的步长过小,训练过程会非常缓慢,可能需要更多的迭代次数才能达到收敛。此外,过低的学习率可能会使模型陷入局部最优,无法找到全局最优解。
  • 收敛速度和稳定性

    • 一个适当的学习率可以加速训练过程,使模型更快达到较好的性能,同时保持训练过程的稳定性,避免震荡或发散。

例子

假设你在训练一个神经网络,损失函数(Loss Function)用于衡量模型预测值与真实值之间的差距。学习率决定了每次迭代中参数调整的幅度。

  • 公式
    • 更新后的参数 = 当前参数 - 学习率 × 梯度

学习率调度器

在训练过程中,固定的学习率可能并不能始终有效。为此,通常会使用学习率调度器(Learning Rate Scheduler)来动态调整学习率。例如,开始时使用较高的学习率,加速模型的训练;在训练的中后期逐步降低学习率,以精细调整模型参数,提高模型的最终性能。

学习率调度策略

  1. 固定学习率:整个训练过程中保持不变。
  2. 阶梯式衰减(Step Decay):每隔一定的迭代次数将学习率降低一定比例。
  3. 指数衰减(Exponential Decay):学习率按指数规律逐步衰减。
  4. 余弦退火(Cosine Annealing):学习率按照余弦函数曲线进行变化,通常在训练后期逐步减小。
  5. 线性调度(Linear Scheduler):学习率在整个训练过程中线性下降。

示例代码

在使用Hugging Face Transformers库进行训练时,可以使用学习率调度器来调整学习率。以下是一个简化的示例:

from transformers import get_scheduler
import torch# 假设我们使用Adam优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)# 设置训练的总步数和预热步数
num_training_steps = 1000  # 总步数
num_warmup_steps = 100  # 预热步数# 使用线性学习率调度器
lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
)# 训练循环
for epoch in range(num_epochs):for step, batch in enumerate(train_dataloader):# 前向传播和损失计算outputs = model(**batch)loss = outputs.loss# 反向传播和参数更新loss.backward()optimizer.step()lr_scheduler.step()optimizer.zero_grad()

在这个例子中,我们使用了一个线性学习率调度器。在预热阶段(前100步),学习率从0逐步增加到初始学习率(5e-5),之后在整个训练过程中线性下降,直到训练结束。这种方法有助于加速前期的训练并在后期进行更精细的调整。

问题三:什么是 num_warmup_steps 预热步数呢

在深度学习中,特别是在使用学习率调度器(Learning Rate Scheduler)时,预热步数(num_warmup_steps)是指在训练初期逐步增加学习率的步数。预热步数的设定可以帮助模型在训练初期更快地找到合适的参数配置,从而加速收敛过程。

预热步数的作用

  1. 加速收敛:在训练初期,模型参数通常处于较差的初始化状态。通过逐步增加学习率,可以帮助模型更快地适应训练数据,加速模型参数的调整和优化,从而加快收敛速度。

  2. 稳定训练:预热步数还有助于确保训练过程的稳定性。过早的高学习率可能导致模型参数更新过大,从而影响训练的稳定性和收敛性。通过逐步增加学习率,可以在训练初期避免这种问题。

如何设置预热步数

预热步数通常是一个超参数,需要根据具体的任务和模型进行调整。一般来说,预热步数的设置不宜过长或过短,一般占总训练步数的一小部分,比如总步数的5%到20%之间。

在使用学习率调度器时,可以通过设置 num_warmup_steps 参数来指定预热步数。例如,在 Hugging Face Transformers 中,使用 get_scheduler 函数设置线性调度器时可以指定预热步数:

from transformers import get_scheduler
import torchoptimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_training_steps = 1000  # 总步数
num_warmup_steps = 100  # 预热步数lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
)

在这个例子中,num_warmup_steps=100 表示在训练的前100步中,学习率将从0逐步增加到设定的初始学习率。这样做有助于平稳地启动训练过程,并为模型提供足够的时间适应训练数据。

问题四:总训练步数num_training_steps 应该怎么确定呢

确定总训练步数 num_training_steps 在深度学习中是非常重要的,它直接影响到模型的训练时长和效果。通常,确定总训练步数需要考虑以下几个因素:

1. 数据集大小和批次大小

  • 数据集大小:首先要考虑数据集中样本的总数,因为一个完整的训练步骤涉及整个数据集的多次迭代。

  • 批次大小:每个训练步骤中处理的样本数量。批次大小越大,每个epoch中的训练步骤就越少,总训练步数也会相应减少。

2. 训练轮数(Epochs)

  • 训练轮数:通常情况下,我们会设定一个训练的轮数,每个epoch表示将整个数据集训练一遍的次数。总训练步数应该考虑到每个epoch的批次数和训练轮数的乘积。

3. 学习率调度器的策略

  • 学习率调度器:如果使用了学习率调度器,需要确保总训练步数能够覆盖整个学习率策略所需的步数。例如,如果使用了一个线性调度器,就需要知道需要多少步来完全降低学习率。

示例计算

假设以下情况:

  • 数据集大小:10000个样本
  • 批次大小:32
  • 训练轮数:3个epoch
  • 学习率调度器:线性调度器,从初始学习率到最终学习率的过渡需要1000步

计算总训练步数的步骤如下:

1) 计算每个epoch中的训练步数:

num_batches_per_epoch = len(dataset) / batch_size

其中,len(dataset) 是数据集中的样本数量,batch_size 是每个批次中的样本数量。

假设 len(dataset) = 10000batch_size = 32,则:

num_batches_per_epoch = 10000 / 32 = 312.5

因为批次大小必须是整数,所以每个epoch中的实际批次数为 312

2) 计算总训练步数:

total_training_steps = num_batches_per_epoch * num_epochs

假设 num_epochs = 3,则:

total_training_steps = 312 * 3 = 936

3) 考虑学习率调度器:

如果还有学习率调度器,在计算总训练步数时要确保它能够完整执行其策略所需的步骤。例如,如果需要额外的1000步来完成学习率从初始到最终的过渡,则总训练步数应为:

total_training_steps = total_training_steps + 1000

4. 小结

确定总训练步数需要结合数据集大小、批次大小、训练轮数和任何使用的学习率调度器策略。这样可以确保训练过程充分覆盖所有数据,并根据需要调整学习率以优化模型的训练效果。

问题六:训练轮数是怎么确定的呢

确定训练轮数(epochs)通常需要考虑以下几个因素:

1. 训练收敛性

训练轮数应该足够多,使得模型能够在训练过程中逐渐收敛到一个较好的状态。一般来说,随着训练轮数的增加,模型的性能(如损失函数的减少、精度的提高)会逐步稳定。

2. 训练时间和计算资源

训练轮数的增加会导致训练时间的增加,尤其是在数据集较大或模型复杂的情况下。因此,需要在训练时间和计算资源之间进行权衡,选择一个合适的训练轮数。

3. 训练效果监控

可以通过监控训练过程中的指标变化来确定是否需要增加训练轮数。例如,可以观察损失函数的下降曲线是否趋于平稳,或者验证集上的性能是否达到了一个稳定的水平。

4. 经验和实验

通常,选择训练轮数也具有一定的经验性和试验性质。可以先尝试一些常见的训练轮数,如10、20轮,然后观察模型的表现。根据实际情况调整训练轮数,以达到最佳的训练效果。

5. 示例

假设你正在训练一个图像分类模型,通常情况下,可以按照以下步骤来确定训练轮数:

1) 初始尝试:开始时,可以选择一个相对较小的训练轮数(如5轮),观察模型在训练集和验证集上的表现。

2) 监控训练进展:通过训练过程中的损失函数变化、准确率等指标来评估模型的收敛情况。如果模型在几轮训练后仍在改善,可以继续增加训练轮数。

3) 早停策略:一种常见的做法是使用早停策略(Early Stopping),即当验证集上的性能不再提升时停止训练。这样可以避免过度拟合,并节省计算资源。

4) 超参数调整:在确定训练轮数的同时,还应该考虑其他超参数(如学习率、批次大小等)的调整,以优化模型的训练效果。

总之,确定训练轮数是一个根据实际情况进行调整和优化的过程,需要综合考虑模型的收敛速度、训练时间和计算资源等因素,以达到最佳的训练效果。

问题七:训练轮数在哪里配置

训练轮数通常在训练代码或训练脚本中进行配置和设定。具体来说,它可能涉及以下几个方面的设置:

  1. 训练循环中的epoch设置:在训练代码中,通常会有一个循环来迭代每个epoch。这个循环会设定训练轮数的具体次数。例如,在Python中的训练代码中可能会有如下形式的循环结构:

    num_epochs = 10  # 设置训练轮数为10
    for epoch in range(num_epochs):# 在每个epoch中执行训练步骤for batch in train_dataloader:# 训练逻辑
    

    这里的 num_epochs 变量就是设定的训练轮数。

  2. 训练脚本的参数配置:有时候训练脚本会接受命令行参数或配置文件来设定训练的相关参数,包括训练轮数。例如,可以通过命令行参数或配置文件中设置一个参数来指定训练轮数。

  3. 训练参数对象或类的属性:如果使用面向对象的方式编写训练代码,训练参数(如训练轮数)可能会作为对象的属性。在创建训练实例时,可以通过修改这些属性来设定训练轮数。

  4. 集成开发环境(IDE)或集成训练平台:一些集成开发环境或训练平台(如TensorBoard、PyTorch Lightning等)可能会提供图形化界面或配置文件来设置训练的各种参数,包括训练轮数。

具体到实际使用的情况,可以根据所用的深度学习框架和编程风格来确定如何设置和配置训练轮数。在编写和调试训练代码时,确保训练轮数的设置能够满足任务的需求,并在实验过程中根据训练结果进行必要的调整。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/43208.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Mysql Workbench的使用

本篇内容:对Mysql Workbench的常规使用学习 一、知识储备 1. Workbench 可以做什么 是mysql数据库可视化管理的一款免费工具,除了平常的通过sql语句,进行创建数据库表、增删改查外,还可以利用其进行建模创建数据库表。通过创建…

域名注册后还需要做什么?

在建立网站或在线业务时,域名注册是一个非常重要的步骤。但是,仅仅注册一个域名还不足以让您的网站或在线业务成功运营。在域名注册后,还需要进行一系列的步骤来确保您的网站能够正常运行,并吸引到访者。本文将介绍域名注册后的必…

人工智能的新时代:从模型到应用的转变

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

【Linux】记录一起网站劫持事件

故事很短,处理也简单。权当记录一下,各位安全大大们手下留情。 最近一位客户遇到官网被劫持的情况,想我们帮忙解决一下(本来不关我们的事,毕竟情面在这…还是无偿地协助一下),经过三四轮“谦让…

Conda修改默认环境创建路径

conda安装好后默认将新建环境安装在C盘 修改.condarc 配置文件 注 : Windows操作系统创建的 .condarc 文件通常在 C:\Users\User_name 这个目录下; 注 : Linux操作系统创建的 .condarc 文件通常在/home/User_name 这个目录下。 在.condarc文件中添加以下内容 有…

海康威视监控web实时预览解决方案

海康威视摄像头都试rtsp流,web页面无法加载播放,所以就得转换成web页面可以播放的hls、rtmp等数据流来播放。 一:萤石云 使用萤石云平台,把rtsp转化成ezopen协议,然后使用组件UIKit 最佳实践 萤石开放平台API文档 …

【ROS2】中级-编写动作服务器和客户端(Python)

目标:用 Python 实现一个动作服务器和客户端。 教程级别:中级 时间:15 分钟 目录 背景 先决条件 任务 1. 编写动作服务器2. 编写动作客户端 摘要 相关内容 背景 动作是 ROS 2 中异步通信的一种形式。动作客户端向动作服务器发送目标请求。动作…

SpringBoot整合MongoDB文档相关操作

文章目录 SpringBoot整合MongoDB文档操作添加文档查询文档更新文档删除文档 SpringBoot整合MongoDB 创建项目&#xff0c;添加依赖&#xff0c;配置连接 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-dat…

Python 数据容器的对比

五类数据容器 列表&#xff0c;元组&#xff0c;字符串&#xff0c;集合&#xff0c;字典 是否能下标索引 支持&#xff1a;列表&#xff0c;元组&#xff0c;字符串 不支持&#xff1a;集合&#xff0c;字典 是否能放重复元素 是&#xff1a;列表&#xff0c;元组&#…

遥感分类产品精度验证之TIF验证TIF

KKB_2020.tif KKB_2020_JRC.tif kkb.geojson 所用到的包&#xff1a;&#xff08;我嫌geopandas安装太麻烦colab做的。。 import rasterio import geopandas as gpd import numpy as np import pandas as pd import matplotlib.pyplot as plt from sklearn.metrics import c…

【零基础】学JS之APIS(基于黑马)

喝下这碗鸡汤 披盔戴甲,一路勇往直前! 1. 什么是事件 事件是在编程时系统内发生的动作或者发生的事情 比如用户在网页上单击一个按钮 2. 什么是事件监听? 就是让程序检测是否有事件产生&#xff0c;一旦有事件触发&#xff0c;就立即调用一个函数做出响应&#xff0c;也称为 注…

MySQL怎么获取当前时间

在 MySQL 中&#xff0c;您可以使用以下几种方式获取当前时间&#xff1a; 使用 NOW() 函数&#xff1a; SELECT NOW();NOW() 函数返回当前的日期和时间&#xff0c;格式为 YYYY-MM-DD HH:MM:SS 。 使用 CURRENT_TIMESTAMP 函数&#xff1a; SELECT CURRENT_TIMESTAMP;其效果与…

如何用java语言开发一套数字化产科系统 数字化产科管理平台源码

如何用java语言开发一套数字化产科系统 数字化产科管理平台源码 要使用Java语言来开发一个数字化产科系统&#xff0c;你需要遵循一系列步骤&#xff0c;从环境搭建到系统设计与开发&#xff0c;再到测试与部署。 以下是一个大致的开发流程概览&#xff1a; 1. 环境搭建 Jav…

从Docker 网络看IaC

【引子】近来&#xff0c;老码农又一次有机会实施IaC 了&#xff0c; 但是环境有了新的变化&#xff0c;涵盖了云环境、虚拟机、K8S 以及Docker&#xff0c;而网络自动化则是IaC中的重要组成&#xff0c;温故知新&#xff0c;面向Docker 的网络是怎样的呢&#xff1f; Docker …

C++相关概念和易错语法(16)(list)

1.list易错点 &#xff08;1&#xff09;慎用list的sort&#xff0c;list的排序比vector慢得多&#xff0c;尽管两者时间复杂度一样&#xff0c;甚至不如先把list转为vector&#xff0c;用vector排完序后再转为list &#xff08;2&#xff09;splice是剪切链表&#xff0c;将…

指数增长远大于nlgn

在学习算法导论的时候&#xff0c;遇到了这么一行字把我难住了。我不理解为什么叶节点代价总和就为Ω(nlgn)了&#xff0c;后来经过学习之后了解了&#xff0c;因为n的指数严格大于1&#xff0c;只要指数函数的指数大于1就是指数增长&#xff0c;那么就远大于nlgn。

C++ | Leetcode C++题解之第22题完全二叉树的节点个数

题目&#xff1a; 题解&#xff1a; class Solution { public:int countNodes(TreeNode* root) {if (root nullptr) {return 0;}int level 0;TreeNode* node root;while (node->left ! nullptr) {level;node node->left;}int low 1 << level, high (1 <&…

【笔记】finalshell中使用nano编辑器GNU

ctrl O 保存 enter 确定 ctrl X 退出 nano编辑 能不用就不用吧 因为我真用不习惯 nano编辑的文件也可以用vim编辑的

Social to Sales全链路,数说故事专享会开启出海新视角

————瞎出海&#xff0c;必出局 TikTok&#xff0c;这个充满活力的短视频平台&#xff0c;已经成为全球范围内不可忽视的电商巨头。就在6月8日&#xff0c;TikTok美区带货直播诞生了首个“百万大场”。在此之前&#xff0c;百万GMV被视为一道难以逾越的高墙。以TikTok为首的…

224. 基本计算器

给你一个字符串表达式 s &#xff0c;请你实现一个基本计算器来计算并返回它的值。 注意:不允许使用任何将字符串作为数学表达式计算的内置函数&#xff0c;比如 eval() 。 示例 1&#xff1a; 输入&#xff1a;s "1 1" 输出&#xff1a;2示例 2&#xff1a; 输入…