Pytorch基础:torch.load_state_dict()方法在加载时不会检查类型

相关阅读

Pytorch基础icon-default.png?t=N7T8https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482


        笔者在使用torch.nn.module的load_state_dict中出现了一个问题,一个被注册的张量在加载后居然没有变化,一开始以为是加载出现了问题,但发现其他参数加载成功,思索后发现是注册的张量的类型是整型而checkpoint中保存为浮点数类型,恰好注册时的默认值给的是0,而checkpoint中的浮点数又在0到1之间,因此出现了这个令人困惑的bug。

        下面首先复现这个bug。

import torch
import torch.nn as nn# 定义一个简单的线性模型,参数类型为整数
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.register_buffer('test', torch.tensor(0)) # 注册一个整型张量# 创建一个简单模型实例
model = SimpleModel()# 创建一个浮点数作为参数
float_parameter = torch.tensor(0.6)# 将注册名指向另一个浮点型张量
model.test = float_parameter# 保存模型
torch.save(model.state_dict(), 'model.pth')# 直接使用原模型加载
checkpoint = torch.load('model.pth')
model.load_state_dict(checkpoint)# 打印加载后的参数
print(model.test)# 直接使用新模型加载
model_1 = SimpleModel()
model_1.load_state_dict(checkpoint)# 打印加载后的参数
print(model_1.test)
输出:
tensor(0.6000)
tensor(0)

        可以看到,当模型中注册的名字(test),指向了一个类型不符的张量后,并不会导致浮点型张量被截断为整型,这是因为此处是直接使用赋值号=,使名字指向了另一个张量。

        但使用load_state_dict()方法与使用赋值号是不同的,load_state_dict()方法的实现中,调用了_load_from_state_dict()方法,其中调用了copy_()方法,进行了原位(in-place)数据替换,这可能会进行截断,下面是原位替换的一个例子。

import torch# 创建两个张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5.1, 6.1], [7.1, 8.1]])# 查看张量对象的id
print(id(a))
print(id(b))# 查看底层存储的内存地址
print(a.storage().data_ptr())
print(b.storage().data_ptr())# 将张量 b 中的值复制到张量 a 中
a.copy_(b)# 打印复制后的结果
print(a)# 查看张量对象的id
print(id(a))
print(id(b))# 查看底层存储的内存地址
print(a.storage().data_ptr())
print(b.storage().data_ptr())
输出:
2604425272672
2604426953808  
2604511348096  
2602930352832  
tensor([[5, 6],[7, 8]])
2604425272672
2604426953808
2604511348096
2602930352832

        在保存了模型的状态字典后,使用load_state_dict()方法加载后,也不会有任何截断问题,因为对于原模型而言,名字test指向的是一个浮点型张量,此时原位替换,类型吻合。但是对于一个新的模型,此时的test指向的是一个整型张量,此时原位替换,会发生截断。

        因此,在注册一个张量时,需要确保其在注册时和保存时的类型吻合,此处除了指形状,还有类型,否则可能会出现意想不到的bug。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/831725.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

漏洞挖掘之某厂商OAuth2.0认证缺陷

0x00 前言 文章中的项目地址统一修改为: a.test.com 保护厂商也保护自己 0x01 OAuth2.0 经常出现的地方 1:网站登录处 2:社交帐号绑定处 0x02 某厂商绑定微博请求包 0x02.1 请求包1: Request: GET https://www.a.test.com/users/auth/weibo?…

SpringCloud微服务:Eureka 和 Nacos 注册中心

共同点 都支持服务注册和服务拉取都支持服务提供者心跳方式做健康检测 不同点 Nacos 支持服务端主动检测提供者状态:临时实例采用心跳模式,非临时(永久)实例采用主动检测模式Nacos 临时实例心跳不正常会被剔除,非临时实…

深度学习中权重初始化的重要性

深度学习模型中的权重初始化经常被人忽略,而事实上这是非常重要的一个步骤,模型的初始化权重的好坏关系到模型的训练成功与否,以及训练速度是否快速,效果是否更好等等,这次我们专门来看看深度学习中的权重初始化问题。…

my-room-in-3d中的电脑,电视,桌面光带发光原理

1. my-room-in-3d中的电脑,电视,桌面光带发光原理 最近在github中,看到了这样的一个项目; 项目地址 我看到的时候,蛮好奇他这个光带时怎么做的。 最后发现,他是通过,加载一个 lightMap.jpg这个…

让我们一起来领悟带环问题的核心思想

一、带环的链表: 本质还是快慢指针来解决 关于如下一个带环链表怎么去找到他们想碰到的节点呢????我们可以想到快慢指针,第一个快点走,若是有环就会进入环,此时快指针每次走2步&am…

2.1 上海雷卯电子PLC

PLC(可编程逻辑控制器)像是工厂自动化系统的“大脑”,负责监控和控制各种生产过程。PLC 能够精确地协调各类设备的操作,实现生产流程的自动化和优化。通过编程,它可以根据不同的生产需求灵活调整控制逻辑,提…

可视化大屏应用场景:智慧安防,保驾护航

hello,我是大千UI工场,本篇分享智慧安防的大屏设计,关注我们,学习N多UI干货,有设计需求,我们也可以接单。 实时监控与预警 可视化大屏可以将安防系统中的监控画面、报警信息、传感器数据等实时展示在大屏上…

快速幂笔记

