《深度学习》——ResNet网络

文章目录

  • ResNet网络
    • ResNet网络实例
        • 导入所需库
        • 下载训练数据和测试数据
        • 设置每个批次的样本个数
        • 判断是否使用GPU
        • 定义残差模块
        • 定义ResNet网络
        • 模型导入GPU
        • 定义训练函数
        • 定义测试函数
        • 创建损失函数和优化器
        • 训练测试数据
        • 结果

ResNet网络

ResNet(Residual Network,残差网络)是深度学习领域中非常重要且具有影响力的一种卷积神经网络(CNN)架构,由何恺明等人于 2015 年提出,在图像识别、目标检测等诸多计算机视觉任务中取得了巨大成功。
1. 产生背景:在深度学习发展过程中,随着网络深度的增加,会出现梯度消失或梯度爆炸的问题,导 致网络难以训练。即使通过归一化等方法解决了梯度问题,还会面临退化问题,即网络深度增加时,模型的训 练误差和测试误差反而增大。ResNet 的提出就是为了解决深度神经网络中的退化问题。
在这里插入图片描述
在这里插入图片描述

  • ResNet-18:是 ResNet 家族中相对较浅的网络,由 4 个残差块组构成,每个残差块组包含不同数量的残差块。它的结构简单,计算量相对较小,适合计算资源有限或对模型复杂度要求不高的场景,如一些小型图像数据集的分类任务。它在一些对实时性要求较高的应用中,如移动设备上的图像识别,也有一定的应用。
  • ResNet-34:同样由 4 个残差块组组成,但相比 ResNet-18,它在某些残差块组中包含更多的残差块,网络深度更深,因此能够学习到更复杂的特征表示。它在中等规模的图像数据集上表现良好,在一些对模型性能有一定要求但又不过分追求极致精度的任务中较为常用。
  • ResNet-50:是一个比较常用的 ResNet 模型,在许多计算机视觉任务中都有广泛应用。它使用了瓶颈结构(Bottleneck)的残差块,这种结构通过先降维、再卷积、最后升维的方式,在减少计算量的同时保持了模型的表达能力。该模型在图像分类、目标检测、语义分割等任务中,都能作为性能不错的骨干网络,为后续的任务提供有效的特征提取。
  • ResNet-101:比 ResNet-50 的网络层数更多,拥有更强大的特征提取能力。它适用于大规模图像数据集和复杂的计算机视觉任务,如在大型目标检测数据集中,能够更好地捕捉目标的细节特征,提升检测的准确性。由于其深度和复杂度,在处理高分辨率图像或需要精细特征表示的任务时表现出色。
  • ResNet-152:是 ResNet 系列中深度较深的网络,具有极高的特征提取能力。但由于其深度很大,计算量和参数量也相应增加,训练和推理所需的时间和资源较多。它通常用于对精度要求极高的场景,如学术研究中的图像识别挑战、大规模图像搜索引擎的图像特征提取等。

18层残差网络:

在这里插入图片描述

ResNet网络实例

项目需求:对手写数字进行识别。
数据集:此项目数据集来自MNIST 数据集由美国国家标准与技术研究所(NIST)整理而成,包含手写数字的图像,主要用于数字识别的训练和测试。该数据集被分为两部分:训练集和测试集。训练集包含 60,000 张图像,用于模型的学习和训练;测试集包含 10,000 张图像,用于评估训练好的模型在未见过的数据上的性能。
图像格式:数据集中的图像是灰度图像,即每个像素只有一个值表示其亮度,取值范围通常为 0(黑色)到 255(白色)。
图像尺寸:每张图像的尺寸为 28x28 像素,总共有 784 个像素点。
标签信息:每个图像都有一个对应的标签,标签是 0 到 9 之间的整数,表示图像中手写数字的值。

