2-2 MATLAB鮣鱼优化算法ROA优化CNN超参数回归预测

本博客来源于CSDN机器鱼,未同意任何人转载。

更多内容,欢迎点击本专栏目录,查看更多内容。

目录

0.引言

1.ROA优化CNN

2.主程序调用

3.结语


0.引言

在博客【ROA优化LSTM超参数回归】中,我们采用ROA对LSTM的学习率、迭代次数、batchsize、两个lstmlayer的节点数进行寻优,在优化过程中我们不必知道ROA的具体优化原理,只需要修改lb、ub、维度D、边界判断、适应度函数即可。今天这边博客,我们依旧采用此前提到的步骤对CNN的超参数进行回归,话不多说,首先我们定义一个超级简单的CNN网络进行回归预测,代码如下:

clc;clear;close all;rng(0)
%% 数据的提取
load data%数据是4输入1输出的简单数据
train_x;%4*98
train_y;%1*98
test_x;%4*42
test_y;%1*98
%转成CNN的输入格式
feature=size(train_x,1);
num_train=size(train_x,2);
num_test=size(test_x,2);
trainD=reshape(train_x,[feature,1,1,num_train]);
testD=reshape(test_x,[feature,1,1,num_test]);
targetD = train_y';
targetD_test  = test_y';%% 网络构建
layers = [imageInputLayer([size(trainD,1) size(trainD,2) size(trainD,3)]) % 输入convolution2dLayer(3,4,'Stride',1,'Padding','same')%核3*1 数量4 步长1 填充为samereluLayer%relu激活convolution2dLayer(3,8,'Stride',1,'Padding','same')%核3*1 数量8 步长1 填充为samereluLayer%relu激活fullyConnectedLayer(20) % 全连接层1 20个神经元reluLayerfullyConnectedLayer(20) % 全连接层2 20个神经元reluLayerfullyConnectedLayer(size(targetD,2)) %输出层regressionLayer];
%% 网络训练
options = trainingOptions('adam', ...'ExecutionEnvironment','cpu', ...'MaxEpochs',30, ...'MiniBatchSize',16, ...'InitialLearnRate',0.01, ...'GradientThreshold',1, ...'shuffle','every-epoch',...'Verbose',false);
train_again=1;% 为1就代码重新训练模型,为0就是调用训练好的网络
if train_again==1[net,traininfo] = trainNetwork(trainD,targetD,layers,options);save result/cnn_net net traininfo
elseload result/cnn_net
end
figure;
plot(traininfo.TrainingLoss,'b')
hold on;grid on
ylabel('损失')
xlabel('训练次数')
title('CNN')
%% 结果评价
YPred = predict(net,testD);YPred=double(YPred);

观察网络构建与训练,我们发现至少有9个参数需要优化,分别是:迭代次数MaxEpochs、MiniBatchSize、第一层卷积层的核大小和数量、第2层卷积层的核大小和数量,以及两个全连接层的神经元数量,还有学习率InitialLearnRate(学习率放最后是因为其他的都是整数,只有这个是小数,要么放最前要么放最后,方便我们写边界判断函数与初始化种群的程序)

1.ROA优化CNN

步骤1:知道要优化的参数与优化范围。显然就是上面提到的9个参数。代码如下,首先改写lb与ub,然后初始化的时候注意除了学习率,其他的都是整数。并将原来里面的边界判断,改成了Bounds函数,方便在计算适应度函数值的时候转化成整数与小数。如果学习率的位置不在最后,而是在其他位置,就需要改随机初始化位置和Bounds函数与fitness函数里对应的地方,具体怎么改就不说了,很简单。

