TensorFlow2实战-系列教程6:迁移学习实战

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

1、迁移学习

  • 用已经训练好模型的权重参数当做自己任务的模型权重初始化
  • 一般全连接层需要自己训练,可以选择是否训练已经训练好的特征提取层

一般情况下根据自己的任务,选择对那些网络进行微调和重新训练:
如果预训练模型的任务和自己任务非常接近,那可能只需要把最后的全连接层重新训练即可
如果自己任务的数据量比较小,那么应该选择重新训练少数层
如果自己任务的数据量比较大,可以适当多选择几层进行训练

2、猫狗识别

import os
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers
from tensorflow.keras import Model
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')

前面的内容和TensorFlow2实战-系列教程3:猫狗识别1完全一样

3、加载预训练模型

from tf.keras.applications.resnet import ResNet50
from tensorflow.keras.applications.resnet import ResNet101
from tensorflow.keras.applications.inception_v3 import InceptionV3

从keras中导入预训练模型,在TensorFlow的keras模块,有很多可以直接导入的预训练权重。

pre_trained_model = ResNet101(input_shape = (75, 75, 3),  include_top = False, weights = 'imagenet')
  • 加载导入的模型
  • input_shape 为输入大小
  • include_top为False就是表示不要最后的全连接层
  • 这段代码执行后,会自动进行下载

downloading data from
https://storage.googleapis.com/tensorflow/kerasapplications/resnet/resnet101_weights_tf_dim_ordering_tf_kernels_notop.h5
171446536/171446536 [==============================] - 15s 0us/step

for layer in pre_trained_model.layers:layer.trainable = False

选择要进行重新训练的层

4、callback模块

在 TensorFlow 中,回调(Callbacks)是一个强大的工具,用于在训练的不同阶段(例如在每个时代的开始和结束、在每个批次的处理前后)自定义和控制模型的行为,相当于一个监视器:

4.1 callback示例

callbacks = [
# 如果连续两个epoch还没降低就停止:tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
# 可以动态改变学习率:tf.keras.callbacks.LearningRateScheduler
# 保存模型:tf.keras.callbacks.ModelCheckpoint
# 自定义方法:tf.keras.callbacks.Callback
]

上面是一个模板,继续我们的猫狗识别的迁移学习项目:

4.2 定义callback

class myCallback(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs={}):if(logs.get('acc')>0.95):print("\nReached 95% accuracy so cancelling training!")self.model.stop_training = True
  1. 定义一个类,继承Callback
  2. 定义一个函数,传入epoch值和日志
  3. 从当前epoch的日志中取出准确率,如果准确率大于95%
  4. 打印信息
  5. 停止训练
from tensorflow.keras.optimizers import Adam
x = layers.Flatten()(pre_trained_model.output)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)                  
x = layers.Dense(1, activation='sigmoid')(x)           
model = Model(pre_trained_model.input, x) 
model.compile(optimizer = Adam(lr=0.001), loss = 'binary_crossentropy', metrics = ['acc'])
  1. 导入优化器
  2. 将预训练模型的输出展平为一维
  3. 定义一个1024的全连接层
  4. 在这层加入dropout
  5. 输出全连接层
  6. 构建模型
  7. 指定优化器、损失函数、验证方法等配置训练器

5、模型训练

定义需要重新训练的层

train_datagen = ImageDataGenerator(rescale = 1./255.,rotation_range = 40,width_shift_range = 0.2,height_shift_range = 0.2,shear_range = 0.2,zoom_range = 0.2,horizontal_flip = True)test_datagen = ImageDataGenerator( rescale = 1.0/255. )train_generator = train_datagen.flow_from_directory(train_dir,batch_size = 20,class_mode = 'binary', target_size = (75, 75))     validation_generator =  test_datagen.flow_from_directory( validation_dir,batch_size  = 20,class_mode  = 'binary', target_size = (75, 75))

前面的内容和TensorFlow2实战-系列教程3:猫狗识别1一样,制作数据

callbacks = myCallback()
history = model.fit_generator(train_generator,validation_data = validation_generator,steps_per_epoch = 100,epochs = 100,validation_steps = 50,verbose = 2,callbacks=[callbacks])

