【Res模块学习】结合CIFAR-100分类任务学习

初次尝试训练CIFAR-100:【图像分类】CIFAR-100图像分类任务-CSDN博客

1.训练模型(MyModel.py)

import torch
import torch.nn as nnclass BasicRes(nn.Module):def __init__(self, in_cha, out_cha, stride=1, res=True):super(BasicRes, self).__init__()self.layer01 = nn.Sequential(nn.Conv2d(in_channels=in_cha, out_channels=out_cha, kernel_size=3, stride=stride, padding=1),nn.BatchNorm2d(out_cha),nn.ReLU(),)self.layer02 = nn.Sequential(nn.Conv2d(in_channels=out_cha, out_channels=out_cha, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(out_cha),)if res:self.res = resif in_cha != out_cha or stride != 1:  # 若x和f(x)维度不匹配:self.shortcut = nn.Sequential(nn.Conv2d(in_channels=in_cha, out_channels=out_cha, kernel_size=1, stride=stride),nn.BatchNorm2d(out_cha),)else:self.shortcut = nn.Sequential()def forward(self, x):residual = xx = self.layer01(x)x = self.layer02(x)if self.res:x += self.shortcut(residual)return x# 2.训练模型
class cifar100(nn.Module):def __init__(self):super(cifar100, self).__init__()# 初始维度3*32*32self.Stem = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=2),  # (32-5+2*2)/1+1=32nn.BatchNorm2d(64),nn.ReLU(),)self.layer01 = BasicRes(in_cha=64, out_cha=64)self.layer02 = BasicRes(in_cha=64, out_cha=64)self.layer11 = BasicRes(in_cha=64, out_cha=128)self.layer12 = BasicRes(in_cha=128, out_cha=128)self.layer21 = BasicRes(in_cha=128, out_cha=256)self.layer22 = BasicRes(in_cha=256, out_cha=256)self.layer31 = BasicRes(in_cha=256, out_cha=512)self.layer32 = BasicRes(in_cha=512, out_cha=512)self.pool_max01 = nn.MaxPool2d(1, 1)self.pool_max02 = nn.MaxPool2d(2)self.pool_avg = nn.AdaptiveAvgPool2d((1, 1))  # b*c*1*1self.fc = nn.Sequential(nn.Dropout(0.4),nn.Linear(512, 256),nn.ReLU(),nn.Linear(256, 100),)def forward(self, x):x = self.Stem(x)x = self.pool_max01(x)x = self.layer01(x)x = self.layer02(x)x = self.pool_max02(x)x = self.layer11(x)x = self.layer12(x)x = self.pool_max02(x)x = self.layer21(x)x = self.layer22(x)x = self.pool_max02(x)x = self.layer31(x)x = self.layer32(x)x = self.pool_max02(x)x = self.pool_avg(x).view(x.size()[0], -1)x = self.fc(x)return x

由于CIFAR-100图像维度为(3,32,32),适当修改了ResNet-18的设计框架加以应用。

2.正式训练

import torch
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import time
from MyModel import BasicRes, cifar100total_start = time.time()# 正式训练函数
def train_val(train_loader, val_loader, device, model, loss, optimizer, epochs, save_path):  # 正式训练函数model = model.to(device)plt_train_loss = []  # 训练过程loss值,存储每轮训练的均值plt_train_acc = []  # 训练过程acc值plt_val_loss = []  # 验证过程plt_val_acc = []max_acc = 0  # 以最大准确率来确定训练过程的最优模型for epoch in range(epochs):  # 开始训练train_loss = 0.0train_acc = 0.0val_acc = 0.0val_loss = 0.0start_time = time.time()model.train()for index, (images, labels) in enumerate(train_loader):images, labels = images.to(device), labels.to(device)optimizer.zero_grad()  # 梯度置0pred = model(images)bat_loss = loss(pred, labels)  # CrossEntropyLoss会对输入进行一次softmaxbat_loss.backward()  # 回传梯度optimizer.step()  # 更新模型参数train_loss += bat_loss.item()# 注意此时的pred结果为64*10的张量pred = pred.argmax(dim=1)train_acc += (pred == labels).sum().item()print("当前为第{}轮训练,批次为{}/{},该批次总loss:{} | 正确acc数量:{}".format(epoch+1, index+1, len(train_data)//config["batch_size"],bat_loss.item(), (pred == labels).sum().item()))# 计算当前Epoch的训练损失和准确率,并存储到对应列表中:plt_train_loss.append(train_loss / train_loader.dataset.__len__())plt_train_acc.append(train_acc / train_loader.dataset.__len__())model.eval()  # 模型调为验证模式with torch.no_grad():  # 验证过程不需要梯度回传,无需追踪gradfor index, (images, labels) in enumerate(val_loader):images, labels = images.cuda(), labels.cuda()pred = model(images)bat_loss = loss(pred, labels)  # 算交叉熵lossval_loss += bat_loss.item()pred = pred.argmax(dim=1)val_acc += (pred == labels).sum().item()print("当前为第{}轮验证,批次为{}/{},该批次总loss:{} | 正确acc数量:{}".format(epoch+1, index+1, len(val_data)//config["batch_size"],bat_loss.item(), (pred == labels).sum().item()))val_acc = val_acc / val_loader.dataset.__len__()if val_acc > max_acc:max_acc = val_acctorch.save(model, save_path)plt_val_loss.append(val_loss / val_loader.dataset.__len__())plt_val_acc.append(val_acc)print('该轮训练结束,训练结果如下[%03d/%03d] %2.2fsec(s) TrainAcc:%3.6f TrainLoss:%3.6f | valAcc:%3.6f valLoss:%3.6f \n\n'% (epoch+1, epochs, time.time()-start_time, plt_train_acc[-1], plt_train_loss[-1], plt_val_acc[-1], plt_val_loss[-1]))print(f'训练结束,最佳模型的准确率为{max_acc}')plt.plot(plt_train_loss)  # 画图plt.plot(plt_val_loss)plt.title('loss')plt.legend(['train', 'val'])plt.show()plt.plot(plt_train_acc)plt.plot(plt_val_acc)plt.title('Accuracy')plt.legend(['train', 'val'])# plt.savefig('./acc.png')plt.show()# 1.数据预处理
transform = transforms.Compose([transforms.RandomHorizontalFlip(),  # 以 50% 的概率随机翻转输入的图像,增强模型的泛化能力transforms.RandomCrop(size=(32, 32), padding=4),  # 随机裁剪transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 对图像张量进行归一化
])  # 数据增强
ori_data = dataset.CIFAR100(root="./Data_CIFAR100",train=True,transform=transform,download=True
)
print(f"各标签的真实含义:{ori_data.class_to_idx}\n")
# print(len(ori_data))
# # 查看某一样本数据
# image, label = ori_data[0]
# print(f"Image shape: {image.shape}, Label: {label}")
# image = image.permute(1, 2, 0).numpy()
# plt.imshow(image)
# plt.title(f'Label: {label}')
# plt.show()config = {"train_size_perc": 0.8,"batch_size": 64,"learning_rate": 0.001,"epochs": 50,"save_path": "model_save/Res_cifar100_model.pth"
}# 设置训练集和验证集的比例
train_size = int(config["train_size_perc"] * len(ori_data))  # 80%用于训练
val_size = len(ori_data) - train_size  # 20%用于验证
train_data, val_data = random_split(ori_data, [train_size, val_size])
# print(len(train_data))
# print(len(val_data))train_loader = DataLoader(dataset=train_data, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(dataset=val_data, batch_size=config["batch_size"], shuffle=False)device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"{device}\n")
model = cifar100()
# model = torch.load(config["save_path"]).to(device)
print(f"我的模型框架如下:\n{model}")
loss = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = torch.optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=1e-4)  # L2正则化
# optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])  # 优化器train_val(train_loader, val_loader, device, model, loss, optimizer, config["epochs"], config["save_path"])print(f"\n本次训练总耗时为:{(time.time()-total_start) / 60 }min")

