RNN实现阿尔茨海默症的诊断识别

本文为为🔗365天深度学习训练营内部文章

原作者:K同学啊

 一 导入数据

import torch.nn as nn
import torch.nn.functional as F
import torchvision,torch
from sklearn.preprocessing import StandardScaler
from torch.utils.data import TensorDataset,DataLoader
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei']
import warnings
warnings.filterwarnings('ignore')# 设置硬件设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")df = pd.read_excel('dia.xls')
df

二 数据处理分析

# 删除第一列和最后一列
df = df.iloc[:,1:-1]
print(df)
 Age  Gender  Ethnicity  EducationLevel        BMI  Smoking   
0      73       0          0               2  22.927749        0  \
1      89       0          0               0  26.827681        0   
2      73       0          3               1  17.795882        0   
3      74       1          0               1  33.800817        1   
4      89       0          0               0  20.716974        0   
...   ...     ...        ...             ...        ...      ...   
2144   61       0          0               1  39.121757        0   
2145   75       0          0               2  17.857903        0   
2146   77       0          0               1  15.476479        0   
2147   78       1          3               1  15.299911        0   
2148   72       0          0               2  33.289738        0   AlcoholConsumption  PhysicalActivity  DietQuality  SleepQuality  ...   
0              13.297218          6.327112     1.347214      9.025679  ...  \
1               4.542524          7.619885     0.518767      7.151293  ...   
2              19.555085          7.844988     1.826335      9.673574  ...   
3              12.209266          8.428001     7.435604      8.392554  ...   
4              18.454356          6.310461     0.795498      5.597238  ...   
...                  ...               ...          ...           ...  ...   
2144            1.561126          4.049964     6.555306      7.535540  ...   
2145           18.767261          1.360667     2.904662      8.555256  ...   
2146            4.594670          9.886002     8.120025      5.769464  ...   
2147            8.674505          6.354282     1.263427      8.322874  ...   
2148            7.890703          6.570993     7.941404      9.878711  ...   FunctionalAssessment  MemoryComplaints  BehavioralProblems       ADL   
0                 6.518877                 0                   0  1.725883  \
1                 7.118696                 0                   0  2.592424   
2                 5.895077                 0                   0  7.119548   
3                 8.965106                 0                   1  6.481226   
4                 6.045039                 0                   0  0.014691   
...                    ...               ...                 ...       ...   
2144              0.238667                 0                   0  4.492838   
2145              8.687480                 0                   1  9.204952   
2146              1.972137                 0                   0  5.036334   
2147              5.173891                 0                   0  3.785399   
2148              6.307543                 0                   1  8.327563   Confusion  Disorientation  PersonalityChanges   
0             0               0                   0  \
1             0               0                   0   
2             0               1                   0   
3             0               0                   0   
4             0               0                   1   
...         ...             ...                 ...   
2144          1               0                   0   
2145          0               0                   0   
2146          0               0                   0   
2147          0               0                   0   
2148          0               1                   0   DifficultyCompletingTasks  Forgetfulness  Diagnosis  
0                             1              0          0  
1                             0              1          0  
2                             1              0          0  
3                             0              0          0  
4                             1              0          0  
...                         ...            ...        ...  
2144                          0              0          1  
2145                          0              0          1  
2146                          0              0          1  
2147                          0              1          1  
2148                          0              1          0  [2149 rows x 33 columns]

三 探索性数据分析 

1.得病分布

res = df.groupby('Diabetes')['Age'].count()
print(res)plt.figure(figsize=(8, 6))
plt.pie(res.values, labels=res.index, autopct='%1.1f%%', startangle=90,colors=['#ff9999','#66b3ff','#99ff99'], explode=(0.1,  0),wedgeprops={'edgecolor': 'black', 'linewidth': 1, 'linestyle': 'solid'})
plt.title('是否得阿尔茨海默症', fontsize=16, fontweight='bold')
plt.show()

 2.BMI分布直方图

# BMI分布直方图
sns.displot(df['BMI'], kde=True, color='skyblue', bins=30, height=6, aspect=1.2)
plt.title('BMI Distribution', fontsize=18, fontweight='bold', color='darkblue')
plt.xlabel('BMI', fontsize=14, color='darkgreen')
plt.ylabel('Frequency', fontsize=14, color='darkgreen')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

 

3.年龄分布直方图 