导入所需库
import torch
from torch import nn  # 导入神经网络模块
from torch.utils.data import DataLoader  # 数据包管理工具,打包数据
from torchvision import datasets  # 封装了很对与图像相关的模型,数据集
from torchvision.transforms import ToTensor  # 数据转换,张量,将其他类型的数据转换成tensor张量
import torch.nn.functional as F # 用于应用 ReLU 激活函数
下载训练数据和测试数据
'''下载训练数据集(包含训练集图片+标签)'''
training_data = datasets.MNIST(  # 跳转到函数的内部源代码,pycharm 按下ctrl+鼠标点击root='data',  # 表示下载的手写数字 到哪个路径。60000train=True,  # 读取下载后的数据中的数据集download=True,  # 如果你之前已经下载过了,就不用再下载了transform=ToTensor(),  # 张量,图片是不能直接传入神经网络模型# 对于pytorch库能够识别的数据一般是tensor张量
)'''下载测试数据集(包含训练图片+标签)'''
test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor(),  # Tensor是在深度学习中提出并广泛应用的数据类型,它与深度学习框架(如pytorch,TensorFlow)
)  # numpy数组只能在cpu上运行。Tensor可以在GPU上运行,这在深度学习应用中可以显著提高计算速度。
print(len(training_data))
print(len(test_data))
设置每个批次的样本个数
train_dataloader = DataLoader(training_data, batch_size=64)  # 建议用2的指数当作一个包的数量
test_dataloader = DataLoader(test_data, batch_size=64)
判断是否使用GPU
'''判断是否支持GPU'''
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')
定义残差模块
# 定义残差块类,继承自 nn.Module
class ResBlock(nn.Module):def __init__(self, channels_in):# 调用父类的构造函数super().__init__()# 定义第一个卷积层,输入通道数为 channels_in,输出通道数为 30,卷积核大小为 5,填充为 2self.conv1 = torch.nn.Conv2d(channels_in, 30, 5, padding=2)# 定义第二个卷积层,输入通道数为 30,输出通道数为 channels_in,卷积核大小为 3,填充为 1self.conv2 = torch.nn.Conv2d(30, channels_in, 3, padding=1)def forward(self, x):# 输入数据通过第一个卷积层out = self.conv1(x)# 经过第一个卷积层的输出再通过第二个卷积层out = self.conv2(out)# 将输入 x 与卷积输出 out 相加,并通过 ReLU 激活函数return F.relu(out + x)
定义ResNet网络
# 定义 ResNet 网络类,继承自 nn.Module
class ResNet(nn.Module):def __init__(self):# 调用父类的构造函数super().__init__()# 定义第一个卷积层,输入通道数为 1,输出通道数为 20,卷积核大小为 5self.conv1 = torch.nn.Conv2d(1, 20, 5)# 定义第二个卷积层,输入通道数为 20,输出通道数为 15,卷积核大小为 3self.conv2 = torch.nn.Conv2d(20, 15, 3)# 定义最大池化层,池化核大小为 2self.maxpool = torch.nn.MaxPool2d(2)# 定义第一个残差块,输入通道数为 20self.resblock1 = ResBlock(channels_in=20)# 定义第二个残差块,输入通道数为 15self.resblock2 = ResBlock(channels_in=15)# 定义全连接层,输入特征数为 375,输出特征数为 10self.full_c = torch.nn.Linear(375, 10)def forward(self, x):# 获取输入数据的批次大小size = x.shape[0]# 输入数据通过第一个卷积层,然后进行最大池化,最后通过 ReLU 激活函数x = F.relu(self.maxpool(self.conv1(x)))# 经过第一个卷积和池化的输出通过第一个残差块x = self.resblock1(x)# 经过第一个残差块的输出通过第二个卷积层,然后进行最大池化,最后通过 ReLU 激活函数x = F.relu(self.maxpool(self.conv2(x)))# 经过第二个卷积和池化的输出通过第二个残差块x = self.resblock2(x)# 将输出数据展平为一维向量x = x.view(size, -1)# 展平后的向量通过全连接层x = self.full_c(x)return x
模型导入GPU
model = ResNet().to(device)
定义训练函数
# 定义训练函数
def train(dataloader, model, loss_fn, optimizer):# 将模型设置为训练模式,这会影响一些层(如 Dropout、BatchNorm 等)的行为model.train()# 初始化批次编号batch_size_num = 1# 遍历数据加载器中的每个批次for x, y in dataloader:# 将输入数据和标签移动到指定设备(如 GPU)x, y = x.to(device), y.to(device)# 前向传播,计算模型的预测结果pred = model.forward(x)# 通过交叉熵损失函数计算预测结果与真实标签之间的损失值loss = loss_fn(pred, y)# 反向传播步骤:# 清零优化器中的梯度信息,防止梯度累积optimizer.zero_grad()# 反向传播计算每个参数的梯度loss.backward()# 根据计算得到的梯度更新模型的参数optimizer.step()# 从张量中提取损失值的标量loss_value = loss.item()# 每 100 个批次打印一次损失值if batch_size_num % 100 == 0:print(f'loss:{loss_value:7f}  [number:{batch_size_num}]')# 批次编号加 1batch_size_num += 1
定义测试函数
# 定义测试函数
def test(dataloader, model, loss_fn):# 获取数据集的总样本数size = len(dataloader.dataset)# 获取数据加载器中的批次数量num_batches = len(dataloader)# 将模型设置为评估模式,这会影响一些层(如 Dropout、BatchNorm 等)的行为model.eval()# 初始化测试损失和正确预测的样本数test_loss, correct = 0, 0# 上下文管理器,关闭梯度计算,减少内存消耗with torch.no_grad():# 遍历数据加载器中的每个批次for x, y in dataloader:# 将输入数据和标签移动到指定设备(如 GPU)x, y = x.to(device), y.to(device)# 前向传播,计算模型的预测结果pred = model.forward(x)# 累加每个批次的损失值test_loss += loss_fn(pred, y).item()# 计算每个批次中预测正确的样本数并累加correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均测试损失test_loss /= num_batches# 计算平均准确率correct /= size# 打印测试结果print(f'Test result: \n Accuracy:{(100 * correct)}%,Avg loss:{test_loss}')
创建损失函数和优化器
# 创建交叉熵损失函数对象
loss_fn = nn.CrossEntropyLoss()
# 创建 Adam 优化器,用于更新模型的参数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=3,gamma=0.1)
训练测试数据
# 定义训练的轮数
epochs = 26
# 开始训练循环
for t in range(epochs):print(f'epoch{t + 1}\n--------------------')# 调用训练函数进行一轮训练train(train_dataloader, model, loss_fn, optimizer)
print('Done!')
# 调用测试函数进行测试
test(test_dataloader, model, loss_fn)
结果

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/895840.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

