Tensorflow框架:卷积神经网络实战--Cifar训练集

Cifar-10数据集包含10类共60000张32*32的彩色图片,每类6000张图。包括50000张训练图片和 10000张测试图片
在这里插入图片描述

代码分为数据处理部分和卷积网络训练部分:

数据处理部分:

#该文件负责读取Cifar-10数据并对其进行数据增强预处理
import os
import tensorflow as tf
num_classes=10#设定用于训练和评估的样本总数
num_examples_pre_epoch_for_train=50000
num_examples_pre_epoch_for_eval=10000#定义一个空类,用于返回读取的Cifar-10的数据
class CIFAR10Record(object):pass#定义一个读取Cifar-10的函数read_cifar10(),这个函数的目的就是读取目标文件里面的内容
def read_cifar10(file_queue):result=CIFAR10Record()label_bytes=1                                            #标签占用一字节,如果是Cifar-100数据集,则此处为2result.height=32result.width=32result.depth=3                                           #因为是RGB三通道,所以深度是3image_bytes=result.height * result.width * result.depth  #图片样本总元素数量record_bytes=label_bytes + image_bytes                   #因为每一个样本包含图片和标签,所以最终的元素数量还需要图片样本数量加上一个标签值reader=tf.FixedLengthRecordReader(record_bytes=record_bytes)  #使用tf.FixedLengthRecordReader()创建一个文件读取类。该类的目的就是读取文件result.key,value=reader.read(file_queue)                 #使用该类的read()函数从文件队列里面读取文件record_bytes=tf.decode_raw(value,tf.uint8)               #读取到文件以后,将读取到的文件内容从字符串形式解析为图像对应的像素数组#因为该数组第一个元素是标签,所以我们使用strided_slice()函数将标签提取出来,并且使用tf.cast()函数将这一个标签转换成int32的数值形式result.label=tf.cast(tf.strided_slice(record_bytes,[0],[label_bytes]),tf.int32)#剩下的元素再分割出来,这些就是图片数据,因为这些数据在数据集里面存储的形式是depth * height * width,我们要把这种格式转换成[depth,height,width]#这一步是将一维数据转换成3维数据depth_major=tf.reshape(tf.strided_slice(record_bytes,[label_bytes],[label_bytes + image_bytes]),[result.depth,result.height,result.width])  #我们要将之前分割好的图片数据使用tf.transpose()函数转换成为高度信息、宽度信息、深度信息这样的顺序#这一步是转换数据排布方式,变为(h,w,c)result.uint8image=tf.transpose(depth_major,[1,2,0])return result                                 #返回值是已经把目标文件里面的信息都读取出来def inputs(data_dir,batch_size,distorted):               #这个函数就对数据进行预处理---对图像数据是否进行增强进行判断,并作出相应的操作filenames=[os.path.join(data_dir,"data_batch_%d.bin"%i)for i in range(1,6)]   #拼接地址file_queue=tf.train.string_input_producer(filenames)     #根据已经有的文件地址创建一个文件队列read_input=read_cifar10(file_queue)                      #根据已经有的文件队列使用已经定义好的文件读取函数read_cifar10()读取队列中的文件reshaped_image=tf.cast(read_input.uint8image,tf.float32)   #将已经转换好的图片数据再次转换为float32的形式num_examples_per_epoch=num_examples_pre_epoch_for_trainif distorted != None:                         #如果预处理函数中的distorted参数不为空值,就代表要进行图片增强处理cropped_image=tf.random_crop(reshaped_image,[24,24,3])          #首先将预处理好的图片进行剪切,使用tf.random_crop()函数flipped_image=tf.image.random_flip_left_right(cropped_image)    #将剪切好的图片进行左右翻转,使用tf.image.random_flip_left_right()函数adjusted_brightness=tf.image.random_brightness(flipped_image,max_delta=0.8)   #将左右翻转好的图片进行随机亮度调整,使用tf.image.random_brightness()函数adjusted_contrast=tf.image.random_contrast(adjusted_brightness,lower=0.2,upper=1.8)    #将亮度调整好的图片进行随机对比度调整,使用tf.image.random_contrast()函数float_image=tf.image.per_image_standardization(adjusted_contrast)          #进行标准化图片操作,tf.image.per_image_standardization()函数是对每一个像素减去平均值并除以像素方差float_image.set_shape([24,24,3])                      #设置图片数据及标签的形状read_input.label.set_shape([1])min_queue_examples=int(num_examples_pre_epoch_for_eval * 0.4)print("Filling queue with %d CIFAR images before starting to train.    This will take a few minutes."%min_queue_examples)images_train,labels_train=tf.train.shuffle_batch([float_image,read_input.label],batch_size=batch_size,num_threads=16,capacity=min_queue_examples + 3 * batch_size,min_after_dequeue=min_queue_examples,)#使用tf.train.shuffle_batch()函数随机产生一个batch的image和labelreturn images_train,tf.reshape(labels_train,[batch_size])else:                               #不对图像数据进行数据增强处理resized_image=tf.image.resize_image_with_crop_or_pad(reshaped_image,24,24)   #在这种情况下,使用函数tf.image.resize_image_with_crop_or_pad()对图片数据进行剪切float_image=tf.image.per_image_standardization(resized_image)          #剪切完成以后,直接进行图片标准化操作float_image.set_shape([24,24,3])read_input.label.set_shape([1])min_queue_examples=int(num_examples_per_epoch * 0.4)images_test,labels_test=tf.train.batch([float_image,read_input.label],batch_size=batch_size,num_threads=16,capacity=min_queue_examples + 3 * batch_size)#这里使用batch()函数代替tf.train.shuffle_batch()函数return images_test,tf.reshape(labels_test,[batch_size])

