深度学习中保存最优模型的实践与探索:以食物图像分类为例

深度学习中保存最优模型的实践与探索:以食物图像分类为例

在深度学习的模型训练过程中,训练一个性能良好的模型往往需要耗费大量的时间和计算资源。而保存最优模型不仅可以避免重复训练,还能方便后续使用和部署。本文将结合食物图像分类的代码实例,深入探讨如何在深度学习项目中保存最优模型,以及不同保存方式的特点和适用场景。

一、保存最优模型的重要性

在深度学习模型的训练过程中,随着训练轮次(epoch)的增加,模型的性能(如准确率、损失值等指标)会不断变化。由于数据分布、模型复杂度、超参数设置等多种因素的影响,模型并非在训练的最后一轮就能达到最佳性能。因此,我们需要一种机制来记录和保存模型在训练过程中表现最优的状态,以便后续在实际应用中使用该模型进行预测和推理。

保存最优模型可以帮助我们:

  1. 避免过度训练:如果不保存最优模型,模型可能会在后续训练过程中出现过拟合,导致性能下降。通过保存最优模型,我们可以确保使用的是在验证集或测试集上表现最好的模型版本。
  2. 提高开发效率:在后续的项目迭代、模型优化或实际部署中,直接使用保存的最优模型,无需重新训练,节省大量的时间和计算资源。
  3. 便于模型评估与比较:保存不同训练阶段或不同参数设置下的最优模型,有助于我们对比分析模型的性能差异,为进一步优化模型提供依据。

二、代码实例解析

在食物图像分类的代码中,我们通过以下方式实现了最优模型的保存:

def test(dataloader,model,loss_fn):best_acc = 0size=len(dataloader.dataset)num_batches=len(dataloader)  #打包的数量model.eval()  #测试,w就不能再更新。test_loss,correct=0,0with torch.no_grad():    #一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。for X,y in dataloader:X,y=X.to(device),y.to(device)pred=model.forward(X)test_loss+=loss_fn(pred,y).item()   #test_loss是会自动累加每一个批次的损失值correct+=(pred.argmax(1)==y).type(torch.float).sum().item()a=(pred.argmax(1)==y)   #dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值b=(pred.argmax(1)==y).type(torch.float)result = zip(pred.argmax(1).tolist(), y.tolist())for i in result:print(f"当前测试的结果为:{food_type[i[0]]}\t,当前真实的结果为:{food_type[i[1]]}")test_loss /=num_batchescorrect /=sizeprint(f'Test result: \n Accuracy: {(100*correct)}%, Avg loss: {test_loss}')
#保存最优模型的两种方法:if correct > best_acc:best_acc=correct# 1、保存模型参数方法: torch.save(model.state_dict(),path)print(model.state_dict().keys())torch.save(model.state_dict(),'best.pth')# # 2、保存完整模型(w, b, 模型cnn),#     torch.save(model, 'best1.pt')

在上述代码的test函数中,我们定义了一个变量best_acc用于记录当前最优的准确率。每次进行测试时,计算模型在测试集上的准确率correct,并与best_acc进行比较。如果当前的准确率高于best_acc,则更新best_acc,并保存当前的模型。

1. 保存模型参数

torch.save(model.state_dict(), 'best.pth')
这种方式只保存模型的参数(如卷积层的权重、偏置,全连接层的权重等),而不保存模型的结构。保存的文件格式通常为.pth.pt。其优点在于:

  • 文件体积小:由于只保存参数,文件大小相对较小,便于存储和传输。
  • 灵活性高:在加载模型时,我们可以先定义相同结构的模型,然后将保存的参数加载到模型中。这样可以方便地在不同的代码环境或项目中使用,只要模型结构一致即可。

例如,在加载保存的模型参数时,可以使用以下代码:

model = CNN()  # 定义与训练时相同结构的模型
model.load_state_dict(torch.load('best.pth'))  # 加载模型参数
model.eval()  # 将模型设置为评估模式

2. 保存完整模型

torch.save(model, 'best1.pt')
这种方式会将整个模型(包括模型的结构和参数)一起保存。其优点是加载模型时非常方便,无需重新定义模型结构,直接加载即可使用:

loaded_model = torch.load('best1.pt')  # 直接加载完整模型
loaded_model.eval()  # 设置为评估模式

然而,这种方式也存在一些缺点:

  • 文件体积大:由于包含了模型结构和参数,文件大小通常比只保存参数的方式大很多。
  • 代码依赖性强:如果代码中的模型定义发生了变化,可能会导致加载失败。因为保存的模型是基于特定的代码结构保存的。

三、选择合适的保存方式

在实际应用中,我们需要根据具体需求来选择合适的模型保存方式:

  • 如果注重文件大小和灵活性:例如在模型部署到资源受限的设备(如嵌入式设备),或者需要在不同代码库中共享模型参数时,保存模型参数的方式更为合适。
  • 如果追求加载的便捷性:在快速测试、演示或者代码结构相对固定的项目中,保存完整模型的方式可以简化加载过程,提高开发效率。

四、总结

保存最优模型是深度学习项目中不可或缺的环节。通过合理选择保存方式,我们可以更好地管理和利用训练好的模型。在食物图像分类的案例中,我们展示了两种常见的保存模型的方法,并分析了它们的优缺点和适用场景。在实际项目开发中,开发者应根据具体情况灵活运用这些方法,确保模型能够在后续的应用中发挥最佳性能,推动深度学习技术在各个领域的落地与发展。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/78492.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

护理岗位技能比赛主持稿串词

男:尊敬的各位老师 女:亲爱的各位同学 合:大家下午好。 男:在这鸟语花香,诗意盎然的季节里 女:在这阳光灿烂,激情似火的日子里 合:我们欢聚一堂,共同庆祝五一二国际护士节…

【翻译、转载】MCP 核心架构

核心架构 了解 MCP 如何连接客户端、服务器和 LLM 模型上下文协议 (MCP) 构建在一个灵活、可扩展的架构之上,能够实现 LLM 应用程序与集成之间的无缝通信。本文档涵盖了核心的架构组件和概念。 概述 MCP 遵循客户端-服务器 (client-server) 架构,其中…

Python 数据智能实战 (11):LLM如何解决模型可解释性

写在前面 —— 不只知其然,更要知其所以然:借助 LLM,揭开复杂模型决策的神秘面纱 在前面的篇章中,我们学习了如何利用 LLM 赋能用户分群、购物篮分析、流失预测以及个性化内容生成。我们看到了 LLM 在理解数据、生成特征、提升模型效果和自动化内容方面的巨大潜力。 然而…

Linux:进程优先级及环境

一:孤儿进程 在Linux系统中,当一个进程创建了子进程后,如果父进程执行完毕或者提前退出而子进程还在运行,那么子进程就会成为孤儿进程。子进程就会被systemd(系统)进程收养,其pid为1 myproces…

Java大厂面试:Java技术栈中的核心知识点

Java技术栈中的核心知识点 第一轮提问:基础概念与原理 技术总监:郑薪苦,你对JVM内存模型了解多少?能简单说说吗?郑薪苦:嗯……我记得JVM有堆、栈、方法区这些区域,堆是存放对象的地方&#xf…

CF1000E We Need More Bosses

