Pytorch入门实战 P06-调用vgg16模型,进行人脸预测

目录

1、本文内容:

1、内容:

2、简单介绍下VGG16:

3、相关其他模型也可以调用:

2、代码展示:

3、训练结果:

1、不同优化器:

①【使用SGD优化器】

②【使用Adam优化器】

③Adam + 动态学习率ExponentialLR

④Adam + 动态学习率ExponentialLR+ 降低初始学习率(lr=0.001)

⑤Adam+动态学习率LinearLR

⑥Adam+动态学习率LinearLR+ 降低c初始学习率(lr=0.001)

4、总结


  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

1、本文内容:

1、内容:

这篇文章,主要是通过调用现有VGG16的模型,来完成人脸的预测。

这篇文章的亮点主要是提高测试集的精确度

2、简单介绍下VGG16:

VGG-16的主要特点:
        1、深度:VGG-16 = 16个卷积层+3个全连接层组成 ,因此具有相对较深的网络结构。这种深度有助于网络学习到更加抽象和复杂的特征。
        2、卷积层的设计:VGG-16的卷积层全部采用3x3的卷积核和步长为1的卷积操作,同时在卷积层之后都有接ReLU激活函数。这种设计的好处在于,通过堆叠多个较小的卷积核,可以提高网络的非线性建模能力,同时减少了参数数量, 从而降低了过拟合的风险。
        3、池化层:在卷积层之后,VGG-16使用最大池化层来减少特征图的空间尺寸,帮助提取更加显著的特征并减少计算量。
        4、全连接层:VGG-16在卷积层之后接有3个全连接层,最后一个全连接层输出与类别数相对应的向量,用于进行分类。

VGG-16结构说明:
        13个卷积层,分别用blockX-convX表示
        3个全连接层,用classifier表示
        5个池化层。

3、相关其他模型也可以调用:

 Pytorch官网链接地址

2、代码展示:

import copy
import pathlib
import warningsimport matplotlib.pyplot as plt
import torch
from PIL import Image
from torch import nn
from torchvision import datasets
from torchvision.models import vgg16, VGG16_Weights
from torchvision.transforms import transforms
import matplotlib as mplmpl.use('Agg')  # 在服务器上运行的时候,打开注释# 检查设备
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)# 1、导入数据
data_dir = './data'
data_dir = pathlib.Path(data_dir)
print(data_dir)data_paths = list(data_dir.glob('*'))
classNames = [str(path).split('/')[1] for path in data_paths]# 图像预处理
train_transforms = transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])total_data = datasets.ImageFolder('./data', transform=train_transforms)# 划分数据集
train_size = int(0.8 * len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
print(train_size, test_size)  # 1440  360# 数据加载
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)for X, y in test_dl:print('Shape of X [N,C,H,W]:', X.shape)  # [32, 3, 224, 224]print('Shape of y:', y.shape, y.dtype)  # torch.Size([32]) torch.int64break# 调用官方的VGG-16模型
"""VGG-16的主要特点:1、深度:VGG-16 = 16个卷积层+3个全连接层组成 ,因此具有相对较深的网络结构。这种深度有助于网络学习到更加抽象和复杂的特征。2、卷积层的设计:VGG-16的卷积层全部采用3x3的卷积核和步长为1的卷积操作,同时在卷积层之后都有接ReLU激活函数。这种设计的好处在于,通过堆叠多个较小的卷积核,可以提高网络的非线性建模能力,同时减少了参数数量,从而降低了过拟合的风险。3、池化层:在卷积层之后,VGG-16使用最大池化层来减少特征图的空间尺寸,帮助提取更加显著的特征并减少计算量。4、全连接层:VGG-16在卷积层之后接有3个全连接层,最后一个全连接层输出与类别数相对应的向量,用于进行分类。VGG-16结构说明:13个卷积层,分别用blockX-convX表示3个全连接层,用classifier表示5个池化层。
"""
# 加载预训练模型,并且对模型进行微调。
model = vgg16(weights=VGG16_Weights.DEFAULT).to(device)for param in model.parameters():param.requires_grad = False  # 冻结模型参数,这样在训练的时候只训练最后一层的参数。# print("原始模型:",model)
# 修改classifier模块的第6层。即:(6): Linear(in_features=4096, out_features=1000, bias=True)
model.classifier[6] = nn.Linear(4096, len(classNames))  # 修改vgg16 模型中最后一层全连接层,输出目标类别个数
model.to(device)# print("改完后的模型:",model)# 编写训练函数
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)num_batches = len(dataloader)train_loss, train_acc = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)loss = loss_fn(pred, y)# 反向传播optimizer.zero_grad()  # grad梯度归零loss.backward()  # 反向传播optimizer.step()  # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss# 编写测试函数
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)test_acc, test_loss = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss# 设置动态学习率
learn_rate = 1e-4  # 初始学习率# 调用官方动态学习率接口时使用:
lambda1 = lambda epoch: 0.92 ** (epoch // 4)
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)  # 选定调整学习率的方法# 正式训练
loss_fn = nn.CrossEntropyLoss()  # 创建损失函数
epochs = 40
train_loss = []
train_acc = []
test_loss = []
test_acc = []best_acc = 0  # 设置一个最佳准确率,作为最佳模型的判别指标。
for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)scheduler.step()  # 用于更新学习率(调用官网动态学习率接口的时候,在这里使用)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)# 保存最佳模型到best_modelif epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f},|||| Test_acc:{:.1f}%,Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss, lr))# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)
print('Done')# 结果可视化
warnings.filterwarnings('ignore')
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label="Training Accuracy")
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title("Training and Validation Accuracy")plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validataion Loss')
plt.savefig("/data/jupyter/deepinglearning_train_folder/p06_vgg16/resultImg.jpg")  # 保存图片在服务器的位置
plt.show()# 指定图片进行预测
classes = list(total_data.class_to_idx)def predict_one_image(image_path, model, transform, classes):test_img = Image.open(image_path).convert('RGB')plt.imshow(test_img)  # 展示预测图片test_img = transform(test_img)img = test_img.to(device).unsqueeze(0)model.eval()output = model(img)_, pred = torch.max(output, 1)print(_, pred)pred_class = classes[pred]print(f'预测结果:{pred_class}')# 预测训练集中的某张照片
predict_one_image(image_path='./data/Angelina Jolie/001_fe3347c0.jpg', model=model, transform=train_transforms,classes=classes)# 评估模型
best_model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)
print(f'模型评估:测试acc:{epoch_test_acc}-----测试Loss:{epoch_test_loss}')

