使用预训练模型构建自己的深度学习模型(迁移学习)

在深度学习的实际应用中,很少会去从头训练一个网络,尤其是当没有大量数据的时候。即便拥有大量数据,从头训练一个网络也很耗时,因为在大数据集上所构建的网络通常模型参数量很大,训练成本大。所以在构建深度学习应用时,通常会使用预训练模型

需求

训练一个模型来分类蚂蚁ants和蜜蜂bees

步骤

  1. 加载数据集
  2. 编写函数(训练并寻找最优模型)
  3. 编写函数(查看模型效果)
  4. 使用torchvision微调模型
  5. 使用tensorboard可视化训练情况

Pytorch保存和加载模型的两种方式

1.完整保存模型、加载模型

torch.save(net, 'mnist.pth')
net = torch.load('mnist.pth', 'map_location="cpu"')
# # Model class must be defined somewhere

load() 默认会将该张量加载到保存时所在的设备上,map_location 可以强制加载到指定的设备上

2. 保存、加载模型的状态字典(模型中的参数)

torch.save(model.state_dict(), PATH)
model = TheModelClass(*args, **kwargs)model.load_state_dict(torch.load(PATH))

迁移学习全过程

1.加载数据集

ants和bees各有约120张训练图片。

每个类有75张验证图片,从零开始在 如此小的数据集上进行训练通常是很难泛化的。

由于我们使用迁移学习,模型的泛化能力会相当好。

from __future__ import print_function, divisionimport torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
cudnn.benchmark = True  
plt.ion()

lr_scheduler:学习率调度器,用于在训练过程中动态调整学习率

torch.backends.cudnn.benchmark = True:大部分情况下,设置这个 flag 可以让内置的 cuDNN 的 auto-tuner 自动寻找最适合当前配置的高效算法,从而加速运算

plt.ion():这允许你在一个交互式环境中运行matplotlib

data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}

transforms.RandomResizedCrop(224):随机大小裁剪和缩放图像,使得裁剪后的图像尺寸为224x224像素。这个操作有助于模型学习不同尺度和长宽比的图像特征,从而提高模型的泛化能力。

transforms.RandomHorizontalFlip():以一定的概率(默认为0.5)对图像进行水平翻转。这也是一种数据增强技术,有助于模型学习对称性

transforms.Resize(256):将图像缩放到256x256像素

transforms.CenterCrop(224):从图像的中心裁剪出224x224像素的区域。这种裁剪方式确保每次裁剪的都是图像的中心部分,有助于在验证或测试时获得更一致的结果

data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

可视化数据

def imshow(inp, title=None): # 可视化一组 Tensor 的图片inp = inp.numpy().transpose((1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) inp = std * inp + mean inp = np.clip(inp, 0, 1) plt.imshow(inp) if title is not None: plt.title(title) plt.pause(0.001) # 暂停一会儿,为了将图片显示出来
# 获取一批训练数据
inputs, classes = next(iter(dataloaders['train'])) 
# 批量制作网格
out = torchvision.utils.make_grid(inputs) 
imshow(out, title=[class_names[x] for x in classes]) 

2.编写函数(训练并寻找最优模型)

def train_model(model, criterion, optimizer, scheduler, num_epochs=25): """ 训练模型,并返回在验证集上的最佳模型和准确率 - criterion: 损失函数 Return: - model(nn.Module): 最佳模型 - best_acc(float): 最佳准确率 """since = time.time()best_model_wts = copy.deepcopy(model.state_dict())best_acc = 0.0for epoch in range(num_epochs):print(f'Epoch {epoch}/{num_epochs - 1}')print('-' * 10)# 训练集和验证集交替进行前向传播for phase in ['train', 'val']:if phase == 'train':model.train()  else:model.eval()   running_loss = 0.0running_corrects = 0# 遍历数据集for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 清空梯度,避免累加了上一次的梯度optimizer.zero_grad()with torch.set_grad_enabled(phase == 'train'):# 正向传播outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 反向传播且仅在训练阶段进行优化if phase == 'train':loss.backward() optimizer.step()# 统计loss、准确率running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)if phase == 'train':scheduler.step()epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')# 发现了更优的模型,记录起来if phase == 'val' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())print()time_elapsed = time.time() - sinceprint(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')print(f'Best val Acc: {best_acc:4f}')# 加载训练的最好的模型model.load_state_dict(best_model_wts)return model

3.编写函数(查看模型效果)

