strict=False 但还是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur

strict=False 但还是size mismatch for []: copying a param with shape [] from checkpoint,the shape in cur

问题

我们知道通过

model.load_state_dict(state_dict, strict=False)

可以暂且忽略掉模型和参数文件中不匹配的参数,先将正常匹配的参数从文件中载入模型。

笔者在使用时遇到了这样一个报错:

RuntimeError: Error(s) in loading state_dict for ViT_Aes:size mismatch for mlp_head.1.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([10, 768]).size mismatch for mlp_head.1.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([10]).

一开始笔者很奇怪,我已经写明strict=False了,不匹配参数的不管就是了,为什么还要给我报错。

原因及解决方案

经过笔者仔细打印模型的键和文件中的键进行比对,发现是这样的:strict=False可以保证模型中的键与文件中的键不匹配时暂且跳过不管,但是一旦模型中的键和文件中的键匹配上了,PyTorch就会尝试帮我们加载参数,就必须要求参数的尺寸相同,所以会有上述报错。

比如在我们需要将某个预训练的模型的最后的全连接层的输出的类别数替换为我们自己的数据集的类别数,再进行微调,有时会遇到上述情况。这时,我们知道全连接层的参数形状会是不匹配,比如我们加载 ImageNet 1K 1000分类的预训练模型,它的最后一层全连接的输出维度是1000,但如果我们自己的数据集是10分类,我们需要将最后一层全链接的输出维度改为10。但是由于键名相同,所以PyTorch还是尝试给我们加载,这时1000和10维度不匹配,就会导致报错。

解决方案就是我们将 .pth 模型文件读入后,将其中我们不需要的层(通常是最后的全连接层)的参数pop掉即可。

以 ViT 为例子,假设我们有一个 ViT 模型,并有一个参数文件 vit-in1k.pth,它里面存储着 ViT 模型在 ImageNet-1K 1000分类数据集上训练的参数,而我们要在自己的10分类数据集上微调这个模型。

model = ViT(num_classes=10)
ckpt = torch.load('vit-in1k.pth', map_location='cpu')
msg = model.load_state_dict(ckpt, strict=False)
print(msg)

直接这样加载会出错,就是上面的错误:

	size mismatch for head.weight: copying a param with shape torch.Size([1000, 768]) from checkpoint, the shape in current model is torch.Size([10, 768]).size mismatch for head.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([10]).

我们将最后 pth 文件加载进来之后(即 ckpt) 中全连接层的参数直接pop掉,至于需要pop掉哪些键名,就是上面报错信息中提到了的,在这里就是 head.weighthead.bias

ckpt.pop('head.weight')
ckpt.pop('head.bias')

之后在运行,会发现我们打印的 msg 显示:

_IncompatibleKeys(missing_keys=['head.weight', 'head.bias'], unexpected_keys=[])

即缺失了head.weighthead.bias 这两个参数,这是正常的,因为在自己的数据集上微调时,我们本就不需要这两个参数,并且已经将它们从模型文件字典 ckpt 中pop掉了。现在,模型全连接之前的层(通常即所谓的特征提取层)的参数已经正常加载了,接下来可以在自己的数据集上进行微调。

因为反正我们也不用这些参数,就直接把这个键值对从字典中pop掉,以免 PyTorch 在帮我们加载时试图加载这些维度不匹配,我们也不需要的参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/532802.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux中权限765啥意思,Linux中的文件权限

Linux系统中的每一个文件都与多种权限类型相关联。在这些权限中,我们主要和三类权限打交道:用户(user)、用户组(group)和其他用户(others)。用户是文件的所有者;用户组是指和文件所有者在同一组的其他多个用户的集合;其他用户是除…

CV中的色彩空间大全

HSI、HSV、RGB、CMY、CMYK、HSL、HSB、Ycc、XYZ、Lab、YUV颜色模型 HSV颜色空间 HSV(hue,saturation,value)颜色空间的模型对应于圆柱坐标系中的一个圆锥形子集,圆锥的顶面对应于V1. 它包含RGB模型中的R1,G1,B1 三个面,所代表的…

linux 系统调用时怎么知道当前上下文属于那个进程,linux – 编写系统调用来计算进程的上下文切换...

如果您的系统调用只应报告统计信息,则可以使用内核中已有的上下文切换计数代码.struct rusage {...long ru_nvcsw; /* voluntary context switches */long ru_nivcsw; /* involuntary context switches */};您可以通过运行来尝试:$/usr/bin/time -v /bin/ls -R....V…

linux串口缓冲区的大小,linux-----------串口设置缓冲器的大小

转自:http://stackoverflow.com/questions/10815811/linux-serial-port-reading-can-i-change-size-of-input-bufferYou want to use the serial IOCTL TIOCSSERIAL which allows changing both receive buffer depth and send buffer depth (among other things). The maximum…

FLOPs、FLOPS、Params的含义及PyTorch中的计算方法

FLOPs、FLOPS、Params的含义及PyTorch中的计算方法 含义解释 FLOPS:注意全大写,是floating point operations per second的缩写(这里的大S表示second秒),表示每秒浮点运算次数,理解为计算速度。是一个衡量…

设置中文linux输入ubuntu,Linux_ubuntu怎么设置成中文?ubuntu中文设置图文方法,  很多朋友安装ubuntu后,发 - phpStudy...