快速幂即为快速求出一个数的幂&#xff0c;这样可以避免TLE&#xff08;超时&#xff09;的错误。 传送门&#xff1a;快速幂模板 前置知识&#xff1a; 1) 又 2) 代码&#xff1a; #include <bits/stdc.h> using namespace std; int quickPower(int a, int b) {int…

TiDB系列之:部署TiDB集群常见报错解决方法

TiDB系列之&#xff1a;部署TiDB集群常见报错解决方法 一、部署TiDB集群二、unsupported filesystem ext3三、soft limit of nofile四、THP is enabled五、numactl not usable六、net.ipv4.tcp_syncookies 1七、service irqbalance not found,八、登陆TiDB数据库 一、部署TiDB…

搜款网商品列表API接口:高效获取时尚潮流商品的新途径

API接口概述 搜款网商品列表API接口允许开发者根据设定的条件&#xff08;如分类、价格区间、关键词等&#xff09;查询搜款网上的商品信息&#xff0c;并返回符合条件的商品列表。通过调用该接口&#xff0c;您可以轻松获取到搜款网上最新、最热的时尚商品数据&#xff0c;为…

批量视频剪辑新选择:一键式按照指定秒数分割视频并轻松提取视频中的音频,让视频处理更高效!

是否经常为大量的视频剪辑工作感到头疼&#xff1f;还在一个个手动分割、提取音频吗&#xff1f;现在&#xff0c;我们为你带来了一款全新的视频批量剪辑神器&#xff0c;让你轻松应对各种视频处理需求&#xff01; 首先&#xff0c;进入媒体梦工厂的主页面&#xff0c;并在板…

TFT显示屏偶发无法点亮

一. 问题描述 最近接到一起客诉&#xff1a;设备偶发显示屏不亮。复现现象时&#xff0c;发现有如下规律&#xff1a; 上电后&#xff0c;如果显示屏正常启动&#xff0c;则在使用过程中会一直正常。反之&#xff0c;如果显示屏一上电就无法显示&#xff0c;则一直黑屏。 是…

安卓硬件访问服务

安卓硬件访问服务 硬件访问服务通过硬件抽象层模块来为应用程序提供硬件读写操作。 由于硬件抽象层模块是使用C语言开发的&#xff0c; 而应用程序框架层中的硬件访问服务是使用Java语言开发的&#xff0c; 因此&#xff0c; 硬件访问服务必须通过Java本地接口&#xff08;Jav…

vector的使用

1.构造函数 void test_vector1() {vector<int> v; //无参的构造函数vector<int> v2(10, 0);//n个value构造&#xff0c;初始化为10个0vector<int> v3(v2.begin(), v2.end());//迭代器区间初始化,可以用其他容器的区间初始化vector<int> v4(v3); //拷贝…

Java项目:基于SSM框架实现的学院党员管理系统高校党员管理系统(ssm+B/S架构+源码+数据库+毕业论文+开题)

一、项目简介 本项目是一套基于SSM框架实现的学院党员管理系统 包含&#xff1a;项目源码、数据库脚本等&#xff0c;该项目附带全部源码可作为毕设使用。 项目都经过严格调试&#xff0c;eclipse或者idea 确保可以运行&#xff01; 该系统功能完善、界面美观、操作简单、功能齐…

ConstraintLayout 特殊用法详解

1.使用百分比设置间距 app:layout_constraintHorizontal_bias"0.4" 水平偏移&#xff08;0-1&#xff09; app:layout_constraintVertical_bias"0.4" 垂直偏移 &#xff08;0-1&#xff09; <?xml version"1.0" encoding"u…

第18章 基于经验的测试技术

一、错误猜想法 &#xff08;一&#xff09;概念 错误推算法基于测试人员对以往测试项目中一些经验测试程序中的错误测试程序时&#xff0c;人们可根据经验或直觉推测程序中可能存在的各种错误&#xff0c;然后有针对性地编写检查这些错误的测试用例的方法 &#xff08;二&a…

使用MATLAB/Simulink点亮STM32开发板LED灯

使用MATLAB/Simulink点亮STM32开发板LED灯-笔记 一、STM32CubeMX新建工程二、Simulink 新建工程三、MDK导入生成的代码 一、STM32CubeMX新建工程 1. 打开 STM32CubeMX 软件&#xff0c;点击“新建工程”&#xff0c;选择中对应的型号 2. RCC 设置&#xff0c;选择 HSE(外部高…

LeetCode 69—— x 的平方根

阅读目录 1. 题目2. 解题思路一3. 代码实现一4. 解题思路二5. 代码实现二 1. 题目 2. 解题思路一 二分查找法&#xff0c;对于整数 i ∈ [ 0 , x ] i \in [0,x] i∈[0,x]&#xff0c;我们判断 i 2 i^2 i2 和 x x x 的关系&#xff0c;然后找到最后一个平方小于等于 x x x …

【 书生·浦语大模型实战营】作业(六):Lagent AgentLego 智能体应用搭建

【 书生浦语大模型实战营】作业&#xff08;六&#xff09;&#xff1a;Lagent & AgentLego 智能体应用搭建 &#x1f389;AI学习星球推荐&#xff1a; GoAI的学习社区 知识星球是一个致力于提供《机器学习 | 深度学习 | CV | NLP | 大模型 | 多模态 | AIGC 》各个最新AI方…