PyTorch 深度学习实战(24):分层强化学习(HRL)

一、分层强化学习原理

1. 分层学习核心思想

分层强化学习(Hierarchical Reinforcement Learning, HRL)通过时间抽象任务分解解决复杂长程任务。核心思想是:

对比维度传统强化学习分层强化学习
策略结构单一策略直接输出动作高层策略选择选项(Option)
时间尺度单一步长决策高层策略决策跨度长,底层策略执行
适用场景简单短程任务复杂长程任务(如迷宫导航、机器人操控)
2. Option-Critic 算法框架

Option-Critic 是 HRL 的代表性算法,其核心组件包括:


二、Option-Critic 实现步骤(基于 Gymnasium)

我们将以 Meta-World 机械臂多阶段任务 为例,实现 Option-Critic 算法:

  1. 定义选项集合:包含 reach(接近目标)、grasp(抓取)、move(移动) 三个选项

  2. 构建策略网络:高层策略 + 选项内部策略 + 终止条件网络

  3. 分层交互训练:高层选择选项,底层执行多步动作

  4. 联合梯度更新:优化高层和底层策略


三、代码实现

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical, Normal
import gymnasium as gym
from metaworld.envs import ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE
import time
​
# ================== 配置参数优化 ==================
class OptionCriticConfig:num_options = 3                  # 选项数量(reach, grasp, move)option_length = 20               # 选项最大执行步长hidden_dim = 128                 # 网络隐藏层维度lr_high = 1e-4                   # 高层策略学习率lr_option = 3e-4                 # 选项策略学习率gamma = 0.99                     # 折扣因子entropy_weight = 0.01            # 熵正则化权重max_episodes = 5000              # 最大训练回合数device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
​
# ================== 高层策略网络 ==================
class HighLevelPolicy(nn.Module):def __init__(self, state_dim, num_options):super().__init__()self.net = nn.Sequential(nn.Linear(state_dim, OptionCriticConfig.hidden_dim),nn.ReLU(),nn.Linear(OptionCriticConfig.hidden_dim, num_options))def forward(self, state):return self.net(state)
​
# ================== 选项内部策略网络 ==================
class OptionPolicy(nn.Module):def __init__(self, state_dim, action_dim):super().__init__()self.net = nn.Sequential(nn.Linear(state_dim, OptionCriticConfig.hidden_dim),nn.ReLU(),nn.Linear(OptionCriticConfig.hidden_dim, action_dim))def forward(self, state):return self.net(state)
​
# ================== 终止条件网络 ==================
class TerminationNetwork(nn.Module):def __init__(self, state_dim):super().__init__()self.net = nn.Sequential(nn.Linear(state_dim, OptionCriticConfig.hidden_dim),nn.ReLU(),nn.Linear(OptionCriticConfig.hidden_dim, 1),nn.Sigmoid()  # 输出终止概率)def forward(self, state):return self.net(state)
​
# ================== 训练系统 ==================
class OptionCriticTrainer:def __init__(self):# 初始化环境self.env = ALL_V2_ENVIRONMENTS_GOAL_OBSERVABLE['pick-place-v2-goal-observable']()# 处理观测空间if isinstance(self.env.observation_space, gym.spaces.Dict):self.state_dim = sum([self.env.observation_space.spaces[key].shape[0] for key in ['observation', 'desired_goal']])self.process_state = self._process_dict_stateelse:self.state_dim = self.env.observation_space.shape[0]self.process_state = lambda x: xself.action_dim = self.env.action_space.shape[0]# 初始化网络self.high_policy = HighLevelPolicy(self.state_dim, OptionCriticConfig.num_options).to(OptionCriticConfig.device)self.option_policies = nn.ModuleList([OptionPolicy(self.state_dim, self.action_dim).to(OptionCriticConfig.device)for _ in range(OptionCriticConfig.num_options)])self.termination_networks = nn.ModuleList([TerminationNetwork(self.state_dim).to(OptionCriticConfig.device)for _ in range(OptionCriticConfig.num_options)])# 优化器self.optimizer_high = optim.Adam(self.high_policy.parameters(), lr=OptionCriticConfig.lr_high)self.optimizer_option = optim.Adam(list(self.option_policies.parameters()) + list(self.termination_networks.parameters()),lr=OptionCriticConfig.lr_option)def _process_dict_state(self, state_dict):return np.concatenate([state_dict['observation'], state_dict['desired_goal']])def select_option(self, state):state = torch.FloatTensor(state).to(OptionCriticConfig.device)logits = self.high_policy(state)dist = Categorical(logits=logits)option = dist.sample()return option.item(), dist.log_prob(option)def select_action(self, state, option):state = torch.FloatTensor(state).to(OptionCriticConfig.device)action_mean = self.option_policies[option](state)dist = Normal(action_mean, torch.ones_like(action_mean))  # 假设动作空间连续action = dist.sample()log_prob = dist.log_prob(action).sum(dim=-1)  # 沿最后一个维度求和得到标量return action.cpu().numpy(), log_prob  # 返回标量log概率def should_terminate(self, state, current_option):state = torch.FloatTensor(state).to(OptionCriticConfig.device)terminate_prob = self.termination_networks[current_option](state)return torch.bernoulli(terminate_prob).item() == 1def train(self):for episode in range(OptionCriticConfig.max_episodes):state_dict, _ = self.env.reset()state = self.process_state(state_dict)episode_reward = 0current_option, log_prob_high = self.select_option(state)option_step = 0while True:# 执行选项内部策略action, log_prob_option = self.select_action(state, current_option)next_state_dict, reward, terminated, truncated, _ = self.env.step(action)done = terminated or truncatednext_state = self.process_state(next_state_dict)episode_reward += reward# 判断是否终止选项terminate = self.should_terminate(next_state, current_option) or (option_step >= OptionCriticConfig.option_length)# 计算梯度if terminate or done:# 计算选项价值(添加detach防止梯度传递)with torch.no_grad():next_value = self.high_policy(torch.FloatTensor(next_state).to(OptionCriticConfig.device)).max().item()termination_output = self.termination_networks[current_option](torch.FloatTensor(state).to(OptionCriticConfig.device))# 计算delta时分离终止网络的梯度delta = reward + OptionCriticConfig.gamma * next_value - termination_output.detach()
​# 高层策略梯度计算loss_high = -log_prob_high * deltaself.optimizer_high.zero_grad()loss_high.backward(retain_graph=True)  # 保留计算图self.optimizer_high.step()
​# 选项策略梯度计算loss_option = -log_prob_option * deltaentropy = -log_prob_option * torch.exp(log_prob_option.detach())loss_option_total = loss_option + OptionCriticConfig.entropy_weight * entropyself.optimizer_option.zero_grad()loss_option_total.backward()  # 此时仍可访问保留的计算图self.optimizer_option.step()# 重置选项if not done:current_option, log_prob_high = self.select_option(next_state)option_step = 0else:breakelse:option_step += 1state = next_stateif (episode + 1) % 100 == 0:print(f"Episode {episode+1} | Reward: {episode_reward:.1f}")
​
if __name__ == "__main__":start = time.time()start_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(start))print(f"开始时间: {start_str}")print("初始化环境...")trainer = OptionCriticTrainer()trainer.train()end = time.time()end_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(end))print(f"训练完成时间: {end_str}")print(f"训练完成,耗时: {end - start:.2f}秒")

