使用pytorch保存和加载预训练的模型方法

需要使用到的函数

在 PyTorch 中,torch.save()torch.load() 是用于保存和加载模型的核心函数。

torch.save() 函数

  • 主要用途:将模型或模型的状态字典(state_dict)保存到文件中。

  • 语法

torch.save(obj, f, pickle_module=pickle, pickle_protocol=None, _use_new_zipfile_serialization=True)
  • obj: 要保存的对象,可以是整个模型(nn.Module)或模型的状态字典(state_dict)。

  • f: 保存文件的路径。可以是一个字符串路径(如 'model.pth''model.pkl')或一个打开的文件对象。

  • pickle_module: 默认是 pickle,用于序列化对象。你可以使用其他兼容的序列化模块。

  • pickle_protocol: pickle 协议版本。默认值为 None,表示使用最高可用协议版本。

  • _use_new_zipfile_serialization: 默认值为 True,控制是否使用新的序列化格式(推荐使用)。

# 保存整个模型
torch.save(model, 'model.pth')# 保存模型的状态字典
torch.save(model.state_dict(), 'model_state_dict.pth')

torch.load() 函数

  • 主要用途:从文件中加载保存的模型或模型的状态字典。

  • 语法

torch.load(f, map_location=None, pickle_module=pickle)
  • f: 要加载的文件路径。可以是一个字符串路径或一个打开的文件对象。

  • map_location: 控制如何将存储位置映射到当前设备。例如,map_location='cuda:0' 表示将模型加载到 GPU 上。

  • pickle_module: 默认是 pickle,用于反序列化对象。

# 加载整个模型
model = torch.load('model.pth', map_location='cpu')  # 加载到 CPU# 加载模型的状态字典
model_state_dict = torch.load('model_state_dict.pth', map_location='cuda:0')  # 加载到 GPU

加载状态字典到模型

  • 加载状态字典后,通常需要将其加载到一个已经实例化的模型中。可以使用 model.load_state_dict() 方法:

  • 语法

model.load_state_dict(state_dict, strict=True)
  • state_dict: 从文件中加载的模型状态字典。

  • strict: 默认为 True,表示严格加载状态字典中的所有键。如果设置为 False,可以忽略不匹配的键。

# 实例化模型
model = SimpleModel()# 加载状态字典
model_state_dict = torch.load('model_state_dict.pth')# 将状态字典加载到模型中
model.load_state_dict(model_state_dict)

 注意事项

  • 设备映射:使用 torch.load() 时,可以指定 map_location 参数来控制模型加载到的设备(如 CPU 或 GPU)。

  • 自定义类:保存和加载整个模型时,需要确保自定义的模型类在加载代码中已经定义,否则会报错。

  • 兼容性torch.save()torch.load() 使用 pickle 序列化,可能会受到 Python 版本和 PyTorch 版本的影响。建议使用相同版本的 PyTorch 和 Python 进行保存和加载。

  • 推荐使用状态字典:保存和加载状态字典(state_dict)比保存整个模型更灵活和可移植。这样可以避免保存自定义类的依赖关系。

通过以上方法,你可以灵活地保存和加载 PyTorch 模型,无论是 .pth 还是 .pkl 格式,都可以根据需要选择合适的保存方式。

保存和读取.pth格式的预训练模型

保存

import torch
import torch.nn as nn# 定义一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):x = self.fc(x)return x# 创建模型实例
model = SimpleModel()# 假设已经训练了模型,这里只是演示保存
# 保存整个模型
torch.save(model, 'model.pth')
# 或者只保存模型的参数
torch.save(model.state_dict(), 'model_state_dict.pth')

读取

