PyTorch搭建LeNet训练集详细实现

一、下载训练集

导包

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

 ToTensor()函数:

把图像[heigh x width x channels] 转换为 [channels x height x width]

Normalize() 数据标准化函数:

最后一行是标准化数值计算公式

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 50000张训练图片
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)

参数解释: 

root='./data':数据集下载的路径,我下载到当前目录下的data文件夹,下载完成后会自动创建 

train=True:当前为训练集

download=True:下载数据集时设置为True,下载完成后改为False

transform=transform :设置对图像进行预处理的函数

运行下载数据集结果为: 

 下载完成后生成了data文件夹

二、导入训练集 

# 导入训练集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=36,shuffle=True, num_workers=0)

参数解释: 

        trainset:把刚刚下载的数据导入进来

        batch_size=36:一批数据的大小

        shuffle=True:训练集中的数据是否打乱(一般默认打乱)

        num_workers=0:载入数据的现成数,在lunix操作系统下,可以设置为别的参数,在windows操作系统系统下,默认为0.

三、下载测试集

# 10000张测试图片
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=10000,shuffle=False, num_workers=0)
test_data_iter = iter(testloader)
test_image, test_lable = test_data_iter.next()classes = ('plane', 'car', 'bird', 'cat',   # 数据集中的分类,设置为元组,不可变类'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

参数解释:

test_data_iter = iter(testloader):通过iter()函数把testloader转化成可迭代的迭代器
test_image, test_lable = test_data_iter.next():通过next()方法可以获得测试的图像和图像对应的标签值。

 四、查看导入的图片

在中间过程打印图片进行查看,后续会注释掉

def imshow(img):img = img / 2 + 0.5nping = img.numpy()plt.imshow(np.transpose(nping, (1, 2, 0)))plt.show()# print labels
print(' '.join('%5s' % classes[test_lable[j]] for j in range(4)))
# show images
imshow(torchvision.utils.make_grid(test_image))

运行结果

 图片很模糊,因为像素很低。

上面识别出来的结果都对了。

 我遇到的问题:
一开始有结果但是没有图片,我以为时matplotlib的问题,我重新安装并且更新了版本,但是我再运行后报错更多了,报错提示我 AttributeError: module 'numpy' has no attribute 'bool',我就知道是numpy的问题了,我重新安装并且更新了版本结果还是不行,我百度了一下,发现不是越新的版本越好,我重新下载了1.23.2这个版本的numpy,下载完成后运行就出来结果了。

pip install numpy==1.23.2

这个也只是中间过程,后续会注释或者删了。


五、将创建的模型实例化

创建模型请看PyTorch搭建LeNet神经网络-CSDN博客

# 将创建的模型实例化
net = LeNet()  # 实例化
loss_fuction = nn.CrossEntropyLoss()  # 定义损失函数# 通过优化器将所有可训练的参数都进行训练,lr是learningrate学习率
optimizer = optim.Adam(net.parameters(), lr=0.001)#通过for循环实现训练过程,循环几次就是将训练集迭代多少次
for epoch in range(5):running_loss = 0.0  # 用来累加在学习过程中的损失for step, data in enumerate(trainloader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()   # 历时损失梯度清零。# forward + backward + optimizeoutputs = net(inputs)loss = loss_fuction(outputs, labels)  # 计算神经网络的预测值和真实标签之间的损失loss.backward()optimizer.step()  # step()函数实现参数更新# print statistics  打印数据的过程running_loss += loss.item()if step % 500 == 499:  # 每隔500步打印一次数据的信息with torch.no_grad():  # 上下文管理器outputs = net(test_image)predict_y = torch.max(outputs, dim=1)[1]accuracy = (predict_y == test_lable).sum().item() / test_lable.size(0)print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')# 将模型保存到文件夹中
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)

 详细解释:

比较重点的单独解释了,其他的在注释中。

 optimizer.zero_grad()   # 历时损失梯度清零。

 ? 为什么每计算一个batch,就要调用一次 optimizer.zero_grad()函数

=> 通过清楚历史梯度,就会对计算的历史梯度进行累加。通过这个特性,能变相的实现一个很大的batch数值的训练(因为batch数值越大,训练效果越好)

 with torch.no_grad():  # 上下文管理器

 上下文管理器: 在接下来的计算过程中,不再去计算每个节点的误差损失梯度。

