基于Pytorch的LSTM网络全流程实验(自带数据集,可直接运行出结果,替换自己的数据集即可使用)

文章目录

    • LSTM代码
    • 双向LSTM,需要修改哪几个参数?

LSTM代码

import numpy as np
import matplotlib.pyplot as pltimport torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDatasetfrom sklearn.model_selection import train_test_split# 生成数据集
def generate_data(num_samples, seq_length):# 生成正弦波形数据(类别0)half_num_samples=num_samples//2 # 整除x_sin = np.array([np.sin(0.06 * np.arange(seq_length) + np.random.rand()) for _ in range(half_num_samples)])y_sin = np.zeros(half_num_samples, dtype=np.int64)# 生成余弦波形数据(类别1),频率略有不同x_cos = np.array([np.cos(0.05 * np.arange(seq_length) + np.random.rand()) for _ in range(half_num_samples)])y_cos = np.ones(half_num_samples, dtype=np.int64)# 合并数据x = np.concatenate((x_sin, x_cos), axis=0)y = np.concatenate((y_sin, y_cos), axis=0)# 打乱数据indices = np.arange(num_samples)np.random.shuffle(indices)x = x[indices]y = y[indices]# 转换为pytorch张量,LSTM需要3D tensor [batch, seq_len, features],# 所以用unsqueeze(2)在第二个维度上增加一个维度x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(2)  print('x_tensor.shape:',x_tensor.shape) # x_tensor.shape: torch.Size([1000, 100, 1])y_tensor = torch.tensor(y, dtype=torch.int64) # y_tensor.shape: torch.Size([1000])print('y_tensor.shape:',y_tensor.shape)return x_tensor, y_tensor# LSTM分类模型
class LSTMClassifier(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim, n_layers):super(LSTMClassifier, self).__init__()self.hidden_dim = hidden_dimself.n_layers = n_layers# LSTM Layerself.lstm = nn.LSTM(input_dim, hidden_dim, n_layers, batch_first=True)# 全连接层(Fully connected layer)self.fc = nn.Linear(hidden_dim, output_dim)# forward方法在模型训练时会自动调用def forward(self, x):# 用零初始化隐藏层的状态h0 = torch.zeros(self.n_layers, x.size(0), self.hidden_dim).requires_grad_()# 用零初始化细胞状态c0 = torch.zeros(self.n_layers, x.size(0), self.hidden_dim).requires_grad_()out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))out = self.fc(out[:, -1, :])return out# 训练模型
def train_model(model, train_loader, criterion, optimizer, num_epochs):for epoch in range(num_epochs):for i, (sequences, labels) in enumerate(train_loader):# Forward passoutputs = model(sequences)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}], Loss: {loss.item():.4f}')# 评估模型
def evaluate_model(model, test_loader):model.eval()  # Set model to evaluation modewith torch.no_grad():correct = 0total = 0for sequences, labels in test_loader:outputs = model(sequences)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the model on the test sequences: {100 * correct / total} %')if __name__=='__main__':# ----------------- 生成样本数据 ----------------- num_samples = 1000  # 训练总样本数seq_length = 100    # 每个样本的序列长度(可以看作是特征的长度)x_data,y_data = generate_data(num_samples, seq_length) # 产生总的样本x_train, x_test, y_train, y_test = train_test_split(x_data, y_data, test_size=0.2, random_state=2) # ----------------- 数据加载器 ----------------- batch_size=64train_loader = DataLoader(TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)test_loader = DataLoader(TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)# ----------------- 可视化数据 ----------------- plt.figure(figsize=(12, 6))for i in range(6):plt.subplot(2, 3, i+1)plt.plot(x_train[i].numpy().flatten(), label=f"Class {y_train[i].item()}")plt.legend()plt.tight_layout()plt.show() # 不想看数据,可以注释掉这行# ----------------- 超参设定 ----------------- input_dim = 10    # 输入特征的维数hidden_dim = 50   # LSTM 隐藏层的维度output_dim = 2    # 输出的维度(分类的类别数)n_layers = 1      # 堆叠的 LSTM 层的数量(默认为1层)# ----------------- 创建模型 ----------------- model = LSTMClassifier(input_dim=1, hidden_dim=50, output_dim=2, n_layers=1)criterion = nn.CrossEntropyLoss() # 损失函数optimizer = optim.Adam(model.parameters(), lr=0.01) # 优化器# ----------------- 训练模型 ----------------- train_model(model, train_loader, criterion, optimizer, num_epochs=10)# ----------------- 评估模型 ----------------- evaluate_model(model, test_loader)

双向LSTM,需要修改哪几个参数?

