python3+TensorFlow 2.x(三)手写数字识别

目录

代码实现

模型解析:

1、加载 MNIST 数据集:

2、数据预处理:

3、构建神经网络模型:

4、编译模型:

5、训练模型:

6、评估模型:

7、预测和可视化结果:

输出结果:

总结:


代码实现

TensorFlow 2.x 实现手写数字识别(MNIST 数据集)。MNIST 数据集包含了 28x28 像素的手写数字图像,任务是将这些图像分类为 10 个类别(0-9) 

import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt# 1. 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()# 2. 数据预处理:归一化和改变形状
train_images = train_images / 255.0  # 将图像像素值归一化到 [0, 1]
test_images = test_images / 255.0# 调整形状,使得每张图片的维度是 [28, 28, 1],因为模型需要3D输入
train_images = train_images.reshape((train_images.shape[0], 28, 28, 1))
test_images = test_images.reshape((test_images.shape[0], 28, 28, 1))# 3. 构建神经网络模型
model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10, activation='softmax')  # 10类分类问题
])# 4. 编译模型:选择优化器、损失函数和评价指标
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',  # 因为标签是整数,所以使用 sparse_categorical_crossentropymetrics=['accuracy'])# 5. 训练模型
history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))# 6. 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels)
print(f"Test accuracy: {test_acc}")# 7. 可视化训练过程中的损失和准确率变化
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()# 8. 使用模型进行预测
predictions = model.predict(test_images)# 显示一些预测结果
for i in range(5):plt.imshow(test_images[i].reshape(28, 28), cmap='gray')plt.title(f"Predicted Label: {predictions[i].argmax()}, Actual Label: {test_labels[i]}")plt.show()

模型解析:

1、加载 MNIST 数据集:

使用 tf.keras.datasets.mnist.load_data() 函数来加载 MNIST 数据集。返回的数据包括训练集和测试集。训练集有 60,000 张图像,测试集有 10,000 张图像。

2、数据预处理:

将图像的像素值从 [0, 255] 归一化到 [0, 1],使每个像素的值在 0 到 1 之间,提升模型的训练效果。将每张图像的形状调整为 (28, 28, 1),即每个图像是 28x28 的灰度图像。

3、构建神经网络模型:

使用卷积神经网络(CNN)构建模型:Conv2D 层用于提取图像的特征,使用了 ReLU 激活函数。MaxPooling2D 层用于下采样,减少计算量。Flatten 层将卷积层的输出展平,进入全连接层。Dense 层用于输出分类结果,其中最后一层使用了 softmax 激活函数,将模型的输出转换为 10 类的概率分布。

4、编译模型:

使用 adam 优化器,sparse_categorical_crossentropy 作为损失函数(适用于类别标签是整数的情况),并使用 accuracy 作为评价指标。

5、训练模型:

使用 model.fit 训练模型,设置了 5 个 epoch,使用训练集进行训练,并验证模型在测试集上的表现。

6、评估模型:

使用 model.evaluate 在测试集上评估模型的准确性。并可视化训练过程中的损失和准确率变化:使用 matplotlib 绘制训练过程中的损失和准确率变化曲线,查看模型的学习进度。

7、预测和可视化结果

使用训练好的模型对测试集进行预测,展示一些预测结果,并与真实标签进行对比。

输出结果

训练和验证准确率:随着训练的进行,准确率应该逐渐提高。
测试准确率:训练完成后,模型在测试集上的准确率会显示出来,通常可以达到 98% 以上。
预测图像:展示一些手写数字图像,标注预测的标签和实际标签。

预测可视化展示

总结:

该模型使用了卷积层、池化层以及全连接层,在 MNIST 数据集上训练,最终达到了很好的分类效果。你可以调整模型的超参数(例如卷积层的数量、神经元的数量等)以提高性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/69339.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

《深度揭秘:TPU张量计算架构如何重塑深度学习运算》

在深度学习领域,计算性能始终是推动技术发展的关键因素。从传统CPU到GPU,再到如今大放异彩的TPU(张量处理单元),每一次硬件架构的革新都为深度学习带来了质的飞跃。今天,就让我们深入探讨TPU的张量计算架构…

Queries Acceleration -Tuning- Tuning Execution 学习笔记

1 Adjustment of RuntimeFilter Wait Time 1.1 Case: Too Short RuntimeFilter Wait Time 1.1.1 没有看懂,好像是等待时间过小也会导致性能下降 1.1.2 set runtime_filter_wait_time_ms = 3000; 2 Data Skew Handling 2.1 Case 1: Bucket Data Skew Leading to Suboptimal …

React应用深度优化与调试实战指南

一、渲染性能优化进阶 1.1 精细化渲染控制 typescript 复制 // components/HeavyComponent.tsx import React, { memo, useMemo } from react;interface Item {id: string;complexData: {// 复杂嵌套结构}; }const HeavyComponent memo(({ items }: { items: Item[] }) &g…

Python3 OS模块中的文件/目录方法说明十三

一. 简介 前面文章简单学习了 Python3 中 OS模块中的文件/目录的部分函数。 本文继续来学习 OS 模块中文件、目录的操作方法:os.rmdir() 方法、os.stat() 方法。 二. Python3 OS模块中的文件/目录方法说明十三 1. os.rmdir() 方法 os.rmdir() 方法用于删除指定路…

SFTP 使用方法

SFTP(SSH File Transfer Protocol)是一种安全的文件传输协议,通过 SSH(Secure Shell)提供加密的文件传输服务。SFTP 比传统的 FTP 更安全,因为它使用加密来保护传输的数据。 1. 连接到远程主机 首先&#…

