GoogleNet的简易实现

这里使用GooleNet对MNIST手写数据集进行分类,最后的效果达到了在测试集98%的准确率。这里关于该网络的细节可以在网络上搜索到,相关原理也可以搜索到,这里仅展示网络的代码实现,这里是基于pytorch实现的,详细的代码如下:

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optimbatch_size = 64
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))
])
train_dataset = datasets.MNIST(root=r"C:\Users\pszpszpsz\Desktop\dataset\mnist\MNIST\raw",train=True,download=True,transform=transform)
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size)
test_dataset = datasets.MNIST(root=r"C:\Users\pszpszpsz\Desktop\dataset\mnist\MNIST\raw",train=False,download=True,transform=transform)
test_loader = DataLoader(test_dataset,shuffle=False,batch_size=batch_size)class InceptionA(torch.nn.Module):def __init__(self,in_channels):super(InceptionA,self).__init__()self.branch1x1 = torch.nn.Conv2d(in_channels,16,kernel_size=1)self.branch5x5_1 = torch.nn.Conv2d(in_channels,16,kernel_size=1)self.branch5x5_2 = torch.nn.Conv2d(16, 24, kernel_size=5,padding=2)self.branch3x3_1 = torch.nn.Conv2d(in_channels, 16, kernel_size=1)self.branch3x3_2 = torch.nn.Conv2d(16, 24, kernel_size=3,padding=1)self.branch3x3_3 = torch.nn.Conv2d(24, 24, kernel_size=3, padding=1)self.branch_pool = torch.nn.Conv2d(in_channels, 24, kernel_size=1)def forward(self,x):branch1x1 = self.branch1x1(x)branch5x5 = self.branch5x5_1(x)branch5x5 = self.branch5x5_2(branch5x5)branch3x3 = self.branch3x3_1(x)branch3x3 = self.branch3x3_2(branch3x3)branch3x3 = self.branch3x3_3(branch3x3)branch_pool = F.avg_pool2d(x,kernel_size=3,stride=1,padding=1)branch_pool = self.branch_pool(branch_pool)outputs = [branch1x1,branch5x5,branch3x3,branch_pool]return torch.cat(outputs,dim=1)class Net(torch.nn.Module):def __init__(self):super(Net,self).__init__()self.conv1 = torch.nn.Conv2d(1,10,kernel_size=5)self.conv2 = torch.nn.Conv2d(88,20,kernel_size=5)self.incep1 = InceptionA(in_channels=10)self.incep2 = InceptionA(in_channels=20)self.mp = torch.nn.MaxPool2d(2)self.fc = torch.nn.Linear(1408,10)def forward(self,x):in_size = x.size(0)x = F.relu(self.mp(self.conv1(x)))x = self.incep1(x)x = F.relu(self.mp(self.conv2(x)))x = self.incep2(x)x = x.view(in_size,-1)x = self.fc(x)return xmodel = Net()
device = torch.device("cude:0"if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=0.01,momentum=0.5)def train(epoch):running_loss = 0.0for batch_idx,data in enumerate(train_loader,0):inputs,target = datainputs,target = inputs.to(device),target.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs,target)loss.backward()optimizer.step()running_loss += loss.item()if batch_idx % 300 == 299:print('[%d,%5d] loss:%.3f'%(epoch + 1,batch_idx + 1,running_loss / 300))running_loss = 0.0def test():correct = 0total = 0with torch.no_grad():for data in test_loader:images,labels = dataimages, labels = images.to(device), labels.to(device)outputs = model(images)_,predicted = torch.max(outputs.data,dim=1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy on test set: %d %%' %(100 * correct / total))if __name__ == '__main__':for epoch in range(10):train(epoch)test()

如果喜欢内容不妨点个关注,后续会持续更新

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/75508.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

javaweb自用笔记:Mybatis

目录 mybatis 配置sql书写提示 JDBC 数据库连接池 lombok mybatis 只需要定义Mapper接口就好,不需要有实现类,因为框架底层会自动生成实现类 配置sql书写提示 JDBC 数据库连接池 lombok XML映射文件 动态SQL

Rust从入门到精通之精通篇:22.Unsafe Rust 详解

Unsafe Rust 详解 在 Rust 的设计哲学中,安全性是核心原则之一。Rust 的所有权系统、借用检查器和类型系统共同保证了内存安全和线程安全。然而,有些底层操作无法通过 Rust 的安全检查机制进行验证,这就是 unsafe Rust 存在的原因。在本章中,我们将深入探讨 unsafe Rust,…

比手动备份快 Iperius全自动加密备份,NAS/云盘/磁带机全兼容

IperiusBackupFull是一款专为服务器和工作站设计的备份解决方案,它同时也是一款针对Windows 7/8/10/11/Server系统的简洁且可靠的备份软件。该软件支持增量备份、数据同步以及驱动器镜像,确保能够实现完全的系统恢复。在备份存储方面,Iperius…

deepseek实战教程-第六篇查找源码之仓库地址与deepseek-R1、deepseek-LLM仓库内容查看

上一篇讲了支持deepseek的模型应用的本地安装和部署以及使用。再上一篇讲解了deepseek提供的开放api,便于开发者基于deepseek提供的接口来编写属于自己的业务应用程序。但是前面几篇我们都是在用模型,我们知道deepseek是开源的,那么deepseek的源码在哪里,具体源码是什么样的…

ES 加入高亮设置

searchTextQueryOne new MatchQuery.Builder().field(searchFieldOne).query(searchText).build();// 帮助中心文档切分 只查询6条Integer finalTopK 10;List<String> newReturnFileds returnFields;newReturnFileds.add("kid"); // 需要返回kidHighlight h…

mapbox进阶,添加鹰眼图控件

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:mapbox 从入门到精通 文章目录 一、🍀前言1.1 ☘️mapboxgl.Map 地图对象1.2 ☘️mapboxgl.Map style属性1.3 ☘️mapboxgl-minimap 鹰眼控件二、🍀添加…

亮数据爬取API爬取亚马逊电商平台实战教程

前言 在当今数据驱动的商业环境中&#xff0c;企业需要快速、精准地获取互联网上的公开数据以支持市场分析、竞品调研和用户行为研究。然而&#xff0c;传统的手动网页爬取方式面临着诸多挑战&#xff1a;IP封锁、验证码干扰、网站结构频繁变更&#xff0c;以及高昂的运维成本…

「Unity3D」使用C#获取Android虚拟键盘的高度

原理是&#xff1a;利用getWindowVisibleDisplayFrame方法&#xff0c;获取Android窗口可见区域的Rect&#xff0c;这个Rect剔除了状态栏与导航栏&#xff0c;并且在有虚拟键盘遮挡的时候&#xff0c;会剔除这个遮挡区域。 接着&#xff0c;Unity的safeArea也剔除了状态栏与导…

“城市超级智能体”落地,联想智慧城市4.0“功到自然成”

作者 | 曾响铃 文 | 响铃说 交通摄像头捕捉到车流量数据&#xff0c;进入一套“自动化”的城市整体管理体系中&#xff0c;交通路况信息、天气变化情况以及城市大型活动安排等看似分散的数据被整合&#xff0c;根据预测的路况精准调控交通信号灯&#xff0c;让自动驾驶清扫车…

每日总结3.24

第十届蓝桥杯大赛软件赛省赛C/C 大学 B 组 183.完全二叉树的权值&#xff08;找规律&#xff0c;临界值&#xff09; #include <bits/stdc.h> using namespace std; int a[1000005]; int main() { int m;int d; cin>>m; int sum;int maxn0; for(int i1;i&…

计算机复试面试

数据库 1.设计过程/设计步骤 1.需求分析&#xff1a;明确客户需求&#xff0c;确定系统边界&#xff0c;生成数据字典 2.概念结构设计&#xff1a;将用户需求抽象为概念模型&#xff0c;绘制e-r图 3.逻辑结构设计&#xff1a;将e-r图转化为dbms相符合的逻辑结构&#xff0c;db…

模型 拆屋效应

系列文章分享模型&#xff0c;了解更多&#x1f449; 模型_思维模型目录。先过分后合理&#xff0c;易被接受。 1 拆屋效应的应用 1.1 高端手表销售案例 一、案例背景 在高端手表销售领域&#xff0c;销售人员面临顾客对价格敏感且购买决策谨慎的挑战。如何引导顾客接受较高…

Windows系统下Pycharm+Minianaconda3连接教程【成功】

0.引言 PycharmMinianaconda3开发组合的好处 优点类别具体优点描述环境管理便捷独立环境创建 环境复制与共享Miniconda3可创建多独立Python环境&#xff0c;支持不同版本与依赖&#xff0c;避免冲突。 能复制、分享环境配置&#xff0c;方便团队搭建相同开发环境。依赖管理高…

4、pytest常用插件

pytest 是一个功能非常强大的测试框架&#xff0c;支持丰富的插件系统。插件可以扩展 pytest 的功能&#xff0c;从而使测试过程更加高效和便捷。以下是一些常用的 pytest 插件及其作用&#xff1a; pytest-cov: 作用: 提供测试覆盖率报告&#xff0c;帮助你了解代码的表现情况…

python每日十题(10)

在Python语言中&#xff0c;源文件的扩展名&#xff08;后缀名&#xff09;一般使用.py。 保留字&#xff0c;也称关键字&#xff0c;是指被编程语言内部定义并保留使用的标识符。Python 3.x有35个关键字&#xff0c;分别为&#xff1a;and&#xff0c;as&#xff0c;assert&am…

Clio:具备锁定、用户认证和审计追踪功能的实时日志记录工具

在网络安全工具不断发展的背景下&#xff0c;Clio 作为一款革命性的实时日志记录解决方案&#xff0c;由 CyberLock Technologies 的网络安全工程师开发&#xff0c;于 2025 年 1 月正式发布。这款先进的工具通过提供对系统事件的全面可见性&#xff0c;同时保持强大的安全协议…

内核编程十三:进程状态详解

进程如同数字世界中的生命体&#xff0c;诞生时被系统母体赋予初始资源&#xff0c;在CPU的脉搏中呼吸&#xff0c;于内存的疆域里生长。它睁开线程之眼观察世界&#xff0c;伸出系统调用之手与环境互动&#xff0c;时而如幼童般单纯执行指令&#xff0c;时而如哲人般陷入阻塞沉…

GitLab 中文版17.10正式发布,27项重点功能解读【一】

GitLab 是一个全球知名的一体化 DevOps 平台&#xff0c;很多人都通过私有化部署 GitLab 来进行源代码托管。极狐GitLab 是 GitLab 在中国的发行版&#xff0c;专门为中国程序员服务。可以一键式部署极狐GitLab。 学习极狐GitLab 的相关资料&#xff1a; 极狐GitLab 官网极狐…

哈尔滨工业大学DeepSeek公开课人工智能:大模型原理 技术与应用-从GPT到DeepSeek|附视频下载方法

导 读INTRODUCTION 今天继续哈尔滨工业大学车万翔教授带来了一场主题为“DeepSeek 技术前沿与应用”的报告。 本报告深入探讨了大语言模型在自然语言处理&#xff08;NLP&#xff09;领域的核心地位及其发展历程&#xff0c;从基础概念出发&#xff0c;延伸至语言模型在机器翻…

web爬虫笔记:js逆向案例十一 某数cookie(补环境流程)

web爬虫笔记:js逆向案例十一 某数cookie(补环境流程) 一、获取网页数据请求流程 二、目标网址、cookie生成(逐步分析) 1、目标网址:aHR0cHM6Ly9zdWdoLnN6dS5lZHUuY24vSHRtbC9OZXdzL0NvbHVtbnMvNy9JbmRleC5odG1s 2、快速定位入口方法 1、通过脚本监听、hook_cookie等操作可…