TensorFlow 2.0 mnist手写数字识别(CNN卷积神经网络)

TensorFlow 2.0 (五) - mnist手写数字识别(CNN卷积神经网络)

源代码/数据集已上传到 Github - tensorflow-tutorial-samples

卷积神经网络gif动图

大白话讲解卷积神经网络工作原理,推荐一个bilibili的讲卷积神经网络的视频,up主从youtube搬运过来,用中文讲了一遍。

这篇文章是 TensorFlow 2.0 Tutorial 入门教程的第五篇文章,介绍如何使用卷积神经网络(Convolutional Neural Network, CNN)来提高mnist手写数字识别的准确性。之前使用了最简单的784x10的神经网络,达到了 0.91 的正确性,而这篇文章在使用了卷积神经网络后,正确性达到了0.99

卷积神经网络(Convolutional Neural Network, CNN)是一种前馈神经网络,它的人工神经元可以响应一部分覆盖范围内的周围单元,对于大型图像处理有出色表现。

卷积神经网络由一个或多个卷积层和顶端的全连通层(对应经典的神经网络)组成,同时也包括关联权重和池化层(pooling layer)。这一结构使得卷积神经网络能够利用输入数据的二维结构。与其他深度学习结构相比,卷积神经网络在图像和语音识别方面能够给出更好的结果。这一模型也可以使用反向传播算法进行训练。相比较其他深度、前馈神经网络,卷积神经网络需要考量的参数更少,使之成为一种颇具吸引力的深度学习结构。

——维基百科

1. 安装TensorFlow 2.0

Google与2019年3月发布了TensorFlow 2.0,TensorFlow 2.0 清理了废弃的API,通过减少重复来简化API,并且通过Keras能够轻松地构建模型,从这篇文章开始,教程示例采用TensorFlow 2.0版本。

1
pip install tensorflow==2.0.0-beta0

或者在这里下载whl包安装:https://pypi.tuna.tsinghua.edu.cn/simple/tensorflow/

2. 代码目录结构

1
2
3
4
5
6
7
8
9
10
11
12
13
data_set_tf2/  # TensorFlow 2.0的mnist数据集|--mnist.npz  
test_images/   # 预测所用的图片|--0.png|--1.png|--4.png
v4_cnn/|--ckpt/   # 模型保存的位置|--checkpoint|--cp-0005.ckpt.data-00000-of-00001|--cp-0005.ckpt.index|--predict.py  # 预测代码|--train.py    # 训练代码

3. CNN模型代码(train.py)

模型定义的前半部分主要使用Keras.layers提供的Conv2D(卷积)与MaxPooling2D(池化)函数。

CNN的输入是维度为 (image_height, image_width, color_channels)的张量,mnist数据集是黑白的,因此只有一个color_channel(颜色通道),一般的彩色图片有3个(R,G,B),熟悉Web前端的同学可能知道,有些图片有4个通道(R,G,B,A),A代表透明度。对于mnist数据集,输入的张量维度就是(28,28,1),通过参数input_shape传给网络的第一层。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import os
import tensorflow as tf
from tensorflow.keras import datasets, layers, modelsclass CNN(object):def __init__(self):model = models.Sequential()# 第1层卷积,卷积核大小为3*3,32个,28*28为待训练图片的大小model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))model.add(layers.MaxPooling2D((2, 2)))# 第2层卷积,卷积核大小为3*3,64个model.add(layers.Conv2D(64, (3, 3), activation='relu'))model.add(layers.MaxPooling2D((2, 2)))# 第3层卷积,卷积核大小为3*3,64个model.add(layers.Conv2D(64, (3, 3), activation='relu'))model.add(layers.Flatten())model.add(layers.Dense(64, activation='relu'))model.add(layers.Dense(10, activation='softmax'))model.summary()self.model = model

model.summary()用来打印我们定义的模型的结构。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 26, 26, 32)        320       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 11, 11, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 3, 3, 64)          36928     
_________________________________________________________________
flatten (Flatten)            (None, 576)               0         
_________________________________________________________________
dense (Dense)                (None, 64)                36928     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                650       
=================================================================
Total params: 93,322
Trainable params: 93,322
Non-trainable params: 0
_________________________________________________________________

我们可以看到,每一个Conv2DMaxPooling2D层的输出都是一个三维的张量(height, width, channels)。height和width会逐渐地变小。输出的channel的个数,是由第一个参数(例如,32或64)控制的,随着height和width的变小,channel可以变大(从算力的角度)。