3.测试文件

import torch
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time
from MyModel import BasicRes, cifar100total_start = time.time()
# 测试函数
def test(save_path, test_loader, device, loss):  # 测试函数best_model = torch.load(save_path).to(device)test_loss = 0.0test_acc = 0.0start_time = time.time()with torch.no_grad():for index, (images, labels) in enumerate(test_loader):images, labels = images.cuda(), labels.cuda()pred = best_model(images)bat_loss = loss(pred, labels)  # 算交叉熵losstest_loss += bat_loss.item()pred = pred.argmax(dim=1)test_acc += (pred == labels).sum().item()print("正在最终测试:批次为{}/{},该批次总loss:{} | 正确acc数量:{}".format(index + 1, len(test_data) // config["batch_size"],bat_loss.item(), (pred == labels).sum().item()))print('最终测试结束,测试结果如下:%2.2fsec(s) TestAcc:%.2f%%  TestLoss:%.2f \n\n'% (time.time() - start_time, test_acc/test_loader.dataset.__len__()*100, test_loss/test_loader.dataset.__len__()))# 1.数据预处理
transform = transforms.Compose([transforms.RandomHorizontalFlip(),  # 以 50% 的概率随机翻转输入的图像,增强模型的泛化能力transforms.RandomCrop(size=(32, 32), padding=4),  # 随机裁剪transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 对图像张量进行归一化
])  # 数据增强
test_data = dataset.CIFAR100(root="./Data_CIFAR100",train=False,transform=transform,download=True
)
# print(len(test_data))  # torch.Size([3, 32, 32])
config = {"batch_size": 64,"save_path": "model_save/Res_cifar100_model.pth"
}
test_loader = DataLoader(dataset=test_data, batch_size=config["batch_size"], shuffle=True)
loss = nn.CrossEntropyLoss()  # 交叉熵损失函数
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"{device}\n")test(config["save_path"], test_loader, device, loss)print(f"\n本次训练总耗时为:{time.time()-total_start}sec(s)")

