强化学习与神经网络结合(以 DQN 展开)

目录

基于 PyTorch 实现简单 DQN

double DQN

dueling DQN

Noisy DQN:通过噪声层实现探索,替代 ε- 贪心策略

Rainbow_DQN如何计算连续型的Actions


强化学习中,智能体(Agent)通过与环境交互学习最优策略。当状态空间或动作空间庞大时,传统表格法(如 Q-Learning)难以存储所有状态 - 动作值(Q 值)。此时引入神经网络,用其函数拟合能力近似 Q 函数,即深度 Q 网络(DQN)。

图片所示的运流程

1.环境交互:智能体根据当前状态St,选择动作A执行,环境反馈奖励R,转移到新状态St+1

Q 值计算:将St+1输入 Q 网络,计算所有动作的 Q 值,取最大值

目标值构建:更新目标为

为折扣因子,平衡短期与长期奖励)

损失计算:当前 Q 值Q(St,a)作为预测值,目标值作为标签,通过均方误差(MSE计算损失

网络更新:利用梯度下降优化 Q 网络,减少损失

基于 PyTorch 实现简单 DQN

import torch  
import torch.nn as nn  
import torch.optim as optim  
import gym  # 定义Q网络  
class QNetwork(nn.Module):  def __init__(self, state_dim, action_dim):  super(QNetwork, self).__init__()  
        self.fc = nn.Sequential(  
            nn.Linear(state_dim, 64),  
            nn.ReLU(),  
            nn.Linear(64, action_dim)  )  def forward(self, x):  return self.fc(x)  # 初始化环境与网络  
env = gym.make('CartPole-v1')  
state_dim = env.observation_space.shape[0]  
action_dim = env.action_space.n  
q_net = QNetwork(state_dim, action_dim)  
optimizer = optim.Adam(q_net.parameters(), lr=0.001)  
criterion = nn.MSELoss()  # 训练循环  
for episode in range(100):  
    state = env.reset()  
    state = torch.FloatTensor(state)  
    total_reward = 0  
    done = False  while not done:  # 选择动作(简化示例,未包含探索策略)  with torch.no_grad():  
            q_values = q_net(state.unsqueeze(0))  
        action = torch.argmax(q_values).item()  # 执行动作,获取反馈  
        next_state, reward, done, _ = env.step(action)  
        next_state = torch.FloatTensor(next_state)  
        reward = torch.tensor(reward, dtype=torch.float32)  # 计算目标值与当前Q值  with torch.no_grad():  
            next_q = q_net(next_state.unsqueeze(0)).max(1)[0]  
        target = reward + 0.99 * next_q  
        current_q = q_net(state.unsqueeze(0))[0][action]  # 计算损失并更新网络  
        loss = criterion(current_q, target)  
        optimizer.zero_grad()  
        loss.backward()  
        optimizer.step()          state = next_state  
        total_reward += reward  print(f"Episode {episode}, Total Reward: {total_reward}")  env.close()  

agent与env的交互

判断done:

done 状态是由环境(env)在交互过程中返回的终止信号,用于标识当前回合(episode)是否结束

done 的来源:

在 observation_, reward, done = env.step(action) 这一步中,环境会根据智能体执行的动作(action),反馈当前状态的后续信息:

observation_:执行动作后转移到的新状态。

reward:执行动作获得的即时奖励。

done:一个布尔值(True/False),由环境规则定义。当环境认为当前回合结束时(如达到任务目标、触发终止条件等),done 会被置为 True

double DQN

Double DQN(Double Deep Q-Network)是对传统DQN的改进算法,旨在解决Q值估计过估计(Overestimation)的问题

传统DQN的局限

传统DQN通过 同一个网络 同时完成两个任务:

1. 选择动作:根据当前状态选择Q值最大的动作(贪心策略)。

2. 评估价值:计算目标Q值以更新网络参数。

这会导致**过估计**问题:网络在计算目标值时,倾向于选择自身高估的动作,从而引入偏差,导致训练不稳定甚至发散。

Double DQN的核心改进

Double DQN通过 分离动作选择和动作评估 来解决过估计问题,具体方法如下:

1. 引入两个网络:

- 在线网络(Online Network):负责选择动作(ε-贪心策略)

- 目标网络(Target Network):负责计算目标Q值,其参数定期从在线网络复制

2. 解耦动作选择与评估:

- 用在线网络选择动作a_t = argmax Q(s_t, a; θ)

- 用目标网络评估该动作的Q值 Q(s_{t+1}, a_t; θ^-)

算法流程

1. 初始化:

- 在线网络和目标网络结构相同,但参数独立

- 目标网络参数初始化为在线网络的副本θ^- = θ

2. 经验回放池:存储训练样本s_t, a_t, r_t, s_{t+1}

3. 训练循环:

- 步骤1:在线网络根据当前状态选择动作

- 步骤2:执行动作,获取奖励和下一状态

- 步骤3:将样本存入经验回放池

- 步骤4:从池中随机采样一批数据

- 步骤5:计算目标Q值:

target_Q = r + γ * Q_target(s_{t+1}, argmax Q_online(s_{t+1}, a; θ); θ^-)

- 步骤6:用在线网络计算当前Q值,并通过MSE损失更新参数

- 步骤7:定期更新目标网络参数(如每C步复制θ到θ^-)

关键技术细节

1. 目标网络更新策略:

- 目标网络参数并非实时更新,而是每隔C步从在线网络复制一次,避免梯度震荡

2. 经验回放(Experience Replay):

- 打破数据相关性,提高样本利用率

3. ε-贪心策略:

- 平衡探索与利用,确保充分探索环境

伪代码

初始化在线网络 Q(θ) 和目标网络 Q(θ^-)
θ^- = θ  # 初始同步
经验回放池 D 初始化for episode in episodes:
    初始化状态 swhile s 非终止状态:
        根据 ε-贪心策略选择动作 a(由 Q(θ) 决定)
        执行动作 a,得到奖励 r 和下一状态 s'
(s, a, r, s') 存入 D
        从 D 中随机采样 mini-batchfor 每个样本 (s_i, a_i, r_i, s'_i):
            a_next = argmax Q(s'_i, a; θ)  # 在线网络选动作
            Q_target = r_i + γ * Q_target(s'_i, a_next; θ^-)  # 目标网络评估
            loss = MSE(Q(s_i, a_i; θ), Q_target)
            反向传播更新 θ
        每隔 C 步,θ^- = θ  # 同步目标网络
        s = s'

优势与效果

1. 减少过估计:通过分离动作选择和评估,显著降低Q值偏差

2. 训练更稳定:目标网络参数定期更新,避免梯度震荡

3. 性能提升:在Atari游戏等任务中,Double DQN比传统DQN表现更优,如《Pong》《Breakout》等

扩展与变种

1. Dueling DQN:将Q值分解为状态价值和动作优势,进一步提升性能

2. Rainbow DQN:融合Double DQN、Dueling DQN、Prioritized Replay等技术

3. Noisy DQN:通过噪声层实现探索,替代ε-贪心策略

Double DQN通过解耦动作选择与评估,有效解决了传统DQN的过估计问题,成为深度强化学习的经典算法之一。其核心思想是利用两个网络的分工协作,平衡探索与利用,提升训练稳定性和样本效率

dueling DQN

Dueling DQN(对决网络 DQN)是对传统 DQN 的改进,核心在于将 Q 值拆解为状态价值和动作优势,让网络更高效地学习

核心思想:拆分 Q 值的意义

传统 DQN 直接学习 Q (s, a),即 “状态 - 动作” 的价值。而 Dueling DQN 将 Q 值拆分为两部分:

状态价值 V (s):衡量当前状态本身的好坏,不依赖具体动作。例如 “站在十字路口” 这个状态的基础价值。

动作优势 Adv (s, a):衡量某个动作相对于平均动作的优势。例如 “在十字路口,右转比左转 / 直行更好” 的优势。

公式:Q(s, a) = V(s) + Adv(s, a)。

通过这种拆分,网络能更清晰地学习 “状态本身的价值” 和 “动作的相对优势”,尤其在状态价值明显、动作选择影响较小时,学习效率更高。

网络结构:共享特征 + 双分支

共享特征提取层:

前几层网络(如卷积层、全连接层)用于提取状态特征,类似传统 DQN 的特征处理。双分支结构:V (s) 分支:输入共享层特征,输出标量(状态价值)。

Adv (s, a) 分支:输入共享层特征,输出向量(每个动作的优势)。

约束条件:为避免 V 和 Adv 的表示冗余,通常对 Adv 添加约束,例如让同一状态下所有动作的 Adv 均值为 0(∑Adv(s, a)/num_actions = 0)。这样 V (s) 实际代表该状态下所有动作 Q 值的平均值。

共享特征提取层:特征的 “通用加工厂”

作用:无论是计算状态价值 V(s) 还是动作优势Adv(s,a),都需要先从原始状态(如游戏画面、机器人传感器数据)中提取有意义的特征。

类比理解:类似炒菜前的备菜环节。不管最后是要炒青菜还是炒肉,都需要先洗菜、切菜(提取特征)。共享特征层就是 “洗菜切菜” 的通用流程,用卷积层、全连接层等处理原始状态,得到后续分支可用的特征。

双分支结构:分工明确的 “价值分析员”

V(s) 分支

目标:评估“状态本身的价值”,不考虑具体动作。例如,在“游戏角色站在能量道具旁边”的状态,V(s)衡量这个状态潜在的整体收益(无论捡道具还是不捡,先评估状态基础价值)。

输出:一个标量(单一数值),代表状态 s 的价值。

Adv(s,a) 分支:

目标:评估每个动作的“相对优势”。例如,在上述状态下,“捡道具”这个动作比“不捡道具”好多少,优势通过Adv(s,a) 体现。

输出:向量,每个元素对应一个动作的优势值。比如游戏有 4 个动作,就输出 4 个数值,分别表示每个动作的优势。

更新机制:为何能 “一次性调整所有 Q 值”?

普通 DQN 的更新局限:

传统 DQN 更新时,只针对选中的动作调整 Q 值。例如在状态 s 下选动作 a,仅更新 Q (s, a),其他动作的 Q 值不受影响。Dueling DQN 的更新优势:由于Adv(s, a)的和为 0 的约束,网络更倾向于先调整V(s)(状态价值)。而V(s)是该状态下所有动作 Q 值的平均值,调整V(s)相当于对该状态下所有动作的 Q 值进行了一次全局更新。

举例:若状态 s 的 V (s) 从 10 提升到 12,且 Adv (s, a) 不变,那么所有动作的 Q 值都会增加 2。这种 “批量更新” 让网络学习更高效,尤其在状态价值主导决策时,能更快收敛。

Noisy DQN:通过噪声层实现探索,替代 ε- 贪心策略

传统探索策略的局限

ε- 贪心的问题:通过固定概率(如 ε=0.1)随机选动作实现探索,但 ε 需手动调整,且无法根据训练阶段动态适应。例如,前期需要强探索,后期应聚焦利用,ε- 贪心难以灵活平衡

Noisy DQN 的核心原理:参数空间噪声

Noisy DQN 将探索直接融入网络结构,通过在网络参数中添加噪声,让智能体在决策时自然产生探索行为:

噪声添加方式:在网络的全连接层参数(权重W和偏置b)中加入噪声。例如,对某一层的权重W,实际使用W + epsilon_W是噪声,偏置b同理使用b + epsilon_b。

探索与利用的平衡:训练阶段:噪声激活,智能体因参数扰动尝试不同动作,实现探索。例如,原本选 Q 值最高的动作,加噪声后可能因参数变化选择其他动作。

推理阶段:去掉噪声(或噪声趋近于 0),直接基于确定的网络参数选择最优动作,专注利用。

实现细节:噪声层设计

噪声类型:常用高斯噪声,如权重噪声,偏置噪声,通过超参数sigma控制噪声强度。

Frostbite 噪声层:一种典型实现,对权重和偏置的噪声初始化做特殊设计。例如,权重噪声epsilon_W的每个元素初始化为(x是高斯分布样本),确保噪声有合理的方差尺度,避免探索过度或不足。

噪声的动态调整:训练中可逐渐降低噪声强度(如退火策略),让智能体从探索为主过渡到利用为主,无需像 ε- 贪心一样手动设定探索概率。

Rainbow_DQN如何计算连续型的Actions

Rainbow DQN 最初设计用于离散动作空间(如 Atari 游戏中的有限操作),若要处理连续动作空间,需对其进行改造

离散化处理:将连续动作转为离散

区间划分对连续动作的每个维度(如机器人关节角度、车辆速度)划分离散区间。例如,动作是二维连续空间 (a_1, a_2),可将 a_1分为 N_1 个区间,a_2分为 N_2个区间,形成N_1 * N_2个离散动作组合。

直接套用 Rainbow DQN离散化后,动作空间变为有限集合,直接使用 Rainbow DQN 的网络结构(如融合 Double DQN、Dueling DQN 等模块)计算每个离散动作的 Q 值,选择 Q 值最高的动作执行。局限性:离散化粒度影响性能,粒度过粗丢失细节,过细则计算量剧增。

结合连续动作策略网络:改造输出层

若需保持动作连续性,可改造 Rainbow DQN 的网络结构,引入连续动作生成机制:

策略网络输出修改网络末端,输出连续动作的参数。例如:均值 - 方差输出:网络输出动作的均值 mu和方差sigma^2,基于高斯分布 采样连续动作。

直接回归:通过全连接层直接回归连续动作值(如 DDPG 的思路),但需结合评论家网络(Critic)评估动作价值,与 Rainbow DQN 的 Q 学习框架融合。

价值函数计算保留 Rainbow DQN 的多组件改进(如优先经验回放、双网络结构),但将 Q 值计算适配连续动作。例如,使用积分或采样近似连续动作空间的 Q 值:

其中是连续动作的策略分布。

典型实践框架

实际应用中,常将 Rainbow DQN 的优化组件(如多步引导、噪声探索)与连续动作算法结合,形成新框架:

网络结构:特征提取层:与 Rainbow DQN 一致,提取状态特征。

策略分支:输出连续动作参数(如均值、方差)。

价值分支:计算状态 - 动作对的 Q 值,融合 Dueling DQN 等思想。

训练流程:

结合策略梯度(PG)或深度确定性策略梯度(DDPG)的训练方式,利用 Rainbow DQN 的经验回放、多步更新等技术优化训练稳定性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/75630.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

“11.9元“引发的系统雪崩:Spring Boot中BigDecimal反序列化异常全链路狙击战 ✨

💥 "11.9元"引发的系统雪崩:Spring Boot中BigDecimal反序列化异常全链路狙击战 🎯 🔍 用 Mermaid原生防御体系图 #mermaid-svg-XZtcYBnmHrF9bFjc {font-family:"trebuchet ms",verdana,arial,sans-serif;fon…

Cortex-M7进入异常中断分析

使用cmbacktrace库,其支持M3,4,7。 1、串口输出异常信息 #define cmb_println(...) Debug_Printf(__VA_ARGS__)//cmb_println处理可变参数和格式化字符串 int Debug_Printf(const char *fmt, ...) {char buffer[DEBUG_TxBUFLEN];INT16U n;va_list args;va_star…

如何管理间接需求?团队实践分享

管理间接需求的核心方法包括明确需求识别流程、建立规范的需求管理体系、实施有效的需求沟通机制。 其中,明确需求识别流程最为关键。企业在实际业务中,往往会遇到大量的间接需求,如非直接生产性的采购需求、服务类需求等。这些需求往往隐蔽性…

与Aspose.pdf类似的jar库分享

如果你在寻找类似于 Aspose.PDF 的 JAR 库,这些库通常用于处理 PDF 文档的创建、编辑、转换、合并等功能。以下是一些类似的 Java 库,它们提供 PDF 处理的功能,其中一些是收费的,但也有开源选项: 1. iText (iText PDF…

2-2 MATLAB鮣鱼优化算法ROA优化CNN超参数回归预测

本博客来源于CSDN机器鱼,未同意任何人转载。 更多内容,欢迎点击本专栏目录,查看更多内容。 目录 0.引言 1.ROA优化CNN 2.主程序调用 3.结语 0.引言 在博客【ROA优化LSTM超参数回归】中,我们采用ROA对LSTM的学习率、迭代次数…

企业入驻成都国际数字影像产业园,可享150多项专业服务

企业入驻成都国际数字影像产业园,可享150多项专业服务 全方位赋能,助力影像企业腾飞 入驻成都国际数字影像产业园,企业将获得一个涵盖超过150项专业服务的全周期、一站式支持体系,旨在精准解决企业发展各阶段的核心需求&#xf…

线路板元器件介绍及选型指南:提高电路设计效率

电路板(PCB)是现代电子设备的核心,其上安装了各类电子元器件,这些元器件通过PCB的导电线路彼此连接,实现信号传输与功能执行。 元器件的选择与安装直接决定了电子产品的性能与稳定性。本文将为大家详细介绍电路板上的…

探究 Arm Compiler for Embedded 6 的 Clang 版本

原创标题:Arm Compiler for Embedded 6 的 Clang 版本 原创作者:庄晓立(LIIGO) 原创日期:20250218(首发日期20250326) 原创连接:https://blog.csdn.net/liigo/article/details/14653…

RedHat7.6_x86_x64服务器(最小化安装)搭建使用记录(二)

PostgreSQL数据库部署管理 1.rpm方式安装 挂载系统安装镜像: [rootlocalhost ~]# mount /dev/cdrom /mnt 进入安装包路径: [rootlocalhost ~]# cd /mnt/Packages 依次安装如下程序包: [rootlocalhost Packages]# rpm -ihv postgresql-libs-9…

浏览器存储 IndexedDB

IndexedDB 1. 什么是 IndexedDB? IndexedDB 是一种 基于浏览器的 NoSQL 数据库,用于存储大量的结构化数据,包括文件和二进制数据。它比 localStorage 和 sessionStorage 更强大,支持索引查询、事务等特性。 IndexedDB 主要特点…

panda3d 渲染

目录 安装 设置渲染宽高: 渲染3d 安装 pip install Panda3D 设置渲染宽高: import panda3d.core as pdmargin 100 screen Tk().winfo_screenwidth() - margin, Tk().winfo_screenheight() - margin width, height (screen[0], int(screen[0] / 1…

Node.js 包管理工具 - NPM 与 PNPM 清理缓存

NPM 清理缓存 1、基本介绍 npm 缓存是 npm 用来存储已下载包的地方,以加快后续安装速度 但是,有时缓存可能会损坏或占用过多磁盘空间,这时可以清理 npm 缓存 2、清理操作 执行如下指令,清理 npm 缓存 npm cache clean --for…

STM32F103_LL库+寄存器学习笔记05 - GPIO输入模式,捕获上升沿进入中断回调

导言 GPIO设置输入模式后,一般会用轮询的方式去查看GPIO的电平状态。比如,最常用的案例是用于检测按钮的当前状态(是按下还是没按下)。中断的使用一般用于计算脉冲的频率与计算脉冲的数量。 项目地址:https://github.…

【C++进阶二】string的模拟实现

【C进阶二】string的模拟实现 1.构造函数和C_strC_str: 2.operator[]3.拷贝构造3.1浅拷贝3.2深拷贝 4.赋值5.迭代器6.比较ascll码值的大小7.reverse扩容8.push_back尾插和append尾插9.10.insert10.1在pos位置前插入字符ch10.2在pos位置前插入字符串str 11.resize12.erase12.1从…

wokwi arduino mega 2560 - 点亮LED案例

截图: 点亮LED案例仿真截图 代码: unsigned long t[20]; // 定义一个数组t,用于存储20个LED的上次状态切换时间(单位:毫秒)void setup() {pinMode(13, OUTPUT); // 将引脚13设置为输出模式(此…

vue3项目使用 python +flask 打包成桌面应用

server.py import os import sys from flask import Flask, send_from_directory# 获取静态文件路径 if getattr(sys, "frozen", False):# 如果是打包后的可执行文件base_dir sys._MEIPASS else:# 如果是开发环境base_dir os.path.dirname(os.path.abspath(__file…

后端学习day1-Spring(八股)--还剩9个没看

一、Spring 1.请你说说Spring的核心是什么 参考答案 Spring框架包含众多模块,如Core、Testing、Data Access、Web Servlet等,其中Core是整个Spring框架的核心模块。Core模块提供了IoC容器、AOP功能、数据绑定、类型转换等一系列的基础功能,…

LeetCode 第34、35题

LeetCode 第34题:在排序数组中查找元素的第一个和最后一个位置 题目描述 给你一个按照非递减顺序排列的整数数组nums,和一个目标值target。请你找出给定目标值在数组中的开始位置和结束位置。如果数组中不存在目标值target,返回[-1,1]。你必须…

告别分库分表,时序数据库 TDengine 解锁燃气监控新可能

达成效果: 从 MySQL 迁移至 TDengine 后,设备数据自动分片,运维更简单。 列式存储可减少 50% 的存储占用,单服务器即可支撑全量业务。 毫秒级漏气报警响应时间控制在 500ms 以内,提升应急管理效率。 新架构支持未来…

第十四届蓝桥杯真题

一.LED 先配置LED的八个引脚为GPIO_OutPut,锁存器PD2也是,然后都设置为起始高电平,生成代码时还要去解决引脚冲突问题 二.按键 按键配置,由原理图按键所对引脚要GPIO_Input 生成代码,在文件夹中添加code文件夹,code中添加fun.c、fun.h、headfile.h文件,去资源包中把lc…