【信号处理】基于DGGAN的单通道脑电信号增强和情绪检测(tensorflow)

关于

情绪检测,是脑科学研究中的一个常见和热门的方向。在进行情绪检测的分类中,真实数据不足,经常导致情绪检测模型的性能不佳。因此,对数据进行增强,成为了一个提升下游任务的重要的手段。本项目通过DCGAN模型实现脑电信号的扩充。

 图片来源:https://www.medicalnewstoday.com/articles/seizure-eeg

工具

数据

方法实现

DCGAN速递:https://arxiv.org/abs/1511.06434

数据加载和预处理
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from sklearn.preprocessing import StandardScaler
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.layers import Embedding
from tensorflow.keras.layers import LSTM
from tensorflow.keras.optimizers import SGD
from sklearn.metrics import accuracy_score
from model_DCGAN import DCGAN
from tensorflow.keras.optimizers import Adam, SGD, RMSprop
from sklearn.utils import shuffle
from sklearn.ensemble import GradientBoostingClassifieruse_feature_reduction = Truetf.keras.backend.clear_session()df=pd.read_csv('dataset/emotions.csv')encode = ({'NEUTRAL': 0, 'POSITIVE': 1, 'NEGATIVE': 2} )
#new dataset with replaced values
df_encoded = df.replace(encode)print(df_encoded.head())
print(df_encoded['label'].value_counts()),x=df_encoded.drop(["label"]  ,axis=1)
y = df_encoded.loc[:,'label'].valuesscaler = StandardScaler()
scaler.fit(x)
x = scaler.transform(x)
y = to_categorical(y)x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2, random_state = 4)if use_feature_reduction:# Feature reduction partest = GradientBoostingClassifier(n_estimators=10, learning_rate=0.1, random_state=0).fit(x_train,y_train.argmax(-1))# Obtain feature importance results from Gradient Boosting Regressorfeature_importance = est.feature_importances_epsilon_feature = 1e-2x_train = x_train[:, feature_importance > epsilon_feature]x_test = x_test[:, feature_importance > epsilon_feature]
设置DCGAN优化器

# setup optimzers
gen_optim = Adam(1e-4, beta_1=0.5)
disc_optim = RMSprop(5e-4)
 训练GAN生成类别0脑电数据
# generate samples for class 0
generator_class = 0
dcgan = DCGAN(gen_optim, disc_optim, noise_dim=100, dropout=0.3, input_dim=x_train.shape[2])
x_train_class_0 = x_train[y_train[:,generator_class]==1,:]
loss_history_class_0, acc_history_class_0, grads_history_class_0 = dcgan.train(x_train_class_0, epochs=100)
print("Class 0 fake samples are generating")
generator_class_0 = dcgan.generator
generated_samples_class_0, _ = dcgan.generate_fake_data(N=len(x_train_class_0))
  训练GAN生成类别1脑电数据
# generate samples for class 1
generator_class = 1
dcgan = DCGAN(gen_optim, disc_optim, noise_dim=100, dropout=0.3, input_dim=x_train.shape[2])
x_train_class_1 = x_train[y_train[:,generator_class]==1,:]
loss_history_class_1, acc_history_class_1, grads_history_class_1 = dcgan.train(x_train_class_1, epochs=100)
print("Class 1 fake samples are generating")
generator_class_1 = dcgan.generator
generated_samples_class_1, _ = dcgan.generate_fake_data(N=len(x_train_class_1))
 训练GAN生成类别2脑电数据
# generate samples for class 2
generator_class = 2
dcgan = DCGAN(gen_optim, disc_optim, noise_dim=100, dropout=0.3, input_dim=x_train.shape[2])
x_train_class_2 = x_train[y_train[:,generator_class]==1,:]
loss_history_class_2, acc_history_class_2, grads_history_class_2 = dcgan.train(x_train_class_2,epochs=100)
print("Class 2 fake samples are generating")
generator_class_2 = dcgan.generator
generated_samples_class_2, _ = dcgan.generate_fake_data(N=len(x_train_class_2))
合成数据融入真实训练数据集
generated_samples = np.concatenate((generated_samples_class_0,generated_samples_class_1,generated_samples_class_2),axis=0)
generated_y =np.concatenate((np.zeros((len(x_train_class_0),),dtype=np.int32),np.ones((len(x_train_class_1),),dtype=np.int32),2 * np.ones((len(x_train_class_2),),dtype=np.int32)),axis=0)generated_y = to_categorical(generated_y)x_train_all = np.concatenate((x_train,generated_samples),axis=0)
y_train_all = np.concatenate((y_train,generated_y), axis=0)#shuffle training data
x_train_all, y_train_all = shuffle(x_train_all,y_train_all)
 基于数据增强的LSTM模型情绪检测
