MATLAB深度学习(七)——ResNet残差网络

一、ResNet网络

        ResNet是深度残差网络的简称。其核心思想就是在,每两个网络层之间加入一个残差连接,缓解深层网络中的梯度消失问题

二、残差结构

        在多层神经网络模型里,设想一个包含诺干层自网络,子网络的函数用H(x)来表示,其中x是子网络的输入。残差学习是通过重新设定这个参数,让一个参数层表达一个残差函数                 ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        F(x)=H(x)-x

        因此这个子网络的输出y 就是 

        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​   y=F(x)+x

        其中 +x 的操作,是通过一个相当于恒等映射的跳跃连接来完成的,它将残差块的输入直接与输出连接,这就是残差结构。按照上面的结构递推,根据前向传播,第i个残差块的输出就是第i+1个残差块的输入

        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        x_{\ell+1}=F(x_\ell)+x_\ell

        根据递归公式,可以推导出

        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        x_L=x_\ell+\sum_{i=l}^{L-1}F(x_i)

        这里的L表示任意后续残差块,i是靠前块,那么公式就说明了总会有信号能从浅层到深层。

        从反向传播来看,根据上面的公式对 Xl 进行求导,可以得到

        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        ​​​​​​​        \begin{aligned} \frac{\partial\mathcal{E}}{\partial x_\ell} & =\frac{\partial\mathcal{E}}{\partial x_L}\frac{\partial x_L}{\partial x_\ell} \\ & =\frac{\partial\mathcal{E}}{\partial x_L}\left(1+\frac{\partial}{\partial x_\ell}\sum_{i=l}^{L-1}F(x_i)\right) \\ & =\frac{\partial\mathcal{E}}{\partial x_L}+\frac{\partial\mathcal{E}}{\partial x_L}\frac{\partial}{\partial x_\ell}\sum_{i=l}^{L-1}F(x_i) \end{aligned}

        这里的 \mathcal{E} 是最小损失化函数。以上说明,浅层的梯度计算 \frac{\partial\mathcal{E}}{\partial x_{\ell}},总会直接加上上一个项 \frac{\partial\mathcal{E}}{\partial x_L}因为存在额外的一项,所以就想 F(xi)很小,总的梯度都不会消失。

三、基于ResNet识别实现步骤

        其主要步骤为1.加载图像数据,并将数据分为训练集合与验证集;2.加载MATLAB训练好的ResNet50;3.和Alexnet一样替换最后几层;4.按照网络配置调整图像数据;5.对网络进行训练

3.1 调整ResNet实现迁移学习

        针对ImageNet的数据任务,原本最后三层 FC SOFTMAX 输出层是针对1000个类别的物体进行识别,针对图像问题,继续调整这三层,首先冻结上面的层。将下面的三层进行替换。由于ResNet50需要输入的图像大小为224 * 224 *3.

