基于LSTM-AutoEncoder的心电信号时间序列数据异常检测(PyTorch版)

时序异常检验
心电信号(ECG)的异常检测对心血管疾病早期预警至关重要,但传统方法面临时序依赖建模不足与噪声敏感等问题。本文使用一种基于LSTM-AutoEncoder的深度时序异常检测框架,通过编码器-解码器结构捕捉心电信号的长期时空依赖特征,并结合动态阈值自适应识别异常片段。模型在编码阶段利用LSTM层提取时序上下文信息,解码阶段重构正常ECG波形,以重构误差为异常评分依据。在MIT-BIH心律失常数据库上的实验表明,该方法在AUC-ROC(0.932)和F1-Score(0.876)上显著优于孤立森林、CNN-AE等基线模型,误报率降低23.6%。该技术可应用于可穿戴设备的实时心电监护,为临床提供高鲁棒性的自动化异常检测方案。

系列专栏:【深度学习:算法项目实战】✨︎
涉及医疗健康、财经金融、商业零售、食品饮料、运动健身、交通运输、环境科学、能源电力以及自然语言处理等诸多领域,探讨如何使用各种复杂的深度神经网络思想,如卷积神经网络、循环神经网络、生成对抗网络、门控循环单元、长短期记忆、注意力机制等实现时序预测、分类、异常检验以及概率预测。

1. 数据集介绍

本文使用ECG5000心电图时间序列数据集

import pandas as pd
from scipy.io.arff import loadarff
import matplotlib.pyplot as pltimport torch
import torch.nn as nn
from torch.utils.data import DataLoader, SubsetRandomSampler, TensorDataset
from torchinfo import summary
from torchmetrics.functional.classification import precision, recall, f1_score, auroc
from torchmetrics.functional.classification import binary_confusion_matrix
# Download the dataset
traindata, trainmeta = loadarff('../ECG5000/ECG5000_TRAIN.arff')
testdata, testmeta = loadarff('../ECG5000/ECG5000_TEST.arff')
train = pd.DataFrame(traindata, columns=trainmeta.names())
test = pd.DataFrame(testdata, columns=testmeta.names())
df = pd.concat([train, test])
print(train.shape, test.shape, df.shape)
(500, 141) (4500, 141) (5000, 141)

2. 数据可视化

将数据划分为正常心电信号数据normal和异常心电信号数据abnormal

normal = df[df.iloc[:, -1] == b'1']
abnormal = df[df.iloc[:, -1] != b'1']
# 设置全局字体样式
plt.style.use('ggplot')
plt.rcParams['font.family'] = 'serif'
fig, axes = plt.subplots(2, 1, figsize=(9, 12))# 绘制正常数据
axes[0].plot(normal.values.T)
axes[0].set_title('Normal Electrocardiogram (ECG)', fontsize=20, pad=10)# 绘制异常数据
axes[1].plot(abnormal.values.T)
axes[1].set_title('Abnormal Electrocardiogram (ECG)',fontsize=20,pad=10)# 调整子图间距
plt.tight_layout()
plt.show()

心电图

3. 数据预处理

# 2. 数据预处理
# 只使用正常样本训练自编码器
X_normal = normal.iloc[:, :-1].values
X_abnormal = abnormal.iloc[:, :-1].values

3.1 转换数据类型

# 转换为PyTorch张量 (添加通道维度)
normal_tensor = torch.tensor(data=X_normal, dtype=torch.float).unsqueeze(-1)
abnormal_tensor = torch.tensor(data=X_abnormal, dtype=torch.float).unsqueeze(-1)
print(normal_tensor.shape, abnormal_tensor.shape)
torch.Size([2919, 140, 1]) torch.Size([2081, 140, 1])

3.2 数据集划分(Subset)

