政安晨:【Keras机器学习实践要点】(七)—— 使用TensorFlow自定义fit()

政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras实战演绎机器学习

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

在TensorFlow中,fit()是一个非常强大和常用的训练函数,它可以批次地训练模型并监测其性能。虽然fit()提供了很多有用的默认行为,但有时您可能想自定义fit()中发生的操作。

一种自定义fit()的方法是使用回调函数。回调函数是在训练过程中的特定时间点被调用的函数,您可以编写自己的回调函数来执行特定的操作。

TensorFlow提供了许多内置的回调函数,如EarlyStoppingCallback、ModelCheckpointCallback等,您也可以编写自己的自定义回调函数。通过将回调函数传递给fit()函数的callbacks参数,您可以自定义在每个训练批次或训练周期结束时发生的操作。

另一种自定义fit()的方法是编写自定义的训练循环。

默认情况下,fit()函数使用单个步骤来执行训练循环,但您可以重写这个步骤来实现自己的训练逻辑。您可以使用TensorFlow的GradientTape来手动计算梯度,并使用优化器来更新模型的权重。通过这种方式,您可以完全控制训练过程中的每个步骤。

最后,您还可以自定义fit()的行为通过设置fit()函数的其他参数。

例如,您可以通过设置batch_size参数来定义每个训练批次的大小,或者通过设置epochs参数来指定训练周期的数量。除了这些参数,您还可以设置其他参数,如学习率、损失函数等,以进一步自定义fit()的行为。

总而言之,TensorFlow提供了多种方法来自定义fit()函数中发生的操作。无论您是通过回调函数、自定义训练循环还是设置fit()的参数,您都可以根据自己的需求来定制训练过程。

这使得TensorFlow成为一个非常灵活和强大的深度学习框架。

今天我们讲的就是keras api 在这里的应用方法


前言

当你进行监督学习时,你可以使用fit(),一切都运行得很顺利。

当你需要控制每一个细节时,你可以完全从头开始编写自己的训练循环。

但是如果你需要一个自定义的训练算法,但仍然希望从fit()的便利功能中受益,比如回调函数,内置的分发支持或步骤融合,该怎么办呢?

Keras的一个核心原则是逐步揭示复杂性。您应该总是能够逐渐进入更低级的工作流程。如果高级功能与您的使用情况不完全匹配,您不应该感到掉入悬崖。您应该能够在保留相应数量高级便利的同时获得对细节的更多控制。

当你需要自定义fit()函数的行为时,你应该重写Model类的训练步骤函数。这是fit()函数在处理每个数据批次时调用的函数。然后,你仍然可以像往常一样调用fit()函数 - 它将会运行你自己的学习算法。

请注意这种模式并不妨碍您使用函数式API构建模型。无论您是构建顺序模型、函数式API模型还是子类化模型,都可以使用这种模式。

现在让我们看看这一切都是怎么工作的吧。

导入

import os# This guide can only be run with the TF backend.
os.environ["KERAS_BACKEND"] = "tensorflow"import tensorflow as tf
import keras
from keras import layers
import numpy as np

来一个简单例子

我们创建一个新的类,继承自keras.Model。

我们只需重写train_step(self, data)方法。

我们返回一个将指标名称(包括损失)映射到它们当前值的字典。

输入参数data是传递给fit作为训练数据的内容:

如果你通过调用fit(x, y, ...)传递NumPy数组,则data将是元组(x, y)

如果你通过调用fit(dataset, ...)传递tf.data.Dataset,则data将是每个批次由数据集产生的内容。

在train_step()方法的主体中,我们实现了一个常规的训练更新,与你已经熟悉的类似。重要的是,我们通过self.compute_loss()计算损失,它包装了传递给compile()的损失函数。

类似地,我们对self.metrics中的指标调用metric.update_state(y, y_pred)来更新在compile()中传递的指标的状态,并在最后从self.metrics查询结果以获取它们的当前值。

