使用LSTM(长短期记忆网络)模型预测股票价格的实例分析

一:LSTM与RNN的区别

LSTM(Long Short-Term Memory)是一种特殊的循环神经网络(RNN)架构。LSTM是为了解决传统RNN在处理长序列数据时遇到的梯度消失或梯度爆炸问题而设计的。

在传统的RNN中,信息通过隐藏状态在时间步之间传递,但由于权重的重复应用,随着时间的推移,梯度可能会迅速减小或增大,导致网络难以学习长期依赖关系。LSTM通过引入了一种称为“”(gates)的机制来解决这个问题,这些门可以控制信息的流动,从而允许网络在长序列中有效地保持和传递信息。

LSTM的四个主要组成部分是:

1: 细胞状态(Cell State):一个流动的载体,它携带有关观察到的输入序列的信息。细胞状态可以跨越时间步传递信息。

2: 遗忘门(Forget Gate):决定哪些信息应该从细胞状态中丢弃。遗忘门会读取当前的输入和上一时间步的隐藏状态,并输出一个0到1之间的数值,表示保留信息的程度。

3: 输入门(Input Gate):决定哪些新信息将被存储到细胞状态中。输入门由两部分组成:一个sigmoid层决定哪些值将要更新,和一个tanh层创建一个新的候选值向量,它们将会被加入到状态中。

4: 输出门(Output Gate):决定下一个隐藏状态的值。它读取当前的细胞状态和输入,并通过一个sigmoid层和一个tanh层来计算输出值。

LSTM的这些门通过使用sigmoid激活函数来决定信息的保留或丢弃,而tanh激活函数则用来创建新的候选值或输出值。

由于其设计上的优势,LSTM能够捕捉长期依赖关系,因此在处理复杂序列数据时非常有效。

二:使用LSTM预测股票价格

一个典型的LSTM实例可以是股票价格预测。在这个例子中,我们可以使用LSTM模型来学习股票价格的时间序列数据,并尝试预测未来的价格走势。

为了实现这个实例,我们需要完成以下几个步骤:

  1. 数据收集:获取股票价格的历史数据。
  2. 数据预处理
    • 数据清洗:去除异常值。
    • 数据归一化:使用MinMaxScaler将数据缩放到0到1之间。
  3. 构建LSTM模型
    • 设计网络结构:确定LSTM层的数量和每层的神经元数量。
    • 添加全连接层:用于输出预测结果。
    • 编译模型:选择优化器和损失函数。
  4. 训练模型:使用历史数据训练模型。
  5. 预测和评估:使用测试数据评估模型的性能。

接下来将演示一个使用Keras库中的LSTM(长短期记忆网络)模型进行股票价格预测的简单示例。

导入必要的库

import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from keras.models import Sequential
from keras.layers import LSTM, Dense
plt.rcParams['font.sans-serif'] = ['SimHei'] 
plt.rcParams['axes.unicode_minus'] = False 
  • numpy:用于数值计算。
  • matplotlib.pyplot:用于绘制图表。
  • MinMaxScaler:来自sklearn.preprocessing,用于将数据缩放到指定的范围(这里是0到1)。
  • Sequential:来自keras.models,用于创建神经网络模型。
  • LSTMDense:来自keras.layers,分别是LSTM层和全连接层。
  • plt.rcParams:设置matplotlib绘图参数,确保中文字体可以正确显示,并处理坐标轴的负号。

生成假设的股票价格数据集

prices = np.random.rand(100, 1).cumsum()
  • 使用numpy生成一个100行1列的随机数组,并将其累加,模拟股票价格走势。

数据预处理

prices_reshaped = prices.reshape(-1, 1)
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_prices = scaler.fit_transform(prices_reshaped)
  • 将一维的prices数组转换为二维,,因为MinMaxScaler需要二维输入。例如,如果 prices 是一个包含100个元素的一维数组,那么 prices_reshaped 将会是一个形状为 (100, 1) 的二维数组。
  • 创建一个MinMaxScaler对象,并将其用于缩放数据到0和1之间。

创建数据集

X, Y = [], []
for i in range(60, len(scaled_prices)):X.append(scaled_prices[i-60:i, 0])Y.append(scaled_prices[i, 0])
X, Y = np.array(X), np.array(Y)
  • 对于数据集中的每个点,使用过去60个时间点的数据作为输入X,并使用第61个时间点的数据作为输出Y
  • 遍历归一化后的股票价格数据:

    • for i in range(60, len(scaled_prices))::这个循环从索引60开始,直到scaled_prices数组的末尾。索引60意味着每个样本包含60个时间步长的数据。
  • 构建输入数据X:

    • X.append(scaled_prices[i-60:i, 0]):对于每个索引i,从scaled_prices中取出从i-60i-1的60个数据点,这些数据点将作为模型的输入。这里[:, 0]确保只选择一列数据,因为scaled_prices是一个二维数组。
  • 构建输出数据Y:

    • Y.append(scaled_prices[i, 0]):对于每个索引i,从scaled_prices中取出索引为i的数据点,这个数据点将作为模型的输出,即第61个时间步长的股票价格。

