pytorch08:学习率调整策略

在这里插入图片描述

目录

  • 一、为什么要调整学习率?
    • 1.1 class _LRScheduler
  • 二、pytorch的六种学习率调整策略
    • 2.1 StepLR
    • 2.2 MultiStepLR
    • 2.3 ExponentialLR
    • 2.4 CosineAnnealingLR
    • 2.5 ReduceLRonPlateau
    • 2.6 LambdaLR
  • 三、学习率调整小结
  • 四、学习率初始化

一、为什么要调整学习率?

学习率(learning rate):控制更新的步伐
一般在模型训练过程中,在开始训练的时候我们会设置学习率大一些,随着模型训练epoch的增加,学习率会逐渐设置小一些。

1.1 class _LRScheduler

学习率调整的父类函数
在这里插入图片描述
主要属性:
• optimizer:关联的优化器
• last_epoch:记录epoch数
• base_lrs:记录初始学习率
主要方法:
• step():更新下一个epoch的学习率,该操作必须放到epoch循环下面
• get_lr():虚函数,计算下一个epoch的学习率

二、pytorch的六种学习率调整策略

2.1 StepLR

在这里插入图片描述

功能:等间隔调整学习率
主要参数:
• step_size:调整间隔数
• gamma:调整系数
调整方式:lr = lr * gamma

代码实现:

import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plttorch.manual_seed(1)LR = 0.1
iteration = 10
max_epoch = 200
# ------------------------------ fake data and optimizer  ------------------------------weights = torch.randn((1), requires_grad=True)
target = torch.zeros((1))optimizer = optim.SGD([weights], lr=LR, momentum=0.9)# ------------------------------ 1 Step LR ------------------------------
# flag = 0
flag = 1
if flag:scheduler_lr = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)  # 设置学习率下降策略,50轮下降一次,每次下降10倍lr_list, epoch_list = list(), list()for epoch in range(max_epoch):lr_list.append(scheduler_lr.get_lr())epoch_list.append(epoch)for i in range(iteration):loss = torch.pow((weights - target), 2)loss.backward()optimizer.step()optimizer.zero_grad()scheduler_lr.step()  # 学习率更新策略plt.plot(epoch_list, lr_list, label="Step LR Scheduler")plt.xlabel("Epoch")plt.ylabel("Learning rate")plt.legend()plt.show()

输出结果:
在这里插入图片描述

因为我们设置每50个epoch降低一次学习率,所以在7774554

2.2 MultiStepLR

在这里插入图片描述

功能:按给定间隔调整学习率
主要参数:
• milestones:设定调整时刻数
• gamma:调整系数
调整方式:lr = lr * gamma

代码实现

flag = 1
if flag:milestones = [50, 125, 160]  # 设置学习率下降的位置scheduler_lr = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1)lr_list, epoch_list = list(), list()for epoch in range(max_epoch):lr_list.append(scheduler_lr.get_lr())epoch_list.append(epoch)for i in range(iteration):loss = torch.pow((weights - target), 2)loss.backward()optimizer.step()optimizer.zero_grad()scheduler_lr.step()plt.plot(epoch_list, lr_list, label="Multi Step LR Scheduler\nmilestones:{}".format(milestones))plt.xlabel("Epoch")plt.ylabel("Learning rate")plt.legend()plt.show()

输出结果
在这里插入图片描述

根据我们设置milestones = [50, 125, 160],发现学习率在这三个地方发生下降。

2.3 ExponentialLR

在这里插入图片描述

功能:按指数衰减调整学习率
主要参数:
• gamma:指数的底
调整方式:lr = lr * gamma^epoch;这里的gamma通常设置为接近1的数值,例如:0.95

代码实现

