pytorch之诗词生成5--train

先上代码:


import tensorflow as tf
from dataset import PoetryDataGenerator, poetry, tokenizer
from model import model
import settings
import utilsclass Evaluate(tf.keras.callbacks.Callback):"""在每个epoch训练完成后,保留最优权重,并随机生成settings.SHOW_NUM首古诗展示"""def __init__(self):super().__init__()# 给loss赋一个较大的初始值self.lowest = 1e10def on_epoch_end(self, epoch, logs=None):# 在每个epoch训练完成后调用# 如果当前loss更低,就保存当前模型参数if logs['loss'] <= self.lowest:self.lowest = logs['loss']model.save(settings.BEST_MODEL_PATH)# 随机生成几首古体诗测试,查看训练效果print(tokenizer.id_to_token((3)))print(tokenizer.id_to_token(2))print(tokenizer.id_to_token(1))print(tokenizer.id_to_token(0))for i in range(settings.SHOW_NUM):print(utils.generate_random_poetry(tokenizer, model))# 创建数据集
data_generator = PoetryDataGenerator(poetry, random=True)
# 开始训练
model.fit_generator(data_generator.for_fit(), steps_per_epoch=data_generator.steps, epochs=settings.TRAIN_EPOCHS,callbacks=[Evaluate()])

接下来我们开始分析代码:

class Evaluate(tf.keras.callbacks.Callback):

定义了一个回调类Evaluate,它是tf.keras.callbacks.Callback类的子类,回调函数是在训练的不同阶段调用的函数,用于执行额外的操作或监控模型的性能。

def __init__(self):super().__init__()# 给loss赋一个较大的初始值self.lowest = 1e10

这是Evaluate类的构造函数__init__(self),在这个构造函数中,有以下操作:

  • super().__init__():调用父类tf.keras.callbacks.Callback的构造函数,确保父类的初始化操作得到执行。
  • self.lwest=1e10:将lowest属性初始化为一个较大的值1e10。这个属性用于跟踪最低的损失值。通常将其初始化为一个较大的值,确保在训练过程的初始阶段,任何较小的损失值都可以成为新的最新值。
def on_epoch_end(self, epoch, logs=None):# 在每个epoch训练完成后调用# 如果当前loss更低,就保存当前模型参数if logs['loss'] <= self.lowest:self.lowest = logs['loss']model.save(settings.BEST_MODEL_PATH)# 随机生成几首古体诗测试,查看训练效果print(tokenizer.id_to_token((3)))print(tokenizer.id_to_token(2))print(tokenizer.id_to_token(1))print(tokenizer.id_to_token(0))for i in range(settings.SHOW_NUM):print(utils.generate_random_poetry(tokenizer, model))

在TensorFlow的回调函数中,logs是一个字典,其中包含了训练过程中的各种指标和损失值。它提供了一些有关模型的信息,可以用于监控和记录训练的进程。我们初始化我们的logs为空值,也就是没有记录任何信息。

logs['loss']表示访问logs字典的loss键。(由于函数是在每个epoch训练完成之后使用,训练之后logs就保存了模型的信息)。同时,如果损失值低于我们的预设值(第一轮),就将最低损失值进行更新。

然后使用模型的save方法,将模型的各种参数都保存到我们给定的路径中去。

然后我们就输出SHOW_NUM首我们生成的古诗。

# 创建数据集
data_generator = PoetryDataGenerator(poetry, random=True)
# 开始训练
model.fit_generator(data_generator.for_fit(), steps_per_epoch=data_generator.steps, epochs=settings.TRAIN_EPOCHS,callbacks=[Evaluate()])

这段代码调用PoetryDataGenerator类,我们将poetry传入模型,并进行随机打乱。

开始训练,使用fit_generator方法来训练模型。是模型对象的方法,用于使用生成器进行模型训练。它适用于数据较大无法一次加载到内存的情况,可以按照批次从生成器中获取数据进行训练。

steps_pre_epoch表示每个时钟周期加载多少个批次的数据进行训练。