# 如果保存的是整个模型
loaded_model = torch.load('model.pth')
# 如果保存的是模型参数
model_load = SimpleModel()  # 先实例化模型结构
model_load.load_state_dict(torch.load('model_state_dict.pth'))
###########################################################################
# 检查 GPU 是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载预训练模型
model = SimpleModel()
model.load_state_dict(torch.load('model_state_dict.pth', map_location=device))# 将模型转移到 GPU
model.to(device)# 示例输入数据
input_data = torch.randn(1, 10).to(device)  # 确保输入数据也在 GPU 上# 前向传播
output = model(input_data)
print(output)

在使用 model.load_state_dict(torch.load('model_state_dict.pth', map_location=device)) 读取模型时,已经指定了 map_location=device,这确保了模型的参数(张量)被加载到指定的设备上。但是,是否还需要调用 model.to(device) 取决于具体的情况。

详细分析

  1. map_location=device 的作用

    • map_location=device 参数用于指定加载的张量应该被放置到哪个设备上。当你加载模型的状态字典时,这个参数确保所有张量(如模型的权重和偏置)被加载到指定的设备(CPU 或 GPU)。

    • 这个参数主要用于处理加载时的设备映射,特别是在加载存储在不同设备上的模型时(例如,从 GPU 上保存的模型加载到 CPU 上或反之)。

2 .model.to(device) 的作用

  • model.to(device) 用于将整个模型(包括模型的参数、缓冲区等)转移到指定的设备上。这是一个递归操作,会遍历模型的所有子模块并将其转移到目标设备。

  • 如果模型在加载时已经将所有张量加载到了正确的设备上(通过 map_location=device),那么调用 model.to(device) 是冗余的,但它不会产生负面影响。

具体情况分析

  • 加载到 CPU: -你在 如果 CPU 上加载模型,并且使用 map_location='cpu',那么模型的张量已经被加载到 CPU 上。在这种情况下,调用 model.to('cpu') 是不必要的,因为模型已经在 CPU 上了。

  • 加载到 GPU

    • 如果你在 GPU 上加载模型,并且使用 map_location='cuda'map_location=device(其中 device 是 GPU),那么模型的张量已经被加载到 GPU 上。但是,模型对象本身(如模型的结构)可能仍然在 CPU 上。

    • 此,调用 model.to(device) 可以确保模型的所有部分(包括模型的结构和参数)都正确地在 GPU 上。

推荐做法

为了确保模型及其所有组成部分都在正确的设备上,建议在加载模型后调用 model.to(device)。这样可以避免潜在的设备不一致问题。

保存和读取.pkl格式的预训练模型

保存

import torch
import torch.nn as nn# 义一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 2)def forward(self, x):x = self.fc(x)return x# 创建模型实例
model = SimpleModel()# 保存整个模型
with open('model.pkl', 'wb') as f:torch.save(model, f)
# 或者只保存模型的参数
with open('model_state_dict.pkl', 'wb') as f:torch.save(model.state_dict(), f)

读取

# 如果保存的是整个模型
with open('model.pkl', 'rb') as f:loaded_model = torch.load(f)
# 如果保存的是模型参数
model_load = SimpleModel()  # 先实例化模型结构
with open('model_state_dict.pkl', 'rb') as f:model_load.load_state_dict(torch.load(f))

两种格式的区别

  • pth 格式

    • 是 PyTorch 推荐的模型保存格式。它使用 Python 的 pickle 模块来序列化模型对象。对于模型的存储来说,它能够较好地保存和加载模型的结构以及参数。当你想要完整地保存和恢复一个模型的训练状态(包括模型结构、参数、优化器等时),使用.pth 格式很方便。

  • pkl 格式

    • 本质上也是使用 pickle 序列化对象。它是一种通用的 Python 对象序列化格式。在 PyTorch 的早期版本中,pkl 格式被广泛用于保存模型。但是使用 pkl 格式时,可能会受到 Python 版本的限制。因为不同 Python 版本之间,pickle 序列化后的对象在反序列化时可能会出现兼容性问题。例如,你在 Python 3.7 环境下用 pickle 保存了一个模型,然后在 Python 3.8 环境下尝试加载时,可能会因为 pickle 协议版本或者对象结构差异等原因导致加载失败。而.pth 格式会更好地处理这些兼容性问题。

