Keras深度学习框架第二十九讲:在自定义训练循环中应用KerasTuner超参数优化

1、简介

在KerasTuner中,HyperModel类提供了一种方便的方式来在可重用对象中定义搜索空间。你可以通过重写HyperModel.build()方法来定义和进行模型的超参数调优。为了对训练过程进行超参数调优(例如,通过选择适当的批处理大小、训练轮数或数据增强设置),程序员可以重写HyperModel.fit()方法,在该方法中你可以访问:

  • hp对象,它是keras_tuner.HyperParameters的一个实例
  • 由HyperModel.build()构建的模型

在“开始使用KerasTuner”一文的“调整模型训练”部分中给出了一个基本示例。

2、自定义训练循环的超参数调优

本文将通过重写HyperModel.fit()方法来子类化HyperModel类,并编写一个自定义训练循环。如果你想了解如何使用Keras编写一个自定义训练循环,可以参考指南《从零开始编写训练循环》。

首先,我们导入所需的库,并为训练和验证创建数据集。在这里,我们仅使用随机数据作为演示目的。

import keras_tuner
import tensorflow as tf
import keras
import numpy as npx_train = np.random.rand(1000, 28, 28, 1)
y_train = np.random.randint(0, 10, (1000, 1))
x_val = np.random.rand(1000, 28, 28, 1)
y_val = np.random.randint(0, 10, (1000, 1))

接着,我们将HyperModel类子类化为MyHyperModel。在MyHyperModel.build()中,我们构建一个简单的Keras模型来进行10个不同类别的图像分类。MyHyperModel.fit()接受几个参数,其签名如下所示:

def fit(self, hp, model, x, y, validation_data, callbacks=None, **kwargs):

hp 参数用于定义超参数。
model 参数是由 MyHyperModel.build() 返回的模型。
x, y, 和 validation_data 都是自定义参数。稍后我们将通过调用 tuner.search(x=x, y=y, validation_data=(x_val, y_val)) 来传递我们的数据给它们。你可以定义任意数量的这些参数并给它们自定义的名称。
callbacks 参数原本是为了与 model.fit() 一起使用的。KerasTuner 在其中放置了一些有用的 Keras 回调,例如,在模型最佳轮次时保存模型的回调。

在自定义训练循环中,我们将手动调用这些回调。但在调用它们之前,我们需要使用以下代码将我们的模型分配给它们,以便它们可以访问模型以进行保存。

for callback in callbacks:callback.model = model

在这个例子中,我们只调用了回调的 on_epoch_end() 方法来帮助我们保存模型的最佳状态。如果需要,你也可以调用其他回调方法。如果你不需要保存模型,那么你就不需要使用回调。

在自定义训练循环中,我们将通过将NumPy数据包装成tf.data.Dataset来调优数据集的批处理大小。请注意,你也可以在这里调优任何预处理步骤。此外,我们还调优了优化器的学习率。

我们将使用验证损失作为模型的评估指标。为了计算平均验证损失,我们将使用keras.metrics.Mean(),它在批次之间平均验证损失。我们需要返回验证损失,以便Tuner可以记录它。

class MyHyperModel(keras_tuner.HyperModel):def build(self, hp):"""Builds a convolutional model."""inputs = keras.Input(shape=(28, 28, 1))x = keras.layers.Flatten()(inputs)x = keras.layers.Dense(units=hp.Choice("units", [32, 64, 128]), activation="relu")(x)outputs = keras.layers.Dense(10)(x)return keras.Model(inputs=inputs, outputs=outputs)def fit(self, hp, model, x, y, validation_data, callbacks=None, **kwargs):# Convert the datasets to tf.data.Dataset.batch_size = hp.Int("batch_size", 32, 128, step=32, default=64)train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size)validation_data = tf.data.Dataset.from_tensor_slices(validation_data).batch(batch_size)# Define the optimizer.optimizer = keras.optimizers.Adam(hp.Float("learning_rate", 1e-4, 1e-2, sampling="log", default=1e-3))loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)# The metric to track validation loss.epoch_loss_metric = keras.metrics.Mean()# Function to run the train step.@tf.functiondef run_train_step(images, labels):with tf.GradientTape() as tape:logits = model(images)loss = loss_fn(labels, logits)# Add any regularization losses.if model.losses:loss += tf.math.add_n(model.losses)gradients = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(gradients, model.trainable_variables))# Function to run the validation step.@tf.functiondef run_val_step(images, labels):logits = model(images)loss = loss_fn(labels, logits)# Update the metric.epoch_loss_metric.update_state(loss)# Assign the model to the callbacks.for callback in callbacks:callback.set_model(model)# Record the best validation loss valuebest_epoch_loss = float("inf")# The custom training loop.for epoch in range(2):print(f"Epoch: {epoch}")# Iterate the training data to run the training step.for images, labels in train_ds:run_train_step(images, labels)# Iterate the validation data to run the validation step.for images, labels in validation_data:run_val_step(images, labels)# Calling the callbacks after epoch.epoch_loss = float(epoch_loss_metric.result().numpy())for callback in callbacks:# The "my_metric" is the objective passed to the tuner.callback.on_epoch_end(epoch, logs={"my_metric": epoch_loss})epoch_loss_metric.reset_state()print(f"Epoch loss: {epoch_loss}")best_epoch_loss = min(best_epoch_loss, epoch_loss)# Return the evaluation metric value.return best_epoch_loss