flag = 1
if flag:gamma = 0.95scheduler_lr = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)lr_list, epoch_list = list(), list()for epoch in range(max_epoch):lr_list.append(scheduler_lr.get_lr())epoch_list.append(epoch)for i in range(iteration):loss = torch.pow((weights - target), 2)loss.backward()optimizer.step()optimizer.zero_grad()scheduler_lr.step()plt.plot(epoch_list, lr_list, label="Exponential LR Scheduler\ngamma:{}".format(gamma))plt.xlabel("Epoch")plt.ylabel("Learning rate")plt.legend()plt.show()

输出结果
在这里插入图片描述

可以发现学习率是呈指数下降的。

2.4 CosineAnnealingLR

在这里插入图片描述

功能:余弦周期调整学习率
主要参数:
• T_max:下降周期
• eta_min:学习率下限
调整方式:
在这里插入图片描述

代码实现

flag = 1
if flag:t_max = 50scheduler_lr = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=t_max, eta_min=0.)lr_list, epoch_list = list(), list()for epoch in range(max_epoch):lr_list.append(scheduler_lr.get_lr())epoch_list.append(epoch)for i in range(iteration):loss = torch.pow((weights - target), 2)loss.backward()optimizer.step()optimizer.zero_grad()scheduler_lr.step()plt.plot(epoch_list, lr_list, label="CosineAnnealingLR Scheduler\nT_max:{}".format(t_max))plt.xlabel("Epoch")plt.ylabel("Learning rate")plt.legend()plt.show()

输出结果
在这里插入图片描述

2.5 ReduceLRonPlateau

在这里插入图片描述

功能:监控指标,当指标不再变化则调整,例如:可以监控我们的loss或者准确率,当其不发生变化的时候,调整学习率。
主要参数:
• mode:min/max 两种模式
min模式:当某一个值不下降的时候我们调整学习率,通常用于监控损失
max模型:当某一个值不上升的时候我们调整学习率,通常用于监控精确度
• factor:调整系数
• patience:“耐心”,接受几次不变化
• cooldown:“冷却时间”,停止监控一段时间
• verbose:是否打印日志
• min_lr:学习率下限
• eps:学习率衰减最小值

代码实现

flag = 1
if flag:loss_value = 0.5accuray = 0.9factor = 0.1  # 学习率变换参数mode = "min"patience = 10  # 能接受多少轮不变化cooldown = 10  # 停止监控多少轮min_lr = 1e-4  # 设置学习率下限verbose = True  # 打印更新日志scheduler_lr = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=factor, mode=mode, patience=patience,cooldown=cooldown, min_lr=min_lr, verbose=verbose)for epoch in range(max_epoch):for i in range(iteration):# train(...)optimizer.step()optimizer.zero_grad()#if epoch == 5:# loss_value = 0.4scheduler_lr.step(loss_value) #监控的标量是否下降

输出结果
在这里插入图片描述

2.6 LambdaLR

在这里插入图片描述
功能:自定义调整策略
主要参数:
• lr_lambda:function or list

代码实现

flag = 1
if flag:lr_init = 0.1weights_1 = torch.randn((6, 3, 5, 5))weights_2 = torch.ones((5, 5))optimizer = optim.SGD([{'params': [weights_1]},{'params': [weights_2]}], lr=lr_init)# 设置两种不同的学习率调整方法lambda1 = lambda epoch: 0.1 ** (epoch // 20)  # 每到20轮的时候学习率变为原来的0.1倍lambda2 = lambda epoch: 0.95 ** epoch  # 将学习率进行指数下降scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])lr_list, epoch_list = list(), list()for epoch in range(max_epoch):for i in range(iteration):# train(...)optimizer.step()optimizer.zero_grad()scheduler.step()lr_list.append(scheduler.get_lr())epoch_list.append(epoch)print('epoch:{:5d}, lr:{}'.format(epoch, scheduler.get_lr()))plt.plot(epoch_list, [i[0] for i in lr_list], label="lambda 1")plt.plot(epoch_list, [i[1] for i in lr_list], label="lambda 2")plt.xlabel("Epoch")plt.ylabel("Learning Rate")plt.title("LambdaLR")plt.legend()plt.show()

