神经网络字符分类

按照题目要求修改了多层感知机

题目将图片的每个点作为输入,其中大小为28*28,中间有两个大小为100的隐藏层,激活函数是relu,然后输出大小是10,激活函数是softmax

优化器是Adam,结合了AdaGrad和RMSProp算法的优点,为每个参数计算自适应的学习率。

损失函数是交叉熵损失的函数,通常用于分类问题,交叉熵损失函数衡量的是实际输出(probability distribution)与期望输出(true labels)的相似程度,在多分类问题中特别有用。

准确率(Accuracy)指标衡量的是模型预测正确的样本数与总样本数之间的比例。

epochs:训练的轮数5

batch_size:每次训练时使用的样本数量64

---------------------------------------------------------------------------------------------------------------------------------

本实践使用多层感知器训练(DNN)模型,用于预测手写数字图片。

本次实验主要考查以下内容 (1)尝试调整隐藏层单元数量、激活函数、隐藏层数量对于模型性能的影响 激活函数参照https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/Overview_cn.html#activation-functional 或paddle.nn.functional (2)调整不同的训练的迭代轮次(epoch)、学习率、优化器并学会观察训练阶段与测试阶段loss变化,并依据此调整模型 优化器、学习率可参照https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/optimizer/Overview_cn.html (2)补全测试数据集上计算accuracy的过程,可以采用model下的evaluate,也可以利用predict之后的result结果进行计算 模型训练与评估相关API调用举例 https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/Model_cn.html

首先导入必要的包

numpy---------->python第三方库,用于进行科学计算

PIL------------> Python Image Library,python第三方图像处理库

matplotlib----->python的绘图库 pyplot:matplotlib的绘图框架

os------------->提供了丰富的方法来处理文件和目录

#导入需要的包
import numpy as np
import paddle as paddle
import paddle.nn as nn
import paddle.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
import paddle
from paddle.io import Dataset
import os
print("本教程基于Paddle的版本号为:"+paddle.__version__)
! python -m pip install visualdl -i https://mirror.baidu.com/pypi/simple

Step1:准备数据。

(1)数据集介绍

MNIST数据集包含60000个训练集和10000测试数据集。分为图片和标签,图片是28*28的像素矩阵,标签为0~9共10个数字。

(2)transform函数是定义了一个归一化标准化的标准

(3)train_dataset和test_dataset

paddle.vision.datasets.MNIST()中的mode='train'和mode='test'分别用于获取mnist训练集和测试集

transform=transform参数则为归一化标准

#导入数据集Compose的作用是将用于数据集预处理的接口以列表的方式进行组合。
#导入数据集Normalize的作用是图像归一化处理,支持两种方式: 1. 用统一的均值和标准差值对图像的每个通道进行归一化处理; 2. 对每个通道指定不同的均值和标准差值进行归一化处理。
from paddle.vision.transforms import Compose, Normalize
transform = Compose([Normalize(mean=[127.5],std=[127.5],data_format='CHW')])
# 使用transform对数据集做归一化
print('下载并加载训练数据')
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
#print(np.array(test_dataset).shape)
print('加载完成')
#让我们一起看看数据集中的图片是什么样子的
train_data0, train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
print(plt.imshow(train_data0, cmap=plt.cm.binary))
print('train_data0 的标签为: ' + str(train_label_0))
AxesImage(25,22;155x154)
train_data0 的标签为: [5]
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() insteada_max = np.asscalar(a_max.astype(scaled_dtype))
#让我们再来看看数据样子是什么样的吧
print(train_data0)

Step2.网络配置

以下的代码判断就是定义一个简单的多层感知器,一共有三层,两个大小为100的隐层和一个大小为10的输出层,因为MNIST数据集是手写0到9的灰度图像,类别有10个,所以最后的输出大小是10。最后输出层的激活函数是Softmax,所以最后的输出层相当于一个分类器。加上一个输入层的话,多层感知器的结构是:输入层-->>隐层-->>隐层-->>输出层。

