pytorch强化学习(1)——DQNSARSA

实验环境

python=3.10
torch=2.1.1
gym=0.26.2
gym[classic_control]
matplotlib=3.8.0
numpy=1.26.2

DQN代码

首先是module.py代码,在这里定义了网络模型和DQN模型

import torch
import torch.nn as nn
import numpy as npclass Net(nn.Module):# 构造只有一个隐含层的网络def __init__(self, n_states, n_hidden, n_actions):super(Net, self).__init__()# [b,n_states]-->[b,n_hidden]self.network = nn.Sequential(torch.nn.Linear(n_states, n_hidden),torch.nn.ReLU(),torch.nn.Linear(n_hidden, n_actions))# 前传def forward(self, x):  # [b,n_states]return self.network(x)class DQN:def __init__(self, n_states, n_hidden, n_actions, lr, gamma, epsilon):# 属性分配self.n_states = n_states  # 状态的特征数self.n_hidden = n_hidden  # 隐含层个数self.n_actions = n_actions  # 动作数self.lr = lr  # 训练时的学习率self.gamma = gamma  # 折扣因子,对下一状态的回报的缩放self.epsilon = epsilon  # 贪婪策略,有1-epsilon的概率探索# 计数器,记录迭代次数self.count = 0# 实例化训练网络self.q_net = Net(self.n_states, self.n_hidden, self.n_actions)# 优化器,更新训练网络的参数self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)self.criterion = torch.nn.MSELoss()  # 损失函数def choose_action(self, gym_state):state = torch.Tensor(gym_state)if np.random.random() < self.epsilon:action_values = self.q_net(state)  # q_net(state)采取动作后的预测action = action_values.argmax().item()else:# 随机选择一个动作action = np.random.randint(self.n_actions)return actiondef update(self, gym_state, action, reward, next_gym_state, done):state, next_state = torch.tensor(gym_state), torch.tensor(next_gym_state)q_value = self.q_net(state)[action]# 前千万不能缺少done,如果下一步游戏结束的花,那下一步的q值应该为0q_target = reward + self.gamma * self.q_net(next_state).max() * (1 - float(done))self.optimizer.zero_grad()dqn_loss = self.criterion(q_value, q_target)dqn_loss.backward()self.optimizer.step()

然后是train.py代码,在这里调用DQN模型和gym环境,来进行训练:

import gym
import torch
from module import DQN
import matplotlib.pyplot as pltlr = 1e-3  # 学习率
gamma = 0.95  # 折扣因子
epsilon = 0.8  # 贪心系数
n_hidden = 200  # 隐含层神经元个数env = gym.make("CartPole-v1")
n_states = env.observation_space.shape[0]  # 4
n_actions = env.action_space.n  # 2 动作的个数dqn = DQN(n_states, n_hidden, n_actions, lr, gamma, epsilon)if __name__ == '__main__':reward_list = []for i in range(500):state = env.reset()[0]  # len=4total_reward = 0done = Falsewhile True:# 获取当前状态下需要采取的动作action = dqn.choose_action(state)# 更新环境next_state, reward, done, _, _ = env.step(action)dqn.update(state, action, reward, next_state, done)state = next_statetotal_reward += rewardif done:breakprint("第%d回合,total_reward=%f" % (i, total_reward))reward_list.append(total_reward)# 绘图episodes_list = list(range(len(reward_list)))plt.plot(episodes_list, reward_list)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('DQN Returns')plt.show()

SARSA代码