卷积网络训练部分:

#该文件的目的是构造神经网络的整体结构,并进行训练和测试(评估)过程
import tensorflow as tf
import numpy as np
import time
import math
import Cifar10_datamax_steps=4000
batch_size=100
num_examples_for_eval=10000
data_dir="Cifar_data/cifar-10-batches-bin"#创建一个variable_with_weight_loss()函数,该函数的作用是:
#   1.使用参数w1控制L2 loss的大小
#   2.使用函数tf.nn.l2_loss()计算权重L2 loss
#   3.使用函数tf.multiply()计算权重L2 loss与w1的乘积,并赋值给weights_loss
#   4.使用函数tf.add_to_collection()将最终的结果放在名为losses的集合里面,方便后面计算神经网络的总体loss,
def variable_with_weight_loss(shape,stddev,w1):var=tf.Variable(tf.truncated_normal(shape,stddev=stddev))if w1 is not None:weights_loss=tf.multiply(tf.nn.l2_loss(var),w1,name="weights_loss")tf.add_to_collection("losses",weights_loss)return var#使用上一个文件里面已经定义好的文件序列读取函数读取训练数据文件和测试数据从文件.
#其中训练数据文件进行数据增强处理,测试数据文件不进行数据增强处理
images_train,labels_train=Cifar10_data.inputs(data_dir=data_dir,batch_size=batch_size,distorted=True)
images_test,labels_test=Cifar10_data.inputs(data_dir=data_dir,batch_size=batch_size,distorted=None)#创建x和y_两个placeholder,用于在训练或评估时提供输入的数据和对应的标签值。
#要注意的是,由于以后定义全连接网络的时候用到了batch_size,所以x中,第一个参数不应该是None,而应该是batch_size
x=tf.placeholder(tf.float32,[batch_size,24,24,3])
y_=tf.placeholder(tf.int32,[batch_size])#创建第一个卷积层 shape=(kh,kw,ci,co)
kernel1=variable_with_weight_loss(shape=[5,5,3,64],stddev=5e-2,w1=0.0)
conv1=tf.nn.conv2d(x,kernel1,[1,1,1,1],padding="SAME")
bias1=tf.Variable(tf.constant(0.0,shape=[64]))
relu1=tf.nn.relu(tf.nn.bias_add(conv1,bias1))
pool1=tf.nn.max_pool(relu1,ksize=[1,3,3,1],strides=[1,2,2,1],padding="SAME")#创建第二个卷积层
kernel2=variable_with_weight_loss(shape=[5,5,64,64],stddev=5e-2,w1=0.0)
conv2=tf.nn.conv2d(pool1,kernel2,[1,1,1,1],padding="SAME")
bias2=tf.Variable(tf.constant(0.1,shape=[64]))
relu2=tf.nn.relu(tf.nn.bias_add(conv2,bias2))
pool2=tf.nn.max_pool(relu2,ksize=[1,3,3,1],strides=[1,2,2,1],padding="SAME")#因为要进行全连接层的操作,所以这里使用tf.reshape()函数将pool2输出变成一维向量,并使用get_shape()函数获取扁平化之后的长度
reshape=tf.reshape(pool2,[batch_size,-1])    #这里面的-1代表将pool2的三维结构拉直为一维结构
dim=reshape.get_shape()[1].value             #get_shape()[1].value表示获取reshape之后的第二个维度的值#建立第一个全连接层
weight1=variable_with_weight_loss(shape=[dim,384],stddev=0.04,w1=0.004)
fc_bias1=tf.Variable(tf.constant(0.1,shape=[384]))
fc_1=tf.nn.relu(tf.matmul(reshape,weight1)+fc_bias1)#建立第二个全连接层
weight2=variable_with_weight_loss(shape=[384,192],stddev=0.04,w1=0.004)
fc_bias2=tf.Variable(tf.constant(0.1,shape=[192]))
local4=tf.nn.relu(tf.matmul(fc_1,weight2)+fc_bias2)#建立第三个全连接层
weight3=variable_with_weight_loss(shape=[192,10],stddev=1 / 192.0,w1=0.0)
fc_bias3=tf.Variable(tf.constant(0.1,shape=[10]))
result=tf.add(tf.matmul(local4,weight3),fc_bias3)#计算损失,包括权重参数的正则化损失和交叉熵损失
cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=result,labels=tf.cast(y_,tf.int64))weights_with_l2_loss=tf.add_n(tf.get_collection("losses"))
loss=tf.reduce_mean(cross_entropy)+weights_with_l2_losstrain_op=tf.train.AdamOptimizer(1e-3).minimize(loss)#函数tf.nn.in_top_k()用来计算输出结果中top k的准确率,函数默认的k值是1,即top 1的准确率,也就是输出分类准确率最高时的数值
top_k_op=tf.nn.in_top_k(result,y_,1)init_op=tf.global_variables_initializer()
with tf.Session() as sess:sess.run(init_op)#启动线程操作,这是因为之前数据增强的时候使用train.shuffle_batch()函数的时候通过参数num_threads()配置了16个线程用于组织batch的操作tf.train.start_queue_runners()      #每隔100step会计算并展示当前的loss、每秒钟能训练的样本数量、以及训练一个batch数据所花费的时间for step in range (max_steps):start_time=time.time()image_batch,label_batch=sess.run([images_train,labels_train])_,loss_value=sess.run([train_op,loss],feed_dict={x:image_batch,y_:label_batch})duration=time.time() - start_timeif step % 100 == 0:examples_per_sec=batch_size / durationsec_per_batch=float(duration)print("step %d,loss=%.2f(%.1f examples/sec;%.3f sec/batch)"%(step,loss_value,examples_per_sec,sec_per_batch))#计算最终的正确率num_batch=int(math.ceil(num_examples_for_eval/batch_size))  #math.ceil()函数用于求整true_count=0total_sample_count=num_batch * batch_size#在一个for循环里面统计所有预测正确的样例个数for j in range(num_batch):image_batch,label_batch=sess.run([images_test,labels_test])predictions=sess.run([top_k_op],feed_dict={x:image_batch,y_:label_batch})true_count += np.sum(predictions)#打印正确率信息print("accuracy = %.3f%%"%((true_count/total_sample_count) * 100))