4.训练结果

设计learning rate=0.001先训练了30轮,模型在测试集上的准确率已经来到了62.61%;
 

后续引入学习率衰减策略对同一网络进行再次训练,初始lr=0.001,衰减系数0.2,每20轮衰减一次,训练60轮,结果如下:

最终训练模型在测试集上的准确率 达到了65.20%

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/79406.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

爱胜品ICSP YPS-1133DN Plus黑白激光打印机报“自动进纸盒进纸失败”处理方法之一

故障现象如下图提示: 用户的爱胜品ICSP YPS-1133DN Plus黑白激光打印机在工作过程中提示自动进纸盒进纸失败并且红色故障灯闪烁; 给出常见故障一般处理建议如下: 当您的爱胜品ICSP YPS-1133DN Plus 黑白激光打印机出现“自动进纸盒进纸失败”…

Flinkcdc 实现 MySQL 写入 Doris

Flinkcdc 实现 MySQL 写入 Doris Flinkcdc 实现 MySQL 写入 Doris 一、环境配置 Doris:3.0.4 JDK 17 MySQL (业务数据库):5.7 MySQL(本地数据库):5.7 Flink:flink-1.19.1 flinkc…

【Linux庖丁解牛】—环境变量!

目录 1. 环境变量 1.1 概念介绍 1.2 命令行参数 1.3 一个例子,一个环境变量 1.4 认识更多的环境变量 1.5 获取环境变量的方法 a. 指令操作 b. 代码操作 1.6 理解环境变量的特性 a.环境变量具有全局特性 b.补充两个概念(为后面埋一个伏笔) 1. 环境变量 …

LangChain4j +DeepSeek大模型应用开发——7 项目实战 创建硅谷小鹿

这部分我们实现硅谷小鹿的基本聊天功能,包含聊天记忆、聊天记忆持久化、提示词 1. 创建硅谷小鹿 创建XiaoLuAgent package com.ai.langchain4j.assistant;import dev.langchain4j.service.*; import dev.langchain4j.service.spring.AiService;import static dev…

普通 html 项目也可以支持 scss_sass

项目结构示例 下载vscode的插件Live Sass Compiler 自动监听编译scss 下载插件Live Server 用于 web 服务器,打开 html 文件到浏览器,也可以不用这个,自己用 nginx 或者宝塔其他 web 工具 新建一个 index.scss打开,点击 vscode 底…

网工_IP协议

2025.02.17:小猿网&网工老姜学习笔记 第19节 IP协议 9.1 IP数据包的格式(首部数据部分)9.1.1 IP协议的首部格式(固定部分可变部分) 9.2 IP数据包分片(找题练)9.3 TTL生存时间的应用9.4 常见…

SQL语句练习 自学SQL网 在查询中使用表达式 统计

目录 Day 9 在查询中使用表达式 Day 10 在查询中进行统计 聚合函数 Day 11 在查询中进行统计 HAVING关键字 Day12 查询执行顺序 Day 9 在查询中使用表达式 SELECT id , Title , (International_salesDomestic_sales)/1000000 AS International_sales FROM moviesLEFT JOIN …

基于机器学习的舆情分析算法研究

标题:基于机器学习的舆情分析算法研究 内容:1.摘要 随着互联网的飞速发展,舆情信息呈现爆炸式增长,如何快速准确地分析舆情成为重要课题。本文旨在研究基于机器学习的舆情分析算法,以提高舆情分析的效率和准确性。方法上,收集了近…