注意事项

  • 当保存整个模型时,如果自定义了模型类,加载模型时也需要提供相同的自定义类定义。否则加载时会出现错误,因为无法识别自定义类的结构。

  • 如果只保存模型参数(state_dict),在加载时必须先实例化一个与保存时相同的模型结构,然后将保存的参数加载到这个结构中。这样可以避免保存自定义类的依赖关系,增加模型的可移植性,但前提是你要清楚地知道模型的结构。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/904714.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python从入门到高手8.3节-元组的常用操作方法

目录 11.3.1 元组的常用操作方法 11.3.2 元组的查找 11.3.3 祈祷明天不再打雷下雨 11.3.1 元组的常用操作方法 元组类型是一种抽象数据类型,抽象数据类型定义了数据类型的操作方法,在本节的内容中,着重介绍元组类型的操作方法。 ​ 元组是…

图书推荐(协同过滤)算法的实现:基于订单购买实现相似用户的图书推荐

代码部分 package com.ruoyi.system.service.impl;import com.ruoyi.system.domain.Book; import com.ruoyi.system.domain.MyOrder; import com.ruoyi.system.mapper.BookMapper; import com.ruoyi.system.mapper.MyOrderMapper; import com.ruoyi.system.service.IBookRecom…

JMeter快速指南:命令行生成HTML测试报告(附样例命令解析)

一、核心命令解析 jmeter -g Dash_CapacityTest_01_AllModules_1000.jtl -o report/ 参数 作用 示例文件说明 -g 指定.jtl结果文件路径 -o 指定报告输出目录 自动创建report文件夹 二、操作步骤(Windows/Linux/Mac通用) 进入JMe…

2025年渗透测试面试题总结-渗透岗位全职工作面试(附回答)(题目+回答)

网络安全领域各种资源,学习文档,以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具,欢迎关注。 目录 一、通用基础类问题 1. 自我介绍 2. 职业动机与规划 3. 加班/出差接受度 二、安全技术类问题 1. 漏…

使用DEEPSEEK快速修改QT创建的GUI

QT的GUI,本质上是使用XML进行描述的,在QT CREATOR的界面编辑处,按CTRL2 切换到代码视图,CTRL3切换到编辑器视图。 CTRL2 切换到代码视图 CTRL3 切换到编辑器视图 鼠标左键点击代码视图中,按CTRLA → CTRLC复制XML代码…

draw.io流程图使用笔记

文章目录 图形较少的问题安装版好还是非安装版好业务系统嵌入的draw.io如何导入呢?如何判断组合和取消组合如何快速选中框里面的内容有时候选不到文本怎么办连接线如何不走直角 航点和取消航点支持多少种图形多个连接点?多个图形对齐双向箭头如何画图形的大小 其他流程图图标…

音频相关基础知识

主要参考: 音频基本概念_音频和音调的关系-CSDN博客 音频相关基础知识(采样率、位深度、通道数、PCM、AAC)_音频2通道和8ch的区别-CSDN博客 概述 声音的本质 声音的本质是波在介质中的传播现象,声波的本质是一种波,是一…

MySQL中隔离级别那点事

引言 在MySQL中,事务隔离级别和二进制日志(binlog)的格式密切相关,直接影响数据的一致性和复制的正确性。尤其是在“已提交读”(Read Committed)隔离级别下,由于没有使用间隙锁,某些…

LeetCode 热题 100 238. 除自身以外数组的乘积

LeetCode 热题 100 | 238. 除自身以外数组的乘积 大家好,今天我们来解决一道经典的算法问题——除自身以外数组的乘积。这道题在 LeetCode 上被标记为中等难度,要求在不使用除法的情况下,计算数组中每个元素的乘积,其中每个元素的…

【网络编程】三、TCP网络套接字编程