function [Rbest,Convergence_curve,process]= roa_cnn(X1,y1,Xt,yt)
D=9;%一共有9个参数需要优化,分别是迭代次数、batchsize、第一层卷积层的核大小、和数量、第2层卷积层的核大小、和数量,以及两个全连接层的神经元数量,学习率
lb= [10 16  1 1 1 1 1 1 0.001];    % 下边界
ub= [50 256 3 20 3 20 50 50 0.01];    % 上边界
% 迭代次数的范围是10-50 batchsize的范围是16-256 核大小的范围是1-3 核数量的范围是1-20 全连接层的范围是1-50% 学习率的范围是0.001-0.01
sizepop=5;
maxgen=10;% maxgen 为最大迭代次数,
% sizepop 为种群规模
%记D为维度,lb、 ub分别为搜索上、下限
R=ones(sizepop,D);%预设种群
for i=1:sizepop%随机初始化位置for j=1:Dif j==D%除了学习率 其他的都是整数R( i, j ) = (ub(j)-lb(j))*rand+lb(j);elseR( i, j ) = round((ub(j)-lb(j))*rand+lb(j));endend
endfor k= 1:sizepopFitness(k)=fitness(R(k,:),X1,y1,Xt,yt);%个体适应度
end
[Fbest,elite]= min(Fitness);%Fbest为最优适应度值
Rbest= R(elite,:);%最优个体位置
H=zeros(1,sizepop);%控制因子%主循环
for iter= 1:maxgenRpre= R;%记录上一代的位置V=2*(1-iter/maxgen);B= 2*V*rand-V;a=-(1 + iter/maxgen);alpha=rand*(a-1)+ 1;for i= 1:sizepopif H(i)==0dis = abs(Rbest-R(i,:));R(i,:)= R(i,:)+ dis* exp(alpha)*cos(2*pi* alpha);elseRAND= ceil(rand*sizepop);%随机选择一个个体R(i,:)= Rbest -(rand*0.5*(Rbest + R(RAND,:))- R(RAND,:));endRatt= R(i,:)+ (R(i,:)- Rpre(i,:))*randn;%作出小幅度移动%边界吸收R(i, : ) = Bounds( R(i, : ), lb, ub );%对超过边界的变量进行去除Ratt = Bounds( Ratt, lb, ub );%对超过边界的变量进行去除Fitness(i)=fitness(R(i,:),X1,y1,Xt,yt);Fitness_Ratt= fitness(Ratt,X1,y1,Xt,yt);if Fitness_Ratt < Fitness(i)%改变寄主if H(i)==1H(i)=0;elseH(i)=1;endelse %不改变寄主A= B*(R(i,:)-rand*0.3*Rbest);R(i,:)=R(i,:)+A;endR(i, : ) = Bounds( R(i, : ), lb, ub );%对超过边界的变量进行去除end%更新适应度值、位置[fbest,elite] = min(Fitness);%更新最优个体if fbest< FbestFbest= fbest;Rbest= R(elite,:);endprocess(iter,:)=Rbest;Convergence_curve(iter)= Fbest;iter,Fbest,Rbest
endendfunction s = Bounds( s, Lb, Ub)
temp = s;
dim=length(Lb);
for i=1:length(s)if i==dim%除了学习率 其他的都是整数temp(:,i) =temp(:,i);elsetemp(:,i) =round(temp(:,i));end
end% 判断参数是否超出设定的范围for i=1:length(s)if temp(:,i)>Ub(i) | temp(:,i)<Lb(i) if i==dim%除了学习率 其他的都是整数temp(:,i) =rand*(Ub(i)-Lb(i))+Lb(i);elsetemp(:,i) =round(rand*(Ub(i)-Lb(i))+Lb(i));endend
end
s = temp;
end
function s = Bounds( s, Lb, Ub)
temp = s;
for i=1:length(s)if i==1%除了学习率 其他的都是整数temp(:,i) =temp(:,i);elsetemp(:,i) =round(temp(:,i));end
end% 判断参数是否超出设定的范围for i=1:length(s)if temp(:,i)>Ub(i) | temp(:,i)<Lb(i) if i==1%除了学习率 其他的都是整数temp(:,i) =rand*(Ub(i)-Lb(i))+Lb(i);elsetemp(:,i) =round(rand*(Ub(i)-Lb(i))+Lb(i));endend
end
s = temp;
end

步骤2:知道优化的目标。优化的目标是提高的网络的准确率,而ROA代码我们这个代码是最小值优化的,所以我们的目标可以是最小化CNN的预测误差。预测误差具体是,测试集(或验证集)的预测值与真实值之间的均方差。

步骤3:构建适应度函数。通过步骤2我们已经知道目标,即采用ROA去找到9个值,用这9个值构建的CNN网络,误差最小化。观察下面的代码,首先我们将ROA的值传进来,然后转成需要的9个值,然后构建网络,训练集训练、测试集预测,计算预测值与真实值的mse,将mse作为结果传出去作为适应度值。