3、训练结果:

1、不同优化器:

对比下,目前主流的两个优化器:SGD和Adam优化器。

①【使用SGD优化器】

测试精确度达到18%。

②【使用Adam优化器】

测试精确达到39%。

③Adam + 动态学习率ExponentialLR

测试精确度达到43%。

④Adam + 动态学习率ExponentialLR+ 降低初始学习率(lr=0.001)

测试精确度达到48%

⑤Adam+动态学习率LinearLR

测试精确度达到43%。

⑥Adam+动态学习率LinearLR+ 降低c初始学习率(lr=0.001)

测试精确度达到48%,最高可达到51%。

4、总结

①将SGD优化器换成Adam优化器,精确度会提升1倍。

②使用Adam+动态学习率(即:③、⑤,精度会再次提升。)

③使用③里的的配置,仅改变学习率(lr=1e-4→lr=1e-3),测试精度会再次提升,见④。

④使用⑤里的的配置,仅改变学习率(lr=1e-4→lr=1e-3),测试精度会再次提升,见⑥。

总结上述,测试精确度的提升,最大是优化器的改变、动态学习率、初始学习率的降低。

这些都会影响到模型的精确度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/821973.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(BERT蒸馏)TinyBERT: Distilling BERT for Natural Language Understanding

文章链接:https://arxiv.org/abs/1909.10351 背景 在自然语言处理(NLP)领域,预训练语言模型(如BERT)通过大规模的数据训练,已在多种NLP任务中取得了卓越的性能。尽管BERT模型在语言理解和生成…

深度学习 Lecture 7 迁移学习、精确率、召回率和F1评分

一、迁移学习(Transfer learning) 用来自不同任务的数据来帮助我解决当前任务。 场景:比如现在我想要识别从0到9度手写数字,但是我没有那么多手写数字的带标签数据。我可以找到一个很大的数据集,比如有一百万张图片的猫、狗、汽…

论文笔记:(INTHE)WILDCHAT:570K CHATGPT INTERACTION LOGS IN THE WILD

iclr 2024 spotlight reviewer 评分 5668 1 intro 由大型语言模型驱动的对话代理(ChatGPT,Claude 2,Bard,Bing Chat) 他们的开发流程通常包括三个主要阶段 预训练语言模型在被称为“指令调优”数据集上进行微调&…

JDK5.0新特性

目录 1、JDK5特性 1.1、静态导入 1.2 增强for循环 1.3 可变参数 1.4 自动装箱/拆箱 1.4.1 基本数据类型包装类 1.5 枚举类 1.6 泛型 1.6.1 泛型方法 1.6.2 泛型类 1.6.3 泛型接口 1.6.4 泛型通配符 1、JDK5特性 JDK5中新增了很多新的java特性,利用这些新…

v-for中涉及的key

一、为什么要用key? key可以标识列表中每个元素的唯一性,方便Vue高效地更新虚拟DOM;key主要用于dom diff算法,diff算法是同级比较,比较当前标签上的key和标签名,如果都一样,就只移动元素&#…

【刷题笔记】第七天