为什么要学习AI、掌握AI技能有什么用?

随着人工智能的迅速的发展,DeepSeek的爆火,加之目前就业环境的走向,越来越多的职场朋友开始关注到AI的发展,重视AI技能的掌握。不少同学都会问:“职场人为什么要学习AI、掌握AI技能?” 为什么要学AI 现…

AIP-146 泛化域

编号146原文链接AIP-146: Generic fields状态批准创建日期2019-05-28更新日期2019-05-28 API中的大多数域,无论是在请求、资源还是自定义应答中,都有具体的类型或模式。这个模式是约定的一部分,开发者依此约定进行编码。 然而,偶…

vue3和vue2的组件开发有什么区别

Vue3和Vue2在组件开发上存在不少差异,下面从多个方面详细介绍: 响应式原理 Vue2:用Object.defineProperty()方法来实现响应式。打个比方,它就像给对象的每个属性都安排了一个“小管家”,属性被访问或修改时&#xff0…

【NLP 25、模型训练方式】

目录 一、按学习范式分类 1. 监督学习(Supervised Learning) 2. 无监督学习(Unsupervised Learning) 3. 半监督学习(Semi-supervised Learning) 4. 强化学习(Reinforcement Learning, RL&#x…

1-知识图谱-概述和介绍

知识图谱:浙江大学教授 陈华军 知识图谱 1课时 http://openkg.cn/datasets-type/ 知识图谱的价值 知识图谱是有什么用? 语义搜索 问答系统 QA问答对知识图谱:结构化图 辅助推荐系统 大数据分析系统 自然语言理解 辅助视觉理解 例…

零基础学QT、C++(一)安装QT

目录 如何快速学习QT、C呢? 一、编译器、项目构建工具 1、编译器(介绍2款) 2、项目构建工具 二、安装QT 1、下载QT安装包 2、运行安装包 3、运行QT creator 4、导入开源项目 总结 闲谈 如何快速学习QT、C呢? 那就是项目驱动法&…

STM32外设SPI FLASH应用实例

STM32外设SPI FLASH应用实例 1. 前言1.1 硬件准备1.2 软件准备 2. 硬件连接3. 软件实现3.1 SPI 初始化3.2 QW128 SPI FLASH 驱动3.3 乒乓存储实现 4. 测试与验证4.1 数据备份测试4.2 数据恢复测试 5 实例5.1 参数结构体定义5.2 存储参数到 SPI FLASH5.3 从 SPI FLASH 读取参数5…

Leetcode2080:区间内查询数字的频率

题目描述: 请你设计一个数据结构,它能求出给定子数组内一个给定值的 频率 。 子数组中一个值的 频率 指的是这个子数组中这个值的出现次数。 请你实现 RangeFreqQuery 类: RangeFreqQuery(int[] arr) 用下标从 0 开始的整数数组 arr 构造…

Spring Boot自动装配:约定大于配置的魔法解密

#### 一、自动装配的哲学思考 在传统Spring应用中,开发者需要手动配置大量的XML或JavaConfig。Spring Boot通过自动装配机制实现了**约定大于配置**的设计理念,其核心思想可以概括为: 1. **智能预设**:基于类路径检测自动配置 2…

Fiddler笔记

文章目录 一、与F12对比二、核心作用三、原理四、配置1.Rules:2.配置证书抓取https包3.设置过滤器4、抓取App包 五、模拟弱网测试六、调试1.线上调试2.断点调试 七、理论1.四要素2.如何定位前后端bug 注 一、与F12对比 相同点: 都可以对http和https请求进行抓包分析…

Python爬虫-猫眼电影的影院数据

前言 本文是该专栏的第46篇,后面会持续分享python爬虫干货知识,记得关注。 本文笔者以猫眼电影为例子,获取猫眼的影院相关数据。 废话不多说,具体实现思路和详细逻辑,笔者将在正文结合完整代码进行详细介绍。接下来,跟着笔者直接往下看正文详细内容。(附带完整代码) …

linux笔记:shell中的while、if、for语句

在Udig软件的启动脚本中使用了while循环、if语句、for循环,其他内容基本都是变量的定义,所以尝试弄懂脚本中这三部分内容,了解脚本执行过程。 (1)while循环 while do循环内容如下所示,在循环中还用了expr…

利用分治策略优化快速排序

1. 基本思想 分治快速排序(Quick Sort)是一种基于分治法的排序算法,采用递归的方式将一个数组分割成小的子数组,并通过交换元素来使得每个子数组元素按照特定顺序排列,最终将整个数组排序。 快速排序的基本步骤&#…

从零到一实现微信小程序计划时钟:完整教程

在本教程中,我们将一起实现一个微信小程序——计划时钟。这个小程序的核心功能是帮助用户添加任务、设置任务的时间范围,并且能够删除和查看已添加的任务。通过以下步骤,我们将带你从零开始实现一个具有基本功能的微信小程序计划时钟。 项目…

idea日常报错之UTF-8不可映射的字符

目录 一、UTF-8不可映射的字符的解决 1、出现这种报错的情形 2、具体解决办法 前言: 在我们日常代码编写的时候可能会遇到各式各样的错误,有时候并不是你改动了代码,而是莫名其妙就出现的报错,今天我就遇到一个在maven编译的时候…

人工智能技术-基于长短期记忆(LSTM)网络在交通流量预测中的应用

人工智能技术-基于长短期记忆(LSTM)网络在交通流量预测中的应用 基于人工智能的智能交通管理系统 随着城市化进程的加快,交通问题日益严峻。为了解决交通拥堵、减少交通事故、提高交通管理效率,人工智能(AI&#xff…

HTTP FTP SMTP TELNET 应用协议

1. 标准和非标准的应用协议 标准应用协议: 由标准化组织(如 IETF,Internet Engineering Task Force)制定和维护,具有广泛的通用性和互操作性。这些协议遵循严格的规范和标准,不同的实现之间可以很好地进行…

Matlab离线安装硬件支持包的方法

想安装支持树莓派的包,但是发现通过matlab安装需要续订维护服务 可以通过离线的方式安装。 1. 下载SupportSoftwareDownloader Support Software Downloader - MATLAB & Simulink 登录账号 选择对应的版本 2. 选择要安装的包 3.将下载的包copy到安装目录下 …

Django REST Framework (DRF) 中用于构建 API 视图类解析

Django REST Framework (DRF) 提供了丰富的视图类,用于构建 API 视图。这些视图类可以分为以下几类: 1. 基础视图类 这些是 DRF 中最基础的视图类,通常用于实现自定义逻辑。 常用类 APIView: 最基本的视图类,所有其…

MyBatis拦截器终极指南:从原理到企业级实战

在本篇文章中,我们将深入了解如何编写一个 MyBatis 拦截器,并通过一个示例来展示如何在执行数据库操作(如插入或更新)时,自动填充某些字段(例如 createdBy 和 updatedBy)信息。本文将详细讲解拦…