pytorch-模型训练

目录

  • 1. 模型训练的基本步骤
    • 1.1 train、test数据下载
    • 1.2 train、test数据加载
    • 1.3 Lenet5实例化、初始化loss函数、初始化优化器
    • 1.4 开始train和test
  • 2. 完整代码

1. 模型训练的基本步骤

以cifar10和Lenet5为例

1.1 train、test数据下载

使用torchvision中的datasets可以方便下载cifar10数据

cifar_train = datasets.CIFAR10('cifa', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)

transforms.Resize((32, 32)) 将数据图形数据resize为32x32,这里可不用因为cifar10本身就是32x32
transforms.ToTensor()是将numpy或者numpy数组或PIL图像)转换为PyTorch的Tensor格式,以便输入网络。
transforms.Normalize()根据指定的均值和标准差对每个颜色通道进行图像归一化,可以提高神经网络训练过程中的收敛速度

1.2 train、test数据加载

使用pytorch torch.utils.data中的DataLoader用来加载数据

cifar_train = DataLoader(cifar_train, batch_size=batchz, shuffle=True)

batch_size=batchz: 这里batchz是一个变量,代表每个批次的样本数量。
shuffle=True: 这个参数设定为True意味着在每次训练循环(epoch)开始前,数据集中的样本会被随机打乱顺序。这样做可以增加训练过程中的随机性,帮助模型更好地泛化,避免过拟合特定的样本排列顺序。

1.3 Lenet5实例化、初始化loss函数、初始化优化器

    device = torch.device('cuda')model = Lenet5().to(device)crition = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)

注意:网络和模型一定要搬到GPU上

1.4 开始train和test

  • 循环epoch
  • 加载train数据、输入模型、计算loss、backward、调用优化器
  • 加载test数据、输入模型、计算prediction、计算正确率
  • 输出正确率
 for epoch in range(1000):model.train()for batch, (x, label) in enumerate(cifar_train):x, label = x.to(device), label.to(device)logits = model(x)loss = crition(logits, label)optimizer.zero_grad()loss.backward()optimizer.step()# testmodel.eval()with torch.no_grad():total_correct = 0total_num = 0for x, label in cifar_test:x, label = x.to(device), label.to(device)logits = model(x)pred = logits.argmax(dim=1)correct = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)acc = total_correct / total_numprint(epoch, 'test acc:', acc)

2. 完整代码

import torch
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn, optim
import syssys.path.append('.')
from Lenet5 import Lenet5def main():batchz = 128cifar_train = datasets.CIFAR10('cifa', True, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_train = DataLoader(cifar_train, batch_size=batchz, shuffle=True)cifar_test = datasets.CIFAR10('cifa', False, transform=transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]), download=True)cifar_test = DataLoader(cifar_test, batch_size=batchz, shuffle=True)device = torch.device('cuda')model = Lenet5().to(device)crition = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(model.parameters(), lr=1e-3)for epoch in range(1000):model.train()for batch, (x, label) in enumerate(cifar_train):x, label = x.to(device), label.to(device)logits = model(x)loss = crition(logits, label)optimizer.zero_grad()loss.backward()optimizer.step()# testmodel.eval()with torch.no_grad():total_correct = 0total_num = 0for x, label in cifar_test:x, label = x.to(device), label.to(device)logits = model(x)pred = logits.argmax(dim=1)correct = torch.eq(pred, label).float().sum().item()total_correct += correcttotal_num += x.size(0)acc = total_correct / total_numprint(epoch, 'test acc:', acc)if __name__ == '__main__':main()

model.train()和model.eval()的区别和作用
model.train()
作用:当调用模型的model.train()方法时,模型会进入训练模式。这意味着:
启用 Dropout层和BatchNorm层:在训练模式下,Dropout层会按照设定的概率随机“丢弃”一部分神经元以防止过拟合,而Batch Normalization(批规范化)层会根据当前批次的数据动态计算均值和方差进行归一化。
梯度计算:允许梯度计算,这是反向传播和权重更新的基础。
应用场景:在模型的训练循环中,每次迭代开始之前调用,以确保模型处于正确的训练状态。

