【学习记录】pytorch载入模型的部分参数

需要从PointNet网络框架中提取encoder部分的参数,然后赋予自己的模型。因此,需要从一个已有的.pth文件读取部分参数,加载到自定义模型上面。做了一些尝试,记录如下。

关于模型保存与载入

torch.save(): 使用Python的pickle实用程序将对象进行序列化,然后将序列化的对象保存到disk,可以保存各种对象,包括模型、张量和字典等。
torch.load(): 使用pickle unpickle工具将pickle的对象文件反序列化为内存。
可以看出,pth文件本质上是一个序列化的dict。

我们在save时,代码如下:

state = {'epoch': epoch,'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),
}

然后以下代码load进来:

checkpoint = torch.load(args.model_file,  map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])

查看checkpoint,可以看到包含的就是自己保存时的3个dict,分别是epoch,model_state_dict,和optimizer信息。
在这里插入图片描述

这里我们重点关注 model_state_dict,数据类型是一个 OrderedDict,有序字典。展开如下:
在这里插入图片描述
可以看到里面包含了自己定义的encoder,bn1-3,mlp 1-4层,以及每个层对应的参数(权重、bias,对于bn层还有mean, var等)。
这个Dict的顺序就是在Model中我们定义的顺序,这个和模型是一致的。
因此,如果载入时的模型和保存模型完全一致,直接用load_state_dict()就可以按顺序把数据载入进来。但如,如果定义不同怎么办?这就需要手动载入。

方法1:手动载入指定层的参数

从debug的断点可以看到,每个参数就是存在dict中的一个tensor。因此,我们只要读取对应的dict即可。
例如,encoder的conv1的权重,就是 checkpoint['model_state_dict']['encoder.conv1.weight'],那么我们在自己的模型对应的位置读取这个dict即可。
具体载入方式如下:

# 定义模型
model = MyPointNetSegmentation(channel=3, get_feature=True, batch_size=1)
model.to('cpu')# 载入其他模型的参数
checkpoint = torch.load(model_file, map_location='cpu')
model_dict = checkpoint['model_state_dict']# 将其他模型的参数,赋值给自己模型对应参数
model.encoder.conv1.weight.data.copy_(model_dict['encoder.conv1.weight'])
model.encoder.conv1.bias.data.copy_(model_dict['encoder.conv1.bias'])

把所有有用的参数都赋值过来就好,但要注意参数对应的tensor维度是一样的。
在这里插入图片描述

方法2:一次性载入key值相同的参数

如果说两个model的某些key值相同,可以用python的字典推导方式,将名称相关的参数提取出来。例如:

def load_dict_from_pointnet(model : Point2VoxelNet, checkpoint):my_model_dict = model.state_dict()pretrained_dict =  checkpoint['model_state_dict']# 只将pretraind_dict中那些在model_dict中的参数,提取出来state_dict = {k:v for k,v in pretrained_dict.items() if k in my_model_dict .keys()}my_model_dict.update(state_dict)		# 注意要更新state的变量,如果直接赋值,会出现某些key没有定义,导致运行失败model.load_state_dict(my_model_dict)# 对比参数是否一致print(f"{checkpoint['model_state_dict']['feat.stn.conv1.weight'][1]}")print(f"{model.feat.stn.conv1.weight[1]}")return model

看到这里,可以知道如果自己的模型改了名称,例如.pth的参数是:feat.stn.conv1,我这边叫做了 encoder.stn.conv1,那么是无法直接赋值的。可以用方法1,一个个载入,但是太慢了。另一种方式,是做一个键值映射,如果读到的是 feat.xxx,则赋予自定义模型中的 encoder.xxx ,简单处理即可。

注意事项

  • conv层需要载入的参数有:weight 和 bias
  • BN层涉及的参数有:
    1. weight,bias
    2. running_mean,running_var:这两个参数用于归一化的均值和方差, 因此也需要载入
    3. num_batches_tracked:在训练时需要载入,在test时不需要载入
  • 载入参数后,如果用于测试,需要调用 eval()。注意不能在载入参数前调用 eval。eval 会将 bn 层的training参数设置为 false ,这样在测试时 batch_size 时如果是 1 也能够正常运行。