def visualize_model(model, num_images=6):was_training = model.trainingmodel.eval()images_so_far = 0fig = plt.figure()with torch.no_grad():for i, (inputs, labels) in enumerate(dataloaders['val']):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)for j in range(inputs.size()[0]):images_so_far += 1ax = plt.subplot(num_images//2, 2, images_so_far)ax.axis('off')ax.set_title(f'predicted: {class_names[preds[j]]}')imshow(inputs.cpu().data[j])if images_so_far == num_images:model.train(mode=was_training)returnmodel.train(mode=was_training)

4.使用torchvision微调模型

微调步骤图

加载预训练模型,并将最后一个全连接层重置

model = models.resnet18(pretrained=True) # 加载预训练模型
num_ftrs = model.fc.in_features # 获取低级特征维度 
model.fc = nn.Linear(num_ftrs, 2) # 替换新的输出层 
model = model.to(device) 
criterion = nn.CrossEntropyLoss() 
# 所有参数都参加训练 
optimizer_ft = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) 
# 每过 7 个 epoch 将学习率变为原来的 0.1 
scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

训练

model_ft = train_model(model, criterion, optimizer_ft, scheduler, num_epochs=3)

预估

visualize_model(model_ft)

 

5.使用tensorboard可视化训练情况

将以下代码块合理的加入到训练模块里

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()
ep_losses, ep_acces = [], []ep_losses.append(epoch_loss)
ep_acces.append(epoch_acc.item())writer.add_scalars('loss', {'train': ep_losses[-2], 'val': ep_losses[-1]}, global_step=epoch)
writer.add_scalars('acc', {'train': ep_acces[-2], 'val': ep_acces[-1]}, global_step=epoch)
writer.close()

运行命令启动tensorboard

tensorboard --logdir runs

可以通过执行命令的终端看到tensorboard可视化页面地址,且该命令会在执行该命令的目录下生成一个runs文件夹(日志文件的保存位置)

以下是我训练了25个epochs的acc和loss的动态图,

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/829435.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

OS对软件的管理,进程,PCB、子进程

进程 可执行程序加载到内存中,操作系统为内个程序都形成一个PCB对象(结构体对象),PCB里存放着这个程序的所有的属性。进程可执行程序PCB ,CPU执行程序也是先通过该程序的PCB找到相应的程序代码,然后一条一…

ThinkPHP5 SQL注入漏洞敏感信息泄露漏洞

1 漏洞介绍 ThinkPHP是在中国使用极为广泛的PHP开发框架。在其版本5.0&#xff08;<5.1.23&#xff09;中,开启debug模式&#xff0c;传入的某参数在绑定编译指令的时候又没有安全处理&#xff0c;预编译的时候导致SQL异常报错。然而thinkphp5默认开启debug模式&#xff0c…

分享一些实用的工具

1、amCharts5&#xff1a;模拟航线飞行/业务分布图/k线/数据分析/地图等 网址&#xff1a; JavaScript mapping library: amCharts 5https://www.amcharts.com/javascript-maps/ Demo地址&#xff1a;Chart Demos - amChartshttps://www.amcharts.com/demos/#maps 他分为amC…

小龙虾优化算法(Crayfish Optimization Algorithm,COA)

小龙虾优化算法&#xff08;Crayfish Optimization Algorithm&#xff0c;COA&#xff09; 前言一、小龙虾优化算法的实现1.初始化阶段2.定义温度和小龙虾的觅食量3.避暑阶段&#xff08;探索阶段&#xff09;4.竞争阶段&#xff08;开发阶段&#xff09;5.觅食阶段&#xff08…

【誉天战报】3月HCIE战报火热来袭!新增45位同学通过认证!

2024年3月&#xff0c;誉天教育共有45名学员顺利通过了HCIE认证&#xff0c;其中&#xff1a;云计算20人、数通18人、存储5人、云服务2人。让我们一起祝贺他们吧~ 誉天教育是华为优选级授权培训合作伙伴&#xff0c;专业从事华为授权认证课程实战技能培训。连续13年荣获“华为优…

和林曦老师一起读书吧 | 愿我们:只生欢喜不生愁

今天&#xff0c;想和你一起来读书&#xff0c;林曦老师的《只生欢喜不生愁》。    这本书的名字很有意味&#xff0c;它来自于清代《养真集》中的一句话&#xff1a;自古神仙无别法&#xff0c;只生欢喜不生愁。      我们会羡慕这样的状态&#xff1a;只生欢喜不生愁…

018基于SSM的音乐系统网站

018基于SSM的音乐系统/网站 开发环境&#xff1a; Jdk7(8)Tomcat7(8)MysqlIntelliJ IDEA(Eclipse)Maven 数据库&#xff1a; MySQL 技术&#xff1a; SpringSpring mvcMybatisJqueryVideo jsJSPJSTLEasyUI 适用于&#xff1a; 课程设计&#xff0c;毕业设计&#xff0c;学习…

“你需要TrustedInstaller提供的权限才能对此文件进行更改” 解决方案

