Keras框架:Alexnet网络代码实现

网络思想:

在这里插入图片描述
1、一张原始图片被resize到(224,224,3);
2、使用步长为4x4,大小为11的卷积核对图像进行卷积,输出的特征层为96层, 输出的shape为(55,55,96);
3、使用步长为2的最大池化层进行池化,此时输出的shape为(27,27,96)
4、使用步长为1x1,大小为5的卷积核对图像进行卷积,输出的特征层为256层, 输出的shape为(27,27,256);
5、使用步长为2的最大池化层进行池化,此时输出的shape为(13,13,256);
6、使用步长为1x1,大小为3的卷积核对图像进行卷积,输出的特征层为384层, 输出的shape为(13,13,384);
7、使用步长为1x1,大小为3的卷积核对图像进行卷积,输出的特征层为384层, 输出的shape为(13,13,384);
8、使用步长为1x1,大小为3的卷积核对图像进行卷积,输出的特征层为256层, 输出的shape为(13,13,256);
9、使用步长为2的最大池化层进行池化,此时输出的shape为(6,6,256);
10、两个全连接层,最后输出为1000类

细节部分举例:

第一层
第一层输入数据为原始图像的2242243的图像,这个图像被11113(3代表 深度,例如RGB的3通道)的卷积核进行卷积运算,卷积核对原始图像的每次 卷积都会生成一个新的像素。 卷积核的步长为4个像素,朝着横向和纵向这两个方向进行卷积。 由此,会生成新的像素; 第一层有96个卷积核,所以就会形成555596个像素层。 pool池化层:这些像素层还需要经过pool运算(池化运算)的处理,池化运 算的尺度由预先设定为33,运算的步长为2,则池化后的图像的尺寸为: (55-3)/2+1=27。即经过池化处理过的规模为2727*96.

代码实现:

网络主体部分:(AlexNet.py)

from keras.models import Sequential
from keras.layers import Dense,Activation,Conv2D,MaxPooling2D,Flatten,Dropout,BatchNormalization
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam# 注意,为了加快收敛,我将每个卷积层的filter减半,全连接层减为1024
def AlexNet(input_shape=(224,224,3),output_shape=2):# AlexNetmodel = Sequential()# 使用步长为4x4,大小为11的卷积核对图像进行卷积,输出的特征层为96层,输出的shape为(55,55,96);# 所建模型后输出为48特征层model.add(Conv2D(filters=48, kernel_size=(11,11),strides=(4,4),padding='valid',input_shape=input_shape,activation='relu'))model.add(BatchNormalization())# 使用步长为2的最大池化层进行池化,此时输出的shape为(27,27,96)# 所建模型后输出为48特征层model.add(MaxPooling2D(pool_size=(3,3), strides=(2,2), padding='valid'))# 使用步长为1x1,大小为5的卷积核对图像进行卷积,输出的特征层为256层,输出的shape为(27,27,256);# 所建模型后输出为128特征层model.add(Conv2D(filters=128, kernel_size=(5,5), strides=(1,1), padding='same',activation='relu'))model.add(BatchNormalization())# 使用步长为2的最大池化层进行池化,此时输出的shape为(13,13,256);# 所建模型后输出为128特征层model.add(MaxPooling2D(pool_size=(3,3),strides=(2,2),padding='valid'))# 使用步长为1x1,大小为3的卷积核对图像进行卷积,输出的特征层为384层,输出的shape为(13,13,384);# 所建模型后输出为192特征层model.add(Conv2D(filters=192, kernel_size=(3,3),strides=(1,1), padding='same', activation='relu')) # 使用步长为1x1,大小为3的卷积核对图像进行卷积,输出的特征层为384层,输出的shape为(13,13,384);# 所建模型后输出为192特征层model.add(Conv2D(filters=192, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))# 使用步长为1x1,大小为3的卷积核对图像进行卷积,输出的特征层为256层,输出的shape为(13,13,256);# 所建模型后输出为128特征层model.add(Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))# 使用步长为2的最大池化层进行池化,此时输出的shape为(6,6,256);# 所建模型后输出为128特征层model.add(MaxPooling2D(pool_size=(3,3), strides=(2,2), padding='valid'))# 两个全连接层,最后输出为1000类,这里改为2类(猫和狗)# 缩减为1024model.add(Flatten())model.add(Dense(1024, activation='relu'))model.add(Dropout(0.25))model.add(Dense(1024, activation='relu'))model.add(Dropout(0.25))model.add(Dense(output_shape, activation='softmax'))return model