ubuntu怎么设置成中文?ubuntu中文设置图文方法很多朋友安装ubuntu后,发现都是英文,看不懂要怎么办?其实ubuntu是可以设置成中文的,下文小编就为大家带来ubuntu中文的设置方法,一起去看下设置方法吧。ubuntu中文设置方…

科普 | 单精度、双精度、多精度和混合精度计算的区别是什么?

科普 | 单精度、双精度、多精度和混合精度计算的区别是什么? 转自:https://zhuanlan.zhihu.com/p/93812784 我们提到圆周率 π 的时候,它有很多种表达方式,既可以用数学常数3.14159表示,也可以用一长串1和0的二进制长串表示。 …

linux 磁盘分配 简书,linux 磁盘分区

1物理磁盘的构成: 盘面:由一圈一圈的磁道组成机械手臂:读取数据主轴马达:帮助机械手臂转动2 扇区:磁盘上存取数据的最小单位512字节按照扇区分配大小,如果数据只有一字节也会占用512字节簇:用若…

条件控制与条件传送详解

条件控制与条件传送详解 提要 CSAPP3e中文译本 3.6.5 用条件控制来实现条件分支 3.6.6 用条件传送来实现条件分支 CSAPP3e第三章前面主要是介绍了机器级代码的二进制形式和汇编形式、反汇编、x86汇编的基础指令、条件码及其访问方式等。 在介绍到汇编语言的条件分支时分了两…

联合体(union)的使用方法及其本质

联合体(union)的使用方法及其本质 转自:https://blog.csdn.net/huqinwei987/article/details/23597091 有些基础知识快淡忘了,所以有必要复习一遍,在不借助课本死知识的前提下做些推理判断,温故知新。 1…

linux设备驱动之串口移植,Linux设备驱动之UART驱动结构

一、对于串口驱动Linux系统中UART驱动属于终端设备驱动,应该说是实现串口驱动和终端驱动来实现串口终端设备的驱动。要了解串口终端的驱动在Linux系统的结构就先要了解终端设备驱动在Linux系统中的结构体系,一方面自己了解的不够,另一发面关于…

linux python复制安装,复制一个Python全部环境到另一个环境,python另一个,导出此环境下安装的包...

复制一个Python全部环境到另一个环境,python另一个,导出此环境下安装的包导出此环境下安装的包的版本信息清单pipfreeze>requirements.txt联网,下载清单中的包到all-packet文件夹[[email protected] ~]# pip download -d ./all-packet -r requirement…

NVIDIA英伟达的Multi-GPU多卡通信框架NCCL

NVIDIA英伟达的Multi-GPU多卡通信框架NCCL 笔者注:NCCL 开源项目地址:https://github.com/NVIDIA/nccl 转自:https://www.zhihu.com/question/63219175/answer/206697974 NCCL是Nvidia Collective multi-GPU Communication Library的简称&…

C语言n个坐标点间的最大距离,c语言已知两点坐标,求另一点到穿过这两点的直线最短距离。...

c语言已知两点坐标,求另一点到穿过这两点的直线最短距离。以下文字资料是由(历史新知网www.lishixinzhi.com)小编为大家搜集整理后发布的内容,让我们赶快一起来看一下吧!c语言已知两点坐标,求另一点到穿过这两点的直线最短距离。#…

[分布式训练] 单机多卡的正确打开方式:理论基础

[分布式训练] 单机多卡的正确打开方式:理论基础 转自:https://fyubang.com/2019/07/08/distributed-training/ 瓦砾由于最近bert-large用的比较多,踩了很多分布式训练的坑,加上在TensorFlow和PyTorch之间更换,算是熟…

s3c2416开发板 linux,S3C2416移植内核Linux3.1的wm9713声卡过程

移植内核的声卡驱动。原因没有声卡驱动,WM9713声卡驱动移植(原来的内核有UDA1341声卡驱动,我们再次基础上直接修改)1、直接复制内核得到三个文件:s3c2416_wm9713.c , wm9713.c , s3c2416_ac97.c.linux-3.1\sound\soc\codecs\Wm9713.c---->wm9713.c;li…

Linux查看文件内容命令:cat, tail, head, more, less

Linux查看文件内容命令:cat, tail, head, more, less cat 直接显示整个文件。 cat直接显示全部文件内容,没有换页等交互。 cat filenamemore more命令,功能类似 cat ,cat命令是整个文件的内容从上到下显示在屏幕上。 more会…

linux查看队列 msg,linux第10天 msg消息队列

cat /proc/sys/kernel/msgmax最大消息长度限制cat /proc/sys/kernel/msgmnb消息队列总的字节数cat /proc/sys/kernel/msgmni消息条目数消息队列综合案例//server#include #include #include #include #include #include #include #include #define ERR_EXIT(m)do{perror(m);}wh…

Linux中 C++ main函数参数argc和argv含义及用法

Linux中 C main函数参数argc和argv含义及用法 简介 argc 是 argument count的缩写,表示传入main函数的参数个数; argv 是 argument vector的缩写,表示传入main函数的参数序列或指针,并且第一个参数argv[0]一定是程序的名称&…

c语言六位抢答器课程设计,51单片机八路抢答器课程设计

;说明:本人的这个设计改进后解决了前一个版本中1号抢答优先的问题,并增加了锦囊的设置,当参赛选手在回答问题时要求使用锦囊,则主持人按下抢答开始键,计时重新开始。;八路抢答器电路请看下图是用ps仿真的,已…