输出结果
在这里插入图片描述

通过lambda方法定义了两种不同的学习率下降策略。

三、学习率调整小结

  1. 有序调整:Step、MultiStep、Exponential 和 CosineAnnealing
  2. 自适应调整:ReduceLROnPleateau
  3. 自定义调整:Lambda

四、学习率初始化

1、设置较小数:0.01、0.001、0.0001
2、搜索最大学习率: 参考该篇《Cyclical Learning Rates for Training Neural Networks》
方法:我们可以设置学习率逐渐从小变大观察精确度的一个变化,下面这幅图,当学习率为0.055左右的时候模型精确度最高,当学习率大于0.055的时候精确度出现下降情况,所以在模型训练过程中我们可以设置学习率为0.055作为我们的初始学习率。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/600288.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Leetcode2965. 找出缺失和重复的数字

Every day a Leetcode 题目来源:2965. 找出缺失和重复的数字 解法1:哈希 用哈希表统计数组 grid 中各元素的出现次数,其中出现次数为 2 的记为 a。 统计数组 grid 的元素之和为 sum。 数组 grid 其中的值在 [1, n2] 范围内,…

【AWS系列】巧用 G5g 畅游Android流媒体游戏

序言 Amazon EC2 G5g 实例由 AWS Graviton2 处理器提供支持,并配备 NVIDIA T4G Tensor Core GPU,可为 Android 游戏流媒体等图形工作负载提供 Amazon EC2 中最佳的性价比。它们是第一个具有 GPU 加速功能的基于 Arm 的实例。 借助 G5g 实例,游…

06.JAVA常用类方法(StringBuilder与StringBuffer、Arrays与Array比较和常用方法+代码示例)