首先是module.py代码,在这里定义了网络模型和SARSA模型。
SARSA和DQN基本相同,只有在更新Q网络的时候略有不同,已在代码相应位置做出注释。

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as Fclass Net(nn.Module):# 构造只有一个隐含层的网络def __init__(self, n_states, n_hidden, n_actions):super(Net, self).__init__()# [b,n_states]-->[b,n_hidden]self.network = nn.Sequential(torch.nn.Linear(n_states, n_hidden),torch.nn.ReLU(),torch.nn.Linear(n_hidden, n_actions))# 前传def forward(self, x):  # [b,n_states]return self.network(x)class SARSA:def __init__(self, n_states, n_hidden, n_actions, lr, gamma, epsilon):# 属性分配self.n_states = n_states  # 状态的特征数self.n_hidden = n_hidden  # 隐含层个数self.n_actions = n_actions  # 动作数self.lr = lr  # 训练时的学习率self.gamma = gamma  # 折扣因子,对下一状态的回报的缩放self.epsilon = epsilon  # 贪婪策略,有1-epsilon的概率探索# 计数器,记录迭代次数self.count = 0# 实例化训练网络self.q_net = Net(self.n_states, self.n_hidden, self.n_actions)# 优化器,更新训练网络的参数self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)self.criterion = torch.nn.MSELoss()  # 损失函数def choose_action(self, gym_state):state = torch.Tensor(gym_state)# 基于贪婪系数,有一定概率采取随机策略if np.random.random() < self.epsilon:action_values = self.q_net(state)  # q_net(state)是在当前状态采取各个动作后的预测action = action_values.argmax().item()else:# 随机选择一个动作action = np.random.randint(self.n_actions)return actiondef update(self, gym_state, action, reward, next_gym_state, done):state, next_state = torch.tensor(gym_state), torch.tensor(next_gym_state)q_value = self.q_net(state)[action]'''sarsa在更新网络时选择的是q_net(next_state)[next_action] 这是sarsa算法和dqn的唯一不同dqn是选择max(q_net(next))'''next_action = self.choose_action(next_state)# 千万不能缺少done,如果下一步游戏结束的话,那下一步的q值应该为0,而不是q网络输出的值q_target = reward + self.gamma * self.q_net(next_state)[next_action] * (1 - float(done))self.optimizer.zero_grad()dqn_loss = self.criterion(q_value, q_target)dqn_loss.backward()self.optimizer.step()

SARSA也有tarin.py文件,功能和上面DQN的一样,内容也几乎完全一样,只是把DQN的名字改成SARSA而已,所以在这里不再赘述。

运行结果

DQN的运行结果如下:
在这里插入图片描述

SARSA运行结果如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/227963.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt容器QToolBox工具箱

# QToolBox QToolBox是Qt框架中的一个窗口容器类,常用的几个函数有: ​setCurrentIndex(int index):设置当前显示的页面索引。可以通过调用该函数,将指定索引的页面设置为当前显示的页面。 addItem(QWidget * widget, const QString & text):向QToolBox中添加一个页面…

Flink系列之:分组聚合

Flink系列之&#xff1a;分组聚合 一、DISTINCT 聚合二、GROUPING SETS三、ROLLUP四、CUBE五、HAVING 适用于流、批 像大多数数据系统一样&#xff0c;Apache Flink支持聚合函数&#xff1b;包括内置的和用户定义的。用户自定义函数在使用前必须在目录中注册。 聚合函数把多行…

flutter学习-day10-布局类组件

&#x1f4da; 目录 介绍布局原理和约束盒模型布局 约束容器ConstrainedBox非约束容器UnconstrainedBox 线性布局 行row列column 弹性布局流式布局 WrapFlow 层叠布局对齐和相对定位布局构建回调 LayoutBuilder布局过程中AfterLayout布局完成后执行 本文学习和引用自《Flutte…

LLM大语言模型(二):Streamlit 无需前端经验也能画web页面

目录 问题 Streamlit是什么&#xff1f; 怎样用Streamlit画一个LLM的web页面呢&#xff1f; 文本输出 页面布局 滑动条 按钮 对话框 输入框 总结 问题 假如你是一位后端开发&#xff0c;没有任何的web开发经验&#xff0c;那如何去实现一个LLM的对话交互页面呢&…

Python MySQL数据库连接与基本使用

一、应用场景 python项目连接MySQL数据库时&#xff0c;需要第三方库的支持。这篇文章使用的是PyMySQL库&#xff0c;适用于python3.x。 二、安装 pip install PyMySQL三、使用方法 导入模块 import pymysql连接数据库 db pymysql.connect(hostlocalhost,usercode_space…

Spring MVC开发流程

1.Spring MVC环境基本配置 Maven工程依赖spring-webmvc <dependency><groupId>org.springframework</groupId><artifactId>spring-webmvc</artifactId><version>5.1.9.RELEASE</version> </dependency>web.xml配置Dispatche…

NSSCTF第16页(2)

