基于深度学习的胸部 X 光图像肺炎分类系统(三)

目录

二分类胸片判断:

1. 数据加载时指定了两类标签

2. 损失函数用了二分类专用的

3. 输出层只有 1 个神经元,用了sigmoid激活函数

4. 预测时用 0.5 作为分类阈值


二分类胸片判断:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, roc_auc_score, roc_curve, confusion_matrix, classification_report
from imblearn.over_sampling import RandomOverSampler
import tensorflow as tf
from keras import layers
from keras import models
# 或者更常用的是直接导入Sequential类
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator
import os
import zipfile
import requests
from tensorflow.python.keras.callbacks import EarlyStopping
#  这个代码执行 请切换环境到tf_env
plt.rcParams['font.sans-serif'] = ['SimHei']  # 使用 SimHei 字体
plt.rcParams['axes.unicode_minus'] = False    # 解决负号显示问题
plt.rcParams['font.size'] = 10  # 设置全局字体大小# 数据加载和预处理
def load_data(train_dir, test_dir, val_dir, img_size=(150, 150), batch_size=32):# 数据增强器 - 仅用于训练集train_datagen = ImageDataGenerator(rescale=1. / 255,rotation_range=10,width_shift_range=0.1,height_shift_range=0.1,shear_range=0.1,zoom_range=0.1,horizontal_flip=True)# 验证集和测试集只需要重新缩放val_test_datagen = ImageDataGenerator(rescale=1. / 255)# 加载训练数据train_generator = train_datagen.flow_from_directory(train_dir,target_size=img_size,batch_size=batch_size,class_mode='binary',classes=['NORMAL', 'PNEUMONIA'],shuffle=True)# 加载验证数据val_generator = val_test_datagen.flow_from_directory(val_dir,target_size=img_size,batch_size=batch_size,class_mode='binary',classes=['NORMAL', 'PNEUMONIA'],shuffle=False)# 加载测试数据test_generator = val_test_datagen.flow_from_directory(test_dir,target_size=img_size,batch_size=batch_size,class_mode='binary',classes=['NORMAL', 'PNEUMONIA'],shuffle=False)return train_generator, val_generator, test_generator# 处理样本不均衡(过采样)
def handle_imbalance(generator):# 提取特征和标签X, y = [], []num_batches = len(generator)# 重置生成器以确保从开始获取数据generator.reset()for i in range(num_batches):batch_x, batch_y = generator.next()X.append(batch_x)y.append(batch_y)X = np.concatenate(X)y = np.concatenate(y)# 打印原始分布print(f"原始样本分布: 正常={np.sum(y == 0)}, 肺炎={np.sum(y == 1)}")# 展平特征用于过采样X_flat = X.reshape(X.shape[0], -1)# 过采样少数类ros = RandomOverSampler(random_state=42)X_resampled, y_resampled = ros.fit_resample(X_flat, y)# 恢复图像形状X_resampled = X_resampled.reshape(-1, *X.shape[1:])print(f"过采样后分布: 正常={np.sum(y_resampled == 0)}, 肺炎={np.sum(y_resampled == 1)}")return X_resampled, y_resampled, y# 构建改进的CNN模型
def build_model(input_shape):model = models.Sequential([# 第一个卷积块layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),layers.BatchNormalization(),layers.MaxPooling2D((2, 2)),layers.Dropout(0.2),# 第二个卷积块layers.Conv2D(64, (3, 3), activation='relu'),layers.BatchNormalization(),layers.MaxPooling2D((2, 2)),layers.Dropout(0.3),# 第三个卷积块layers.Conv2D(128, (3, 3), activation='relu'),layers.BatchNormalization(),layers.MaxPooling2D((2, 2)),layers.Dropout(0.4),# 第四个卷积块layers.Conv2D(256, (3, 3), activation='relu'),layers.BatchNormalization(),layers.MaxPooling2D((2, 2)),layers.Dropout(0.5),# 分类器layers.Flatten(),layers.Dense(512, activation='relu'),layers.BatchNormalization(),layers.Dropout(0.5),layers.Dense(1, activation='sigmoid')])# 使用更稳定的优化器optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)model.compile(optimizer=optimizer,loss='binary_crossentropy',metrics=['accuracy',tf.keras.metrics.Precision(name='precision'),tf.keras.metrics.Recall(name='recall'),tf.keras.metrics.AUC(name='auc')])return model# 主函数
def main():# 假设数据集已经手动下载并解压train_dir = "chest_xray/train"test_dir = "chest_xray/test"val_dir = "chest_xray/val"# 加载数据img_size = (150, 150)batch_size = 32train_generator, val_generator, test_generator = load_data(train_dir, test_dir, val_dir, img_size, batch_size)# 处理样本不均衡X_train, y_train_resampled, y_train_original = handle_imbalance(train_generator)# 计算类别权重(基于原始分布)n_normal = np.sum(y_train_original == 0)n_pneumonia = np.sum(y_train_original == 1)total = n_normal + n_pneumoniaweight_for_normal = (1 / n_normal) * (total / 2.0)weight_for_pneumonia = (1 / n_pneumonia) * (total / 2.0)class_weights = {0: weight_for_normal, 1: weight_for_pneumonia}print(f"类别权重: 正常={weight_for_normal:.2f}, 肺炎={weight_for_pneumonia:.2f}")# 构建模型model = build_model((*img_size, 3))model.summary()# 提前停止回调early_stopping = EarlyStopping(monitor='val_loss',patience=5,restore_best_weights=True,verbose=1)# 训练模型history = model.fit(X_train, y_train_resampled,epochs=30,batch_size=32,validation_data=val_generator,class_weight=class_weights,callbacks=[early_stopping],verbose=1)# 评估模型 - 使用完整测试集test_generator.reset()test_steps = len(test_generator)test_results = model.evaluate(test_generator, steps=test_steps, verbose=1)print("\n测试集评估结果:")print(f"准确率: {test_results[1]:.4f}")print(f"精确率: {test_results[2]:.4f}")print(f"召回率: {test_results[3]:.4f}")print(f"AUC: {test_results[4]:.4f}")# 获取测试集所有预测结果test_generator.reset()y_true = []y_pred_prob = []for i in range(test_steps):batch_x, batch_y = test_generator.next()y_true.extend(batch_y)batch_pred = model.predict(batch_x, verbose=0).ravel()y_pred_prob.extend(batch_pred)y_true = np.array(y_true)y_pred_prob = np.array(y_pred_prob)y_pred = (y_pred_prob > 0.5).astype(int)# 计算额外指标f1 = f1_score(y_true, y_pred)auc = roc_auc_score(y_true, y_pred_prob)print(f"\nF1-score: {f1:.4f}")print(f"AUC-ROC: {auc:.4f}")# 分类报告print("\n分类报告:")print(classification_report(y_true, y_pred, target_names=['NORMAL', 'PNEUMONIA']))# 混淆矩阵cm = confusion_matrix(y_true, y_pred)print("混淆矩阵:")print(cm)# 绘制ROC曲线fpr, tpr, _ = roc_curve(y_true, y_pred_prob)plt.figure(figsize=(10, 6))plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC曲线 (AUC = {auc:.4f})')plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')plt.xlim([0.0, 1.0])plt.ylim([0.0, 1.05])plt.xlabel('False Positive Rate')plt.ylabel('True Positive Rate')plt.title('接收者操作特征曲线(ROC)')plt.legend(loc="lower right")plt.savefig('roc_curve.png', dpi=300)plt.show()# 绘制训练历史plt.figure(figsize=(12, 8))plt.subplot(2, 2, 1)plt.plot(history.history['accuracy'], label='训练准确率')plt.plot(history.history['val_accuracy'], label='验证准确率')plt.title('准确率')plt.legend()plt.subplot(2, 2, 2)plt.plot(history.history['loss'], label='训练损失')plt.plot(history.history['val_loss'], label='验证损失')plt.title('损失')plt.legend()plt.subplot(2, 2, 3)plt.plot(history.history['precision'], label='训练精确率')plt.plot(history.history['val_precision'], label='验证精确率')plt.title('精确率')plt.legend()plt.subplot(2, 2, 4)plt.plot(history.history['recall'], label='训练召回率')plt.plot(history.history['val_recall'], label='验证召回率')plt.title('召回率')plt.legend()plt.tight_layout()plt.savefig('training_history.png', dpi=300)plt.show()if __name__ == "__main__":main()

这段代码里有很多地方明确体现了这是一个二分类任务(判断 “正常胸片” 和 “肺炎胸片” 两类),最关键的有这几个地方:

1. 数据加载时指定了两类标签

load_data 函数中,加载数据时明确指定了类别为两类:

train_generator = train_datagen.flow_from_directory(

    train_dir,

    ...

    class_mode='binary',  # 这里指定是二分类模式

    classes=['NORMAL', 'PNEUMONIA'],  # 明确两类:正常(NORMAL)和肺炎(PNEUMONIA

    ...

)

  1. class_mode='binary':直接告诉程序 “这是二分类任务”,标签会被处理成 0 和 1(0 代表正常,1 代表肺炎)。
  2. classes=['NORMAL', 'PNEUMONIA']:手动指定只有这两个类别,没有第三种情况。

2. 损失函数用了二分类专用的

在模型编译时,损失函数用的是 binary_crossentropy(二分类交叉熵):

model.compile(

    ...

    loss='binary_crossentropy',  # 专门用于二分类的损失函数

    ...

)

这个损失函数的作用是:计算 “模型判断为 0 或 1 的概率” 与 “实际标签(0 或 1)” 之间的差距,指导模型优化。如果是多分类任务,会用其他损失函数(比如 categorical_crossentropy)。

3. 输出层只有 1 个神经元,用了sigmoid激活函数

模型的最后一层是:

layers.Dense(1, activation='sigmoid')  # 输出层

  1. Dense(1):只输出 1 个数值,这个数值经过 sigmoid 激活后,会被压缩到 0~1 之间。
  2. 实际含义:
    1. 数值越接近 0 → 模型认为 “更可能是正常胸片(0 类)”;
    2. 数值越接近 1 → 模型认为 “更可能是肺炎胸片(1 类)”。

这是二分类任务的典型输出方式(多分类会有多个神经元,对应多个类别)。

4. 预测时用 0.5 作为分类阈值

在生成最终判断结果时:

y_pred = (y_pred_prob > 0.5).astype(int)  # 大于0.51类(肺炎),否则算0类(正常)

直接用 0.5 作为 “两类的分界线”,把输出概率分成 “0” 和 “1” 两类,进一步说明这是二分类。

从 “数据标签定义”“损失函数选择”“输出层设计” 到 “最终预测规则”,全流程都围绕 “只能分成两类” 展开,没有任何支持多类别的设计。所以这段代码是典型的二分类任务,目标就是区分 “正常胸片” 和 “肺炎胸片”。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/92799.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/92799.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入理解 BIO、NIO、AIO

目录 一、同步与非同步 二、阻塞与非阻塞 三、BIO(Blocking I/O,阻塞I/O) 四、NIO(Non-blocking I/O,非阻塞I/O) 五、AIO(Asynchronous I/O,异步I/O) 同步阻塞&…

电脑无法识别固态硬盘怎么办?

随着固态硬盘(SSD)越来越普及,不少用户在给电脑更换、加装SSD时会遇到一个让人头大的问题——电脑识别不了固态硬盘。可能是开不了机,或者在“此电脑”中找不到硬盘,甚至连系统安装界面都提示“找不到驱动器”。这时候…

Kingbasepostgis 安装实践

文章目录前言一、安装准备1.1 部署方案规划1.2 SELINUX、防火墙状态检查1.3 操作系统时间检查1.4 创建用户及密码1.5 目录创建1.6 操作系统参数配置1.6.1 配置limits.conf文件二、安装2.1 上传安装包以及license授权文件2.2 拷贝安装文件2.3 命令行方式安装2.3.1简介2.3.2 许可…

移动端设备能部署的llm

mlc-llm 内置RedPajama hf示例模型 TheBloke/Mistral-7B-Instruct-v0.2-GGUF https://github.com/mlc-ai/mlc-llm/tree/main llama.cpp https://github.com/ggml-org/llama.cpp reference --- MLC-LLM:大模型如何部署到浏览器 / 手机?完整流程复现…

Ubuntu硬盘挂载

一、在 Ubuntu 中,你可以用以下命令快速查看 所有已连接但尚未挂载的硬盘和分区:lsblk -o NAME,SIZE,FSTYPE,MOUNTPOINT,UUID输出中 MOUNTPOINT 为空的行,就是 未挂载的分区。sda ├─sda1 500M ext4 /boot ├─sda2 1.8T ntfs └─sda3 …

JavaScript -Socket5代理使用

axios 安装两个包 socks-proxy-agent,axios const { SocksProxyAgent } require(socks-proxy-agent); const axios require(axios);const socks5Axios axios.create();const socks5 () > {const socks5Agent new SocksProxyAgent("socks5://112.194.8…

[特殊字符] 从数据库无法访问到成功修复崩溃表:一次 MySQL 故障排查实录

一次典型的 MySQL 故障排查与修复全过程,涵盖登录失败、表崩溃、innodb_force_recovery 救援、坏表剔除与数据恢复等关键操作。一、问题背景某业务系统运行多年,数据库使用的是 MySQL 8.0.18,近期在一次服务器重启后,发现无法正常…

【Agent】API Reference Manual(API 参考手册)

https://github.com/Intelligent-Internet/CommonGround/blob/main/docs/framework/03-api-reference.md 以下是这份 API Reference Manual(API 参考手册) 的完整中文翻译: API 参考手册 版本:0.1 目录 概览 1.1 API 目的 1.2 通信协议与核心概念 HTTP API 2.1 POST /se…

LeetCode Hot 100 全排列

给定一个不含重复数字的数组 nums ,返回其 所有可能的全排列 。你可以 按任意顺序 返回答案。示例 1:输入:nums [1,2,3] 输出:[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]示例 2:输入:nums [0,1]…

AI大模型如何有效识别和纠正数据中的偏见?

当下,人工智能大模型已成为推动各行业发展的关键力量,广泛应用于自然语言处理、图像识别、医疗诊断、金融风控等领域,为人们的生活和工作带来了诸多便利。然而,随着其应用的不断深入,数据偏见问题逐渐浮出水面&#xf…

如何通过内网穿透,访问公司内部服务器?

“凌晨2点,销售总监王姐在机场候机时突然接到客户电话——对方要求立即查看产品库存数据。她慌忙翻出笔记本电脑,却发现公司内网数据库没有公网IP,VPN连接又卡在验证环节……这样的场景,是否让你想起某个手忙脚乱的时刻&#xff1…

12. isaacsim4.2教程-ROS 导航

1. Teleport 示例 ROS 服务的作用: 提供了一种同步、请求-响应的通信方式,用于执行那些需要即时获取结果或状态反馈的一次性操作或查询。 Teleport 服务在 ROS 仿真(尤其是 Gazebo)和某些简单机器人控制中扮演着瞬移机器人或对象…

DeepSpeed-FastGen:通过 MII 和 DeepSpeed-Inference 实现大语言模型的高吞吐文本生成

温馨提示: 本篇文章已同步至"AI专题精讲" DeepSpeed-FastGen:通过 MII 和 DeepSpeed-Inference 实现大语言模型的高吞吐文本生成 摘要 随着大语言模型(LLM)被广泛应用,其部署与扩展变得至关重要&#xff0…

操作系统:操作系统的结构(Structures of Operating System)

目录 简单结构(Simple Structure) 整体式结构(Monolithic Structure) 什么是 Kernel(内核)? 层次结构(Layered Structure) 微内核结构(Microkernel&#x…

Python柱状图

1.各国GDP柱状图2.各国GDP时间线柱状图

FastGPT:企业级智能问答系统,让知识库触手可及

在信息爆炸的时代,企业如何高效管理和利用海量知识?传统搜索和文档库已难以满足需求。FastGPT正成为企业构建智能知识核心的首选。一、FastGPT:不止于问答的智能知识引擎FastGPT 颠覆了传统知识库的局限,其核心优势在于&#xff1…

探索 MyBatis-Plus

引言在当今的 Java 开发领域,数据库操作是一个至关重要的环节。MyBatis 作为一款优秀的持久层框架,已经被广泛应用。而 MyBatis-Plus 则是在 MyBatis 基础上进行增强的工具,它简化了开发流程,提高了开发效率。本文将详细介绍 MyBa…

Hive【安装 01】hive-3.1.2版本安装配置(含 mysql-connector-java-5.1.47.jar 网盘资源)

我使用的安装文件是 apache-hive-3.1.2-bin.tar.gz ,以下内容均以此版本进行说明。 以下环境测试安装成功: openEuler 22.03 (LTS-SP1)系统 MySQL-8.0.40 1.前置条件 MySQL数据库 我安装的是 mysql-5.7.28 版本的,安装方法可参考《Linux环境…

璞致 PZSDR-P101:ZYNQ7100+AD9361 架构软件无线电平台,重塑宽频信号处理范式

璞致电子 PZSDR-P101 软件无线电平台以 "异构计算 宽频射频 工业级可靠性" 为核心设计理念,基于 Xilinx ZYNQ7100 处理器与 ADI AD9361 射频芯片构建,为工程师提供从 70MHz 到 6GHz 的全频段信号处理解决方案。无论是频谱监测、无线通信原型…