第R8周:RNN实现阿尔兹海默病诊断

数据集包含2149名患者的广泛健康信息,每名缓则的ID范围从4751到6900不等,该数据集包含人口统计详细信息,生活方式因素、病史、临床测量、认知和功能评估、症状以及阿尔兹海默症的诊断。

一、准备工作

1、硬件准备

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F# 设置GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

在这里插入图片描述

2、导入数据

df = pd.read_csv('./alzheimers_disease_data.csv')
# 删除最后一列和第一列
df = df.iloc[:, 1:-1]
df.head()

在这里插入图片描述

二、构建数据集

1、标准化

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScalerX = df.iloc[:, :-1]
y = df.iloc[:, -1]# 将每一列特征标准化为标准正态分布,注意,标准化是针对每一列而言的
scaler = StandardScaler()
X = scaler.fit_transform(X)

2、划分数据集

X = torch.tensor(np.array(X), dtype=torch.float32)
y = torch.tensor(np.array(y), dtype=torch.int64)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=1)X_train.shape, y_train.shape

3、构建数据加载器

from torch.utils.data import TensorDataset, DataLoadertrain_dl = DataLoader(TensorDataset(X_train, y_train), batch_size=32, shuffle=False)
test_dl = DataLoader(TensorDataset(X_test, y_test), batch_size=32, shuffle=False)

三、模型训练

在这里插入图片描述

1、构建模型

class model_rnn(nn.Module):def __init__(self):super(model_rnn, self).__init__()self.rnn0 = nn.RNN(input_size=32, hidden_size=200, num_layers=1, batch_first=True)self.fc0 = nn.Linear(200, 50)self.fc1 = nn.Linear(50, 2)def forward(self, x):out, hidden1 = self.rnn0(x)out          = self.fc0(out)out            = self.fc1(out)return outmodel = model_rnn().to(device)
model

在这里插入图片描述

2、定义训练函数

# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小num_batches = len(dataloader)   # 批次数目, (size/batch_size,向上取整)train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)          # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()        # 反向传播optimizer.step()       # 每一步自动更新# 记录acc与losstrain_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc  /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

3、测试函数

def test (dataloader, model, loss_fn):size        = len(dataloader.dataset)  # 测试集的大小num_batches = len(dataloader)          # 批次数目, (size/batch_size,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss        = loss_fn(target_pred, target)test_loss += loss.item()test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc  /= sizetest_loss /= num_batchesreturn test_acc, test_loss

4、正式训练

loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 5e-5
opt = torch.optim.Adam(model.parameters(), lr= learn_rate)epochs     = 50train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = opt.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss,epoch_test_acc*100, epoch_test_loss, lr))print('Done')

在这里插入图片描述

四、模型评估

1.Loss与Accuracy图

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率from datetime import datetime
current_time = datetime.now()epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time)plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

2、混沌矩阵

print("====================输入数据Shape为====================")
print("X_test.shape: ",X_test.shape)
print("y_test.shape: ",y_test.shape)pred = model(X_test.to(device)).argmax(1).cpu().numpy()
print("====================输出数据Shape为====================")
print("pred.shape: ",pred.shape)

