基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试

基于 PyTorch 的树叶分类任务:从数据准备到模型训练与测试


1. 引言

在计算机视觉领域,图像分类是一个经典的任务。本文将详细介绍如何使用 PyTorch 实现一个树叶分类任务。我们将从数据准备开始,逐步构建模型、训练模型,并在测试集上进行预测,最终生成提交文件。
在这里插入图片描述


2. 环境准备

首先,确保已安装以下 Python 库:

pip install torch torchvision pandas d2l
  • torch:PyTorch 核心库。
  • torchvision:提供计算机视觉相关的工具。
  • pandas:用于处理 CSV 文件。
  • d2l:深度学习工具库,提供辅助函数。

3. 数据准备

竞赛链接:https://www.kaggle.com/competitions/classify-leaves/leaderboard?tab=public

3.1 数据集结构

假设数据集位于 classify-leaves 目录下,包含以下文件:

classify-leaves/
├── train.csv
├── test.csv
├── images/├── image1.jpg├── image2.jpg...
  • train.csv:包含训练图像的路径和标签。
  • test.csv:包含测试图像的路径。

3.2 数据加载与预处理

import os
import pandas as pd
import randomimgpath = "classify-leaves"
trainlist = pd.read_csv(f"{imgpath}/train.csv")
num2name = list(trainlist["label"].value_counts().index)
random.shuffle(num2name)
name2num = {}
for i in range(len(num2name)):name2num[num2name[i]] = i
  • num2name:获取所有类别标签,并按类别数量排序。
  • name2num:将类别名称映射到数字编号。

4. 自定义数据集类

为了加载数据,我们需要定义一个自定义数据集类 Leaf_data

from torch.utils.data import Dataset
from d2l import torch as d2lclass Leaf_data(Dataset):def __init__(self, path, train, transform=lambda x: x):super().__init__()self.path = pathself.transform = transformself.train = trainif train:self.datalist = pd.read_csv(f"{path}/train.csv")else:self.datalist = pd.read_csv(f"{path}/test.csv")def __getitem__(self, index):res = ()tmplist = self.datalist.iloc[index, :]for i in tmplist.index:if i == "image":res += (self.transform(d2l.Image.open(f"{self.path}/{tmplist[i]}")),)else:res += (name2num[tmplist[i]],)if len(res) < 2:res += (tmplist[i],)return resdef __len__(self):return len(self.datalist)
  • __getitem__:根据索引返回一个样本,包括图像和标签。
  • __len__:返回数据集的长度。

5. 模型定义与初始化

我们使用预训练的 ResNet34 模型,并修改最后一层以适应分类任务:

import torch
import torchvision
from torch import nndef init_weight(m):if type(m) in [nn.Linear, nn.Conv2d]:nn.init.xavier_normal_(m.weight)net = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)
net.fc = nn.Linear(in_features=512, out_features=len(name2num), bias=True)
net.fc.apply(init_weight)
net.to(try_gpu())
  • init_weight:使用 Xavier 初始化方法初始化全连接层的权重。
  • net:加载预训练的 ResNet34 模型,并修改最后一层全连接层。

6. 训练过程

6.1 优化器与损失函数

