使用PyTorch进行热狗图像分类模型微调

本教程将演示如何使用PyTorch框架对预训练模型进行微调,实现热狗与非热狗图像的分类任务。我们将从数据准备开始,逐步完成数据加载、可视化等关键步骤。


1. 环境配置与库导入

%matplotlib inline
import os
import torch
from torch import nn
from d2l import torch as d2l
import torchvision

2. 热狗数据集准备

# 热狗数据集配置
d2l.DATA_HUB['hotdog'] = (d2l.DATA_URL + 'hotdog.zip','fba480ffa8aa7e0febbb511d181409f899b9baa5')# 下载并加载数据集
data_dir = d2l.download_extract('hotdog')
train_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train'))
test_imgs = torchvision.datasets.ImageFolder(os.path.join(data_dir, 'test'))

3. 数据可视化

# 可视化训练集样本
import matplotlib.pyplot as plt# 设置画布大小
plt.figure(figsize=(12, 8))# 绘制前16张图片
for i, (image, label) in enumerate(train_imgs[:16]):plt.subplot(4, 4, i+1)plt.imshow(image)plt.title('hotdog' if label == 0 else 'not hotdog')plt.axis('off')plt.tight_layout()
plt.show()

输出结果

array([<Axes: >, <Axes: >, <Axes: >, <Axes: >, <Axes: >, <Axes: >,<Axes: >, <Axes: >, <Axes: >, <Axes: >, <Axes: >, <Axes: >,<Axes: >, <Axes: >, <Axes: >, <Axes: >], dtype=object)

(实际运行时将显示4x4网格排列的16张图像,包含热狗和其他食品的图片) 

4.数据增强 

normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]
)train_augs = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ToTensor(),normalize
])test_augs = torchvision.transforms.Compose([torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),normalize
])

5.定义并修改预训练模型

# 使用预训练的ResNet18模型
pretrained_net = torchvision.models.resnet18(pretrained=True)
print(pretrained_net.fc)  # 最后一层全连接层查看

输出结果:

Linear(in_features=512, out_features=1000, bias=True)
# 修改最后一层,以适应我们二分类任务
finetune_net = torchvision.models.resnet18(pretrained=True)
finetune_net.fc = nn.Linear(finetune_net.fc.in_features, 2)
nn.init.xavier_uniform_(finetune_net.fc.weight)

6.微调模型

定义微调函数:

from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolderdef train_fine_tuning(net, lr, batch_size=128, num_epochs=5, param_group=True):train_iter = DataLoader(ImageFolder(os.path.join(data_dir,'train'), transform=train_augs),batch_size=batch_size,shuffle=True)test_iter = DataLoader(ImageFolder(os.path.join(data_dir,'test'), transform=test_augs),batch_size=batch_size,shuffle=False)devices = d2l.try_all_gpus()loss = nn.CrossEntropyLoss(reduction='mean')if param_group:params_lx = [param for name, param in net.named_parameters()if name not in ['fc.weight', 'fc.bias']]optim = torch.optim.SGD([{'params': params_lx},{'params': net.fc.parameters(), 'lr': lr * 10}], lr=lr, weight_decay=0.001)else:optim = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=0.001)d2l.train_ch13(net, train_iter, test_iter, loss, optim, num_epochs, devices)

使用小的学习率进行微调:

train_fine_tuning(finetune_net, 5e-5)

输出:

loss 0.006, train acc 0.606, test acc 0.599
18.3 examples/sec on [device(type='cuda', index=0)]

为了进行比较,所有模型参数初始化为随机值 

scratch_net=torchvision.models.resnet18() # 没有预训练参数
scratch_net.fc=nn.Linear(scratch_net.fc.in_features,2) # 修改最后一层全连接层,输出为2
train_fine_tuning(scratch_net,5e-4,param_group=False) # param_group=False使得所有层的参数都为默认的学习率   

输出:

loss 0.005, train acc 0.752, test acc 0.750
10.6 examples/sec on [device(type='cuda', index=0)]

7.总结

本文完整展示了从数据准备到模型训练的热狗分类任务流程。关键步骤包括:

  1. 使用torchvision加载和预处理图像数据

  2. 可视化数据集样本

  3. 构建数据加载管道

  4. 修改预训练模型进行微调

  5. 训练和评估分类模型

