深度学习中的迁移学习:预训练模型微调与实践

深度学习中的迁移学习:预训练模型微调与实践

目录

  1. 💡 迁移学习的核心概念
  2. 🧠 预训练模型的使用:ResNet与VGG的微调
  3. 🏥 迁移学习在医学图像分析中的应用
  4. 🔄 实践中的迁移学习微调过程

1. 💡 迁移学习的核心概念

迁移学习(Transfer Learning)在深度学习领域中发挥着至关重要的作用。其核心思想是:在大型数据集上训练好的模型可以被“迁移”到新的任务中,从而避免从零开始训练。深度神经网络的训练通常需要大量的数据和时间,通过利用已经在大规模数据集(如ImageNet)上训练过的模型,迁移学习能够极大地缩短训练时间,并显著提高性能。

迁移学习的关键点:

  • 预训练模型:通过在通用数据集上训练模型(如ResNet、VGG等),这些模型学到了基础的特征表示,如边缘、形状和纹理。迁移学习的核心在于将这些基础特征应用到新的领域任务中。
  • 微调(Fine-tuning):通过对预训练模型进行部分或全部参数的微调,模型可以适应新任务中的特定数据。微调的程度取决于新任务的相似性和目标。
  • 冻结与解冻层:迁移学习过程中,通常会冻结模型的部分层,以保留通用的特征提取能力,针对新任务只对高层进行微调。

通过迁移学习,即使在拥有较少数据的情况下,也能获得优异的模型性能。接下来的部分将详细介绍如何使用经典的预训练模型,如ResNet和VGG,进行微调和迁移学习的实现。


2. 🧠 预训练模型的使用:ResNet与VGG的微调

深度学习中的经典模型如ResNetVGG,常被用作迁移学习的预训练模型。它们在ImageNet等大规模数据集上预训练,并能够捕获图像中的通用特征。

ResNet与VGG的区别:

  • ResNet(Residual Networks):ResNet通过引入残差块,解决了深度神经网络中的梯度消失问题。这使得ResNet可以训练非常深的网络(如ResNet50、ResNet101),同时保持较高的性能。
  • VGG:VGG网络的特点在于其非常规则的卷积层堆叠结构,尽管深度较浅,但它能通过更宽的卷积核捕捉丰富的图像特征。

示例代码:微调ResNet进行图像分类

以下代码展示了如何使用预训练的ResNet模型并进行迁移学习,以适应新的图像分类任务。

# 引入必要的库
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder# 加载预训练的ResNet模型
resnet = models.resnet50(pretrained=True)# 冻结所有层的参数,以便只微调最后的全连接层
for param in resnet.parameters():param.requires_grad = False# 修改ResNet的最后一层,以适应新任务的分类数目
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, 2)  # 假设目标任务是二分类# 定义数据增强和预处理
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载数据集
train_dataset = ImageFolder(root='path_to_train_data', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet.fc.parameters(), lr=0.001)# 模型训练
for epoch in range(10):  # 假设训练10个周期resnet.train()running_loss = 0.0for inputs, labels in train_loader:optimizer.zero_grad()outputs = resnet(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

代码解析:

  • 模型加载与微调:代码中使用了torchvision.models中的resnet50预训练模型,并冻结了所有卷积层,只微调最后的全连接层以适应新任务(如二分类)。
  • 数据增强与预处理:通过transforms.Compose进行图像预处理,包括图像的缩放、裁剪和归一化。
  • 训练过程:通过微调最后的全连接层,模型能够快速适应新任务。

微调深度学习模型的关键在于,冻结模型的大部分层次,并根据任务的需求重新训练部分层。通过这种方式,可以在有限数据的情况下,获得良好的性能表现。


3. 🏥 迁移学习在医学图像分析中的应用

迁移学习在医学图像分析等领域中的应用尤为广泛,特别是在这种特定领域中,通常面临数据稀缺的问题。由于医学图像数据的获取和标注成本高昂,直接从头训练深度学习模型往往不可行。因此,利用预训练模型进行迁移学习成为一种行之有效的解决方案。

医学图像分析中的挑战:

  1. 数据稀缺:标注的医学图像数据通常较少,这使得从零开始训练模型变得困难。
  2. 高精度要求:医学图像分析任务通常需要非常高的精度,因为其结果会直接影响临床诊断。
  3. 特征差异:尽管预训练模型在自然图像上表现优异,但医学图像的特征通常与自然图像有显著区别,因此需要对模型进行专门的微调。

通过迁移学习,医学图像分析可以借助在ImageNet等大数据集上预训练的模型提取基础特征,然后通过微调,模型可以有效学习到医学图像中特定的病变或异常区域。

示例代码:应用ResNet进行医学图像分析

