第七章 手写数字识别V4

news/2025/9/27 16:19:26/文章来源:https://www.cnblogs.com/lidadudu/p/19115244
# 优化:
# 增加父类Module,输出每层信息
# 增加ReLU类,Tanh类
# 增加Dropout类,随机失活,防止过拟合,提高泛化能力
# 增加Parameter类,保存权重和梯度# 导入必要的库
import numpy as np
import os
import struct# 定义导入函数
def load_images(path):with open(path, "rb") as f:data = f.read()magic_number, num_items, rows, cols = struct.unpack(">iiii", data[:16])return np.asanyarray(bytearray(data[16:]), dtype=np.uint8).reshape(num_items, 28, 28)def load_labels(file):with open(file, "rb") as f:data = f.read()return np.asanyarray(bytearray(data[8:]), dtype=np.int32)# 定义sigmoid函数
def sigmoid(x):result = np.zeros_like(x)positive_mask = x >= 0result[positive_mask] = 1 / (1 + np.exp(-x[positive_mask]))negative_mask = x < 0exp_x = np.exp(x[negative_mask])result[negative_mask] = exp_x / (1 + exp_x)return result# 定义softmax函数
def softmax(x):max_x = np.max(x, axis=-1, keepdims=True)x = x - max_xex = np.exp(x)sum_ex = np.sum(ex, axis=1, keepdims=True)result = ex / sum_exresult = np.clip(result, 1e-10, 1e10)return result# 定义独热编码函数
def make_onehot(labels, class_num):result = np.zeros((labels.shape[0], class_num))for idx, cls in enumerate(labels):result[idx, cls] = 1return result# 定义dataset类
class Dataset:def __init__(self, all_images, all_labels):self.all_images = all_imagesself.all_labels = all_labelsdef __getitem__(self, index):image = self.all_images[index]label = self.all_labels[index]return image, labeldef __len__(self):return len(self.all_images)# 定义dataloader类
class DataLoader:def __init__(self, dataset, batch_size, shuffle=True):self.dataset = datasetself.batch_size = batch_sizeself.shuffle = shuffleself.idx = np.arange(len(self.dataset))def __iter__(self):# 如果需要打乱,则在每个 epoch 开始时重新排列索引if self.shuffle:np.random.shuffle(self.idx)self.cursor = 0return selfdef __next__(self):if self.cursor >= len(self.dataset):raise StopIteration# 使用索引来获取数据batch_idx = self.idx[self.cursor : min(self.cursor + self.batch_size, len(self.dataset))]batch_images = self.dataset.all_images[batch_idx]batch_labels = self.dataset.all_labels[batch_idx]self.cursor += self.batch_sizereturn batch_images, batch_labels# 定义Module类
class Module:  # 父类Moduledef __init__(self):self.info = "Module:\n"def __repr__(self):return self.info# 定义Parameters类
class Parameters:def __init__(self, weight):self.weight = weightself.grad = np.zeros_like(weight)# 定义linear类
class Linear(Module):def __init__(self, in_features, out_features):super().__init__()  # 调用父类初始化方法self.info += f"**    Linear({in_features}, {out_features})"  # 打印信息self.W = Parameters(np.random.normal(0, 1, size=(in_features, out_features)))self.B = Parameters(np.random.normal(0, 1, size=(1, out_features)))def forward(self, x):self.x = xreturn np.dot(x, self.W.weight) + self.B.weightdef backward(self, G):self.W.grad = np.dot(self.x.T, G)self.B.grad = np.mean(G, axis=0, keepdims=True)self.W.weight -= lr * self.W.gradself.B.weight -= lr * self.B.gradreturn np.dot(G, self.W.weight.T)# 定义Sigmoid类
class Sigmoid(Module):def __init__(self):super().__init__()self.info += "**    Sigmoid()"  # 打印信息return self.infodef forward(self, x):self.result = sigmoid(x)return self.resultdef backward(self, G):return G * self.result * (1 - self.result)# 定义Tanh类
class Tanh(Module):def __init__(self):super().__init__()self.info += "**    Tanh()"  # 打印信息def forward(self, x):self.result = 2 * sigmoid(2 * x) - 1return self.resultdef backward(self, G):return G * (1 - self.result**2)# 定义Softmax类
class Softmax(Module):def __init__(self):super().__init__()self.info += "**    Softmax()"  # 打印信息def forward(self, x):self.p = softmax(x)return self.pdef backward(self, G):G = (self.p - G) / len(G)return G# 定义ReLU类
class ReLU(Module):def __init__(self):super().__init__()self.info += "**    ReLU()"  # 打印信息def forward(self, x):self.x = xreturn np.maximum(0, x)def backward(self, G):grad = G.copy()grad[self.x <= 0] = 0return grad# 定义Dropout类
class Dropout(Module):def __init__(self, p=0.3):super().__init__()self.info += f"**    Dropout(p={p})"  # 打印信息self.p = pdef forward(self, x):r = np.random.rand(*x.shape)  # 矩阵r与x的shape相同,值在0-1之间随机生成self.nagtive = r < self.px[self.nagtive] = 0return xdef backward(self, G):G[self.nagtive] = 0return G# 定义ModelList类
class ModelList:def __init__(self, layers):self.layers = layersdef forward(self, x):for layer in self.layers:x = layer.forward(x)return xdef backward(self, G):for layer in self.layers[::-1]:G = layer.backward(G)def __repr__(self):info = ""for layer in self.layers:info += layer.info + "\n"return info# 主函数
if __name__ == "__main__":# 加载训练集图片、标签train_images = (load_images(os.path.join("Python", "NLP basic", "data", "minist", "train-images.idx3-ubyte"))/ 255)train_labels = make_onehot(load_labels(os.path.join("Python", "NLP basic", "data", "minist", "train-labels.idx1-ubyte")),10,)# 加载测试集图片、标签dev_images = (load_images(os.path.join("Python", "NLP basic", "data", "minist", "t10k-images.idx3-ubyte"))/ 255)dev_labels = load_labels(os.path.join("Python", "NLP basic", "data", "minist", "t10k-labels.idx1-ubyte"))# 设置超参数epochs = 10lr = 0.08  # V2版本调整了学习率batch_size = 200# 展开图片数据train_images = train_images.reshape(60000, 784)dev_images = dev_images.reshape(-1, 784)# 调用dataset类和dataloader类train_dataset = Dataset(train_images, train_labels)train_dataloader = DataLoader(train_dataset, batch_size)dev_dataset = Dataset(dev_images, dev_labels)dev_dataloader = DataLoader(dev_dataset, batch_size)# 定义模型model = ModelList([Linear(784, 512),ReLU(),Dropout(0.2),Linear(512, 256),Tanh(),Dropout(0.1),Linear(256, 10),Softmax(),])print(model)# 训练集训练过程for e in range(epochs):for x, l in train_dataloader:# 前向传播x = model.forward(x)# 计算损失loss = -np.mean(l * np.log(x))# 反向传播G = model.backward(l)# 验证集验证并输出预测准确率right_num = 0for x, batch_labels in dev_dataloader:x = model.forward(x)pre_idx = np.argmax(x, axis=-1)  # 预测类别right_num += np.sum(pre_idx == batch_labels)  # 统计正确个数acc = right_num / len(dev_images)  # 计算准确率print(f"Epoch {e}, Acc: {acc:.4f}")