实现结果:

在这里插入图片描述
准确率在74%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/389488.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机科学速成课36:自然语言处理

词性 短语结构规则 分析树 语音识别 谱图 快速傅里叶变换 音素 语音合成 转载于:https://www.cnblogs.com/davidliu2018/p/9149252.html

linux内存初始化初期内存分配器——memblock

2019独角兽企业重金招聘Python工程师标准>>> 1.1.1 memblock 系统初始化的时候buddy系统,slab分配器等并没有被初始化好,当需要执行一些内存管理、内存分配的任务,就引入了一种内存管理器bootmem分配器。 当buddy系统和slab分配器初始化好后&…

数据科学学习心得_学习数据科学

数据科学学习心得苹果 | GOOGLE | 现货 | 其他 (APPLE | GOOGLE | SPOTIFY | OTHERS) Editor’s note: The Towards Data Science podcast’s “Climbing the Data Science Ladder” series is hosted by Jeremie Harris. Jeremie helps run a data science mentorship startup…

Keras框架:Alexnet网络代码实现

网络思想: 1、一张原始图片被resize到(224,224,3); 2、使用步长为4x4,大小为11的卷积核对图像进行卷积,输出的特征层为96层, 输出的shape为(55,55,96); 3、使用步长为2的最大池化层进行池化,此时…

PHP对象传递方式

<?phpheader(content-type:text/html;charsetutf-8);class Person{public $name;public $age;}$p1 new Person;$p1->name 金角大王;$p1->age 400;//这个地方&#xff0c;到底怎样?$p2 $p1;$p2->name 银角大王;echo <pre>;echo p1 name . $p1->n…

微软Azure CDN现已普遍可用

微软宣布Azure CDN一般可用&#xff08;GA&#xff09;&#xff0c;客户现在可以从微软的全球CDN网络提供内容。最新版本是对去年五月份发布的公众预览版的跟进。\\今年5月&#xff0c;微软与Verizon和Akamai一起推出了原生CDN产品。现在推出了GA版本&#xff0c;根据发布博文所…

数据科学生命周期_数据科学项目生命周期第1部分

数据科学生命周期This is series of how to developed data science project.这是如何开发数据科学项目的系列。 This is part 1.这是第1部分。 All the Life-cycle In A Data Science Projects-1. Data Analysis and visualization.2. Feature Engineering.3. Feature Selec…