unzip('MerchData.zip');
img_ds = imageDatastore('MerchData', ...'IncludeSubfolders',true, ...'LabelSource','foldernames');total_split = countEachLabel(img_ds); %返回一个包含每个标签和相应图像数量的表格num_images = length(img_ds.Labels); %返回的图像的个数,5个类型都有15张照片
perm =  randperm(num_images,10);  %随机取出10个
figure
for i = 1:9subplot(3,3,i)imshow(imread(img_ds.Files{perm(i)})); %从图像数据集 (img_ds) 中读取一张图像并显示它%可以用alexnet的方法
endtest_idx = randperm(num_images,9); %随机取5张做样本
img_ds_Test = subset(img_ds,test_idx);
train_idx = setdiff(1:length(img_ds.Files),test_idx);
img_ds_Train = subset(img_ds,train_idx);%% 步骤2:加载预训练好的网络% 加载ResNet50网络(注:该网络需要提前下载,当输入下面命令时按要求下载即可)
net = resnet50;%% 步骤3:对网络结构进行调整,替换最后几层% 获取网络图结构
LayerGraph = layerGraph(net);
clear net;% 确定训练数据中新冠图片标签类别数量:5类
numClasses = numel(categories(img_ds_Train.Labels));
disp(numClasses);% 保留ResNet50倒数第三层之前的网络,并替换后3层
% 倒数第三层的全连接层,这里修改为5类
newLearnableLayer = fullyConnectedLayer(numClasses,...'Name','new_fc',...'WeightLearnRateFactor',10,...
'BiasLearnRateFactor',10);
%numClasses 分类任务。通过设置 WeightLearnRateFactor 和 BiasLearnRateFactor
%来控制学习率的调整,使得这些层在训练过程中能更快速地学习
% 分别替换最后3层:fc1000、softmax和分类输出层
LayerGraph = replaceLayer(LayerGraph,'fc1000',newLearnableLayer);
%替换了原网络中的 fc1000 层(ResNet50 中的全连接层)为 new_fc 层
% (即刚才定义的新的全连接层)。replaceLayer 函数通过层的名字来替换图层newSoftmaxLayer = softmaxLayer('Name','new_softmax');
LayerGraph = replaceLayer(LayerGraph,'fc1000_softmax',newSoftmaxLayer);newClassLayer = classificationLayer('Name','new_classoutput');
LayerGraph = replaceLayer(LayerGraph,'ClassificationLayer_fc1000',newClassLayer);%% 步骤4:按照网络配置调整图像数据% 输入图像格式转换,这里调用了自定义函数preprocess
img_ds_Train.ReadFcn = @(filename)preprocess(filename);
img_ds_Test.ReadFcn  = @(filename)preprocess(filename);% 数据增强的参数
augmenter = imageDataAugmenter(...'RandRotation',[-5 5],...'RandXReflection',1,...'RandYReflection',1,...'RandXShear',[-0.05 0.05],...'RandYShear',[-0.05 0.05]);
% 将批量训练图像的大小调整为与输入层的大小相同
aug_img_ds_train = augmentedImageDatastore([224 224],img_ds_Train,'DataAugmentation',augmenter);
% 将批量测试图像的大小调整为与输入层的大小相同
aug_img_ds_test = augmentedImageDatastore([224 224],img_ds_Test);%% 步骤5:对网络进行训练% 对训练参数进行设置
options = trainingOptions('adam',...'MaxEpochs',10,...'MiniBatchSize',8,...'Shuffle','every-epoch',...'InitialLearnRate',1e-4,...'Verbose',false,...'Plots','training-progress',...'ExecutionEnvironment','cpu');% 用训练图像对网络进行训练
netTransfer = trainNetwork(aug_img_ds_train,LayerGraph,options);%% 步骤6:进行测试并查看结果% 对训练好的网络采用验证数据集进行验证
[YPred,scores] = classify(netTransfer,aug_img_ds_test);% 随机显示验证效果
idx = randperm(numel(img_ds_Test.Files),4);
figure
for i = 1:4subplot(2,2,i)I = readimage(img_ds_Test,idx(i));imshow(I)label = YPred(idx(i));title(string(label));
end%% 计算分类准确率
YValidation = img_ds_Test.Labels;
accuracy = mean(YPred == YValidation);%% 创建并显示混淆矩阵
figure
confusionchart(YValidation,YPred) 

实现效果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/62651.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端入门之VUE--vue组件化编程

前言 VUE是前端用的最多的框架;这篇文章是本人大一上学习前端的笔记;欢迎点赞 收藏 关注,本人将会持续更新。 文章目录 2、Vue组件化编程2.1、组件2.2、基本使用2.2.1、VueComponent 2、Vue组件化编程 2.1、组件 组件:用来实现…

设计模式-装饰器模式(结构型)与责任链模式(行为型)对比,以及链式设计

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言1.装饰器模式1.1概念1.2作用1.3应用场景1.4特点1.5类与对象关系1.6实现 2责任链模式2.1概念2.2作用2.3应用场景2.4特点2.5类与对象关系2.6实现 3.对比总结 前言…

交叉熵损失函数(Cross-Entropy Loss)

原理 交叉熵损失函数是深度学习中分类问题常用的损失函数,特别适用于多分类问题。它通过度量预测分布与真实分布之间的差异,来衡量模型输出的准确性。 交叉熵的数学公式 交叉熵的定义如下: C r o s s E n t r o y L o s s − ∑ i 1 N …

操作系统:死锁与饥饿

目录 死锁概念 饥饿与饿死概念 饥饿和死锁对比 死锁类型 死锁条件(Coffman条件) 死锁恢复方法 死锁避免 安全状态与安全进程序列: 银行家算法: 死锁检测时机(了解): 死锁检测 死锁案…

Prisoner’s Dilemma

囚徒困境博弈论解析 什么是囚徒困境? 囚徒困境(Prisoner’s Dilemma)是博弈论中的一个经典模型,用来分析两名玩家在非合作环境下的决策行为。 其核心在于玩家既可以选择合作也可以选择背叛,而最终的结果取决于双方的…

RPO: Read-only Prompt Optimization for Vision-Language Few-shot Learning

文章汇总 想解决的问题对CoOp的改进CoCoOp尽管提升了性能,但却增加了方差(模型的准确率波动性较大)。 模型的框架一眼看去,跟maple很像(maple跟这篇文章都是2023年发表的),但maple的视觉提示是由文本提示经过全连接转换而来的,而这里是文本提示和视觉提示是独立的。另外m…

『MySQL 实战 45 讲』24 - MySQL是怎么保证主备一致的?