现在,我们可以初始化Tuner了。在这里,我们使用Objective("my_metric", "min")作为需要最小化的指标。目标名称应该与你在传递给回调的on_epoch_end()方法的日志中使用的键一致。回调需要使用日志中的这个值来找到最佳的epoch以保存模型的检查点。

换句话说,当你自定义训练循环并决定在每个epoch结束时记录一些指标时,你需要确保你传递给on_epoch_end()方法的日志中包含一个键(例如"my_metric"),该键与你在Tuner中定义的Objective的名称相匹配。这样,Tuner就可以使用这个指标来跟踪模型性能的变化,并决定何时保存最佳的模型检查点。

在上面的例子中,如果我们在每个epoch结束时计算了验证损失,并将其作为"val_loss"键传递给on_epoch_end()方法,那么我们需要在初始化Tuner时使用Objective("val_loss", "min"),因为我们的目标是找到具有最小验证损失的epoch。

tuner = keras_tuner.RandomSearch(objective=keras_tuner.Objective("my_metric", "min"),max_trials=2,hypermodel=MyHyperModel(),directory="results",project_name="custom_training",overwrite=True,
)

我们通过将我们在MyHyperModel.fit()方法的签名中定义的参数传递给tuner.search()来开始搜索。

tuner.search(x=x_train, y=y_train, validation_data=(x_val, y_val))

最后,我们可以检索结果。

在Keras Tuner中,一旦tuner.search()方法执行完毕,你就可以从Tuner对象中检索最佳模型、最佳超参数配置以及搜索结果的历史记录。这些结果可以帮助你理解模型性能如何随着超参数的变化而变化,并为你提供最佳的模型配置以进行进一步的应用或部署。

通常,你可以使用tuner.get_best_models()来获取一个或多个最佳模型,使用tuner.get_best_hyperparameters()来获取最佳超参数配置,以及使用tuner.results_summary()来查看搜索结果的摘要。

best_hps = tuner.get_best_hyperparameters()[0]
print(best_hps.values)best_model = tuner.get_best_models()[0]
best_model.summary()

3、总结

使用Keras Tuner进行自定义训练循环超参数调优的过程可以大致分为以下几个步骤:

3.1. 安装Keras Tuner

首先,确保你已经安装了Keras Tuner库。可以使用pip进行安装:

pip install keras-tuner

3.2. 定义继承自keras_tuner.HyperModel的类

你需要定义一个继承自keras_tuner.HyperModel的类,并在其中定义buildfit方法。

  • build方法:用于定义模型的架构,并使用hp参数设置超参数的搜索空间。
  • fit方法:用于模型的训练过程,它接受hp参数以及训练数据和其他必要的参数。
import tensorflow as tf
from tensorflow.keras import layers
from keras_tuner import HyperModelclass MyHyperModel(HyperModel):def build(self, hp):model = tf.keras.Sequential()# 示例:定义含可调参数的全连接层hp_units = hp.Int('units', min_value=32, max_value=512, step=32)model.add(layers.Dense(units=hp_units, activation='relu'))# ... 其他层 ...model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])return modeldef fit(self, hp, x_train, y_train, **kwargs):model = self.build(hp)model.fit(x_train, y_train, epochs=10, **kwargs)# 假设这里只进行一轮训练作为示例,实际中可能需要多轮return {'loss': model.evaluate(x_train, y_train)[0], 'accuracy': model.evaluate(x_train, y_train)[1]}

3.3. 准备数据和回调

准备好你的训练数据和验证数据,以及可能需要的回调函数(如模型保存、早停等)。

3.4. 使用Tuner进行搜索