function y=fitness(x,trainD,targetD,testD,targetD_test)
rng(0)
%% 将传进来的值 转换为需要的超参数
iter=x(1);
minibatch=x(2);
kernel1_size=x(3);
kernel1_num=x(4);
kernel2_size=x(5);
kernel2_num=x(6);
fc1_num=x(7);
fc2_num=x(8);
lr=x(9);feature=size(trainD,1);
%% 利用寻优得到参数重新训练CNN与预测 
layers = [imageInputLayer([size(trainD,1) size(trainD,2) size(trainD,3)]) convolution2dLayer(kernel1_size,kernel1_num,'Stride',1,'Padding','same')reluLayerconvolution2dLayer(kernel2_size,kernel2_num,'Stride',1,'Padding','same')reluLayerfullyConnectedLayer(fc1_num) reluLayerfullyConnectedLayer(fc2_num) reluLayerfullyConnectedLayer(size(targetD,2))regressionLayer];
options = trainingOptions('adam', ...'ExecutionEnvironment','cpu', ...'MaxEpochs',iter, ...'MiniBatchSize',minibatch, ...'InitialLearnRate',lr, ...'GradientThreshold',1, ...'shuffle','every-epoch',...'Verbose',false);
net = trainNetwork(trainD,targetD,layers,options);
YPred = predict(net,testD);
%% 适应度值计算
YPred=double(YPred);
%以CNN的预测值与实际值的均方误差最小化作为适应度函数,SSA的目的就是找到一组超参数
%用这组超参数训练得到的CNN的误差能够最小化
[m,n]=size(YPred);
YPred=reshape(YPred,[1,m*n]);
targetD_test=reshape(targetD_test,[1,m*n]);
y=mse(YPred,targetD_test);rng(sum(100*clock))

2.主程序调用

clc;clear;close all;format compact;rng(0)%% 数据的提取
load data
load data%数据是4输入1输出的简单数据
train_x;%4*98
train_y;%1*98
test_x;%4*42
test_y;%1*98
feature=size(train_x,1);
num_train=size(train_x,2);
num_test=size(test_x,2);
trainD=reshape(train_x,[feature,1,1,num_train]);
testD=reshape(test_x,[feature,1,1,num_test]);
targetD = train_y';
targetD_test  = test_y';%% ROA优化CNN的超参数
%一共有9个参数需要优化,分别是学习率、迭代次数、batchsize、第一层卷积层的核大小、和数量、第2层卷积层的核大小、和数量,以及两个全连接层的神经元数量
optimaztion=1;  
if optimaztion==1[x,trace,process]=roa_cnn(trainD,targetD,testD,targetD_test);save result/roa_result x trace process
elseload result/roa_result
end
%%figure
plot(trace)
title('适应度曲线')
xlabel('优化次数')
ylabel('适应度值')disp('优化后的各超参数')iter=x(1)%迭代次数
minibatch=x(2)%batchsize 
kernel1_size=x(3)
kernel1_num=x(4)%第一层卷积层的核大小与核数量
kernel2_size=x(5)
kernel2_num=x(6)%第2层卷积层的核大小与核数量
fc1_num=x(7)
fc2_num=x(8)%两个全连接层的神经元数量
lr=x(9)%学习率%% 利用寻优得到参数重新训练CNN与预测
rng(0)
layers = [imageInputLayer([size(trainD,1) size(trainD,2) size(trainD,3)])convolution2dLayer(kernel1_size,kernel1_num,'Stride',1,'Padding','same')reluLayerconvolution2dLayer(kernel2_size,kernel2_num,'Stride',1,'Padding','same')reluLayerfullyConnectedLayer(fc1_num)reluLayerfullyConnectedLayer(fc2_num)reluLayerfullyConnectedLayer(size(targetD,2))regressionLayer];
options = trainingOptions('adam', ...'ExecutionEnvironment','cpu', ...'MaxEpochs',iter, ...'MiniBatchSize',minibatch, ...'InitialLearnRate',lr, ...'GradientThreshold',1, ...'Verbose',false);train_again=1;% 为1就重新训练模型,为0就是调用训练好的网络 
if train_again==1[net,traininfo] = trainNetwork(trainD,targetD,layers,options);save result/roacnn_net net traininfo
elseload result/roacnn_net
endfigure;
plot(traininfo.TrainingLoss,'b')
hold on;grid on
ylabel('损失')
xlabel('训练次数')
title('roa-CNN')%% 结果评价
YPred = predict(net,testD);YPred=double(YPred);

