【Python时序预测系列】建立CNN-LSTM-Transformer融合模型实现多变量时序预测(案例+源码)

这是我的第449篇原创文章。

一、引言

CNN(卷积)擅长抓“局部模式”,LSTM(长短时记忆网络)擅长记住“时间上的因果和长期依赖”,Transformer(自注意力)擅长把序列里任意两个时刻相互比较、找全局相关性,而且能并行处理。

融合方式:串联
CNN → LSTM → Transformer。先提取局部特征,再用 LSTM 建长期状态,最后用 Transformer 做全局交互。

下面通过一个具体的案例,融合CNN + LSTM + Transformer进行多变量输入单变量输出单步时间序列预测,包括模型构建、训练、预测等等。

二、实现过程

2.1 数据加载

核心代码:

df = pd.read_csv('data.csv', parse_dates=["Date"], index_col=[0]) df = pd.DataFrame(df)

结果:原始数据集总数5203

2.2 数据划分

核心代码:

test_split=round(len(df)*0.20) df_for_training=df[:-test_split] df_for_testing=df[-test_split:]

训练集:4162,测试集:1041

2.3 数据归一化

核心代码:

scaler = MinMaxScaler(feature_range=(0,1)) df_for_training_scaled = scaler.fit_transform(df_for_training) df_for_testing_scaled=scaler.transform(df_for_testing)

2.4 构造时序数据集

核心代码:

train_dataset = TimeSeriesDataset(df_for_training_scaled, seq_len=30, pred_len=1) test_dataset = TimeSeriesDataset(df_for_testing_scaled, seq_len=30, pred_len=1) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

时序训练集和测试集数组形状:

2.5 CNN_LSTM_Transformer模型

核心代码:

class CNN_LSTM_Transformer(nn.Module): def __init__(self, input_dim=5, cnn_channels=16, lstm_hidden=32, transformer_dim=32, transformer_heads=4, transformer_layers=1, pred_len=1): super().__init__() # CNN self.cnn = nn.Conv1d(in_channels=input_dim, out_channels=cnn_channels, kernel_size=3, padding=1) self.cnn_relu = nn.ReLU() # LSTM self.lstm = nn.LSTM(input_size=cnn_channels, hidden_size=lstm_hidden, batch_first=True) # Transformer Encoder encoder_layer = nn.TransformerEncoderLayer(d_model=transformer_dim, nhead=transformer_heads, batch_first=True) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=transformer_layers) # Projection layers self.proj_lstm = nn.Linear(lstm_hidden, transformer_dim) self.pred_len = pred_len self.fc_out = nn.Linear(transformer_dim, pred_len) def forward(self, x): # x: [batch, seq_len, 1] batch_size, seq_len, _ = x.shape # CNN expects [batch, channels, seq_len] cnn_out = self.cnn_relu(self.cnn(x.transpose(1,2))) # [B, C, T] cnn_out = cnn_out.transpose(1,2) # [B, T, C] # LSTM lstm_out, _ = self.lstm(cnn_out) # [B, T, hidden] lstm_proj = self.proj_lstm(lstm_out) # [B, T, transformer_dim] # Transformer trans_out = self.transformer(lstm_proj) # [B, T, transformer_dim] # 取最后时间步输出预测 out = self.fc_out(trans_out[:, -1, :]) # [B, pred_len] return out.unsqueeze(-1) # [B, pred_len, 1]

2.6 训练模型

核心代码:

def train_model(model, dataloader, num_epochs=50, learning_rate=1e-3, device='cpu'): optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) criterion = nn.MSELoss() model.train() loss_history = [] for epoch in range(num_epochs): epoch_losses = [] for batch_data, batch_targets in dataloader: batch_data = batch_data.to(device) batch_targets = batch_targets.to(device) optimizer.zero_grad() outputs = model(batch_data) loss = criterion(outputs, batch_targets) loss.backward() optimizer.step() epoch_losses.append(loss.item()) avg_loss = np.mean(epoch_losses) loss_history.append(avg_loss) if (epoch + 1) % 10 == 0: print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}") return loss_history

结果:

2.7 模型测试集评估

核心代码:

def evaluate_model(model, dataloader, device='cpu'): model.eval() preds = [] trues = [] with torch.no_grad(): for batch_data, batch_targets in dataloader: batch_data = batch_data.to(device) outputs = model(batch_data) preds.append(outputs.cpu().numpy()) trues.append(batch_targets.cpu().numpy()) preds = np.concatenate(preds, axis=0).squeeze() trues = np.concatenate(trues, axis=0).squeeze() return preds, trues