实际应用中可以通过调整数据增强策略、尝试不同网络架构、优化超参数等方式进一步提升模型性能。后续可以扩展为部署到移动端的食品识别应用。


注意事项

  1. 确保GPU环境加速训练

  2. 根据显存调整batch_size大小

  3. 适当调整学习率等超参数

  4. 添加早停机制防止过拟合

希望本教程能帮助您快速上手PyTorch模型微调任务!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/903972.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

内容中台与企业内容管理核心差异剖析

功能定位与架构设计差异 在企业数字化进程中&#xff0c;内容中台与企业内容管理&#xff08;ECM&#xff09;的核心差异首先体现在功能定位层面。传统ECM系统以文档存储、版本控制及权限管理为核心&#xff0c;主要服务于企业内部知识库的静态管理需求&#xff0c;例如通过Ba…

使用PyMongo连接MongoDB的基本操作

MongoDB是由C语言编写的非关系型数据库&#xff0c;是一个基于分布式文件存储的开源数据库系统&#xff0c;其内容存储形式类似JSON对象&#xff0c;它的字段值可以包含其他文档、数组及文档数组。在这一节中&#xff0c;我们就来回顾Python 3下MongoDB的存储操作。 常用命令:…

第 12 届蓝桥杯 C++ 青少组中 / 高级组省赛 2021 年真题

一、选择题 第 1 题 题目&#xff1a;下列符号中哪个在 C 中表示行注释 ( )。 A. ! B. # C. ] D. // 正确答案&#xff1a;D 答案解析&#xff1a; 在 C 中&#xff0c;//用于单行注释&#xff08;行注释&#xff09;&#xff0c;从//开始到行末的内容会被编译器忽略。选项 A…

【python】【UV】一篇文章学完新一代 Python 环境与包管理器使用指南

&#x1f40d; UV&#xff1a;新一代 Python 环境与包管理器使用指南 一、UV 是什么&#xff1f; UV 是由 Astral 团队开发的高性能 Python 环境管理器&#xff0c;旨在统一替代 pyenv、pip、venv、pip-tools、pipenv 等工具。 1.1 UV 的主要功能 &#x1f680; 极速包安装&…

前端性能优化2:结合HTTPS与最佳实践,全面优化你的网站性能

点亮极速体验&#xff1a;结合HTTPS与最佳实践&#xff0c;为你详解网站性能优化的道与术 在如今这个信息爆炸、用户耐心极其有限的数字时代&#xff0c;网站的性能早已不是一个可选项&#xff0c;而是关乎生存和发展的核心竞争力。一个迟缓的网站&#xff0c;无异于在数字世界…

JavaWeb:vueaxios

一、简介 什么是vue? 快速入门 <!-- 3.准备视图元素 --><div id"app"><!-- 6.数据渲染 --><h1>{{ msg }}</h1></div><script type"module">// 1.引入vueimport { createApp, ref } from https://unpkg.com/vu…

Tauri联合Vue开发中Vuex与Pinia关系及前景分析

在 TauriVue 的开发场景中&#xff0c;Vuex 和 Pinia 是两种不同的状态管理工具&#xff0c;它们的关系和前景可以从以下角度分析&#xff1a; 一、Vuex 与 Pinia 的关系 继承与发展 Pinia 最初是作为 Vuex 5 的提案设计的&#xff0c;其目标是简化 Vuex 的复杂性并更好地适配 …

Linux中的时间同步

一、时间同步服务扩展总结 1. 时间同步的重要性 多主机协作需求&#xff1a;在分布式系统、集群、微服务架构中&#xff0c;时间一致性是日志排序、事务顺序、数据一致性的基础。 安全协议依赖&#xff1a;TLS/SSL证书、Kerberos认证等依赖时间有效性&#xff0c;时间偏差可能…

【算法基础】三指针排序算法 - JAVA

一、基础概念 1.1 什么是三指针排序 三指针排序是一种特殊的分区排序算法&#xff0c;通过使用三个指针同时操作数组&#xff0c;将元素按照特定规则进行分类和排序。这种算法在处理包含有限种类值的数组时表现出色&#xff0c;最经典的应用是荷兰国旗问题&#xff08;Dutch …

《操作系统真象还原》第十二章(2)——进一步完善内核