需要在 nn.LSTM 的构造函数中设置 bidirectional=True。此外,由于双向 LSTM 在每个时间步将会有两个隐藏状态(正向和反向),因此全连接层的输入特征数需要调整为 2 * hidden_size

下面是对您的代码的修改部分,以及需要注意的几个点:

  1. nn.LSTM 中设置 bidirectional=True来启用双向功能。
  2. h0c0 的尺寸都乘以了 2,因为对于每一层 LSTM,我们现在有两个隐藏层状态(一个用于前向传播,一个用于后向传播)。
  3. 调整全连接层的输入特征数,由 hidden_size 改为 2 * hidden_size,以适应双向输出。

修改后的代码如下:

import torch
import torch.nn as nn# 定义LSTM网络
class LSTM_Model(nn.Module):"""input_size:输入特征的维数hidden_size:LSTM 隐藏层的维度num_layers:堆叠的 LSTM 层的数量class_num: 分类的类别数batch_first: 输入和输出的维度顺序是否为 (batch, seq, feature)"""def __init__(self, input_size, hidden_size, num_layers, class_num):super(LSTM_Model, self).__init__()self.hidden_size = hidden_size self.num_layers = num_layers# 修改为双向LSTMself.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)# 修改全连接层输入特征数为 2 * hidden_sizeself.fc = nn.Linear(in_features=2 * hidden_size, out_features=class_num)def forward(self, x):DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 初始化隐藏层状态全为0h0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).requires_grad_().to(DEVICE)  # 注意乘以2,因为是双向c0 = torch.zeros(self.num_layers * 2, x.size(0), self.hidden_size).requires_grad_().to(DEVICE)  # 注意乘以2,因为是双向x = x.view(x.size(0), 1, -1)out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))# 只需要最后一层隐层的状态,考虑双向,所以取最后一步的输出out = self.fc(out[:, -1, :])  # 这里不用改变return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/4529.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

沉浸式翻译 chrome 插件 Immersive Translate - Translate Website PDF

免费翻译网站,翻译PDF和Epub电子书,双语翻译视频字幕 📣 网络上口碑爆炸的网站翻译扩展工具【沉浸式翻译】⭐⭐⭐⭐⭐ 💻 功能特点如下: 📰 网站翻译 🚀 提供双语网站翻译,智能识…

代谢组数据分析五:溯源分析