callbacks=[Evaluate()]表示在训练模型过程中使用Evalute()函数,并将其作为一个回调函数传递给callbacks函数。回调函数是在训练的过程中特定时间被调用的函数,用于执行一些额外的操作,回调函数是在每个训练周期结束后被调用,在每轮训练的on_epoch_end事件中,回调函数会被触发并执行相应的操作。这意味着在每个训练周期结束时,回调函数会被调用用以执行自定义的评估操作。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/746442.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

将PostgreSQL插件移植到openGauss指导

1 概述 PostgreSQL 社区提供了丰富的插件&#xff0c;但由于 openGauss 和 PostgreSQL 存在一定的差异&#xff0c;如线程/进程模型、系统表和视图等&#xff0c;无法直接为 openGauss 所用&#xff0c;不可避免的需要在插件上做整改。 本文档主要对 Postgresql 插件移植到 o…

面试官:说说C++的引用和指针有什么区别

C中的引用和指针虽然都是用于间接访问和操作对象的工具&#xff0c;但它们之间存在几个重要的区别&#xff1a; 本质和存在性&#xff1a; 指针是一个变量&#xff0c;它存储了另一个变量的地址。指针有自己的内存地址&#xff0c;并且可以改变其指向的内容。 引用是一个别名&a…

springboot271制造装备物联及生产管理ERP系统

制造装备物联及生产管理ERP系统设计与实现 摘 要 传统办法管理信息首先需要花费的时间比较多&#xff0c;其次数据出错率比较高&#xff0c;而且对错误的数据进行更改也比较困难&#xff0c;最后&#xff0c;检索数据费事费力。因此&#xff0c;在计算机上安装制造装备物联及…

3月14日,每日信息差

&#x1f396; 素材来源官方媒体/网络新闻 &#x1f384; 5.5G通信网络在海南投入商用&#xff0c;较5G提升10倍 &#x1f30d; 国务院批复同意&#xff0c;珠海港口岸将整合并扩大开放 &#x1f30b; 同有科技&#xff1a;正在研究新型磁电存储技术 &#x1f381; 美国折扣零售…

考研模拟面试-答案【攻略】

考研模拟面试-答案【攻略】 前言版权推荐考研模拟面试-答案前面的问题通用问题专业题数据结构计算机网络操作系统数据库网络安全 手写题数据结构操作系统计算机网络 代码题基础代码题其他代码题 后面的问题补充题目 基础代码题答案链栈循环队列1循环队列2哈希表 最后 前言 202…

Oracle基础-分组查询 备份

一、概述 数据分组的目的是用来汇总数据或为整个分组显示单行的汇总信息&#xff0c;通常在查询结果集中使用GROUP BY 子句对记录进行分组。在SELECT 语句中&#xff0c;GROUP BY 子句位于FROM 子句之后&#xff0c;语法格式&#xff1a; SELECT columns_list FROM table_nam…

【测试知识】业务面试问答突击版1

高内聚低耦合 高内聚指的是将相关的功能或数据组织在一起&#xff0c;使得模块内部的各个元素紧密地联系在一起&#xff0c;完成特定的任务。 低耦合指的是模块之间的依赖关系尽可能地降低&#xff0c;模块之间的接口简单清晰&#xff0c;减少模块之间的相互影响。 文章目录 整…

【数据结构】二叉搜索树底层刨析

