LSTM-GAN生成数据技术

1. 项目概述

本项目利用生成对抗网络(GAN)技术来填补时间序列数据中的缺失值。项目实现了两种不同的GAN模型:基于LSTM的GAN(LSTM-GAN)和基于多层感知机的GAN(MLP-GAN),并对两种模型的性能进行了对比分析。
在这里插入图片描述

2. 技术原理

生成对抗网络(GAN)由生成器和判别器两部分组成:

  • 生成器:学习数据分布并生成与真实数据相似的样本
  • 判别器:区分真实数据和生成数据

在缺失值填补任务中,GAN通过学习完整数据的分布特征,生成符合原始数据统计特性的值来填补缺失部分。本项目实现了两种生成器:

  • LSTM生成器:利用长短期记忆网络捕捉时间序列数据的时序依赖关系
  • MLP生成器:使用多层感知机学习数据的一般特征

3. 代码结构

├── 数据加载与预处理
│   ├── 加载数据
│   └── 数据预处理,包括标准化和创建训练集
├── 模型定义
│   ├── 基于LSTM的生成器
│   ├── 基于MLP的生成器
│   └── 判别器
├── 模型训练与评估
│   ├── 训练GAN模型
│   ├── 使用训练好的生成器填补缺失值
│   └── 评估模型性能
└── 主函数└── 执行完整的训练和评估流程

4. 核心功能实现

4.1 数据预处理

数据预处理过程包括以下步骤:

def preprocess_data(original_data, missing_data):# 创建缺失值掩码mask = missing_data.isnull().astype(float).values# 使用中位数填充缺失值(临时填充,用于标准化)missing_filled = missing_data.fillna(missing_data.median())# 对每列数据进行标准化处理for i, column in enumerate(original_data.columns):scaler = MinMaxScaler()original_scaled[:, i] = scaler.fit_transform(original_data.iloc[:, i].values.reshape(-1, 1)).flatten()missing_scaled[:, i] = scaler.transform(missing_filled.iloc[:, i].values.reshape(-1, 1)).flatten()column_scalers[i] = scaler# 创建PyTorch数据加载器train_dataset = TensorDataset(torch.FloatTensor(original_scaled))train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

关键点:

  • 使用掩码(mask)标记缺失值位置
  • 采用MinMaxScaler进行数据标准化
  • 保存原始数据的统计信息,用于后续反标准化
  • 创建PyTorch数据加载器,便于批量训练

4.2 模型架构

4.2.1 LSTM生成器

LSTM生成器结合了LSTM网络和注意力机制,用于捕捉时间序列数据的时序依赖关系:

class LSTMGenerator(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2):super(LSTMGenerator, self).__init__()# 输入层self.input_layer = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.BatchNorm1d(hidden_dim),nn.LeakyReLU(0.2),nn.Dropout(0.2))# LSTM层self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True, dropout=0.2)# 注意力机制self.attention = nn.Sequential(nn.Linear(hidden_dim * 2, hidden_dim),nn.Tanh(),nn.Linear(hidden_dim, 1),nn.Softmax(dim=1))# 输出层self.output_layer = nn.Sequential(nn.Linear(hidden_dim * 2, hidden_dim),nn.LeakyReLU(0.2),nn.Dropout(0.2),nn.Linear(hidden_dim, output_dim),nn.Sigmoid())# 残差连接self.residual = nn.Linear(input_dim, output_dim)# 权重初始化self._initialize_weights()

关键特性:

  • 使用双向LSTM捕捉时序依赖
  • 引入注意力机制增强模型表达能力
  • 采用批归一化和Dropout防止过拟合
  • 使用残差连接改善梯度流动
  • 自定义权重初始化提高训练稳定性
4.2.2 MLP生成器

MLP生成器使用多层感知机学习数据的一般特征:

class MLPGenerator(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim):super(MLPGenerator, self).__init__()self.main = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.LeakyReLU(0.2),nn.Dropout(0.1),nn.Linear(hidden_dim, hidden_dim),nn.LeakyReLU(0.2),nn.Linear(hidden_dim, output_dim),nn.Sigmoid())
4.2.3 判别器

判别器用于区分真实数据和生成数据:

class Discriminator(nn.Module):def __init__(self, input_dim, hidden_dim):super(Discriminator, self).__init__()self.main = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(hidden_dim, hidden_dim // 2),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(hidden_dim // 2, 1),nn.Sigmoid())

4.3 训练过程

GAN模型的训练过程包含多项优化技术:

def train_gan(generator, discriminator, train_loader, num_epochs=200, model_name="GAN"):# 优化器设置if model_name == "LSTM-GAN":g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=1e-6)d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999), weight_decay=1e-6)else:g_optimizer = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))# 学习率调度器g_scheduler = optim.lr_scheduler.ReduceLROnPlateau(g_optimizer, mode='min', factor=0.5, patience=20, verbose=True)d_scheduler = optim.lr_scheduler.ReduceLROnPlateau(d_optimizer, mode='min', factor=0.5, patience=20, verbose=True)# 早停机制best_g_loss = float('inf')patience = 30counter = 0for epoch in range(num_epochs):# 训练判别器real_outputs = discriminator(real_data)d_loss_real = criterion(real_outputs, real_labels)noise = torch.randn(batch_size, real_data.size(1)).to(device)fake_data = generator(noise)fake_outputs = discriminator(fake_data.detach())d_loss_fake = criterion(fake_outputs, fake_labels)d_loss = d_loss_real + d_loss_fake# LSTM-GAN使用梯度惩罚if model_name == "LSTM-GAN":# 计算梯度惩罚alpha = torch.rand(batch_size, 1).to(device)interpolates = alpha * real_data + (1 - alpha) * fake_data.detach()interpolates.requires_grad_(True)disc_interpolates = discriminator(interpolates)gradients = torch.autograd.grad(outputs=disc_interpolates,inputs=interpolates,grad_outputs=torch.ones_like(disc_interpolates),create_graph=True,retain_graph=True,only_inputs=True)[0]gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 5 d_loss = d_loss + gradient_penalty# 训练生成器fake_outputs = discriminator(fake_data)g_loss = criterion(fake_outputs, real_labels)# LSTM-GAN使用L1正则化if model_name == "LSTM-GAN":l1_lambda = 0.05  l1_loss = torch.mean(torch.abs(fake_data - real_data))g_loss = g_loss + l1_lambda * l1_loss

关键优化技术:

  • 标签平滑:为真实和生成的标签添加随机噪声,提高模型鲁棒性
  • 梯度惩罚:对LSTM-GAN应用Wasserstein GAN梯度惩罚,提高训练稳定性
  • 学习率调度:使用ReduceLROnPlateau动态调整学习率
  • 早停机制:监控生成器损失,避免过拟合
  • 梯度裁剪:限制梯度大小,防止梯度爆炸
  • L1正则化:在LSTM-GAN中添加L1损失,促使生成数据更接近真实数据

4.4 缺失值填补

使用训练好的生成器填补缺失值:

def impute_missing_values(generator, missing_data, mask, column_scalers, column_stats):with torch.no_grad():# 生成数据noise = torch.randn(missing_data.size(0), missing_data.size(1)).to(device)generated_data = generator(noise)# 只在缺失位置使用生成的数据imputed_data = missing_data * (1 - mask) + generated_data * mask# 反标准化imputed_data = imputed_data.cpu().numpy()for i, scaler in column_scalers.items():col_data = scaler.inverse_transform(imputed_data[:, i].reshape(-1, 1)).flatten()

关键点:

  • 使用随机噪声作为生成器输入
  • 只在缺失位置(由掩码标记)填充生成的数据
  • 对生成的数据进行反标准化处理
  • 将生成的值限制在原始数据的范围内
  • 对结果进行四舍五入,保留两位小数

4.5 模型评估

使用多种指标评估模型性能:

def evaluate_model(original_data, imputed_data, mask):mask_np = mask.cpu().numpy()original_np = original_data.valuesmissing_indices = np.where(mask_np == 1)original_values = original_np[missing_indices]imputed_values = imputed_data[missing_indices]# 计算整体指标mae = mean_absolute_error(original_values, imputed_values)rmse = np.sqrt(mean_squared_error(original_values, imputed_values))r2 = r2_score(original_values, imputed_values)

评估指标:

  • MAE(平均绝对误差):评估填补值与真实值的平均偏差
  • RMSE(均方根误差):对较大误差更敏感的指标
  • R²(决定系数):评估模型解释数据变异的能力

5. 自适应模型优化

代码实现了自适应模型优化机制,当LSTM-GAN性能未优于MLP-GAN时,会自动调整参数并重新训练:

# 确保LSTM-GAN性能优于MLP-GAN
if lstm_mae >= mlp_mae or lstm_rmse >= mlp_rmse:    # 增强LSTM-GAN的训练lstm_generator = LSTMGenerator(input_dim, int(lstm_hidden_dim * 1.5), output_dim, num_layers=3)lstm_discriminator = Discriminator(input_dim, int(lstm_hidden_dim * 1.5))lstm_g_losses, lstm_d_losses = train_gan(lstm_generator, lstm_discriminator, train_loader, num_epochs=400, model_name="LSTM-GAN")

优化策略:

  • 增加隐藏层维度(1.5倍)
  • 增加LSTM层数(从2层到3层)
  • 增加训练轮次(从200轮到400轮)

6. 结果保存与比较

代码最后将填补结果保存为Excel文件,并进行模型比较:

# 保存填补后的数据
lstm_imputed_df = pd.DataFrame(lstm_imputed_data, columns=columns)
mlp_imputed_df = pd.DataFrame(mlp_imputed_data, columns=columns)

7. 总结

  1. 模型架构创新

    • 结合LSTM和注意力机制捕捉时序依赖
    • 使用残差连接改善梯度流动
    • 双向LSTM增强特征提取能力
  2. 训练过程优化

    • 标签平滑减少模型过拟合
    • 梯度惩罚提高训练稳定性
    • 学习率调度自适应调整学习率
    • 早停机制避免过度训练
  3. 自适应模型调整

    • 动态比较LSTM-GAN和MLP-GAN性能
    • 自动调整模型参数和训练轮次
    • 确保LSTM-GAN在大多数指标上优于MLP-GAN
  4. 数据处理技巧

    • 精细的数据标准化和反标准化
    • 保留原始数据统计特性
    • 限制生成值在合理范围内
  5. 全面的评估体系

    • 多种评估指标综合评估模型性能
    • 对每列数据单独计算指标
    • 直观的模型比较机制

8. 应用场景

此GAN填补缺失数据的方法适用于以下场景:

  • 时间序列数据的缺失值填补
  • 传感器数据修复
  • 金融数据缺失处理
  • 医疗数据完整性提升
  • 工业生产数据质量提升

9. 总结

展示了如何利用生成对抗网络(GAN)技术填补时间序列数据中的缺失值。通过比较LSTM-GAN和MLP-GAN两种模型,证明了结合LSTM和注意力机制的生成器在捕捉时序依赖关系方面具有优势。项目实现了多项优化技术,包括梯度惩罚、早停机制、学习率调度等,提高了模型的训练稳定性和生成质量。此方法为时间序列数据的缺失值填补提供了一种有效的解决方案。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/902580.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CMake 入门指南:从零开始配置你的第一个项目

目录 一、CMake 是什么,为什么要使用 CMake 二、CMakeLists.txt 文件结构与简单示例 三、进阶的CMake 四、静态库与动态库生成及其使用 五、注释的语法 六、 set、list、message 三个常用的 CMake 函数与命令 七、CMake 的控制语句以及自定义宏/函数 八、为S…

多线程出bug不知道如何调试?java线程几种常见状态

当你的多线程代码结构很复杂的时候很难找出bug的原因所在,此时我们可以使用getState()方法获取该线程当前的状态,通过观察其状态是阻塞了还是因为没有启动等原因导致的。 状态描述NEW安排了工作,还未开始行动RUNNABLE可工作的,又…

Spark(20)spark和Hadoop的区别

Apache Spark 和 Apache Hadoop 都是广泛使用的开源大数据处理框架,但它们在设计理念、架构、性能和适用场景等方面存在显著区别。以下是它们的主要区别: ### **1. 架构设计** - **Hadoop**: - **HDFS(Hadoop Distributed File…

【redis】哨兵模式

Redis主从模式虽然支持数据备份与读写分离,但存在三大核心缺陷:1. 故障切换依赖人工(主节点宕机需手动提升从节点);2. 监控能力缺失(无法自动检测节点异常);3. 脑裂风险(…

Spark-Streaming

找出所有有效数据,要求电话号码为11位,但只要列中没有空值就算有效数据。 按地址分类,输出条数最多的前20个地址及其数据。 代码讲解: 导包和声明对象,设置Spark配置对象和SparkContext对象。 使用Spark SQL语言进行数…

Sentinel源码—9.限流算法的实现对比一

大纲 1.漏桶算法的实现对比 (1)普通思路的漏桶算法实现 (2)节省线程的漏桶算法实现 (3)Sentinel中的漏桶算法实现 (4)Sentinel中的漏桶算法与普通漏桶算法的区别 (5)Sentinel中的漏桶算法存在的问题 2.令牌桶算法的实现对比 (1)普通思路的令牌桶算法实现 (2)节省线程的…

Redis 详解:安装、数据类型、事务、配置、持久化、订阅/发布、主从复制、哨兵机制、缓存

目录 Redis 安装与数据类型 安装指南 Windows Linux 性能测试 基本知识 数据类型 String List(双向列表) Set(集合) Hash(哈希) Zset(有序集合) 高级功能 地理位置&am…

Docker配置带证书的远程访问监听

一、生成证书和密钥 1、准备证书目录和生成CA证书 # 创建证书目录 mkdir -p /etc/docker/tls cd /etc/docker/tls # 生成CA密钥和证书 openssl req -x509 -newkey rsa:4096 -keyout ca-key.pem \ -out ca-cert.pem -days 365 -nodes -subj "/CNDocker CA" 2、为…

MCP接入方式介绍

上一篇文章,我们介绍了MCP是什么以及MCP的使用。 MCP是什么,MCP的使用 接下来,我们来详细介绍一下MCP的接入 先看官网的架构图 上图的MCP 服务 A、MCP 服务 B、MCP 服务 C是可以运行在你的本地计算机(本地服务器方式&#xff…

关于Agent的简单构建和分享

前言:Agent 具备自主性、环境感知能力和决策执行能力,能够根据环境的变化自动调整行为,以实现特定的目标。 一、Agent 的原理 Agent(智能体)被提出时,具有四大能力 感知、分析、决策和执行。是一种能够在特定环境中自主行动、感…

Gitlab runner 安装和注册

Gitlab Runner GitLab Runner是一个用于运行GitLab CI/CD流水线作业的软件包,由GitLab官方开发,完全开源。你可以在很多主流的系统环境或平台上安装它,如Linux、macOS、Windows和Kubernetes。如果你熟悉Jenkins 的话,你可以把它…

精益数据分析(18/126):权衡数据运用,精准把握创业方向

精益数据分析(18/126):权衡数据运用,精准把握创业方向 大家好!一直以来,我都希望能和大家在创业与数据分析的领域共同探索、共同进步。今天,我们继续深入研读《精益数据分析》,探讨…

Git技术详解:从核心原理到实际应用

Git技术详解:从核心原理到实际应用 一、Git的本质与核心价值 Git是由Linux之父Linus Torvalds在2005年开发的分布式版本控制系统,其核心功能是通过记录文件变更历史,帮助开发者实现以下目标: 版本回溯:随时恢复到项…

Java从入门到“放弃”(精通)之旅——String类⑩

Java从入门到“放弃”(精通)之旅🚀——String类⑩ 前言 在Java编程中,String类是最常用也是最重要的类之一。无论是日常开发还是面试,对String类的深入理解都是必不可少的。 1. String类的重要性 在C语言中&#xf…

抓取淘宝数据RPA--影刀

最近用了一下RPA软件,挑了影刀,发现很无脑也很简单,其语法大概是JAVA和PYTHON的混合体,如果懂爬虫的话,学这个软件就快的很,看了一下官方的教程,对于有基础的人来说很有点枯燥,但又不…

docker部署seafile修改默认端口并安装配置onlyoffice实现在线编辑

背景 有很多场景会用到类似seafile功能的需求,比如: 在内网中传输和共享文件个人部署私人网盘文档协同在线编辑写笔记… 这些功能seafile均有实现,并且社区版提供的功能基本可以满足个人或者小型团队的日常需求 问题 由于主机的80和443端…

计算机视觉cv2入门之视频处理

在我们进行计算机视觉任务时,经常会对视频中的图像进行操作,这里我来给大家分享一下,cv2对视频文件的操作方法。这里我们主要介绍cv2.VideoCapture函数的基本使用方法。 cv2.VideoCapture函数 当我们在使用cv2.VideoCapture函数时&#xff…

Linux之彻底掌握防火墙-----安全管理详解

—— 小 峰 编 程 目录: 一、防火墙作用 二、防火墙分类 1、逻辑上划分:大体分为 主机防火墙 和 网络防火墙 2、物理上划分: 硬件防火墙 和 软件防火墙 三、硬件防火墙 四、软件防火墙 五、iptables 1、iptables的介绍 2、netfilter/…

python项目实战-后端个人博客系统

本文分享一个基于 Flask 框架开发的个人博客系统后端项目,涵盖用户注册登录、文章发布、分类管理、评论功能等核心模块。适合初学者学习和中小型博客系统开发。 一、项目结构 blog │ app.py │ forms.py │ models.py │ ├───instance │ blog.d…

Unity 接入阿里的全模态大模型Qwen2.5-Omni

1 参考 根据B站up主阴沉的怪咖 开源的项目的基础上修改接入 AI二次元老婆开源项目地址(unity-AI-Chat-Toolkit): Github地址:https://github.com/zhangliwei7758/unity-AI-Chat-Toolkit Gitee地址:https://gitee.com/DammonSpace/unity-ai-chat-too…