2.8 结果可视化

核心代码:

def visualize_results(loss_history, preds, trues): sns.set(font_scale=1.2) plt.rc('font', family=['Times New Roman', 'Simsun'], size=12) # 图 1:训练损失曲线 # 模型在训练过程中损失的下降情况,说明模型不断优化拟合数据。 plt.plot(loss_history, marker='o', color='dodgerblue', linestyle='-', linewidth=2) plt.title("Training Loss Curve") plt.xlabel("Epoch") plt.ylabel("MSE Loss") plt.tight_layout() plt.savefig('output_image1.png', dpi=300, format='png') plt.show() # 图 2:真实值与预测值对比曲线 # 对比曲线直观展示模型预测趋势与真实数据的匹配情况,越接近表示模型效果越好。 plt.plot(trues, label="True Values", color='limegreen') plt.plot(preds, label="Predicted Values", color='crimson') plt.title("True vs. Predicted Values") plt.xlabel("Sample Index") plt.ylabel("Trend Value") plt.legend() plt.tight_layout() plt.savefig('output_image2.png', dpi=300, format='png') plt.show()

图 1:训练损失曲线

图 2:真实值与预测值对比曲线

2.9 计算误差

核心代码:

testScore1 = math.sqrt(mean_squared_error(preds_test, trues_test)) print('Test Score: %.2f RMSE' % (testScore1)) testScore2 = mean_absolute_error(preds_test, trues_test) print('Test Score: %.2f MAE' % (testScore2)) testScore3 = r2_score(preds_test, trues_test) print('Test Score: %.2f R2' % (testScore3)) testScore4 = mean_absolute_percentage_error(preds_test, trues_test) print('Test Score: %.2f MAPE' % (testScore4))

结果:

作者简介:

读研期间发表6篇SCI数据挖掘相关论文,现在某研究院从事数据算法相关科研工作,结合自身科研实践经历不定期分享关于Python、机器学习、深度学习、人工智能系列基础知识与应用案例。致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。需要数据集和源码的小伙伴可以关注底部公众号添加作者微信。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/1219872.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机Java毕设实战-基于springboot的面向企业用户的复合型活动基地活动中心线上管理系统会议室预订系统【完整源码+LW+部署说明+演示视频,全bao一条龙等】

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

计算机Java毕设实战-基于协同过滤推荐算法的在线教育平台基于springboot+协同过滤课程推荐的线上安全教育平台【完整源码+LW+部署说明+演示视频,全bao一条龙等】

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

计算机Java毕设实战-基于SpringBoot+vue的本地生活攻略与美食发现平台基于web的美食探店平台【完整源码+LW+部署说明+演示视频,全bao一条龙等】

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

DBeaver连接sql server数据库时,提示驱动版本不合适

当你使用DBeaver,连接不上sql server数据库时,可以查看本篇文章,也许可以帮助你解决问题 目录 一.先看报错信息 二.尝试解决方案:降低驱动版本 1.打开“编辑连接”页面 2.点击“编辑驱动设置” —>"库"—>当前…

企业网站制作公司对比:2026年十大靠谱网站建设公司盘点

在数字经济深度渗透的2026年,企业官网已从单纯的“线上名片”进化为承载品牌价值传递、用户精准转化与全链路数据沉淀的核心数字化阵地。面对AI智能交互、3D沉浸式体验、移动端优先等技术趋势,选择适配的网站建设服务商成为企业数字化转型的关键决策。本…

什么是网关支付?网上付款的 “安全中转站”

一、 网关支付的定义网关支付(又称网络支付)是一种在线支付方式,指消费者通过商家平台,借助支付网关直接完成向商家付款的交易模式。二、 核心角色与作用支付网关是核心枢纽,连接消费者、商家与银行 / 清算机构&#x…

[特殊字符]天津别墅装修|选正规材料商,筑牢家的质感基底

🏡天津别墅装修|选正规材料商,筑牢家的质感基底天津作为北方高端住宅聚集地,别墅装修的品质需求持续升级,但某行业报告显示,近60%的别墅业主曾因材料商不正规遭遇环保超标、质感降级、售后推诿等问题。对于…

产品种类繁多,工艺路线录入太繁琐,用好APS排产的工艺路线批量导入,生产效率飙升