文章目录 [924. 尽量减少恶意软件的传播](https://leetcode.cn/problems/minimize-malware-spread/)方法一,并查集方法二,dfs [GCD and LCM ](https://vjudge.net.cn/problem/HDU-4497#authorKING_LRL) 924. 尽量减少恶意软件的传播 如果移除一个感染节…

上海计算机学会 2023年10月月赛 乙组T4 树的覆盖(树、最小点覆盖、树形dp)

第四题:T4树的覆盖 标签:树、最小点覆盖、树形 d p dp dp题意:求树的最小点覆盖集的大小和对应的数量,数量对 1 , 000 , 000 , 007 1,000,000,007 1,000,000,007取余数。所谓覆盖集,是该树的点构成的集合,…

docker 环境变量设置实现方式

1、前言 docker在当前运用的越来广泛,很多应用或者很多中间软件都有很多docker镜像资源,运行docker run 启动镜像资源即可应用。但是很多应用或者中间件有很多配置参数。这些参数在运用过程怎么设置给docker 容器呢?下面介绍几种方式 2 、do…

无线网络安全之WiFi Pineapple初探

背景 WiFi Pineapple(大菠萝)是由国外无线安全审计公司Hak5开发并售卖的一款无线安全测试神器。集合了一些功能强大的模块,基本可以还原钓鱼攻击的全过程。在学习无线安全时也是一个不错的工具,本文主要讲WiFi Pineapple基础配置…

和鲸科技将参与第五届空间数据智能学术会议并于应急减灾与可持续发展专题论坛做报告分享

ACM SIGSPATIAL中国分会致力于推动空间数据的研究范式及空间智能理论与技术在时空大数据、智慧城市、交通科学、社会治理等领域的创新与应用。ACM SIGSPATIAL中国分会创办了空间数据智能学术会议(SpatialDI),分会将于2024年4月25日-27日在南京…

javaWeb项目-快捷酒店管理系统功能介绍

项目关键技术 开发工具:IDEA 、Eclipse 编程语言: Java 数据库: MySQL5.7 框架:ssm、Springboot 前端:Vue、ElementUI 关键技术:springboot、SSM、vue、MYSQL、MAVEN 数据库工具:Navicat、SQLyog 1、Spring Boot框架 …

PSCAD|应用于输电线路故障测距的行波波速仿真分析

1 主要内容 该程序参考文献《应用于输电线路故障测距的行波波速仿真分析》,利用线路内部故障产生的初始行波浪涌达线路两端测量点的绝对时间之差值计算故障点到两端测量点之间的距离,并利用小波变换得到初始行波波头准确到达时刻,从而精准定…

富文本在线编辑器 - tinymce

tinymce 项目是一个比较好的富文本编辑器. 这里有个小demo, 下载下来尝试一下, 需要配置个本地服务器才能够访问, 我这里使用的nginx, 下面是我的整个操作过程: git clone gitgitee.com:chick1993/layui-tinymce.git cd layui-tinymcewget http://nginx.org/download/nginx-1.…

JavaEE:JVM

基本介绍 JVM:Java虚拟机,用于解释执行Java字节码 jdk:Java开发工具包 jre:Java运行时环境 C语言将写入的程序直接编译成二进制的机器语言,而java不想重新编译,希望能直接执行。Java先通过javac把.java…

RK3568 学习笔记 : 更改 u-boot spl 中的 emmc 的启动次序

环境 开发板: 【正点原子】 的 RK3568 开发板 ATK-DLRK3568 u-boot 版本:来自 【正点原子】 的 RK3568 开发板 Linux SDK,单独复制出来一份,手动编译 编译环境:VMware 虚拟机 ubuntu 20.04 问题描述 RK3568 默认 …

浅谈线程的生命周期

Java线程的生命周期是一个从创建到终止的过程,经历了多种状态的转变。在Java中,线程的生命周期可以划分为以下几个主要状态: 新建(New): 当使用 new Thread() 创建一个新的线程对象但尚未调用 start() 方法…

CSS基础之伪元素选择器(如果想知道CSS的伪元素选择器知识点,那么只看这一篇就足够了!)

前言:我们已经知道了在CSS中,选择器有基本选择器、复合选择器、伪类选择器、那么选择器学习完了吗?显然是没有的,这篇文章讲解最后一种选择器——伪元素选择器。 ✨✨✨这里是秋刀鱼不做梦的BLOG ✨✨✨想要了解更多内容可以访问我…

【linux】mobaterm如何kill pycharm进程

【linux】mobaterm如何kill pycharm进程 【先赞后看养成习惯】求点赞关注收藏😀 使用云服务器时,pycharm在打开状态下,不小心关了mobaxterm,然后再输入pycharm.sh就会打不开pycharm,显示已经打开报错:Com…

软考131-上午题-【软件工程】-软件可靠性、可用性、可维护性

可靠性、可用性和可维护性是软件的质量属性,软件工程中,用 0-1 之间的数来度量。 0.66 66% 1、 可靠性 可靠性是指一个系统对于给定的时间间隔内、在给定条件下无失效运作的概率。 可以用 MTTF/ (1MTTF) 来度量,其中 MTTF 为平均无故障时间…

PHP一句话木马

一句话木马 PHP 的一句话木马是一种用于 Web 应用程序漏洞利用的代码片段。它通常是一小段 PHP 代码,能够在目标服务器上执行任意命令。一句话木马的工作原理是利用 Web 应用程序中的安全漏洞,将恶意代码注入到服务器端的 PHP 脚本中。一旦执行&#xf…