【神经网络】python实现神经网络(一)——数据集获取

一.概述

        在文章【机器学习】一个例子带你了解神经网络是什么中,我们大致了解神经网络的正向信息传导、反向传导以及学习过程的大致流程,现在我们正式开始进行代码的实现,首先我们来实现第一步的运算过程模拟讲解:正向传导。本次代码实现将以“手写数字识别”为例子。

二.测试训练数据集的获取

        首先我们需要通过官网获取到手写数字识别数据集,数据集一共分为四个部分,分别是训练集的图片(六万张)、训练集的标签、测试集的图片(一万张)以及测试集的标签。所以我们在代码中可以使用键值表示对应的key-value:

url_base = 'http://yann.lecun.com/exdb/mnist/'
key_file = {'train_img':'train-images-idx3-ubyte.gz','train_label':'train-labels-idx1-ubyte.gz','test_img':'t10k-images-idx3-ubyte.gz','test_label':'t10k-labels-idx1-ubyte.gz'
}

        同时,我们需要将下载的文件保存到与代码同一级目录下:

dataset_dir = os.path.dirname(os.path.abspath(__file__))

        下载部分十分简单么,就不在此赘述,需要注意的是代码使用了python的urlretrieve函数,该函数需要使用头文件urllib.request,需要自行下载:

def download_mnist():for filename in key_file.values():file_path = dataset_dir + "/" + filenameif os.path.exists(file_path):returnprint("Downloading " + filename + " ... ")urllib.request.urlretrieve(url_base + filename, file_path)print("Done")

三.测试训练数据集的加载

        下载完数据集后,我们需要将其加载到我们的程序中以供后续的使用,首先是判断一下我们是否已经下载过数据集,如果没有下载,则先进行下载操作,再执行其他步骤:

    if not os.path.exists(save_file) :download_mnist()dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")

        以上代码有个需要注意的地方,因为下载完数据集之后无法直接给到python使用,所以还需要对数据进行格式处理,处理成python可以识别的格式,这一步交由函数_convert_numpy实现:

def _convert_numpy():    dataset = {}dataset['train_img'] = _load_img(key_file['train_img'])dataset['train_label'] = _load_label(key_file['train_label'])dataset['test_img'] = _load_img(key_file['test_img'])dataset['test_label'] = _load_label(key_file['test_label'])return dataset

       其中,_load_img函数负责处理图片数据:

def _load_img(file_name):file_path = dataset_dir + "\\MNIST\\" + file_nameprint("Converting " + file_name + " to NumPy Array ...")with gzip.open(file_path, 'rb') as f:data = np.frombuffer(f.read(), np.uint8, offset=16)data = data.reshape(-1, img_size)print("Done")return data

        其中,_load_label函数负责处理标签数据:

def _load_label(file_name):file_path = dataset_dir + "\\MNIST\\" + file_nameprint("Converting " + file_name + " to NumPy Array ...")with gzip.open(file_path, 'rb') as f:labels = np.frombuffer(f.read(), np.uint8, offset=8)print("Done")return labels

        函数中使用到的都是一些python常用的函数,所以具体作用不在赘述,可自行查询。介绍完_convert_numpy函数,我们继续回到数据集加载函数本身,为了方便后续数据集的批量调用等操作,我们需要在加载数据后对其进行进一步的数据清洗整理等预处理,分别为数据归一化(normalize)、图像展开(flatten)以及图像标签对应(one_hot_label),先将三个功能代码贴上,然后我们再详细讲解各个功能的具体作用:

    with open(save_file,'rb') as f:dataset = pickle.load(f)if normalize:for key in ['train_img','test_img']:dataset[key] = dataset[key].astype(np.float32)if not flatten:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].reshape(-1, 1, 28, 28)if one_hot_label:dataset['train_label'] = _change_one_hot_label(dataset['train_label'])dataset['test_label'] = _change_one_hot_label(dataset['test_label'])

3.1.数据归一化(normalize)

        数据归一化normalize如果设置为True,可以将输入图像归一化为0.0~1.0 的值。如果将该参数设置为False,则输入图像的像素会保持原来的0~255。函数实现是使用了python函数中的astype功能将数据,用于将数据集指定字段的数据转换为 float32 类型,常见于深度学习模型输入前的数据预处理。

dataset[key] = dataset[key].astype(np.float32)

3.2.图像展开(flatten)

        图像展开flatten用于设置是否展开输入图像使其变成一维数组。如果将该参数设置为False,则输入图像为1 × 28 × 28 的三维数组;若设置为True,则输入图像会保存为由784 个元素构成的一维数组。函数实现也只是使用到深度学习中常用的reshape函数:

 dataset[key] = dataset[key].reshape(-1, 1, 28, 28)

3.3.图像标签对应(one_hot_label)

        图像标签对应one_hot_label用于设置是否将标签保存为onehot表示(one-hot representation)。one-hot 表示是仅正确解标签为1,其余皆为0 的数组,就像[0,0,1,0,0,0,0,0,0,0]这样。当one_hot_label为False时,就是像7、2这样简单保存正确解标签,函数_change_one_hot_label的实现如下:

def _change_one_hot_label(X):T = np.zeros((X.size, 10))for idx, row in enumerate(T):row[X[idx]] = 1return T

        以上即为测试训练数据集加载函数的全部内容,我们将在下面正式调用一下看看是否能够正常工作,在此贴上函数全文:

ef load_mnist(normalize=True, flatten=True, one_hot_label=False):if not os.path.exists(save_file) :download_mnist()dataset = _convert_numpy()print("Creating pickle file ...")with open(save_file, 'wb') as f:pickle.dump(dataset, f, -1)print("Done!")with open(save_file,'rb') as f:dataset = pickle.load(f)if normalize:for key in ['train_img','test_img']:dataset[key] = dataset[key].astype(np.float32)if not flatten:for key in ('train_img', 'test_img'):dataset[key] = dataset[key].reshape(-1, 1, 28, 28)if one_hot_label:dataset['train_label'] = _change_one_hot_label(dataset['train_label'])dataset['test_label'] = _change_one_hot_label(dataset['test_label'])return (dataset['train_img'],dataset['train_label']),(dataset['test_img'],dataset['test_label'])

四.测试训练数据集的使用测试

        我们可以加载数据集并且查看到各个数据集的形状:

(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True,normalize=False)
# 输出各个数据的形状
print(x_train.shape) # (60000, 784)
print(t_train.shape) # (60000,)
print(x_test.shape) # (10000, 784)
print(t_test.shape) # (10000,)

        根据输出我们可以看到,训练集图片有六万张,每张图片有784各像素(28*28),训练集标签和照片数量一样(那是肯定的),测试集图片和标签数量比训练集的少,主要用来验证模型学习后的正确性。

        我们甚至还能随机从数据集中抽取一张照片查看一下实际样子,具体实现如下:

def img_show(img):
pil_img = Image.fromarray(np.uint8(img))
pil_img.show()
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True,normalize=False)
img = x_train[0]
label = t_train[0]
print(label) # 5
print(img.shape) # (784,)
img = img.reshape(28, 28) # 把图像的形状变成原来的尺寸
print(img.shape) # (28, 28)
img_show(img)

        输出的图片如图下所示:

        在后面的文章中,我们将开始正式步入主题,讲解神经网络如何学习,各层次之间如何传递数值,如何反向传导,计算损失,又在重新学习,最终实现传入一张手写数字就能自动识别出具体的数字的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/72819.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Sentinel 笔记

Sentinel 笔记 1 介绍 Sentinel 是阿里开源的分布式系统流量防卫组件,专注于 流量控制、熔断降级、系统保护。 官网:https://sentinelguard.io/zh-cn/index.html wiki:https://github.com/alibaba/Sentinel/wiki 对比同类产品&#xff1…

manus本地部署方法研究测试

Manus本地部署方法,Manus邀请码实在太难搞了,昨晚看到有一个团队,5个人3个小时,一个完全免费、无需排队等待的OpenManus就做好了。 由于也是新手,找了好几轮,实在是没有找到合适的部署方法,自己…

FreeRTOS第15篇:FreeRTOS链表实现细节03_List_t与ListItem_t的奥秘

文/指尖动听知识库-星愿 文章为付费内容,商业行为,禁止私自转载及抄袭,违者必究!!! 文章专栏:深入FreeRTOS内核:从原理到实战的嵌入式开发指南 1 FreeRTOS列表的核心数据结构 FreeRTOS的列表实现由两个关键结构体组成:List_t(列表)和ListItem_t(列表项)。它们共同…

gzip压缩

什么是Gzip 前端优化:开启Gzip压缩_前端开启gzip压缩-CSDN博客 Gzip是一种文件压缩算法,减少文件大小,节省带宽从而提减少网络传输时间,网站会更快地加载。 如何判断是否开启: 请求头:服务端会通过客户…

机器学习在地图制图学中的应用

原文链接:https://www.tandfonline.com/doi/full/10.1080/15230406.2023.2295948#abstract CSDN/2025/Machine learning in cartography.pdf at main keykeywu2048/CSDN GitHub 核心内容 本文是《制图学与地理信息科学》特刊的扩展评论,系统探讨了机…

智慧消防新篇章:4G液位/压力传感器,筑牢安全防线!

火灾无情,防患未“燃”!在智慧消防时代,如何实现消防水系统的实时监测、预警,保障人民生命财产安全?山东一二三物联网深耕物联网领域,自主研发4G液位、4G压力智能传感器,为智慧消防水位、水压无…

set、LinkedHashSet和TreeSet的区别、Map接口常见方法、Collections 工具类使用

DAY7.2 Java核心基础 想学习Collection、list、ArrayList、Set、HashSet部分的小伙伴可以转到 7.1集合框架、Collection、list、ArrayList、Set、HashSet和LinkedHashSet、判断两个对象是否相等文章查看 set集合 在set集合中,处理LinkedHashSet是有序的&#xf…

windows:curl: (60) schannel: SEC_E_UNTRUSTED_ROOT (0x80090325)