在这里插入图片描述

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay# 计算混淆矩阵
cm = confusion_matrix(y_test, pred)plt.figure(figsize=(6,5))
# plt.suptitle('Confusion Matrix')
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')# 修改字体大小
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.title('Confusion Matrix', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Labels', fontsize=10)# 调整布局防止重叠
plt.tight_layout()# 显示图形
plt.show()

在这里插入图片描述

五、预测

在这里插入图片描述

六、总结

当然,在学习完RNN及其演进模型(如LSTM、GRU)后,对“如何处理序列数据”进行总结是非常有价值的。这能帮你建立起一个清晰的知识框架。

以下是一个系统性的总结,涵盖了从核心思想、关键挑战到解决方案和现代最佳实践。

处理序列数据的核心思想与总结

  1. 核心目标:处理带有“顺序依赖”的数据
    序列数据的根本特征是​​数据点之间的顺序关系蕴含重要信息​​。例如,一句话中单词的顺序、一段音乐中音符的先后、股票价格随时间的变化等。模型的目标是学习这种顺序依赖关系,并做出预测、分类或生成。
  2. 基础架构:循环神经网络 (RNN)
    RNN提供了处理序列数据的基本范式:
    ​​核心机制​​: ​​循环连接​​。网络为每个时间步的输入进行处理,并将一个“隐藏状态(Hidden State)”传递给下一个时间步。这个隐藏状态作为“记忆”,承载了之前所有时间步的摘要信息。

h_t = f(W * h_{t-1} + U * x_t + b)

​​优势​​: 参数共享(所有时间步共用同一组参数),理论上可以处理任意长度的序列。
​​典型结构​​:
​​一对一​​: 单个输入 -> 单个输出(例如,图像分类)
​​一对多​​: 单个输入 -> 序列输出(例如,图像字幕生成)
​​多对一​​: 序列输入 -> 单个输出(例如,情感分析)
​​多对多​​: 序列输入 -> 序列输出(例如,机器翻译、股票预测)
3. 核心挑战与致命缺陷:梯度消失/爆炸
​​问题​​: 当序列很长时,RNN在反向传播(BPTT)过程中,梯度需要连续乘以相同的权重矩阵,导致梯度呈指数级缩小(消失)或增大(爆炸)。
​​后果​​: ​​模型无法学习长期依赖关系​​。它变得“健忘”,只能记住近期信息,而难以利用序列早期的重要信息。这严重限制了基础RNN在长序列任务上的应用。
4. 解决方案:门控机制 (Gating Mechanism)
为了解决长期依赖问题,引入了更为强大的循环单元,其核心思想是使用“门”来精确控制信息的流动。

​​LSTM (长短期记忆网络)​​:
​​核心​​: 引入了​​细胞状态(Cell State)​​ 作为“信息高速公路”和三个门(​​输入门、遗忘门、输出门​​)。
​​工作方式​​: 门(Sigmoid函数)决定让多少信息通过(0~1)。遗忘门决定从细胞状态中丢弃什么信息;输入门决定添加什么新信息。这使得LSTM可以长期保存和传递关键信息。
​​GRU (门控循环单元)​​:
​​核心​​: LSTM的简化版,将LSTM的三个门合并为两个:​​更新门​​和​​重置门​​。
​​特点​​: 参数更少,训练速度更快,但在大多数任务上的效果与LSTM相当。它成为了一个非常流行且高效的默认选择。
​​小结:RNN -> LSTM/GRU 的演进,是为了解决基础RNN的“短期记忆”问题,其核心技术创新是“门控”。​​
5. 更深与更广:架构的扩展
​​深度RNN​​: 将多个RNN层堆叠起来,底层处理低级特征(如字符、音素),高层处理高级特征(如语义、意图),以增强模型的表达能力。
​​双向RNN (Bi-RNN/Bi-LSTM/Bi-GRU)​​: 同时运行两个独立的RNN,一个从序列开头到结尾(正向),一个从结尾到开头(反向),然后将它们的输出合并。
​​优势​​: 对于任何一个时间点,模型都拥有​​完整的上下文信息​​(过去和未来)。这在自然语言处理(如阅读理解、命名实体识别)中极其重要。
现代范式:注意力机制与Transformer
尽管LSTM/GRU解决了长期依赖,但仍存在序列计算无法并行、信息压缩丢失等问题。这催生了更革命的架构:
​​注意力机制 (Attention Mechanism)​​:
​​核心思想​​: 允许模型在生成输出时,直接“关注”并加权输入序列中的任何部分,而不是强制将所有信息压缩到最后一個隐藏状态。
​​优势​​: 极大地改善了长序列性能,提供了更好的可解释性(可以看到模型在关注哪里)。
​​Transformer​​:
​​核心​​: ​​完全基于自注意力机制(Self-Attention)​​,彻底抛弃了循环结构。
​​优势​​:
1.​​极高的并行化能力​​: 训练速度远超RNN/LSTM。
2.​​全局建模能力​​: 一步计算即可捕捉序列中任意两个元素之间的关系,无论距离多远。
​​影响​​: Transformer及其衍生模型(如BERT, GPT)已成为当前NLP乃至跨模态领域(视觉、音频)的绝对主流架构。

在学完循环神经网络RNN后,同时学习完门控机制后,GRU的优势(相对于RNN)​​解决了核心缺陷​​,极大缓解了梯度消失问题,具有强大的长序列建模能力。收敛更快​​: 训练过程更稳定,收敛速度通常更快。
​​性能卓越​​: 在绝大多数任务上的性能远超传统RNN。

​​GRU的劣势(相对于RNN):​​
结构更复杂​​,参数更多,计算量稍大。

​​过拟合风险​​稍高。那么,GRU和LSTM又如何选择呢?(你可能会问的下一个问题)​​

GRU通常被认为是LSTM的一个更轻量、更快的替代品。它们的性能在大多数任务上​​非常接近​​,没有绝对的赢家。
​​优先选择GRU​​:当计算资源受限、训练速度是关键因素,或者数据集较小时,GRU是一个很好的选择,因为它用更少的参数达到了与LSTM相似的性能。

​​优先选择LSTM​​:在一些非常长和复杂的序列任务上(如语音识别、音乐生成),LSTM凭借其更精细的门控控制(三个独立的门),可能拥有微弱的优势,但这并非绝对。

​​最佳实践是:​​ 在你的特定数据集上同时试验GRU和LSTM,选择表现更好的那个。
​​GRU是对传统RNN的一次重大升级​​。它通过巧妙的门控设计,以可接受的计算成本为代价,成功解决了RNN的核心痛点,使其成为处理序列数据的强大而高效的模型。在学习上,从RNN到GRU/LSTM的演进,是理解如何通过设计更复杂的细胞结构来优化梯度流和信息保存的关键一步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/95801.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/95801.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL复制技术的发展历程

在互联网应用不断发展的二十多年里,MySQL 一直是最广泛使用的开源关系型数据库之一。它凭借开源、轻量、灵活的优势,支撑了无数网站、移动应用和企业系统。支撑 MySQL 长期发展的关键之一,就是 复制(Replication)技术。…

C++从字符串中移除前导零

该程序用于去除字符串开头的零字符。当输入"0000123456"时,程序会输出"123456"。核心函数removeZero()通过while循环找到第一个非零字符的位置,然后使用erase()方法删除前面的所有零。主函数读取输入字符串并调用该函数处理。程序简…

【面试题】C++系列(一)

本专栏文章持续更新,新增内容使用蓝色表示。C面向对象的三大特性:封装,继承,多态(1)封装是将数据和函数组合到一个类里。主要目的是隐藏内部的实现细节,仅暴露必要的接口给外部。通过封装&#…

当没办法实现从win复制东西到Linux虚拟机时的解决办法

① 先确认是否已安装bash复制sudo apt list --installed | grep open-vm-tools如果 没有任何回显 → 没装,跳到 ③如果看到 open-vm-tools 已安装 → 继续 ②② 启动正确的服务(单词别打错)bash复制systemctl status vmtoolsd # 查看…

用Markdown写自动化用例:Gauge实战全攻略!

你作为一名自动化测试工程师,正在为一个复杂的Web应用编写测试脚本:传统工具要求写大量代码,维护起来像解谜游戏,团队非技术成员完全插不上手。这时,Gauge这个“自动化神器”如魔法般出现——它允许用Markdown写可读的…

Unity开发保姆级教程:C#脚本+物理系统+UI交互,3大模块带你通关游戏开发

文章目录基础概念Unity开发环境搭建版本选择:为什么2021 LTS是最佳起点?三步安装:从下载到项目创建界面认知:5分钟掌握核心操作区配置优化:让开发更顺畅验证环境:创建你的第一个CubeC#基础语法与Unity脚本结…

Depth Anything V2论文速读

这篇论文主要讲了两方面1.为了解决模型在正常标注的现实图像上训练的缺陷问题、提出了新的模型训练数据和训练方法真实标记图像存在缺点:标签噪声(深度传感器可能存在空洞、玻璃等物体反射导致精度不准确)、标签细节粗糙(深度图边…

数据库原理及应用_数据库管理和保护_第5章数据库的安全性_理论部分

前言 "<数据库原理及应用>(MySQL版)".以下称为"本书"中第5章前6节内容 引入 数据库的安全性是非常重要的,表现在两个方面:一数据的访问权限,二数据的物理安全.本书在这一章前6节基本上都是理论性的内容,选择其中重要部分进行解读. 5.1数据库安全性…

QT6 配置 Copilot插件

下载项目&#xff1a;解压 GitHub - github/copilot.vim: Neovim plugin for GitHub Copilot Node.js必须安装 Node.js — Download Node.js 例如先安装一个qt6 ,qt Cteatror选择新版本的 设置 效果&#xff0c;注释里面写要求&#xff0c;tab同意 #include "mainwindow…

ArcGIS学习-15 实战-建设用地适宜性评价

选定参评因子 高程坡度河流道路土地利用 确定因子分析标准 以下仅参数仅做展示&#xff0c;并非合理的数值 高程 0-100m&#xff1a;100 分&#xff0c;此高程范围通常地势较为平坦&#xff0c;建设成本相对较低&#xff0c;适宜建设。100-200m&#xff1a;70 分&#xff…

[C/C++学习] 7.“旋转蛇“视觉图形生成

参考文献: 童晶. C和C游戏趣味编程[M].人民邮电出版社.2021. 一.弧度制和角度制的转换 弧度制数值和角度对应表: (PI为圆周率&#xff0c;值为3.1415926)弧度制角度制00PI/630PI/360PI/2902*PI/3120PI1802*PI360二.扇形的绘制 easyx的solidpie( )函数用于在一个矩形区域内绘制…

自然语言处理之PyTorch实现词袋CBOW模型

在自然语言处理&#xff08;NLP&#xff09;领域&#xff0c;词向量&#xff08;Word Embedding&#xff09;是将文本转换为数值向量的核心技术。它能让计算机“理解”词语的语义关联&#xff0c;例如“国王”和“女王”的向量差可能与“男人”和“女人”的向量差相似。而Word2…

TCP, 三次握手, 四次挥手, 滑动窗口, 快速重传, 拥塞控制, 半连接队列, RST, SYN, ACK

目录 TCP 是什么&#xff1a;面向连接 可靠 字节流三次握手&#xff1a;为什么不是两次四次挥手与 TIME_WAIT&#xff1a;谁等谁序列号/确认号与去重、排序、确认重传机制&#xff1a;超时重传与快速重传滑动窗口与流量控制拥塞控制&#xff1a;慢启动/拥塞避免/快重传/快恢…

CentOS 7.2 虚机 ssh 登录报错在重启后无法进入系统

文章目录前言1. 故障描述2. 故障诊断3. 故障原因4. 解决方案总结前言 上周帮用户处理了一个 linux 虚拟机在重启后无法正常进入操作系统的故障&#xff0c;觉得比较有意思&#xff0c;在这里分享给大家。 1. 故障描述 事情的起因是一台系统版本为 CentOS 7.2 的 VMware 虚拟机…

《从使用到源码:OkHttp3责任链模式剖析》

一 从使用开始0.依赖引入implementation ("com.squareup.okhttp3:okhttp:3.14.7")1.创建OkHttpClient实例方式一&#xff1a;直接使用默认配置的Builder//从源码可以看出&#xff0c;当我们直接new创建OkHttpClient实例时&#xff0c;会默认给我们配置好一个Builder …

安装3DS MAX 2026后,无法运行,提示缺少.net core的解决方案

今天安装了3DS MAX 2026&#xff08;俗称3DMAX&#xff09;&#xff0c;安装完毕后死活运行不了。提示如下&#xff1a; 大意是找不到所需的.NET Core 8库文件。后来搜索了下&#xff0c;各种文章说.NET CORE和.NET FRAMEWORK不是一个东西。需要单独下载安装。然后根据提示&…

FastAPI + LangChain 和 Spring AI + LangChain4j

FastAPI+LangChain和Spring AI+LangChain4j这两个技术组合进行详细对比。 核心区别: 特性维度 FastAPI + LangChain (Python栈) Spring AI + LangChain4j (Java栈) 技术栈 Python生态 (FastAPI, LangChain) Java生态 (Spring Boot, Spring AI, LangChain4j) 核心设计哲学 灵活…

Apache 2.0 开源协议详解:自由、责任与商业化的完美平衡-优雅草卓伊凡

Apache 2.0 开源协议详解&#xff1a;自由、责任与商业化的完美平衡-优雅草卓伊凡引言由于我们优雅草要推出收银系统&#xff0c;因此要采用开源代码&#xff0c;卓伊凡目前看好了一个产品是apache 2.0协议&#xff0c;因此我们有必要深刻理解apache 2.0协议避免触犯版权问题。…

自学嵌入式第37天:MQTT协议

一、MQTT&#xff08;消息队列遥测传输协议Message Queuing Telemetry Transport&#xff09;1.MQTT是应用层的协议&#xff0c;是一种基于发布/订阅模式的“轻量级”通讯协议&#xff0c;建构于TCP/IP协议上&#xff0c;可以以极少的代码和有限的带宽为连接远程设备提供实时可…

RabbitMQ--延时队列总结

一、延迟队列概念 延迟队列&#xff08;Delay Queue&#xff09;是一种特殊类型的队列&#xff0c;队列中的元素需要在指定的时间点被取出和处理。简单来说&#xff0c;延时队列就是存放需要在某个特定时间被处理的消息。它的核心特性在于“延迟”——消息在队列中停留一段时间…