lr = 1e-4
parames = [parame for name, parame in net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.Adam([{"params": parames}, {"params": net.fc.parameters(), "lr": lr * 10}], lr=lr)
LR_con = torch.optim.lr_scheduler.CosineAnnealingLR(trainer, 1, 0)
loss = nn.CrossEntropyLoss(reduction='none')
  • trainer:使用 Adam 优化器,对全连接层使用更高的学习率。
  • LR_con:使用余弦退火学习率调度器。
  • loss:使用交叉熵损失函数。

6.2 训练函数


def train_batch(features, labels, net, loss, trainer, device):# 将数据移动到指定设备(如 GPU)features, labels = features.to(device), labels.to(device)# 前向传播outputs = net(features)l = loss(outputs, labels).mean()  # 计算损失# 反向传播和优化trainer.zero_grad()  # 梯度清零l.backward()         # 反向传播trainer.step()      # 更新参数# 计算准确率acc = (outputs.argmax(dim=1) == labels).float().mean()return l.item(), acc.item()def train(train_data, test_data, net, loss, trainer, num_epochs, device=try_gpu()):best_acc = 0timer = d2l.Timer()plot = d2l.Animator(xlabel="epoch", xlim=[1, num_epochs], legend=['train loss', 'train acc', 'test loss'], ylim=[0, 1])for epoch in range(num_epochs):metric = d2l.Accumulator(4)for i, (features, labels) in enumerate(train_data):timer.start()l, acc = train_batch(features, labels, net, loss, trainer, device)metric.add(l, acc, labels.shape[0], labels.numel())timer.stop()test_acc = d2l.evaluate_accuracy_gpu(net, test_data, device=device)if test_acc > best_acc:save_model(net)best_acc = test_accplot.add(epoch + 1, (metric[0] / metric[2], metric[1] / metric[3], test_acc))print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')print(f"best acc {best_acc}")return metric[0] / metric[2], metric[1] / metric[3], test_acc
  • train:训练模型,记录损失和准确率,并在验证集上评估模型。

7. 测试与结果保存

在测试集上进行预测,并保存结果到 CSV 文件:

net.load_state_dict(torch.load(model_path))
augs = torchvision.transforms.Compose([torchvision.transforms.Resize(224),torchvision.transforms.ToTensor(), norm
])
test_data = Leaf_data(imgpath, False, augs)
test_dataloader = Data.DataLoader(test_data, batch_size=64, shuffle=False)
res = pd.DataFrame(columns=["image", "label"], index=range(len(test_data)))
net = net.cpu()
count = 0
for X, y in test_dataloader:preds = net(X).detach().argmax(dim=-1).numpy()preds = pd.DataFrame(y, index=map(lambda x: num2name[x], preds))preds.loc[:, 1] = preds.indexpreds.index = range(count, count + len(y))res.iloc[preds.index] = predscount += len(y)print(f"loaded {count}/{len(test_data)} datas")
res.to_csv('./submission.csv', index=False)
  • test_dataloader:加载测试数据。
  • res:保存预测结果到 CSV 文件。

8. 总结

本文详细介绍了如何使用 PyTorch 实现一个树叶分类任务,包括数据准备、模型定义、训练、验证和测试。通过本文,您可以掌握以下技能:

  1. 自定义数据集类的实现。
  2. 使用预训练模型进行迁移学习。
  3. 训练模型并保存最佳模型。
  4. 在测试集上进行预测并生成提交文件。

希望本文对您有所帮助!如果有任何问题,欢迎在评论区留言讨论。😊

完整代码

import os
import torch
from torch.utils import data as Data
import torchvision
from torch import nn
from d2l import torch as d2l
import pandas as pd
import random# 数据准备
imgpath = "classify-leaves"
trainlist = pd.read_csv(f"{imgpath}/train.csv")
num2name = list(trainlist["label"].value_counts().index)
random.shuffle(num2name)
name2num = {}
for i in range(len(num2name)):name2num[num2name[i]] = i# GPU 检查
def try_gpu():if torch.cuda.device_count() > 0:return torch.device('cuda')return torch.device('cpu')# 模型保存路径
model_dir = './models'
if not os.path.exists(model_dir):os.makedirs(model_dir)
model_path = os.path.join(model_dir, 'pre_res_model.ckpt')def save_model(net):torch.save(net.state_dict(), model_path)# 自定义数据集类
class Leaf_data(Data.Dataset):def __init__(self, path, train, transform=lambda x: x):super().__init__()self.path = pathself.transform = transformself.train = trainif train:self.datalist = pd.read_csv(f"{path}/train.csv")else:self.datalist = pd.read_csv(f"{path}/test.csv")def __getitem__(self, index):res = ()tmplist = self.datalist.iloc[index, :]for i in tmplist.index:if i == "image":res += (self.transform(d2l.Image.open(f"{self.path}/{tmplist[i]}")),)else:res += (name2num[tmplist[i]],)if len(res) < 2:res += (tmplist[i],)return resdef __len__(self):return len(self.datalist)def train_batch(features, labels, net, loss, trainer, device):# 将数据移动到指定设备(如 GPU)features, labels = features.to(device), labels.to(device)# 前向传播outputs = net(features)l = loss(outputs, labels).mean()  # 计算损失# 反向传播和优化trainer.zero_grad()  # 梯度清零l.backward()         # 反向传播trainer.step()      # 更新参数# 计算准确率acc = (outputs.argmax(dim=1) == labels).float().mean()return l.item(), acc.item()# 训练函数
def train(train_data, test_data, net, loss, trainer, num_epochs, device=try_gpu()):best_acc = 0timer = d2l.Timer()plot = d2l.Animator(xlabel="epoch", xlim=[1, num_epochs], legend=['train loss', 'train acc', 'test loss'], ylim=[0, 1])for epoch in range(num_epochs):metric = d2l.Accumulator(4)for i, (features, labels) in enumerate(train_data):timer.start()l, acc = train_batch(features, labels, net, loss, trainer, device)metric.add(l, acc, labels.shape[0], labels.numel())timer.stop()test_acc = d2l.evaluate_accuracy_gpu(net, test_data, device=device)if test_acc > best_acc:save_model(net)best_acc = test_accplot.add(epoch + 1, (metric[0] / metric[2], metric[1] / metric[3], test_acc))print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'loss {metric[0] / metric[2]:.3f}, train acc {metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on {str(device)}')print(f"best acc {best_acc}")return metric[0] / metric[2], metric[1] / metric[3], test_acc# 模型初始化
def init_weight(m):if type(m) in [nn.Linear, nn.Conv2d]:nn.init.xavier_normal_(m.weight)net = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)
net.fc = nn.Linear(in_features=512, out_features=len(name2num), bias=True)
net.fc.apply(init_weight)
net.to(try_gpu())# 优化器和损失函数
lr = 1e-4
parames = [parame for name, parame in net.named_parameters() if name not in ["fc.weight", "fc.bias"]]
trainer = torch.optim.Adam([{"params": parames}, {"params": net.fc.parameters(), "lr": lr * 10}], lr=lr)
LR_con = torch.optim.lr_scheduler.CosineAnnealingLR(trainer, 1, 0)
loss = nn.CrossEntropyLoss(reduction='none')# 数据增强和数据加载
batch = 64
num_epochs = 10
norm = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
augs = torchvision.transforms.Compose([torchvision.transforms.Resize(224),torchvision.transforms.RandomHorizontalFlip(p=0.5),torchvision.transforms.ToTensor(), norm
])
train_data, valid_data = Data.random_split(dataset=Leaf_data(imgpath, True, augs),lengths=[0.8, 0.2]
)
train_dataloder = Data.DataLoader(train_data, batch, True)
valid_dataloder = Data.DataLoader(valid_data, batch, True)# 训练模型
train(train_dataloder, valid_dataloder, net, loss, trainer, num_epochs)# 测试模型
net.load_state_dict(torch.load(model_path))
augs = torchvision.transforms.Compose([torchvision.transforms.Resize(224),torchvision.transforms.ToTensor(), norm
])
test_data = Leaf_data(imgpath, False, augs)
test_dataloader = Data.DataLoader(test_data, batch_size=64, shuffle=False)
res = pd.DataFrame(columns=["image", "label"], index=range(len(test_data)))
net = net.cpu()
count = 0
for X, y in test_dataloader:preds = net(X).detach().argmax(dim=-1).numpy()preds = pd.DataFrame(y, index=map(lambda x: num2name[x], preds))preds.loc[:, 1] = preds.indexpreds.index = range(count, count + len(y))res.iloc[preds.index] = predscount += len(y)print(f"loaded {count}/{len(test_data)} datas")
res.to_csv('./submission.csv', index=False)

