PyTorch入门之【CNN】

参考:https://www.bilibili.com/video/BV1114y1d79e/?spm_id_from=333.999.0.0&vd_source=98d31d5c9db8c0021988f2c2c25a9620
书接上回的MLP故本章就不详细解释了

目录

  • train
  • test

train

import torch
from torchvision.transforms import ToTensor
from torchvision import datasets
import torch.nn as nn# load MNIST dataset
training_data = datasets.MNIST(root='../02_dataset/data',train=True,download=True,transform=ToTensor()
)train_data_loader = torch.utils.data.DataLoader(training_data, batch_size=64, shuffle=True)# define a CNN model
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1),nn.BatchNorm2d(32),nn.ReLU())self.conv_2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1),nn.BatchNorm2d(64),nn.ReLU(),)self.maxpool = nn.MaxPool2d(2)self.flatten = nn.Flatten()self.fc_1 = nn.Sequential(nn.Linear(9216, 128),nn.BatchNorm1d(128),nn.ReLU())self.fc_2 = nn.Linear(128, 10)def forward(self, x):x = self.conv_1(x)x = self.conv_2(x)x = self.maxpool(x)x = self.flatten(x)x = self.fc_1(x)logits = self.fc_2(x)return logits# create a CNN model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cnn = CNN().to(device)
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()# train the model
num_epochs = 20for epoch in range(num_epochs):print(f'Epoch {epoch+1}\n-------------------------------')for idx, (img, label) in enumerate(train_data_loader):size = len(train_data_loader.dataset)img, label = img.to(device), label.to(device)# compute prediction errorpred = cnn(img)loss = loss_fn(pred, label)# backpropagationoptimizer.zero_grad()loss.backward()optimizer.step()if idx % 400 == 0:loss, current = loss.item(), idx*len(img)print(f'loss: {loss:>7f} [{current:>5d}/{size:>5d}]')# save the model
torch.save(cnn.state_dict(), 'cnn.pth')
print('Saved PyTorch Model State to cnn.pth')

test

import torch
from torchvision import datasets
from torchvision import transforms
from torchvision.transforms import ToTensor
from torchvision.datasets import ImageFolder
import torch.nn as nn# load test data
test_data = datasets.MNIST(root='../02_dataset/data',train=False,download=True,transform=ToTensor()
)
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)transform = transforms.Compose([transforms.Grayscale(),transforms.RandomRotation(10),transforms.ToTensor()
])
my_mnist = ImageFolder(root='../02_dataset/my-mnist', transform=transform)
my_mnist_loader = torch.utils.data.DataLoader(my_mnist, batch_size=64, shuffle=True)# define a CNN model
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=1),nn.BatchNorm2d(32),nn.ReLU())self.conv_2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, stride=1),nn.BatchNorm2d(64),nn.ReLU(),)self.maxpool = nn.MaxPool2d(2)self.flatten = nn.Flatten()self.fc_1 = nn.Sequential(nn.Linear(9216, 128),nn.BatchNorm1d(128),nn.ReLU())self.fc_2 = nn.Linear(128, 10)def forward(self, x):x = self.conv_1(x)x = self.conv_2(x)x = self.maxpool(x)x = self.flatten(x)x = self.fc_1(x)logits = self.fc_2(x)return logits# load the pretrained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cnn = CNN()
cnn.load_state_dict(torch.load('cnn.pth', map_location=device))
cnn.eval().to(device)# test the pretrained model on MNIST test data
size = len(test_data_loader.dataset)
correct = 0with torch.no_grad():for img, label in test_data_loader:img, label = img.to(device), label.to(device)pred = cnn(img)correct += (pred.argmax(1) == label).type(torch.float).sum().item()correct /= size
print(f'Accuracy on MNIST: {(100*correct):>0.1f}%')# test the pretrained model on my MNIST test data
size = len(my_mnist_loader.dataset)
correct = 0with torch.no_grad():for img, label in my_mnist_loader:img, label = img.to(device), label.to(device)pred = cnn(img)correct += (pred.argmax(1) == label).type(torch.float).sum().item()correct /= size
print(f'Accuracy on my MNIST: {(100*correct):>0.1f}%')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/95970.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据结构】二叉树的基本操作

