Pytorch 加载部分预训练模型并冻结某些层

目录

1  pytorch的版本:

2  数据下载地址:

3  原始版本代码下载:

4  直接上代码:


 

1  pytorch的版本:

2  数据下载地址:

<https://download.pytorch.org/tutorial/hymenoptera_data.zip>

3  原始版本代码下载:

https://pytorch.org/tutorials/_downloads/transfer_learning_tutorial.py

 

4  直接上代码:

# -*- coding: utf-8 -*-
# @File    : test4.py
# @Blog    : https://blog.csdn.net/caomin1haofrom __future__ import print_function, divisionimport torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copydevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")plt.ion()   # interactive mode######################################################################
# 1.定义模型,  2.加载部分预训练数据,  3.冻结部分层
######################################
#1.定义模型
model_conv = models.resnet18()
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)'''
#打印模型的结构
print('###打印模型model_conv的结构####')
print(model_conv)
print('\n')print('###打印模型model_conv加载参数前的初始值####')
print(list(model_conv.parameters()))
print('\n')
'''#############################################
#2.加载部分预训练数据
pretrained_dict = torch.load('./08 transfer_learning/resnet18-5c106cde.pth')
'''
for k,v in pretrained_dict.items():print(k)
'''
#删除预训练模型跟当前模型层名称相同,层结构却不同的元素;这里有两个'fc.weight'、'fc.bias'
pretrained_dict.pop('fc.weight')
pretrained_dict.pop('fc.bias')#自己的模型参数变量
model_dict = model_conv.state_dict()
#去除一些不需要的参数
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#参数更新
model_dict.update(pretrained_dict)# 加载我们真正需要的state_dict
model_conv.load_state_dict(model_dict)'''
print('###打印模型model_conv加载参数后的参数值####')
print(list(model_conv.parameters()))
print('\n')
'''
#############################################
#3.冻结部分层
#将满足条件的参数的 requires_grad 属性设置为False
for name, value in model_conv.named_parameters():if (name != 'fc.weight') and (name != 'fc.bias'):value.requires_grad = False
'''
#打印各层的requires_grad属性
print('###打印模型model_conv参数的requires_grad属性####')
for name, param in model_conv.named_parameters():print(name,param.requires_grad)
'''# filter 函数将模型中属性 requires_grad = True 的参数选出来
params_conv = filter(lambda p: p.requires_grad, model_conv.parameters())
model_conv = model_conv.to(device)criterion = nn.CrossEntropyLoss()# Observe that only parameters of final layer are being optimized as
# opoosed to before.
optimizer_conv = optim.SGD(params_conv, lr=0.001, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)######################################################################
# Training the model
#编写一个通用函数来训练模型。
# 下面将说明: * 调整学习速率 * 保存最好的模型
#下面的参数scheduler是一个来自 torch.optim.lr_scheduler 的学习速率调整类的对象(LR scheduler object)。def train_model(model, criterion, optimizer, scheduler, num_epochs=25):since = time.time()best_model_wts = copy.deepcopy(model.state_dict())best_acc = 0.0for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 每个epoch都有一个训练和验证阶段for phase in ['train', 'val']:if phase == 'train':scheduler.step()model.train()  # Set model to training modeelse:model.eval()   # Set model to evaluate moderunning_loss = 0.0running_corrects = 0#  迭代数据.for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# zero the parameter gradientsoptimizer.zero_grad()# forward# track history if only in trainwith torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 后向+仅在训练阶段进行优化if phase == 'train':loss.backward()optimizer.step()# statisticsrunning_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# 深度复制moif phase == 'val' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# 加载最佳模型权重model.load_state_dict(best_model_wts)return model######################################################################
# 可视化部分训练图像,以便了解数据扩充。def imshow(inp, title=None):"""Imshow for Tensor."""inp = inp.numpy().transpose((1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)plt.imshow(inp)if title is not None:plt.title(title)plt.pause(0.001)  # pause a bit so that plots are updated######################################################################
# Visualizing the model predictions
# 一个通用的展示少量预测图片的函数def visualize_model(model, num_images=6):was_training = model.trainingmodel.eval()images_so_far = 0fig = plt.figure()with torch.no_grad():for i, (inputs, labels) in enumerate(dataloaders['val']):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)for j in range(inputs.size()[0]):images_so_far += 1ax = plt.subplot(num_images//2, 2, images_so_far)ax.axis('off')ax.set_title('predicted: {}'.format(class_names[preds[j]]))imshow(inputs.cpu().data[j])if images_so_far == num_images:model.train(mode=was_training)returnmodel.train(mode=was_training)######################################################################
#训练集数据扩充和归一化
#在验证集上仅需要归一化
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}data_dir = './08 transfer_learning/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")if __name__ == '__main__':# Train and evaluate 2# 训练模型 在CPU上,与前一个场景相比,这将花费大约一半的时间,因为不需要为大多数网络计算梯度。但需要计算转发。model_conv = train_model(model_conv, criterion, optimizer_conv,exp_lr_scheduler, num_epochs=11)visualize_model(model_conv)plt.ioff()plt.show()

 