Ubuntu 顶部状态栏 配置,gnu扩展程序

顶部状态栏 默认没有配置、隐藏的地方 安装使用Hide Top Bar 或Just Perfection等进行配置 1 安装 sudo apt install gnome-shell-extension-manager2 打开 安装的“扩展管理器” 3. 对顶部状态栏进行配置 使用Hide Top Bar 智能隐藏,或者使用Just Perfection 直…

【信息系统项目管理师-选择真题】2011上半年综合知识答案和详解

更多内容请见: 备考信息系统项目管理师-专栏介绍和目录 文章目录 【第1题】【第2题】【第3题】【第4题】【第5题】【第6题】【第7题】【第8题】【第9题】【第10题】【第11题】【第12题】【第13题】【第14题】【第15题】【第16题】【第17题】【第18题】【第19题】【第20题】【第…

spark运行流程

spark运行流程 任务提交后,先启动 Driver 程序随后 Driver 向集群管理器注册应用程序集群管理器根据此任务的配置文件分配 Executor 并启动Driver 开始执行 main 函数,Spark 查询为懒执行,当执行到 Action 算子时开始反向推 算,根…

Formality:时序变换(二)(不可读寄存器移除)

相关阅读 Formalityhttps://blog.csdn.net/weixin_45791458/category_12841971.html?spm1001.2014.3001.5482 一、引言 时序变换在Design Compiler的首次综合和增量综合中都可能发生,它们包括:时钟门控(Clock Gating)、寄存器合并(Register Merging)、…

QGIS3.34绿色版更新

我打包的QGIS3.34在实际工作中方便了很多初次接触GIS的朋友,感到十分欣慰!但由于初次推出也发现了一些问题,本次对该版本进行了一个更新! 还是秉承咱一贯理念,方便您使用也方便您不用!该工具还是被打包为绿…

参数是模型学会的东西,预训练是让它学习的东西

参数 就是模型“学会的东西”。这些参数是模型在训练过程中通过调整其权重来存储的知识。它们代表了模型如何处理输入数据、做出决策和生成输出。每个参数都是模型用来预测和理解语言的一部分。 预训练 就是让模型“学习的过程”。预训练阶段,模型通过大量的文本数…

寒假1.26

题解 web:[极客大挑战 2019]Havefun 打开是一个猫猫的图片 查看源代码 就是一个简单的get传参,直接在url后面输入catdog即可 有点水,再来一题 [极客大挑战 2019]LoveSQL 熟悉的界面,不熟悉的注入 尝试上次的方法,注…

Python GUI 开发 | Qt Designer — 工具介绍

关注这个框架的其他相关笔记:Python GUI 开发 | PySide6 & PyQt6 学习手册-CSDN博客 Qt Designer 即 Qt 设计师,是一个强大、灵活的可视化 GUI 设计工具,可以帮助用户加快开发 PySide6 程序的速度。 Qt Designer 是专门用来制作 PySide6…

【第九天】零基础入门刷题Python-算法篇-数据结构与算法的介绍-六种常见的图论算法(持续更新)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、Python数据结构与算法的详细介绍1.Python中的常用的图论算法2. 图论算法3.详细的图论算法1)深度优先搜索(DFS)2&#xf…

基于回归分析法的光伏发电系统最大功率计算simulink建模与仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 5.完整工程文件 1.课题概述 基于回归分析法的光伏发电系统最大功率计算simulink建模与仿真。选择回归法进行最大功率点的追踪,使用光强和温度作为影响因素,电压作为输出进行建模。…

使用Ollama部署deepseek大模型

Ollama 是一个用于部署和管理大模型的工具,而 DeepSeek 是一个特定的大模型。以下是如何使用 Ollama 部署 DeepSeek 大模型的步骤: 1. 安装 Ollama 首先,你需要在你的系统上安装 Ollama。你可以通过以下命令来安装: # 假设你已…

嵌入式蓝桥杯电子赛嵌入式(第14届国赛真题)总结

打开systic 生成工程编译查看是否有问题同时打开对应需要的文档 修改名称的要求 5.简单浏览赛题 选择题,跟单片机有关的可以查相关手册 答题顺序 先从显示开始看 1,2 所以先打开PA1的定时器这次选TIM2 从模式、TI2FP2二通道、内部时钟、1通道设为直接2通道设置…

SuperAGI - 构建、管理和运行 AI Agent

文章目录 一、关于 SuperAGI💡特点🛠 工具包 二、⚙️安装☁️SuperAGI云🖥️本地🌀 Digital Ocean 三、架构1、SuperAGI 架构2、代理架构3、代理工作流架构4、Tools 架构5、ER图 一、关于 SuperAGI SuperAGI 一个开发优先的开源…

FLTK - FLTK1.4.1 - demo - adjuster.exe

文章目录 FLTK - FLTK1.4.1 - demo - adjuster.exe概述笔记根据代码,用fluid重建一个adjuster.fl 备注 - fluid生成的代码作为参考代码好了修改后可用的代码END FLTK - FLTK1.4.1 - demo - adjuster.exe 概述 想过一遍 FLTK1.4.1的demo和测试工程,工程…

缓存策略通用分布式缓存解决方案

Cache Aside(旁路缓存)策略 Cache Aside(旁路缓存)策略是一种在应用程序中协调缓存与数据库交互的常用策略,是使用最多的策略。 基本原理 读操作:应用程序首先尝试从缓存中读取数据,如果缓存…