图像预处理部分:(utils.py)

import matplotlib.image as mpimg
import numpy as np
import cv2
import tensorflow as tf
from tensorflow.python.ops import array_opsdef load_image(path):# 读取图片,rgbimg = mpimg.imread(path)# 将图片修剪成中心的正方形short_edge = min(img.shape[:2])yy = int((img.shape[0] - short_edge) / 2)xx = int((img.shape[1] - short_edge) / 2)crop_img = img[yy: yy + short_edge, xx: xx + short_edge]return crop_imgdef resize_image(image, size):with tf.name_scope('resize_image'):images = []for i in image:i = cv2.resize(i, size)images.append(i)images = np.array(images)return imagesdef print_answer(argmax):with open("./data/model/index_word.txt","r",encoding='utf-8') as f:synset = [l.split(";")[1][:-1] for l in f.readlines()]# print(synset[argmax])return synset[argmax]

训练部分:(train.py)

from keras.callbacks import TensorBoard, ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from keras.utils import np_utils
from keras.optimizers import Adam
from model.AlexNet import AlexNet
import numpy as np
import utils
import cv2
from keras import backend as K
#K.set_image_dim_ordering('tf')
K.image_data_format() == 'channels_first'def generate_arrays_from_file(lines,batch_size):# 获取总长度n = len(lines)i = 0while 1:X_train = []Y_train = []# 获取一个batch_size大小的数据for b in range(batch_size):if i==0:np.random.shuffle(lines)name = lines[i].split(';')[0]# 从文件中读取图像img = cv2.imread(r".\data\image\train" + '/' + name)img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)img = img/255X_train.append(img)Y_train.append(lines[i].split(';')[1])# 读完一个周期后重新开始i = (i+1) % n# 处理图像X_train = utils.resize_image(X_train,(224,224))X_train = X_train.reshape(-1,224,224,3)Y_train = np_utils.to_categorical(np.array(Y_train),num_classes= 2)   yield (X_train, Y_train)if __name__ == "__main__":# 模型保存的位置log_dir = "./logs/"# 打开数据集的txtwith open(r".\data\dataset.txt","r") as f:lines = f.readlines()# 打乱行,这个txt主要用于帮助读取数据来训练# 打乱的数据更有利于训练np.random.seed(10101)np.random.shuffle(lines)np.random.seed(None)# 90%用于训练,10%用于估计。num_val = int(len(lines)*0.1)num_train = len(lines) - num_val# 建立AlexNet模型model = AlexNet()# 保存的方式,3代保存一次checkpoint_period1 = ModelCheckpoint(log_dir + 'ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5',monitor='acc', save_weights_only=False, save_best_only=True, period=3)# 学习率下降的方式,acc三次不下降就下降学习率继续训练reduce_lr = ReduceLROnPlateau(monitor='acc', factor=0.5, patience=3, verbose=1)# 是否需要早停,当val_loss一直不下降的时候意味着模型基本训练完毕,可以停止early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1)# 交叉熵model.compile(loss = 'categorical_crossentropy',optimizer = Adam(lr=1e-3),metrics = ['accuracy'])# 一次的训练集大小batch_size = 128print('Train on {} samples, val on {} samples, with batch size {}.'.format(num_train, num_val, batch_size))# 开始训练model.fit_generator(generate_arrays_from_file(lines[:num_train], batch_size),steps_per_epoch=max(1, num_train//batch_size),validation_data=generate_arrays_from_file(lines[num_train:], batch_size),validation_steps=max(1, num_val//batch_size),epochs=50,initial_epoch=0,callbacks=[checkpoint_period1, reduce_lr])model.save_weights(log_dir+'last1.h5')#保存模型

