tf.train.Saver

将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情。tf里面提供模型保存的是tf.train.Saver()模块。

模型保存,先要创建一个Saver对象:如

saver=tf.train.Saver()

在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:

saver=tf.train.Saver(max_to_keep=0)

但是这样做除了多占用硬盘,并没有实际多大的用处,因此不推荐。

当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即

saver=tf.train.Saver(max_to_keep=1)

创建完saver对象后,就可以保存训练好的模型了,如:

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)

第一个参数sess,这个就不用说了。第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。

saver.save(sess,'my-model', global_step=0) ==>      filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

看一个mnist实例:

 

# -*- coding:utf-8 -*-

"""

Created on SunJun  4 10:29:48 2017

 

@author:Administrator

"""

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

mnist =input_data.read_data_sets("MNIST_data/", one_hot=False)

 

x =tf.placeholder(tf.float32, [None, 784])

y_=tf.placeholder(tf.int32,[None,])

 

dense1 =tf.layers.dense(inputs=x,

                      units=1024,

                      activation=tf.nn.relu,

                     kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

                      kernel_regularizer=tf.nn.l2_loss)

dense2=tf.layers.dense(inputs=dense1,

                      units=512,

                      activation=tf.nn.relu,

                     kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

                      kernel_regularizer=tf.nn.l2_loss)

logits=tf.layers.dense(inputs=dense2,

                        units=10,

                        activation=None,

                       kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),

                       kernel_regularizer=tf.nn.l2_loss)

 

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)

train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)

correct_prediction= tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)   

acc=tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

 

sess=tf.InteractiveSession() 

sess.run(tf.global_variables_initializer())

 

saver=tf.train.Saver(max_to_keep=1)

for i in range(100):

  batch_xs, batch_ys = mnist.train.next_batch(100)

  sess.run(train_op, feed_dict={x: batch_xs,y_: batch_ys})

  val_loss,val_acc=sess.run([loss,acc],feed_dict={x: mnist.test.images, y_: mnist.test.labels})

  print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

  saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

sess.close()

 

代码中红色部分就是保存模型的代码,虽然我在每训练完一代的时候,都进行了保存,但后一次保存的模型会覆盖前一次的,最终只会保存最后一次。因此我们可以节省时间,将保存代码放到循环之外(仅适用max_to_keep=1,否则还是需要放在循环内).

在实验中,最后一代可能并不是验证精度最高的一代,因此我们并不想默认保存最后一代,而是想保存验证精度最高的一代,则加个中间变量和判断语句就可以了。

 

saver=tf.train.Saver(max_to_keep=1)

max_acc=0

for i in range(100):

  batch_xs, batch_ys =mnist.train.next_batch(100)

  sess.run(train_op, feed_dict={x: batch_xs,y_: batch_ys})

  val_loss,val_acc=sess.run([loss,acc],feed_dict={x: mnist.test.images, y_: mnist.test.labels})

  print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

  ifval_acc>max_acc:

      max_acc=val_acc

     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

sess.close()

 

如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。

 

saver=tf.train.Saver(max_to_keep=3)

max_acc=0

f=open('ckpt/acc.txt','w')

for i in range(100):

  batch_xs, batch_ys =mnist.train.next_batch(100)

  sess.run(train_op, feed_dict={x: batch_xs,y_: batch_ys})

  val_loss,val_acc=sess.run([loss,acc],feed_dict={x: mnist.test.images, y_: mnist.test.labels})

  print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

  f.write(str(i+1)+', val_acc:'+str(val_acc)+'\n')

  if val_acc>max_acc:

      max_acc=val_acc

      saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

f.close()

sess.close()

 

 

模型的恢复用的是restore()函数,它需要两个参数restore(sess,save_path)save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:

model_file=tf.train.latest_checkpoint('ckpt/')

saver.restore(sess,model_file)

则程序后半段代码我们可以改为:

 

sess=tf.InteractiveSession() 

sess.run(tf.global_variables_initializer())

 

is_train=False

saver=tf.train.Saver(max_to_keep=3)

 

#训练阶段