目录 1. git update-git-for-windows 报错2. 解决方案2.1. 更新 CA 证书库2.2. 使用 SSH 连接(推荐)2.3 禁用 SSL 验证(不推荐) 1. git update-git-for-windows 报错 LenovoLAPTOP-EQKBL89E MINGW64 /d/YHProjects/omni-channel-…

《深度剖析架构蒸馏与逻辑蒸馏:探寻知识迁移的差异化路径》

在人工智能模型优化的前沿领域,架构蒸馏与逻辑蒸馏作为知识蒸馏的关键分支,正引领着模型小型化与高效化的变革浪潮。随着深度学习模型规模与复杂度的不断攀升,如何在资源受限的情况下,实现模型性能的最大化,成为了学术…

先序二叉树的线索化,并找指定结点的先序后继

#include<stdio.h> #include<stdlib.h> #define elemType char //线索二叉树结点 typedef struct ThreadNode{ elemType data; struct ThreadNode *lchild,*rchild; int ltag,rtag;//用来判断一个结点是否有线索 }ThreadNode,*ThreadTree; //全局变量…

蚂蚁集团转正实习大模型算法岗内推

1.负责以大模型为代表的A转术能力的建设和优化&#xff0c;打造业界领先的A(技术系统&#xff0c;主要职责包括A系统结构设计、RAG 系统开发、大模型凯练数据构建、大模型能力评测、大模型准理效果和效率优化等 2.紧密跟踪、探索大模型方向前沿技术&#xff0c;依托丰富目体系化…

未授权漏洞大赏

ActiveMQ未授权访问漏洞 漏洞描述 Apache ActiveMQ是美国阿帕奇&#xff08;Apache&#xff09;软件基金会所研发的一套开源的消息中间件&#xff0c;它支持Java消息服务、集群、Spring Framework等。 Apache ActiveMQ管理控制台的默认管理用户名和密码分别为admin和admin&am…

Python包结构与 `__init__.py` 详解

1. 什么是 __init__.py&#xff1f; __init__.py 是Python包的标识文件&#xff0c;它告诉Python解释器这个目录应该被视为一个包&#xff08;Package&#xff09;。这个文件可以为空&#xff0c;也可以包含初始化代码。 1.1 基本作用 包的标识 将普通目录转换为Python包允许…

Web前端开发——HTML基础下

HTML语法 一表格1.基本格式2.美化表格合并居中属性 二表单1.input2.select3.textarea4.button5.date6.color7.checkbox8.radio9.range10.number 一表格 1.基本格式 HTML表格由<table>标签定义 其中行由<tr>标签定义&#xff0c;单元格由<td>定义。我们先来…

小程序事件系统 —— 33 事件传参 - data-*自定义数据

事件传参&#xff1a;在触发事件时&#xff0c;将一些数据作为参数传递给事件处理函数的过程&#xff0c;就是事件传参&#xff1b; 在微信小程序中&#xff0c;我们经常会在组件上添加一些自定义数据&#xff0c;然后在事件处理函数中获取这些自定义数据&#xff0c;从而完成…

安卓设备root检测与隐藏手段

安卓设备root检测与隐藏手段 引言 安卓设备的root权限为用户提供了深度的系统控制能力&#xff0c;但也可能带来安全风险。因此&#xff0c;许多应用&#xff08;如银行软件、游戏和流媒体平台&#xff09;会主动检测设备是否被root&#xff0c;并限制其功能。这种对抗催生了ro…

如何在Ubuntu上直接编译Apache Doris

以下是在 Ubuntu 22.04 上直接编译 Apache Doris 的完整流程&#xff0c;综合多个版本和环境的最佳实践&#xff1a; 注意&#xff1a;Ubuntu的数据盘VMware默认是20G&#xff0c;编译不够用&#xff0c;给到50G以上吧 一、环境准备 1. 安装系统依赖 # 基础构建工具链 apt i…

vuejs相关链接和格式化插件推荐

vue官网&#xff1a; https://cn.vuejs.org/ 配合路由设置&#xff1a; https://router.vuejs.org/zh/guide/ element plus (vue3) | element UI (vue2)&#xff1a; https://element-plus.org/zh-CN/#/zh-CN 构建工具vite&#xff1a; https://cn.vitejs.dev/ 右键选择…

IDEA中Git版本回退终极指南:Reset与Revert双方案详解

目录 前言一、版本回退前置知识二、Reset方案&#xff1a;整体改写历史1、IDEA图形化操作&#xff08;推荐&#xff09;1.1、查看提交历史1.2、选择目标版本1.3、选择回退模式1.3.1、Soft&#xff08;推荐&#xff09;1.3.2、Mixed1.3.3、Hard&#xff08;慎用&#xff09;1.3.…

PHP并发请求优化:使用`curl_multi_select()`实现高效的多请求处理

PHP并发请求优化&#xff1a;使用curl_multi_select()实现高效的多请求处理 背景 最近在项目中遇到一个需求&#xff0c;需要从多个 1 级网站&#xff08;超过 200 个&#xff09;获取数据&#xff0c;并且是通过 POST 请求瞬间发送到这些网站上。开始时我直接使用了 curl_ex…