持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

觉得有用的话,欢迎一起讨论相互学习~Follow Me

参考文献Tensorflow实战Google深度学习框架
实验平台:
Tensorflow1.4.0
python3.5.0
MNIST数据集将四个文件下载后放到当前目录下的MNIST_data文件夹下

定义模型框架与前向传播

import tensorflow as tf# 定义神经网络结构相关参数
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500# 设置权值函数
# 在训练时会创建这些变量,在测试时会通过保存的模型加载这些变量的取值
# 因为可以在变量加载时将滑动平均变量均值重命名,所以这个函数可以直接通过同样的名字在训练时使用变量本身
# 而在测试时使用变量的滑动平均值,在这个函数中也会将变量的正则化损失加入损失集合def get_weight_variable(shape, regularizer):weights = tf.get_variable("weights", shape, initializer=tf.truncated_normal_initializer(stddev=0.1))# 如果使用正则化方法会将该张量加入一个名为'losses'的集合if regularizer != None: tf.add_to_collection('losses', regularizer(weights))return weights# 定义神经网络前向传播过程
def inference(input_tensor, regularizer):with tf.variable_scope('layer1'):weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)with tf.variable_scope('layer2'):weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))layer2 = tf.matmul(layer1, weights) + biasesreturn layer2

模型训练与模型框架及参数持久化

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import os# 配置神经网络参数
BATCH_SIZE = 100  # 批处理数据大小
LEARNING_RATE_BASE = 0.8  # 基础学习率
LEARNING_RATE_DECAY = 0.99  # 学习率衰减速度
REGULARIZATION_RATE = 0.0001  # 正则化项
TRAINING_STEPS = 30000  # 训练次数
MOVING_AVERAGE_DECAY = 0.99  # 平均滑动模型衰减参数
# 模型保存的路径和文件名
MODEL_SAVE_PATH = "MNIST_model/"
MODEL_NAME = "mnist_model"def train(mnist):# 定义输入输出placeholderx = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')  # 可以直接引用mnist_inference中的超参数y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')# 定义L2正则化器regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)# 在前向传播时使用L2正则化y = mnist_inference.inference(x, regularizer)global_step = tf.Variable(0, trainable=False)# 在可训练参数上定义平均滑动模型variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)# tf.trainable_variables()返回的是图上集合GraphKeys.TRAINABLE_VARIABLES中的元素。这个集合中的元素是所有没有指定trainable=False的参数variables_averages_op = variable_averages.apply(tf.trainable_variables())cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))cross_entropy_mean = tf.reduce_mean(cross_entropy)# 在交叉熵函数的基础上增加权值的L2正则化部分loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))# 设置学习率,其中学习率使用逐渐递减的原则learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples/BATCH_SIZE, LEARNING_RATE_DECAY,staircase=True)# 使用梯度下降优化器train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)# with tf.control_dependencies([train_step, variables_averages_op]):# train_op = tf.no_op(name='train')# 在反向传播的过程中,不仅更新神经网络中的参数还更新每一个参数的滑动平均值train_op = tf.group(train_step, variables_averages_op)# 定义Saver模型保存器saver = tf.train.Saver()with tf.Session() as sess:tf.global_variables_initializer().run()for i in range(TRAINING_STEPS):xs, ys = mnist.train.next_batch(BATCH_SIZE)_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})# 每1000轮保存一次模型if i%1000 == 0:# 输出当前的训练情况,这里只输出了模型在当前训练batch上的损失函数大小# 通过损失函数的大小可以大概了解训练的情况,# 在验证数据集上的正确率信息会有一个单独的程序来生成print("After %d training step(s), loss on training batch is %g."%(step, loss_value))# 模型保存saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)def main(argv=None):mnist = input_data.read_data_sets("./MNIST_data", one_hot=True)train(mnist)if __name__ == '__main__':tf.app.run()

模型恢复与评价测试集上的效果