经过这个循环,X将包含40个的60个时间步长的数据,而Y将包含对应时间步长之后的股票价格。这样的数据结构非常适合用于训练时间序列预测模型LSTM,其中模型需要学习如何根据过去60个时间步长的数据来预测下一个时间步长的价格。

重构输入数据

X = np.reshape(X, (X.shape[0], X.shape[1], 1))
  • X: 这是一个NumPy数组,包含了模型的输入数据。

  • np.reshape(): NumPy中的函数,用于在不改变数据内容的情况下改变数组的形状。

  • (X.shape[0], X.shape[1], 1): 这是重塑操作的目标形状。

    • X.shape[0]: 表示X数组的第一个维度40,即样本的数量
    • X.shape[1]: 表示X数组的第二个维度60,即每个样本的特征数量
    • 1: 表示为每个样本增加一个维度,使其成为三维数组。

在LSTM网络中,期望的输入数据格式通常是三维的,其形状为 [样本数量, 时间步长, 特征数量]在这个例子中,每个样本是一个时间序列,包含了过去60个时间点的数据,而每个时间点只有一个特征(股票价格)。通过这行代码,X数组被重塑为以下形状:

  • [样本数量(40), 时间步长(60), 特征数量(1)]

这种形状是LSTM层能够正确处理的数据格式。

构建LSTM模型

model = Sequential()
model.add(LSTM(units=50, return_sequences=True, input_shape=(X.shape[1], 1)))
model.add(LSTM(units=50))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mean_squared_error')
  • 创建一个序贯模型。
  • 添加两个LSTM层,第一个LSTM层返回序列,第二个不返回。
  • 添加一个全连接层,输出一个值。
  • 编译模型,使用Adam优化器和均方误差损失函数。

训练模型

model.fit(X, Y, epochs=1, batch_size=1, verbose=2)
  • 使用数据XY训练模型,设置一个周期,批量大小为1。
  • verbose=2:输出每个epoch的进度以及每个epoch结束时的一些统计信息(如损失值)。

预测

predicted_prices = model.predict(X)
predicted_prices = scaler.inverse_transform(predicted_prices)
  • 使用模型进行预测。
  • 将预测结果从缩放后的数据转换回原始数据范围。

可视化结果

plt.figure(figsize=(10, 6))
plt.plot(prices, color='blue', label='实际价格')
plt.plot(np.arange(60, 100), predicted_prices, color='red', label='预测价格')
plt.title('股票价格预测')
plt.xlabel('时间')
plt.ylabel('价格')
plt.legend()
plt.show()
  • 绘制实际价格和预测价格的图表,蓝色表示实际价格,红色表示预测价格。可视化图表如下:

可以看出建立的LSTM模型的预测效果较好。

三:每日股票行情数据

想要探索更多元化的数据分析视角,可以关注之前发布的相关内容。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/53690.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux-用户与权限管理-组管理

在 Linux 系统中,用户、组与权限管理是保障系统安全的重要机制。用户和组的管理不仅涉及对系统资源的访问控制,还用于权限的分配和共享。组管理在 Linux 中尤其重要,它能够帮助管理员组织用户并为不同的组分配特定权限,从而控制用…

使用虚拟信用卡WildCard轻松订阅POE:全面解析平台功能与订阅方式

POE(Platform of Engagement)是一个由Quora推出的人工智能聊天平台,汇集了多个强大的AI聊天机器人,如GPT-4、Claude、Sage等。POE提供了一个简洁、统一的界面,让用户能够便捷地与不同的AI聊天模型进行互动。这种平台的…

基于SpringBoot的社团管理系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 基于JavaSpringBootVueMySQL的社团管理系统【附源码文档】、…

HashTable哈希表

概念 散列表(Hash Table),又称哈希表。是一种数据结构,特点是:数据元素的关键字与其存储地址直接相关 在顺序结构以及树型结构中,数据元素的关键字与其存储位置没有对应的关系,因此在查找一个元素时,必须要经过关键码…

Windows 平台安装 Nacos 2.x

环境准备 64 位操作系统,Windows 10 / Linux Centos 7JDK 1.8 安装包下载 安装包下载官方地址:https://github.com/alibaba/nacos/releases 启动 将安装包解压到安装的目录下,改名为 nacos-2.0.4。然后进行到 bin 目录下,打开…

软件测试服务公司出具第三方软件测试报告流程和周期简析

随着信息技术的飞速发展,软件已成为各行各业不可或缺的重要工具。然而,软件的质量直接影响到企业的效率和用户体验,因此,软件测试服务的重要性日益凸显。软件测试服务公司,顾名思义,就是专门提供专业的软件…