目录: 二叉树的基本操作 1. 二叉树的创建 1.1. 顺序存储 2. 二叉树的初始化3. 二叉树插入节点4. 二叉树的遍历 4.1. 递归遍历4.2. 层序遍历4.3. 非递归遍历 二叉树的基本操作 1. 二叉树的创建 二叉树的存储方式哦同样有两种,一种是顺序存储&#x…

SpringBoot vue云办公系统

SpringBoot vue云办公系统 系统功能 云办公系统 登录 员工资料管理: 搜索员工 添加编辑删除员工 导入导出excel 薪资管理: 工资账套管理 添加编辑删除工资账套 员工账套设置 系统管理: 基础信息设置 部门管理 职位管理 职称管理 权限组管理 操作员管理 开发环境和技术 开发语…

选择适合户外篷房企业的企业云盘解决方案

“户外篷房企业用什么企业云盘好?Zoho WorkDrive企业网盘可以帮助户外篷房企业实现文档统一管理、提高工作效率、加强团队协作,并且支持各种文件类型的预览和编辑。” S公司是一家注重管理规范的大型户外篷房企业,已经有10余年的经验。作为设…

string和const char*参数类型选择的合理性对比

在编程中,我们经常需要处理字符串类型的参数。在C中,有两种常见的表示字符串的参数类型,即string和const char*。本文将对比这两种参数类型的特点,分析其在不同情况下的合理性,以便程序员能够根据实际需求做出正确的选…

Docker安装ActiveMQ

ActiveMQ简介 官网地址:https://activemq.apache.org/ 简介: ActiveMQ 是Apache出品,最流行的,能力强劲的开源消息总线。ActiveMQ 是一个完全支持JMS1.1和J2EE 1.4规范的 JMS Provider实现,尽管JMS规范出台已经是很久的事情了,…

次方计数的拆贡献法(考虑组合意义)+限定类问题善用值域与位置进行ds:1006T3

对于多次方的计数问题可以考虑拆贡献。 题目问 ∣ S ∣ 3 |S|^3 ∣S∣3, ∣ S ∣ |S| ∣S∣ 表示选的点数。相当于在 ∣ S ∣ |S| ∣S∣ 中选了3次,也就是选了3个可相同的点。 先考虑3个不相同点的贡献,对应任意3个点,必然会对…

【小工具-生成合并文件】使用python实现2个excel文件根据主键合并生成csv文件

1 小工具说明 1.1 功能说明 一般来说,我们会先有一个老的文件,这个文件内容是定制好相关列的表格,作为每天的报告。 当下一天来的时候,需要根据新的报表文件和昨天的报表文件做一个合并,合并的时候就会出现有些事新增…

【BI看板】Superset2.0+图表二次开发初探

Superset图表功能也很丰富了,但一些个性化的定制需求就需要二次开发了。网上二开的superset版本大多是0.xxx版本的或1.5xxx版本,本次用的是2.xxx。 源码相关说明 源码目录 superset-2.0\superset-frontend\plugins\plugin-chart-echarts 插件相关资料 官…

【重拾C语言】六、批量数据组织(二)线性表——分类与检索(主元排序、冒泡排序、插入排序、顺序检索、对半检索)

目录 前言 六、批量数据组织——数组 6.4 线性表——分类与检索 6.4.1 主元排序 6.4.2 冒泡排序 6.4.3 插入排序 6.4.4 顺序检索(线性搜索) 6.4.5 对半检索(二分查找) 算法比较 前言 线性表是一种常见的数据结构&#xf…

在linux下预览markdown的方法,转换成html和pdf