import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import mnist_train# 每10秒加载一次最新的模型
# 加载的时间间隔。
EVAL_INTERVAL_SECS = 10def evaluate(mnist):with tf.Graph().as_default() as g:x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}# 直接通过调用封装好的函数来计算前向传播的结果,因为测试时不关注正则化损失的值所以这里用于计算正则化损失的函数被设置为Noney = mnist_inference.inference(x, None)correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))# 如果需要离线预测未知数据的类别,只需要将计算正确率的部分改为答案的输出即可。accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))# 通过获取变量重命名的方式来加载模型,这样在前向传播的过程中就不需要调用滑动平均的函数来获取平均值# 这样可以完全共用mnist_inference.py重定义的前向传播过程variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)variables_to_restore = variable_averages.variables_to_restore()saver = tf.train.Saver(variables_to_restore)while True:with tf.Session() as sess:# tf.train.get_checkpoint_state函数会通过checkpoint文件自动找到目录中最新模型的文件名ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)if ckpt and ckpt.model_checkpoint_path:# 加载模型saver.restore(sess, ckpt.model_checkpoint_path)# 通过文件名得到模型保存是迭代的轮数global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]accuracy_score = sess.run(accuracy, feed_dict=validate_feed)print("After %s training step(s), validation accuracy = %g"%(global_step, accuracy_score))else:print('No checkpoint file found')returntime.sleep(EVAL_INTERVAL_SECS)# 每次运行都是读取最新保存的模型,并在MNIST验证数据集上计算模型的正确率# 每隔EVAL_INTERVAL_SECS秒来调用一侧计算正确率的过程以检验训练过程中的正确率变化# ###  主程序
def main(argv=None):mnist = input_data.read_data_sets("./MNIST_data", one_hot=True)evaluate(mnist)if __name__ == '__main__':main()# After 29001 training step(s), validation accuracy = 0.9854

转载于:https://www.cnblogs.com/cloud-ken/p/9318037.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/277643.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

怎样制作滴滴截图_滴滴老了吗?

作者 / 薛静 来源 / 盒饭财经(ID:daxiongfan)滴滴最近有点忙。6月11日,滴滴地图与公交事业部负责人柴华还在忙于解答消费者对于滴滴司机绕路的质疑,网上就流传出了滴滴司机直播性侵的消息。当晚,滴滴急忙在官方微博中做出回应称已…

mysql Backup recovery