Keras框架:VGG网络代码实现

VGG概念&#xff1a; VGG之所以经典&#xff0c;在于它首次将深度学习做得非常“深”&#xff0c;达 到了16-19层&#xff0c;同时&#xff0c;它用了非常“小”的卷积核&#xff08;3X3&#xff09;。 网络框架&#xff1a; VGG的结构&#xff1a; 1、一张原始图片被resize…

Django笔记1

内容整理1.创建django工程django-admin startproject 工程名2.创建APPcd 工程名python manage.py startapp cmdb3.静态文件project.settings.pySTATICFILES_dirs {os.path.join(BASE_DIR, static),}4.模板路径DIRS > [os.path.join(BASE_DIR, templates),]5.settings中mid…

BZOJ 2003 [Hnoi2010]Matrix 矩阵

题目链接 https://www.lydsy.com/JudgeOnline/problem.php?id2003 题解 考虑搜索。 确定了第一行和第一列&#xff0c;那么就确定了整个矩阵&#xff0c;因此搜索的范围可以降到399个位置。 首先搜索第一行&#xff0c;显然每个不是第一行第一列的位置都可以由三个位置唯一确定…

Keras框架:resent50代码实现

Residual net概念 概念&#xff1a; Residual net(残差网络)&#xff1a;将靠前若干层的某一层数据输出直接跳过多层引入到后面数据层的输入 部分。 残差神经单元&#xff1a;假定某段神经网络的输入是x&#xff0c;期望输出是H(x)&#xff0c;如果我们直接将输入x传到输出作…

MySQL数据库的回滚失败(JAVA)

这几天在学习MySQL数据的知识&#xff0c;有一个小测试&#xff0c;用来测试数据库的提交和回滚。 刚开始的时候真的没把这个当回事&#xff0c;按照正常的步骤来讲的话&#xff0c;如下所示&#xff0c;加载驱动&#xff0c;获取数据库的连接&#xff0c;并且把数据库的自动提…

条件概率分布_条件概率

条件概率分布If you’re currently in the job market or looking to switch careers, you’ve probably noticed an increase in popularity of Data Science jobs. In 2019, LinkedIn ranked “data scientist” the №1 most promising job in the U.S. based on job openin…

MP实战系列(十七)之乐观锁插件

声明&#xff0c;目前只是仅仅针对3.0以下版本&#xff0c;2.0以上版本。 意图&#xff1a; 当要更新一条记录的时候&#xff0c;希望这条记录没有被别人更新 乐观锁实现方式&#xff1a; 取出记录时&#xff0c;获取当前version 更新时&#xff0c;带上这个version 执行更新时…

二叉树删除节点,(查找二叉树最大值节点)

从根节点往下分别查找左子树和右子树的最大节点&#xff0c;再比较左子树&#xff0c;右子树&#xff0c;根节点的大小得到结果&#xff0c;在得到左子树和右子树最大节点的过程相似&#xff0c;因此可以采用递归的 //树节点结构 public class TreeNode { TreeNode left;…

Tensorflow框架:InceptionV3网络概念及实现

卷积神经网络迁移学习-Inception • 有论文依据表明可以保留训练好的inception模型中所有卷积层的参数&#xff0c;只替换最后一层全连接层。在最后 这一层全连接层之前的网络称为瓶颈层。 • 原理&#xff1a;在训练好的inception模型中&#xff0c;因为将瓶颈层的输出再通过…

View详解(4)

在上文中我们简单介绍了Canvas#drawCircle()的使用方式&#xff0c;以及Paint#setStyle(),Paint#setStrokeWidth(),Paint#setColor()等相关函数&#xff0c;不知道小伙伴们了解了多少&#xff1f;那么是不是所有的图形都能通过圆来描述呢&#xff1f;当然不行&#xff0c;那么熟…

成为一名真正的数据科学家有多困难

Data Science and Machine Learning are hard sports to play. It’s difficult enough to motivate yourself to sit down and learn some maths, let alone to becoming an expert on the matter.数据科学和机器学习是一项艰巨的运动。 激励自己坐下来学习一些数学知识是非常…

Ubuntu 装机软件

Ubuntu16.04 软件商店闪退打不开 sudo apt-get updatesudo apt-get dist-upgrade# 应该执行一下更新就好&#xff0c;不需要重新安装软件中心 sudo apt-get install –reinstall software-center Ubuntu16.04 深度美化 https://www.jianshu.com/p/4bd2d9b1af41 Ubuntu18.04 美化…

数据分析中的统计概率_了解统计和概率:成为专家数据科学家

数据分析中的统计概率Data Science is a hot topic nowadays. Organizations consider data scientists to be the Crme de la crme. Everyone in the industry is talking about the potential of data science and what data scientists can bring in their BigTech and FinT…