DAY 43 复习日

作业:

kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

划分数据集

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os
from sklearn.model_selection import train_test_split
from shutil import copyfile
import cv2
from torch.nn import functional as F# 数据集划分
data_root = "flowers"  # 数据集根目录
classes = ["daisy", "tulip", "rose", "sunflower", "dandelion"]for folder in ["train", "val", "test"]:os.makedirs(os.path.join(data_root, folder), exist_ok=True)for cls in classes:cls_path = os.path.join(data_root, cls)imgs = [f for f in os.listdir(cls_path) if f.lower().endswith((".jpg", ".jpeg", ".png"))]# 划分数据集(测试集20%,验证集20% of 剩余数据,训练集60%)train_val, test = train_test_split(imgs, test_size=0.2, random_state=42)train, val = train_test_split(train_val, test_size=0.25, random_state=42)# 复制到train/val/test下的类别子文件夹(关键修正!)for split, imgs_list in zip(["train", "val", "test"], [train, val, test]):split_class_path = os.path.join(data_root, split, cls)# 创建子文件夹:train/chamomile/os.makedirs(split_class_path, exist_ok=True)for img in imgs_list:copyfile(os.path.join(cls_path, img), os.path.join(split_class_path, img))

数据预处理

  # 数据预处理(新增旋转增强)# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"使用设备: {device}")# 训练集数据增强(彩色图像通用处理)
train_transform = transforms.Compose([transforms.Resize((224, 224)),         # 调整尺寸为224x224(匹配CNN输入)transforms.RandomCrop(224, padding=4), # 随机裁剪并填充,增加数据多样性transforms.RandomHorizontalFlip(),     # 水平翻转(概率0.5)transforms.RandomRotation(15),     # 新增旋转transforms.ColorJitter(brightness=0.2, contrast=0.2),  # 颜色抖动transforms.ToTensor(),                 # 转换为张量transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) # ImageNet标准归一化
])# 测试集仅归一化,不增强
test_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
]))

加载数据集

 # 数据加载器(保持不变)data_root = "flowers"  # 数据集根目录,需包含5个子类别文件夹train_dataset = datasets.ImageFolder(os.path.join(data_root, "train"), transform=train_transform)val_dataset = datasets.ImageFolder(os.path.join(data_root, "val"), transform=test_transform)test_dataset = datasets.ImageFolder(os.path.join(data_root, "test"), transform=test_transform)# 创建数据加载器
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)# 获取类别名称(自动从文件夹名获取)
class_names = train_dataset.classesprint(f"检测到的类别: {class_names}")  # 确保输出5个类别名称

定义模型

  # 模型定义(新增第4卷积块)
class FlowerCNN(nn.Module):def __init__(self, num_classes=5):super().__init__()# 卷积块1self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.bn1 = nn.BatchNorm2d(32)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(2, 2)  # 224→112# 卷积块2self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.bn2 = nn.BatchNorm2d(64)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(2, 2)  # 112→56# 卷积块3self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.bn3 = nn.BatchNorm2d(128)self.relu3 = nn.ReLU()self.pool3 = nn.MaxPool2d(2, 2)  # 56→28# 卷积块4self.conv4 = nn.Conv2d(128, 256, 3, padding=1)  # 新增卷积块self.bn4 = nn.BatchNorm2d(256)self.relu4 = nn.ReLU()self.pool4 = nn.MaxPool2d(2, 2)  # 28→14# 全连接层self.fc1 = nn.Linear(256 * 14 * 14, 512)   # 计算方式:224->112->56->28->14(四次池化后尺寸)self.dropout = nn.Dropout(0.5)self.fc2 = nn.Linear(512, num_classes)   # 输出5个类别def forward(self, x):x = self.pool1(self.relu1(self.bn1(self.conv1(x))))x = self.pool2(self.relu2(self.bn2(self.conv2(x))))x = self.pool3(self.relu3(self.bn3(self.conv3(x))))x = self.pool4(self.relu4(self.bn4(self.conv4(x))))  # 新增池化x = x.view(x.size(0), -1)      # 展平特征图x = self.dropout(self.relu1(self.fc1(x)))x = self.fc2(x)return x# 初始化模型并移至设备# 训练配置(增加轮数,使用StepLR)
model = FlowerCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