# Age分布直方图
sns.displot(df['Age'], kde=True, color='skyblue', bins=30, height=6, aspect=1.2)
plt.title('Age Distribution', fontsize=18, fontweight='bold', color='darkblue')
plt.xlabel('Age', fontsize=14, color='darkgreen')
plt.ylabel('Frequency', fontsize=14, color='darkgreen')
plt.grid(True, linestyle='--', alpha=0.7)
plt.show()

 

 

四 构建划分数据集 

X = df.iloc[:,:-1]
y = df.iloc[:,-1]sc = StandardScaler()
X = sc.fit_transform(X)# 划分数据集
X = torch.tensor(np.array(X),dtype=torch.float32)
y = torch.tensor(np.array(y),dtype=torch.int64)X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.1,random_state=1)# 构建数据加载器
train_dl = DataLoader(TensorDataset(X_train,y_train),batch_size=64,shuffle=False)
test_dl = DataLoader(TensorDataset(X_test,y_test),batch_size=64,shuffle=False)

 五 训练模型

1.构建模型

# 构建模型
class model_rnn(nn.Module):def __init__(self):super(model_rnn, self).__init__()self.rnn0 = nn.RNN(input_size=32,hidden_size=200,num_layers=1,batch_first=True)self.fc0 = nn.Linear(200,50)self.fc1 = nn.Linear(50,2)def forward(self,x):out , hidden1 = self.rnn0(x)out = self.fc0(out)out = self.fc1(out)return outmodel = model_rnn().to(device)
print(model)
model_rnn((rnn0): RNN(32, 200, batch_first=True)(fc0): Linear(in_features=200, out_features=50, bias=True)(fc1): Linear(in_features=50, out_features=2, bias=True)
)

2.训练函数 

'''
训练模型
'''
# 训练循环
def train(dataloader,model,loss_fn,optimizer):size = len(dataloader.dataset)   # 训练集的大小num_batches = len(dataloader)      # 批次数目,(size/batchsize,向上取整)train_acc,train_loss = 0,0  # 初始化训练损失和正确率for x,y in dataloader:    # 获取数据X,y = x.to(device),y.to(device)# 计算预测误差pred = model(X)   # 网络输出loss = loss_fn(pred,y)   # 计算误差# 反向传播optimizer.zero_grad()    # grad属性归零loss.backward()   # 反向传播optimizer.step()   # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc,train_loss

3.测试函数 

# 测试循环
def valid(dataloader,model,loss_fn):size = len(dataloader.dataset)  # 训练集的大小num_batches = len(dataloader)  # 批次数目,(size/batchsize,向上取整)test_loss, test_acc = 0, 0  # 初始化训练损失和正确率# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs,target in dataloader:imgs,target = imgs.to(device),target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred,target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc,test_loss

4.正式训练 