现代制造业中,工艺路线 定义了产品从原材料到成品的完整加工路径,当产品种类繁多时,逐个手动录入工艺路线效率就显得低下,并且容易出错。在APS排产系统里,工艺路线模块为产品生产的每个步骤流程搭建起了清晰明确的路径…

重构 CPython 的遐想:三个改变 Python 未来的关键设计

重构 CPython 的遐想:三个改变 Python 未来的关键设计 引言:站在巨人肩膀上的思考 作为一名与 Python 相伴十余年的开发者,我见证了它从小众脚本语言成长为全球最受欢迎的编程语言之一。从 Web 后端到数据科学,从自动化运维到人工智能,Python 的身影无处不在。根据 TIOB…

无人机滑模控制模块详解

1.什么是滑模控制?滑模控制是一种变结构控制,属于非线性鲁棒控制方法。其核心思想是:1.设计一个滑模面(SlidingSurface):一个在状态空间中的超平面或流形。当系统的状态轨迹被约束在这个面上时,…

GIL 的囚笼与自由:Python 多线程性能之谜完全解析

GIL 的囚笼与自由:Python 多线程性能之谜完全解析 引言:一个让人困惑的实验 2019年,我在优化一个数据处理系统时遇到了职业生涯中最反直觉的现象。我将单线程改为4线程处理,期待性能提升4倍,结果却发现多线程版本比单线程慢了20%。 这不是我的代码问题,而是触碰到了 P…

17.设置笔记本电脑不休眠

有时候,需要让笔记本电脑一直开着,不休眠。 记录下设置方法: 快捷键win i键。 搜索框:输入电源

国产知识协作平台如何重塑企业数字化转型路径

国产知识协作平台如何重塑企业数字化转型路径 在数字化转型的浪潮中,知识协作平台已从简单的文档存储工具进化为支撑企业核心业务的关键基础设施。随着国产化替代进程加速,Gitee Wiki、CODING Wiki等本土产品正凭借其独特的优势,在研发效能提…

救命神器10个AI论文网站,助你搞定研究生毕业论文!

救命神器10个AI论文网站,助你搞定研究生毕业论文! AI 工具如何成为论文写作的得力助手 在研究生阶段,撰写毕业论文是一项既重要又充满挑战的任务。面对繁重的研究内容和严格的格式要求,许多同学都感到无从下手。而随着 AI 技术的…

收藏!RAG技术全面解析:从基础到智能化的演进之路

本文系统梳理了检索增强生成(RAG)架构的演进历程,从Naive、Advanced、Modular到Agentic四代架构的发展。文章详细分析了各代架构的核心特点与技术突破,揭示了模块化设计、智能体协同等创新如何解决知识更新、语义对齐和复杂任务处理等关键问题&#xff0…

鼎捷ERP和MES系统集成方案详解,如何实现现有软件无缝对接?

某汽车零部件制造商通过上述方法将ERP与MES系统对接后,生产数据流转效率提升40%;某零售电商平台整合订单与仓储系统,使订单处理时长缩短至500毫秒内,错误率下降至0.01%。为实现新系统与既有架构的平滑集成,应基于业务流…

高标准康复理疗实训室,夯实职业技能基础

一、 建设标准是康复理疗实训室的核心前提建设标准决定了康复理疗实训室的起点与效能。高标准康复理疗实训室在空间规划上,严格区分评估、训练及治疗区域。高标准康复理疗实训室在设备选型上,对接行业主流技术与临床规范。这种对硬件与环境的精细要求&am…

国内iPaaS平台前十排行榜,鼎捷ERP和MES系统集成深度测评

企业数字化转型规模化落地阶段,iPaaS(集成平台即服务)已从可选工具升级为数字化基座,是打通数据孤岛、实现业务自动化的核心。ERP(企业资源计划系统)统筹全流程资源,MES(制造执行系统…

ABC422F题解

先考虑第一个限制,设每一行的 \(k\) 的集合为 \(\{k_1,k_2,\dots,k_n\}\),那么对于第 \((i,k_i)\) 的方格,如果它是白色,那根据第二个限制,\((i-1,k_{i-1})\) 的方格也一定是白色的(原因显然),于是我们成功消除…

精油品牌方必看:2026年值得关注的水性ODM厂商,机场香氛/固体香氛/天然植物精油香氛/除味香薰,精油公司推荐

在香氛行业快速发展的背景下,水性香薰精油因其环保性、安全性及适配性成为公共与私人空间环境优化的核心工具。作为连接品牌方与终端市场的关键环节,ODM厂商的技术实力、产能稳定性及定制化能力直接影响产品竞争力。…