训练模型

 def train_model(epochs=30):best_val_acc = 0.0train_loss, val_loss, train_acc, val_acc = [], [], [], []for epoch in range(epochs):model.train()running_loss, correct, total = 0.0, 0, 0for data, target in train_loader:data, target = data.to(device), target.to(device)optimizer.zero_grad()outputs = model(data)loss = criterion(outputs, target)loss.backward()optimizer.step()running_loss += loss.item()_, pred = torch.max(outputs, 1)correct += (pred == target).sum().item()total += target.size(0)# 计算 epoch 指标epoch_train_loss = running_loss / len(train_loader)epoch_train_acc = 100 * correct / total# 验证集评估model.eval()val_running_loss, val_correct, val_total = 0.0, 0, 0with torch.no_grad():for data, target in val_loader:data, target = data.to(device), target.to(device)outputs = model(data)val_running_loss += criterion(outputs, target).item()_, pred = torch.max(outputs, 1)val_correct += (pred == target).sum().item()val_total += target.size(0)epoch_val_loss = val_running_loss / len(val_loader)epoch_val_acc = 100 * val_correct / val_totalscheduler.step()# 记录历史数据train_loss.append(epoch_train_loss)val_loss.append(epoch_val_loss)train_acc.append(epoch_train_acc)val_acc.append(epoch_val_acc)print(f"Epoch {epoch+1}/{epochs} | 训练损失: {epoch_train_loss:.4f} 验证准确率: {epoch_val_acc:.2f}%")# 保存最佳模型if epoch_val_acc > best_val_acc:torch.save(model.state_dict(), "best_model.pth")best_val_acc = epoch_val_acc# 绘制曲线plt.figure(figsize=(12, 4))# 损失曲线plt.subplot(1, 2, 1); plt.plot(train_loss, label='训练损失'); plt.plot(val_loss, label='验证损失'); plt.legend()# 准确率曲线plt.subplot(1, 2, 2); plt.plot(train_acc, label='训练准确率'); plt.plot(val_acc, label='验证准确率'); plt.legend()plt.show()return best_val_acc# 训练与可视化(保持不变)
print("开始训练...")
train_model(epochs=30)
print("训练完成,开始可视化...")
开始训练...
Epoch 1/30 | 训练损失: 5.8699 验证准确率: 47.05%
Epoch 2/30 | 训练损失: 1.3307 验证准确率: 53.76%
Epoch 3/30 | 训练损失: 1.3045 验证准确率: 52.95%
Epoch 4/30 | 训练损失: 1.2460 验证准确率: 55.38%
Epoch 5/30 | 训练损失: 1.2342 验证准确率: 49.48%
Epoch 6/30 | 训练损失: 1.2442 验证准确率: 54.10%
Epoch 7/30 | 训练损失: 1.2309 验证准确率: 50.75%
Epoch 8/30 | 训练损失: 1.2172 验证准确率: 56.65%
Epoch 9/30 | 训练损失: 1.2025 验证准确率: 56.53%
Epoch 10/30 | 训练损失: 1.1733 验证准确率: 56.53%
Epoch 11/30 | 训练损失: 1.1167 验证准确率: 61.04%
Epoch 12/30 | 训练损失: 1.0763 验证准确率: 64.28%
Epoch 13/30 | 训练损失: 1.0564 验证准确率: 63.12%
Epoch 14/30 | 训练损失: 1.0469 验证准确率: 62.31%
Epoch 15/30 | 训练损失: 1.0295 验证准确率: 65.09%
Epoch 16/30 | 训练损失: 1.0365 验证准确率: 65.78%
Epoch 17/30 | 训练损失: 1.0091 验证准确率: 66.71%
Epoch 18/30 | 训练损失: 1.0152 验证准确率: 65.32%
Epoch 19/30 | 训练损失: 0.9794 验证准确率: 65.43%
Epoch 20/30 | 训练损失: 0.9875 验证准确率: 68.90%
Epoch 21/30 | 训练损失: 0.9496 验证准确率: 69.94%
Epoch 22/30 | 训练损失: 0.9608 验证准确率: 69.71%
Epoch 23/30 | 训练损失: 0.9342 验证准确率: 69.71%
Epoch 24/30 | 训练损失: 0.9586 验证准确率: 69.25%
Epoch 25/30 | 训练损失: 0.9554 验证准确率: 69.60%
Epoch 26/30 | 训练损失: 0.9463 验证准确率: 69.83%
Epoch 27/30 | 训练损失: 0.9373 验证准确率: 69.94%
Epoch 28/30 | 训练损失: 0.9282 验证准确率: 69.48%
Epoch 29/30 | 训练损失: 0.9130 验证准确率: 69.36%
Epoch 30/30 | 训练损失: 0.9585 验证准确率: 69.94%