模型的后半部分,是定义输出张量的。layers.Flatten会将三维的张量转为一维的向量。展开前张量的维度是(3, 3, 64) ,转为一维(576)的向量后,紧接着使用layers.Dense层,构造了2层全连接层,逐步地将一维向量的位数从576变为64,再变为10。

后半部分相当于是构建了一个隐藏层为64,输入层为576,输出层为10的普通的神经网络。最后一层的激活函数是softmax,10位恰好可以表达0-9十个数字。

最大值的下标即可代表对应的数字,使用numpy很容易计算出来:

1
2
3
4
5
6
import numpy as npy1 = [0, 0.8, 0.1, 0.1, 0, 0, 0, 0, 0, 0]
y2 = [0, 0.1, 0.1, 0.1, 0.5, 0, 0.2, 0, 0, 0]
np.argmax(y1) # 1
np.argmax(y2) # 4

4. mnist数据集预处理(train.py)

1
2
3
4
5
6
7
8
9
10
11
12
13
class DataSource(object):def __init__(self):# mnist数据集存储的位置,如何不存在将自动下载data_path = os.path.abspath(os.path.dirname(__file__)) + '/../data_set_tf2/mnist.npz'(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data(path=data_path)# 6万张训练图片,1万张测试图片train_images = train_images.reshape((60000, 28, 28, 1))test_images = test_images.reshape((10000, 28, 28, 1))# 像素值映射到 0 - 1 之间train_images, test_images = train_images / 255.0, test_images / 255.0self.train_images, self.train_labels = train_images, train_labelsself.test_images, self.test_labels = test_images, test_labels

因为mnist数据集国内下载不稳定,因此数据集也同步到了Github仓库。

对mnist数据集的介绍,大家可以参考这个系列的第一篇文章TensorFlow入门(一) - mnist手写数字识别(网络搭建)。

5. 开始训练并保存训练结果(train.py)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
class Train:def __init__(self):self.cnn = CNN()self.data = DataSource()def train(self):check_path = './ckpt/cp-{epoch:04d}.ckpt'# period 每隔5epoch保存一次save_model_cb = tf.keras.callbacks.ModelCheckpoint(check_path, save_weights_only=True, verbose=1, period=5)self.cnn.model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])self.cnn.model.fit(self.data.train_images, self.data.train_labels, epochs=5, callbacks=[save_model_cb])test_loss, test_acc = self.cnn.model.evaluate(self.data.test_images, self.data.test_labels)print("准确率: %.4f,共测试了%d张图片 " % (test_acc, len(self.data.test_labels)))if __name__ == "__main__":app = Train()app.train()

在执行python train.py后,会得到以下的结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 45s 749us/sample - loss: 0.1477 - accuracy: 0.9536
Epoch 2/5
60000/60000 [==============================] - 45s 746us/sample - loss: 0.0461 - accuracy: 0.9860
Epoch 3/5
60000/60000 [==============================] - 50s 828us/sample - loss: 0.0336 - accuracy: 0.9893
Epoch 4/5
60000/60000 [==============================] - 50s 828us/sample - loss: 0.0257 - accuracy: 0.9919
Epoch 5/5
59968/60000 [============================>.] - ETA: 0s - loss: 0.0210 - accuracy: 0.9930
Epoch 00005: saving model to ./ckpt/cp-0005.ckpt
60000/60000 [==============================] - 51s 848us/sample - loss: 0.0210 - accuracy: 0.9930
10000/10000 [==============================] - 3s 290us/sample - loss: 0.0331 - accuracy: 0.9901
准确率: 0.9901,共测试了10000张图片

可以看到,在第一轮训练后,识别准确率达到了0.9536,5轮之后,使用测试集验证,准确率达到了0.9901

在第五轮时,模型参数成功保存在了./ckpt/cp-0005.ckpt。接下来我们就可以加载保存的模型参数,恢复整个卷积神经网络,进行真实图片的预测了。

6. 图片预测(predict.py)