[NSSRound#4 SWPU]1zweb(revenge) 查看index.php <?php class LoveNss{public $ljt;public $dky;public $cmd;public function __construct(){$this->ljt"ljt";$this->dky"dky";phpinfo();}public function __destruct(){if($this->ljt"…

【力扣100】73.矩阵置零

添加链接描述 class Solution:def setZeroes(self, matrix: List[List[int]]) -> None:"""Do not return anything, modify matrix in-place instead."""# 思路是1.记录每一个0元素的行和列下标 2.遍历全数组row_index[]column_index[]mlen(…

day01unittest复习,断言

1.unittest 方法执行前 # def setUp(self) -> None: # print(方法执行前执行) # # def tearDown(self) -> None: # print(方法执行后执行一次) 2.unittest 类方法执行前后执行一次 classmethod def setUpClass(cls) -> None:print(类执行前执行一次)classm…

41、BatchNorm - 什么是批归一化

在 CNN 网络中有一个很重要的技术,叫作批归一化(bn, BatchNorm )。 归一化层一般位于卷积的后面,学术或者工程上,一般习惯将卷积+批归一化+激活统一成一个小的网络结构,比如口语化上称为conv+bn+relu。 这是因为基本上卷积后面肯定会有批归一化,而后面肯定会接激活函数…

微分和导数(一)

1.微分&#xff1a; 假设我们有⼀个函数f : R → R&#xff0c;其输⼊和输出都是标量。如果f的导数存在&#xff0c;这个极限被定义为 如果f′(a)存在&#xff0c;则称f在a处是可微的。如果f在⼀个区间内的每个数上都是可微的&#xff0c;则此函数在此区间中是可微的。导数f′…

网络协议 - UDP 协议详解

网络协议 - UDP 协议详解 UDP概述UDP特点UDP的首部格式UDP校验 參考文章 基于TCP和UDP的协议非常广泛&#xff0c;所以也有必要对UDP协议进行详解。 UDP概述 UDP(User Datagram Protocol)即用户数据报协议&#xff0c;在网络中它与TCP协议一样用于处理数据包&#xff0c;是一种…

必要时进行保护性拷贝

保护性拷贝&#xff08;Defensive Copy&#xff09;是一种常见的编程实践&#xff0c;用于在传递参数或返回值时&#xff0c;创建副本以防止原始对象被意外修改。以下是一个例子&#xff0c;展示了何时进行保护性拷贝&#xff1a; mport java.util.ArrayList; import java.uti…

成功解决 Plugin ‘org.springframework.boot:spring-boot-maven-plugin:‘ not found

Plugin ‘org.springframework.boot:spring-boot-maven-plugin:‘ not found的解决方案&#xff0c;亲测可用&#xff01; 方法一&#xff1a;清理IDEA的缓存 File -> Invalidate Caches 方法二&#xff1a;添加版本号 先看自己当前的版本号 首先打开pom.xml文件进行查看C…

数据手册Datasheet解读-肖特基二极管笔记

数据手册Datasheet解读笔记1-肖特基二极管 数据手册大体结构共包含10个部分肖特基二极管-SS14第一重点关注点&#xff1a;极限值第二重点关注点&#xff1a;电气特性 数据手册大体结构共包含10个部分 1.Features一特性 2.Application一应用 3.Description一说明4.Pin Configur…

关于在Java中打印“数字”三角形图形的汇总

之前写过一篇利用*打印三角形汇总&#xff0c;网友需要查看可以去本专栏查找之前的文章&#xff0c;这里利用二维数组嵌套循环打印“数字”三角形&#xff0c;汇总如下&#xff0c;话不多说&#xff0c;直接上代码&#xff1a; /*** 打印如下数字三角形图形*/ public class Wo…

逻辑分析仪_使用手册

LA1010 1> 能干啥&#xff1f;2> 硬件连接3> 软件安装4> 参数设置4.1> 采样深度和采样率4.2> 添加协议解析器4.3> 毛刺过滤设置 1> 能干啥&#xff1f; 测量通信波形&#xff0c;并自动解析&#xff1b; 比如测量&#xff0c;UART&#xff0c;SPI&…

K8S学习指南(22)-k8s核心对象Endpoint

文章目录 前言什么是Kubernetes Endpoint&#xff1f;Endpoint的结构Endpoint与Service的关系Endpoint的使用动态管理Endpoint总结 前言 在Kubernetes&#xff08;K8s&#xff09;中&#xff0c;Endpoint是一个关键的核心对象&#xff0c;它承担着连接Service和后端Pod的重要角…

【DataSophon】大数据管理平台DataSophon-1.2.1安装部署详细流程

&#x1f984; 个人主页——&#x1f390;开着拖拉机回家_Linux,大数据运维-CSDN博客 &#x1f390;✨&#x1f341; &#x1fa81;&#x1f341;&#x1fa81;&#x1f341;&#x1fa81;&#x1f341;&#x1fa81;&#x1f341; &#x1fa81;&#x1f341;&#x1fa81;&am…

java_web_电商项目

java_web_电商项目 1.登录界面2.注册界面3. 主界面4.分页界面5.商品详情界面6. 购物车界面7.确认订单界面8.个人中心界面9.收货地址界面10.用户信息界面11.用户余额充值界面12.后台首页13.后台商品增加14.后台用户增加15.用户管理16.源码分享1.登录页面的源码2.我们的主界面 1.…