部分运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/350250.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

INT类型知多少

前言&#xff1a; 整型是MySQL中最常用的字段类型之一&#xff0c;通常用于存储整数&#xff0c;其中int是整型中最常用的&#xff0c;对于int类型你是否真正了解呢&#xff1f;本文会带你熟悉int类型相关知识&#xff0c;也会介绍其他整型字段的使用。 1.整型分类及存储范围 整…

altera fpga 型号说明_A/X家FPGA架构及资源评估

欢迎FPGA工程师加入官方微信技术群点击蓝字关注我们FPGA之家-中国最好最大的FPGA纯工程师社群评估对比xilinx以及altera两家FPGA芯片逻辑资源。首先要说明&#xff0c;现今FPGA除了常规逻辑资源&#xff0c;还具有很多其他片内资源比如块RAM、DSP单元、高速串行收发器、PLL、AD…

guava api_使用Google Guava的订购API

guava api我们在Google的Guava库中玩的更多&#xff0c;这真是一个了不起的库&#xff01; 我们用于它的最新内容是为我们的域对象整理比较器。 这是如何做。 使用Apache Isis的JDO Objectstore &#xff0c;使您的类实现java.lang.Comparable &#xff0c;并对集合使用SortedS…

Pytorch 加载和保存模型

目录 保存和加载模型 1. 什么是状态字典&#xff1a;state_dict? 2.保存和加载推理模型 2.1 保存/加载 state_dict &#xff08;推荐使用&#xff09; 2.2 保存/加载完整模型 3. 保存和加载 Checkpoint 用于推理/继续训练 4. 在一个文件中保存多个模型 5. 使用在不同…

02-CSS基础与进阶-day9_2018-09-12-20-29-40

定位 静态定位 position: static 相对定位 position: relative 绝对定位 position: absolute 脱标 参考点 子绝父相 让绝对定位的盒子水平居中和垂直居中 固定定位 position: fixed 参考点 浏览器左上角 固定定位的元素脱标不占有位置 兼容性 ie6低版本不支持固定定位 02绝对…

activity直接销毁_Android -- Activity的销毁和重建

两种销毁第一种是正常的销毁&#xff0c;比如用户按下Back按钮或者是activity自己调用了finish()方法&#xff1b;另一种是由于activity处于stopped状态&#xff0c;并且它长期未被使用&#xff0c;或者前台的activity需要更多的资源&#xff0c;这些情况下系统就会关闭后台的进…

Storm和Kafka集成的重要生产错误和修复

我将在此处描述Storm和Kafka集成模块的一些细节&#xff0c;一些您应该意识到的重要错误以及如何克服其中的一些错误&#xff08;尤其是对于生产安装&#xff09;。 我在生产安装中大量使用Apache Storm&#xff0c;并将Kafka作为主要输入源&#xff08;Spout&#xff09;。 …

博客园背景设置CSS代码

/配色参考->>->>>//https://zh.spycolor.com/color-index,a*/ #home { margin: 0 auto; width: 90%;/原始65/ min-width: 980px;/页面顶部的宽度/ background-color:rgba(233,214,107,0.3);/博客主页主体框的颜色/ padding: 30px; margin-top: 25px; margin-bot…

matplotlib 画多条折线图且x轴下标非数值