为了将模型的训练和加载分开,预测的代码写在了predict.py中。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import tensorflow as tf
from PIL import Image
import numpy as npfrom train import CNN'''
python 3.7
tensorflow 2.0.0b0
pillow(PIL) 4.3.0
'''class Predict(object):def __init__(self):latest = tf.train.latest_checkpoint('./ckpt')self.cnn = CNN()# 恢复网络权重self.cnn.model.load_weights(latest)def predict(self, image_path):# 以黑白方式读取图片img = Image.open(image_path).convert('L')img = np.reshape(img, (28, 28, 1)) / 255.x = np.array([1 - img])# API refer: https://keras.io/models/model/y = self.cnn.model.predict(x)# 因为x只传入了一张图片,取y[0]即可# np.argmax()取得最大值的下标,即代表的数字print(image_path)print(y[0])print('        -> Predict digit', np.argmax(y[0]))if __name__ == "__main__":app = Predict()app.predict('../test_images/0.png')app.predict('../test_images/1.png')app.predict('../test_images/4.png')

最终,执行predict.py,可以看到:

1
2
3
4
5
6
7
8
9
10
$ python predict.py
../test_images/0.png
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]-> Predict digit 0
../test_images/1.png
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]-> Predict digit 1
../test_images/4.png
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]-> Predict digit 4

与TensorFlow1.0的区别总结

  1. 数据集从tensorflow.examples.tutorials.mnist切换到了tensorflow.keras.datasets
  2. Keras的接口成为了主力,datasets, layers, models都是从Keras引入的,而且在网络的搭建上,代码更少,更为简洁。

附: 推荐

  • 一篇文章入门 Python

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/547430.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机组成原理第06章在线测试,计算机组成原理第六章单元测试(二)(含答案)(4页)-原创力文档...

PAGEPAGE 1第六章单元测验 (二)书生1、用以指定待执行指令所在主存地址的寄存器是______。(单选)????A、指令寄存器IR????B、程序计数器PC????C、存储器地址寄存器MAR????D、数据缓冲寄存器2、下列关于微程序和微指令的叙述中______是正确的。(单选)????A、…

前端笔试练习一

前端笔试练习一 请编写一段程序&#xff0c;将一个对象和它直接、间接引用的所有对象的属性字符串放入一个数组。如var o {a:1,{b:2,c:{d:1}}}这里o经过处理后&#xff0c;应该得到["a","b","c","d"] 1 <!DOCTYPE html PUBLIC &qu…

职业梦想是计算机的英语作文,理想职业英语作文2篇

篇一&#xff1a;大学英语作文之我理想的工作my ideal jobMy Ideal JobAs college students, we will step into the society, and now we need to prepare for our future and arrange for our future career life, we need to take into consideration what to do in the fut…

C语言二维数组中的指针问题

#include "stdio.h" void main() {int a[5][5];int i,j;for (i0;i<5;i){for (j0;j<5;j){a[i][j] i;}} for (i0;i<5;i){for (j0;j<5;j){printf("%d ",a[i][j]);}printf("\n");} }转载于:https://blog.51cto.com/shamrock/12…

爬取微信小程序源码

爬取微信小程序源码 想知道爬取微信小程序有多简单吗&#xff1f;一张图、三个步骤&#xff0c;拿到你想要的任何微信小程序源码。

C#对称加密(AES加密)每次生成的密文结果不同思路代码分享

思路&#xff1a;使用随机向量&#xff0c;把随机向量放入密文中&#xff0c;每次解密时从密文中截取前16位&#xff0c;其实就是我们之前加密的随机向量。 代码 public static string Encrypt(string plainText, string AESKey){RijndaelManaged rijndaelCipher new Rijndael…

计算机音乐简谱图片,1(音乐简谱基本音级)_百度百科

1是指在音乐简谱中的音乐简谱基本音级。[1]1代表音阶中的1个基本音级&#xff0c;读音为Do(谐音汉字“哆”)&#xff0c;在C大调里唱Do。常用来表示音级第一位或首位。中文名哆外文名do术语范围音 高C大调里的Do英 文One在音乐简谱中&#xff0c;1代表音阶中的1个基本音级…

马老师 生产环境mysql主从复制、架构优化方案

Binlog日志(主服务器) > 中继日志(从服务器 运行一遍,保持一致)。从服务器是否要二进制日志取决于架构设计。如果二进制保存足够稳定&#xff0c;从性能上来说&#xff0c;从服务器不需要二进制日志。默认情况下&#xff0c;mysql主从复制是异步的。 异步&#xff1a;命令写…

10分钟带你学会微信小程序的反编译