四、关键代码解析

  1. 高层策略选择选项

    select_option:基于当前状态选择选项,返回选项 ID 和选择概率的对数值。
  2. 选项内部策略执行

    select_action:根据当前选项生成动作,支持连续动作空间(使用高斯分布)。
  3. 终止条件判断

    should_terminate:根据终止网络输出概率判断是否终止当前选项。
  4. 梯度更新逻辑

    高层策略:基于选项的价值差(TD Error)更新。
    选项策略:结合 TD Error 和熵正则化更新。

五、训练输出示例

开始时间: 2025-03-24 08:29:46
初始化环境...
Episode 100 | Reward: 2.7
Episode 200 | Reward: 4.9
Episode 300 | Reward: 2.2
Episode 400 | Reward: 2.8
Episode 500 | Reward: 3.0
Episode 600 | Reward: 3.3
Episode 700 | Reward: 3.2
Episode 800 | Reward: 4.7
Episode 900 | Reward: 5.3
Episode 1000 | Reward: 7.5
Episode 1100 | Reward: 6.3
Episode 1200 | Reward: 3.7
Episode 1300 | Reward: 7.8
Episode 1400 | Reward: 3.8
Episode 1500 | Reward: 2.4
Episode 1600 | Reward: 2.3
Episode 1700 | Reward: 2.5
Episode 1800 | Reward: 2.7
Episode 1900 | Reward: 2.7
Episode 2000 | Reward: 3.9
Episode 2100 | Reward: 4.5
Episode 2200 | Reward: 4.1
Episode 2300 | Reward: 4.7
Episode 2400 | Reward: 4.0
Episode 2500 | Reward: 4.3
Episode 2600 | Reward: 3.8
Episode 2700 | Reward: 3.3
Episode 2800 | Reward: 4.6
Episode 2900 | Reward: 5.2
Episode 3000 | Reward: 7.7
Episode 3100 | Reward: 7.8
Episode 3200 | Reward: 3.3
Episode 3300 | Reward: 5.3
Episode 3400 | Reward: 4.5
Episode 3500 | Reward: 3.9
Episode 3600 | Reward: 4.1
Episode 3700 | Reward: 4.0
Episode 3800 | Reward: 5.2
Episode 3900 | Reward: 8.2
Episode 4000 | Reward: 2.2
Episode 4100 | Reward: 2.2
Episode 4200 | Reward: 2.2
Episode 4300 | Reward: 2.2
Episode 4400 | Reward: 6.9
Episode 4500 | Reward: 5.6
Episode 4600 | Reward: 2.0
Episode 4700 | Reward: 1.6
Episode 4800 | Reward: 1.7
Episode 4900 | Reward: 1.9
Episode 5000 | Reward: 3.1
训练完成时间: 2025-03-24 12:41:48
训练完成,耗时: 15122.31秒