loss_fn = nn.CrossEntropyLoss()   # 创建损失函数
learn_rate = 1e-4   # 学习率
opt = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs = 30train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):model.train()epoch_train_acc,epoch_train_loss = train(train_dl,model,loss_fn,opt)model.eval()epoch_test_acc,epoch_test_loss = valid(test_dl,model,loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = opt.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f},lr:{:.2E}')print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,lr))print("="*20,'Done',"="*20)
Epoch: 1,Train_acc:52.9%,Train_loss:0.688,Test_acc:67.0%,Test_loss:0.658,lr:1.00E-04
Epoch: 2,Train_acc:68.7%,Train_loss:0.612,Test_acc:67.4%,Test_loss:0.600,lr:1.00E-04
Epoch: 3,Train_acc:68.7%,Train_loss:0.566,Test_acc:70.7%,Test_loss:0.567,lr:1.00E-04
Epoch: 4,Train_acc:74.4%,Train_loss:0.526,Test_acc:72.6%,Test_loss:0.533,lr:1.00E-04
Epoch: 5,Train_acc:77.9%,Train_loss:0.487,Test_acc:78.1%,Test_loss:0.501,lr:1.00E-04
Epoch: 6,Train_acc:81.1%,Train_loss:0.451,Test_acc:79.5%,Test_loss:0.473,lr:1.00E-04
Epoch: 7,Train_acc:82.3%,Train_loss:0.421,Test_acc:80.0%,Test_loss:0.451,lr:1.00E-04
Epoch: 8,Train_acc:83.4%,Train_loss:0.397,Test_acc:78.6%,Test_loss:0.434,lr:1.00E-04
Epoch: 9,Train_acc:84.7%,Train_loss:0.378,Test_acc:80.0%,Test_loss:0.422,lr:1.00E-04
Epoch:10,Train_acc:85.2%,Train_loss:0.365,Test_acc:80.0%,Test_loss:0.414,lr:1.00E-04
Epoch:11,Train_acc:85.6%,Train_loss:0.354,Test_acc:80.0%,Test_loss:0.408,lr:1.00E-04
Epoch:12,Train_acc:85.9%,Train_loss:0.347,Test_acc:80.0%,Test_loss:0.405,lr:1.00E-04
Epoch:13,Train_acc:86.3%,Train_loss:0.341,Test_acc:78.6%,Test_loss:0.403,lr:1.00E-04
Epoch:14,Train_acc:87.0%,Train_loss:0.335,Test_acc:78.1%,Test_loss:0.403,lr:1.00E-04
Epoch:15,Train_acc:87.1%,Train_loss:0.331,Test_acc:78.6%,Test_loss:0.404,lr:1.00E-04
Epoch:16,Train_acc:87.1%,Train_loss:0.327,Test_acc:78.1%,Test_loss:0.405,lr:1.00E-04
Epoch:17,Train_acc:87.1%,Train_loss:0.324,Test_acc:78.6%,Test_loss:0.407,lr:1.00E-04
Epoch:18,Train_acc:87.3%,Train_loss:0.321,Test_acc:78.6%,Test_loss:0.409,lr:1.00E-04
Epoch:19,Train_acc:87.4%,Train_loss:0.318,Test_acc:77.7%,Test_loss:0.412,lr:1.00E-04
Epoch:20,Train_acc:87.7%,Train_loss:0.315,Test_acc:78.1%,Test_loss:0.415,lr:1.00E-04
Epoch:21,Train_acc:87.8%,Train_loss:0.312,Test_acc:77.7%,Test_loss:0.418,lr:1.00E-04
Epoch:22,Train_acc:88.1%,Train_loss:0.309,Test_acc:78.1%,Test_loss:0.422,lr:1.00E-04
Epoch:23,Train_acc:88.6%,Train_loss:0.306,Test_acc:78.1%,Test_loss:0.425,lr:1.00E-04
Epoch:24,Train_acc:88.6%,Train_loss:0.303,Test_acc:79.1%,Test_loss:0.429,lr:1.00E-04
Epoch:25,Train_acc:88.6%,Train_loss:0.301,Test_acc:79.5%,Test_loss:0.433,lr:1.00E-04
Epoch:26,Train_acc:88.6%,Train_loss:0.298,Test_acc:79.5%,Test_loss:0.437,lr:1.00E-04
Epoch:27,Train_acc:88.8%,Train_loss:0.295,Test_acc:80.0%,Test_loss:0.440,lr:1.00E-04
Epoch:28,Train_acc:89.1%,Train_loss:0.292,Test_acc:79.5%,Test_loss:0.444,lr:1.00E-04
Epoch:29,Train_acc:89.1%,Train_loss:0.290,Test_acc:79.1%,Test_loss:0.449,lr:1.00E-04
Epoch:30,Train_acc:89.2%,Train_loss:0.287,Test_acc:79.1%,Test_loss:0.453,lr:1.00E-04
==================== Done ====================

六 结果可视化 

1.Loss和Acc图

epochs_range = range(30)
plt.figure(figsize=(14,4))
plt.subplot(1,2,1)
plt.plot(epochs_range,train_acc,label='training accuracy')
plt.plot(epochs_range,test_acc,label='validation accuracy')
plt.legend(loc='lower right')
plt.title('training and validation accuracy')plt.subplot(1,2,2)
plt.plot(epochs_range,train_loss,label='training loss')
plt.plot(epochs_range,test_loss,label='validation loss')
plt.legend(loc='upper right')
plt.title('training and validation loss')
plt.show()

 2.调用模型预测

test_X = X_test[0].reshape(1,-1)
pred = model(test_X.to(device)).argmax(1).item()
print('模型预测结果:',pred)
print('=='*20)
print('0:未患病')
print('1:已患病')
模型预测结果: 0
========================================
0:未患病
1:已患病

3.绘制混淆矩阵 