model = Sequential()
model.add(LSTM(64, input_shape=(1,x_train_all.shape[2]),activation="relu",return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(32,activation="sigmoid"))
model.add(Dropout(0.2))
model.add(Dense(3, activation='sigmoid'))
model.compile(loss = 'categorical_crossentropy', optimizer = "adam", metrics = ['accuracy'])
model.summary()history = model.fit(x_train_all, y_train_all, epochs = 250, validation_data= (x_test, y_test))
score, acc = model.evaluate(x_test, y_test)pred = model.predict(x_test)
predict_classes = np.argmax(pred,axis=1)
expected_classes = np.argmax(y_test,axis=1)
print(expected_classes.shape)
print(predict_classes.shape)
correct = accuracy_score(expected_classes,predict_classes)
print(f"Test Accuracy: {correct}")

已附DCGAN模型

相关项目和代码问题,欢迎沟通交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/775980.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于STC12C5A60S2系列1T 8051单片机的按键单击长按实现互不干扰增加减少数值应用

基于STC12C5A60S2系列1T 8051单片机的按键单击长按实现互不干扰增加减少数值应用 STC12C5A60S2系列1T 8051单片机管脚图STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式及配置STC12C5A60S2系列1T 8051单片机I/O口各种不同工作模式介绍基于STC12C5A60S2系列1T 8051单片机的…

iscsi网络协议(连接硬件设备)

iscsi概念 iscsi是一种互联网协议,用于将存储设备(如硬盘驱动器或磁带驱动器)通过网络连接到计算机。它是一种存储区域网络(SAN)技术,允许服务器通过网络连接到存储设备,就像它们是本地设备一样…

区块链技术与大数据结合的商业模式探索

hello宝子们...我们是艾斯视觉擅长ui设计和前端开发10年经验!希望我的分享能帮助到您!如需帮助可以评论关注私信我们一起探讨!致敬感谢感恩! 随着区块链技术和大数据技术的不断发展,两者的结合为企业带来了新的商业模式…

科东软件联手英特尔,用工业AI智能机器人赋能工业升级

AI浪潮已经冲击到各行各业中,它能够帮助人们提高思考和生产效率。在创作中,AI能够帮助人们释放创意,那在工业中,AI能够为产业带来什么呢? 科东软件是国内专注于操作系统开发的企业。当前,科东开发的Intewe…

机器学习——贝叶斯分类器(基础理论+编程)

目录 一、理论 1、初步引入 2、做简化 3、拉普拉斯修正 二、实战 1、计算P(c) 2、计算P(x|c) 3、实战结果 1、数据集展示 2、相关信息打印 一、理论 1、初步引入 在所有相关概率都已知的理想情形下,贝叶斯决策论考虑如何基于这些概率和误判损失来选择最…

Jenkins升级中的小问题

文章目录 使用固定版本安装根据jenkins页面下载war包升级jenkins重启jenkins报错问题解决 K8s部署过程中的一些小问题 ##### Jenkins版本小插曲 ​ 在Jenkins环境进行插件安装时全部清一色飘红,发现是因为Jenkins版本过低导致,报错的位置可以找到更新je…

巨控GRM560工业物联网的升级后的功能

巨控GRM560:工业自动化领域的革命者 标签:#工业自动化 #PLC #远程控制 #OPCUA #MQTT 随着工业4.0时代的到来,智能制造已经成为了发展的大势所趋。在这样的背景下,自动化控制系统的核心——可编程逻辑控制器(PLC)的作用…

shell脚本发布docker-nginx vue2 项目示例

docker、git、node.js安装略过。 使git pull或者git push不需要输入密码操作方法 nginx安装在docker容器里面,参见:https://blog.csdn.net/HSJ0170/article/details/128631155 姊妹篇(宿主机nginx,非docker-nginx)&am…

基于java+SpringBoot+Vue的数码论坛系统设计与实现

基于javaSpringBootVue的数码论坛系统设计与实现 开发语言: Java 数据库: MySQL技术: SpringBoot MyBatis工具: IDEA/Eclipse、Navicat、Maven 系统展示 前台展示 后台展示 系统简介 整体功能包含: 数码论坛系统是一个基于互联网的数码产品讨论和信息分享平台…

深度学习语义分割篇——DeepLabV2原理详解篇

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题 🍊专栏推荐:深度学习网络原理与实战 🍊近期目标:写好专栏的每一篇文章 🍊支持小苏:点赞👍🏼、…

小狐狸JSON-RPC:wallet_watchAsset(向钱包中新增资产代币)

wallet_watchAsset 请求用户在 MetaMask 中添加新的资产。返回一个布尔值,是否已成功添加。 var res await window.ethereum.request({ "method": "wallet_watchAsset","params": {"type": "ERC20","opti…

盘点库存怎么做账

库存的盘点是企业中非常重要的一步,也是仓管经常要做的工作,盘点通俗点说就是点一下实物与账面上的数据是否一至,来判断我们平时的货物管理是否与账面上的业务往来符合,盘点库存怎么做账? 按目前的情况来看&#xff0c…

【数据结构】Java中Map和Set详解(含二叉搜索树和哈希表)

目录 Map和Set详解 1.二叉搜索树 2.Map常见方法 3.Set常见方法 4.哈希表 Map和Set详解 Map:一种键值对结构,hashMap中键和值均可以为空,hashTable中则不可以存放null值 Set:一种集合,不能存放重复元素&#xff0c…

SpringBoot使用Jedis步骤

基础连接方式 引入依赖 <!-- Jedis --><dependency><groupId>redis.clients</groupId><artifactId>jedis</artifactId></dependency>创建Jedis对象&#xff0c;建立连接 操作字符串 方法名与Rdeis命令一致 操作Hash类型 释放资源 测…

【小米SU7实测发布】Python与人工智能的结合

小米在2023年底正式发布小米SU7,成为继华为之后第二个推出成品的的科技企业。不过此时小米需要做的不仅是打造一款产品力够高的车型,更是要以后发者的身份更快速地追上头部智驾车企。从昨天的发布会中可以发现,小米SU7采用双Orin-X芯片以及27个感知硬件组合,这套硬件组合在…

FFmpeg拉取RTSP流并定时生成10秒短视频

生成效果: 视频时长为10秒 生成格式为FLV 输出日志: 完整实现代码如下: 需要在Mac和终端先安装FFmpeg brew install ffmpeg CMake文件配置: cmake_minimum_required(VERSION 3.27) project(ffmpeg_open_stream) set(CMAKE_CXX_STANDARD 17)#头文件包目录 include_director…

ETL工具-nifi干货系列 第五讲 处理器GenerateFlowFile

1、今天我们一起来学习处理器GenerateFlowFile。这个处理器创建带有随机数据或自定义内容的 FlowFiles。GenerateFlowFile 对于负载测试、配置和模拟非常有用。从工具栏拖动处理器到画布&#xff0c;然后选择GenerateFlowFile即可。 2、点击add按钮或者双击 GenerateFlowFile可…

【蓝桥杯省赛真题34】python积木搭建 中小学青少年组蓝桥杯比赛 算法思维python编程省赛真题解析

python积木搭建 第十三届蓝桥杯青少年组python比赛省赛真题 一、题目要求 &#xff08;注&#xff1a;input&#xff08;&#xff09;输入函数的括号中不允许添加任何信息&#xff09; 1、编程实现 小蓝和小青在玩积木搭建游戏&#xff0c;具体玩法如下: 小蓝报一个数字N&…

vue多语言包i18n

1.安装 如果是vue2直接安装8.2.1版本&#xff0c;否则会出现版本不匹配的错误 npm install vue-i18n8.2.1 --save2.文件编辑 在src目录下创建文件 en.js export const h {system: "Background management system",loginOut:"LoginOut",LayoutSet:Layout …

用搜索引擎收集信息-常用方式

1&#xff0c;site csdn.net &#xff08;下图表示只在csdn网站里搜索java&#xff09; 2&#xff0c;filetype:pdf &#xff08;表示只检索某pdf文件类型&#xff09; 表示在浏览器里面查找有关java的pdf文件 3&#xff0c;intitle:花花 &#xff08;表示搜索网页标题里面有花…