在下一篇文章中,我们将探索 逆向强化学习(Inverse RL),并实现 GAIL 算法!


注意事项

  1. 安装依赖:

    pip install metaworld gymnasium torch
  2. Meta-World 需要 MuJoCo 许可证:

    export MUJOCO_PY_MUJOCO_PATH=/path/to/mujoco
  3. 训练时间较长(推荐 GPU 加速):

    CUDA_VISIBLE_DEVICES=0 python option_critic.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/898983.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

车载网络测试实操源码_使用CAPL脚本进行UDS刷写及其自动化测试

系列文章目录 使用CAPL脚本解析hex、S19、vbf文件 使用CAPL脚本对CAN报文的Counter、CRC、周期、错误帧进行实时监控 使用CAPL脚本模拟发送符合协议要求(Counter和CRC)的CAN报文 使用CAPL脚本控制继电器实现CAN线、电源线的通断 使用CAPL脚本实现安全访问解锁 使用CAPL脚本实现…

Spring Boot整合Spring Data JPA

Spring Data作为Spring全家桶中重要的一员,在Spring项目全球使用市场份额排名中多次居前位,而在Spring Data子项目的使用份额排名中,Spring Data JPA也一直名列前茅。Spring Boot为Spring Data JPA提供了启动器,使Spring Data JPA…

JS 应用WebPack 打包器第三方库 JQuery安装使用安全检测

# 打包器 -WebPack- 使用 & 安全 参考: https://mp.weixin.qq.com/s/J3bpy-SsCnQ1lBov1L98WA Webpack 是一个模块打包器。在 Webpack 中会将前端的所有资源文件都作为模块处理。 它将根据模块的依赖关系进行分析,生成对应的资源。 五个核心概…

Oracle归档配置及检查

配置归档位置到 USE_DB_RECOVERY_FILE_DEST,并设置存储大小 startup mount; !mkdir /db/archivelog ALTER SYSTEM SET db_recovery_file_dest_size100G SCOPEBOTH; ALTER SYSTEM SET db_recovery_file_dest/db/archivelog SCOPEBOTH; ALTER SYSTEM SET log_archive…

Four.meme是什么,一篇文章读懂

一、什么是Four.meme? Four.meme 是一个运行在 BNB 链的去中心化平台旨在为 meme 代币供公平启动服务。它允许用户以极低的成本创建和推出 meme 代币,无需预售或团队分配,它消除了传统的预售、种子轮和团队分配,确保所有参与者有…

Simula语言的正则表达式

Simula语言中的正则表达式 引言 Simula是一种开创性的编程语言,最初在1960年代由Ole-Johan Dahl和Kristen Nygaard在挪威的计算机中心开发。它不仅是面向对象编程的先驱,还在模拟和各种计算领域有显著的应用。然而,Simula语言本身并不直接支…

Java 集合 List、Set、Map 区别与应用

一、核心特性对比 二、底层实现与典型差异 ‌List‌ ‌ArrayList‌:动态数组结构,随机访问快(O(1)),中间插入/删除效率低(O(n))‌‌LinkedList‌:双向链表结构,头尾操作…