文章目录 TCP通信流程Ⅰ. 服务器日志类实现Ⅱ. TCP服务端1、服务器创建流程2、创建套接字 -- socket3、绑定服务器 -- bind🎏4、服务器监听 -- listen🎏5、获取客户端连接请求 -- acceptaccept函数返回的套接字描述符是什么,不是已经有一个了…

STM32的SysTick

SysTick介绍 定义:Systick,即滴答定时器,是内核中的一个特殊定时器,用于提供系统级的定时服务。该定时器是一个24位的递减计数器,具有自动重载值寄存器的功能。当计数器到达自动重载值时,它会自动重新加载…

【Java项目脚手架系列】第一篇:Maven基础项目脚手架

【Java项目脚手架系列】第一篇:Maven基础项目脚手架 前言 在Java开发中,一个好的项目脚手架可以大大提高开发效率,减少重复工作。本系列文章将介绍各种常用的Java项目脚手架,帮助开发者快速搭建项目。今天,我们先从最基础的Maven项目脚手架开始。 什么是项目脚手架? …

Kafka的消息保留策略是怎样的? (基于时间log.retention.hours或大小log.retention.bytes,可配置删除或压缩策略)

Kafka 消息保留策略详解 1. 核心保留机制 # Broker 基础配置示例(server.properties) log.retention.hours168 # 默认7天保留时间 log.retention.bytes1073741824 # 1GB 大小限制2. 策略类型对比 策略类型配置参数执行逻辑适用场景时间删除log.re…

五一の自言自语 2025/5/5

今天开学了,感觉还没玩够。 假期做了很多事,弄了好几天的路由器、监控、录像机,然后不停的出现问题,然后问ai,然后解决问题。这次假期的实践,更像是计算机网络的实验,把那些交换机,…

安卓基础(静态方法)

静态方法的特点​​ ​​无需实例化​​:直接用 类名.方法名() 调用。 ​​不能访问实例成员​​:只能访问类的静态变量或静态方法。 ​​内存中只有一份​​:随类加载而初始化,生命周期与类相同。 // 工具类 MathUtils publi…

EasyRTC嵌入式音视频通话SDK驱动智能硬件音视频应用新发展

一、引言 在数字化浪潮下,智能硬件蓬勃发展,从智能家居到工业物联网,深刻改变人们的生活与工作。音视频通讯作为智能硬件交互与协同的核心,重要性不言而喻。但嵌入式设备硬件资源受限,传统音视频方案集成困难。EasyRT…

《数字图像处理(面向新工科的电工电子信息基础课程系列教材)》封面颜色空间一图的选图历程

禹晶、肖创柏、廖庆敏《数字图像处理(面向新工科的电工电子信息基础课程系列教材)》 学图像处理的都知道,彩色图像的颜色空间很多,而且又是三维,不同的角度有不同的视觉效果,MATLAB的图又有有box和没有box。…

Flutter 异步原理-Zone

前言 Zone 是 Dart 异步模型中的核心机制,主要用于: 隔离异步上下文,形成逻辑上的执行环境。捕获未处理的异步异常,保证系统稳定。自定义异步任务的调度行为(比如微任务、Timer)。 什么是 Zone&#xff1…

聊一聊自然语言处理在人工智能领域中的应用

目录 一、智能交互与对话系统 二、 信息提取与文本分析 三、机器翻译与跨语言应用 四、内容生成与创作辅助 五、 搜索与推荐系统 六、垂直领域的专业应用 七、关键技术支撑 自然语言处理NLP属于AI的一个子领域,专注于让机器理解和生成人类语言,比…

Redis的过期设置和策略

Redis设置过期时间主要有以下几个配置方式 expire key seconds 设置key在多少秒之后过期pexpire key milliseconds 设置key在多少毫秒之后过期expireat key timestamp 设置key在具体某个时间戳(timestamp:时间戳 精确到秒)过期pexpireat key millisecon…