基于“动手学强化学习”的知识点(二):第 15 章 模仿学习(gym版本 >= 0.26)

第 15 章 模仿学习(gym版本 >= 0.26)

  • 摘要

摘要

本系列知识点讲解基于动手学强化学习中的内容进行详细的疑难点分析!具体内容请阅读动手学强化学习!


对应动手学强化学习——模仿学习


# -*- coding: utf-8 -*-import gym
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import rl_utilsclass PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)class ValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class PPO:''' PPO算法,采用截断方式 '''def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,lmbda, epochs, eps, gamma, device):self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.critic = ValueNet(state_dim, hidden_dim).to(device)self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)self.gamma = gammaself.lmbda = lmbdaself.epochs = epochs  # 一条序列的数据用于训练轮数self.eps = eps  # PPO中截断范围的参数self.device = devicedef take_action(self, state):if isinstance(state, tuple):state = state[0]state = torch.tensor([state], dtype=torch.float).to(self.device)probs = self.actor(state)'''根据概率分布创建一个离散分类分布对象,用于采样离散动作。离散的概率模型。'''action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def update(self, transition_dict):    processed_state = []for s in transition_dict['states']:if isinstance(s, tuple):# 如果元素是元组,则取元组的第一个元素processed_state.append(s[0])else:processed_state.append(s)# states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)states = torch.tensor(processed_state, dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)'''计算 TD 目标(即回归目标):td_target=r+γ×V(s′)×(1−done)'''td_target = rewards + self.gamma * self.critic(next_states) * (1 -dones)'''计算 TD 残差(或优势估计的基础):当前状态的 TD 目标减去当前 critic 估计的状态价值。'''td_delta = td_target - self.critic(states)'''调用辅助函数(在 rl_utils 模块中定义)计算优势函数,通常使用广义优势估计(GAE)。'''advantage = rl_utils.compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)'''先将状态输入 actor 网络得到动作概率分布(例如 shape 为 (batch_size, action_dim))。使用 .gather(1, actions) 选出每个样本所执行动作对应的概率(注意 actions 的形状必须匹配)。取对数得到旧的对数概率,再 detach() 阻断梯度流,保存旧策略下的概率值。'''old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()for _ in range(self.epochs):'''在当前策略下重新计算所有样本的对数概率,与旧对数概率进行比较。'''log_probs = torch.log(self.actor(states).gather(1, actions))'''计算概率比率,即新旧策略的概率之比,用于 PPO 的 clip 损失计算。'''ratio = torch.exp(log_probs - old_log_probs)'''计算无截断的策略目标,乘上优势值。'''surr1 = ratio * advantage'''对 ratio 进行截断,确保其在 [1−ϵ,1+ϵ] 范围内(例如 [0.8, 1.2]),然后乘以优势。'''surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage  # 截断'''PPO 算法的目标是最大化最小值,因此这里取两者中的较小值再取负号作为损失。对整个 batch 求均值。'''actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO损失函数'''计算 critic 的均方误差(MSE)损失:当前 critic 估计与 TD 目标之间的误差,对整个 batch 取平均。'''critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()actor_loss.backward()critic_loss.backward()self.actor_optimizer.step()self.critic_optimizer.step()actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 250
hidden_dim = 128
gamma = 0.98
lmbda = 0.95
epochs = 10
eps = 0.2
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'CartPole-v0'
env = gym.make(env_name)
if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
ppo_agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device)return_list = rl_utils.train_on_policy_agent(env, ppo_agent, num_episodes)def sample_expert_data(n_episode):states = []actions = []for episode in range(n_episode):state = env.reset()done = Falsewhile not done:action = ppo_agent.take_action(state)states.append(state)actions.append(action)result = env.step(action)if len(result) == 5:next_state, reward, done, truncated, info = resultdone = done or truncated  # 可合并 terminated 和 truncated 标志else:next_state, reward, done, info = result# next_state, reward, done, _ = env.step(action)state = next_stateprocessed_states = []for s in states:if isinstance(s, tuple):# 如果元素是元组,则取元组的第一个元素processed_states.append(s[0])else:processed_states.append(s)return np.array(processed_states), np.array(actions)if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
random.seed(0)
n_episode = 1
expert_s, expert_a = sample_expert_data(n_episode)n_samples = 30  # 采样30个数据
random_index = random.sample(range(expert_s.shape[0]), n_samples)
expert_s = expert_s[random_index]
expert_a = expert_a[random_index]class BehaviorClone:def __init__(self, state_dim, hidden_dim, action_dim, lr):self.policy = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)def learn(self, states, actions):"""解释:定义一个学习函数,接收一批专家数据中的状态和动作,用于更新策略网络。"""states = torch.tensor(states, dtype=torch.float).to(device)actions = torch.tensor(actions).view(-1, 1).to(device)'''- 将 states 输入 policy 网络,得到每个状态下所有动作的概率分布,假设输出形状为 (batch_size, action_dim);- 使用 .gather(1, actions.long()) 从概率分布中取出对应专家动作的概率(注意动作需要转换为长整型索引);- 对这些概率取对数,得到对数概率(log likelihood)。'''log_probs = torch.log(self.policy(states).gather(1, actions.long()))# log_probs = torch.log(self.policy(states).gather(1, actions))'''计算行为克隆的损失,即负对数似然损失。对所有样本的负对数概率取均值。'''bc_loss = torch.mean(-log_probs)  # 最大似然估计self.optimizer.zero_grad()bc_loss.backward()self.optimizer.step()def take_action(self, state):if isinstance(state, tuple):state = state[0]state = torch.tensor([state], dtype=torch.float).to(device)probs = self.policy(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def test_agent(agent, env, n_episode):return_list = []for episode in range(n_episode):episode_return = 0state = env.reset()done = Falsewhile not done:action = agent.take_action(state)result = env.step(action)if len(result) == 5:next_state, reward, done, truncated, info = resultdone = done or truncated  # 可合并 terminated 和 truncated 标志else:next_state, reward, done, info = result# next_state, reward, done, _ = env.step(action)state = next_stateepisode_return += rewardreturn_list.append(episode_return)return np.mean(return_list)if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
np.random.seed(0)lr = 1e-3
bc_agent = BehaviorClone(state_dim, hidden_dim, action_dim, lr)
n_iterations = 1000
batch_size = 64
test_returns = []with tqdm(total=n_iterations, desc="进度条") as pbar:for i in range(n_iterations):sample_indices = np.random.randint(low=0, high=expert_s.shape[0], size=batch_size)bc_agent.learn(expert_s[sample_indices], expert_a[sample_indices])current_return = test_agent(bc_agent, env, 5)test_returns.append(current_return)if (i + 1) % 10 == 0:pbar.set_postfix({'return': '%.3f' % np.mean(test_returns[-10:])})pbar.update(1)iteration_list = list(range(len(test_returns)))
plt.plot(iteration_list, test_returns)
plt.xlabel('Iterations')
plt.ylabel('Returns')
plt.title('BC on {}'.format(env_name))
plt.show()class Discriminator(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(Discriminator, self).__init__()self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x, a):cat = torch.cat([x, a], dim=1)x = F.relu(self.fc1(cat))return torch.sigmoid(self.fc2(x))class GAIL:def __init__(self, agent, state_dim, action_dim, hidden_dim, lr_d):print(state_dim, action_dim, hidden_dim)self.discriminator = Discriminator(state_dim, hidden_dim, action_dim).to(device)self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d)self.agent = agentdef learn(self, expert_s, expert_a, agent_s, agent_a, next_s, dones):expert_states = torch.tensor(expert_s, dtype=torch.float).to(device)expert_actions = torch.tensor(expert_a).to(device)processed_state = []for s in agent_s:if isinstance(s, tuple):# 如果元素是元组,则取元组的第一个元素processed_state.append(s[0])else:processed_state.append(s)agent_states = torch.tensor(processed_state, dtype=torch.float).to(device)agent_actions = torch.tensor(agent_a).to(device)'''作用:将专家动作转换为 one-hot 编码形式,转换为浮点数。'''expert_actions = F.one_hot(expert_actions.long(), num_classes=2).float()agent_actions = F.one_hot(agent_actions.long(), num_classes=2).float()expert_prob = self.discriminator(expert_states, expert_actions)agent_prob = self.discriminator(agent_states, agent_actions)'''作用:计算二元交叉熵损失(BCE):- 对 agent 数据,目标标签设为 1(即希望判别器认为 agent 数据为“真”),损失为 BCE(agent_prob, 1);- 对专家数据,目标标签设为 0(希望判别器认为专家数据为“假”),损失为 BCE(expert_prob, 0)。- 然后将两部分损失相加。'''discriminator_loss = nn.BCELoss()(agent_prob, torch.ones_like(agent_prob)) + nn.BCELoss()(expert_prob, torch.zeros_like(expert_prob))self.discriminator_optimizer.zero_grad()discriminator_loss.backward()self.discriminator_optimizer.step()'''作用:利用判别器对 agent 数据输出计算奖励:- 计算 –log(agent_prob) 作为奖励信号(当 agent_prob 较小时,奖励较高,鼓励 agent 模仿专家);- detach() 阻断梯度,转移到 CPU 并转换为 numpy 数组,方便后续传入 agent.update。'''rewards = -torch.log(agent_prob).detach().cpu().numpy()transition_dict = {'states': agent_s,'actions': agent_a,'rewards': rewards,'next_states': next_s,'dones': dones}self.agent.update(transition_dict)if not hasattr(env, 'seed'):def seed_fn(self, seed=None):env.reset(seed=seed)return [seed]env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
lr_d = 1e-3
agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device)
gail = GAIL(agent, state_dim, action_dim, hidden_dim, lr_d)
n_episode = 500
return_list = []with tqdm(total=n_episode, desc="进度条") as pbar:for i in range(n_episode):episode_return = 0state = env.reset()done = Falsestate_list = []action_list = []next_state_list = []done_list = []while not done:action = agent.take_action(state)result = env.step(action)if len(result) == 5:next_state, reward, done, truncated, info = resultdone = done or truncated  # 可合并 terminated 和 truncated 标志else:next_state, reward, done, info = result# next_state, reward, done, _ = env.step(action)state_list.append(state)action_list.append(action)next_state_list.append(next_state)done_list.append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)gail.learn(expert_s, expert_a, state_list, action_list, next_state_list, done_list)if (i + 1) % 10 == 0:pbar.set_postfix({'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)    iteration_list = list(range(len(return_list)))
plt.plot(iteration_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('GAIL on {}'.format(env_name))
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/73682.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JAVA面试_进阶部分_Java JVM:垃圾回收(GC 在什么时候,对什么东西,做了什么事情)

在什么时候: 首先需要知道,GC又分为minor GC 和 Full GC(major GC)。Java堆内存分为新生代和老年代,新生代 中又分为1个eden区和两个Survior区域。 一般情况下,新创建的对象都会被分配到eden区&#xff…

2024年消费者权益数据分析

📅 2024年315消费者权益数据分析 数据见:https://mp.weixin.qq.com/s/eV5GoionxhGpw7PunhOVnQ 一、引言 在数字化时代,消费者维权数据对于市场监管、商家诚信和行业发展具有重要价值。本文基于 2024年315平台线上投诉数据,采用数…

设计模式Python版 访问者模式

文章目录 前言一、访问者模式二、访问者模式示例 前言 GOF设计模式分三大类: 创建型模式:关注对象的创建过程,包括单例模式、简单工厂模式、工厂方法模式、抽象工厂模式、原型模式和建造者模式。结构型模式:关注类和对象之间的组…

安全无事故连续天数计算,python 时间工具的高效利用

安全天数计算,数据系统时间直取,安全标准高效便捷好用。 笔记模板由python脚本于2025-03-17 23:50:52创建,本篇笔记适合对python时间工具有研究欲的coder翻阅。 【学习的细节是欢悦的历程】 博客的核心价值:在于输出思考与经验&am…

大型语言模型(LLM)部署中的内存消耗计算

在部署大型语言模型(LLM)时,显存(VRAM)的合理规划是决定模型能否高效运行的核心问题。本文将通过详细的公式推导和示例计算,系统解析模型权重、键值缓存(KV Cache)、激活内存及额外开…

Mysql表的查询

一:创建一个新的数据库(companydb),并查看数据库。 二:使用该数据库,并创建表worker。 mysql> use companydb;mysql> CREATE TABLE worker(-> 部门号 INT(11) NOT NULL,-> 职工号 INT(11) NOT NULL,-> 工作时间 D…

ASP.NET Webform和ASP.NET MVC 后台开发 大概80%常用技术

本文涉及ASP.NET Webform和ASP.NET MVC 后台开发大概80%技术 2019年以前对标 深圳22K左右 广州18K左右 武汉16K左右 那么有人问了2019年以后的呢? 答:吉祥三宝。。。 So 想继续看下文的 得有自己的独立判断能力。 C#.NET高级笔试题 架构 优化 性能提…

首页性能优化

首页性能提升是前端优化中的核心任务之一,因为首页是用户访问的第一入口,其加载速度和交互体验直接影响用户的留存率和转化率。 1. 性能瓶颈分析 在优化之前,首先需要通过工具分析首页的性能瓶颈。常用的工具包括: Chrome DevTo…

一周学会Flask3 Python Web开发-SQLAlchemy删除数据操作-班级模块

锋哥原创的Flask3 Python Web开发 Flask3视频教程&#xff1a; 2025版 Flask3 Python web开发 视频教程(无废话版) 玩命更新中~_哔哩哔哩_bilibili 首页list.html里加上删除链接&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta c…

改变一生的思维模型【12】笛卡尔思维模型

目录 基本结构 警惕认知暗礁 案例分析应用 一、怀疑阶段:破除惯性认知 二、解析阶段:拆解问题为最小单元 三、整合阶段:重构逻辑链条 四、检验阶段:多维验证解决方案 总结与启示 笛卡尔说,唯独自己的思考是可以相信的。 世界上所有的事情,都是值得被怀疑的,但是…

需求文档(PRD,Product Requirement Document)的基本要求和案例参考:功能清单、流程图、原型图、逻辑能力和表达能力

文章目录 引言I 需求文档的基本要求结构清晰内容完整语言准确图文结合版本管理II 需求文档案例参考案例1:电商平台“商品中心”功能需求(简化版)案例2:教育类APP“记忆宝盒”非功能需求**案例3:软件项目的功能需求模板3.1 功能需求III 需求文档撰写技巧1. **从核心逻辑出发…

五大方向全面对比 IoTDB 与 OpenTSDB

对比系列第三弹&#xff0c;详解 IoTDB VS OpenTSDB&#xff01; 之前&#xff0c;我们已经深入探讨了时序数据库 Apache IoTDB 与 InfluxDB、Apache HBase 在架构设计、性能和功能方面等多个维度的区别。还没看过的小伙伴可以点击阅读&#xff1a; Apache IoTDB vs InfluxDB 开…

Electron使用WebAssembly实现CRC-16 MAXIM校验

Electron使用WebAssembly实现CRC-16 MAXIM校验 将C/C语言代码&#xff0c;经由WebAssembly编译为库函数&#xff0c;可以在JS语言环境进行调用。这里介绍在Electron工具环境使用WebAssembly调用CRC-16 MAXIM格式校验的方式。 CRC-16 MAXIM校验函数WebAssembly源文件 C语言实…

vue3vue-elementPlus-admin框架中form组件的upload写法

dialog中write组件代码 let ImageList reactive<UploadFile[]>([])const formSchema reactive<FormSchema[]>([{field: ImageFiles,label: 现场图片,component: Upload,colProps: { span: 24 },componentProps: {limit: 5,action: PATH_URL /upload,headers: {…

Linux mount和SSD分区

为什么要用 mount&#xff1f; Linux 的文件系统结构是单一的树状层次 所有文件、目录和设备都从根目录 / 开始延伸。 外部的存储设备&#xff08;如硬盘、U盘、网络存储&#xff09;或虚拟文件系统&#xff08;如 /proc、/sys&#xff09;必须通过挂载点“嫁接”到这棵树上&a…

【Function】Azure Function通过托管身份或访问令牌连接Azure SQL数据库

【Function】Azure Function通过托管身份或访问令牌连接Azure SQL数据库 推荐超级课程: 本地离线DeepSeek AI方案部署实战教程【完全版】Docker快速入门到精通Kubernetes入门到大师通关课AWS云服务快速入门实战目录 【Function】Azure Function通过托管身份或访问令牌连接Azu…

举例说明 牛顿法 Hessian 矩阵

矩阵求逆的方法及示例 目录 矩阵求逆的方法及示例1. 伴随矩阵法2. 初等行变换法矩阵逆的实际意义1. 求解线性方程组2. 线性变换的逆操作3. 数据分析和机器学习4. 优化问题牛顿法原理解释举例说明 牛顿法 Hessian 矩阵1. 伴随矩阵法 原理:对于一个 n n n 阶方阵 A A

安科瑞分布式光伏监测系统:推动绿色能源高效发展

安科瑞顾强 为应对传统能源污染与资源短缺&#xff0c;分布式光伏发电成为关键解决方案。安科瑞Acrel-1000DP分布式光伏监控系统结合光功率预测技术&#xff0c;有效提升发电稳定性&#xff0c;助力上海汽车变速器有限公司8.3MW屋顶光伏项目实现清洁能源高效利用。 项目亮点 …

从零开始使用 **Taki + Node.js** 实现动态网页转静态网站的完整代码方案

以下是从零开始使用 Taki Node.js 实现动态网页转静态网站的完整代码方案&#xff0c;包含预渲染、自动化构建、静态托管及优化功能&#xff1a; 一、环境准备 1. 初始化项目 mkdir static-site && cd static-site npm init -y2. 安装依赖 npm install taki expre…

商业智能BI分析中,汽车4S销售行业的返厂频次有什么分析价值?

买过车的朋友会发现&#xff0c;同一款车不管在哪个4S店去买&#xff0c;基本上价格都相差不大。即使有些差别&#xff0c;也是带着附加条件的&#xff0c;比如要做些加装需要额外再付一下费用。为什么汽车4S销售行业需要商业智能BI&#xff1f;就是因为在汽车4S销售行业&#…