model.eval()
作用:调用model.eval()方法后,模型会进入评估模式。此时:
禁用 Dropout层:Dropout层在评估时不发挥作用,所有的神经元都会被保留,以确保预测的确定性和可重复性。
固定 BatchNorm层:BatchNorm层使用训练过程中积累的统计量(全局均值和方差)进行归一化,而不是当前批次的统计量,这有助于模型输出更加稳定和一致。
应用场景:在验证或测试模型性能时使用,确保模型输出是确定性的,不受训练时特有的随机操作影响,以便于准确评估模型的泛化能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/862422.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何在Postman中模拟表单数据

在Postman中模拟表单数据主要涉及到在发送HTTP请求时,正确地设置请求的Body部分以模拟表单的提交。以下是在Postman中模拟表单数据的步骤,按照清晰的格式进行分点表示和归纳: 一、创建新的POST请求 打开Postman工具,点击左上角的…

Kubernetes面试整理-不同CNI插件的作用和区别

容器网络接口(Container Network Interface,CNI)插件在 Kubernetes 中负责管理 Pod 的网络连接。CNI 插件提供了 Pod 之间以及 Pod 与外部世界之间的网络通信能力。以下是一些常见的 CNI 插件及其作用和区别: 1. Flannel 作用: ● Flannel 是一种简单的网络层实现,专注于…

第4章 客户端-客户端管理

1. 客户端API 1.1client list client list命令能列出与Redis服务端相连的所有客户端连接信息。 127.0.0.1:6379> client list id254487 addr10.2.xx.234:60240 fd1311 name age8888581 idle8888581 flagsN db0 sub0 psub0 multi-1 qbuf0 qbuf-free0 obl0 oll0 omem0 events…

MySQL数据库基础练习系列——教务管理系统

项目名称与项目简介 教务管理系统是一个旨在帮助学校或教育机构管理教务活动的软件系统。它涵盖了学生信息管理、教师信息管理、课程管理、成绩管理以及相关的报表生成等功能。通过该系统,学校可以更加高效地处理教务数据,提升教学质量和管理水平。 1.…

5款提高工作效率的免费工具推荐

SimpleTex SimpleTex是一款用于创建和编辑LaTeX公式的简单工具。它能够识别图片中的复杂公式并将其转换为可编辑的数据格式。该软件提供了一个直观的界面,用户可以在编辑LaTeX代码的同时实时预览公式的效果,无需额外的编译步骤。此外,SimpleT…

生命在于学习——Python人工智能原理(2.5.1)

五、Python的类与继承 5.1 Python面向对象编程 在现实世界中存在各种不同形态的事物,这些事物之间存在各种各样的联系。在程序中使用对象来映射现实中的事物,使用对象之间的关系描述事物之间的联系,这种思想用在编程中就是面向对象编程。 …

C++笔记:实现一个字符串类(构造函数、拷贝构造函数、拷贝赋值函数)

实现一个字符串类String&#xff0c;为其提供可接受C风格字符串的构造函数、析构函数、拷贝构造函数和拷贝赋值函数。 声明依赖文件 其中ostream库用于打印标准输入输出&#xff0c;cstring库为C风格的字符串库 #include <iostream> #include <cstring> 声明命…

关于代购系统带来的便利性的研究报告

摘要&#xff1a; 本研究报告旨在深入探讨代购系统为消费者带来的显著便利性。通过对相关数据的分析、用户体验的调查以及行业发展的研究&#xff0c;详细阐述了代购系统在商品获取、价格优势、服务质量等方面的突出表现&#xff0c;以及其对现代消费模式的积极影响。 一、引言…

【python】OpenCV—Color Correction

文章目录 cv2.aruco 介绍imutils.perspective.four_point_transform 介绍skimage.exposure.match_histograms 介绍牛刀小试遇到的问题 参考学习来自 OpenCV基础&#xff08;18&#xff09;使用 OpenCV 和 Python 进行自动色彩校正 cv2.aruco 介绍 一、cv2.aruco模块概述 cv2.…

kubeadm 部署k8s集群[单master]

kubeadm 部署k8s集群 主机名地址角色配置kub-k8s-master192.168.75.100主节点2核2G 50Gkub-k8s-node1192.168.75.103工作节点1核2G 50Gkub-k8s-node2192.168.75.104工作节点/haproxy1核2G 50G 一、安装docker[集群] # 一键安装docker脚本 curl -L https://gitea.beyourself.…