class CustomModel(keras.Model):def train_step(self, data):# Unpack the data. Its structure depends on your model and# on what you pass to `fit()`.x, y = datawith tf.GradientTape() as tape:y_pred = self(x, training=True)  # Forward pass# Compute the loss value# (the loss function is configured in `compile()`)loss = self.compute_loss(y=y, y_pred=y_pred)# Compute gradientstrainable_vars = self.trainable_variablesgradients = tape.gradient(loss, trainable_vars)# Update weightsself.optimizer.apply(gradients, trainable_vars)# Update metrics (includes the metric that tracks the loss)for metric in self.metrics:if metric.name == "loss":metric.update_state(loss)else:metric.update_state(y, y_pred)# Return a dict mapping metric names to current valuereturn {m.name: m.result() for m in self.metrics}

现在让咱们试试:

# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)

下沉到更低的级别

当然,你可以在compile()中跳过传递损失函数,而是在train_step中手动完成所有操作。同样,在指标方面也是如此。

这是一个低级示例,只使用compile()配置优化器:

我们首先创建度量实例来跟踪我们的损失和MAE得分(在__init__()中)。

我们实现了一个自定义的train_step()函数,该函数通过调用update_state()来更新这些度量的状态,然后通过调用result()来查询它们的当前平均值,以便在进度条中显示并传递给任何回调函数。

注意,我们需要在每个epoch之间调用reset_states()来重置我们的度量!

否则,调用result()将返回训练开始以来的平均值,而我们通常使用每个epoch的平均值。

幸运的是,框架可以为我们做到这一点:只需将要重置的任何度量列在模型的metrics属性中。

模型将在每次fit()的epoch开始或调用evaluate()的开始时调用此处列出的任何对象的reset_states()函数。

class CustomModel(keras.Model):def __init__(self, *args, **kwargs):super().__init__(*args, **kwargs)self.loss_tracker = keras.metrics.Mean(name="loss")self.mae_metric = keras.metrics.MeanAbsoluteError(name="mae")self.loss_fn = keras.losses.MeanSquaredError()def train_step(self, data):x, y = datawith tf.GradientTape() as tape:y_pred = self(x, training=True)  # Forward pass# Compute our own lossloss = self.loss_fn(y, y_pred)# Compute gradientstrainable_vars = self.trainable_variablesgradients = tape.gradient(loss, trainable_vars)# Update weightsself.optimizer.apply(gradients, trainable_vars)# Compute our own metricsself.loss_tracker.update_state(loss)self.mae_metric.update_state(y, y_pred)return {"loss": self.loss_tracker.result(),"mae": self.mae_metric.result(),}@propertydef metrics(self):# We list our `Metric` objects here so that `reset_states()` can be# called automatically at the start of each epoch# or at the start of `evaluate()`.return [self.loss_tracker, self.mae_metric]# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)# We don't pass a loss or metrics here.
model.compile(optimizer="adam")# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)

支持样本权重和类别权重

您可能已经注意到,我们的第一个基本示例没有提及样本加权。

如果您想支持fit()函数的sample_weight和class_weight参数,只需按照以下步骤操作:

从data参数中解包sample_weight 将其传递给compute_loss和update_state函数(当然,如果您不依赖compile()函数来计算损失和指标,您也可以手动应用它) 就是这样。

class CustomModel(keras.Model):def train_step(self, data):# Unpack the data. Its structure depends on your model and# on what you pass to `fit()`.if len(data) == 3:x, y, sample_weight = dataelse:sample_weight = Nonex, y = datawith tf.GradientTape() as tape:y_pred = self(x, training=True)  # Forward pass# Compute the loss value.# The loss function is configured in `compile()`.loss = self.compute_loss(y=y,y_pred=y_pred,sample_weight=sample_weight,)# Compute gradientstrainable_vars = self.trainable_variablesgradients = tape.gradient(loss, trainable_vars)# Update weightsself.optimizer.apply(gradients, trainable_vars)# Update the metrics.# Metrics are configured in `compile()`.for metric in self.metrics:if metric.name == "loss":metric.update_state(loss)else:metric.update_state(y, y_pred, sample_weight=sample_weight)# Return a dict mapping metric names to current value.# Note that it will include the loss (tracked in self.metrics).return {m.name: m.result() for m in self.metrics}# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])# You can now use sample_weight argument
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
sw = np.random.random((1000, 1))
model.fit(x, y, sample_weight=sw, epochs=3)