参考链接:

  • PyTorch 官方文档
  • torchvision 官方文档
  • d2l 深度学习工具库

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/69757.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

北斗导航 | 基于多假设解分离(MHSS)模型的双星故障监测算法(MATLAB代码实现——ARAIM)

===================================================== github:https://github.com/MichaelBeechan CSDN:https://blog.csdn.net/u011344545 ===================================================== 双星故障监测算法 一、多星故障MHSS模型流程1、数据预处理2、构建假设模…

pytest测试专题 - 1.2 如何获得美观的测试报告

<< 返回目录 1 pytest测试专题 - 1.2 如何获得美观的测试报告 1.1 背景 虽然pytest命令的报文很详细&#xff0c;用例在执行调试时还算比较方便阅读和提取失败信息&#xff0c; 但对于大量测试用例运行时&#xff0c;可能会存在以下不足 报文被冲掉测试日志没法归档 …

压缩stl文件大小

1、MeshLab下载界面&#xff0c;从MeshLab下载适合自己系统的最新版本。 2、打开 MeshLab软件&#xff0c;将stl文件拖入其中。 3、 4、Percentage reduction参数即为缩放比例&#xff0c;根据自身想要将文件压缩到多大来。 然后点击apply 5、CtrlE弹出窗口保存文件后&…