3.结语

优化网络超参数的格式都是这样的!只要会改一种,那么随便拿一份能跑通的优化算法,在不管原理的情况下,都能用来优化网络的超参数。更多内容【点击专栏】目录。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/75625.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

企业入驻成都国际数字影像产业园,可享150多项专业服务

企业入驻成都国际数字影像产业园&#xff0c;可享150多项专业服务 全方位赋能&#xff0c;助力影像企业腾飞 入驻成都国际数字影像产业园&#xff0c;企业将获得一个涵盖超过150项专业服务的全周期、一站式支持体系&#xff0c;旨在精准解决企业发展各阶段的核心需求&#xf…

线路板元器件介绍及选型指南:提高电路设计效率

电路板&#xff08;PCB&#xff09;是现代电子设备的核心&#xff0c;其上安装了各类电子元器件&#xff0c;这些元器件通过PCB的导电线路彼此连接&#xff0c;实现信号传输与功能执行。 元器件的选择与安装直接决定了电子产品的性能与稳定性。本文将为大家详细介绍电路板上的…

探究 Arm Compiler for Embedded 6 的 Clang 版本

原创标题&#xff1a;Arm Compiler for Embedded 6 的 Clang 版本 原创作者&#xff1a;庄晓立&#xff08;LIIGO&#xff09; 原创日期&#xff1a;20250218&#xff08;首发日期20250326&#xff09; 原创连接&#xff1a;https://blog.csdn.net/liigo/article/details/14653…

RedHat7.6_x86_x64服务器(最小化安装)搭建使用记录(二)

PostgreSQL数据库部署管理 1.rpm方式安装 挂载系统安装镜像&#xff1a; [rootlocalhost ~]# mount /dev/cdrom /mnt 进入安装包路径&#xff1a; [rootlocalhost ~]# cd /mnt/Packages 依次安装如下程序包&#xff1a; [rootlocalhost Packages]# rpm -ihv postgresql-libs-9…

浏览器存储 IndexedDB

IndexedDB 1. 什么是 IndexedDB&#xff1f; IndexedDB 是一种 基于浏览器的 NoSQL 数据库&#xff0c;用于存储大量的结构化数据&#xff0c;包括文件和二进制数据。它比 localStorage 和 sessionStorage 更强大&#xff0c;支持索引查询、事务等特性。 IndexedDB 主要特点…

panda3d 渲染

目录 安装 设置渲染宽高&#xff1a; 渲染3d 安装 pip install Panda3D 设置渲染宽高&#xff1a; import panda3d.core as pdmargin 100 screen Tk().winfo_screenwidth() - margin, Tk().winfo_screenheight() - margin width, height (screen[0], int(screen[0] / 1…

Node.js 包管理工具 - NPM 与 PNPM 清理缓存

NPM 清理缓存 1、基本介绍 npm 缓存是 npm 用来存储已下载包的地方&#xff0c;以加快后续安装速度 但是&#xff0c;有时缓存可能会损坏或占用过多磁盘空间&#xff0c;这时可以清理 npm 缓存 2、清理操作 执行如下指令&#xff0c;清理 npm 缓存 npm cache clean --for…

STM32F103_LL库+寄存器学习笔记05 - GPIO输入模式,捕获上升沿进入中断回调

导言 GPIO设置输入模式后&#xff0c;一般会用轮询的方式去查看GPIO的电平状态。比如&#xff0c;最常用的案例是用于检测按钮的当前状态&#xff08;是按下还是没按下&#xff09;。中断的使用一般用于计算脉冲的频率与计算脉冲的数量。 项目地址&#xff1a;https://github.…

【C++进阶二】string的模拟实现

【C进阶二】string的模拟实现 1.构造函数和C_strC_str: 2.operator[]3.拷贝构造3.1浅拷贝3.2深拷贝 4.赋值5.迭代器6.比较ascll码值的大小7.reverse扩容8.push_back尾插和append尾插9.10.insert10.1在pos位置前插入字符ch10.2在pos位置前插入字符串str 11.resize12.erase12.1从…

wokwi arduino mega 2560 - 点亮LED案例