如果您要在MySQL数据库中存储任何您不想丢失的内容,那么定期备份数据以保护数据免受损失非常重要。本教程将向您展示两种简单的方法来备份和恢复MySQL数据库中的数据。您还可以使用此过程将数据移动到新的Web服务器。 从命令行备份(使用mysqldump&#x…

Kinect开发笔记之三Kinect开发环境配置详解

0、前言:首先说一下我的开发环境,Visual Studio是2013的,系统是win8的64位版本,SDK是Kinect for windows SDK 1.8版本。虽然前一篇博文费了半天劲,翻译了2.0SDK的新特性,但我还是决定要回退一个版本。其实我…

opencv python 图像缩放/图像平移/图像旋转/仿射变换/透视变换

Geometric Transformations of Images 1图像转换 OpenCV提供了两个转换函数cv2.warpAffine和cv2.warpPerspective,可以使用它们进行各种转换。 cv2.warpAffine采用2x3变换矩阵,而cv2.warpPerspective采用3x3变换矩阵作为输入。 2图像缩放 缩放只是调整图…

.net调用c++方法时如何释放c++中分配的内存_C/C++编程笔记:C语言编程知识要点总结!大一C语言知识点(全)...

一、C语言程序的构成与C、Java相比,C语言其实很简单,但却非常重要。因为它是C、Java的基础。不把C语言基础打扎实,很难成为程序员高手。1、C语言的结构先通过一个简单的例子,把C语言的基础打牢。C语言的结构要掌握以下几点&#x…

Django 使用 mysql 数据库连接

启用 mysql 数据库连接 修改 app01 下的 __init__.py import pymysqlpymysql.install_as_MySQLdb() 修改 settings.py DATABASES {default: {ENGINE: django.db.backends.mysql,NAME: django,USER: django,PASSWORD: django,HOST: 192.168.0.200,PORT: 3306,} } 测试 #生成同步…

Kinect开发笔记之四检测并调试Kinect设备

之前我们已经装好了Developer Toolkit 1.8,下面我们来做进一步的测试。首先到开始菜单中找到Kinect for Windows SDK v1.8,点击其中的Developer Toolkit Browser v1.8.0。打开后,有许多东西,我们选择最右边的Tools来筛选一下&…

c语言双引号和单引号的区别_Python中的单引号和双引号有什么区别?

在Python中使用单引号或双引号是没有区别的,都可以用来表示一个字符串。但是这两种通用的表达方式可以避免出错之外,还可以减少转义字符的使用,使程序看起来更清晰。举两个例子:1、包含单引号的字符串定义一个字符串m…

mysql 开发基础系列22 SQL Model(带迁移事项)

一.概述 与其它数据库不同,mysql 可以运行不同的sql model 下, sql model 定义了mysql应用支持的sql语法,数据校验等,这样更容易在不同的环境中使用mysql。 sql model 常用来解决下面几类问题: (1) 通过设置sql mode, …

五月28学习笔记

<!DOCTYPE html><html> <head> <meta charset"UTF-8"> <title></title> </head> <body> <!--链接标签--> <!--核心属性就是href 属性值可以是一个跳转的地址--&…

Kinect开发笔记之五使用PowerShell控制Kinect

这是第一次用MarkDown编辑器来写博客&#xff0c;挺喜欢这种没有任何格式舒服的编辑器&#xff0c;自由洒脱更加易读&#xff0c;留一个不自然的自然段纪念下找到舒服的编辑器。 这次要记录使用win7/win8内建的PowerShell来控制Kinect&#xff0c;改变Kinect的俯仰角度。 在我…

可转债数据一览表集思录_可转债股票数据一览表

128107交科转债720612061浙江交科-11.90%25113578全筑转债754030603030全筑股份-1.26%3.84113573纵横转债754602603602纵横通信5.79%2.7113577春秋转债754890603890春秋电子-9.46%2.4123050聚飞转债370303300303聚飞光电2.52%7.05110070凌钢转债733231600231凌钢股份24.44%4.41…

国标流媒体H5实现无插件视频监控按需直播

介绍 按需直播肯定是为了减少带宽流量和服务器性能占用。安防行业GB28181协议天生就是按需播放的&#xff0c;有人请求播放时服务端才从设备端获取设备的直播流或录像视频&#xff0c;停止播放时就会停止获取视频流。同时GB28181协议又是目前安防设备厂商都支持的统一的协议&am…

ipa 安装包不用市场如果扫码下载安装 免费IOS安装API

在做开发过程中可能会用于生成测试包的情况,不过测试包不能直接安装,非常不方便,所以我提供给大家一下可通过链接下载安装的方法也可以把链接生成二维码扫码下载 api地址: https://tool.bitefu.net/ipa/ 文件地址:http://tool.bitefu.net/showdoc/web/#/3 源码下载:http://tado…

Kinect开发笔记之六Kinect Studio的应用

这一次我们来操作一下Kinect Studio&#xff0c;体验一下它给我们带来的功能。 首先我们需要打开Developer Toolkit Browser 1.8&#xff0c;打开后在默认情况下&#xff0c;光标是选择在All选项卡上的&#xff0c;即我们现在所有Developer Toolkit Browser中的部件都可以看得…

antd picker 使用 如何_如何打造 Serverless JavaScript 全栈商业级应用?

2019 年底我们发布过一篇《O’Reilly 1500 份问卷调研&#xff1a;2019 年 Serverless 落地到底香不香&#xff1f;》&#xff0c;揭示了海外 Serverless 的落地情况&#xff0c;但中国 Serverless 的落地实践分享相对较少&#xff0c;似乎谁都在喊 Serverless&#xff0c;谁都…

【Android Studio安装部署系列】十三、Android studio添加和删除Module 2

版权声明&#xff1a;本文为HaiyuKing原创文章&#xff0c;转载请注明出处&#xff01; 概述 新建、导入、删除Module是常见的操作&#xff0c;这里简单介绍下。 新建Module File——New——New Module... 选中Android Library 修改Library名称 在项目工程中修改依赖 和添加下面…

Kinect开发笔记之七Visual Studio结合C#调控Kinect俯仰角度

总感觉自己前面啰啰嗦嗦写了好多&#xff0c;却一直都没有使用用开发kinect的重型武器——Visual Studio。 那么本次我们就借助于Visual Studio&#xff0c;写一个C#程序&#xff0c;连接Kinect并调用Kinect SDK标准函数库来改变Kinect的俯仰角。 首先我们打开VS创建一个项目…

hadoop HDFS常用文件操作命令

命令基本格式: hadoop fs -cmd < args >1.ls hadoop fs -ls /列出hdfs文件系统根目录下的目录和文件 hadoop fs -ls -R /列出hdfs文件系统所有的目录和文件 2.put hadoop fs -put < local file > < hdfs file >hdfs file的父目录一定要存在&#xff0c;否则…

定量库存控制模型_探索全面流动管理TFM 库存控制与低减的理性策略

库存乃万恶之源库存不仅占用了资金&#xff0c;还占用了各种管理性资源&#xff0c;形成了“财务性显性成本“而且过多的库存导致“缓冲区”的存在&#xff0c;还使得各类问题变得不那么紧迫&#xff0c;从而掩盖了各类隐藏的问题&#xff0c;这被称为“隐形成本”零库存不仅做…