MySQL是怎么保证主备一致的? MySQL 主备的基本原理 基本的主备切换流程 状态 1:客户端的读写都直接访问节点 A,而节点 B 是 A 的备库状态 2:切换时,读写访问的都是节点 B,而节点 A 是 B 的备库注意&…

自荐一部IT方案架构师回忆录

作者本人毕业于一个不知名大专院校,所读专业计算机科学技术。2009年开始IT职业生涯,至今工作15年。擅长TSQL/Shell/linux等技术,曾经就职于超万人大型集团、国内顶级云厂商、央国企公司。参与过运营商大数据平台、大型智慧城市ICT、云计算、人…

python数据分析之爬虫基础:selenium详细讲解

目录 1、selenium介绍 2、selenium的作用: 3、配置浏览器驱动环境及selenium安装 4、selenium基本语法 4.1、selenium元素的定位 4.2、selenium元素的信息 4.3、selenium元素的交互 5、Phantomjs介绍 6、chrome handless模式 1、selenium介绍 (1…

【数据结构——查找】顺序查找(头歌实践教学平台习题)【合集】

目录😋 任务描述 相关知识 测试说明 我的通关代码: 测试结果: 任务描述 本关任务:实现顺序查找的算法。 相关知识 为了完成本关任务,你需要掌握:1.根据输入数据建立顺序表,2.顺序表的输出,…

光伏电站建设成本利润估算

​截至2024年9月底,全国光伏发电装机容量达到7.7亿千瓦,同比增长48.4%。其中集中式光伏4.3亿千瓦,分布式光伏3.4亿千瓦。2024年前三季度,全国光伏发电量6359亿千瓦时,同比增长45.5%。全国光伏发电利用率97.2%,同比下降1.1个百分点.早在今年2月份,中国光伏行业协会名誉理…

create-react-app react19 搭建项目报错

报错截图 此时运行会报错: 解决方法: 1.根据提示安装依赖法 执行npm i web-vitals然后重新允许 2.删除文件法 在index.js中删除对报错文件的引入,删除报错文件

scala的集合性能2

可变集合\n可变集合允许在原地修改数据,适合需要频繁更新的场景。Scala 的可变集合包括 ArrayBuffer、HashSet和HashMap。 1. ArrayBuffer\nArrayBuffer 是一个可变的动态数组,提供高效的随机访问和添加操作。 import scala.collection.mutable.ArrayB…

【Ubuntu】脚本自动化控制终端填充

1.sh脚本文件控制终端写入命令 在SLAM算法中,每次启动vins都需要起很多终端,尽管使用了超级终端Terminator可以终端内划分看起来更加便捷,但是每次起算法的命令还是要自己输入,已经被麻烦了两年了,今天突然想写写一个…

【自学】Vues基础

学习目录 Vues基础本地应用网络应用综合应用 工具的准备 我个人比较喜欢使用HTMLDROWNER,学习资料推荐使用VC,仅供选择吧 前置知识 HTMLCSSJSAJAX:这个是学习资料博主推荐的 个人感觉认真学好HTMLCSSJS理解vues基础很容易上手 官方网址…

Scratch 消灭字母小游戏

背景 最近尝试一边自学Scratch,一边尝试教给小孩,看他打字时在键盘上乱打一气,想起来自己小时候玩过的学习机打字母游戏,就想给他下载一个。结果网上看到的代码,要么质量太差(有26个字母就要写 26 个判断&…

python调用matlab函数(内置 + 自定义) —— 安装matlab.engine

文章目录 一、简介二、安装matlab.engine2.1、基于 CMD 安装2.2、基于 MATLAB 安装(不建议) 三、python调用matlab函数(内置 自定义) 一、简介 matlab.engine(MATLAB Engine API for Python):…

pytroch环境安装-pycharm

环境介绍 安装pycharm 官网下载即可,我这里已经安装,就不演示了 安装anaconda 【官网链接】点击下载 注意这一步选择just me 这一步全部勾上 打开 anaconda Prompt 输入conda create -n pytorch python3.8 命令解释:创建一个叫pytorch&…

Photoshop提示错误弹窗dll缺失是什么原因?要怎么解决?

Photoshop提示错误弹窗“DLL缺失”:原因分析与解决方案 在创意设计与图像处理领域,Photoshop无疑是众多专业人士和爱好者的首选工具。然而,在使用Photoshop的过程中,有时会遇到一些令人头疼的问题,比如突然弹出的错误…

自己总结:selenium高阶知识

全篇大概10000字(含代码),建议阅读时间30min 一、等待机制 如果有一些内容是通过Ajax加载的内容,那就需要等待内容加载完毕才能进行下一步操作。 为了避免人为操作等待,会遇到的问题, selenium将等待转换…