语言模型之LLaMA

LLaMA&#xff08;Large Language Model Meta AI&#xff09;是由 Meta&#xff08;前 Facebook&#xff09;开发的大型语言模型&#xff0c;它是一种基于深度学习的自然语言处理&#xff08;NLP&#xff09;模型&#xff0c;旨在在多个语言理解和生成任务中达到高水平的性能。…

骨传导耳机哪个牌子值得入手?精选热销榜TOP5推荐!

短短数年&#xff0c;骨传导耳机的市场规模迅速扩大&#xff0c;其受欢迎程度可见一斑。但身为拥有十二年经验的音频专家&#xff0c;我在此有义务提醒大家&#xff0c;在选择骨传导耳机时一定要谨慎。面对市面上的众多品牌&#xff0c;一定不要盲目入手&#xff0c;不然很容易…

leetcode提速小技巧

据我所知&#xff0c;leetcode可能是按最难那个用例给你打分的&#xff0c;非难题的用时好坏不完全看复杂度&#xff0c;因为可能都差不多&#xff0c;O(n/2)和O(n)虽然都是O(n)&#xff0c;但是反应到成绩上是不同的&#xff0c;所以&#xff0c;尽可能的在条件足够的情况下提…

CVE-2018-8120漏洞提权:Windows 7的安全剖析与实战应用

CVE-2018-8120漏洞提权&#xff1a;Windows 7的安全剖析与实战应用 在网络安全的世界里&#xff0c;漏洞利用常常是攻击者用来获取系统控制权的捷径。2018年发现的CVE-2018-8120漏洞&#xff0c;针对Windows 7操作系统&#xff0c;提供了一个这样的途径。本文将深入分析这一漏…

Java鲜花下单预约系统源码小程序源码

让美好触手可及 &#x1f338;一、开启鲜花新篇章 在繁忙的都市生活中&#xff0c;我们总是渴望那一抹清新与美好。鲜花&#xff0c;作为大自然的馈赠&#xff0c;总能给我们带来无尽的惊喜与愉悦。但你是否曾因为工作繁忙、时间紧张而错过了亲自挑选鲜花的机会&#xff1f;今…

KVB交易平台: 美元兑日元升破161,这一趋势会继续吗?

在2024年6月28日&#xff0c;美元在亚洲交易市场中表现强劲&#xff0c;接近四十年来的新高&#xff0c;预计将连续第二个季度上涨。与此同时&#xff0c;日本日元持续走低&#xff0c;跌至38年以来的新低&#xff0c;首次突破161关口。在东京交易中&#xff0c;日元兑美元贬值…

小程序反编译后报错“_typeof3 is not a function”

详情->本地设置->取消勾选“将JS编译成ES5” 参考链接&#xff1a;https://blog.csdn.net/csl12919/article/details/131569914

Ubuntu网络管理命令:netstat

安装Ubuntu桌面系统&#xff08;虚拟机&#xff09;_虚拟机安装ubuntu桌面版-CSDN博客 顾名思义&#xff0c;netstat命令不是用来配置网络的&#xff0c;而是用来查看各种网络信息的&#xff0c;包括网络连接、路由表以及网络接口的各种统计数据等。 netstat命令的基本语法如…

湿气易藏在身体的这3个地方,1个方法“定位祛湿”,摆脱它!

6月的雨&#xff0c;滴滴答答的下了十几天了&#xff0c;体内湿气越来越重&#xff0c;皮肤发痒、身体沉重、四肢乏力、总犯困、还出现了消化不良&#xff0c;甚至腹泻的情况…… 为什么我都在积极喝祛湿汤&#xff0c;却不见效呢&#xff1f; 一个很重要的原因可能是&#xff…

融入云端的心跳:在Spring Cloud应用中集成Eureka Client

融入云端的心跳&#xff1a;在Spring Cloud应用中集成Eureka Client 引言 在微服务架构中&#xff0c;服务发现是一个关键组件&#xff0c;它允许服务实例之间相互发现并通信。Netflix Eureka是Spring Cloud体系中广泛使用的服务发现框架。Eureka提供了一个服务注册中心&…