image

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/919656.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么?你的蓝牙用不了了?

什么?你的蓝牙用不了了?如果你的电脑蓝牙出现一下问题:蓝牙图标不见? 搜索不到任何设备? 只能搜到手机不能搜到耳机? 看看本篇文章给你的解决办法把!蓝牙图标不见 暂未解决。 搜索不到任何设备 首先,同时按下 …

个人可以做电影网站吗信用徐州网站建设情况

黑马程序员上海中心学姐微信&#xff1a;CZBKSH关注咳咳&#xff0c;今天学姐就来和你们说说Spring对于Java程序员的重要性。首先&#xff0c;Spring 官网首页是这么介绍自己的——“Spring: the source for modern Java”&#xff0c;这也意味着 Spring 与 Java 有着密切的关系…

做韦恩图的在线网站wordpress下载视频

简单选择排序的介绍&#xff1a;从给定的序列中&#xff0c;按照指定的规则选出某一个元素&#xff0c;再根据规定交换位置后达到有序的目的。简单选择排序的基本思想&#xff1a;假定我们的数组为int [] arr new int[n]&#xff0c;第一次我们从arr[0]~arr[n-1]中选择出最小的…

2025/9/27

2025/9/271.完成课后任务:验证码任务 2.完成课后任务:生成三十道四则运算题

30.Linux DHCP 服务器 - 详解

30.Linux DHCP 服务器 - 详解pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco"…

C# Smart3D Plate Part零件形状提取