测试

用默认方式载入参数,以及手动方式载入后的两个模型,预测结果一致。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/74186.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【蓝桥杯14天冲刺课题单】Day 8

1.题目链接:19714 数字诗意 这道题是一道数学题。 先考虑奇数,已知奇数都可以表示为两个相邻的数字之和,2k1k(k1) ,那么所有的奇数都不会被计入。 那么就需要考虑偶数什么情况需要被统计。根据打表,其实可以发现除了…

鸿蒙ArkTS开发:微信/系统来电通话监听功能实现

本文将介绍如何在鸿蒙应用中使用ArkTS实现通话监听和录音功能,利用harmony-utils工具库简化开发流程。 工具库地址 一、功能概述 本实现包含以下核心功能: 通话状态监听:检测来电、去电和通话中状态 音频流监控:通过麦克风使用…

NFS 重传次数速率监控

这张图展示的是 NFS 重传次数速率监控,具体解释如下: 1. 指标含义 监控指标 node_nfs_rpc_retransmissions_total 统计 NFS(网络文件系统)通信中 RPC(远程过程调用)的重传次数,rate(node_nfs_…

【 <二> 丹方改良:Spring 时代的 JavaWeb】之 Spring Boot 中的国际化:支持多语言的 RESTful API

<前文回顾> 点击此处查看 合集 https://blog.csdn.net/foyodesigner/category_12907601.html?fromshareblogcolumn&sharetypeblogcolumn&sharerId12907601&sharereferPC&sharesourceFoyoDesigner&sharefromfrom_link <今日更新> 一、开篇整…

黑帽SEO之搜索引擎劫持-域名劫持原理分析

问题起源 这是在《Web安全深度剖析》的第二章“深入HTTP请求流程”的2.3章节“黑帽SEO之搜索引擎劫持”提到的内容&#xff0c;但是书中描述并不详细&#xff0c;没有讲如何攻击达到域名劫持的效果。 书中对SEO搜索引擎劫持的现象描述如下&#xff1a;直接输入网站的域名可以进…

theos工具来编译xcode的swiftUI项目为ipa文件

Theos 是一个开源的开发工具套件&#xff0c;主要用于为 iOS/macOS 平台开发和编译 越狱插件&#xff08;Tweaks&#xff09;、动态库、命令行工具等。它由 Dustin Howett 创建&#xff0c;并被广泛用于越狱社区的开发中。但这里我主要使用它的打包ipa功能&#xff0c;因为我的…

25.4.1学习总结【Java】

动态规划题 2140. 解决智力问题https://leetcode.cn/problems/solving-questions-with-brainpower/ 给你一个下标从 0 开始的二维整数数组 questions &#xff0c;其中 questions[i] [pointsi, brainpoweri] 。 这个数组表示一场考试里的一系列题目&#xff0c;你需要 按顺…

计算机网络知识点汇总与复习——(二)物理层

Preface 计算机网络是考研408基础综合中的一门课程&#xff0c;它的重要性不言而喻。然而&#xff0c;计算机网络的知识体系庞大且复杂&#xff0c;各类概念、协议和技术相互关联&#xff0c;让人在学习时容易迷失方向。在进行复习时&#xff0c;面对庞杂的的知识点&#xff0c…

string的底层原理

一.构造函数 我们来看一下&#xff0c;string的底层就是一个字符型指针和一个size来表示string的大小&#xff0c;capacity来表示分配的内存大小。 我们来看我们注释掉的第一个构造函数&#xff0c;我们是通过初始化列表来初始化size的大小&#xff0c;再通过size的大小来初始化…

Python FastAPI + Celery + RabbitMQ 分布式图片水印处理系统

FastAPI 服务器Celery 任务队列RabbitMQ 作为消息代理定时任务处理 首先创建项目结构&#xff1a; c:\Users\Administrator\Desktop\meitu\ ├── app/ │ ├── __init__.py │ ├── main.py │ ├── celery_app.py │ ├── tasks.py │ └── config.py…

【蓝桥杯】每日练习 Day18

目录 前言 动态求连续区间和 分析 代码 数星星 分析 代码 星空之夜 分析 代码 前言 接下来是今天的题目&#xff08;本来是有四道题的但是有一道题是前面讲过&#xff08;逆序数的&#xff0c;感兴趣的小伙伴可以去看我归并排序的那一篇&#xff09;的我就不再过多赘…

基于银河麒麟桌面服务器操作系统的 DeepSeek本地化部署方法【详细自用版】

一、3种方式使用DeepSeek 1.本地部署 服务器操作系统环境进行,具体流程如下(桌面环境步骤相同): 本例所使用银河麒麟高级服务器操作系统版本信息: (1)安装ollama 方式一:按照ollama官网的下载指南,执行如下命令: curl -fsSL https://ollama.com/install.sh | sh方…

Python入门(7):Python序列结构-字典

字典Dictionary 字典(dictionary)和列表类似&#xff0c;也是可变序列&#xff0c;不过与列表不同&#xff0c;它是无序的可变序列&#xff0c;保存的为容是以“键-值对”的形式存放的。 Python 中的字典相当于 Java 或者 C中的 Map 对象。在C#中,就是Dictionary<TKey,TVa…

Flutter项目之构建打包分析

目录&#xff1a; 1、准备部分2、构建Android包2.1、配置修改部分2.2、编译打包 3、构建ios包3.1、配置修改部分3.2、编译打包 1、准备部分 2、构建Android包 2.1、配置修改部分 2.2、编译打包 执行flutter build apk命令进行打包。 3、构建ios包 3.1、配置修改部分 3.2、编译…

不用再付费~全网书源一键下载,实现阅读自由!!!

现在市面上有许多免费你看书的软件&#xff0c;但都软件内太多广告弹窗&#xff0c;这无疑是很烦&#xff0c;有事一不小心点进去就下载了软件&#xff0c;简直让人头大&#xff01; 如果你遇到这样的难题那么就应该看下本文~ 这是一款能一键将在线连载小说整合下载成标准格式&…

GCC RISCV 后端 -- GIMPLE IR 表示的一些理解

C/C源代码经过 GCC 解析&#xff08;Parse&#xff09;及转换后&#xff0c;通过 GIMPLE IR 予以表示&#xff08;Representation&#xff09;。其中&#xff0c;一个C/C源文件&#xff0c;通过 宏处理后&#xff0c;形成一个 转译单元&#xff08;Translation Unit&#xff09…

JAVA设计模式之适配器模式《太白金星有点烦》

太白金星握着月光凝成的鼠标&#xff0c;第108次检查南天门服务器的运行日志。这个刚从天枢院调来的三等仙官&#xff0c;此刻正盯着瑶池主机房里的青铜鼎发愁——鼎身上"天地同寿"的云纹间&#xff0c;漂浮着三界香火系统每分钟吞吐的十万条功德数据。看着居高不下的…

以太坊DApp开发脚手架:Scaffold-ETH 2 详细介绍与搭建教程

一、什么是Scaffold-ETH 2 Scaffold-ETH 2是一个开源的最新工具包&#xff0c;类似于脚手架。用于在以太坊区块链上构建去中心化应用程序 &#xff08;DApp&#xff09;。它旨在使开发人员更容易创建和部署智能合约&#xff0c;并构建与这些合约交互的用户界面。 Scaffold-ETH…

毕业设计:实现一个基于Python、Flask和OpenCV的人脸打卡Web系统(六)

毕业设计:实现一个基于Python、Flask和OpenCV的人脸打卡Web系统(六) Flask Flask是一个使用 Python 编写的轻量级 Web 应用框架。其 WSGI 工具箱采用 Werkzeug ,模板引擎则使用 Jinja2 。Flask使用 BSD 授权。 Flask也被称为 “microframework” ,因为它使用简单的核心,…

第十一章 VGA显示图片(还不会)

FPGA至简设计实例 前言 一、项目背景 1. IP核概述 IP 核(Intellectual Property core)指的是知识产权核或知识产权模块,其是具有特定电路功能的硬件描述语言程序,在EDA技术开发中具有十分重要的地位。美国著名的Dataquest咨询公司将 半导体产业的IP定义为“用于ASIC或FPGA…