深度学习(1)-简单神经网络示例

我们来看一个神经网络的具体实例:使用Python的Keras库来学习手写数字分类。在这个例子中,我们要解决的问题是,将手写数字的灰度图像(28像素×28像素)划分到10个类别中(从0到9)​。我们将使用MNIST数据集,图2-1给出了MNIST数据集的一些样本。
在这里插入图片描述
在机器学习中,分类问题中的某个类别叫作类(class)​,数据点叫作样本(sample)​,与某个样本对应的类叫作标签(label)​。你不需要现在就尝试在计算机上运行这个例子。如果你想这么做,那么首先需要建立深度学习工作区(见第3章)​。MNIST数据集已预先加载在Keras库中,其中包含4个NumPy数组,如代码清单2-1所示。

from tensorflow.keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images和train_labels组成了训练集,模型将从这些数据中进行学习。然后,我们在测试集(包括test_images和test_labels)上对模型进行测试。图像被编码为NumPy数组,而标签是一个数字数组,取值范围是0~9。图像和标签一一对应。我们来看一下训练数据:

>>> train_images.shape
(60000, 28, 28)
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

再来看一下测试数据:

>>> test_images.shape
(10000, 28, 28)
>>> len(test_labels)
10000
>>> test_labels
array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)

工作流程如下:首先,将训练数据(train_images和train_labels)输入神经网络;然后,神经网络学习将图像和标签关联在一起;最后,神经网络对test_images进行预测,我们来验证这些预测与test_labels中的标签是否匹配。下面我们来构建神经网络,如代码清单2-2所示。

代码清单2-2 神经网络架构

from tensorflow import keras
from tensorflow.keras import layers
model = keras.Sequential([layers.Dense(512, activation="relu"),layers.Dense(10, activation="softmax")
])

**神经网络的核心组件是层(layer)**​。你可以将层看成数据过滤器:进去一些数据,出来的数据变得更加有用。具体来说,层从输入数据中提取表示——我们期望这种表示有助于解决手头的问题。大多数深度学习工作涉及将简单的层链接起来,从而实现渐进式的数据蒸馏(data distillation)​。深度学习模型就像是处理数据的筛子,包含一系列越来越精细的数据过滤器(也就是层)​。**本例中的模型包含2个Dense层,它们都是密集连接(也叫全连接)的神经层。第2层(也是最后一层)是一个10路softmax分类层,它将返回一个由10个概率值(总和为1)组成的数组。**每个概率值表示当前数字图像属于10个数字类别中某一个的概率。在训练模型之前,我们还需要指定编译(compilation)步骤的3个参数。优化器(optimizer)​:模型基于训练数据来自我更新的机制,其目的是提高模型性能。损失函数(loss function)​:模型如何衡量在训练数据上的性能,从而引导自己朝着正确的方向前进。在训练和测试过程中需要监控的指标(metric)​:本例只关心精度(accuracy)​,即正确分类的图像所占比例。

后面两章会详细介绍损失函数和优化器的确切用途。代码清单2-3展示了编译步骤。

代码清单2-3 编译步骤

model.compile(optimizer="rmsprop",loss="sparse_categorical_crossentropy",metrics=["accuracy"])

在开始训练之前,我们先对数据进行预处理,将其变换为模型要求的形状,并缩放到所有值都在[0, 1]区间。前面提到过,训练图像保存在一个uint8类型的数组中,其形状为(60000, 28, 28),取值区间为[0, 255]。我们将把它变换为一个float32数组,其形状为(60000, 28 * 28),取值范围是[0, 1]。下面准备图像数据,如代码清单2-4所示。

代码清单2-4 准备图像数据

train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype("float32") / 255

现在我们准备开始训练模型。在Keras中,这一步是通过调用模型的fit方法来完成的——我们在训练数据上拟合(fit)模型,如代码清单2-5所示。
代码清单2-5 拟合模型

>>> model.fit(train_images, train_labels, epochs=5, batch_size=128)
Epoch 1/5
60000/60000 [===========================] - 5s - loss: 0.2524 - acc: 0.9273
Epoch 2/5
51328/60000 [=====================>.....] - ETA: 1s - loss: 0.1035 - acc: 0.9692

训练过程中显示了两个数字:一个是模型在训练数据上的损失值(loss)​,另一个是模型在训练数据上的精度(acc)​。我们很快就在训练数据上达到了0.989(98.9%)的精度。现在我们得到了一个训练好的模型,可以利用它来预测新数字图像的类别概率(见代码清单2-6)​。这些新数字图像不属于训练数据,比如可以是测试集中的数据。

代码清单2-6 利用模型进行预测