CF1000E We Need More Bosses 题目描述 题目大意: 给定一个 n n n 个点 m m m 条边的无向图,保证图连通。找到两个点 s , t s,t s,t,使得 s s s到 t t t必须经过的边最多(一条边无论走哪条路线都经过ta,这条边就是…

imx6uLL应用-v4l2

Linux V4L2 视频采集 JPEG 解码 LCD 显示实践 本文记录一个完整的嵌入式视频处理项目:使用 V4L2 接口从摄像头采集 MJPEG 图像,使用 libjpeg 解码为 RGB 格式,并通过 framebuffer 显示在 LCD 屏幕上。适用于使用 ARM Cortex-A 系列开发板进…

强化学习机器人模拟器——QAgent:一个支持多种强化学习算法的 Python 实现

QAgent 是一个灵活的 Python 类,专为实现经典的强化学习(Reinforcement Learning, RL)算法而设计,支持 Q-learning、SARSA 和 SARSA(λ) 三种算法。本篇博客将基于提供的 q_agent.py 代码,详细介绍 QAgent 类的功能、结构和使用方法,帮助您理解其在强化学习任务中的应用,…

Feign的原理

为什么 SpringCloud 中的Feign,可以帮助我们像使用本地接口一样调用远程 HTTP服务? Feign底层是如何实现的?这篇文章,我们一起来聊一聊。 1. Feign 的基本原理 Feign 的核心思想是通过接口和注解定义 HTTP 请求,将接…

探索正态分布:交互式实验带你体验统计之美

探索正态分布:交互式实验带你体验统计之美 正态分布,这条优美的钟形曲线,可以说是统计学中最重要、最无处不在的概率分布。从自然现象(如身高、测量误差)到金融市场,再到机器学习,它的身影随处…

使用 IDEA + Maven 搭建传统 Spring MVC 项目的详细步骤(非Spring Boot)

搭建Spring MVC项目 第一步:创建Maven项目第二步:配置pom.xml第三步:配置web.xml第四步:创建Spring配置文件第五步:创建控制器第六步:创建JSP视图第七步:配置Tomcat并运行目录结构常见问题解决与…

AI日报 · 2025年5月04日|Hugging Face 启动 MCP 全球创新挑战赛

1、Hugging Face 启动 MCP 全球创新挑战赛 Hugging Face 于 5 月 3 日发布 MCP Global Innovation Challenge,面向全球开发者征集基于模型上下文协议(MCP)的创新工具与应用,赛事持续至 5 月 31 日,设立多档…

学习spring boot-拦截器Interceptor,过滤器Filter

目录 拦截器Interceptor 过滤器Filter 关于过滤器的前置知识可以参考: 过滤器在springboot项目的应用 一,使用WebfilterServletComponentScan 注解 1 创建过滤器类实现Filter接口 2 在启动类中添加 ServletComponentScan 注解 二,创建…

汇编常用语法

GNU汇编语句: [lable:] instruction [comment] lable 表示标号,表示地址位置,可选. instruction即指令,也就是汇编指令或伪指令。 comment 就是注释内容。 用户使用.section 伪操作来定义一个段,汇编系统预定义了一些…

terraform resource创建了5台阿里云ecs,如要使用terraform删除其中一台主机,如何删除?

在 Terraform 中删除阿里云 5 台 ECS 实例中的某一台,具体操作取决于你创建资源时使用的 多实例管理方式(count 或 for_each)。以下是详细解决方案: 方法一:使用 for_each(推荐) 如果创建时使…

pycharm terminal 窗口打不开了

参考添加链接描述powershell.exe改为cmd.exe发现有一个小正方形,最大化可以看见了。

百度「心响」:左手“多智能体”右手“保姆级服务”,C端用户能看懂这技术告白吗?

——当技术名词撞上“傻瓜式”需求,谁是赢家? 「多智能体」是什么?用户:不重要,能一键搞定就行 百度最新推出的多智能体平台“心响”,号称能用自然语言交互一键托管复杂任务。 从旅游攻略到法律咨询&#x…

57认知干货:AI机器人产业

机器人本质上由可移动的方式和可交互万物的机构组成,即适应不同环境下不同场景的情况,机器人能够做到根据需求调整交互机构和移动方式。因此,随着人工智能技术的发展,AI机器人的产业也将在未来逐步从单一任务的执行者,发展为能够完成复杂多样任务的智能体。 在未来的社会…

在两个bean之间进行数据传递的解决方案

简介 在日常开发中,在两个bean之间进行数据传递是常见的操作,例如在日常开发中,将数据从VO类转移到DO类等。在两个bean之间进行数据传递,最常见的解决方案,就是手动复制,但是它比较繁琐,充斥着…

基于开闭原则优化数据库查询语句拼接方法

背景 在开发实践中,曾有同事在实现新功能时,因直接修改一段数据库查询条件拼接方法的代码逻辑,导致生产环境出现故障。 具体来看,该方法通过在函数内部直接编写条件判断语句实现查询拼接,尽管从面向对象设计的开闭原…