public class ExportPartShape : BaseModalCommand{public override void OnStart(int instanceId, object argument){base.OnStart(instanceId, argument);var symFile = @"C:\Program Files (x86)\Smart3D\Comm…

威海建设局网站首页图片编辑在线

在FTP协议中&#xff0c;可以通过配置服务器端的空闲连接超时时间来设置连接的过期时间。具体步骤如下&#xff1a; 登录FTP服务器&#xff0c;进入服务器的配置文件目录。通常配置文件位于/etc或/etc/vsftpd目录下。打开FTP服务器的配置文件&#xff0c;例如vsftpd.conf。在配…

网站使用微软雅黑小程序模板免费下载

C语言实验lab10C程序设计实验报告学院&#xff1a;国际商学院班级&#xff1a;14电商专业&#xff1a;电子商务姓名&#xff1a;熊靓男日期&#xff1a;15.5.25学号&#xff1a;1420070049实验目的复习一维数组掌握二维数组参数的传递掌握排序算法实验内容消灭怪物在阳光明媚月…

化妆品网站系统规划网站制作排名

import java.util.*;public class Solution {/*** 代码中的类名、方法名、参数名已经指定&#xff0c;请勿修改&#xff0c;直接返回方法规定的值即可** * param s string字符串 * param n int整型 * return string字符串*/public String trans (String s, int n) {// write co…

长春网站建设公司会展设计效果图

作者前言 欢迎小可爱们前来借鉴我的gtiee秦老大大 (qin-laoda) - Gitee.com —————————————————————————————— 目录 查询数据 条件 逻辑运算符 模糊查询 范围查询 in 判断空 UNION 排序 聚合 分组&#xff1a;group by —————————…

网站开发工程师的证件seo技术专员招聘

本文简单记录一次实践使用过程&#xff0c;涉及presto-mysql,presto-elasticsearch&#xff0c;文中参数未做注释&#xff0c;请参考官方文档&#xff0c;希望能帮到大家1 下载安装 presto-0.228<1>下载服务端客户端相关jar<2>安装&#xff1a;1> 解压tar -zxvf…

路飞和女帝做h的网站女装网站建设计划书

正则表达式是一个特殊的字符序列&#xff0c;它能帮助你方便的检查一个字符串是否与某种模式匹配。re 模块使 Python 语言拥有全部的正则表达式功能。compile 函数根据一个模式字符串和可选的标志参数生成一个正则表达式对象。该对象拥有一系列方法用于正则表达式匹配和替换。r…

潍坊市住房和城乡建设厅网站如何自己制作链接内容

第一步&#xff1a;安装svg-sprite-loader插件 <!-- svg-sprite-loader svg雪碧图 转换工具 --> <!-- <symbol> 元素中的 path 就是绘制图标的路径&#xff0c;这种一大串的东西我们肯定没办法手动的去处理&#xff0c; 那么就需要用到插件 svg-sprite-loader …

用户体验好的网站wordpress用户修改头像

在处理多个 Python 库依赖时&#xff0c;遇到依赖冲突是很常见的&#xff0c;特别是当项目依赖的库版本相互不兼容时。要解决这些冲突&#xff0c;可以采用以下方式。 1. 虚拟环境的使用 为了避免系统级和用户级包的冲突&#xff0c;建议你使用 虚拟环境。虚拟环境为每个项目…

题解:QOJ9619/洛谷13568 [CCPC 2024 重庆站] 乘积,欧拉函数,求和(数论+状压DP)

首先将 \(\phi(x)\) 拆成 \(\phi(x)= x \prod_{p | x} \frac {p-1}{p}\),发现我们要求的式子其实可以转化为 \(\sum_{S} (\prod a_i)\prod_{p|\prod a_i} \frac {p-1}{p}\)。 发现其实我们只关心哪些质数 \(p\) 在最终…

Momentum Gradient Descent(动量梯度下降)

Momentum Gradient Descent(动量梯度下降)是标准梯度下降(SGD)的一个重要改进版,旨在加速训练过程,并帮助模型更有效地找到最优解。 你可以将动量(Momentum)想象成物理学中的惯性。动量梯度下降(Momentum GD)…

Halcon算子——2D几何变换

齐次坐标 介绍仿射变换前,先介绍什么是齐次坐标。对于一个平面像素点,我们可以通过坐标(x,y)描述其位置。但是当涉及平移时,如果仅仅使用它对应的坐标向量[x,y],就必须通过向量加法来描述其位移。 而齐次坐标的引入…

深入解析:深度解析 CUDA-QX 0.4 加速 QEC 与求解器库

深入解析:深度解析 CUDA-QX 0.4 加速 QEC 与求解器库pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas&qu…

网站建设360网站建设优化汕头

Plupload有以下功能和特点&#xff1a; 1、拥有多种上传方式&#xff1a;HTML5、flash、silverlight以及传统的<input type”file” />。Plupload会自动侦测当前的环境&#xff0c;选择最合适的上传方式&#xff0c;并且会优先使用HTML5的方式。所以你完全不用去操心当前…

电子商务网站开发意义深圳商业网站建设案例

“工作三年&#xff0c;并不等于拥有三年的工作经验。”这句话告诉我们每天都要思考自己当天所遇到的问题&#xff0c;记录下来&#xff0c;并且思考这个问题的解决办法&#xff0c;每一周或两周总结这些问题和解决办法&#xff0c;归纳思考问题根源&#xff0c;学习解决问题的…