ray.rllib 入门实践-5: 训练算法

        前面的博客介绍了ray.rllib中算法的配置和构建,也包含了算法训练的代码。 但是rllib中实现算法训练的方式不止一种,本博客对此进行介绍。很多教程使用 PPOTrainer 进行训练,但是 PPOTrainer 在最近的 ray 版本中已经取消了。

方式1: algo.train()

        rllib 中的 Algorithm 类自带了.train() 函数,实现算法训练,前面几个博客教程均是采用的这种方式。这里仅再提供一下示例, 不再赘述:

import os 
from ray.rllib.algorithms.ppo import PPO,PPOConfig
from ray.tune.logger import pretty_print## 配置算法
storage_path = "F:/codes/RLlib_study/ray_results/build_method_3"
config = PPOConfig()
config = config.rollouts(num_rollout_workers=2)
config = config.resources(num_gpus=0,num_cpus_per_worker=1,num_gpus_per_worker=0)
config = config.environment(env="CartPole-v1",env_config={})
config.output = storage_path  ## 设置过程文件的存储路径## 构建算法
algo = config.build()## 训练
for i in range(3):result = algo.train() print(f"episode_{i}")

方式2:tune.Tuner()

        以上方式只能固定训练超参数,不能对训练超参数寻优。ray中还有一个模块 tune, 专门用于算法训练过程中超参数调参。

        在使用tune.Tuner()执行rllib算法训练时, 可以默认为tune背后自动执行了以下操作:

algo = PPOConfig().build() ## 构建算法

result = algo.train()  ## 算法训练

print(pretty_print(result))  ## 每完成一次algo.train, 打印一次阶段性训练结果

algo.save_checkpoint()  ## 保存训练模型

并且遍历了多个超参数组合,多次进行训练。直到达到停止训练的条件(自己配置)。 

        基于tune的rllib训练示例如下(代码篇幅比较大是因为添加的功能模块和注释比较多,后面的介绍主要以 方式一为主,所以对于这种方式,这里介绍的多一些):