文章目录 1. 二叉搜索树的实现2. 二叉搜索树的应用3. 改造二叉搜索树为 KV 结构4. 二叉搜索树的性能分析 1. 二叉搜索树的实现 namespace key {template<class K>struct BSTreeNode{typedef BSTreeNode<K> Node;Node* _left;Node* _right;K _key;BSTreeNode(const…

工作中用到的 —— 工作总结提炼出来的股文

这里是目录 ---------------- VUE相关 -----------------1 - Vue3 是怎么得更快的&#xff1f;1-1 Fragment [frɡˈment]1-2 Suspense [səˈspens]1-3 Teleport [ˈtelipɔːt]1-4 v-memo 2- 说一下 Composition API3- 说一下 setup4- watch 和 watchEffect 的区别5- Vue3 响…

Sublime查看ANSI编码文档乱码问题

原因为没有安装对应的解码插件。 选择安装插件包 选择插件包&#xff1a;ConvertToUTF8或者GBK&#xff0c;我试了第一个插件包不行&#xff0c;安装GBK插件包后OK。

Git如何清除账户凭证

场景&#xff1a;一般发生在Git用户变更的情况 1.git base 操作 Git会使用凭证助手 credential.helper来储存账户凭证&#xff0c;通过以下命令移除&#xff1a; git config --system --unset credential.helper 除了system系统级外&#xff0c;还有 global、local范围。 查…

20万英文单词同义词宝典ACCESS\EXCEL数据库

英语同义词反义词的数据之前搞到过《近万英语单词同义词典ACCESS数据库》、《上百万英语同义反义词词典ACCESS数据库》&#xff0c;今天又搞到一份几十万行数据的&#xff0c;发上来看看有没有适合朋友们的需求。 今天这个数据提供了非常全的词汇单词以及词汇对应的含义以及近…

将Java项目Jar包制作成Docker镜像

文章目录 前言一、准备事项二、使用步骤1.Dockerfile脚本2.制作镜像推送Harbor仓库前言 以前单体项目通常采用传统部署方式将项目打成Jar包再进行部署。如果我们项目是微服务则需要进行Docker容器部署。本文将介绍如何在本地将Jar包制作成Docker镜像并推送到Harbor仓库 一、准…

Spring揭秘:ClassPathScanningProvider接口应用场景及实现原理!

技术应用场景 ClassPathScanningCandidateComponentProvider是Spring框架中一个非常核心的类&#xff0c;它主要用于在类路径下扫描并发现带有特定注解的组件&#xff0c;支持诸如ComponentScan、Component、Service、Repository和Controller等注解的自动扫描和注册。 ClassP…

Mysql 无法启动,mysql-bin.日志丢失删除处理

在linux操作系统中&#xff0c;当mysql无法启动时候&#xff0c;先看日志 2024-03-15T05:20:16.352075Z 0 [Warning] [MY-000081] [Server] option max_allowed_packet: unsigned value 107374182400 adjusted to 1073741824. 2024-03-15T05:20:16.352156Z 0 [Warning] [MY-010…

Marshmallow,一个有点甜的Python库

前言 在许多场景中&#xff0c;我们常常需要执行Python对象的序列化、反序列化操作。例如&#xff0c;在开发REST API时&#xff0c;或者在进行一些面向对象化的数据加载和保存时&#xff0c;这一功能经常派上用场。 经常cv Python代码的臭宝&#xff0c;接触最多的应该是通过…

验证与分享执行计划突变引发的问题

作者简介 张瑞远&#xff0c;曾经从事银行、证券数仓设计、开发、优化类工作&#xff0c;现主要从事电信级IT系统及数据库的规划设计、架构设计、运维实施、运维服务、故障处理、性能优化等工作。 持有Orale OCM,MySQL OCP及国产代表数据库认证。 获得的专业技能与认证包括 Oce…

被军训到的两天

1.gradle7.6.1 1.安装gradle7.6.1,一定要注意的是&#xff0c;使用的JDK是否能用&#xff0c;比如gradle7.6.1用的是JDK11。 2. F:/sofer....是Gradle自己的仓库地址&#xff0c;注意不能和maven使用一样的仓库。 使用specified location,可以避免下本项目的gradle版本&…

如何更改SonarQube的JDK版本

如何更改SonarQube的JDK版本 当需要升级或更换SonarQube所使用的JDK版本时&#xff0c;可以按照以下步骤进行操作&#xff1a; 第一步&#xff1a;确定新JDK的安装路径 首先&#xff0c;您需要找到您打算使用的JDK的安装路径。这通常是一个包含JDK各种工具和库的文件夹。请确…

ego - 人工智能原生 3D 模拟引擎——基于AI的3D引擎,可以做游戏、空间计算、元宇宙等项目

1. 产品概述:Ego是一款AI本地化的3D模拟引擎,旨在让非技术创作者通过自然语言生成逼真的角色、3D世界和交互式脚本。该平台提供了创建和分享游戏、虚拟世界和交互体验的功能。 2. 定位:Ego定位于解决开放世界游戏和模拟的三大难题:难以编写游戏脚本、非玩家角色无法展现人…