菲索旋转齿轮法:首次地面光速测量的科学魔术

一、当齿轮邂逅光束:19世纪的光速实验室 1849年,法国物理学家阿曼德菲索(Armand Fizeau)在巴黎郊外的一座庄园里,用一组旋转齿轮、一面镜子和一盏油灯,完成了人类首次地面光速测量。他的实验测得光速为315…

上位机知识篇---PSRAM和RAM

文章目录 前言一、RAM(Random Access Memory)1. 核心定义分类:SRAM(静态RAM)DRAM(动态RAM) 2. 关键特性SRAM优点缺点应用 DRAM优点缺点应用 3. 技术演进DDR SDRAMLPDDR(低功耗DRAM&a…

Qt QComboBox 下拉复选多选(multicombobox)

Qt QComboBox 下拉复选多选(multicombobox),备忘,待更多测试 【免费】QtQComboBox下拉复选多选(multicombobox)资源-CSDN文库

ElasticSearch深入解析(五):如何将一台电脑上的Elasticsearch服务迁移到另一台电脑上

文章目录 0.安装数据迁移工具1.导出数据2.导出mapping3.导出查询模板4.拷贝插件5.拷贝配置6.导入到目标电脑上 0.安装数据迁移工具 Elasticsearch dump是一个用于将Elasticsearch索引数据导出为JSON格式的工具。你可以使用Elasticsearch dump通过命令行或编程接口来导出数据。…

微服务中组件扫描(ComponentScan)的工作原理

微服务中组件扫描(ComponentScan)的工作原理 你的问题涉及到Spring框架中ComponentScan的工作原理以及Maven依赖管理的影响。我来解释为什么能够扫描到common模块的bean而扫描不到其他模块的bean。 根本原因 关键在于**类路径(Classpath)**的包含情况: Maven依赖…

Python镜像源配置:

1.用命令进行配置: 1. 使用命令行方式更改镜像源 可以直接通过 pip config 命令来设置全局或用户级别的镜像源地址。例如,使用清华大学开源软件镜像站作为新的索引 URL: pip config set global.index-url https://pypi.tuna.tsinghua.edu.…

【SpringBoot】Spring中事务的实现:声明式事务@Transactional、编程式事务

1. 准备工作 1.1 在MySQL数据库中创建相应的表 用户注册的例子进行演示事务操作,索引需要一个用户信息表 (1)创建数据库 -- 创建数据库 DROP DATABASE IF EXISTS trans_test; CREATE DATABASE trans_test DEFAULT CHARACTER SET utf8mb4;…

javascript 深拷贝和浅拷贝的区别及具体实现方案

一、核心区别 特性浅拷贝深拷贝复制层级仅复制对象的第一层属性递归复制对象的所有层级属性(包括嵌套对象和数组)引用关系嵌套对象/数组与原对象共享内存(引用拷贝)嵌套对象/数组与原对象完全独立(值拷贝)…

pytorch对应gpu版本是否可用判断逻辑

# gpu_is_ok.py import torchdef check_torch_gpu():# 打印PyTorch版本print(f"PyTorch version: {torch.__version__}")# 检查CUDA是否可用cuda_available torch.cuda.is_available()print(f"CUDA available: {cuda_available}")if cuda_available:# 打印…

国内无法访问GitHub官网的问题解决

作为一名程序员,在国内访问GitHub官网经常会遇到打开过慢或者访问失败的问题,但通过一些技巧可以改善访问体验。GitHub访问问题的根源在于GitHub官网访问不稳定的主要原因在于DNS解析过程。当我们直接访问github.com时,需要通过DNS服务器将域…

使用 MediaPipe 和 OpenCV 快速生成人脸掩膜(Face Mask)

在实际项目中,尤其是涉及人脸识别、换脸、图像修复等任务时,我们经常需要生成人脸区域的掩膜(mask)。这篇文章分享一个简单易用的小工具,利用 MediaPipe 和 OpenCV,快速提取人脸轮廓并生成二值掩膜图像。 …

【动态导通电阻】GaN功率器件中动态导通电阻退化的机制、表征及建模方法

2019年,浙江大学的Shu Yang等人在《IEEE Journal of Emerging and Selected Topics in Power Electronics》上发表了一篇关于GaN(氮化镓)功率器件动态导通电阻(Dynamic On-Resistance, RON)的研究论文。该文深入探讨了GaN功率器件中动态导通电阻退化的机制、表征方法、建模…