强化学习 动作空间(离散/连续)

news/2025/10/11 13:49:28/文章来源:https://www.cnblogs.com/pass-ion/p/19133476

1. 离散动作空间的策略网络

在离散空间中,动作是可数的,例如:{左, 右, 上, 下} 或 {加速, 刹车}。

网络架构与处理方式

  1. 输出层:Softmax

    • 策略网络的最后一层是一个 Softmax 层。

    • 假设有 N 个可选动作,网络会输出一个长度为 N 的向量

    • Softmax 函数确保这个向量的所有元素都在 (0, 1) 之间,且和为 1。这样,每个元素就代表了选择对应动作的概率。

  2. 策略表示

    • 策略 π(a|s) 直接由网络输出给出:
      π(a=i|s) = Softmax(Logits(s))[i]

  3. 动作采样

    • 根据网络输出的概率分布,进行分类采样来选择动作。

    • 在 Python 中,可以使用 np.random.choice 或 torch.distributions.Categorical

import torch
import torch.nn as nn
import torch.nn.functional as Fclass DiscretePolicyNetwork(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(DiscretePolicyNetwork, self).__init__()self.fc1 = nn.Linear(input_dim, hidden_dim)self.fc2 = nn.Linear(hidden_dim, output_dim) # output_dim = 动作数量def forward(self, state):x = F.relu(self.fc1(state))logits = self.fc2(x) # 输出 logits,未归一化的概率return logitsdef act(self, state):logits = self.forward(state)# 创建分类分布action_probs = F.softmax(logits, dim=-1)dist = torch.distributions.Categorical(action_probs)# 采样动作action = dist.sample()# 计算对数概率,用于策略梯度更新log_prob = dist.log_prob(action)return action.detach().item(), log_prob# 假设有4个动作
policy_net = DiscretePolicyNetwork(input_dim=8, hidden_dim=128, output_dim=4)
state = torch.tensor([0.1, 0.5, -0.2, ...]) # 状态向量
action, log_prob = policy_net.act(state)
print(f"Sampled action: {action}")

 

2. 连续动作空间的策略网络

在连续空间中,动作是实数向量,例如:方向盘转角 [-1, 1],机器人关节扭矩 [τ₁, τ₂, ...]

这里有两种主要设计思路:

A. 随机策略 - 输出分布参数

这是最常用的方法,策略网络输出一个概率分布的参数,动作从这个分布中采样。

    1. 输出层:分布参数

      • 最常用的是高斯分布。网络为每个动作维度输出两个值:

        • 均值:通常使用 tanh 作为激活函数,将均值限制在 [-1, 1] 范围内,或者不适用激活函数。

        • 标准差:通常使用 softplus 等函数确保其为正数。也可以是一个与状态无关的可学习参数。

    2. 策略表示

      • 策略 π(a|s) 是一个概率密度函数。例如,对于高斯分布:
        a ~ N(μ(s), σ(s)²)

    3. 动作采样

      • 使用网络输出的均值和标准差构建一个高斯分布,然后从这个分布中采样。

      • 由于采样操作不可导,在训练时需要使用重参数化技巧。

class ContinuousPolicyNetwork(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(ContinuousPolicyNetwork, self).__init__()self.output_dim = output_dim # 动作空间的维度self.fc1 = nn.Linear(input_dim, hidden_dim)# 输出均值self.mean_head = nn.Linear(hidden_dim, output_dim)# 输出对数标准差(更稳定),通常作为一个独立的层self.log_std_head = nn.Linear(hidden_dim, output_dim)# 或者:self.log_std = nn.Parameter(torch.zeros(1, output_dim))def forward(self, state):x = F.relu(self.fc1(state))mean = torch.tanh(self.mean_head(x)) # 将均值限制在[-1,1]log_std = self.log_std_head(x)# 使用 clamp 将标准差限制在一个合理范围内log_std = torch.clamp(log_std, min=-20, max=2)std = torch.exp(log_std)return mean, stddef act(self, state):mean, std = self.forward(state)# 创建多元高斯分布(假设各维度独立)dist = torch.distributions.Normal(mean, std)# 重参数化技巧采样action = dist.rsample()# 计算对数概率(对于多维动作,需要对数概率的和)log_prob = dist.log_prob(action).sum(dim=-1)# 如果需要将动作限制在[-1,1],可以使用tanh,但需要修正对数概率# action = torch.tanh(raw_action)# 更复杂的实现会处理tanh变换后的概率计算return action.detach().numpy(), log_prob# 假设动作是2维的(如:速度,方向)
policy_net = ContinuousPolicyNetwork(input_dim=8, hidden_dim=128, output_dim=2)
state = torch.tensor([0.1, 0.5, -0.2, ...])
action, log_prob = policy_net.act(state)
print(f"Sampled continuous action: {action}")

 

torch.clamp 将输入张量中的所有元素限制在一个指定的区间 [min, max] 内。具体来说:

  • 如果元素小于 min,则将其设置为 min

  • 如果元素大于 max,则将其设置为 max

  • 如果元素在 [min, max] 范围内,则保持不变

 

tanh函数:

image

 

torch.distributions.Normal 表示一个一元高斯分布,由两个参数定义:

  • loc: 分布的均值

  • scale: 分布的标准差

# 创建分布
mean = torch.tensor([0.0, 1.0])
std = torch.tensor([1.0, 0.5])
normal = dist.Normal(mean, std)# 1. sample() - 普通采样
samples = normal.sample()
print("Sample:", samples)
# 输出: tensor([-0.1234, 1.2345])# 2. rsample() - 重参数化采样(可微分)
reparam_samples = normal.rsample()
print("Reparameterized sample:", reparam_samples)
# 输出: tensor([0.5678, 0.8765])# 3. sample() 批量采样
batch_samples = normal.sample((3,))  # 采样3次
print("Batch samples shape:", batch_samples.shape)
# 输出: torch.Size([3, 2])

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/934753.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QuickLook软件!一款鼠标单击PDF即能显示内容的软件!

软件介绍 大家都知道,苹果电脑有个非常实用的功能,那就是只要单击文件,然后按空格就可以预览文件里的内容,但是Windows没有这功能。今天介绍的这款叫QuickLook,它能在Windows的环境下实现快速预览文件的功能。软件…

Http Security Headers

HTTP 安全相关的响应头(Security Headers)是 Web 应用安全防护的核心手段,通过浏览器与服务器的协作,防御跨站脚本(XSS)、点击劫持、中间人攻击、信息泄露等常见风险。以下是最常用的安全头及其作用机制、使用方…

参照Yalla、Hawa等主流APP核心功能,开发一款受欢迎的海外语聊需要从哪些方面入手

近期,从海外客户的主要咨询需求来看,主要是围绕在借鉴主流APP,在此基础上需要开发属于他们Agency、Coinseller、CP、PK等特色功能。每个客户的需求都有差异,建议您从自己的运营方向出发,来开发符合自己需求的海外…

本土化DevOps的突围之路:Gitee如何重塑企业研发效能

本土化DevOps的突围之路:Gitee如何重塑企业研发效能 在数字经济加速发展的今天,DevOps已从技术概念升级为企业数字化转型的核心引擎。国际权威调研机构Gartner预测,到2025年全球DevOps市场规模将突破300亿美元,其中…

【STM32计划开源】基于STM32的智能点滴输液系统

【STM32计划开源】基于STM32的智能点滴输液系统pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", &…

溶气气浮/浅层气浮/国内知名气浮机靠谱厂家品牌推荐

溶气气浮/浅层气浮/国内知名气浮机靠谱厂家品牌推荐 无锡工源环境科技股份有限公司是一家在环保水处理设备领域,特别是气浮设备研发与制造方面,具有深厚技术积累和市场声誉的高新技术企业。公司始终专注于水处理技术…

iOS 26 崩溃日志深度指南,如何收集、符号化、定位与监控 - 实践

iOS 26 崩溃日志深度指南,如何收集、符号化、定位与监控 - 实践2025-10-11 13:35 tlnshuju 阅读(0) 评论(0) 收藏 举报pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !imp…

鸿蒙Next密码自动填充服务:安全与便捷的完美融合 - 实践

鸿蒙Next密码自动填充服务:安全与便捷的完美融合 - 实践pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas…

覆盖动画 / 工业 / 科研!Rhino 7:专业 3D 建模的全能解决方案,新手也能上手

在 3D 建模领域,一款功能强大、兼容广泛且高效稳定的工具,往往能成为设计师、工程师突破创作瓶颈的关键。由美国 Robert McNeel 公司打造的Rhinoceros(简称 Rhino) 系列软件,凭借其 “集百家之长” 的设计理念,早…

2020CSP-J2比赛记录题解

题目请看洛谷备注:这次比赛我是没打的T1 先把数转成二进制,逐位计算,并判断是否可完整正确拆分贴一下代码 #include <bits/stdc++.h> using namespace std; #define fre(c) freopen(c".in","r…

Binder.getCallingPid()和Binder.getCallingUid()漏洞分析

最近在学习安卓漏挖,在分析ghera数据集时发现一个很有意思的binder特性,但还没搞懂底层原理,先挖个坑 漏洞分析EnforceCallingOrSelfPermission-PrivilegeEscalation-Lean以下代码使用Binder.getCallingPid()和Bind…

详细介绍:golang基础语法(五)切片

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

让博客园设置支持PlantUml画图

1. 引入 2. 博客园不支持plantuml渲染 3. 编写js脚本支持plantuml 4. 缺点‍ 1. 引入众所周知,我们在写博客的时候,常使用PlantUML 和 Mermaid绘制图表、流程图、架构图。这是因为用代码去画图,不怎么需要手动控制格…

jj

jjimport numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from sklearn.preprocessing import StandardScaler from sklearn.decomposition import PCA from sklearn.cluste…

光谱相机的未来趋势 - 详解

光谱相机的未来趋势 - 详解pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", …

Hall定理学习笔记

内容 设二分图左部点点数为 \(x\),右部点点数为 \(y\),且满足 \(x<y\)。定义一张二分图的完备匹配为:对于任意一个左部点都有与之匹配的右部点。 \(\text{Hall}\) 定理的内容是:一张二分图有完备匹配,等价于对…

面向对象抽象,接口多态综合-动物模拟系统

1、抽象一个动物类,会说话和走路。 public abstract class Animal() { public abstract void Speak(); public abstract void Walk(); } 2、抽象出能力,有的动物会飞,有的动物能用四条腿走路 interface IFly { void…

实用指南:APache shiro-550 CVE-2016-4437复现

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

Spark - deprecated registerTempTable() function

Spark - deprecated registerTempTable() functionIn Apache Spark, the function registerTempTable() was an old API (deprecated since Spark 2.0 and removed in Spark 3.0) that allowed you to register a Data…

MinGW-即时入门-全-

MinGW 即时入门(全)原文:zh.annas-archive.org/md5/a899d9a6a04025b2abd50163c83cff2a 译者:飞龙 协议:CC BY-NC-SA 4.0第一章. 立即开始使用 MinGW 欢迎使用 立即开始使用 MinGW。 本书特别创建,旨在为您提供所…