# 引入必要的库
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision import transforms, datasets# 加载预训练的ResNet50模型
resnet = models.resnet50(pretrained=True)# 冻结所有层的参数
for param in resnet.parameters():param.requires_grad = False# 修改最后一层以适应医学图像分析的分类
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, 5)  # 假设任务为五分类# 定义数据增强和预处理
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])# 加载医学图像数据集
train_dataset = datasets.ImageFolder(root='path_to_medical_data', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet.fc.parameters(), lr=0.0001)# 模型训练过程
for epoch in range(20):  # 假设训练20个周期resnet.train()running_loss = 0.0for inputs, labels in train_loader:optimizer.zero_grad()outputs = resnet(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

代码解析:

  • 医学数据微调:利用预训练的ResNet模型,只微调最后的分类层,使其能够适应五分类任务,适用于医学图像分析中的不同疾病分类任务。
  • 医学图像预处理:通过数据增强,如缩放、裁剪等操作,增强模型的泛化能力。

迁移学习在医学图像分析中的应用能够有效降低数据需求,同时提高模型的准确性和可靠性。


4. 🔄 实践中的迁移学习微调过程

在实际操作中,迁移学习的微调过程需要根据任务的复杂度和数据集的大小进行调整。具体微调的策略包括:

  1. 冻结大部分层:对于简单任务,只

需微调网络的高层特征表示层,而保留低层特征不变。
2. 解冻更多层:对于复杂任务,可能需要解冻更多层次,以学习更多领域特定的特征。
3. 调整学习率:微调时,通常使用较小的学习率,以避免破坏预训练模型中学到的有用特征。

以下是微调不同层的实践过程:

# 解冻部分层,允许更多层进行训练
for name, param in resnet.named_parameters():if "layer4" in name:  # 假设只解冻ResNet的最后一层param.requires_grad = Trueelse:param.requires_grad = False# 调整学习率以适应微调
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, resnet.parameters()), lr=0.00001)# 继续进行模型的训练与微调

拓展部分:使用迁移学习进行图像分割任务

迁移学习不仅可以应用于分类任务,还可以应用于图像分割等更复杂的任务。通过调整预训练模型的结构,可以实现图像中的目标检测或分割。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/55635.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Streamlit:用Python快速构建交互式Web应用

在传统的Web开发中,开发者常常需要编写大量的前端和后端代码,才能实现一个简单的交互式Web应用。Streamlit 通过简化这一过程,使得你只需要用Python编写代码,就能快速创建具有丰富交互功能的Web应用。本文将介绍如何使用Streamlit…

Pikachu-SSRF(curl / file_get_content)

SSRF SSRF是Server-side Request Forge的缩写,中文翻译为服务端请求伪造。产生的原因是由于服务端提供了从其他服务器应用获取数据的功能且没有对地址和协议等做过滤和限制。常见的一个场景就是,通过用户输入的URL来获取图片。这个功能如果被恶意使用&am…

Ascend C 自定义算子开发:高效的算子实现

Ascend C 自定义算子开发:高效的算子实现 在 Ascend C 平台上,开发自定义算子能够充分发挥硬件的性能优势,帮助开发者针对不同的应用场景进行优化。本文将以 AddCustom 算子为例,介绍 Ascend C 中自定义算子的开发流程及关键技术…

如何制作一个企业网站,建设网站的基本步骤有哪些?

企业网站是企业的门面和名片,决定网民对企业的第一印象,因此,现在很多公司想做一个属于自己网站,但是不知道怎么做,更不知道从何做起,更别说做成了。为了能够让大家清楚如何做一个企业网站,现在…

Mysql数据库原理--查询收尾+索引+事务

文章目录 1.查询收尾1.1自查询1.2合并查询 2.索引事务2.1约束自动生成索引2.2create手动添加索引2.3.删除手动创建的索引2.4索引背后的数据结构2.5B树的结构特点和优点--经典面试题 3.事务--经典面试题3.1基本理解3.2事务的特性3.3隔离级别 1.查询收尾 1.1自查询 子查询就是套…

《CUDA编程》6.CUDA的内存组织