>>> test_digits = test_images[0:10]
>>> predictions = model.predict(test_digits)
>>> predictions[0]
array([1.0726176e-10, 1.6918376e-10, 6.1314843e-08, 8.4106023e-06,2.9967067e-11, 3.0331331e-09, 8.3651971e-14, 9.9999106e-01,2.6657624e-08, 3.8127661e-07], dtype=float32)

这个数组中每个索引为i的数字对应数字图像test_digits[0]属于类别i的概率。第一个测试数字在索引为7时的概率最大(0.99999106,几乎等于1)​,所以根据我们的模型,这个数字一定是7。

>>> predictions[0].argmax()
7
>>> predictions[0][7]
0.99999106

我们可以检查测试标签是否与之一致:

>>> test_labels[0]
7

平均而言,我们的模型对这种前所未见的数字图像进行分类的效果如何?我们来计算在整个测试集上的平均精度,如代码清单2-7所示。

代码清单2-7 在新数据上评估模型

>>> test_loss, test_acc = model.evaluate(test_images, test_labels)
>>> print(f"test_acc: {test_acc}")
test_acc: 0.9785

测试精度约为97.8%,比训练精度(98.9%)低不少。训练精度和测试精度之间的这种差距是过拟合(overfit)造成的。**过拟合是指机器学习模型在新数据上的性能往往比在训练数据上要差,**它是第4章的核心主题。第一个例子到这里就结束了。你刚刚看到了如何用不到15行Python代码构建和训练一个神经网络,对手写数字进行分类。在本章和第3章中,我们会详细了解这个例子中的每一个步骤及其原理。接下来,你将学到张量(输入模型的数据存储对象)​、张量运算(层的组成要素)与梯度下降(可以让模型从训练示例中进行学习)​。

需要记住的名词:
1.类
2.样本
3.标签
4.训练集
5.测试集
6.层(layer)
7.dense
8.softmax
9.损失函数
10.指标
11.过拟合

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/71237.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入探索 C++17 中的 std::hypot:从二维到三维的欧几里得距离计算

文章目录 1. std::hypot 的起源与背景2. 三维空间中的 std::hypot3. 为什么需要 std::hypot 而不是手动计算?4. 使用 std::hypot 的示例4.1 二维空间中的应用4.2 三维空间中的应用4.3 处理浮点数溢出问题 5. std::hypot 的性能与精度6. 实际应用场景6.1 计算机图形学…

面基Spring Boot项目中实用注解一

在Spring Boot项目中,实用注解根据功能可以分为多个类别。以下是常见的注解分类、示例说明及对比分析: 1. 核心配置注解 SpringBootApplication 作用:标记主启动类,组合了Configuration、EnableAutoConfiguration和ComponentScan…

【每日论文】Latent Radiance Fields with 3D-aware 2D Representations

下载论文或阅读原文,请点击:每日论文 摘要 中文 潜在3D重建技术在赋予3D语义理解和3D生成能力方面展现出巨大的潜力,它通过将2D特征提炼到3D空间来实现。然而,现有的方法在2D特征空间和3D表示之间的领域差距问题上挣扎&#xff…

CPP集群聊天服务器开发实践(七):Github上传项目

github链接:GitHub - arduino-ctrl/ClusterServer: 基于jsonmuduomysqlnginxredis的集群服务器与客户端通信源码 步骤如下: 1. github新建代码仓库,复制url 2. git clone https://github.com/arduino-ctrl/ClusterServer.git 3. 将项目文件…

作业。。。。。

顺序表按元素删除 参数:删除元素,顺序表 1.调用元素查找的函数 4.根据下表删除 delete_sub(list,sub); //删除元素 void delete_element(int element, Sqlist *list) …

二、从0开始卷出一个新项目之瑞萨RZT2M双核架构通信和工程构建

一、概述 RZT2M双核架构是同构多核,但双核针对不同应用 扩展多核架构和通信知识可参见嵌入式科普(30)一文看懂嵌入式MCU/MPU多核架构与通信 二、参考资料 用户手册:RZ/T2M Group Users Manual: Hardware R52内核手册:arm_cortex_r52_proc…

【HF设计模式】07-适配器模式 外观模式

声明:仅为个人学习总结,还请批判性查看,如有不同观点,欢迎交流。 摘要 《Head First设计模式》第7章笔记:结合示例应用和代码,介绍适配器模式和外观模式,包括遇到的问题、采用的解决方案、遵循…

RDMA 高性能通信技术原理

目录 文章目录 目录DMA 与 RDMARDMA 特性和优势大带宽低延时 RDMA 协议栈标准RDMA 运行原理通信通路通信模型通信方式内存注册QP 建链常规流程双向控制 Send-Receive API 流程单向数据 Write API 流程单向数据 Read API 流程 RDMA Verbs API 编程基础网络连通性RDMA C/S 程序 D…