Grad-CAM可视化

class GradCAM:def __init__(self, model, target_layer_name="conv3"):self.model = model.eval()                       # 设置模型为评估模式self.target_layer_name = target_layer_name      # 目标卷积层名称(需与模型定义一致)self.gradients, self.activations = None, None   # 存储梯度,激活值# 注册前向和反向钩子函数for name, module in model.named_modules():if name == target_layer_name:module.register_forward_hook(self.forward_hook)module.register_backward_hook(self.backward_hook)breakdef forward_hook(self, module, input, output):"""前向传播时保存激活值"""self.activations = output.detach()  # 不记录梯度的激活值def backward_hook(self, module, grad_input, grad_output):"""反向传播时保存梯度"""self.gradients = grad_output[0].detach()    # 提取梯度(去除批量维度)def generate(self, input_image, target_class=None):"""生成Grad-CAM热力图"""# 前向传播获取模型输出outputs = self.model(input_image)   # 输出形状: [batch_size, num_classes]target_class = torch.argmax(outputs, dim=1).item() if target_class is None else target_class# 反向传播计算梯度self.model.zero_grad()one_hot = torch.zeros_like(outputs); one_hot[0, target_class] = 1outputs.backward(gradient=one_hot)# 计算通道权重(全局平均池化)weights = torch.mean(self.gradients, dim=(2, 3))# 生成类激活映射(CAM)cam = torch.sum(self.activations[0] * weights[0][:, None, None], dim=0)cam = F.relu(cam); cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) cam = F.interpolate(cam.unsqueeze(0).unsqueeze(0), size=(224, 224), mode='bilinear').squeeze()return cam.cpu().numpy(), target_class# 可视化函数(关键修改:增加图像尺寸统一和颜色通道转换)
def visualize_gradcam(img_path, model, class_names, alpha=0.6):"""可视化Grad-CAM结果:param img_path: 测试图像路径:param model: 训练好的模型:param class_names: 类别名称列表:param alpha: 热力图透明度(0-1)"""# 加载图像并统一尺寸为224x224(解决尺寸不匹配问题)img = Image.open(img_path).convert("RGB").resize((224, 224))img_np = np.array(img) / 255.0# 预处理图像(与模型输入一致)transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])input_tensor = transform(img).unsqueeze(0).to(device)# 生成Grad-CAM热力图grad_cam = GradCAM(model, target_layer_name="conv3")heatmap, pred_class = grad_cam.generate(input_tensor)# 热力图后处理(解决颜色通道问题)heatmap = np.uint8(255 * heatmap); heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) / 255.0; heatmap_rgb = heatmap[:, :, ::-1]# 叠加原始图像和热力图(尺寸和通道完全匹配)superimposed = cv2.addWeighted(img_np, 1 - alpha, heatmap, alpha, 0)# 绘制结果plt.figure(figsize=(12, 4))plt.subplot(1, 3, 1); plt.imshow(img_np); plt.title(f"原始图像\n真实类别: {img_path.split('/')[-2]}"); plt.axis('off')plt.subplot(1, 3, 2); plt.imshow(heatmap_rgb); plt.title(f"Grad-CAM热力图\n预测类别: {class_names[pred_class]}"); plt.axis('off')plt.subplot(1, 3, 3); plt.imshow(superimposed); plt.title("叠加热力图"); plt.axis('off')plt.tight_layout(); plt.show()# 选择测试图像(需存在且路径正确)
test_image_path = "flowers/tulip/100930342_92e8746431_n.jpg"  # 执行可视化
visualize_gradcam(test_image_path, model, class_names)

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/918567.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/918567.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Flink运行时的实现细节

