pytorch实现长短期记忆网络 (LSTM)

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

LSTM 通过 记忆单元(cell)三个门控机制(遗忘门、输入门、输出门)来控制信息流:

 记忆单元(Cell State)

  • 负责存储长期信息,并通过门控机制决定保留或丢弃信息。

 遗忘门(Forget Gate, ftf_tft​)

 输入门(Input Gate, iti_tit​)

 输出门(Output Gate, oto_tot​)

特性

传统 RNNLSTM
记忆能力短期记忆长短期记忆
计算复杂度
解决梯度消失
适用场景短序列数据长序列数据

LSTM 应用场景

  • 自然语言处理(NLP):文本生成、情感分析、机器翻译
  • 时间序列预测:股票预测、天气预报、传感器数据分析
  • 语音识别:自动字幕生成、语音转文字(ASR)
  • 机器人与控制系统:智能体决策、自动驾驶

例子:

下面例子实现了一个 基于 LSTM 的强化学习智能体,在 1D 网格环境 里移动,并找到最优路径。
最终,我们 绘制 5 条测试路径,并高亮显示最佳路径(红色)

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# ========== 1. 定义 LSTM 策略网络 ==========
class LSTMPolicy(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(LSTMPolicy, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)self.softmax = nn.Softmax(dim=-1)def forward(self, x, hidden_state):batch_size = x.size(0)# 确保 hidden_state 维度正确if hidden_state[0].dim() == 2:hidden_state = (hidden_state[0].unsqueeze(1).repeat(1, batch_size, 1),hidden_state[1].unsqueeze(1).repeat(1, batch_size, 1))out, hidden_state = self.lstm(x, hidden_state)out = self.fc(out[:, -1, :])  # 取最后时间步的输出action_prob = self.softmax(out)  # 归一化输出,作为策略return action_prob, hidden_statedef init_hidden(self, batch_size=1):return (torch.zeros(self.num_layers, batch_size, self.hidden_size),torch.zeros(self.num_layers, batch_size, self.hidden_size))# ========== 2. 创建网格环境 ==========
class GridWorld:def __init__(self, grid_size=10, goal_position=9):self.grid_size = grid_sizeself.goal_position = goal_positionself.reset()def reset(self):self.position = 0return self.positiondef step(self, action):if action == 0:self.position = max(0, self.position - 1)elif action == 1:self.position = min(self.grid_size - 1, self.position + 1)reward = 1 if self.position == self.goal_position else -0.1done = self.position == self.goal_positionreturn self.position, reward, done# ========== 3. 训练智能体 ==========
def train(num_episodes=500, max_steps=50):env = GridWorld()input_size = 1hidden_size = 64output_size = 2num_layers = 1policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)optimizer = optim.Adam(policy.parameters(), lr=0.01)gamma = 0.99for episode in range(num_episodes):state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0)  # (1, 1, input_size)hidden_state = policy.init_hidden(batch_size=1)log_probs = []rewards = []for step in range(max_steps):action_probs, hidden_state = policy(state, hidden_state)action = torch.multinomial(action_probs, 1).item()log_prob = torch.log(action_probs.squeeze(0)[action])log_probs.append(log_prob)next_state, reward, done = env.step(action)rewards.append(reward)if done:breakstate = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)# 计算回报并更新策略returns = []R = 0for r in reversed(rewards):R = r + gamma * Rreturns.insert(0, R)returns = torch.tensor(returns, dtype=torch.float32)returns = (returns - returns.mean()) / (returns.std() + 1e-9)loss = sum([-log_prob * R for log_prob, R in zip(log_probs, returns)])optimizer.zero_grad()loss.backward()optimizer.step()if (episode + 1) % 50 == 0:print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {sum(rewards)}")torch.save(policy.state_dict(), "policy.pth")# 训练智能体
train(500)# ========== 4. 测试智能体并绘制最佳路径 ==========
def test(num_episodes=5):env = GridWorld()input_size = 1hidden_size = 64output_size = 2num_layers = 1policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)policy.load_state_dict(torch.load("policy.pth"))plt.figure(figsize=(10, 5))best_path = Nonebest_steps = float('inf')for episode in range(num_episodes):state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0)  # (1, 1, input_size)hidden_state = policy.init_hidden(batch_size=1)positions = [env.position]  # 记录位置变化while True:action_probs, hidden_state = policy(state, hidden_state)action = torch.argmax(action_probs, dim=-1).item()next_state, reward, done = env.step(action)positions.append(next_state)if done:breakstate = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)# 记录最佳路径(最短步数)if len(positions) < best_steps:best_steps = len(positions)best_path = positions# 绘制普通路径(蓝色)plt.plot(range(len(positions)), positions, marker='o', linestyle='-', color='blue', alpha=0.6,label=f'Episode {episode + 1}' if episode == 0 else "")# 绘制最佳路径(红色)if best_path:plt.plot(range(len(best_path)), best_path, marker='o', linestyle='-', color='red', linewidth=2,label="Best Path")# 打印最佳路径print(f"Best Path (steps={best_steps}): {best_path}")plt.xlabel("Time Steps")plt.ylabel("Agent Position")plt.title("Agent's Movement Path (Best Path in Red)")plt.legend()plt.grid(True)plt.show()# 测试并绘制智能体移动路径
test(5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/70070.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

后盾人JS--继承

继承是原型的继承 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title> </hea…

自定义数据集 使用scikit-learn中SVM的包实现SVM分类

生成自定义数据集 生成一个简单的二维数据集&#xff0c;包含两类数据点&#xff0c;分别用不同的标签表示。 import numpy as np import matplotlib.pyplot as plt# 生成数据 np.random.seed(42) X np.r_[np.random.randn(100, 2) - [2, 2], np.random.randn(100, 2) [2, …

实际操作 检测缺陷刀片

号he 找到目标图像的缺陷位置&#xff0c;首先思路为对图像进行预处理&#xff0c;灰度-二值化-针对图像进行轮廓分析 //定义结构元素 Mat se getStructuringElement(MORPH_RECT, Size(3, 3), Point(-1, -1)); morphologyEx(thre, tc, MORPH_OPEN, se, Point(-1, -1), 1); …

从实数与复数在交流电路正弦量表示中的对比分析

引言 在交流电路领域&#xff0c;深入理解电压和电流等正弦量的表示方式对电路分析至关重要。其中&#xff0c;只用实数表示正弦量存在诸多局限性&#xff0c;而复数的引入则为正弦量的描述与分析带来了极大的便利。下面将从瞬时值角度&#xff0c;详细剖析只用实数的局限性&a…

Python3 OS模块中的文件/目录方法说明十四

一. 简介 前面文章简单学习了 Python3 中 OS模块中的文件/目录的部分函数。 本文继续来学习 OS 模块中文件、目录的操作方法&#xff1a;os.statvfs() 方法&#xff0c;os.symlink() 方法。 二. Python3 OS模块中的文件/目录方法 1. os.statvfs() 方法 os.statvfs() 方法用…

知识蒸馏教程 Knowledge Distillation Tutorial

来自于&#xff1a;Knowledge Distillation Tutorial 将大模型蒸馏为小模型&#xff0c;可以节省计算资源&#xff0c;加快推理过程&#xff0c;更高效的运行。 使用CIFAR-10数据集 import torch import torch.nn as nn import torch.optim as optim import torchvision.tran…

day38|leetcode 322零钱兑换,279.完全平方数,139.单词拆分

322. 零钱兑换 给你一个整数数组 coins &#xff0c;表示不同面额的硬币&#xff1b;以及一个整数 amount &#xff0c;表示总金额。 计算并返回可以凑成总金额所需的 最少的硬币个数 。如果没有任何一种硬币组合能组成总金额&#xff0c;返回 -1 。 你可以认为每种硬币的数量是…

Turing Complete-1位开关

要求如下&#xff1a; 我的思考&#xff1a; 把输入1当作控制信号&#xff0c;把输入2当作输出信号。 通过非门和开关使输入2形成双通道输出&#xff0c; 通道一为输出输入2取反。 通道二为输出输入2本身。 通过输入1来控制两个通道的开闭。

从Transformer到世界模型:AGI核心架构演进

文章目录 引言:架构革命推动AGI进化一、Transformer:重新定义序列建模1.1 注意力机制的革命性突破1.2 从NLP到跨模态演进1.3 规模扩展的黄金定律二、通向世界模型的关键跃迁2.1 从语言模型到认知架构2.2 世界模型的核心特征2.3 混合架构的突破三、构建世界模型的技术路径3.1 …

深度求索DeepSeek横空出世

真正的强者从来不是无所不能&#xff0c;而是尽我所能。多少有关输赢胜负的缠斗&#xff0c;都是直面本心的搏击。所有令人骄傲振奋的突破和成就&#xff0c;看似云淡风轻寥寥数语&#xff0c;背后都是数不尽的焚膏继晷、汗流浃背。每一次何去何从的困惑&#xff0c;都可能通向…

性能优化中的数据过滤优化

目录 以下是一些关于数据过滤优化的策略和方法 索引使用 避免全表扫描 使用分区 数据预处理 合理设计查询 利用缓存机制 数据库层面优化 系统中通常会有一些统计和分析的功能&#xff0c;以前我们主要针对结构化数据&#xff08;关系型数据库存储&#xff09;进行分析&a…

与本地Deepseek R1:14b的第一次交流

本地部署DS的方法&#xff0c;见&#xff1a;本地快速部署DeepSeek-R1模型——2025新年贺岁-CSDN博客 只有16GB内存且没有强大GPU的个人电脑&#xff0c;部署和运行14b参数的DS大模型已是天花板了。 运行模型 ollama run deepseek-r1:14b C:\Users\Administrator>ollama r…

Python 梯度下降法(六):Nadam Optimize

文章目录 Python 梯度下降法&#xff08;六&#xff09;&#xff1a;Nadam Optimize一、数学原理1.1 介绍1.2 符号定义1.3 实现流程 二、代码实现2.1 函数代码2.2 总代码 三、优缺点3.1 优点3.2 缺点 四、相关链接 Python 梯度下降法&#xff08;六&#xff09;&#xff1a;Nad…

【狂热算法篇】探秘图论之Dijkstra 算法:穿越图的迷宫的最短路径力量(通俗易懂版)

羑悻的小杀马特.-CSDN博客羑悻的小杀马特.擅长C/C题海汇总,AI学习,c的不归之路,等方面的知识,羑悻的小杀马特.关注算法,c,c语言,青少年编程领域.https://blog.csdn.net/2401_82648291?typebbshttps://blog.csdn.net/2401_82648291?typebbshttps://blog.csdn.net/2401_8264829…

Git 的起源与发展

序章&#xff1a;版本控制的前世今生 在软件开发的漫长旅程中&#xff0c;版本控制犹如一位忠诚的伙伴&#xff0c;始终陪伴着开发者们。它的存在&#xff0c;解决了软件开发过程中代码管理的诸多难题&#xff0c;让团队协作更加高效&#xff0c;代码的演进更加有序。 简单来…

MySQL(Undo日志)

后面也会持续更新&#xff0c;学到新东西会在其中补充。 建议按顺序食用&#xff0c;欢迎批评或者交流&#xff01; 缺什么东西欢迎评论&#xff01;我都会及时修改的&#xff01; 大部分截图和文章采用该书&#xff0c;谢谢这位大佬的文章&#xff0c;在这里真的很感谢让迷茫的…

全面剖析 XXE 漏洞:从原理到修复

目录 前言 XXE 漏洞概念 漏洞原理 XML 介绍 XML 结构语言以及语法 XML 结构 XML 语法规则 XML 实体引用 漏洞存在原因 产生条件 经典案例介绍分析 XXE 漏洞修复方案 结语 前言 网络安全领域暗藏危机&#xff0c;各类漏洞威胁着系统与数据安全。XXE 漏洞虽不常见&a…

初级数据结构:栈和队列

目录 一、栈 (一)、栈的定义 (二)、栈的功能 (三)、栈的实现 1.栈的初始化 2.动态扩容 3.压栈操作 4.出栈操作 5.获取栈顶元素 6.获取栈顶元素的有效个数 7.检查栈是否为空 8.栈的销毁 9.完整代码 二、队列 (一)、队列的定义 (二)、队列的功能 (三&#xff09…

C++STL(一)——string类

目录 一、string的定义方式二、 string类对象的容量操作三、string类对象的访问及遍历操作四、string类对象的修改操作五、string类非成员函数 一、string的定义方式 string是个管理字符数组的类&#xff0c;其实就是字符数组的顺序表。 它的接口也是非常多的。本章介绍一些常…

与,|与||的区别

按位运算符 | 和 & 功能与运算规则 |&#xff08;按位或运算符&#xff09;&#xff1a;对两个操作数的对应二进制位进行逻辑或运算。只要对应的两个二进制位中有一个为 1&#xff0c;则该位的结果为 1&#xff1b;只有当两个二进制位都为 0 时&#xff0c;结果才为 0。&…