StringBuilder 基本介绍 概述:可变的字符串类,可视为一个容器,StringBuilder对象中的内容是可变的,也可称之为字符串缓冲类。其常用的成员方法根据有无返回值可以分为查询和删改两大类。 构造方法:StringBuilder( …

阻止持久性攻击改善网络安全

MITRE ATT&CK框架是一个全球可访问的精选知识数据库,其中包含基于真实世界观察的已知网络攻击技术和策略。持久性是攻击者用来访问系统的众多网络攻击技术之一;在获得初始访问权限后,他们继续在很长一段时间内保持立足点,以窃取数据、修改…

Java实战项目六:进销存管理系统

文章目录 一、系统概述二、关键知识点梳理(一)数据库设计与连接(二)实体类与关系映射(三)集合框架的应用(四)CRUD操作与事务管理 三、系统架构与模块划分(一)…

集合-及其各种特征详解

集合 概念:是提供一种存储空间 可变 的存储模型,存储的数据容量可以发生改变。(也就是集合容量不固定) 集合关系图 绿色的代表接口,蓝色的代表接口的实现类 单列集合 Collection(接口) 概述:单列集合的…

spring模块(一)DI依赖注入

一、介绍:Dependency Injection,组件之间依赖关系由容器在运行期决定,形象的说,即由容器动态的将某个依赖关系注入到组件之中。依赖注入的目的并非为软件系统带来更多功能,而是为了提升组件重用的频率,并为…

【策】策略

为应对未来激烈的竞争,需要具备以下要素: 主动选择。每个人都有主动选择的权利,既然做出选择,应对其产生的结果而负责。反逆向思维。梳理过程,分析因果,优化决策。

How to implement anti-crawler strategies to protect site data

How to implement anti-crawler strategies to protect site data 信息校验型反爬虫User-Agent反爬虫Cookie反爬虫签名验证反爬虫WebSocket握手验证反爬虫WebSocket消息校验反爬虫WebSocket Ping反爬虫 动态渲染反爬虫文本混淆反爬虫图片伪装反爬虫CSS偏移反爬虫SVG映射反爬虫字…

VM与欧姆龙PLC通讯设置

1、欧姆龙PLC 进行网口通讯,协议用的Fins tcp,也可以用Fins UDP。 2、主要步骤如下; step1:设置IP地址、端口号默认是9600,根据需要设置寄存器首地址和寄存器数量 step2:鼠标移动到某个地址下&#xff0c…

Python入门-函数

1.函数的定义及调用 函数:函数是将一段实现功能的完整代码,使用函数名称进行封装,通过函数名称进行调用。 以此达到一次编写,多次调用的目的 def get_sum(num): #num叫形式参数s0for i in range(1,num1):siprint(f1到{num}之…

数字后端设计实现 | 数字后端PR工具Innovus中如何创建不同高度的row?

吾爱IC社区星球学员问题:Innovus后端实现时两种种不同高度的site能做在一个pr里面吗? 答案是可以的。 Innovus支持在同一个设计中中使用不同的row,但需要给各自子模块创建power domain。这里所说的不同高度的row,有两种情况。 1…

Docker一键极速安装Nacos,并配置数据库!

1 部署方式 1.1 DockerHub javaedgeJavaEdgedeMac-mini ~ % docker run --name nacos \ -e MODEstandalone \ -e JVM_XMS128m \ -e JVM_XMX128m \ -e JVM_XMN64m \ -e JVM_MS64m \ -e JVM_MMS64m \ -p 8848:8848 \ -d nacos/nacos-server:v2.2.3 a624c64a1a25ad2d15908a67316d…

指针大礼包6

第6题 (10.0分) 题号:5 难度:中 第8章 /*------------------------------------------------------- 【程序改错】 --------------------------------------------------------- 题目:下列给定程序中函数fun的功能是…

代码随想录刷题第三十九天| 62.不同路径 ● 63. 不同路径 II

代码随想录刷题第三十九天 不同路径 (LC 62) 题目思路: 代码实现: class Solution:def uniquePaths(self, m: int, n: int) -> int:dp [[0 for _ in range(n1)] for _ in range(m1)]dp[0][1] 1for i in range(1,m1):for j in range(1, n1):dp[i]…

使用IDEA官方docker插件构建镜像

此方法同样适用于jetbrains系列的其他开发软件 在IDEA中&#xff0c;如果是maven项目&#xff0c;可以使用插件 <plugin><groupId>com.spotify</groupId><artifactId>docker-maven-plugin</artifactId><version>1.2.2</version> &…

Linux入门攻坚——11、Linux网络属性配置相关知识1

网络基础知识&#xff1a; 局域网&#xff1a;以太网&#xff0c;令牌环网&#xff0c; Ethernet&#xff1a;CSMA/CD 冲突域 广播域 MAC&#xff1a;Media Access Control&#xff0c;共48bit&#xff0c;前24bit需要机构分配&#xff0c;后24bit自己…

WPF 如何知道当前有多少个 DispatcherTimer 在运行

在 WPF 调试中&#xff0c;对于 DispatcherTimer 定时器的执行&#xff0c;没有直观的调试方法。本文来告诉大家如何在 WPF 中调试当前主线程有多少个 DispatcherTimer 在运行 在 WPF 中&#xff0c;如果有 DispatcherTimer 定时器在执行&#xff0c;将会影响到主线程的执行&a…

Qt6入门教程 2:Qt6下载与安装

Qt6不提供离线安装包&#xff0c;下载和安装实际上是一体的了。 关于Qt简介&#xff0c;详见&#xff1a;Qt6入门教程1&#xff1a;Qt简介 一.下载在线安装器 Qt官网 地址&#xff1a;https://download.qt.io/ 在线下载器地址&#xff1a;https://download.qt.io/archive/on…

Unity | NGO网络框架

目录 一、相关属性及变量 1.ServerRpc属性 2.ClientRpc属性 3.NetworkVariable变量 二、相关组件 1.NetworkManager 2.Unity Transport 3.Network Object 4.NetworkBehaviour&#xff1a; 5.NetworkTransform Syncing(Synchronizing) Thresholds Interpolation 三…