pytorch实现循环神经网络

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

PyTorch 提供三种主要的 RNN 变体:

  • nn.RNN:最基本的循环神经网络,适用于短时依赖任务。
  • nn.LSTM:长短时记忆网络,适用于长序列数据,能有效解决梯度消失问题。
  • nn.GRU:门控循环单元,比 LSTM 计算更高效,适用于大部分任务。
网络类型优势适用场景
RNN计算简单,适用于短时序列语音、文本处理(短序列)
LSTM适用于长序列,能记忆长期信息机器翻译、语音识别、股票预测
GRU比 LSTM 计算更高效,效果相似语音处理、文本生成

例子:

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 1. 生成正弦波数据(仅使用 PyTorch)
def generate_sine_wave(seq_length=10, num_samples=1000):x = torch.linspace(0, 100, num_samples)  # 生成 1000 个等间距数据点y = torch.sin(x)  # 计算正弦值X_data, Y_data = [], []for i in range(len(y) - seq_length):X_data.append(y[i:i + seq_length].unsqueeze(-1))  # 过去 seq_length 作为输入Y_data.append(y[i + seq_length])  # 预测下一个点return torch.stack(X_data), torch.tensor(Y_data).unsqueeze(-1)# 生成数据
seq_length = 10  # 序列长度
X, Y = generate_sine_wave(seq_length)# 划分训练集和测试集
train_size = int(0.8 * len(X))
X_train, X_test = X[:train_size], X[train_size:]
Y_train, Y_test = Y[:train_size], Y[train_size:]# 2. 定义 RNN 模型
class SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(SimpleRNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)  # 初始化隐藏状态out, _ = self.rnn(x, h0)out = self.fc(out[:, -1, :])  # 取最后一个时间步的输出return out# 3. 训练模型
# 超参数
input_size = 1
hidden_size = 32
output_size = 1
num_layers = 1
num_epochs = 100
learning_rate = 0.001# 初始化模型
model = SimpleRNN(input_size, hidden_size, output_size, num_layers)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 训练
for epoch in range(num_epochs):model.train()optimizer.zero_grad()outputs = model(X_train)loss = criterion(outputs, Y_train)loss.backward()optimizer.step()if (epoch + 1) % 10 == 0:print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')# 4. 评估与绘图
model.eval()
with torch.no_grad():predictions = model(X_test)# 画图
plt.figure(figsize=(10, 5))
plt.plot(Y_test.numpy(), label="Real Data")
plt.plot(predictions.numpy(), label="Predicted Data")
plt.legend()
plt.title("RNN Sine Wave Prediction")
plt.show()

代码解析

数据生成

  • torch.linspace(0, 100, num_samples) 生成 1000 个均匀分布的数据点。
  • torch.sin(x) 计算正弦值,形成时间序列数据。
  • X过去 10 个时间步的数据,Y下一个时间步的预测目标

构建 RNN

  • nn.RNN(input_size, hidden_size, num_layers, batch_first=True) 定义循环神经网络
    • input_size=1:每个时间步只有一个输入值(正弦波)。
    • hidden_size=32:隐藏层神经元数目。
    • num_layers=1:单层 RNN。
  • self.fc = nn.Linear(hidden_size, output_size) 负责最终输出。

训练

  • 使用 MSELoss(均方误差损失) 计算预测值与真实值的误差。
  • 使用 Adam 优化器 更新模型参数。
  • 每 10 个 epoch 输出一次损失 loss

测试 & 绘图

  • 关闭梯度计算 (torch.no_grad()),执行前向传播预测测试数据。
  • Matplotlib 绘制预测曲线与真实曲线。

运行效果

如果训练成功,预测曲线(橙色)应该与真实曲线(蓝色)非常接近

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/70130.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt u盘自动升级软件

Qt u盘自动升级软件 Chapter1 Qt u盘自动升级软件u盘自动升级软件思路:step1. 获取U盘 判断U盘名字是否正确, 升级文件是否存在。step2. 升级step3. 升级界面 Chapter2 Qt 嵌入式设备应用程序,通过U盘升级的一种思路Chapter3 在开发板上运行的…

4种架构的定义和关联

文章目录 **1. 各架构的定义****业务架构(Business Architecture)****应用架构(Application Architecture)****数据架构(Data Architecture)****技术架构(Technology Architecture)*…

FinRobot:一个使用大型语言模型的金融应用开源AI代理平台

“FinRobot: An Open-Source AI Agent Platform for Financial Applications using Large Language Models” 论文地址:https://arxiv.org/pdf/2405.14767 Github地址:https://github.com/AI4Finance-Foundation/FinRobot 摘要 在金融领域与AI社区间&a…

DDD - 微服务架构模型_领域驱动设计(DDD)分层架构 vs 整洁架构(洋葱架构) vs 六边形架构(端口-适配器架构)

文章目录 引言1. 概述2. 领域驱动设计(DDD)分层架构模型2.1 DDD的核心概念2.2 DDD架构分层解析 3. 整洁架构:洋葱架构与依赖倒置3.1 整洁架构的核心思想3.2 整洁架构的层次结构 4. 六边形架构:解耦核心业务与外部系统4.1 六边形架…

【大模型LLM面试合集】大语言模型架构_llama系列模型

llama系列模型 1.LLama 1.1 简介 Open and Efficient Foundation Language Models (Open但没完全Open的LLaMA) 2023年2月,Meta(原Facebook)推出了LLaMA大模型,使用了1.4T token进行训练,虽然最大模型只有65B&…

深入探索Vue 3组合式API