'''
绘制混淆矩阵
'''
print('=============输入数据shape为==============')
print('X_test.shape:',X_test.shape)
print('y_test.shape:',y_test.shape)pred = model(X_test.to(device)).argmax(1).cpu().numpy()print('\n==========输出数据shape为==============')
print('pred.shape:',pred.shape)from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay# 计算混淆矩阵
cm = confusion_matrix(y_test,pred)plt.figure(figsize=(6,5))
plt.suptitle('')
sns.heatmap(cm,annot=True,fmt='d',cmap='Blues')
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title('Confusion Matrix',fontsize=12)
plt.xlabel('Pred Label',fontsize=10)
plt.ylabel('True Label',fontsize=10)
plt.tight_layout()
plt.show()
=============输入数据shape为==============
X_test.shape: torch.Size([215, 32])
y_test.shape: torch.Size([215])==========输出数据shape为==============
pred.shape: (215,)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/68574.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【新春特辑】2025年春节技术展望:蛇年里的科技创新与趋势预测

🔥【新春特辑】2025年春节技术展望:蛇年里的科技创新与趋势预测 📅 发布日期:2025年01月29日(大年初一) 在这个辞旧迎新的美好时刻,我们迎来了充满希望的2025年,也是十二生肖中的蛇…

使用 Docker + Nginx + Certbot 实现自动化管理 SSL 证书

使用 Docker Nginx Certbot 实现自动化管理 SSL 证书 在互联网安全环境日益重要的今天,为站点或应用部署 HTTPS 已经成为一种常态。然而,手动申请并续期证书既繁琐又容易出错。本文将以 Nginx Certbot 为示例,基于 Docker 容器来搭建一个…

C++11新特性之使用using(代替typedef)定义别名

1.介绍 传统的C使用typedef重定义一个类型存在一些限制&#xff0c;例如无法直接重定义一个模版。如下所示。 template <typename Val> struct str_map {typedef std::map<std::string, Val> type; };str_map<int>::type map1; 需要添加额外的类来实现&…

编程题-最长的回文子串(中等)

题目&#xff1a; 给你一个字符串 s&#xff0c;找到 s 中最长的回文子串。 示例 1&#xff1a; 输入&#xff1a;s "babad" 输出&#xff1a;"bab" 解释&#xff1a;"aba" 同样是符合题意的答案。示例 2&#xff1a; 输入&#xff1a;s &…

maven、npm、pip、yum官方镜像修改文档

文章目录 Maven阿里云网易华为腾讯云 Npm淘宝腾讯云 pip清华源阿里中科大华科 Yum 由于各博客繁杂&#xff0c;本文旨在记录各常见镜像官网&#xff0c;及其配置文档。常用镜像及配置可评论后加入 Maven 阿里云 官方文档 setting.xml <mirror><id>aliyunmaven&l…

CNN-GRU卷积门控循环单元时间序列预测(Matlab完整源码和数据)

CNN-GRU卷积门控循环单元时间序列预测&#xff08;Matlab完整源码和数据&#xff09; 目录 CNN-GRU卷积门控循环单元时间序列预测&#xff08;Matlab完整源码和数据&#xff09;预测效果基本介绍CNN-GRU卷积门控循环单元时间序列预测一、引言1.1、研究背景与意义1.2、研究现状1…

HTML-新浪新闻-实现标题-样式1

用css进行样式控制 css引入方式&#xff1a; --行内样式&#xff1a;写在标签的style属性中&#xff08;不推荐&#xff09; --内嵌样式&#xff1a;写在style标签中&#xff08;可以写在页面任何位置&#xff0c;但通常约定写在head标签中&#xff09; --外联样式&#xf…

搜索引擎友好:设计快速收录的网站架构

本文来自&#xff1a;百万收录网 原文链接&#xff1a;https://www.baiwanshoulu.com/14.html 为了设计一个搜索引擎友好的网站架构&#xff0c;以实现快速收录&#xff0c;可以从以下几个方面入手&#xff1a; 一、清晰的目录结构与层级 合理划分内容&#xff1a;目录结构应…

CF1098F Ж-function

【题意】 给你一个字符串 s s s&#xff0c;每次询问给你 l , r l, r l,r&#xff0c;让你输出 s s s l , r sss_{l,r} sssl,r​中 ∑ i 1 r − l 1 L C P ( s s i , s s 1 ) \sum_{i1}^{r-l1}LCP(ss_i,ss_1) ∑i1r−l1​LCP(ssi​,ss1​)。 【思路】 和前一道题一样&#…

C++ 拷贝构造