前面几章讲了一些编写高性能CUDA程序的要点,但还有很多其他需要注意的,其中最重要的就是合理的使用设备内存 1 CUDA的内存组织简介 现代计算机中的内存存在一种组织结构(hierachy),即不同类型的内存具有不同的容量和访问延迟(可以…

力扣203.移除链表元素

题目链接:203. 移除链表元素 - 力扣(LeetCode) 给你一个链表的头节点 head 和一个整数 val ,请你删除链表中所有满足 Node.val val 的节点,并返回 新的头节点 。 示例 1: 输入:head [1,2,6…

PDF怎么转换成TXT文本?这4个方法简单还免费,pdf转txt就靠它!

PDF怎么转换成TXT文本?PDF文件虽然广泛支持,但在某些设备或软件上可能无法完全正确显示,尤其是当文件包含特殊字体或复杂布局时。此外,PDF文件的阅读体验也可能受到格式干扰,如复杂的背景颜色或字体样式。将PDF转换为T…

python 实现最小路径和算法

最小路径和算法介绍 最小路径和问题通常指的是在一个网格(如二维数组)中,找到从起点(如左上角)到终点(如右下角)的一条路径,使得路径上经过的元素值之和最小。这类问题可以通过多种…

IDEA几大常用AI插件

文章目录 前言列表GPT中文版TalkXBito AIIDEA自带的AI 前言 最近AI、GPT特别火,IDEA里面又有一堆插件支持GPT,所以做个专题比较一下各个GPT插件 列表 先看idea的plugins里支持哪些,搜索“GPT”之后得到的,我用下来感觉第一第二和…

阿里云云虚拟主机SSL证书安装指南

在安装SSL证书的过程中,您需要确保已经正确获取了SSL证书文件,并且能够访问阿里云云虚拟主机的管理页面。以下是详细的步骤说明: 第一步:准备SSL证书 申请SSL证书:访问华测ctimall网站(https://www.ctimal…

Unite Barcelona主题演讲回顾:深入了解 Unity 6

本周,来自世界各地的 Unity 开发者齐聚西班牙巴塞罗那,参加 Unite 2024。本次大会的主题演讲持续了一个多小时,涵盖新功能的介绍、开发者成功案例的分享,以及在编辑器中进行的技术演示,重点展示了 Unity 6 在实际项目中…

Java | Leetcode Java题解之第457题环形数组是否存在循环

题目&#xff1a; 题解&#xff1a; class Solution {public boolean circularArrayLoop(int[] nums) {int n nums.length;for (int i 0; i < n; i) {if (nums[i] 0) {continue;}int slow i, fast next(nums, i);// 判断非零且方向相同while (nums[slow] * nums[fast]…

游戏开发指南:使用 UOS C# 云函数快速构建与部署服务端逻辑实战教学

零基础的服务端小白&#xff0c;现在也可以使用 Unity 结合 C# 来轻松搞定游戏服务端啦&#xff01; 在本篇文章中&#xff0c;我们将以游戏中的“抽卡”功能为例&#xff0c;展示如何使用 Unity Online Services&#xff08;UOS&#xff09;提供的强大 C# 云函数服务&#xf…

如何革新源代码保密?七大方法教你应对!

在数字化时代&#xff0c;源代码的安全保密对于企业而言至关重要&#xff0c;它不仅关系到企业的核心竞争力&#xff0c;还涉及到知识产权的保护。源代码一旦泄露&#xff0c;可能会给企业带来无法估量的损失。因此&#xff0c;采取有效的源代码保密措施&#xff0c;是每个企业…

焊接缺陷分割系统源码&数据集分享

焊接缺陷分割系统源码&#xff06;数据集分享 [yolov8-seg-C2f-DiverseBranchBlock&#xff06;yolov8-seg-C2f-DCNV3等50全套改进创新点发刊_一键训练教程_Web前端展示] 1.研究背景与意义 项目参考ILSVRC ImageNet Large Scale Visual Recognition Challenge 项目来源AAAI…

Django一分钟:在Django中怎么存储树形结构的数据,DRF校验递归嵌套模型的替代方案

引言 在开发过程中我们可能需要这样的树形结构: [{"data": {"name": "牛奶"},"children": [{"data": {"name": "蒙牛"}, },{"data": {"name": "伊利"}, }]},{"da…

如何使类目树与闭包表相结合

类目树与闭包表结合的教程 类目树和闭包表是非常常见的组合,特别是在处理带有层次关系的分类数据时,这种组合可以让查询和维护更高效。接下来,我们将详细讲解如何将类目树和闭包表结合起来,以实现企业级项目中的分类管理。 什么是类目树? 类目树是一种数据结构,它表示…

减少重复的请求之promise缓存池(构造器版) —— 缓存promise,多次promise等待并返回第一个promise的结果

减少重复的请求之promise缓存池 —— 缓存promise&#xff0c;多次promise等待并返回第一个promise的结果 背景简介 当一个业务组件初始化调用了接口&#xff0c;统一个页面多吃使用同一个组件&#xff0c;将会请求大量重复的接口 如果将promise当作一个普通的对象&#xff0…

LeetCode39:组合总和

题目&#xff1a; 给你一个 无重复元素 的整数数组 candidates 和一个目标整数 target &#xff0c;找出 candidates 中可以使数字和为目标数 target 的 所有 不同组合 &#xff0c;并以列表形式返回。你可以按 任意顺序 返回这些组合。 candidates 中的 同一个 数字可以 无限…