import ray 
from ray.rllib.algorithms.ppo import PPO,PPOConfig 
from ray import train, tune
import torch 
import os 
import shutil 
from ray.tune.logger import pretty_print 
import gymnasium as gym ray.init()####  配置算法  ####
config = PPOConfig()
config = config.training(lr=tune.grid_search([0.01, 0.001]))
config = config.environment(env="CartPole-v1")####  配置 tune  ###### 准备 tune 的 stop_condition,多个条件之间是”或“的关系。有一个满足即停止训练。 
##   'episode_reward_mean'关键字将在 ”ray-2.40”版中中被抛弃,
##   届时需要用 'env_runners/episode_return_mean' 替代 'episode_reward_mean'
stop_condition = {'episode_reward_mean':10,    ## 这里设置的结束条件很宽松,所以能够快速结束训练。"training_iteration":3}## 准备 tune 的过程文件存储路径
storage_path = "F:/codes/RLlib_study/ray_results"
os.makedirs(storage_path, exist_ok=True)## 准备 tune 的 checkpoint_config
##    tune 默认保存每个 algo.train() 训练得到的 checkpoint.
##    通过以下配置,可以对此进行自定义修改
checkpoint_config =  train.CheckpointConfig(num_to_keep=None, ## 保存几个checkpoint, None 表示保存所有checkpointcheckpoint_at_end=True) ## 是否在训练结束后保存 checkpoint. ## 配置 tuner
tuner = tune.Tuner(PPO,                                       ## 需要是一个 rllib 的 Algorithm 类, 从 ray.rllib.algorithms 导入, 也可以是自定义的,后面介绍                   run_config = train.RunConfig(stop = stop_condition, checkpoint_config = checkpoint_config,  ## 用于设置保存哪个checkpoint. storage_path = storage_path,            ## 如果不设置, 默认存储路径是 “~/ray-results” 或 “C:/用户/xxx/ray_results”),param_space=config,  ## 这里定义了参数搜索调优空间,是一个 PPOConfig 对象
)## 执行训练                         
results = tuner.fit() ## tuner 返回一个Result表格对象,该对象允许进一步分析训练结果并检索经过训练的智能体的checkpoint。
print("====训练结束====")## 获取最佳训练结果
best_result = results.get_best_result(metric="episode_reward_mean", mode="max")
## 以 "episode_reward_mean" 为选择指标, 从results里面选择checkpoint, 选择模式是“max”,
## 'episode_reward_mean'关键字将在ray-2.40版中中被抛弃,届时需要用 'env_runners/episode_return_mean' 替代 'episode_reward_mean'## 从最佳训练结果中提取对应的 checkpoint , 并保存
checkpoint_save_dir = "F:/codes/RLlib_study/ray_results/best_checkpoints"
os.makedirs(checkpoint_save_dir, exist_ok=True)best_checkpoint = best_result.checkpoint 
if best_checkpoint:with best_checkpoint.as_directory() as checkpoint_dir:print(f"====最佳模型路径位于:{checkpoint_dir}====")## 把最佳模型转存到指定位置。 shutil.rmtree(checkpoint_save_dir)shutil.copytree(checkpoint_dir,checkpoint_save_dir)print(f"====保存最佳模型到:{checkpoint_save_dir}====")## 加载保存的最佳模型
checkpoint_dir = "F:/codes/RLlib_study/ray_results/best_checkpoints"
algo = PPO.from_checkpoint(checkpoint_dir)
print(f"==== 加载最佳模型: {checkpoint_dir}")## evaluate 模型
env_name = "CartPole-v1"
env = gym.make(env_name)## 模型推断: method-1
step = 0
episode_reward = 0
terminated = truncated = False obs,info = env.reset()
while not terminated and not truncated:action = algo.compute_single_action(obs)obs, reward, terminated, truncated, info = env.step(action)episode_reward += rewardstep += 1print(f"step = {step}, reward = {reward}, action = {action}, obs = {obs}, episode_reward = {episode_reward}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/69111.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uart、iic、spi通信总线

一、uart uart一种异步串行通信协议,用于在两个设备之间传输数据。它将数据按位发送,不需要时钟信号进行同步。在uart通信中,数据通过两根线路传输:发送线(TX)和接收线(RX)。它主要用…

LMI Gocator GO_SDK VS2019引用配置

LMI SDK在VS2019中的引用是真的坑爹,总结一下经验,希望后来的人能少走弯路.大致内容如下: (1) 环境变量 (2)C/C 附加包含目录 E:\GWQ\Gocator\GO_SDK\Gocator\GoSdk E:\GWQ\Gocator\GO_SDK\Platform\kApi (3&#…

QT QTableWidget控件 全面详解

本系列文章全面的介绍了QT中的57种控件的使用方法以及示例,包括 Button(PushButton、toolButton、radioButton、checkBox、commandLinkButton、buttonBox)、Layouts(verticalLayout、horizontalLayout、gridLayout、formLayout)、Spacers(verticalSpacer、horizontalSpacer)、…

C# OpenCV机器视觉:红外体温检测

在一个骄阳似火的夏日,全球却被一场突如其来的疫情阴霾笼罩。阿强所在的小镇,平日里熙熙攘攘的街道变得冷冷清清,人们戴着口罩,行色匆匆,眼神中满是对病毒的恐惧。阿强作为镇上小有名气的科技达人,看着这一…

12、MySQL锁相关知识

目录 1、全局锁和表锁使用场景 2、行锁的意义 3、为什么说间隙锁解决了快照的幻读? 4、RR隔离级别产生幻读的场景 5、详解元数据锁(MDL)作用以及如何减少元数据锁 6、出现死锁场景 7、查看MySQL锁情况 8、自增锁 1、全局锁和表锁使用场景 全局锁 备份数据库:当需要…

立创开发板入门ESP32C3第八课 修改AI大模型接口为deepseek3接口

#原代码用的AI模型是minimax的API接口,现在试着改成最热门的deepseek3接口。# 首先按理解所得,在main文件夹下,有minimax.c和minimax.h, 它们是这个API接口的头文件和实现文件,然后在main.c中被调用。所以我们一步步更改。 申请…

2025.1.21——六、BUU XSS COURSE 1 XSS漏洞|XSS平台搭建

题目来源:buuctf BUU XSS COURSE 1 目录 一、打开靶机,整理信息 二、解题思路 step 1:输入框尝试一下 step 2:开始xss注入 step 3:搭建平台 step 4:利用管理员cookie访问地址 三、小结 二编&#…

第20篇:Python 开发进阶:使用Django进行Web开发详解

第20篇:使用Django进行Web开发 内容简介 在上一篇文章中,我们深入探讨了Flask框架的高级功能,并通过构建一个博客系统展示了其实际应用。本篇文章将转向Django,另一个功能强大且广泛使用的Python Web框架。我们将介绍Django的核…

操作无法完成,因为文件已经在Electronic Team Virtual Serial Port Driver Service中打开

报错 操作无法完成,因为文件已经在Electronic Team Virtual Serial Port Driver Service中打开 现象 这个exe文件无法删除 解决办法 按下WinR, 找到Electronic Team Virtual Serial Port Driver Service,右击停止. 再次尝试删除,发现这个exe文件成功删除!

单值二叉树(C语言详解版)

一、摘要 今天要讲的是leetcode单值二叉树,这里用到的C语言,主要提供的是思路,大家看了我的思路之后可以点击链接自己试一下。 二、题目简介 如果二叉树每个节点都具有相同的值,那么该二叉树就是单值二叉树。 只有给定的树是单…

【多表查询】

目录 一. 一对多二. 一对一 and 多对多三. 多表设计案例四. 多表查询4.1 概述4.2 内连接与外连接4.3 子查询4.4 案例 \quad 一. 一对多 \quad 删除外键 \quad 二. 一对一 and 多对多 \quad \quad 三. 多表设计案例 \quad 一个员工对应多个工作经历 \quad 四. 多表查询 \quad \q…

CentOS 7 搭建lsyncd实现文件实时同步 —— 筑梦之路

在 CentOS 7 上搭建 lsyncd(Live Syncing Daemon)以实现文件的实时同步,可以按照以下步骤进行操作。lsyncd 是一个基于 inotify 的轻量级实时同步工具,支持本地和远程同步。以下是详细的安装和配置步骤: 1. 系统准备 …

[Dialog屏幕开发] Table Control 列数据操作

阅读该篇文章之前,可先阅读下述资料 [Dialog屏幕开发] 屏幕绘制(Table Control控件)https://blog.csdn.net/Hudas/article/details/145314623?spm1001.2014.3001.5501https://blog.csdn.net/Hudas/article/details/145314623?spm1001.2014.3001.5501上篇文章我们…

Arduino大师练成手册 -- 读取DHT11

要在 Arduino 上控制 DHT11 温湿度传感器,你可以按照以下步骤进行: 硬件连接: 将 DHT11 的 VCC 引脚连接到 Arduino 的 5V 引脚。 将 DHT11 的 GND 引脚连接到 Arduino 的 GND 引脚。 将 DHT11 的 DATA 引脚连接到 Arduino 的数字引脚&am…

leetcode刷题记录(八十九)——35. 搜索插入位置

(一)问题描述 35. 搜索插入位置 - 力扣(LeetCode)35. 搜索插入位置 - 给定一个排序数组和一个目标值,在数组中找到目标值,并返回其索引。如果目标值不存在于数组中,返回它将会被按顺序插入的位…

-bash: ./uninstall.command: /bin/sh^M: 坏的解释器: 没有那个文件或目录

终端报错: -bash: ./uninstall.command: /bin/sh^M: 坏的解释器: 没有那个文件或目录原因:由于文件行尾符不匹配导致的。当脚本文件在Windows环境中创建或编辑后,行尾符为CRLF(即回车和换行,\r\n)&#xf…

渐变颜色怎么调?

渐变颜色的调整是设计中非常重要的一部分,尤其是在创建具有视觉吸引力和深度感的设计作品时。以下是一些在不同设计软件中调整渐变颜色的详细步骤和技巧: 一、Adobe Photoshop 1. 创建渐变 打开渐变工具: 选择工具栏中的“渐变工具”&#x…

安装wxFormBuilder

1. 网址:GitHub - wxFormBuilder/wxFormBuilder: A wxWidgets GUI Builder 2. 安装MSYS2 MSYS2可以在GitHub的内容中找到,这个版本是32位64位的 3. 在程序中打开MINGW64 shell 4. 在MSYS2 MINGW64 shell中输入 pacman -Syu pacman -S ${MINGW_PACKAGE…

在 Windows 系统上,将 Ubuntu 从 C 盘 迁移到 D 盘

在 Windows 系统上,如果你使用的是 WSL(Windows Subsystem for Linux)并安装了 Ubuntu,你可以将 Ubuntu 从 C 盘 迁移到 D 盘。迁移过程涉及导出当前的 Ubuntu 发行版,然后将其导入到 D 盘的目标目录。以下是详细的步骤…

【知识】可视化理解git中的cherry-pick、merge、rebase

转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn] 如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~ 这三个确实非常像,以至于对于初学者来说比较难理解。 总结对比 先给出对比: 特性git mergegit rebasegit cherry-pick功能合并…