截图&#xff1a; 点亮LED案例仿真截图 代码&#xff1a; unsigned long t[20]; // 定义一个数组t&#xff0c;用于存储20个LED的上次状态切换时间&#xff08;单位&#xff1a;毫秒&#xff09;void setup() {pinMode(13, OUTPUT); // 将引脚13设置为输出模式&#xff08;此…

vue3项目使用 python +flask 打包成桌面应用

server.py import os import sys from flask import Flask, send_from_directory# 获取静态文件路径 if getattr(sys, "frozen", False):# 如果是打包后的可执行文件base_dir sys._MEIPASS else:# 如果是开发环境base_dir os.path.dirname(os.path.abspath(__file…

后端学习day1-Spring(八股)--还剩9个没看

一、Spring 1.请你说说Spring的核心是什么 参考答案 Spring框架包含众多模块&#xff0c;如Core、Testing、Data Access、Web Servlet等&#xff0c;其中Core是整个Spring框架的核心模块。Core模块提供了IoC容器、AOP功能、数据绑定、类型转换等一系列的基础功能&#xff0c;…

LeetCode 第34、35题

LeetCode 第34题&#xff1a;在排序数组中查找元素的第一个和最后一个位置 题目描述 给你一个按照非递减顺序排列的整数数组nums&#xff0c;和一个目标值target。请你找出给定目标值在数组中的开始位置和结束位置。如果数组中不存在目标值target&#xff0c;返回[-1,1]。你必须…

告别分库分表,时序数据库 TDengine 解锁燃气监控新可能

达成效果&#xff1a; 从 MySQL 迁移至 TDengine 后&#xff0c;设备数据自动分片&#xff0c;运维更简单。 列式存储可减少 50% 的存储占用&#xff0c;单服务器即可支撑全量业务。 毫秒级漏气报警响应时间控制在 500ms 以内&#xff0c;提升应急管理效率。 新架构支持未来…

第十四届蓝桥杯真题

一.LED 先配置LED的八个引脚为GPIO_OutPut,锁存器PD2也是,然后都设置为起始高电平,生成代码时还要去解决引脚冲突问题 二.按键 按键配置,由原理图按键所对引脚要GPIO_Input 生成代码,在文件夹中添加code文件夹,code中添加fun.c、fun.h、headfile.h文件,去资源包中把lc…

《基于机器学习发电数据电量预测》开题报告

个人主页&#xff1a;大数据蟒行探索者 目录 一、选题背景、研究意义及文献综述 &#xff08;一&#xff09;选题背景 &#xff08;二&#xff09;选题意义 &#xff08;三&#xff09;文献综述 1. 国内外研究现状 2. 未来方向展望 二、研究的基本内容&#xff0c;拟解…

UWP程序用多页面实现应用实例多开

Windows 10 IoT ARM64平台下&#xff0c;UWP应用和MFC程序不一样&#xff0c;同时只能打开一个应用实例。以串口程序为例&#xff0c;如果用户希望同时打开多个应用实例&#xff0c;一个应用实例打开串口1&#xff0c;一个应用实例打开串口2&#xff0c;那么我们可以加载多个页…

Springboot整合Netty简单实现1对1聊天(vx小程序服务端)

本文功能实现较为简陋&#xff0c;demo内容仅供参考&#xff0c;有不足之处还请指正。 背景 一个小项目&#xff0c;用于微信小程序的服务端&#xff0c;需要实现小程序端可以和他人1对1聊天 实现功能 Websocket、心跳检测、消息持久化、离线消息存储 Netty配置类 /*** au…

GitLab 中文版17.10正式发布,27项重点功能解读【二】

GitLab 是一个全球知名的一体化 DevOps 平台&#xff0c;很多人都通过私有化部署 GitLab 来进行源代码托管。极狐GitLab 是 GitLab 在中国的发行版&#xff0c;专门为中国程序员服务。可以一键式部署极狐GitLab。 学习极狐GitLab 的相关资料&#xff1a; 极狐GitLab 官网极狐…

好消息!软航文档控件(NTKO WebOffice)在Chrome 133版本上提示扩展已停用的解决方案

软航文档控件现有版本依赖Manifest V2扩展技术支持才能正常运行&#xff0c;然而这个扩展技术到2025年6月在Chrome高版本上就彻底不支持了&#xff0c;现在Chrome 133开始的版本已经开始弹出警告&#xff0c;必须手工开启扩展支持才能正常运行。那么如何解决这个技术难题呢&…