FPGA 28 ,基于 Vivado Verilog 的呼吸灯效果设计与实现( 使用 Vivado Verilog 实现呼吸灯效果 )

目录 前言 一. 设计流程 1.1 需求分析 1.2 方案设计 1.3 PWM解析 二. 实现流程 2.1 确定时间单位和精度 2.2 定义参数和寄存器 2.3 实现计数器逻辑 2.4 控制 LED 状态 三. 整体流程 3.1 全部代码 3.2 代码逻辑 1. 参数定义 2. 分级计数 3. 状态切换 4. LED 输…

百度舆情优化:百度下拉框中的负面如何清除?

百度的下拉词、相关搜索、大家还在搜有负面词条&#xff0c;一直是企业公关经理头疼的问题&#xff0c;小马识途营销顾问深耕网络营销领域十几年&#xff0c;对百度SEO优化、百度下拉框、百度相关搜索、自媒体营销、短视频营销等等技巧方面积累了一定的方法和技巧。 对于百度下…

【云安全】云原生-K8S- API Server 未授权访问

API Server 是 Kubernetes 集群的核心管理接口&#xff0c;所有资源请求和操作都通过 kube-apiserver 提供的 API 进行处理。默认情况下&#xff0c;API Server 会监听两个端口&#xff1a;8080 和 6443。如果配置不当&#xff0c;可能会导致未授权访问的安全风险。 8080 端口…

大模型Agent开发框架概览

一、低代码框架 无需代码即可完成Agent开发热门框架&#xff1a;Coze、Dify、langFlow 二、基础框架 借助大模型原生能力进行Agent开发function calling、tools use 三、代码框架 借助代码完成Agent开发热门框架&#xff1a;LangChain、LangGraph、LIamaIndex 四、Multi-…

对前端的技术进行分层

前端相比较后端而言&#xff0c;由于其发展历史和浏览器的标准不一&#xff0c;导致其看上去简单&#xff0c;但是深入起来又很复杂&#xff0c;在最开始学习的时候&#xff0c;我们往往是了解一下三剑客和vue、react的api就开始上手工作了&#xff0c;但是到后面会发现&#x…

DeepSeek的大模型介绍

文章目录 DeepSeek是什么DeepSeek平台使用DeepSeek的使用场景DeepSeek的本地部署 DeepSeek是什么 DeepSeek是一家2023/7月年成立的人工智能公司&#xff0c;致力于开发高效、高性能的生成式AI模型&#xff0c;在短短一年多的时间里推出了多款强大的开源模型&#xff0c;包括De…

yolov8涨点系列之多头自注意力引入与FasterNet融合生成新模块

文章目录 多头自注意力介绍原理特点yolov8增加MultiHeadSelfAttention具体步骤融合新模块代码(1)在_init_.py+__conv.py文件的__all__内添加‘MultiHeadSelfAttention’(2)conv.py文件复制粘贴新模块代码MultiHeadSelfAttentionFasterNetBlockFasterNetBlockWithSelfAttention代…