转载地址 【“你需要TrustedInstaller提供的权限才能对此文件进行更改” 解决方案-CSDN博客】

37.WEB渗透测试-信息收集-企业信息收集(4)

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 内容参考于&#xff1a; 易锦网校会员专享课 上一个内容&#xff1a;36.WEB渗透测试-信息收集-企业信息收集&#xff08;3&#xff09;-CSDN博客 关于主域名收…

Git -- 运用总结

文章目录 1. Git2. 基础/查阅2.1 基础/查阅 - git2.2 仓库 - remote2.3 清理 - rm/clean2.4 版本回退 - reset 3. 分支3.1 分支基础 - branch3.2 分支暂存更改 - stash3.3 分支切换 - checkout 4. 代码提交/拉取4.1 代码提交 - push4.2 代码拉取 - pull 1. Git 2. 基础/查阅 2…

(学习日记)2024.05.07:UCOSIII第六十一节:User文件夹函数概览(uCOS-III->Source文件夹)第七部分

写在前面&#xff1a; 由于时间的不足与学习的碎片化&#xff0c;写博客变得有些奢侈。 但是对于记录学习&#xff08;忘了以后能快速复习&#xff09;的渴望一天天变得强烈。 既然如此 不如以天为单位&#xff0c;以时间为顺序&#xff0c;仅仅将博客当做一个知识学习的目录&a…

RISC-V CVA6 在 Linux 下相关环境下载与安装

RISC-V CVA6 在 Linux 下相关环境下载与安装 所需环境与源码下载 CVA6 源码下载 首先&#xff0c;我们可以直接从 GitHub 一次性拉取所有源码&#xff1a; git clone --recursive https://github.com/openhwgroup/cva6.git如果这里遇到网络问题&#xff0c;拉取失败&#x…

Fluent.Ribbon创建Office的RibbonWindow菜单

链接&#xff1a; Fluent.Ribbon文档 优势&#xff1a; 1. 可以创建类似Office办公软件的复杂窗口&#xff1b; 2. 可以应用自定义主题风格界面

实现 <el-cascader> 组件的回显功能

vue A页面&#xff0c;用户填写了el-cascader多层级数据&#xff0c;层级list数据从接口获取&#xff1b; vue B页面&#xff0c;多层级数据要进行回显&#xff0c;接口给到的数据是value值&#xff1b; 直接看demo <template><div><el-cascaderv-model"…

android studio 编译一直显示Download maven-metadata.xml

今天打开之前的项目的时候遇到这个问题:android studio 编译一直显示Download maven-metadata.xml, AI 查询 报错问题&#xff1a;"android studio 编译一直显示Download maven-metadata.xml" 解释&#xff1a; 这个错误通常表示Android Studio在尝试从Maven仓库…

ReactJS中使用TypeScript

TypeScript TypeScript 实际上就是具有强类型的 JavaScript&#xff0c;可以对类型进行强校验&#xff0c;好处是代码阅读起来比较清晰&#xff0c;代码类型出现问题时&#xff0c;在编译时就可以发现&#xff0c;而不会在运行时由于类型的错误而导致报错。但是&#xff0c;从…

中国各银行流动性比例数据集(2000-2022年)

01、数据简介 银行流动性比例是指银行的流动性资产期末余额与流动性负债期末余额之比&#xff0c;用于衡量银行流动性的总体水平。这个比例越高&#xff0c;表明银行偿还短期债务的能力越强&#xff0c;流动性风险越小。 本数据覆盖到城市商业银行、城镇银行、大型商业银行、…

张小泉签约实在智能,用实在Agent打造自动化高

在不少老杭州人的童年记忆里&#xff0c;妈妈裁剪衣服、料理食材、修剪各种物品&#xff0c;用的都是张小泉刀剪。 近日&#xff0c;实在智能与“刀剪第一股”张小泉&#xff08;股票代码&#xff1a;301055.SZ&#xff09;正式达成合作&#xff0c;实在Agent数字员工助力张小…

【工具】-根源上解决VScode打印输出乱码的问题

目录 1 第一步&#xff1a; 改编译命令&#xff0c;保持一致2 第二步&#xff1a; 更改VScode的编码格式-保持一致 1 第一步&#xff1a; 改编译命令&#xff0c;保持一致 看一下你的控制台的编译的命名后缀&#xff0c;有两个关键的参数&#xff0c;如下图&#xff1a; “-f…

不同路径 1 2

class Solution {public int uniquePaths(int m, int n) {int[][] dpnew int[m][n];//记录到每个格子有多少种路径for(int i0;i<m;i) dp[i][0]1;for(int j0;j<n;j) dp[0][j]1;//初始化for(int i1;i<m;i){for(int j1;j<n;j){dp[i][j]dp[i-1][j]dp[i][j-1];}}return …