如果不调用这个函数,将会在测试过程中占用更多的算力,消耗更多的资源和占用更多的内存资源,导致内存容易崩。

print函数中打印参数解释:

print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))

epoch + 1:迭代到第几轮了

step + 1:某一轮的第几步

running_loss / 500:训练过程中500步平均训练误差

accuracy:准确率

运行结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/734824.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

git学习(创建项目提交代码)

操作步骤如下 git init //初始化git remote add origin https://gitee.com/aydvvs.git //建立连接git remote -v //查看git add . //添加到暂存区git push 返送到暂存区git status // 查看提交代码git commit -m初次提交git push -u origin "master"//提交远程分支 …

微信小程序(五十二)开屏页面效果

注释很详细&#xff0c;直接上代码 上一篇 新增内容&#xff1a; 1.使用控件模拟开屏界面 2.倒计时逻辑 3.布局方法 4.TabBar隐藏复现 源码&#xff1a; components/openPage/openPage.wxml <view class"openPage-box"><image src"{{imagePath}}"…

20个Python中列表(list)最常用的方法和函数。

本篇文章中,我们分别介绍10个Python中列表(list)最常用的方法和函数。 列表方法示例 1. 创建空列表和使用 append() 方法 # 创建一个空列表 my_list = []# 使用 append() 方法在列表末尾添加元素 my_list.append(5) my_list.append(10) print("After append:",…

离线数仓建设

一.数据仓库分层 ODS(Operation Data Store)层&#xff1a;原始数据层&#xff0c;存放加载原始日志、数据&#xff0c;数据保持原貌不做处理。 DWD(Data warehouse detail)层&#xff1a;对ODS层数据进行清洗&#xff08;去除空值&#xff0c;超过极限范围的数据&#xff09;、…

三维不同坐标系下点位姿态旋转平移变换

文章目录 前言正文计算方法思路Python实现总结前言 本文主要说明以下几种场景3D变换的应用: 3D相机坐标系下长方体物体,有本身坐标系,沿该物体长边方向移动一段距离,并绕长边轴正旋转方向转90度,求解当前物体中心点在相机坐标系下的位置和姿态多关节机器人末端沿工具坐标…

介绍Android UI绘制过程以及注意事项

Android UI绘制是一个复杂的过程&#xff0c;它涉及到多个步骤&#xff0c;从测量&#xff08;measure&#xff09;到布局&#xff08;layout&#xff09;再到绘制&#xff08;draw&#xff09;。以下是这个过程的简要介绍以及一些注意事项&#xff1a; 1. **测量&#xff08;…

计算机网络-网络应用服务器(四)

1.Samba服务器&#xff1a; Samba是Linux上实现和Windows系统局域网上共享文件和打印机的一种通信协议&#xff0c;由服务器及客户端程序构成。支持SMB/CIFS协议&#xff0c;实现共享资源。最主要的一个配置文件smb.conf&#xff0c;可以使用vi编辑器修改。守护进程&#xff1a…

STM32 利用FlashDB库实现在线扇区数据管理不丢失

STM32 利用FlashDB库实现在线扇区数据管理不丢失 &#x1f4cd;FalshDB地址:https://gitee.com/Armink/FlashDB ✨STM32没有片内EEPROM这样的存储区&#xff0c;虽然有备份寄存器&#xff0c;仅可以实现对少量数据的频繁存储&#xff0c;但是依赖备份电源&#xff08;BAT引脚&a…

美国签证|附面签相关事项√

小伙伴最近都忙着办签证吧&#xff01;但是需要注意的是&#xff0c;美国的签证跟其他任何国家的签证不同&#xff0c;并不是办理了就一定拿得到&#xff0c;据说概率是50%左右。所以办理美国签证&#xff0c;不要太着急啦&#xff01;先来了解一下美国签证的相片该怎么拍叭 ✅…

题目 2073: [STL训练]亲和串

题目描述: 人随着岁数的增长是越大越聪明还是越大越笨&#xff0c;这是一个值得全世界科学家思考的问题,同样的问题Eddy也一直在思考&#xff0c;因为他在很小的时候就知道亲和串如何判断了&#xff0c;但是发现&#xff0c;现在长大了却不知道怎么去判断亲和串了&#xff0c;…