# 定义多层感知器  
#动态图定义多层感知器
class mnist(paddle.nn.Layer):def __init__(self):super(mnist,self).__init__()#输入通道784,输出通道100self.conv1=nn.Linear(in_features=784,out_features=100)#输入通道100,输出通道100self.conv2=nn.Linear(in_features=100,out_features=100)#输入通道100,输出通道10self.conv3=nn.Linear(in_features=100,out_features=10)def forward(self, input_):x = paddle.reshape(input_, [input_.shape[0], -1])# print(x.shape)[64, 784]y=F.relu(self.conv1(x))y=F.relu(self.conv2(y))y=F.softmax(self.conv3(y))return y

 


from paddle.metric import Accuracy# 用Model封装模型
model = paddle.Model(mnist())   # 定义损失函数
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())# 配置模型
model.prepare(optim,paddle.nn.CrossEntropyLoss(),Accuracy())

Step3.模型训练及评估

callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')
# 训练保存并验证模型
model.fit(train_dataset,test_dataset,epochs=5,batch_size=64,save_dir='multilayer_perceptron',verbose=1)#模型预测
result = model.predict(test_dataset, batch_size=1)#请补全模型性能验证代码,可使用model下的evaluate函数或者利用上面的预测出来的结果model.evaluate(test_dataset,verbose=1)
test_data0, test_label_0 = test_dataset[0][0],test_dataset[0][1]ress=model.predict_batch(test_data0)test_data0 = test_data0.reshape([28,28])
plt.figure(figsize=(2,2))
#展示测试集中的第一个图片
print(plt.imshow(test_data0, cmap=plt.cm.binary))
print('test_data0 的标签为: ' + str(test_label_0))print('test_data0 预测的数值为:' ,end='')
print(ress)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/26830.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Segmentation fault的原因和例子

最近有用cpp写点东西,然后就碰到Segmentation fault了,调试的时候,ide指出报错的地方看着没问题。后来研究发现,是递归层数太多导致的。 “Segmentation fault”(简称"segfault")是一个常见的计…

机器学习python实践——关于ward聚类分层算法的一些个人心得

最近在利用python跟着参考书进行机器学习相关实践,相关案例用到了ward算法,但是我理论部分用的是周志华老师的《西瓜书》,书上没有写关于ward的相关介绍,所以自己网上查了一堆资料,都很难说清楚ward算法,幸…

Python之Pandas详解

Pandas是Python语言的一个扩展程序库,用于数据分析。 Pandas是一个开放源码、BSD许可的库,提供高性能、易于使用的数据结构和数据分析工具。 Pandas名字衍生自术语 “panel data”(面板数据)和 “Python data analysis”&#x…

AIGC绘画设计:Midjourney V6 来袭,该版本有哪些新功能?