if is_train:

    max_acc=0

    f=open('ckpt/acc.txt','w')

    for i in range(100):

      batch_xs, batch_ys = mnist.train.next_batch(100)

      sess.run(train_op, feed_dict={x:batch_xs, y_: batch_ys})

      val_loss,val_acc=sess.run([loss,acc],feed_dict={x: mnist.test.images, y_: mnist.test.labels})

      print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))

      f.write(str(i+1)+', val_acc:'+str(val_acc)+'\n')

      if val_acc>max_acc:

          max_acc=val_acc

         saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)

    f.close()

 

#验证阶段

else:

    model_file=tf.train.latest_checkpoint('ckpt/')

    saver.restore(sess,model_file)

    val_loss,val_acc=sess.run([loss,acc],feed_dict={x: mnist.test.images, y_: mnist.test.labels})

    print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))

sess.close()

 

标红的地方,就是与保存、恢复模型相关的代码。用一个bool型变量is_train来控制训练和验证两个阶段。

整个源程序:

 

# -*- coding:utf-8 -*-"""Created on SunJun  4 10:29:48 2017@author:Administrator"""import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist =input_data.read_data_sets("MNIST_data/", one_hot=False)x =tf.placeholder(tf.float32, [None, 784])y_=tf.placeholder(tf.int32,[None,])dense1 =tf.layers.dense(inputs=x,units=1024,activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss)dense2=tf.layers.dense(inputs=dense1,units=512,activation=tf.nn.relu,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss)logits=tf.layers.dense(inputs=dense2,units=10,activation=None,kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),kernel_regularizer=tf.nn.l2_loss)loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)correct_prediction= tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)   acc=tf.reduce_mean(tf.cast(correct_prediction, tf.float32))sess=tf.InteractiveSession() sess.run(tf.global_variables_initializer())is_train=Truesaver=tf.train.Saver(max_to_keep=3)#训练阶段if is_train:max_acc=0f=open('ckpt/acc.txt','w')for i in range(100):batch_xs, batch_ys =mnist.train.next_batch(100)sess.run(train_op, feed_dict={x:batch_xs, y_: batch_ys})val_loss,val_acc=sess.run([loss,acc],feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')if val_acc>max_acc:max_acc=val_accsaver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)f.close()#验证阶段else:model_file=tf.train.latest_checkpoint('ckpt/')saver.restore(sess,model_file)val_loss,val_acc=sess.run([loss,acc],feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))sess.close()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/350638.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Neo4j:空值如何工作?