提供您自己的评估步骤

如果你想对model.evaluate()的调用做同样的操作,那么你可以通过同样的方式覆盖test_step。下面是实现的样例:

class CustomModel(keras.Model):def test_step(self, data):# Unpack the datax, y = data# Compute predictionsy_pred = self(x, training=False)# Updates the metrics tracking the lossloss = self.compute_loss(y=y, y_pred=y_pred)# Update the metrics.for metric in self.metrics:if metric.name == "loss":metric.update_state(loss)else:metric.update_state(y, y_pred)# Return a dict mapping metric names to current value.# Note that it will include the loss (tracked in self.metrics).return {m.name: m.result() for m in self.metrics}# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(loss="mse", metrics=["mae"])# Evaluate with our custom test_step
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)

总结:一个端到端的GAN示例

让我们通过一个端到端的示例来演示你刚刚学到的所有内容。

让我们考虑以下内容:

一个生成器网络,用于生成28x28x1的图像。

一个判别器网络,用于将28x28x1的图像分为两类("假"和"真")。

每个网络都有一个优化器。 一个损失函数,用于训练判别器。

# Create the discriminator
discriminator = keras.Sequential([keras.Input(shape=(28, 28, 1)),layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),layers.LeakyReLU(negative_slope=0.2),layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),layers.LeakyReLU(negative_slope=0.2),layers.GlobalMaxPooling2D(),layers.Dense(1),],name="discriminator",
)# Create the generator
latent_dim = 128
generator = keras.Sequential([keras.Input(shape=(latent_dim,)),# We want to generate 128 coefficients to reshape into a 7x7x128 maplayers.Dense(7 * 7 * 128),layers.LeakyReLU(negative_slope=0.2),layers.Reshape((7, 7, 128)),layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),layers.LeakyReLU(negative_slope=0.2),layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),layers.LeakyReLU(negative_slope=0.2),layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),],name="generator",
)

这是一个完整的GAN类,它重写了compile()方法以使用自己的签名,并在train_step中使用17行代码实现了整个GAN算法。

class GAN(keras.Model):def __init__(self, discriminator, generator, latent_dim):super().__init__()self.discriminator = discriminatorself.generator = generatorself.latent_dim = latent_dimself.d_loss_tracker = keras.metrics.Mean(name="d_loss")self.g_loss_tracker = keras.metrics.Mean(name="g_loss")self.seed_generator = keras.random.SeedGenerator(1337)@propertydef metrics(self):return [self.d_loss_tracker, self.g_loss_tracker]def compile(self, d_optimizer, g_optimizer, loss_fn):super().compile()self.d_optimizer = d_optimizerself.g_optimizer = g_optimizerself.loss_fn = loss_fndef train_step(self, real_images):if isinstance(real_images, tuple):real_images = real_images[0]# Sample random points in the latent spacebatch_size = tf.shape(real_images)[0]random_latent_vectors = keras.random.normal(shape=(batch_size, self.latent_dim), seed=self.seed_generator)# Decode them to fake imagesgenerated_images = self.generator(random_latent_vectors)# Combine them with real imagescombined_images = tf.concat([generated_images, real_images], axis=0)# Assemble labels discriminating real from fake imageslabels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)# Add random noise to the labels - important trick!labels += 0.05 * keras.random.uniform(tf.shape(labels), seed=self.seed_generator)# Train the discriminatorwith tf.GradientTape() as tape:predictions = self.discriminator(combined_images)d_loss = self.loss_fn(labels, predictions)grads = tape.gradient(d_loss, self.discriminator.trainable_weights)self.d_optimizer.apply(grads, self.discriminator.trainable_weights)# Sample random points in the latent spacerandom_latent_vectors = keras.random.normal(shape=(batch_size, self.latent_dim), seed=self.seed_generator)# Assemble labels that say "all real images"misleading_labels = tf.zeros((batch_size, 1))# Train the generator (note that we should *not* update the weights# of the discriminator)!with tf.GradientTape() as tape:predictions = self.discriminator(self.generator(random_latent_vectors))g_loss = self.loss_fn(misleading_labels, predictions)grads = tape.gradient(g_loss, self.generator.trainable_weights)self.g_optimizer.apply(grads, self.generator.trainable_weights)# Update metrics and return their value.self.d_loss_tracker.update_state(d_loss)self.g_loss_tracker.update_state(g_loss)return {"d_loss": self.d_loss_tracker.result(),"g_loss": self.g_loss_tracker.result(),}

让我们试试吧:

# Prepare the dataset. We use both the training & test MNIST digits.
batch_size = 64
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(batch_size)gan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
)# To limit the execution time, we only train on 100 batches. You can train on
# the entire dataset. You will need about 20 epochs to get nice results.
gan.fit(dataset.take(100), epochs=1)