预测部分:(predict.py)

import numpy as np
import utils
import cv2
from keras import backend as K
from model.AlexNet import AlexNet# K.set_image_dim_ordering('tf')
K.image_data_format() == 'channels_first'if __name__ == "__main__":model = AlexNet()model.load_weights("./logs/last1.h5")img = cv2.imread("./test2.jpg")img_RGB = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)img_nor = img_RGB/255img_nor = np.expand_dims(img_nor,axis = 0)img_resize = utils.resize_image(img_nor,(224,224))#utils.print_answer(np.argmax(model.predict(img)))print('the answer is: ',utils.print_answer(np.argmax(model.predict(img_resize))))cv2.imshow("ooo",img)cv2.waitKey(0)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/389484.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PHP对象传递方式

<?phpheader(content-type:text/html;charsetutf-8);class Person{public $name;public $age;}$p1 new Person;$p1->name 金角大王;$p1->age 400;//这个地方&#xff0c;到底怎样?$p2 $p1;$p2->name 银角大王;echo <pre>;echo p1 name . $p1->n…

微软Azure CDN现已普遍可用

微软宣布Azure CDN一般可用&#xff08;GA&#xff09;&#xff0c;客户现在可以从微软的全球CDN网络提供内容。最新版本是对去年五月份发布的公众预览版的跟进。\\今年5月&#xff0c;微软与Verizon和Akamai一起推出了原生CDN产品。现在推出了GA版本&#xff0c;根据发布博文所…

数据科学生命周期_数据科学项目生命周期第1部分

数据科学生命周期This is series of how to developed data science project.这是如何开发数据科学项目的系列。 This is part 1.这是第1部分。 All the Life-cycle In A Data Science Projects-1. Data Analysis and visualization.2. Feature Engineering.3. Feature Selec…

Keras框架:VGG网络代码实现

VGG概念&#xff1a; VGG之所以经典&#xff0c;在于它首次将深度学习做得非常“深”&#xff0c;达 到了16-19层&#xff0c;同时&#xff0c;它用了非常“小”的卷积核&#xff08;3X3&#xff09;。 网络框架&#xff1a; VGG的结构&#xff1a; 1、一张原始图片被resize…

Django笔记1

内容整理1.创建django工程django-admin startproject 工程名2.创建APPcd 工程名python manage.py startapp cmdb3.静态文件project.settings.pySTATICFILES_dirs {os.path.join(BASE_DIR, static),}4.模板路径DIRS > [os.path.join(BASE_DIR, templates),]5.settings中mid…

BZOJ 2003 [Hnoi2010]Matrix 矩阵

题目链接 https://www.lydsy.com/JudgeOnline/problem.php?id2003 题解 考虑搜索。 确定了第一行和第一列&#xff0c;那么就确定了整个矩阵&#xff0c;因此搜索的范围可以降到399个位置。 首先搜索第一行&#xff0c;显然每个不是第一行第一列的位置都可以由三个位置唯一确定…

Keras框架:resent50代码实现

Residual net概念 概念&#xff1a; Residual net(残差网络)&#xff1a;将靠前若干层的某一层数据输出直接跳过多层引入到后面数据层的输入 部分。 残差神经单元&#xff1a;假定某段神经网络的输入是x&#xff0c;期望输出是H(x)&#xff0c;如果我们直接将输入x传到输出作…

MySQL数据库的回滚失败(JAVA)

这几天在学习MySQL数据的知识&#xff0c;有一个小测试&#xff0c;用来测试数据库的提交和回滚。 刚开始的时候真的没把这个当回事&#xff0c;按照正常的步骤来讲的话&#xff0c;如下所示&#xff0c;加载驱动&#xff0c;获取数据库的连接&#xff0c;并且把数据库的自动提…

条件概率分布_条件概率