RocketMQ的事务消息流程

什么是事务消息&#xff1f; 事务消息是一种在发送方和接收方之间保证消息传递的一致性和可靠性的消息传递机制。在消息发送过程中&#xff0c;生产者可以将消息发送到消息队列&#xff0c;但不会立即被消费者接收和处理。相反&#xff0c;消息会先进入一种“准备”状态&#x…

用chatgpt写insar地质灾害的论文,重复率只有1.8%,chatgpt4.0写论文不是梦

突发奇想&#xff0c;想用chatgpt写一篇论文&#xff0c;并看看查重率&#xff0c;结果很惊艳&#xff0c;说明是确实可行的&#xff0c;请看下图。 下面是完整的文字内容。 InSAR (Interferometric Synthetic Aperture Radar) 地质灾害监测技术是一种基于合成孔径雷达…

【JavaScript】JavaScript 变量 ① ( JavaScript 变量概念 | 变量声明 | 变量类型 | 变量初始化 | ES6 简介 )

文章目录 一、JavaScript 变量1、变量概念2、变量声明3、ES6 简介4、变量类型5、变量初始化 二、JavaScript 变量示例1、代码示例2、展示效果 一、JavaScript 变量 1、变量概念 JavaScript 变量 是用于 存储数据 的 容器 , 通过 变量名称 , 可以 获取 / 修改 变量 中的数据 ; …

第十五届蓝桥杯模拟赛(第三期)

大家好&#xff0c;我是晴天学长&#xff0c;本次分享&#xff0c;制作不易&#xff0c;本次题解只用于学习用途&#xff0c;如果有考试需要的小伙伴请考完试再来看题解进行学习&#xff0c;需要的小伙伴可以点赞关注评论一波哦&#xff01;蓝桥杯省赛就要开始了&#xff0c;祝…

【DimPlot】【FeaturePlot】使用小tips

目录 DimPlot函数参数解析 栅格化点图 放大 ggplot2 图例的点&#xff0c;修改图例的标题 FeaturePlot函数参数解析 调整FeaturePlot颜色 分组绘制featureplot 随手笔记&#xff0c;持续更新中。。。 Reference DimPlot函数参数解析 object: 一个Seurat对象&#xff0c;…

工作纪实46-关于微服务的上线发布姿势

蓝绿部署 在部署时&#xff0c;不需要将旧版本的服务停掉&#xff0c;而是将新版本与旧版本同时运行&#xff0c;新版本测试无误之后再将旧版本停掉。这样可以避免再升级的过程中如果失败服务不可用的问题&#xff0c;因为同时部署了两个版本的程序&#xff0c;使得硬件资源是…

【项目笔记】java微服务:黑马头条(day01)

文章目录 环境搭建、SpringCloud微服务(注册发现、服务调用、网关)1)课程对比2)项目概述2.1)能让你收获什么2.2)项目课程大纲2.3)项目概述2.4)项目术语2.5)业务说明 3)技术栈4)nacos环境搭建4.1)虚拟机镜像准备4.2)nacos安装 5)初始工程搭建5.1)环境准备5.2)主体结构 6)登录6.1…

Ubuntu用扩展分区加载home目录步骤

如果你想要将新的磁盘挂载到默认的 /home 目录下&#xff0c;可以按照以下步骤进行操作&#xff1a; 创建挂载点&#xff1a; 首先&#xff0c;确保新磁盘已连接并识别。然后&#xff0c;创建一个临时挂载点&#xff0c;以便将新磁盘挂载到该点。sudo mkdir /mnt/new_home挂载磁…

JavaScript中的Set和Map:理解与使用

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

C++:类和对象(三)——拷贝构造函数和运算符重载

目录 一、拷贝构造函数 1.概念 2.特性 二、赋值运算符重载 1.运算符重载 2.赋值运算符重载 &#xff08;1&#xff09;注意的点&#xff1a; &#xff08;2&#xff09;赋值运算符不允许被重载为全局函数&#xff0c;只能重载为类的成员函数 &#xff08;3&#xff09;…