指定训练参数、数据、加入callback模块到模型中,执行训练,verbose = 2表示每次epoch记录一次日志

打印结果:

Epoch 99/100 100/100 - 76s - loss: 0.6138 - acc: 0.6655 - val_loss: 0.6570 - val_acc: 0.6900
Epoch 100/100 100/100 - 76s - loss: 0.5993 - acc: 0.6735 - val_loss: 0.7176 - val_acc: 0.6910

6、预测效果展示

import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(acc))plt.plot(epochs, acc, 'b', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()plt.figure()plt.plot(epochs, loss, 'b', label='Training Loss')
plt.plot(epochs, val_loss, 'r', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()

展示
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/653327.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【机器学习】工程实践问题概述

机器学习实际应用时的工程问题与面临的挑战 一、实现细节问题 1.1 训练样本 训练样本与标注对各类机器学习算法和模型的精度影响 训练样本的选择对各类机器学习算法和模型的影响 训练样本的优化 如何进行数据增强? 如何进行数据清洗? 样本的标注对各类机…

C语言每日一练之31

31 第三十一练 使用递归的方式求解第n位的斐波那契数列 要求: 1、输入整数n 2、输出第n位的斐波那契数列的值 第三十练答案 #include<stdio.h> int recur(int n) {if (n == 1)return 1;

数据结构(二)------单链表

制作不易&#xff0c;三连支持一下呗&#xff01;&#xff01;&#xff01; 文章目录 前言一.什么是链表二.链表的分类三.单链表的实现总结 前言 上一节&#xff0c;我们介绍了顺序表的实现与一些经典算法。 但是顺序表这个数据结构依然有不少缺陷&#xff1a; 1.顺序表指定…

导航页配置服务Dashy本地部署并实现公网远程访问

文章目录 简介1. 安装Dashy2. 安装cpolar3.配置公网访问地址4. 固定域名访问 简介 Dashy 是一个开源的自托管的导航页配置服务&#xff0c;具有易于使用的可视化编辑器、状态检查、小工具和主题等功能。你可以将自己常用的一些网站聚合起来放在一起&#xff0c;形成自己的导航…

基于springboot宠物领养系统

摘要 随着社会的不断发展和人们生活水平的提高&#xff0c;宠物在家庭中的地位逐渐上升&#xff0c;宠物领养成为一种流行的社会现象。为了更好地管理和促进宠物领养的过程&#xff0c;本文基于Spring Boot框架设计和实现了一套宠物领养系统。该系统以用户友好的界面为特点&…

时序分析中的去趋势化方法

时序分析中的去趋势化方法 时序分析是研究随时间变化的数据模式的一门学科。在时序数据中&#xff0c;趋势是一种随着时间推移而呈现的长期变化趋势&#xff0c;去趋势化是为了消除或减弱这种趋势&#xff0c;使数据更具平稳性。本文将简单介绍时序分析中常用的去趋势化方法&a…

跟着cherno手搓游戏引擎【13】着色器(shader)

创建着色器类&#xff1a; shader.h:初始化、绑定和解绑方法&#xff1a; #pragma once #include <string> namespace YOTO {class Shader {public:Shader(const std::string& vertexSrc, const std::string& fragmentSrc);~Shader();void Bind()const;void Un…

怎样自行搭建幻兽帕鲁游戏联机服务器?

幻兽帕鲁是一款深受玩家喜爱的多人在线游戏&#xff0c;为了获取更好的游戏体验&#xff0c;许多玩家希望能够自行搭建幻兽帕鲁游戏联机服务器&#xff0c;本文将指导大家如何自行搭建幻兽帕鲁游戏联机服务器。 自行搭建幻兽帕鲁游戏联机服务器&#xff0c;阿里云是一个不错的选…

了解云原生

一.什么是云原生 云原生是一种构建和运行程序的方法。云原生(Cloud Native&#xff09;是一个组合词&#xff0c;Cloud Native。Cloud表示应用程序位于云中&#xff0c;而不是传统的数据中心;Native表示应用程序从设计之初即考虑到云的环境。 二.云原生四要素 1.微服务 和微服…

结构体的增删查改

结构体&#xff0c;是为了解决生活中的一些不方便利用c语言自带数据类型来表示的问题。例如表示一个学生&#xff0c;那么学生这个个体假如用c语言自带数据类型怎么表示呢。可以使用名字&#xff0c;也就是字符数组&#xff1b;也可以使用学号&#xff0c;也就是int类型。但是这…

【解决方法】git pull报错ssh: connect to host github.com port 22: Connection timed out

问题 git pull ssh: connect to host github.com port 22: Connection timed out fatal: Could not read from remote repository.解决方法 在C:\Users\username.ssh文件夹下新建config文件&#xff0c;填入以下文本&#xff08;如有则直接在文件最后一行新增&#xff09;&am…

iOS 面试 Swift基础题

一、Swift 存储属性和计算属性比较&#xff1a; 存储型属性:用于存储一个常量或者变量 计算型属性: 计算性属性不直接存储值,而是用 get / set 来取值 和 赋值,可以操作其他属性的变化. 计算属性可以用于类、结构体和枚举&#xff0c;存储属性只能用于类和结构体。存储属性可…

【AIGC】Diffusers:加载管道、模型和调度程序

前言 拥有一种使用扩散系统进行推理的简单方法对于&#x1f9e8;扩散器至关重要。扩散系统通常由多个组件组成&#xff0c;例如参数化模型、分词器和调度器&#xff0c;它们以复杂的方式进行交互。这就是为什么我们设计了 DiffusionPipeline&#xff0c;将整个扩散系统的复杂性…

检测头篇 | 原创自研 | YOLOv8 更换 SEResNeXtBottleneck 头 | 附详细结构图

左图:ResNet 的一个模块。右图:复杂度大致相同的 ResNeXt 模块,基数(cardinality)为32。图中的一层表示为(输入通道数,滤波器大小,输出通道数)。 1. 思路 ResNeXt是微软研究院在2017年发表的成果。它的设计灵感来自于经典的ResNet模型,但ResNeXt有个特别之处:它采用…

HiveSQL题——用户连续登陆

目录 一、连续登陆 1.1 连续登陆3天以上的用户 0 问题描述 1 数据准备 2 数据分析 3 小结 1.2 每个用户历史至今连续登录的最大天数 0 问题描述 1 数据准备 2 数据分析 3 小结 1.3 每个用户连续登录的最大天数(间断也算) 0 问题描述 1 数据准备 2 数据分析 3 小…

qt信号与槽机制及使用demo

要在 Qt 中将 rclcomm 类与 MainWindow 连接&#xff0c;并使用 rcl->pose_uids 中的项更新 comboBox_model&#xff0c;您可以按照以下步骤操作&#xff1a; 信号与槽机制&#xff1a;Qt 使用信号和槽机制来处理事件和对象间通信。您可以在 rclcomm 类中定义一个信号&#…

MySQL-窗口函数 简单易懂

窗口函数 考查知识点&#xff1a; • 如何用窗口函数解决排名问题、Top N问题、前百分之N问题、累计问题、每组内比较问题、连续问题。 什么是窗口函数 窗口函数也叫作OLAP&#xff08;Online Analytical Processing&#xff0c;联机分析处理&#xff09;函数&#xff0c;可…

Python入门知识点分享——(十七)正则表达式和re模块

不好意思鸽了这么久&#xff0c;这几天备赛美赛没有太多时间写博客。好了闲话少叙&#xff0c;这次为大家带来的是正则表达式的相关介绍。正则表达式又叫做规则表达式,英文全称Regular Expression。是一种对字符串操作的逻辑公式&#xff0c;就是用事先定义好的一些特定字符、及…

RK3568平台 of 操作函数获取属性

一.of 操作函数获取属性 of_find_property 函数&#xff0c;用于在设备树中查找节点 下具有指定名称的属性。 struct property *of_find_property(const struct device_node *np, const char *name, int*lenp)np: 要查找的节点。 name: 要查找的属性的属性名。 lenp: 一个指…

Android 基础技术——列表卡顿问题如何分析解决

笔者希望做一个系列&#xff0c;整理 Android 基础技术&#xff0c;本章是关于列表卡顿问题如何分析解决 onBindViewHolder 优化 是否有耗时操作、重复创建对象、设置监听器、findViewByID、局部的动画对象等操作 是否存在内存泄漏 发生内存泄露&#xff0c;会导致一些不再使用…