HCIA项目实践(网络)---NAT地址转化技术

十三 NAT网络地址转换技术 13.1 什么是NAT NAT(Network Address Translation)地址转换技术,是一种将内部网络的私有 IP 地址转换为外部网络的公有 IP 地址的技术。其主要作用是实现多个内部网络设备通过一个公有 IP 地址访问外部网络&#x…

【JAVA工程师从0开始学AI】,第四步:闭包与高阶函数——用Python的“魔法函数“重构Java思维

副标题:当严谨的Java遇上"七十二变"的Python函数式编程 历经变量战争、语法迷雾、函数对决,此刻我们将踏入Python最迷人的领域——函数式编程。当Java工程师还在用接口和匿名类实现回调时,Python的闭包已化身"智能机器人"…

el-tree选中数据重组成树

vueelement-ui 实现el-tree选择重新生成一个已选中的值组成新的数据树&#xff0c;效果如下 <template><div class"flex"><el-tree class"tree-row" :data"list" ref"tree" :props"{children: children, label: …

测试常见问题汇总-检查表(持续完善)

WEB页面常见的问题 按钮功能的实现&#xff1a;返回按钮是否可以正常返回 信息保存提交后&#xff0c;系统是否给出“成功”的提示信息&#xff0c;列表数据是否自动刷新 没有勾选任何记录直接点【删除】&#xff0c;是否给出“请先选择记录”的提示 删除是否有删除确认框 …

java后端开发day16--字符串(二)

&#xff08;以下内容全部来自上述课程&#xff09; 1.StringBuilder 因为StringBuilder是Java已经写好的类。 java在底层对他进行了一些特殊处理。 打印对象不是地址值而是属性值。 1.概述 StringBuilder可以看成是一个容器&#xff0c;创建之后里面的内容是可变的。 作用…

C++效率掌握之STL库:vector函数全解

文章目录 1.为什么要学习vector&#xff1f;什么是vector&#xff1f;2.vector类对象的常见构造3.vector类对象的容量操作4.vector类对象的迭代器5.vector类对象的元素修改6.vector类对象的元素访问7.vector迭代器失效问题希望读者们多多三连支持小编会继续更新你们的鼓励就是我…

人工智障的软件开发-容器化编码环境就绪-java-env

指令接收&#xff1a;「需要万能开发环境」 系统警报&#xff1a;检测到主人即将陷入"环境配置地狱" 启动救赎协议&#xff1a;构建量子化开发容器 终极目标&#xff1a;让"在我机器上能跑"成为历史文物 需求分析&#xff1a;碳基生物的先天缺陷 人类开发…

kkFileView二开之pdf转图片接口

kkFileView二开之Pdf转图片接口 kkFileView二开系列文章&#xff1a;1 kkFileView源码下载及编译2 Pdf转图片接口2.1 背景2.2 分析2.2 接口开发2.2.1 编写Pdf转图片方法2.2.2 编写转换接口 2.3 接口测试2.3.1 Pdf文件准备2.3.2 pdf2Image 3 部署 kkFileView二开系列文章&#x…

阅读论文笔记《Efficient Estimation of Word Representations in Vector Space》

这篇文章写于2013年&#xff0c;对理解 word2vec 的发展历程挺有帮助。 本文仅适用于 Word2Vect 的复盘 引言 这篇论文致力于探索从海量数据中学习高质量单词向量的技术。当时已发现词向量能保留语义特征&#xff0c;例如 “国王 - 男人 女人≈女王”。论文打算借助该特性&am…

SQL注入(SQL Injection)详解与实战

文章目录 一、什么是SQL注入&#xff1f;二、常见SQL注入类型三、手动注入步骤&#xff08;以CTF题目为例&#xff09;四、CTF实战技巧五、自动化工具&#xff1a;SQLMap六、防御措施七、CTF例题八、资源推荐 一、什么是SQL注入&#xff1f; SQL注入是一种通过用户输入构造恶意…

维护ceph集群

1. set: 设置标志位 # ceph osd set <flag_name> # ceph osd set noout # ceph osd set nodown # ceph osd set norecover 2. unset: 清除标志位 # ceph osd unset <flag_name> # ceph osd unset noout # ceph osd unset nodown # ceph osd unset norecover 3. 标志…

学习threejs,使用PointLight点光源

&#x1f468;‍⚕️ 主页&#xff1a; gis分享者 &#x1f468;‍⚕️ 感谢各位大佬 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍⚕️ 收录于专栏&#xff1a;threejs gis工程师 文章目录 一、&#x1f340;前言1.1 ☘️THREE.PointLight 二、&…