Transformer 与 LSTM 在时序回归中的实践与优化


🧠 深度学习混合模型:Transformer 与 LSTM 在时序回归中的实践与优化

在处理多特征输入、多目标输出的时序回归任务时,结合 Transformer 和 LSTM 的混合模型已成为一种有效的解决方案。Transformer 擅长捕捉长距离依赖关系,而 LSTM 在处理序列数据时表现出色。通过将两者结合,可以充分发挥各自的优势,提高模型的预测性能。


📊 数据生成与预处理

首先,我们生成一个包含多个特征的时序数据集,并进行必要的预处理。

import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split# 设置随机种子以确保结果可复现
np.random.seed(42)# 生成时间序列数据
n_samples = 1000
time_steps = 10
n_features = 5
X = np.random.rand(n_samples, time_steps, n_features)
y = np.random.rand(n_samples, 1)  # 假设我们有一个目标变量# 数据归一化
scaler_X = MinMaxScaler()
scaler_y = MinMaxScaler()X_scaled = X.reshape(-1, n_features)
X_scaled = scaler_X.fit_transform(X_scaled)
X_scaled = X_scaled.reshape(n_samples, time_steps, n_features)y_scaled = scaler_y.fit_transform(y)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y_scaled, test_size=0.2, random_state=42)

🧩 模型架构设计

我们设计一个结合 Transformer 和 LSTM 的混合模型架构。

import tensorflow as tf
from tensorflow.keras import layers, modelsdef build_transformer_lstm_model(input_shape, lstm_units=64, transformer_units=64, num_heads=4, num_layers=2, dropout_rate=0.1):inputs = layers.Input(shape=input_shape)# LSTM 层x = layers.LSTM(lstm_units, return_sequences=True)(inputs)x = layers.Dropout(dropout_rate)(x)# Transformer 层for _ in range(num_layers):attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=transformer_units)(x, x)x = layers.Add()([x, attention])x = layers.LayerNormalization()(x)x = layers.Dropout(dropout_rate)(x)# 输出层x = layers.GlobalAveragePooling1D()(x)x = layers.Dense(64, activation='relu')(x)x = layers.Dropout(dropout_rate)(x)outputs = layers.Dense(1)(x)model = models.Model(inputs, outputs)return model# 构建模型
input_shape = (X_train.shape[1], X_train.shape[2])
model = build_transformer_lstm_model(input_shape)
model.compile(optimizer='adam', loss='mean_squared_error', metrics=['mae'])

🏋️‍♂️ 模型训练与评估

from tensorflow.keras.callbacks import EarlyStopping# 定义早停机制
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)# 训练模型
history = model.fit(X_train, y_train, epochs=50, batch_size=32, validation_data=(X_test, y_test), callbacks=[early_stopping])# 评估模型
loss, mae = model.evaluate(X_test, y_test)
print(f"Test Loss: {loss}, Test MAE: {mae}")

🔧 超参数调优

我们使用 Keras Tuner 进行超参数调优。

import keras_tuner as ktdef model_builder(hp):model = build_transformer_lstm_model(input_shape)model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=hp.Float('learning_rate', min_value=1e-5, max_value=1e-2, sampling='log')),loss='mean_squared_error',metrics=['mae'])return model# 定义调优器
tuner = kt.Hyperband(model_builder,objective='val_loss',max_epochs=10,factor=3,directory='hyperband',project_name='transformer_lstm'
)# 执行超参数调优
tuner.search(X_train, y_train, epochs=50, validation_data=(X_test, y_test), callbacks=[early_stopping])# 获取最佳超参数
best_hps = tuner.get_best_hyperparameters()[0]
print(f"Best learning rate: {best_hps.get('learning_rate')}")

📈 结果可视化

import matplotlib.pyplot as plt# 绘制训练过程中的损失和 MAE
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.title('Loss Over Epochs')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(history.history['mae'], label='Train MAE')
plt.plot(history.history['val_mae'], label='Val MAE')
plt.title('MAE Over Epochs')
plt.legend()plt.tight_layout()
plt.show()

📝 总结