我时不时地发现自己想将CSV文件导入Neo4j,而我总是对如何处理可能潜伏在其中的各种空值感到困惑。 让我们从一个没有CSV文件的示例开始。 考虑以下列表,以及我尝试仅返回空值的尝试: WITH [null, "null", "", "Ma…

楼层钢筋验收会议纪要_钢筋施工质量通病防治

一、钢筋原材1、钢筋表面出现黄色浮锈,严重转为红色,日久后变成暗褐色,甚至发生鱼鳞片剥落现象。图片原因保管不良,受到雨雪侵蚀,存放期长,仓库环境潮湿,通风不良。防 治 措 施1、钢筋原料应存放…

simulink代码生成(一)——环境搭建

一、安装C2000的嵌入式环境; 点击matlab附加功能, 然后搜索C2000,安装嵌入式硬件支持包;点击安装即可;(目前还不知道破解版的怎么操作,目前我用的是正版的这样,完全破解的可能操作…

五步法颈椎病自我按摩图解

​​1.揉捏颈、肩、臂 操作:自我按摩时取坐位。拇指张开,其余四指并拢,虎口相对用力,自枕部开始沿颈椎棘突两旁的肌肉向下揉捏,至上背部手能摸到之处为止。反复揉捏3分钟,然后以相同手法揉捏患侧上肢和颈部…

tf.one_hot

tf.one_hot(indices,#输入,这里是一维的depth,# one hotdimension.on_valueNone,#output 默认1off_valueNone,#output 默认0axisNone,#根据我的实验,默认为1dtypeNone,nameNone) 测试程序,一般说,有几类,depth等于分类…

使用get set方法添减属性_头皮银屑病“克星”使用方法,你GET了吗?

相信小伙伴们最近都了解了治疗头皮银屑病需要使用专业剂型。但...方法不对,心血白费。头皮银屑病专用剂型的正确使用方法,你真的知道吗?快来和利奥娜一起,Get√正确的使用方法吧!适合头皮银屑病的专用药剂1.复方制剂卡…

spring hsqldb_在Spring中嵌入HSQLDB服务器实例

spring hsqldb我一直在愉快地使用XAMPP进行开发,直到不得不将其托管在可通过Internet访问的某个地方,供客户端进行测试和使用。 我有一个仅具有384 RAM的VPS,并且需要快速找到方法,因此决定将XAMPP安装到VPS中。 由于内存较低&…

线性回归,logistic回归和一般回归

http://www.cnblogs.com/riskyer/p/3217601.html转载于:https://www.cnblogs.com/focus-z/p/10822757.html

double小数点后最多几位_基金理财买入后,不断亏损,是最多本金亏光,还是会出现负值...

投资基金不会把本金亏光,更不会倒贴钱,基金是一篮子股票,个别股票或许有黑天鹅事件,不可能全部同时出现,持续亏损的话,可能最后面临清盘,有结算流程,会将剩下的份额折算到投资者的账…

tf.layers.flatten

flatten( inputs, nameNone) 参数说明如下: inputs:必需,即输入数据。name:可选,默认为 None,即该层的名称。

JUnit 5 –参数化测试

JUnit 5令人印象深刻,尤其是当您深入研究扩展模型和体系结构时 。 但是从表面上讲,编写测试的地方,开发的过程比革命的过程更具进化性 – JUnit 4上没有杀手级功能吗? 幸运的是,至少有一个:参数化测试。 JU…

RMQ问题-ST表倍增处理静态区间最值

简介 ST表是利用倍增思想处理RMQ问题(区间最值问题)的一种工具。 它能够做到O(nlogn)预处理,O(1)查询的时间复杂度,效率相当不错。 算法 1.预处理 ST表利用倍增的思想。以洛谷的P3865作为例子。我们需要查询某一区间的最大值。 我…

tf.layers.dropout

dropout 是指在深度学习网络的训练过程中,对于神经网络单元,按照一定的概率将其暂时从网络中丢弃, 可以用来防止过拟合,layers 模块中提供了 tf.layers.dropout() 方法来实现这一操作,定义在 tensorflow/python/layers…

nginx是干嘛用的_nginx小技巧 -非root身份运行nginx

简直罪过,写这篇文章完全是一场毫无意义的口水仗引起的,我这人就这样,喜欢拿事实说话,而不是一句话说的让人摸不着头脑!下载源码文件:wget http://nginx.org/download/nginx-1.16.1.tar.gz解压:…

pooling池化

pooling,即池化,layers 模块提供了多个池化方法,这几个池化方法都是类似的,包括 tf.layers.max_pooling1d()、tf.layers.max_pooling2d()、tf.layers.max_pooling3d()、tf.layers.average_pooling1d()、tf.layers.average_pooling…

mysql 数据迁移_MySQL海量数据迁移

数据库迁移本主前一段时间写毕业设计主要使用MySQL,紧锣密鼓的开发了将近一个多月,项目数据层、接口层、数据采集层已经开发完成,算法还在不断的优化提速,由于请了几位大佬帮我做Code Review,所以不得已购买了一个阿里…

数组中重复的数字

解决问题思路1. 代码实现: package j2;import java.util.Arrays;/*** Created by admin on 2019/5/8.*/ public class FindDuplicate {public static void duplicate(int[] numbers,int length,int[]duplication){//边界条件的判断if (numbers null || length0) {r…

eclipse中junit_在Eclipse中有效使用JUnit

eclipse中junit最近,我被卷入了讨论1和一些受感染的同伴2,他们关于我们如何在Eclipse IDE中使用JUnit 。 令人惊讶的是,对话带来了并非所有人都知道的一些“技巧”。 这使我有了写这篇文章的想法,总结了我们的演讲。 谁知道–也许…

BFS迷宫问题模型(具体模拟过程见《啊哈算法》)

题目描述与DFS模型走迷宫那篇一样。小哈被困在迷宫里,小哼解救小哈。 这里用BFS来写。BFS(广搜)与DFS(深搜)的区别就在于,DFS是“不撞南墙不回头”,一条路走到不能再走之后才会回到起始点&#…

Spring Batch可重启性

首先,我要非常感谢Spring的优秀人员,他们花了无数时间来确保Spring Batch作业的可行性,以及发出重新启动作业的神奇能力! 感谢您提供的这个优雅的工具集,它使我们能够浏览大量数据集,同时使我们在跌倒时能够…