文章目录 前言可变参数的原理实现系统调用write更新syscall.h更新syscall.c更新syscall-init.c 实现printf编写stdio.h编写stdio.c 第一次测试main.cmakefile结果截图 完善printf修改main.c 结语 前言 上部分链接&#xff1a;《操作系统真象还原》第十二章&#xff08;1&#…

ICML2021 | DeiT | 训练数据高效的图像 Transformer 与基于注意力的蒸馏

Training data-efficient image transformers & distillation through attention 摘要-Abstract引言-Introduction相关工作-Related Work视觉Transformer&#xff1a;概述-Vision transformer: overview通过注意力机制蒸馏-Distillation through attention实验-Experiments…

深度学习:AI 机器人时代

在科技飞速发展的当下&#xff0c;AI 机器人时代正以汹涌之势席卷而来&#xff0c;而深度学习作为其核心驱动力&#xff0c;正重塑着我们生活与工作的方方面面。 从智能工厂的自动化生产&#xff0c;到家庭中贴心服务的智能助手&#xff0c;再到复杂环境下执行特殊任务的专业机…

《告别试错式开发:TDD的精准质量锻造术》

深度解锁TDD&#xff1a;应用开发的创新密钥 在应用开发的复杂版图中&#xff0c;如何雕琢出高质量、高可靠性的应用&#xff0c;始终是开发者们不懈探索的核心命题。测试驱动开发&#xff08;TDD&#xff09;&#xff0c;作为一种颠覆性的开发理念与方法&#xff0c;正逐渐成…

应用层自定义协议序列与反序列化

目录 一、网络版计算器 二、网络版本计算器实现 2.1源代码 2.2测试结果 一、网络版计算器 应用层定义的协议&#xff1a; 应用层进行网络通信能否使用如下的协议进行通信呢&#xff1f; 在操作系统内核中是以这种协议进行通信的&#xff0c;但是在应用层禁止以这种协议进行…

Excel-CLI:终端中的轻量级Excel查看器

在数据驱动的今天&#xff0c;Excel 文件处理成为了我们日常工作中不可或缺的一部分。然而&#xff0c;频繁地在图形界面与命令行界面之间切换&#xff0c;不仅效率低下&#xff0c;而且容易出错。现在&#xff0c;有了 Excel-CLI&#xff0c;一款运行在终端中的轻量级Excel查看…

百度后端开发一面

mutex, rwmutex 在Go语言中&#xff0c;Mutex&#xff08;互斥锁&#xff09;和RWMutex&#xff08;读写锁&#xff09;是用于管理并发访问共享资源的核心工具。以下是它们的常见问题、使用场景及最佳实践总结&#xff1a; 1. Mutex 与 RWMutex 的区别 Mutex: 互斥锁&#xf…

STM32 IIC总线

目录 IIC协议简介 IIC总线系统结构 IIC总线物理层特点 IIC总线协议层 空闲状态 应答信号 数据的有效性 数据传输 STM32的IIC特性及架构 STM32的IIC结构体 0.96寸OLED显示屏 SSD1306框图及引脚定义 4针脚I2C接口模块原理图 字节传输-I2C 执行逻辑框图 命令表…

【unity游戏开发入门到精通——UGUI】整体控制一个UGUI面板的淡入淡出——CanvasGroup画布组组件的使用

注意&#xff1a;考虑到UGUI的内容比较多&#xff0c;我将UGUI的内容分开&#xff0c;并全部整合放在【unity游戏开发——UGUI】专栏里&#xff0c;感兴趣的小伙伴可以前往逐一查看学习。 文章目录 前言CanvasGroup画布组组件参数 实战专栏推荐完结 前言 如果我们想要整体控制…

大型语言模型个性化助手实现

大型语言模型个性化助手实现 目录 大型语言模型个性化助手实现PERSONAMEM,以及用户资料和对话模拟管道7种原位用户查询类型关于大语言模型个性化能力评估的研究大型语言模型(LLMs)已经成为用户在各种任务中的个性化助手,从提供写作支持到提供量身定制的建议或咨询。随着时间…

生成式 AI 的未来

在人类文明的长河中,技术革命始终是推动社会跃迁的核心引擎。从蒸汽机解放双手,到电力点亮黑夜,再到互联网编织全球神经网络,每一次技术浪潮都在重塑人类的生产方式与认知边界。而今天,生成式人工智能(Generative AI)正以一种前所未有的姿态登上历史舞台——它不再局限于…