直接上python代码&#xff1a; # -*- coding: utf-8 -*- import matplotlib.pyplot as plt names [GFK, SA, DA-NBNN, DLID, DaNN, Ours] x range(len(names))y_1 [0.464, 0.45, 0.528, 0.519, 0.536, 0.841] y_2 [0.613, 0.648, 0.766, 0.782, 0.712, 0.954] y_3 [0.663…

julia常用矩阵函数_Julia系列教程3 数学运算 矩阵运算

数学运算https://www.zhihu.com/video/1113554595376295936数学运算比Matlab更直观的数学表达方式x 102x>>20但这就导致了可能会出现语法的冲突十六进制整数文本表达式 0xff 可以被解析为数值文本 0 乘以变量 xff浮点数文本表达式 1e10 可以被解析为数值文本 1 乘以变量…

Mysql 模糊查询 转义字符

MySQL的转义字符“\”\0 一个ASCII 0 (NUL)字符。 \n 一个新行符。 \t 一个定位符。 \r 一个回车符。 \b 一个退格符。 \ 一个单引号(“”)符。 \ " 一个双引号(“ "”)符。 \\ 一个反斜线(“\”)符。 \% 一个…

Pytorch LSTM初识(详解LSTM+torch.nn.LSTM()实现)1

pytorch LSTM1初识 目录 pytorch LSTM1初识 ​​​​​​​​​​​​​​​​​​​​​ 一、LSTM简介1

Java命令行界面(第8部分):Argparse4j

Argparse4j是“ Java命令行参数解析器库”&#xff0c;其主页描述为“基于Python的argparse模块的Java命令行参数解析器库”。 在本文中&#xff0c;我将简要介绍如何使用Argparse4j 0.7.0处理命令行参数&#xff0c;该参数与本系列中的前七篇有关Java命令行处理的文章中所解析…

MvvmLight框架使用入门(三)

MvvmLight框架使用入门&#xff08;三&#xff09; 本篇是MvvmLight框架使用入门的第三篇。从本篇开始&#xff0c;所有代码将通过Windows 10的Universal App来演示。我们将创建一个Universal App并应用MvvmLight框架。 首先通过VS2015创建一个名为UniversalApp的空工程&#x…

Pytorch LSTM实例2

对Pytorch中LSTM实例稍作修改,这是一个词性标注的实例 #导入相应的包 import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optimtorch.manual_seed(1)#准备数据的阶段 def prepare_sequence(seq, to_ix):idxs = [to_ix[w] for w in …

java更好的语言_Java,如果这是一个更好的世界

java更好的语言只是梦想着有一个更好的世界&#xff0c;在该世界中&#xff0c;Java平台中的一些旧错误已得到纠正&#xff0c;而某些令人敬畏的缺失功能也已实现。 不要误会我的意思。 我认为Java很棒。 但是它仍然存在一些问题&#xff0c;就像其他平台一样。 我没有任何特定…

anaconda安装成功测试_学习笔记120—Win10 成功安装Anaconda 【亲测有效,需注意几点!!!】...

Win10 下安装 Anaconda一、下载安装 Anaconda(勾选 PATH)&#xff1a;Anaconda 是专注于数据分析的 Python 发行版本&#xff0c;包含了 conda、Python 等 190 多个科学包及其依赖项。使用 Anaconda 的好处在于可以省去很多配置环境的步 骤&#xff0c;省时省心又便于分析。下载…

Pytorch 词嵌入word_embedding1初识

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)参数所表示的含义: num_embeddings (int) :嵌入字典的大小 embedding_dim (int) :每个嵌入向量的大小 padding_idx (int, optio…

Python语言 目录

待续.... 转载于:https://www.cnblogs.com/jiangchunsheng/p/11077884.html

JDK 9清单:Project Jigsaw,sun.misc.Unsafe,G1,REPL等

Java 9距离&#xff08;希望&#xff09;数月了&#xff0c;现在该讨论一下即将发生的变化以及您应该采取的措施 Java 9即将来临&#xff08;我们正在计算到达的日子 &#xff09;&#xff0c;其中包含一系列新功能和改进功能。 这就是为什么我们决定创建一份清单来准备自己的…