一、Flink集群中各角色运行架构先说Flink集群中的角色吧,有三个分别是客户端(Client)、JobManager、TaskManager。客户端负责接收作业任务并进行解析,将解析后的二进制数据发送给JobManager;JobManager是作业调度中心,负责对所有作…

思科、华为、华三如何切换三层端口?

三层交换机融合了二层交换技术与三层转发技术,具备强大的网络功能。主流厂商(思科、H3C、华为)的三层交换机均支持二层端口与三层端口的相互切换,但具体命令存在差异。本文将详细介绍三大厂商设备的端口切换方法及相关知识。一、各…

springboot的基础要点

Spring Boot 的核心设计理念是 ​​"约定优于配置"​​(Convention Over Configuration),旨在简化 Spring 应用的初始搭建和开发过程。以下是需要掌握的核心基础要点:​一、核心机制​​自动配置 (Auto-Configuration)​…

lesson36:MySQL从入门到精通:全面掌握数据库操作与核心原理

目录 一、引言:为什么选择MySQL? 二、MySQL安装与登录配置 2.1 环境准备 2.2 登录指令详解 三、数据库核心操作 3.1 数据库生命周期管理 3.2 数据库存储引擎选择 四、数据表设计与操作 4.1 表结构创建(含数据类型详解) …

Spring源码解析 - SpringApplication run流程-prepareContext源码分析