【第二月_day7】Pandas 简介与数据结构_Pandas_ day1

以下是专为小白设计的 Pandas 简介与数据结构 学习内容,用最通俗的语言和案例讲解核心概念: 一、安装 Pandas 1. 安装方法 打开电脑的命令提示符(Windows)或终端(Mac/Linux)输入以下命令并回车&#xff1…

欢迎来到未来:探索 Dify 开源大语言模型应用开发平台

欢迎来到未来:探索 Dify 开源大语言模型应用开发平台 如果你对 AI 世界有所耳闻,那么你一定听说过大语言模型(LLM)。这些智能巨兽能够生成文本、回答问题、甚至编写代码!但是,如何将它们变成真正的实用工具…

python多线程和多进程的区别有哪些

python多线程和多进程的区别有七种: 1、多线程可以共享全局变量,多进程不能。 2、多线程中,所有子线程的进程号相同;多进程中,不同的子进程进程号不同。 3、线程共享内存空间;进程的内存是独立的。 4、同一…

【MySQL报错】:Column count doesn’t match value count at row 1

MySQL报错:Column count doesn’t match value count at row 1 意思是存储的数据与数据库表的字段类型定义不相匹配. 由于类似 insert 语句中,前后列数不等造成的 主要有3个易错点: 要传入表中的字段数和values后面的值的个数不相等。 由于类…

TCP/IP 协议栈深度解析

1. 分层结构设计 TCP/IP协议栈采用四层模型,其分层结构与协议实现细节如下: 1.1 网络层(Network Layer) 核心功能:提供端到端的数据包路由与寻址 核心协议: IP协议(IPv4/IPv6) I…

Apache Tomcat CVE-2025-24813 安全漏洞

Apache Tomcat CVE-2025-24813被广泛利用,但是他必须要满足两个点: 1.被广泛的使用,并且部署在服务器中。 2.漏洞必须依赖在服务器中的配置。 并且漏洞补丁已经发布。 漏洞攻击方式: CVE-2025-24813 是 Apache Tomcat 部分 PUT…

怎么查看linux是Ubuntu还是centos

要确定你的Linux系统是基于Ubuntu还是CentOS,可以通过几种不同的方法来进行判断。下面是一些常用的方法: 要快速判断 Linux 系统是 Ubuntu 还是 CentOS,可通过以下方法综合验证: 一、查看系统信息文件 1. /etc/os-release 文件…

PostgreSQL 连接数超限问题

目录标题 **PostgreSQL 连接数超限问题解决方案****一、错误原因分析****二、查看连接数与配置****三、排查连接泄漏(应用侧问题)****四、服务侧配置调整****1. 调整最大连接数****2. 释放无效连接(谨慎操作)****3. 使用连接池工具…

数据结构模拟-用栈实现队列

用栈实现队列的基本操作,包括pop(), push(), empty(), peek(). 可以用两个栈来实现,一个栈保存入队的一端,也就是队尾,一个栈保存出队的一端,也就是队首。当遇到出队pop()时,如果stack out不为空&#xff…

2025最新-智慧小区物业管理系统

目录 1. 项目概述 2. 技术栈 3. 功能模块 3.1 管理员端 3.1.1 核心业务处理模块 3.1.2 基础信息模块 3.1.3 数据统计分析模块 3.2 业主端 5. 系统架构 5.1 前端架构 5.2 后端架构 5.3 数据交互流程 6. 部署说明 6.1 环境要求 6.2 部署步骤 7. 使用说明 7.1 管…

智能汽车图像及视频处理方案,支持视频智能包装能力

美摄科技的智能汽车图像及视频处理方案,通过深度学习算法与先进的色彩管理技术,能够自动调整图像中的亮度、对比度、饱和度等关键参数,确保在各种光线条件下,图像都能呈现出最接近人眼的自然色彩与细节层次。这不仅提升了驾驶者的…

跨层封装简单介绍

跨层封装 跨四层封装 数据封装时不经过第四层(传输层)。应用层封装后直接来到网络层。一般出现在直连路由设备之间。代表协议: OSPF协议、ICMP协议。 既然不经过四层封装,那四层相应的功能由谁来实现?答案是由三层&a…

SSE进阶详解

嗯,用户的问题涉及到SSE在处理富媒体文件、早期聊天应用选择SSE的原因,以及如何控制流式渲染频率。我需要根据提供的搜索结果来解答这些问题。 首先,关于SSE传输富媒体文件的问题。根据搜索结果,SSE是基于文本的,比如…