通过结合 Transformer 和 LSTM 的混合模型,可以实现更好地捕捉时序数据中的长期依赖关系和复杂模式。本章所讲述流程展示了从数据生成、模型设计到训练和评估的完整过程,并引入了早停机制和超参数调优,以提高模型的性能和稳定性。


 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/78760.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QT —— 信号和槽(带参数的信号和槽函数)

QT —— 信号和槽(带参数的信号和槽函数) 带参的信号和槽函数信号参数个数和槽函数参数个数1. 参数匹配规则2. 实际代码示例✅ 合法连接(槽参数 ≤ 信号参数)❌ 非法连接(槽参数 > 信号参数) 3. 特殊处理…

设计模式简述(十七)备忘录模式

备忘录模式 描述组件使用 描述 备忘录模式用于将对象的状态进行保存为备忘录,以便在需要时可以从备忘录会对象状态;其核心点在于备忘录对象及其管理者是独立于原有对象之外的。 常用于需要回退、撤销功能的场景。 组件 原有对象(包含自身…

标签语句分析

return userList.stream().filter(user -> {String tagsStr user.getTags(); 使用 Stream API 来过滤 userList 中的用户 解析 tagsStr 并根据标签进行过滤 假设 tagsStr 是一个 JSON 格式的字符串,存储了一个标签集合。你希望过滤出包含所有指定标签的用户。…

【应用密码学】实验四 公钥密码1——数学基础

一、实验要求与目的 学习快速模幂运算、扩展欧几里得、中国剩余定理的算法思想以及代码实现。 二、实验内容与步骤记录(只记录关键步骤与结果,可截图,但注意排版与图片大小) 1.快速模幂运算的设计思路 快速模幂运算的核心思想…

WebSocket与Socket、TCP、HTTP的关系及区别

1.什么是WebSocket及原理 WebSocket是HTML5中新协议、新API。 WebSocket从满足基于Web的日益增长的实时通信需求应运而生,解决了客户端发起多个Http请求到服务器资源浏览器必须要在经过长时间的轮询问题,实现里多路复用,是全双工、双向、单套…

基于C++的IOT网关和平台4:github项目ctGateway交互协议

初级代码游戏的专栏介绍与文章目录-CSDN博客 我的github:codetoys,所有代码都将会位于ctfc库中。已经放入库中我会指出在库中的位置。 这些代码大部分以Linux为目标但部分代码是纯C++的,可以在任何平台上使用。 源码指引:github源码指引_初级代码游戏的博客-CSDN博客 系…

【PPT制作利器】DeepSeek + Kimi生成一个初始的PPT文件

如何基于DeepSeek Kimi进行PPT制作 步骤: Step1:基于DeepSeek生成文本,提问 Step2基于生成的文本,用Kimi中PPT助手一键生成PPT 进行PPT渲染-自动渲染 可选择更改模版 生成PPT在桌面 介绍的比较详细,就是这个PPT模版…

拷贝多个Excel单元格区域为图片并粘贴到Word

Excel工作表Sheet1中有两个报表,相应单元格区域分别定义名称为Report1和Report2,如下图所示。 现在需要将图片拷贝图片粘贴到新建的Word文档中。 示例代码如下。 Sub Demo()Dim oWordApp As ObjectDim ws As Worksheet: Set ws ThisWorkbook.Sheets(&…

Spring是如何传播事务的?什么是事务传播行为

Spring是如何传播事务的? Spring框架通过声明式事务管理来传播事务,主要依赖于AOP(面向切面编程)和事务拦截器来实现。Spring的事务传播机制是基于Java Transaction API (JTA) 或者本地资源管理器(如Hibernate、JDBC等…

Python-pandas-操作Excel文件(读取数据/写入数据)及Excel表格列名操作详细分享

Python-pandas-操作Excel文件(读取数据/写入数据) 提示:帮帮志会陆续更新非常多的IT技术知识,希望分享的内容对您有用。本章分享的是pandas的使用语法。前后每一小节的内容是存在的有:学习and理解的关联性。【帮帮志系列文章】:每…

PHP分页显示数据,在phpMyadmin中添加数据

<?php $conmysqli_connect(localhost,root,,stu); mysqli_query($con,"set names utf8"); //设置字符集为utf8 $sql"select * from teacher"; $resultmysqli_query($con,$sql); $countmysqli_num_rows($result); //记录总条数$count。 $pagesize10;//每…

智能参谋部系统架构和业务场景功能实现

将以一个基于微服务和云原生理念、深度集成人工智能组件、强调实时性与韧性的系统架构为基础,详细阐述如何落地“智能参谋部”的各项能力。这不是一个简单的软件堆叠,而是一个有机整合了数据、知识、模型、流程与人员的复杂体系。 系统愿景:“智能参谋部”——基于AI赋能的…

企业级RAG架构设计:从FAISS索引到HyDE优化的全链路拆解,金融/医疗领域RAG落地案例与避坑指南(附架构图)

本文较长&#xff0c;纯干货&#xff0c;建议点赞收藏&#xff0c;以免遗失。更多AI大模型应用开发学习内容&#xff0c;尽在聚客AI学院。 一. RAG技术概述 1.1 什么是RAG&#xff1f; RAG&#xff08;Retrieval-Augmented Generation&#xff0c;检索增强生成&#xff09; 是…

Spring Boot Validation实战详解:从入门到自定义规则

目录 一、Spring Boot Validation简介 1.1 什么是spring-boot-starter-validation&#xff1f; 1.2 核心优势 二、快速集成与配置 2.1 添加依赖 2.2 基础配置 三、核心注解详解 3.1 常用校验注解 3.2 嵌套对象校验 四、实战开发步骤 4.1 DTO类定义校验规则 4.2 Cont…

理清缓存穿透、缓存击穿、缓存雪崩、缓存不一致的本质与解决方案

在构建高性能系统中&#xff0c;缓存&#xff08;如Redis&#xff09; 是不可或缺的关键组件&#xff0c;它大幅减轻了数据库压力、加快了响应速度。然而&#xff0c;在高并发环境下&#xff0c;缓存也可能带来一系列棘手的问题&#xff0c;如&#xff1a;缓存穿透、缓存击穿、…

PyTorch_构建线性回归

使用 PyTorch 的 API 来手动构建一个线性回归的假设函数&#xff0c;数据加载器&#xff0c;损失函数&#xff0c;优化方法&#xff0c;绘制训练过程中的损失变化。 数据构建 import torch from sklearn.datasets import make_regression import matplotlib.pyplot as plt i…

005-nlohmann/json 基础方法-C++开源库108杰

《二、基础方法》&#xff1a;节点访问、值获取、显式 vs 隐式、异常处理、迭代器、类型检测、异常处理……一节课搞定C处理JSON数据85%的需求…… JSON 字段的简单类型包括&#xff1a;number、boolean、string 和 null&#xff08;即空值&#xff09;&#xff1b;复杂类型则有…

HarmonyOS 5.0 分布式数据协同与跨设备同步​​

大家好&#xff0c;我是 V 哥。 使用 Mate 70有一段时间了&#xff0c;系统的丝滑使用起来那是爽得不要不要的&#xff0c;随着越来越多的应用适配&#xff0c;目前使用起来已经和4.3的兼容版本功能差异无碍了&#xff0c;还有些纯血鸿蒙独特的能力很是好用&#xff0c;比如&am…

Linux云计算训练营笔记day02(Linux、计算机网络、进制)

Linux 是一个操作系统 Linux版本 RedHat Rocky Linux CentOS7 Linux Ubuntu Linux Debian Linux Deepin Linux 登录用户 管理员 root a 普通用户 nsd a 打开终端 放大: ctrl shift 缩小: ctrl - 命令行提示符 [rootlocalhost ~]# ~ 家目录 /root 当前登录的用户…

macOS 安装了Docker Desktop版终端docker 命令没办法使用

macOS 安装了Docker Desktop版终端docker 命令没办法使用 1、检查Docker Desktop能否正常运行。 确保Docker Desktop能正常运行。 2、检查环境变量是否添加 1、添加环境变量 如果环境变量中没有包含Docker的路径&#xff0c;你可以手动添加。首先&#xff0c;找到Docker的…