深度学习背后的思想很简单吧。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/778241.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python+Django+Yolov5路面墙体桥梁裂缝特征检测识别html网页前后端

程序示例精选 PythonDjangoYolov5路面墙体桥梁裂缝特征检测识别html网页前后端 如需安装运行环境或远程调试,见文章底部个人QQ名片,由专业技术人员远程协助! 前言 这篇博客针对《PythonDjangoYolov5路面墙体桥梁裂缝特征检测识别html网页前…

Parade Series - SVG Resource

iconfont https://www.iconfont.cn/?spma313x.search_index.i3.2.74e53a819tkkcG音符 <div class"form-group"><a href"Javascript:reload();" class"btn btn-icon btn-outline-light btn-block" style";"><svg t&q…

打造快乐成长的乐园:探索少儿教育项目的魅力

在当今社会&#xff0c;家长们越来越重视孩子的全面发展和个性培养&#xff0c;少儿教育项目因其独特的魅力吸引着越来越多的关注。本文将探讨少儿教育项目的特点、重要性&#xff0c;以及如何打造一个快乐成长的教育乐园。 ### 少儿教育项目的价值 少儿教育项目不仅仅是传授…

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之九 简单闪烁效果

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之九 简单闪烁效果 目录 Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单实战案例 之九 简单闪烁效果 一、简单介绍 二、简单闪烁效果实现原理 三、简单闪烁效果案例实现简单步骤 四、注意事项 一、简单…

【开发篇】十二、GCeasy报告分析

文章目录 1、图一&#xff1a;正常情况2、图二&#xff1a;缓存对象过多3、图三&#xff1a;内存泄漏4、图四&#xff1a;频繁持续Full GC5、图五&#xff1a;元空间不足导致的Full GC 1、图一&#xff1a;正常情况 正常的堆内存如图&#xff1a; 锯齿状对象创建后内存占用上…

基础算法-去重字符串,辗转相除法,非递归前序遍历二叉树题型分析

目录 不同子串 辗转相除法-求最大公约数 二叉树非递归前序遍历 不同子串 从a开始&#xff0c;截取 a aa aaa aaab 从第二个下标开始a aa aab 从第三个 a ab 从第四个 b 使用set的唯一性&#xff0c;然后暴力遍历来去去重&#xff0c;从第一个下标开始截取aaab a aa aaa aaab…

ES学习日记(三)-------第三方插件选择

前言 在学习和使用Elasticsearch的过程中&#xff0c;必不可少需要通过一些工具查看es的运行状态以及数据。如果都是通过rest请求&#xff0c;未免太过麻烦&#xff0c;而且也不够人性化。 目前我了解的比较主流的插件就三个,head,cerebor和elasticHD 1.head 老牌插件,功能…

原生js实现循环滚动效果

原生js实现如下图循环滚动效果 核心代码 <div class"scroll"><div class"blist" id"scrollContainer"><div class"bitem"></div>......<div class"bitem"></div></div> </di…

Long long类型比较大小

long 与 Long long类型和Long类型是不一样&#xff0c;long类型属于基本的数据类型&#xff0c;而Long是long类型的包装类。 结论 long是基本数据类型&#xff0c;判断是否相等时使用 &#xff0c;即可判断值是否相等。&#xff08;基本数据类型没有equals()方法&#xff0…