拷贝构造函数会在以下几种场景中被调用: 1. 用一个对象显式初始化另一个对象。 2. 对象按值传递给函数。 3. 函数按值返回对象。 4. 将对象插入到容器中。 5. 明确调用拷贝构造函数。 1. 当用一个对象显式初始化另一个对象时 MyClass obj1("Hello"); MyClass obj2…

2024年终总结

回顾 今年过年没回老家&#xff0c;趁着有时间&#xff0c;总结一下24年吧。 我把23年看做是打基础的一年&#xff0c;而24年主要是忙于项目的一年&#xff0c;基本上大部分时间都是忙着交付软件&#xff0c;写的一些文章也大部分都是项目中遇到的问题和解决方案&#xff0c;虽…

《哈佛家训》

《哈佛家训》是一本以教育为主题的书籍&#xff0c;旨在通过一系列富有哲理的故事和案例&#xff0c;传递积极的人生观、价值观和教育理念。虽然它并非直接由哈佛大学官方出版&#xff0c;但其内容深受读者喜爱&#xff0c;尤其是在家庭教育和个人成长领域。 以下是《哈佛家训…

[c语言日寄]越界访问:意外的死循环

【作者主页】siy2333 【专栏介绍】⌈c语言日寄⌋&#xff1a;这是一个专注于C语言刷题的专栏&#xff0c;精选题目&#xff0c;搭配详细题解、拓展算法。从基础语法到复杂算法&#xff0c;题目涉及的知识点全面覆盖&#xff0c;助力你系统提升。无论你是初学者&#xff0c;还是…

使用 KNN 搜索和 CLIP 嵌入构建多模态图像检索系统

作者&#xff1a;来自 Elastic James Gallagher 了解如何使用 Roboflow Inference 和 Elasticsearch 构建强大的语义图像搜索引擎。 在本指南中&#xff0c;我们将介绍如何使用 Elasticsearch 中的 KNN 聚类和使用计算机视觉推理服务器 Roboflow Inference 计算的 CLIP 嵌入构建…

深入理解三高架构:高可用性、高性能、高扩展性的最佳实践

引言 在现代互联网环境下&#xff0c;随着用户规模和业务需求的快速增长&#xff0c;系统架构的设计变得尤为重要。为了确保系统能够在高负载和复杂场景下稳定运行&#xff0c;"三高架构"&#xff08;高可用性、高性能、高扩展性&#xff09;成为技术架构设计中的核…

Nginx 开发总结

文章目录 1. Nginx 基础概念1-1、什么是 Nginx1-2、Nginx 的工作原理1-3、Nginx 的核心特点1-4、Nginx 的常见应用场景1-5、Nginx 与 Apache 的区别1-6、 Nginx 配置的基本结构1-7、Nginx 常见指令 2. Nginx 配置基础2-1、Nginx 配置文件结构2-2、全局配置 (Global Block)2-3、…

Git 出现 Please use your personal access token instead of the password 解决方法

目录 前言1. 问题所示2. 原理分析3. 解决方法前言 1. 问题所示 执行Git提交代码的时候,出现如下所示: lixiaosong@IT07 MINGW64 /f/java_project/JavaDemo (master) $ git push -u origin --all libpng warning: iCCP: known incorrect sRGB profile libpng warning

maven的打包插件如何使用

默认的情况下&#xff0c;当直接执行maven项目的编译命令时&#xff0c;对于结果来说是不打第三方包的&#xff0c;只有一个单独的代码jar&#xff0c;想要打一个包含其他资源的完整包就需要用到maven编译插件&#xff0c;使用时分以下几种情况 第一种&#xff1a;当只是想单纯…

Golang Gin系列-7:认证和授权

在本章中&#xff0c;我们将探讨Gin框架中身份验证和授权的基本方面。这包括实现基本的和基于令牌的身份验证&#xff0c;使用基于角色的访问控制&#xff0c;应用中间件进行授权&#xff0c;以及使用HTTPS和漏洞防护保护应用程序。 实现身份认证 Basic 认证 Basic 认证是内置…

CTF-web: phar反序列化+数据库伪造 [DASCTF2024最后一战 strange_php]

step 1 如何触发反序列化? 漏洞入口在 welcome.php case delete: // 获取删除留言的路径&#xff0c;优先使用 POST 请求中的路径&#xff0c;否则使用会话中的路径 $message $_POST[message_path] ? $_POST[message_path] : $_SESSION[message_path]; $msg $userMes…