实例化你的Tuner类(如RandomSearchHyperband等),并传入你的HyperModel、数据以及搜索的目标(如最小化验证损失)。

from keras_tuner import RandomSearchtuner = RandomSearch(MyHyperModel(),objective='val_loss',max_trials=10,  # 搜索的最大试验次数executions_per_trial=3,  # 每个试验的重复次数directory='my_dir',  # 结果保存目录project_name='my_project'
)tuner.search(x_train, y_train,validation_data=(x_val, y_val),epochs=10,  # 注意这里的epochs仅用于fit方法中的一轮训练callbacks=[...])  # 可能的回调,如ModelCheckpoint

3.5. 检索结果

搜索完成后,你可以从Tuner对象中检索最佳模型、最佳超参数配置以及搜索结果的历史记录。

best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]
print(f'Best hyperparameters: {best_hps.values}')best_model = tuner.get_best_models(num_models=1)[0]
# 使用best_model进行预测或进一步评估

3.6. 可视化结果

Keras Tuner提供了丰富的可视化支持,你可以使用TensorBoard等工具来查看搜索过程的详细结果。

使用Keras Tuner进行自定义训练循环的超参数调优涉及安装Keras Tuner库,定义继承自HyperModel的类并实现其build和fit方法,准备训练数据和验证数据以及可能的回调。随后,实例化Tuner类并传入定义的HyperModel和数据,开始搜索最佳超参数组合。搜索完成后,可以通过Tuner的接口检索到最佳的超参数和模型。整个调优过程中需要注意设置合理的搜索空间、试验次数,并使用独立的验证集来评估模型性能,最后可以利用可视化工具查看调优结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/16463.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

mysql中连接的原理

大家好。我们在日常开发中经常会遇到多表联查的场景。今天我来为大家讲一下我们在进行多表联查时,表与表之间连接的原理。 为了方便讲解,我们先创建两个表,并填充一些数据。 如图所示,我创建了t1、t2两张表,每张表中…

四大运营商大流量卡测评,手机卡,物联网卡,纯流量卡

买大流量卡,看4个方面 优惠时间。有的只是12个月,24个月有优惠【可以先用一年,然后注销】通用流量。而不是定向流量全国通话分钟数。而不是亲情通话分钟数销户方式。是否支持随时销户,异地销户,线上销户,额…

火箭升空AR虚拟三维仿真演示满足客户的多样化场景需求

在航空工业的协同研发领域,航空AR工业装配系统公司凭借前沿的AR增强现实技术,正引领一场革新。通过将虚拟信息无缝融入实际环境中,我们为工程师、设计师和技术专家提供了前所未有的共享和审查三维模型的能力,极大地提升了研发效率…

stream-基本流

定义 一般流中:都是以Object对象存储的,基本流中是将数据作为基本类型存储的,空间占用率更低,效率更高基本流只有三种:int、long、double基本流也有一些特有的方法 // 基本流 有三种 IntStream LongStream DoubleStrea…

使用Prometheus组件node_exporter采集linux系统的指标数据(包括cpu/内存/磁盘/网络)

一、背景 Linux系统的基本指标包括cpu、内存、磁盘、网络等,其中网络可以细分为带宽进出口流量、连接数和tcp监控等。 本文使用Prometheus组件node_exporter采集,存储在promethues,展示在grafana面板。 二、安装node_exporter 1、下载至本…

【数学建模】碎纸片的拼接复原

2013高教社杯全国大学生数学建模竞赛B题 问题一模型一模型二条件设立思路 问题求解 问题一 已知 d i d_i di​为第 i i i张图片图片的像素矩阵 已知 d i d_i di​都是 n ∗ m n*m n∗m二维矩阵 假设有 N N N张图片 模型一 我们认为对应位置像素匹配为 d i [ j ] [ 1 ] d k…

C++:单例模型、强制类型转换

目录 特殊类的设计不能被拷贝的类实现一个类,只能在堆上实例化的对象实现一个类,只能在栈上实例化的对象不能被继承的类 单例模式饿汉模式懒汉模式饿汉模式与懒汉模式的对比饿汉优缺点懒汉优缺点懒汉模式简化版本(C11) 单例释放问…

索引失效的场景有哪些?

一.概念 索引失效是指在查询时,数据库引擎无法使用索引来加速查询,从而导致查询性能下降。常见的索引失效原因有以下几种: 索引列没有被包含在查询条件中。如果查询条件中没有包含索引列,那么数据库引擎无法使用索引来加速查询。…

机器人--路径--bezier