MetOrigin Analysis {#MetOriginAnalysis} 微生物群及其代谢产物与人类健康和疾病密切相关。然而,理解微生物组和代谢物之间复杂的相互作用是具有挑战性的。 在研究肠道代谢物时,代谢物的来源是一个无法避免的问题即代谢物到底是来自肠道微生物的代谢还是宿主本身代谢产生的…

web自动化系列-selenium的基本方法介绍

web自动化 ,一个老生常谈的话题 ,很多人的自动化之路就是从它开始 。它学起来简单 ,但做起来又比较难以驾驭 ;它的执行效率慢 、但又是最接近于用户的操作场景 ; 1.web自动化中的三大亮点技术 我们先聊聊 &#xff0…

登录rabbitMQ管理界面时浏览器显示要求进行身份验证,与此站点连接不安全解决办法

问题描述 最近在黑马学习rabbitMQ的过程中,在使用docker部署好rabbitMQ后,使用账号为:itcast,密码为:123321 登录的时候浏览器显示了这个问题,如图所示: 当时以为自己需要输入自己的浏览…

Spring Web MVC入门(3)——响应

目录 一、返回静态页面 RestController 和 Controller之间的关联和区别 二、返回数据ResponseBody ResponseBody作用在类和方法的情况 三、返回HTML代码片段 响应中的Content-Type常见的取值: 四、返回JSON 五、设置状态码 六、设置Header 1、设置Content…

【C++】---STL容器适配器之底层deque浅析

【C】---STL容器适配器之底层deque浅析 一、deque的使用二、deque的原理1、deque的结构2、deque的底层结构(1)deque的底层空间(2)deque如何支持随机访问、deque迭代器 3、deque的优缺点(1)deque的优势&…

java基础之java容器-Collection,Map

java容器 java容器分类一. Collection1. List①. ArrayList② . LinkedList③ . Vector 2. Queue队列①. LinkedList②. PriorityQueue 3. Set集合①. HashSet②. TreeSet 二. Map1. HashMap2.TreeMap3. Hashtable java容器分类 java容器分为两大类,分别是Collecti…

代码随想录算法训练营第五十三天|1143.最长公共子序列 、 1035.不相交的线、 53. 最大子序和

1143 题目: 给定两个字符串 text1 和 text2,返回这两个字符串的最长 公共子序列 的长度。如果不存在 公共子序列 ,返回 0 。一个字符串的 子序列 是指这样一个新的字符串:它是由原字符串在不改变字符的相对顺序的情况下删除某些…

探索区块链世界:赋能创新,揭示区块链媒体发稿的影响力-世媒讯

区块链,这个由“区块”和“链”组成的概念,可能在您眼中充满神秘和复杂,但其实甚至无所不在,它正静悄悄地改变着我们日常生活的方方面面,从金融到媒体,从医疗到教育。 我们来揭开区块链的神秘面纱。区块链…

VRRP基础

1.基本概念 VRRP(Virtual Router Redundancy protocol,虚拟路由冗余协议) VRRP能够在不改变组网的情况下,将多台路由器虚拟成一个虚拟路由器,通过配置虚拟路由器的IP地址为默认网关,实现网关的备份。 VRRP协议版本为VRRPv2&…

SQLServer条件查询,排序

一.常用的运算符 &#xff1a;相等 !&#xff1a;不等 >&#xff1a;大于 <&#xff1a;小于 >&#xff1a;大于等于 <&#xff1a;小于等于 IS NULL&#xff1a;为空 IS NOT NULL&#xff1a;不为空 in&#xff1a;在其中 like&#xff1a;模糊查询 BE…

Java多线程基础

Java多线程 文章目录 Java多线程一、线程介绍及相关概念二、创建和启动线程2.1 Thread类的常用结构2.2 创建线程法1&#xff1a;继承Thread类&#xff08;分配线程对象&#xff09;2.3 创建线程法2&#xff1a;实现Runnable接口&#xff08;创建线程的目标对象&#xff09;2.4 …

揭示C++设计模式中的实现结构及应用——行为型设计模式

简介 行为型模式&#xff08;Behavioral Pattern&#xff09;是对在不同的对象之间划分责任和算法的抽象化。 行为型模式不仅仅关注类和对象的结构&#xff0c;而且重点关注它们之间的相互作用。 通过行为型模式&#xff0c;可以更加清晰地划分类与对象的职责&#xff0c;并…

易错知识点(学习过程中不断记录)

快捷键专区&#xff1a; 注释&#xff1a;ctrl/ ctrlshift/ 保存&#xff1a;ctrls 调试&#xff1a; 知识点专区&#xff1a; 1基本数据类型 基本数据类型有四类&#xff1a;整型、浮点型、字符型、布尔型&#xff08;Boolean&#xff09;&#xff0c; 分为八种&#xff…

JS判断元素是否在数组中

在JavaScript中&#xff0c;有多种方法可以用来判断一个元素是否存在于数组中。以下是其中的一些方法&#xff1a; 1. 使用 Array.prototype.includes() 方法 includes() 方法用于判断一个数组是否包含一个指定的值&#xff0c;根据情况&#xff0c;如果需要区分大小写&#…

AI图书推荐:《企业AI转型:如何在企业中部署ChatGPT?》

Jay R. Enterprise AI in the Cloud. A Practical Guide...ChatGPT Solutions &#xff08;《企业AI转型&#xff1a;如何在企业中部署ChatGPT&#xff1f;》&#xff09;是一本由Rabi Jay撰写、于2024年由John Wiley & Sons出版的书籍&#xff0c;主要为企业提供实施AI转型…

半导体厂FDC系统 的trace data知识

01、什么是FDC系统 在半导体行业中,FDC系统通常指的是"Failure Data Collection"(故障数据收集)系统。FDC系统的作用是收集、存储和分析在半导体制造过程中检测到的故障或不良品数据。以下是FDC系统的一些关键作用: 1. **故障检测**:FDC系统可以实时监测生产线…

python facebook business SDK campaign 广告复制方法

facebook广告复制调试了一天&#xff0c;特此记录&#xff0c;广告复制分为两个步骤&#xff1a; 第一步&#xff1a;使用campaign.create_copy()复制广告系列。 第二步&#xff1a;复制源广告广告集&#xff08;ad_set&#xff09;如果广告集需要修改&#xff0c;使用api_upd…

(六)Servlet教程——JSP与Servlet的关系

JSP与Servlet的关系 JSP&#xff08;Java Server Pages&#xff09;是继Servlet后Sun公司推出的新技术。JSP技术在传统的HTML文件中插入Java程序段和JSP标记&#xff0c;从而形成JSP文件&#xff08;*.jsp&#xff09;。用JSP开发的Web应用是跨平台的&#xff0c;既能在Window…

数据结构与算法目录

1、基本认识 算法的复杂度和稳定性 链接&#xff1a;算法的复杂度和稳定性_o(1) < o(logn) < o(n) < o(nlogn) < o(n^2) < o(n^3)-CSDN博客Java的链表的创建、插入、修改、删除、查询等 链接&#xff1a;Java的链表的创建、插入、修改、删除、查询等_java链表查…