prepareContext源码分析 private void prepareContext(DefaultBootstrapContext bootstrapContext, ConfigurableApplicationContext context,ConfigurableEnvironment environment, SpringApplicationRunListeners listeners,ApplicationArguments applicationArguments, Bann…

HIS系统:医院信息化建设的核心,采用Angular+Java技术栈,集成MySQL、Redis等技术,实现医院全业务流程管理。

HIS系统在医院信息化建设中扮演着核心的角色。它是一个综合性的信息系统,旨在管理和运营医院的各种业务,包括门诊、住院、财务、物资、科研等。技术细节:前端:AngularNginx后台:JavaSpring,SpringBoot&…

深度学习-卷积神经网络-LeNet

卷积神经网络是一种专门用于处理具有网格结构数据(如图像、音频等)的深度学习模型。它通过卷积层自动提取数据中的特征,利用局部连接和参数共享的特性减少了模型的参数数量,降低了过拟合的风险,同时能够有效地捕捉数据…

【Java项目与数据库、Maven的关系详解】

Java项目与数据库、Maven的关系详解 一、Java项目是否都需要连接本地数据库? 不一定,这取决于项目类型和需求: 1. 需要数据库的项目类型项目类型数据库作用典型场景Web应用存储用户数据/业务数据电商系统、CMS服务端程序持久化数据金融交易系…

两个Maven工程,使用idea开发,工程A中依赖了工程B,改了工程B,工程A如何获取最新代码

两个Maven工程,使用idea开发,工程A中依赖了工程B,改了工程B,工程A如何获取最新代码 如果工程B的版本是快照,那么如下。 步骤一 工程B 执行 clean package install deploy 步骤二 工程A 刷新Maven

奥比中光与地平线、地瓜机器人达成战略合作,携手推动机器人智能化

摘要:机器人“慧眼”与“智脑”强强联合!8月11日,奥比中光与地平线及其控股子公司地瓜机器人在北京签订合作协议,双方将在机器人智能化领域展开深度合作,充分发挥各自的技术与产品优势,携手推动机器人产业的…

【Linux】Tomcat

Tomcat简介Tomcat 服务器是一个免费的开放源代码的Web 应用服务器,属于轻量级应用服务器,在中小型系统和 并发访问用户不是很多的场合下被普遍使用,Tomcat 具有处理HTML页面的功能,它还是一个Servlet和 JSP容器Tomcat的使用安装ja…

Putting it all together 将所有内容整合在一起

官方链接 https://www.youtube.com/watch?vAa_FAA3v22g&t1s Task1 Putting It All Together 将所有内容整合在一起 图片版 文字版 Putting It All Together 将所有内容整合在一起 From the previous modules, youll have learned that quite a lot of things go on b…

Python 闭包详解:从变量作用域到实战案例

一、变量作用域基础在 Python 中,变量根据作用范围可分为三类:全局变量:定义在函数外部的变量,作用范围是整个程序。如果在函数内部需要修改全局变量,必须使用global关键字声明。局部变量:定义在函数内部的…

Docker 跨主机容器之间的通信macvlan

默认一个物理网卡,只有一个物理mac地址,虚拟多个mac地址 缺点:每次需要手动配置ip地址,容易ip地址冲突。类似于保存到execl表格里面。 两台物理机: docker-01和docker-02 创建macvlan网络 [rootdocker-01 ~]# docker n…

android 换肤框架详解1-换肤逻辑基本

android 换肤框架详解1-换肤逻辑基本-CSDN博客 android 换肤框架详解2-LayoutInflater源码解析-CSDN博客 android 换肤框架详解3-自动换肤原理梳理-CSDN博客 换肤框架流程 1,通过AssetManager获取换肤的资源文件 2,通过原文件中的resId获取到res名称…

NEON性能优化总结

转自 NEON优化:性能优化经验总结-CSDN博客 NEON优化:性能优化经验总结 1. 什么是 NEON Arm Adv SIMD 历史 2. 寄存器 3. NEON 命名方式 4. 优化技巧 5. 优化 NEON 代码(Armv7-A内容,但区别不大) 5.1 优化 NEON 汇编代码 …

计算机网络摘星题库800题笔记 第2章 物理层

第2章 物理层2.1 物理层概述题组闯关1.采用以下哪种设备,可以使数字信号传输得更远 ( )。 A. 放大器 B. 中继器 C. 网桥 D. 路由器1.【参考答案】B 【解析】选项 A 放大器只是单纯地放大信号、抑制噪音和干扰。选项 B 中继器是把一根线缆中的电或者光信号传递给另一…

导入文件到iPhone实现

我们有时候开发需要加载一些自己的文件&#xff0c;这个时候就需要导入文件到iPhone等设备。在info里面open as source code&#xff0c;加入如下配置&#xff1a;<!-- 开启 iTunes / Finder 文件共享 --><key>UIFileSharingEnabled</key><true/>或者o…

Ubuntu Server系统安装磁盘分区方案

最近打算把家里的旧电脑利用起来&#xff0c;装上Ubuntu Server 24.04.3 LTS作为一个家用NAS服务器&#xff0c;但是给旧电脑安装系统时遇到了一些问题&#xff0c;遂记录下来 GPT分区与MBR分区 GPT 指的是 GUID Partition Table&#xff08;全局唯一标识分区表&#xff09;&am…

1小时 MySQL 数据库基础速通

目录 一、MySQL安装配置 1、下载mysql 2、下载mysql-shell 二、MySQL基本概念 1. 数据库&#xff08;Database&#xff09; 2. 表&#xff08;Table&#xff09; 3. 数据类型&#xff08;Data Type&#xff09; 4. 主键&#xff08;Primary Key&#xff09; 5. 索引&am…