问卷数据分析|SPSS实操之单因素方差分析

适用条件&#xff1a; 检验分类变量和定量变量之间的差异 分类变量数量要大于等于三 具体操作&#xff1a; 1.选择分析--比较平均值--单因素ANOVA检验 2. 下方填分类变量&#xff0c;上方为各个量表数据Z1-Y2 3. 点击选项&#xff0c;选择描述和方差齐性检验 4.此处为结果数…

全排列II(力扣47)

这道题与全排列(力扣46)-CSDN博客 的不同就在于集合中有相同元素&#xff0c;我们唯一多的操作就是在同一层递归中也要去重&#xff0c;其他的都与上一题相同。大家可以结合我下面的代码及详细注释理解此题。 代码及详细注释如下&#xff1a; class Solution { public:vector…

信息收集-主机服务器系统识别IP资产反查技术端口扫描协议探针角色定性

知识点&#xff1a; 1、信息收集-服务器系统-操作系统&IP资产 2、信息收集-服务器系统-端口扫描&服务定性 一、演示案例-应用服务器-操作系统&IP资产 操作系统 1、Web大小写(windows不区分大小写&#xff0c;linux区分大小写) 2、端口服务特征(22就是linux上的服…

vmware安装win7

1、版本说明 vmware workstation 16 win7 X64 2、安装步骤 安装步骤有点独特&#xff0c;先配置虚拟机&#xff0c;然后再虚拟机的虚拟光驱里添加下载的win7。 配置完了之后&#xff0c;点击要运行的虚拟机&#xff0c;然后一直往下走就可以完成系统的安装。 3、配置系统以解…

【C++学习笔记】if 和 if constexpr

背景 在工作中&#xff0c;在一个模版函数里&#xff0c;需要判断 if (std::is_same<T, float>) 来选择走哪个分支&#xff0c;分支里的函数是只能处理相应的类型的&#xff0c;编译过程中产生了报错。 解释 if (std::is_same<T, float>::value)和if constexpr …

使用 Express 写接口

在现代 Web 开发中&#xff0c;构建高效的 RESTful API 是非常重要的。Node.js 和其上的 Express 框架为开发者提供了一种简便而强大的方式来创建这些接口。本文将详细介绍如何使用 Express 来编写和部署一个简单的 RESTful API&#xff0c;涵盖从安装到实现增删改查&#xff0…

【ThreeJS Basics 1-3】Hello ThreeJS,实现第一个场景

文章目录 环境创建一个项目安装依赖基础 Web 页面概念解释编写代码运行项目 环境 我的环境是 node version 22 创建一个项目 首先&#xff0c;新建一个空的文件夹&#xff0c;然后 npm init -y , 此时会快速生成好默认的 package.json 安装依赖 在新建的项目下用 npm 安装依…

Linux下的进程切换与调度

目录 1.进程的优先级 优先级是什么 Linux下优先级的具体做法 优先级的调整为什么要受限 2.Linux下的进程切换 3.Linux下进程的调度 1.进程的优先级 我们在使用计算机的时候&#xff0c;通常会启动多个程序&#xff0c;这些程序最后都会变成进程&#xff0c;但是我们的硬…

使用 EMQX 接入 LwM2M 协议设备

LwM2M 协议介绍 LwM2M 是一种轻量级的物联网设备管理协议&#xff0c;由 OMA&#xff08;Open Mobile Alliance&#xff09;组织制定。它基于 CoAP &#xff08;Constrained Application Protocol&#xff09;协议&#xff0c;专门针对资源受限的物联网设备设计&#xff0c;例…

2024年12月中国电子学会青少年软件编程(Python)等级考试试卷(五级)

青少年软件编程&#xff08;Python&#xff09;等级考试试卷&#xff08;五级&#xff09; 一、单选题(共25题&#xff0c;共50分) 1.已知x[3,5,7]&#xff0c;那么执行语句x[len(x):][1,2]后&#xff0c;x的值?(A) A. [3,5,7,1,2] B. [1,2,3,5,7] C. [3,5,7] D. [1,2] 2.以下…