# 划分训练集(正常样本)和验证集索引
dataset = TensorDataset(normal_tensor, normal_tensor)
train_idx = list(range(len(dataset)*4//5)) # 划分训练集索引
val_idx = list(range(len(dataset)*4//5, len(dataset))) # 划分验证集索引
print(len(train_idx), len(val_idx))
2335 584

划分测试集,包含异常数据,用于模型的最终测试。

# 划分测试集(正常+异常)
x_val_tensor = normal_tensor[val_idx]
x_test_tensor = torch.cat((x_val_tensor, abnormal_tensor), dim=0)
y_test_tensor = torch.cat((torch.zeros(len(x_val_tensor),dtype=torch.long),torch.ones(len(abnormal_tensor),dtype=torch.long)),dim=0
)
print(x_test_tensor.shape, y_test_tensor.shape)
torch.Size([2665, 140, 1]) torch.Size([2665])

3.3 数据加载器

通过 SubsetRandomSampler 从完整数据集 dataset 中按索引划分训练集和验证集,并生成批量数据迭代器‌。SubsetRandomSampler 会在每次迭代时随机打乱索引顺序,避免训练数据顺序固定导致的模型过拟合‌。

train_sampler = SubsetRandomSampler(indices=train_idx)
val_sampler = SubsetRandomSampler(indices=val_idx)
train_loader = DataLoader(dataset, batch_size=128, sampler=train_sampler)
val_loader = DataLoader(dataset, batch_size=128, sampler=val_sampler)

DataLoadersampler 参数优先级高于 shuffle,因此无需设置 shuffle=True‌

4. 构建时序异常检测模型

4.1 构建LSTM编码器

class Encoder(nn.Module):def __init__(self, context_len, n_variables, embedding_dim=64):super(Encoder, self).__init__()self.context_len, self.n_variables = context_len, n_variables  # 时间步、输入特征self.embedding_dim, self.hidden_dim = embedding_dim, 2 * embedding_dimself.lstm1 = nn.LSTM(input_size=self.n_variables,hidden_size=self.hidden_dim,num_layers=1,batch_first=True,)self.lstm2 = nn.LSTM(input_size=self.hidden_dim,hidden_size=embedding_dim,num_layers=1,batch_first=True,)def forward(self, x):batch_size = x.shape[0]x, (_, _) = self.lstm1(x)x, (hidden_n, _) = self.lstm2(x)return hidden_n.reshape((batch_size, self.embedding_dim))

4.2 构建LSTM解码器

class Decoder(nn.Module):def __init__(self, context_len, n_variables=1, input_dim=64):super(Decoder, self).__init__()self.context_len, self.input_dim = context_len, input_dimself.hidden_dim, self.n_variables = 2 * input_dim, n_variablesself.lstm1 = nn.LSTM(input_size=input_dim, hidden_size=input_dim, num_layers=1, batch_first=True)self.lstm2 = nn.LSTM(input_size=input_dim,hidden_size=self.hidden_dim,num_layers=1,batch_first=True,)self.output_layer = nn.Linear(self.hidden_dim, self.n_variables)def forward(self, x):batch_size = x.shape[0]x = x.repeat(self.context_len, self.n_variables)x = x.reshape((batch_size, self.context_len, self.input_dim))x, (hidden_n, cell_n) = self.lstm1(x)x, (hidden_n, cell_n) = self.lstm2(x)x = x.reshape((batch_size, self.context_len, self.hidden_dim))return self.output_layer(x)

4.3 构建LSTM AE

class LSTMAutoencoder(nn.Module):def __init__(self, context_len, n_variables, embedding_dim):super().__init__()self.encoder = Encoder(context_len, n_variables, embedding_dim)self.decoder = Decoder(context_len, n_variables, embedding_dim)def forward(self, x):x = self.encoder(x)x = self.decoder(x)return x

4.4 实例化模型、定义损失函数与优化器

automodel = LSTMAutoencoder(context_len=140, n_variables=1, embedding_dim=64)
optimizer = torch.optim.Adam(params=automodel.parameters(), lr=1e-4)
criterion = nn.MSELoss()

4.5 模型概要

summary(model=automodel, input_size=(128, 140, 1))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
LSTMAutoencoder                          [128, 140, 1]             --
├─Encoder: 1-1                           [128, 64]                 --
│    └─LSTM: 2-1                         [128, 140, 128]           67,072
│    └─LSTM: 2-2                         [128, 140, 64]            49,664
├─Decoder: 1-2                           [128, 140, 1]             --
│    └─LSTM: 2-3                         [128, 140, 64]            33,280
│    └─LSTM: 2-4                         [128, 140, 128]           99,328
│    └─Linear: 2-5                       [128, 140, 1]             129
==========================================================================================
Total params: 249,473
Trainable params: 249,473
Non-trainable params: 0
Total mult-adds (Units.GIGABYTES): 4.47
==========================================================================================
Input size (MB): 0.07
Forward/backward pass size (MB): 55.19
Params size (MB): 1.00
Estimated Total Size (MB): 56.26
==========================================================================================

5. 模型训练

5.1 定义训练函数

在模型训练之前,我们需先定义 train 函数来执行模型训练过程

def train(model, iterator):model.train()epoch_loss = 0for batch_idx, (data, target) in enumerate(iterable=iterator):optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()epoch_loss += loss.item()avg_loss = epoch_loss / len(iterator)return avg_loss

上述代码定义了一个名为 train 的函数,用于训练给定的模型。它接收模型、数据迭代器作为参数,并返回训练过程中的平均损失。

5.2 定义评估函数

def evaluate(model, iterator): # Being used to validate and testmodel.eval()epoch_loss = 0with torch.no_grad():for batch_idx, (data, target) in enumerate(iterable=iterator):output = model(data)loss = criterion(output, target)epoch_loss += loss.item()avg_loss = epoch_loss / len(iterator)return avg_loss

上述代码定义了一个名为 evaluate 的函数,用于评估给定模型在给定数据迭代器上的性能。它接收模型、数据迭代器作为参数,并返回评估过程中的平均损失。这个函数通常在模型训练的过程中定期被调用,以监控模型在验证集或测试集上的性能。通过评估模型的性能,可以了解模型的泛化能力和训练的进展情况。

5.3 定义早停法并保存模型

定义早停法以便在模型训练过程中调用

class EarlyStopping:def __init__(self, patience=5, delta=0.0):self.patience = patience  # 允许的连续未改进次数self.delta = delta        # 损失波动容忍阈值self.counter = 0          # 未改进计数器self.best_loss = float('inf')  # 最佳验证损失值self.early_stop = False   # 终止训练标志def __call__(self, val_loss, model):if val_loss < (self.best_loss - self.delta):self.best_loss = val_lossself.counter = 0# 保存最佳模型参数‌:ml-citation{ref="1,5" data="citationList"}torch.save(model.state_dict(), 'best_model.pth')else:self.counter +=1if self.counter >= self.patience:self.early_stop = True
EarlyStopper = EarlyStopping(patience=10, delta=0.00001)  # 设置参数

若不想使用早停法EarlyStopper,参数patience设置一个超大的值,delta设置为0,即可。

5.4 定义模型训练主程序

通过定义模型训练主程序来执行模型训练

def main():train_losses = []val_losses = []for epoch in range(300):train_loss = train(model=automodel, iterator=train_loader)val_loss = evaluate(model=automodel, iterator=val_loader)train_losses.append(train_loss)val_losses.append(val_loss)print(f'Epoch: {epoch + 1:02}, Train MSELoss: {train_loss:.5f}, Val. MSELoss: {val_loss:.5f}')# 触发早停判断EarlyStopper(val_loss, model=automodel)if EarlyStopper.early_stop:print(f"Early stopping at epoch {epoch}")breakplt.figure(figsize=(10, 5))plt.plot(train_losses, label='Training Loss')plt.plot(val_losses, label='Validation Loss')plt.xlabel('Epoch')plt.ylabel('MSELoss')plt.title('Training and Validation Loss over Epochs')plt.legend()plt.grid(True)plt.show()

5.5 执行模型训练过程

main()
Epoch: 69, Train MSELoss: 0.21886, Val. MSELoss: 0.21556
Epoch: 70, Train MSELoss: 0.22166, Val. MSELoss: 0.21716
Epoch: 71, Train MSELoss: 0.22082, Val. MSELoss: 0.20737
Epoch: 72, Train MSELoss: 0.21676, Val. MSELoss: 0.20873
Epoch: 73, Train MSELoss: 0.22007, Val. MSELoss: 0.21766
Epoch: 74, Train MSELoss: 0.22644, Val. MSELoss: 0.21219
Epoch: 75, Train MSELoss: 0.22045, Val. MSELoss: 0.20890
Epoch: 76, Train MSELoss: 0.22027, Val. MSELoss: 0.21222
Epoch: 77, Train MSELoss: 0.21933, Val. MSELoss: 0.20765
Epoch: 78, Train MSELoss: 0.22219, Val. MSELoss: 0.20903
Epoch: 79, Train MSELoss: 0.22051, Val. MSELoss: 0.20856
Epoch: 80, Train MSELoss: 0.22001, Val. MSELoss: 0.21346
Epoch: 81, Train MSELoss: 0.21968, Val. MSELoss: 0.21276
Early stopping at epoch 80

损失

6. 异常检测

6.1 异常检测

接下来,我们通过构建 detect_anomalies 函数来对模型中的数据进行检测。

# 5. 异常检测
def detect_anomalies(model, x):model.eval()with torch.no_grad():reconstructions = model(x)mse = torch.mean((x - reconstructions)**2, dim=(1,2))return mse

6.2 设置阈值

# 在测试集上计算重建误差
test_mse = detect_anomalies(automodel, x_test_tensor)# 设置阈值 (使用验证集正常样本的95%分位数)
val_mse = detect_anomalies(automodel, x_val_tensor)
threshold = torch.quantile(val_mse, 0.95)# 预测结果
y_pred = (test_mse > threshold).long()
print(f'Threshold: {threshold:.4f}')
print(y_pred.dtype)
print(y_pred.shape)
Threshold: 0.5402
torch.int64
torch.Size([2665])

7. 模型评估

7.1 评估函数

torchmetrics库提供了各种评估函数,例如:精确率Precision、召回率Recall、F1分数F1-Score Area Under ROC Curve \text{Area Under ROC Curve} Area Under ROC Curve,我们可以直接用来评估模型性能

pre = precision(preds=y_pred, target=y_test_tensor, task="binary")
print(f"Precision: {pre:.5f}")rec = recall(preds=y_pred, target=y_test_tensor, task="binary")
print(f"Recall: {rec:.5f}")f1 = f1_score(preds=y_pred, target=y_test_tensor, task="binary")
print(f"F1 Score: {f1:.5f}")auc = auroc(preds=test_mse, target=y_test_tensor, task="binary")
print(f"AUC: {auc:.5f}")
Precision: 0.98586
Recall: 0.97165
F1 Score: 0.97870
AUC: 0.98020

7.2 混淆矩阵

cm = binary_confusion_matrix(preds=y_pred, target=y_test_tensor)
print(cm)
tensor([[ 555,   29],[  59, 2022]])

预测可视化

# 7. 可视化部分结果
plt.figure(figsize=(12, 6))
plt.plot(test_mse, label='Reconstruction Error')
plt.axhline(threshold, color='r', linestyle='--', label='Threshold')
plt.title('Anomaly Detection Results')
plt.xlabel('Sample Index')
plt.ylabel('MSE')
plt.legend()
plt.show()

Anomaly Detection Results

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/78010.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Docker 部署 PostgreSQL 数据库

Docker 部署 PostgreSQL 数据库 基于 Docker 部署 PostgreSQL 数据库一、拉取 PostgreSQL 镜像二、运行 PostgreSQL 容器三、运行命令参数详解四、查看容器运行状态 基于 Docker 部署 PostgreSQL 数据库 一、拉取 PostgreSQL 镜像 首先&#xff0c;确保你的 Docker 环境已正确…

MySQL性能调优(四):MySQL的执行原理(MYSQL的查询成本)

文章目录 MySQL性能调优数据库设计优化查询优化配置参数调整硬件优化 1.MySQL的执行原理-21.1.MySQL的查询成本1.1.1.什么是成本1.1.2.单表查询的成本1.1.2.1.基于成本的优化步骤实战1. 根据搜索条件&#xff0c;找出所有可能使用的索引2. 计算全表扫描的代价3. 计算使用不同索…

用 Go 优雅地清理 HTML 并抵御 XSS——Bluemonday

1、背景与动机 只要你的服务接收并回显用户生成内容&#xff08;UGC&#xff09;——论坛帖子、评论、富文本邮件正文、Markdown 等——就必须考虑 XSS&#xff08;Cross‑Site Scripting&#xff09;攻击风险。浏览器在解析 HTML 时会执行脚本&#xff1b;如果不做清理&#…

Redis SCAN 命令的详细介绍

Redis SCAN 命令的详细介绍 以下是 Redis SCAN​ 命令的详细介绍&#xff0c;结合其核心特性、使用场景及底层原理进行综合说明&#xff1a; 工作原理图 &#xff1a; ​ 一、核心特性 非阻塞式迭代 通过游标&#xff08;Cursor&#xff09; 分批次遍历键&#xff0c;避免一次…

SpringBoot3集成MyBatis-Plus(解决Boot2升级Boot3)

总结&#xff1a;目前升级仅发现依赖有变更&#xff0c;其他目前未发现&#xff0c;如有发现&#xff0c;后续会继续更新 由于项目架构提升&#xff0c;以前开发的很多公共的组件&#xff0c;以及配置都需要升级&#xff0c;因此记录需要更改的配置&#xff08;记录时间&#…

基于mybatis与PageHelper插件实现条件分页查询(3.19)

实现商品分页例子 需要先引入mybatis与pagehelper插件&#xff0c;在pom.xml里 <!-- Mybatis --> <dependency><groupId>org.mybatis.spring.boot</groupId><artifactId>mybatis-spring-boot-starter</artifactId><version>3.0.3&l…

Spring Bean 全方位指南:从作用域、生命周期到自动配置详解

目录 1. Bean 的作用域 1.1 singleton 1.2 prototype 1.3 request 1.4 session 1.5 application 1.5.1 servletContext 和 applicationContext 区别 2. Bean 的生命周期 2.1 详解初始化 2.1.1 Aware 接口回调 2.1.2 执行初始化方法 2.2 代码示例 2.3 源码 [面试题…

C++ (非类型参数)

模板除了定义类型参数之外&#xff0c;也可以在模板内定义非类型参数 非类型参数不是类型&#xff0c;而是值&#xff0c;比如&#xff1a;指针&#xff0c;整数&#xff0c;引用 非类型参数的用法&#xff1a; 1.整数常量&#xff1a;非类型参数最常见的形式是整数常量&…

短视频+直播商城系统源码全解析:音视频流、商品组件逻辑剖析

时下&#xff0c;无论是依托私域流量运营的品牌方&#xff0c;还是追求用户粘性与转化率的内容创作者&#xff0c;搭建一套完整的短视频直播商城系统源码&#xff0c;已成为提升用户体验、增加商业变现能力的关键。本文将围绕三大核心模块——音视频流技术架构、商品组件设计、…

5.QT-常用控件-QWidget|enabled|geometry|window frame(C++)

控件概述 实现图形化界面的程序. Qt中已经给我们提供了很多的“控件" 就需要学习和了解这些控件&#xff0c;学会如何使用这些控件 编程讲究的是“站在巨人的肩膀上”&#xff0c;而不是“从头发明轮子" 一个图形化界面上的内容&#xff0c;不需要咱们全都从零去实…

2025-04-22| Docker: --privileged参数详解

在 Docker 中&#xff0c;--privileged 是一个运行容器时的标志&#xff0c;它赋予容器特权模式&#xff0c;大幅提升容器对宿主机资源的访问权限。以下是 --privileged 的作用和相关细节&#xff1a; 作用 完全访问宿主机的设备&#xff1a; 容器可以访问宿主机的所有设备&am…

高性能服务器配置经验指南1——刚配置好服务器应该做哪些事

文章目录 安装ubuntu安装必要软件设置用户远程连接安全问题ClamAV安装教程步骤 1&#xff1a;更新系统软件源步骤 2&#xff1a;升级系统&#xff08;可选但推荐&#xff09;步骤 3&#xff1a;安装 ClamAV步骤 4&#xff1a;更新病毒库步骤 5&#xff1a;验证安装ClamAV 常用命…

直流绝缘监测解决方案:保障工业与新能源系统的安全运行

一、引言 随着工业自动化和新能源技术的快速发展&#xff0c;直流供电系统在光伏发电、储能电站、电动汽车充电桩等领域的应用日益广泛。然而&#xff0c;直流系统的正负极不接地&#xff08;IT系统&#xff09;特性&#xff0c;使得绝缘故障可能导致漏电、短路甚至设备损毁等…

VSCode 用于JAVA开发的环境配置,JDK为1.8版本时的配置

插件安装 JAVA开发在VSCode中&#xff0c;需要安装JAVA的必要开发。当前安装只需要安装 “Language Support for Java(TM) by Red Hat”插件即可 安装此插件后&#xff0c;会自动安装包含如下插件&#xff0c;不再需要单独安装 Project Manager for Java Test Runner for J…

C++入门语法

C入门 首先第一点&#xff0c;C中可以混用C语言中的语法。但是C语言是不兼容C的。C主要是为了改进C语言而创建的一门语言&#xff0c;就是有人用C语言用不爽了&#xff0c;改出来个C。 命名空间 c语言中会有如下这样的问题&#xff1a; 那么C为了解决这个问题就整出了一个命名…

输入框仅支持英文、特殊符号、全角自动转半角 vue3

需求&#xff1a;封装一个输入框组件 1.只能输入英文。 2.输入的小写英文自动转大写。 3.输入的全角特殊符号自动转半角特殊字符 效果图 代码 <script setup> import { defineEmits, defineModel, defineProps } from "vue"; import { debounce } from "…

Uniapp:创建项目

目录 一、前提准备二、创建项目三、项目结构四、运行测试 一、前提准备 首先要创建uniapp项目&#xff0c;需要先下载HBuilderX&#xff0c;HBuilderX是一款开箱即用的工具&#xff0c;下载完毕之后&#xff0c;解压到指定的目录即可使用&#xff0c;需要注意的是最好路径里面…

ESM 内功心法:化解 require 中的夺命一击!

前言 传闻在JavaScript与TypeScript武林中,曾有两大绝世心法:CommonJS与ESM。两派高手比肩而立,各自称霸一方,江湖一度风平浪静。 岂料,时局突变。ESM逐步修成阳春白雪之姿,登堂入室,成为主流正统。CommonJS则渐入下风,功力不济,逐渐退出主舞台。 话说某日,一位前…

【STL】unordered_set

在 C C C 11 11 11 中&#xff0c; S T L STL STL 标准库引入了一个新的标准关联式容器&#xff1a; u n o r d e r e d _ s e t unordered\_set unordered_set&#xff08;无序集合&#xff09;。功能和 s e t set set 类似&#xff0c;都用于存储唯一元素。但是其底层数据结…

go语言八股文

1.go语言的接口是怎么实现 接口&#xff08;interface&#xff09;是一种类型&#xff0c;它定义了一组方法的集合。任何类型只要实现了接口中定义的所有方法&#xff0c;就被认为实现了该接口。 代码的实现 package mainimport "fmt"// 定义接口 type Shape inte…