教学文章 链接1 链接2--逼近拟合 路径 路径由控制点定义,这些控制点将路径描述为一系列链接的线段。 路径控制点 将路径控制点连接起来,就是路径。 Bezier 曲线的初衷就是用尽可能少的数据表示出复杂的图形。 皮埃尔贝塞尔的想法是,在设…

域提权漏洞系列分析-Zerologon漏洞分析

2020年08⽉11⽇,Windows官⽅发布了 NetLogon 特权提升漏洞的⻛险通告,该漏洞编号为CVE-2020-1472,漏 洞等级:严重,漏洞评分:10分,该漏洞也称为“Zerologon”,2020年9⽉11⽇&#xff…

WinRAR技巧:如何让多个文件压缩到更小!?

但我们要压缩多个文件的时候,可能会出现压缩后的体积仍然过大,或者需要将文件再压缩到更小,这种情况下,小编之前建议过大家将文件压缩成7z格式就会更加压缩体积。今天分享另一个技巧,帮助我们将多个文件压缩到更小。 …

Istio ICA考试之路---4-3

Istio ICA考试之路---4-3 1. 题目2. 解题2.1 获取模板2.2 修改yaml 1. 题目 Using Kubernetes context cluster-2 Create an authorization policy named "allow-get" in the namespace policy-3, allowing all GET requests from workloads in the default names…

【网络安全】勒索软件ShrinkLocker使用 windows系统安全工具BitLocker实施攻击

文章目录 威胁无不不在BitLocker 概述如何利用BitLocker进行攻击如何降低影响Win11 24H2 装机默认开启 BitLocker推荐阅读 威胁无不不在 网络攻击的形式不断发展,即便是合法的 Windows 安全功能也会成为黑客的攻击工具。 卡巴斯基实验室专家 发现 使用BitLocker的…

以不变应万变:在复杂世界中保持初心,坚持原则

在这个日新月异、瞬息万变的世界里,人情世故也显得尤为复杂。我们常常会因为忙碌的生活、工作压力以及人际关系的纠葛而感到迷茫和疲惫。在面对这些复杂局面的同时,如何保持内心的平静,坚持自己的原则,并在变幻莫测的环境中持续成…

ClickHouse架构概览 —— Clickhouse 架构篇(一)

文章目录 前言Clickhouse 架构简介Clickhouse 的核心抽象列和字段数据类型块表 Clickhouse 的运作过程数据插入过程数据查询过程数据更新和删除过程 前言 本文介绍了ClickHouse的整体架构,并对ClickHouse中的一些重要的抽象对象进行了分析。然后此基础上&#xff0…

乘风破浪,创维汽车旗舰店落户安徽

2024年5月19日,创维汽车宣城家奇体验中心盛大开业。宣城市委办公室副主任师典雅、市投资促进局副局长金崇学、经开区管委会副主任汤晓峰、宣城市通信局局长梁登峰、创维汽车战区总经理刘俊、创维汽车大区总监王大明等人出席此次开业盛典,共同见证了创维汽…

内网穿透实现公网访问自己搭建的Ollma架构的AI服务器

内网穿透实现公网访问自己搭建的Ollma架构的AI服务器 [2024年5月9号我发布一个博文关于搭建本地AI服务器的博文][https://blog.csdn.net/weixin_41905135/article/details/138588043?spm1001.2014.3001.5501],今天我们内网穿透实现从公网访问我的本地AI服务器&…

Julia Ide Neptune

一 pkg> add Neptune julia> using Neptune julia> Neptune.run() 二 pkg> add Pluto julia> import Pluto julia> Pluto.run() 主要是装IJulia总报错,索性找了如上替代品。

全国多地入夏!对抗“高温高湿”约克VRF中央空调有妙招

随着气温飙升,北京、上海、广州、南京、天津、江苏、新疆、内蒙古部分地区等多地进入夏季状态,华北、黄淮等不少地方最高气温都超过了30℃,大街上人们短袖、短裤纷纷上阵,一派夏日炎炎的景象。 炎热夏季不仅高温频频来袭,往往还伴随着降雨带来的潮湿,天气湿热交织容易让人们身…

C++第二十弹---深入理解STL中vector的使用

✨个人主页: 熬夜学编程的小林 💗系列专栏: 【C语言详解】 【数据结构详解】【C详解】 目录 1、vector的介绍 2、vector的使用 2.1、构造函数和赋值重载 2.1.1、构造函数的介绍 2.1.2、代码演示 2.2、容量操作 2.3、遍历 2.4、增删…