背景 markdown是一种便于编写和版本控制的格式,但却不便于预览——特别是包含表格等复杂内容时,单纯的语法高亮是远远不够的——这样就不能边预览边调整内容,需要找到一种预览方法。 思路 linux下有个工具,叫pandoc&#xff0c…

Go Gin Gorm Casbin权限管理实现 - 2. 使用Gorm存储Casbin权限配置以及`增删改查`

文章目录 0. 背景1. 准备工作2. 权限配置以及增删改查2.1 策略和组使用规范2.2 用户以及组关系的增删改查2.2.1 获取所有用户以及关联的角色2.2.2 角色组中添加用户2.2.3 角色组中删除用户 2.3 角色组权限的增删改查2.3.1 获取所有角色组权限2.3.2 创建角色组权限2.3.3 修改角色…

Spring MVC工作原理

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…

Qt model/view 理解01

在 Qt 中对数据处理主要有两种方式:1)直接对包含数据的的数据项 item 进行操作,这种方法简单、易操作,现实方式单一的缺点,特别是对于大数据或在不同位置重复出现的数据必须依次对其进行操作,如果现实方式改…

10.1select并发服务器以及客户端

服务器&#xff1a; #include<myhead.h>//do-while只是为了不让花括号单独存在&#xff0c;并不循环 #define ERR_MSG(msg) do{\fprintf(stderr,"%d:",__LINE__);\perror(msg);\ }while(0);#define PORT 8888//端口号1024-49151 #define IP "192.168.2.5…

【16】c++设计模式——>建造者(生成器)模式

什么是建造者模式? 建造者模式&#xff08;Builder Pattern&#xff09;是一种创建型设计模式&#xff0c;它允许你构造复杂对象步骤分解。你可以不同的步骤中使用不同的方式创建对象&#xff0c;且对象的创建与表示是分离的。这样&#xff0c;同样的构建过程可以创建不同的表…

React Hooks —— ref hooks

什么是Hooks Hooks从语法上来说是一些函数。这些函数可以用于在函数组件中引入状态管理和生命周期方法。 React Hooks的优点 简洁 从语法上来说&#xff0c;写的代码少了上手非常简单 基于函数式编程理念&#xff0c;只需要掌握一些JavaScript基础知识与生命周期相关的知识不…

python:openpyxl 读取 Excel文件,显示在 wx.grid 表格中

pip install openpyxl openpyxl-3.1.2-py2.py3-none-any.whl (249 kB) et_xmlfile-1.1.0-py3-none-any.whl (4.7 kB) 摘要&#xff1a;A Python library to read/write Excel 2010 xlsx/xlsm files pip install wxpython4.2 wxPython-4.2.0-cp37-cp37m-win_amd64.whl (18.0 M…

微擎小程序获取不到头像和昵称解决方案

这是一个使用微擎小程序的代码示例&#xff0c;其中包含了获取用户头像和昵称的功能。以下是解决方案&#xff1a; 首先&#xff0c;在<button>标签上添加open-type"chooseAvatar"属性&#xff0c;并绑定bindchooseavatar事件&#xff1a; <button class&qu…

数据结构-快速排序-C语言实现

引言&#xff1a;快速排序作为一种非常经典且高效的排序算法&#xff0c;无论是工作还是面试中广泛用到&#xff0c;作为一种分治思想&#xff0c;需要熟悉递归思想。下面来讲讲快速排序的实现和改进。 老规矩&#xff0c;先用图解来理解一下&#xff1a;&#xff08;这里使用快…

MATLAB中syms函数使用

目录 语法 说明 示例 创建符号标量变量 创建符号标量变量的向量 创建符号标量变量矩阵 管理符号标量变量的假设 创建和评估符号函数 syms函数的作用是创建符号标量和函数&#xff0c;以及矩阵变量和函数。 语法 syms var1 ... varN syms var1 ... varN [n1 ... nM] …