fpga系列 HDL:全连接层的浮点数加法器FADD实现

32 位 float 型的二进制存储 在fpga系列 HDL:全连接层的浮点数乘法器FM实现中已经提到过32 位 float 型的二进制存储结构。 32 位 f l o a t 型数 V ( − 1 ) S ∗ M ∗ 2 E 32 位 float 型数V(-1)^S*M*2^E 32位float型数V(−1)S∗M∗2E Layer 1 22 0 1 0 1 0 0 0 0 0 0 0 0…

工业机器人9公里远距离图传模块,无人机低延迟高清视界,跨过距离限制

在科技日新月异的今天,无线通信技术正以未有的速度发展,其中,图传模块作为连接现实与数字世界的桥梁,正逐步展现出其巨大的潜力和应用价值。今天,我们将聚焦一款引人注目的产品——飞睿智能9公里远距离图传模块&#x…

LiveKit的agent介绍

概念 LiveKit核心概念: Room(房间)Participant(参会人)Track(信息流追踪) Agent 架构图 ​ 订阅信息流 ​ agent交互流程 客户端操作 加入房间 房间创建方式 手动 赋予用户创建房间的…

JavaScript - Api学习 Day03 (日期对象、节点操作、两种定时器、本地存储)

文章目录 一、日期对象1.1 实例化1.2 日期对象方法 二、节点操作2.1 父子兄弟节点1. 父节点查找2. 子节点查找3. 兄弟关系查找 2.2 增删节点1. 创建节点 - createElement2. 添加节点2.1 appendChild() 方法2.2 insertBefore() 方法2.3. 克隆节点 - cloneNode 3. 删除节点3.1 re…

[数据集][目标检测]抽烟检测数据集VOC+YOLO格式22559张2类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):22559 标注数量(xml文件个数):22559 标注数量(txt文件个数):22559 标…

408算法题leetcode--第一天

参考 参考题单 1523. 在区间范围内统计奇数数目 1523. 在区间范围内统计奇数数目思路:数据量有 1 0 9 10^9 109,所以遍历求解会超时;而(low, high)区间中的奇数 (0, high) - (0, low - 1)的奇数时间和空间复杂度:O(1) class …

flink中slotSharingGroup() 的详解

在 Apache Flink 中,slotSharingGroup() 是一个用于控制算子(operator)之间资源共享的机制。它允许多个算子共享相同的 slot(即资源容器)。Slot 是 Flink 中的资源单位,slot 共享可以提高资源利用率&#x…

VisualStudio环境搭建C++

Visual Studio环境搭建 说明 C程序编写中,经常需要链接头文件(.h/.hpp)和源文件(.c/.cpp)。这样的好处是:控制主文件的篇幅,让代码架构更加清晰。一般来说头文件里放的是类的申明,函数的申明,全局变量的定义等等。源…

echarts图 图例跑上面去了 不在右边

legend 中缺少 "right": "1%", "top": "center",

PyTorch----模型运维与实战

一、PyTorch是什么 PyTorch 由Facebook开源的神经网络框架,专门针对 GPU 加速的深度神经网络(DNN)编程。 二、PyTorch安装 首先确保你已经安装了GPU环境,即Anaconda、CUDA和CUDNN 随后进入Pytorch官网​​​​​​PyTorch 官…

【机器学习】高斯网络的基本概念和应用领域以及在python中的实例

引言 高斯网络(Gaussian Network)通常指的是一个概率图模型,其中所有的随机变量(或节点)都遵循高斯分布 文章目录 引言一、高斯网络(Gaussian Network)1.1 高斯过程(Gaussian Proces…

细说STM32F407通用定时器的基础知识

目录 一、通用定时器功能概述 二、细说2通道定时器的功能 1.时钟信号和触发控制器 2.时基单元工作原理 (1)计数寄存器(CNT) (2)预分频寄存器(PSC) (3)自动重载寄存器(ARR) (4)时基单元的控制位 3.捕获/比较通道 三、生成PWM波 1.生成PWM波的原理 2.与生成PWM波相关的HA…

2023下半年软考网络规划

【考情分析】2023下半年软考网络规划设计师机考考情分析-真题解析公开课视频!_哔哩哔哩_bilibili2023年11月软考网络规划设计师案例分析解析与考后复盘_哔哩哔哩_bilibili全网首发!2023年下半年软考【高级】网规真题试卷--案例分析(部分回忆版…

表观遗传系列1:DNA 甲基化以及组蛋白修饰

1. 表观遗传 表观遗传信息很多为化学修饰,包括 DNA 甲基化以及组蛋白修饰,即DNA或蛋白可以通过化学修饰添加附加信息。 DNA位于染色质(可视为微环境)中,并不是裸露的,因此DNA分子研究需要跟所处环境结合起…