以xxxxx小程序为例10分钟带你学会微信小程序的反编译 2019-11-28 12:59:26 以一个简单的例子介绍下小程序反编译操作流程 实验环境前置准备模拟器内软件安装获取小程序包开始解包导入开发者工具补充注意事项技术交流群有偿解包uniapp 逆向服务逆向教程小程序分包教程#实验环境…

中html倒入css那么套路,CSS常用套路

a标签去除原颜色(改为白色)和下划线text-decoration: none;color: #ffffff;列表标签去除默认小点:list-style-type:none;设置元素透明度&#xff1a;opacity:0.5;页面中文字无法被选中&#xff1a;user-select: none;鼠标悬停&#xff0c;样式变化的方法&#xff1a;a:hover {o…

try catch finally的执行顺序到底是怎样的?

首先执行try&#xff0c;如果有异常执行catch&#xff0c;无论如何都会执行finally一个函数中肯定会执行finally中的部分。 关于一个函数的执行过程是&#xff0c;当有return以后&#xff0c;函数就会把这个数据存储在某个位置&#xff0c;然后告诉主函数&#xff0c;我不执行了…

反编译Android APK详细操作指南

早在4年前我曾发表过一篇关于《Android开发之反编译与防止反编译》的文章&#xff0c;在该文章中我对如何在Windows平台反编译APK做了讲解&#xff0c;如今用Mac系统的同学越来越多&#xff0c;也有很多朋友问我能否出一篇关于如何在Mac平台上反编译APK的文章&#xff0c;今天呢…

Ext.grid.CheckboxSelectionModel状态设置

直接上代码&#xff1a; var model grid.getSelectionModel();model.selectAll();//选择所有行model.selectFirstRow();//选择第一行model.selectLastRow([flag]);//选择最后一行,flag为正的话保持当前已经选中的行数,不填则默认falsemodel.selectNext();//选择下一行model.se…

MySql PreparedStatement用法 及 Transaction处理

import java.sql.Connection; import java.sql.DriverManager; import java.sql.PreparedStatement; import java.sql.SQLException;public class TestJDBC {/*** 当银行转账时&#xff0c;需要在汇款人账户上扣除汇款金额&#xff0c;同时在收款人账户上存入汇款金额&#xff…

计算机硬件维修所需技能实习报告,计算机软硬件及网络维护技能实习报告.doc...

计算机软硬件及网络维护技能实习报告计软络维护主板硬盘内存CPU&#xff0c;光驱显卡声卡网卡主板&#xff0c;又叫主机板、系统板和母板&#xff1b;它安装在机箱内&#xff0c;是微机最基本的也是最重要的部件之一。光驱就是播放光盘的,一下安装程序,游戏程序的都是放在光盘里…

用idea新建springboot项目遇到的@Restcontroller不能导入的问题

我个人的解决方法如下&#xff1a; 1.springboot默认有 <dependencies><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter</artifactId></dependency><dependency><groupId>o…

Ext1.X的CheckboxSelectionModel默认全选之后不允许编辑的BUG解决方案

Ext1.X的CheckboxSelectionModel默认全选之后不允许编辑的BUG解决方案&#xff0c;ext 的CheckboxSelectionModel在后台默认选中之后&#xff0c;前台就不允许编辑的bug是存在的&#xff0c;因为CheckboxSelectionModel没有Disabled"true"的设置&#xff0c;只能想办…

广州海珠区计算机学校,2019广州海珠区电脑派位和对口直升表

点击即可领取期末各科试卷预约课程还可获赠免费的学习复习诊断— — 学而思爱智康课程优势 — —12年本地化教研沉淀个性化学习方式专属教学服务优质的教学系统2019广州海珠区电脑派位和对口直升表&#xff0c;各位的爸爸妈妈们看过来&#xff01;&#xff01;看看目标学校都招…

android的消息队列机制

android下的线程&#xff0c;Looper线程&#xff0c;MessageQueue&#xff0c;Handler&#xff0c;Message等之间的关系&#xff0c;以及Message的send/post及Message dispatch的过程。 Looper线程 我们知道&#xff0c;线程是进程中某个单一顺序的控制流&#xff0c;它是内核…

KNN算法检测手势动作

KNN算法原理&#xff1a; KNN&#xff08;k-nearest neighbor&#xff09;是一个简单而经典的机器学习分类算法&#xff0c;通过度量”待分类数据”和”类别已知的样本”的距离&#xff08;通常是欧氏距离&#xff09;对样本进行分类。 这话说得有些绕口&#xff0c;且来分解…