Midjourney V6 支持更自然的语言输入,可以处理更自然地对话式(以前的版本是以关键字为中心的)提示,对复杂提示有了更好的解释能力。大幅增加了每个 /image 的内存,可以处理更长、更详细的提示(从40 直接提升…

Spark 面试题(七)

1. Spark中的Transform和Action,为什么Spark要把操作分为Transform 和Action?常用的列举一些,说下算子原理 ? 在Spark中,操作被分为转换(Transformation)和行动(Action)…

Android framework的Zygote源码分析

文章目录 Android framework的Zygote源码分析linux的fork Android framework的Zygote源码分析 init.rc 在Android系统中,zygote是一个native进程,是Android系统上所有应用进程的父进程,我们系统上app的进程都是由这个zygote分裂出来的。zyg…

Processing java 动态海报 地球日

【Processing java 动态海报 地球日】

12、云服务器上搭建环境

云服务器上搭建环境 12.1 选择一款远程连接工具(mobax) 有很多,比如mobax、xshll等等,我这里选择mobax,下载个免费版的即可 安装完成后,双击打开: 第一步,创建远程连接的用户,用户默认为root,密码为远程服务器的密码 第二步,输入远程公网IP,选择刚刚创建的用…

[C][数据结构][排序][下][快速排序][归并排序]详细讲解

文章目录 1.快速排序1.基本思想2.hoare版本3.挖坑法4.前后指针版本5.非递归版本改写 2.归并排序 1.快速排序 1.基本思想 任取待排序元素序列的某元素作为基准值,按照该排序码将待排序集合分割成两子序列,左子序列中所有元素均小于基准值,右…

技术选型考察哪些方面

在进行技术选型时,需要考虑多个方面,确保所选择的技术能够满足项目的需求,并且在实施过程中具备可行性和可维护性。以下是一些主要考察方面: 1. 业务需求匹配 功能需求:技术能否满足当前及未来的功能需求。性能需求&…

Leetcode.2862 完全子集的最大元素和

题目链接 Leetcode.2862 完全子集的最大元素和 rating : 2292 题目描述 给你一个下标从 1 1 1 开始、由 n n n 个整数组成的数组。你需要从 n u m s nums nums 选择一个 完全集,其中每对元素下标的乘积都是一个 完全平方数,例如选择 a i a_i ai​ 和…

目标检测中的anchor机制

目录 一、目标检测中的anchor机制 1.什么是anchor boxes? 二、什么是Anchor? ​编辑三、为什么需要anchor boxes? 四、anchor boxes是怎么生成的? 五、高宽比(aspect ratio)的确定 六、尺度(scale)的…

工业高温烤箱:现代工业的重要设备

工业高温烤箱,作为现代工业生产中不可或缺的关键设备,以其独特的高温烘烤能力,为各种工业产品的加工与制造提供了强有力的支持。斯博欣将对工业高温烤箱的原理、特点、应用领域及未来发展进行简要介绍。 一、工业高温烤箱的特点 1、高温性能优…

怎么修改Visual Studio Code中现在github账号

git config --global user.name “你的用户名” git config --global user.email “你的邮箱” git config --global --list git push -u origin your_branch_name git remote add origin

编程后端:深入探索其所属的行业领域

编程后端:深入探索其所属的行业领域 在数字化浪潮席卷全球的今天,编程后端作为技术领域的重要分支,其所属的行业领域一直备受关注。本文将从四个方面、五个方面、六个方面和七个方面,深入剖析编程后端所属的行业,并揭…

FastAPI 作为H5中流式输出的后端

FastAPI 作为H5中流式输出的后端 最近大家都在玩LLM,我也凑了热闹,简单实现了一个本地LLM应用,分享给大家,百分百可以用哦~^ - ^ 先介绍下我使用的三种工具: Ollama:一个免费的开源框架&…

2024年护网行动全国各地面试题汇总(1)作者:————LJS

目录 1. SQL注入原理 2. SQL注入分类: 3. SQL注入防御: 4. SQL注入判断注入点的思路: 5. 报错注入的函数有哪些: 6. SQL注入漏洞有哪些利用手法: 1. 文件上传漏洞的绕过方法有以下几种: 2. 文件上传时突破前…

centos7 xtrabackup mysql 基本测试(4)---虚拟机环境 mysql 修改datadir(有问题)

centos7 xtrabackup mysql 基本测试(4)—虚拟机环境 mysql 修改datadir 参考 centos更改mysql数据库目录 https://blog.csdn.net/sinat_33151213/article/details/125079593 https://blog.csdn.net/jx_ZhangZhaoxuan/article/details/129139499 创建目…

锌,能否成为下一个“铜”?

光大期货认为,今年以来,市场关注锌能否接棒铜价牛市。铜需求增长空间大,而锌消费结构传统,缺乏新亮点。虽然在供应的扰动上锌强于铜,但因需求乏善可陈,金融属性弱势,锌很难接棒铜,引…