条件概率分布If you’re currently in the job market or looking to switch careers, you’ve probably noticed an increase in popularity of Data Science jobs. In 2019, LinkedIn ranked “data scientist” the №1 most promising job in the U.S. based on job openin…

MP实战系列(十七)之乐观锁插件

声明&#xff0c;目前只是仅仅针对3.0以下版本&#xff0c;2.0以上版本。 意图&#xff1a; 当要更新一条记录的时候&#xff0c;希望这条记录没有被别人更新 乐观锁实现方式&#xff1a; 取出记录时&#xff0c;获取当前version 更新时&#xff0c;带上这个version 执行更新时…

二叉树删除节点,(查找二叉树最大值节点)

从根节点往下分别查找左子树和右子树的最大节点&#xff0c;再比较左子树&#xff0c;右子树&#xff0c;根节点的大小得到结果&#xff0c;在得到左子树和右子树最大节点的过程相似&#xff0c;因此可以采用递归的 //树节点结构 public class TreeNode { TreeNode left;…

Tensorflow框架:InceptionV3网络概念及实现

卷积神经网络迁移学习-Inception • 有论文依据表明可以保留训练好的inception模型中所有卷积层的参数&#xff0c;只替换最后一层全连接层。在最后 这一层全连接层之前的网络称为瓶颈层。 • 原理&#xff1a;在训练好的inception模型中&#xff0c;因为将瓶颈层的输出再通过…

View详解(4)

在上文中我们简单介绍了Canvas#drawCircle()的使用方式&#xff0c;以及Paint#setStyle(),Paint#setStrokeWidth(),Paint#setColor()等相关函数&#xff0c;不知道小伙伴们了解了多少&#xff1f;那么是不是所有的图形都能通过圆来描述呢&#xff1f;当然不行&#xff0c;那么熟…

成为一名真正的数据科学家有多困难

Data Science and Machine Learning are hard sports to play. It’s difficult enough to motivate yourself to sit down and learn some maths, let alone to becoming an expert on the matter.数据科学和机器学习是一项艰巨的运动。 激励自己坐下来学习一些数学知识是非常…

Ubuntu 装机软件

Ubuntu16.04 软件商店闪退打不开 sudo apt-get updatesudo apt-get dist-upgrade# 应该执行一下更新就好&#xff0c;不需要重新安装软件中心 sudo apt-get install –reinstall software-center Ubuntu16.04 深度美化 https://www.jianshu.com/p/4bd2d9b1af41 Ubuntu18.04 美化…

数据分析中的统计概率_了解统计和概率:成为专家数据科学家

数据分析中的统计概率Data Science is a hot topic nowadays. Organizations consider data scientists to be the Crme de la crme. Everyone in the industry is talking about the potential of data science and what data scientists can bring in their BigTech and FinT…

Keras框架:Mobilenet网络代码实现

Mobilenet概念&#xff1a; MobileNet模型是Google针对手机等嵌入式设备提出的一种轻量级的深层神经网络&#xff0c;其使用的核心思想便是depthwise separable convolution。 Mobilenet思想&#xff1a; 通俗地理解就是3x3的卷积核厚度只有一层&#xff0c;然后在输入张量上…

clipboard 在 vue 中的使用

简介 页面中用 clipboard 可以进行复制粘贴&#xff0c;clipboard能将内容直接写入剪切板 安装 npm install --save clipboard 使用方法一 <template><span>{{ code }}</span><iclass"el-icon-document"title"点击复制"click"co…

数据驱动开发_开发数据驱动的股票市场投资方法

数据驱动开发Data driven means that your decision are driven by data and not by emotions. This approach can be very useful in stock market investment. Here is a summary of a data driven approach which I have been taking recently数据驱动意味着您的决定是由数据…

前端之sublime text配置

接下来我们来了解如何调整sublime text的配置&#xff0c;可能很多同学下载sublime text的时候就是把它当成记事本来使用&#xff0c;也就是没有做任何自定义的配置&#xff0c;做一些自定义的配置可以让sublime text更适合我们的开发习惯。 那么在利用刚才的命令面板我们怎么打…