局域网找不到共享电脑怎么办?

局域网找不到共享电脑是一种常见的问题&#xff0c;给我们的共享与合作带来一定的困扰。天联组网技术可以解决这个问题。本文将介绍天联组网的原理和优势&#xff0c;并探讨其在解决局域网找不到共享电脑问题中的应用。 天联组网的原理和优势 天联组网是一种基于加速服务器的远…

基于Pytorch的验证码识别模型应用

前言 在做OCR文字识别的时候&#xff0c;或多或少会接触一些验证码图片&#xff0c;这里收集了一些验证码图片&#xff0c;可以对验证码进行识别&#xff0c;可以识别4到6位&#xff0c;纯数字型、数字字母型和纯字母型的一些验证码&#xff0c;准确率还是相当高&#xff0c;需…

STM32 PWM通过RC低通滤波转双极性SPWM测试

STM32 PWM通过RC低通滤波转双极性SPWM测试 &#x1f4cd;参考内容《利用是stm32cubemx实现双极性spwm调制 基于stm32f407vet6》&#x1f4fa;相关视频链接&#xff1a;https://www.bilibili.com/video/BV16S4y147hB/?spm_id_from333.788 双极性SPWM调制讲解以及基于stm32的代码…

基于 RisingWave 和 ScyllaDB 构建事件驱动应用

概览 在构建事件驱动应用时&#xff0c;人们面临着两大挑战&#xff1a;1&#xff09;低延迟处理大量数据&#xff1b;2&#xff09;实现流数据的实时摄取和转换。 结合 RisingWave 的流处理功能和 ScyllaDB 的高性能 NoSQL 数据库&#xff0c;可为构建事件驱动应用和数据管道…

使用html实现图片相册展示设计

<!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>图片&#xff08;相册&#xff09;展示设计</title><link rel"stylesheet" href"./style.css"> </head> <b…

YOLOv8改进 | 检测头篇 | 2024最新HyCTAS模型提出SAttention(自研轻量化检测头 -> 适用分割、Pose、目标检测)

一、本文介绍 本文给大家带来的改进机制是由全新SOTA分割模型(Real-Time Image Segmentation via Hybrid Convolutional-TransformerArchitecture Search)HyCTAS提出的一种SelfAttention注意力机制,论文中叫该机制应用于检测头当中(论文中的分割效果展现目前是最好的)。我…

【Ubuntu】Ubuntu LTS 稳定版更新策略

1、确保下载环境 sudo apt update && sudo apt upgrade -y sudo apt autoremove 2、安装更新管理器 sudo apt install update-manager-core -y 3、设置只更新稳定版 sudo vim /etc/update-manager/release-upgrades 4、开始更新&#xff0c;耐心等待 sudo do-re…

深入浅出的揭秘游标尺模式与迭代器模式的神秘面纱 ✨

​&#x1f308; 个人主页&#xff1a;danci_ &#x1f525; 系列专栏&#xff1a;《设计模式》 &#x1f4aa;&#x1f3fb; 制定明确可量化的目标&#xff0c;坚持默默的做事。 &#x1f680; 转载自&#xff1a;设计模式深度解析&#xff1a;深入浅出的揭秘游标尺模式与迭代…

【测试开发学习历程】Python数据类型:字符串-str(上)

目录 1 Python中的引号 2 字符串的声明 3 字符串的切片 4 字符串的常用函数 4.1 len()函数 4.2 ord()函数 4.3 chr()函数 5 字符串的常用方法&#xff08;内置方法/内建方法&#xff09; 5.1 find()方法 5.2 index()方法 5.3 rfind()方法 5.4 rindex()方法 1 Python…

SAP-CO主数据之统计指标创建-<KK01>

公告&#xff1a;周一至周五每日一更&#xff0c;周六日存稿&#xff0c;请您点“关注”和“在看”&#xff0c;后续推送的时候不至于看不到每日更新内容&#xff0c;感谢。 目录 一、背景&#xff1a; 成本中心主数据创建&#xff1a;传送门 成本要素主数据创建&#xff1…

Linux内核之最核心数据结构之二:struct inode(三十一)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…