6-6 卷积神经网络LeNet

news/2025/9/21 23:39:36/文章来源:https://www.cnblogs.com/xxb667/p/19104384

1.LeNet

import torch
from torch import nn
from d2l import torch as d2lnet = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16*5*5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10)
)
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape: \t', X.shape)
Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])

2.模型训练

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)
def evaluate_accuracy_gpu(net, data_iter, device=None):'''使用GPU计算模型在数据集上的精度'''if isinstance(net, nn.Module):net.eval() # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需的X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):'''用GPU训练模型'''def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():# metric.add(1*X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])metric.add(l.sum().item(), d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_1 = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i+1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch+(i+1) / num_batches,(train_1, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch+1, (None, None, test_acc))print(f'loss {train_1:.3f}, train acc {train_acc:.3f},'f'test acc {test_acc:.3f}')print(f'{metric[2]*num_epochs / timer.sum():.1f} examples/sec'f'on {str(device)}')
lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.002, train acc 0.827,test acc 0.803
9020.8 examples/secon cpu

image


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/909127.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

5-5读写文件

本章主要介绍将训练后的数据保存到文件中1.加载和保存张量 import torch from torch import nn from torch.nn import functional as Fx = torch.arange(4) # 把 Python 对象 x 打包成字节流,原封不动地写进文件 x-fi…

6-2图像卷积

本章主要介绍二维卷积和图像卷积的计算1.二维卷积计算 import torch from torch import nn from d2l import torch as d2l定义二维卷积函数 def corr2d(X, K):计算二维互相关运算h, w = K.shapeY = torch.zeros((X.sha…

二叉树的高度和判断平衡二叉树

LCR 176. 判断是否为平衡二叉树 利用递归得出结果,平衡二叉树成立的条件:左子树和右子树之差的绝对值小于等于 1,也就是当左子树高度 - 右子树高度的差值等于 0或者等于1的时候该平衡二叉树成立。 那么我们可以利用…

20250921 之所思 - 人生如梦

20250921 之所思一大早就收到老板要求每天晚上十点开会的信息,顿时心情很糟,因为晚上十点开会,开完就已经接近十二点,很害怕自己会彻夜难眠,然后起床就一直在想这件事,顿时整个早上的心情都受到了影响。还为这个…

UE5 Cook数据结构

UE Cook 数据结构 本篇讲非 MPCook 的数据结构1. FPackageDatas核心类 FPackageDatas 是管理 CookOnTheFlyServer 里的所有 PackageData 的列表的类。PackageDatas 是一个关联数组,存储 COTFS 需要的 package(如 coo…

通过微信对客服系统客户进行消息提醒,比如客户快过期了,访客发来的消息也是通过模板消息通知给客服

vx: llike620我的客服系统已经通过自己开发的形式实现了对接 客户服务到期提醒​​和​​客服消息通知​​——正是模板消息功能的典型和优秀应用案例。 作为开发者,您肯定关心如何将现有系统做得更健壮、更高效。以…

WPF治具软件模板分享 - Dragonet

目录WPF治具软件模板分享程序功能介绍功能实现导航功能程序配置日志功能界面介绍 WPF治具软件模板分享 运行环境:VS2022 .NET 8.0 完整项目:Gitee仓库 项目重命名方法参考:网页概要:针对治具单机软件制作了一个设…

基于WOA鲸鱼优化的XGBoost序列预测算法matlab仿真

1.算法运行效果图预览 (完整程序运行后无水印)2.算法运行软件版本 matlab2024b3.部分核心程序 (完整版代码包含详细中文注释和操作步骤视频)%最大迭代次数 paramters.maxiter = 50; paramters.tr…

软件工程第二次作业——个人项目

这个作业属于哪个课程 https://edu.cnblogs.com/campus/gdgy/Class12Grade23ComputerScience这个作业要求在哪里 https://edu.cnblogs.com/campus/gdgy/Class12Grade23ComputerScience/homework/13468这个作业的目标 &…

微信扫码二维码,关注绑定公众号提醒,利用微信公众号的模板消息进行消息通知的推送

gofly.v1kf.com vx: llike620我的客服系统已经通过自己开发的形式实现了对接希望通过微信扫码关注公众号,并利用模板消息功能实现消息推送。这是一个非常实用的需求,尤其在服务通知和用户互动方面能极大提升体验。其…

Arch下实现人脸识别登录:howdy的配置与使用

安装Howdy 查阅Arch Linux中文Wiki的教程[1]可知:howdy包已无法在最新的Arch Linux上正常使用,推荐安装howdy-git包那么,就 yay -S howdy-git查看摄像头路径 如果电脑只有一个摄像头的话,一般而言,摄像头的路径是…

fedora无法看视频?编解码器详细安装教程【转发】

fedora无法看视频?编解码器详细安装教程【转发】原文:https://zhuanlan.zhihu.com/p/26494803528 启用rpm fusion 包 free包: sudo dnf install https://download1.rpmfusion.org/free/fedora/rpmfusion-free-relea…

Winform的Formborder.None情况下,解决不能拖动的问题

using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; using System.Text; using System.Threading.Tasks; using System.Windows; using System.Windows.Cont…

Salephpscripts Web_Directory_Free SQL注入漏洞利用分析(CVE-2024-3552)

本文详细分析了Salephpscripts Web_Directory_Free插件中存在的SQL注入漏洞(CVE-2024-3552),包含漏洞环境搭建方法和利用步骤,涉及Docker容器部署和本地测试环境配置。Exploit for SQL Injection in Salephpscript…

12306高并发架构设计:基于区间计数器的网关层拒单方案

引言 在上一篇文章《重新理解12306:它卖的从来不是“库存”,而是“状态”》,我们深入探讨了12306的业务模型核心:它不是简单的库存管理系统,而是基于座位段的状态管理。每个座位被拆分为多个段(例如A-B、B-C、C-…

各位同学,大家好!我想请大家回忆一段我们在刘集中学的故事,和我单独联系。我想把这些故事写出来保存。欢迎与我分享!谢谢!

各位同学,大家好!我想请大家回忆一段我们在刘集中学的故事,和我单独联系。我想把这些故事写出来保存。欢迎与我分享!谢谢! 初三时周杰伦的歌曲开始出现,有段时间教室早上会播放他的歌曲,又《双截棍》《霍元甲》…

实用指南:centos sshd:xxx.xxx.xxx.xxx:allow 如何设置

实用指南:centos sshd:xxx.xxx.xxx.xxx:allow 如何设置pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas&…

fedora无法看视频?编解码器详细安装教程

fedora无法看视频?编解码器详细安装教程启用rpm fusion 包 free包: sudo dnf install https://download1.rpmfusion.org/free/fedora/rpmfusion-free-release-$(rpm -E %fedora).noarch.rpm nofree包: sudo dnf ins…

vite7-vue3-os网页os管理|vue3+vite7+arco.design网页pc版webos系统

最新研发Vite7+Vue3+Pinia3+Arco仿macos/windows网页版webos管理系统。 vite7-webos原创基于vite7.1+vue3.5+pinia3+arco-design+echarts从0-1搭建pc网页版os式管理系统模板。支持macos+windows两种桌面布局风格、自定…

高并发高吞吐量

Java实现高并发需从底层机制、并发控制、资源调度、架构设计、编码细节等多维度系统优化,每个维度聚焦特定技术方向,覆盖从底层到应用的全链路性能提升: 一、底层IO与网络优化(提升数据传输效率)IO模型升级网络通…