深入探索Vue 3组合式API 深入探索Vue 3组合式API一、组合式API诞生背景1.1 Options API的局限性1.2 设计目标二、核心概念解析2.1 setup() 函数:组合式API的基石2.2 响应式系统:重新定义数据驱动2.3 生命周期:全新的接入方式2.4 响应式原理探…

微调llama3问题解决-RuntimeError: CUDA unknown error - this may be due to an incorrectly set up environment

问题说明之一 具体问题如下: RuntimeError: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero.我使用的这套是根据…

【JavaScript】《JavaScript高级程序设计 (第4版) 》笔记-Chapter1-什么是 JavaScript

一、什么是 JavaScript 虽然 JavaScript 和 ECMAScript(发音为“ek-ma-script”) 基本上是同义词,但 JavaScript 远远不限于 ECMA-262 所定义的那样。没错,完整的 JavaScript 实现包含以下几个部分。 核心(ECMAScript&…

2. 【.NET Aspire 从入门到实战】--理论入门与环境搭建--.NET Aspire 概览

在当今快速发展的软件开发领域,构建高效、可靠且易于维护的云原生应用程序已成为开发者和企业的核心需求。.NET Aspire 作为一款专为云原生应用设计的开发框架,旨在简化分布式系统的构建和管理,提供了一整套工具、模板和集成包,帮…

49【服务器介绍】

服务器和你的电脑可以说是一模一样的,只不过用途不一样,叫法就不一样了 物理服务器和云服务器的区别 整台设备眼睛能够看得到的,我们一般称之为物理服务器。所以物理服务器是比较贵的,不是每一个开发者都能够消费得起的。 …

Redis代金卷(优惠卷)秒杀案例-单应用版

优惠卷表:优惠卷基本信息,优惠金额,使用规则 包含普通优惠卷和特价优惠卷(秒杀卷) 优惠卷的库存表:优惠卷的库存,开始抢购时间,结束抢购时间.只有特价优惠卷(秒杀卷)才需要填写这些信息 优惠卷订单表 卷的表里已经有一条普通优惠卷记录 下面首先新增一条秒杀优惠卷记录 { &quo…

Notepad++消除生成bak文件

设置(T) ⇒ 首选项... ⇒ 备份 ⇒ 勾选 "禁用" 勾选禁用 就不会再生成bak文件了 notepad怎么修改字符集编码格式为gbk 如图所示

DeepSeek蒸馏模型:轻量化AI的演进与突破

目录 引言 一、知识蒸馏的技术逻辑与DeepSeek的实践 1.1 知识蒸馏的核心思想 1.2 DeepSeek的蒸馏架构设计 二、DeepSeek蒸馏模型的性能优势 2.1 效率与成本的革命性提升 2.2 性能保留的突破 2.3 场景适应性的扩展 三、应用场景与落地实践 3.1 智能客服系统的升级 3.2…

物联网领域的MQTT协议,优势和应用场景

MQTT(Message Queuing Telemetry Transport)作为轻量级发布/订阅协议,凭借其低带宽消耗、低功耗与高扩展性,已成为物联网通信的事实标准。其核心优势包括:基于TCP/IP的异步通信机制、支持QoS(服务质量&…

基于“蘑菇书”的强化学习知识点(五):条件期望

条件期望 摘要一、条件期望的定义二、条件期望的关键性质三、条件期望的直观理解四、条件期望的应用场景五、简单例子离散情况连续情况 摘要 本系列知识点讲解基于蘑菇书EasyRL中的内容进行详细的疑难点分析!具体内容请阅读蘑菇书EasyRL! 对应蘑菇书Eas…

Node.js与嵌入式开发:打破界限的创新结合

文章目录 一、Node.js的本质与核心优势1.1 什么是Node.js?1.2 嵌入式开发的范式转变 二、Node.js与嵌入式结合的四大技术路径2.1 硬件交互层2.2 物联网协议栈2.3 边缘计算架构2.4 轻量化运行时方案 三、实战案例:智能农业监测系统3.1 硬件配置3.2 软件架…

Shell 中的 Globbing:原理、使用方法与实现解析(中英双语)

Shell 中的 Globbing:原理、使用方法与实现解析 在 Unix Shell(如 Bash、Zsh)中,globbing 是指 文件名模式匹配(filename pattern matching),它允许用户使用特殊的通配符(wildcards…

7 与mint库对象互转宏(macros.rs)

macros.rs代码定义了一个Rust宏mint_vec,它用于在启用mint特性时,为特定的向量类型实现与mint库中对应类型的相互转换。mint库是一个提供基本数学类型(如点、向量、矩阵等)的Rust库,旨在与多个图形和数学库兼容。这个宏…

P3078[USACO13MAR] Poker Hands S

P3078[USACO13MAR] Poker Hands S https://www.luogu.com.cn/problem/P3078 前言 学习差分后写的第一道题,直接给我干懵逼,题解都看不懂……吃了个晚饭后开窍写出来了,遂成此篇。 题目 翻译版本 Bessie 和她的朋友们正在玩一种独特的扑克游…

【物联网】ARM核常用指令(详解):数据传送、计算、位运算、比较、跳转、内存访问、CPSR/SPSR

文章目录 指令格式(重点)1. 立即数2. 寄存器位移 一、数据传送指令1. MOV指令2. MVN指令3. LDR指令 二、数据计算指令1. ADD指令1. SUB指令1. MUL指令 三、位运算指令1. AND指令2. ORR指令3. EOR指令4. BIC指令 四、比较指令五、跳转指令1. B/BL指令2. l…