python学习打卡day43

DAY 43 复习日

作业:
kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

@浙大疏锦行

数据集使用猫狗数据集,训练集中包含猫图像4000张、狗图像4005张。测试集包含猫图像1012张,狗图像1013张。以下是数据集的下载地址。

猫和狗 --- Cat and Dog

1.数据集加载与数据预处理

我这里对数据集文件路径做了改变

C:\Users\vijay\Desktop\1\

├── train\

│      ├── cats\ 

│      └── dogs\

└── test\

        ├── cats\ 

        └── dags\ 

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F# 设置随机种子确保结果可复现
torch.manual_seed(42)
np.random.seed(42)# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题# 检查GPU是否可用
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")# 1. 数据预处理
# 训练集:使用多种数据增强方法提高模型泛化能力
train_transform = transforms.Compose([# 新增:调整图像大小为统一尺寸transforms.Resize((32, 32)),  # 确保所有图像都是32x32像素transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.RandomRotation(15),transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])# 测试集:仅进行必要的标准化,保持数据原始特性
test_transform = transforms.Compose([# 新增:调整图像大小为统一尺寸transforms.Resize((32, 32)),  # 确保所有图像都是32x32像素transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])# 定义数据集根目录
root = r'C:\Users\vijay\Desktop\1'train_dataset = datasets.ImageFolder(root=root + '/train',  # 指向 train 子文件夹transform=train_transform
)
test_dataset = datasets.ImageFolder(root=root + '/test',  # 指向 test 子文件夹transform=test_transform
)# 打印类别信息,确认数据加载正确
print(f"训练集类别: {train_dataset.classes}")
print(f"测试集类别: {test_dataset.classes}")# 3. 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

2.模型训练与评估 

# 定义一个简单的CNN模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 第一个卷积层,输入通道为3(彩色图像),输出通道为32,卷积核大小为3x3,填充为1以保持图像尺寸不变self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)# 第二个卷积层,输入通道为32,输出通道为64,卷积核大小为3x3,填充为1self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)# 第三个卷积层,输入通道为64,输出通道为128,卷积核大小为3x3,填充为1self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)# 最大池化层,池化核大小为2x2,步长为2,用于下采样,减少数据量并提取主要特征self.pool = nn.MaxPool2d(2, 2)# 第一个全连接层,输入特征数为128 * 4 * 4(经过前面卷积和池化后的特征维度),输出为512self.fc1 = nn.Linear(128 * 4 * 4, 512)# 第二个全连接层,输入为512,输出为2(对应猫和非猫两个类别)self.fc2 = nn.Linear(512, 2)def forward(self, x):# 第一个卷积层后接ReLU激活函数和最大池化操作,经过池化后图像尺寸变为原来的一半,这里输出尺寸变为16x16x = self.pool(F.relu(self.conv1(x)))# 第二个卷积层后接ReLU激活函数和最大池化操作,输出尺寸变为8x8x = self.pool(F.relu(self.conv2(x)))# 第三个卷积层后接ReLU激活函数和最大池化操作,输出尺寸变为4x4x = self.pool(F.relu(self.conv3(x)))# 将特征图展平为一维向量,以便输入到全连接层x = x.view(-1, 128 * 4 * 4)# 第一个全连接层后接ReLU激活函数x = F.relu(self.fc1(x))# 第二个全连接层输出分类结果x = self.fc2(x)return x# 初始化模型
model = SimpleCNN()
print("模型已创建")# 如果有GPU则使用GPU,将模型转移到对应的设备上
model = model.to(device)# 训练模型
def train_model(model, train_loader, test_loader, epochs=10):# 定义损失函数为交叉熵损失,用于分类任务criterion = nn.CrossEntropyLoss()# 定义优化器为Adam,用于更新模型参数,学习率设置为0.001optimizer = torch.optim.Adam(model.parameters(), lr=0.001)for epoch in range(epochs):# 训练阶段model.train()running_loss = 0.0correct = 0total = 0for i, data in enumerate(train_loader, 0):# 从数据加载器中获取图像和标签inputs, labels = data# 将图像和标签转移到对应的设备(GPU或CPU)上inputs, labels = inputs.to(device), labels.to(device)# 清空梯度,避免梯度累加optimizer.zero_grad()# 模型前向传播得到输出outputs = model(inputs)# 计算损失loss = criterion(outputs, labels)# 反向传播计算梯度loss.backward()# 更新模型参数optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()if i % 100 == 99:# 每100个批次打印一次平均损失和准确率print(f'[{epoch + 1}, {i + 1}] 损失: {running_loss / 100:.3f} | 准确率: {100.*correct/total:.2f}%')running_loss = 0.0# 测试阶段model.eval()test_loss = 0correct = 0total = 0with torch.no_grad():for data in test_loader:images, labels = dataimages, labels = images.to(device), labels.to(device)outputs = model(images)test_loss += criterion(outputs, labels).item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels).sum().item()print(f'测试集 [{epoch + 1}] 损失: {test_loss/len(test_loader):.3f} | 准确率: {100.*correct/total:.2f}%')print("训练完成")return model# 训练模型
try:# 尝试加载预训练模型(如果存在)model.load_state_dict(torch.load('cat_classifier.pth'))print("已加载预训练模型")
except:print("无法加载预训练模型,训练新模型")model = train_model(model, train_loader, test_loader, epochs=10)# 保存训练后的模型参数torch.save(model.state_dict(), 'cat_classifier.pth')# 设置模型为评估模式
model.eval()

3. Grad-CAM实现

# Grad-CAM实现
class GradCAM:def __init__(self, model, target_layer):self.model = modelself.target_layer = target_layerself.gradients = Noneself.activations = None# 注册钩子,用于获取目标层的前向传播输出和反向传播梯度self.register_hooks()def register_hooks(self):# 前向钩子函数,在目标层前向传播后被调用,保存目标层的输出(激活值)def forward_hook(module, input, output):self.activations = output.detach()# 反向钩子函数,在目标层反向传播后被调用,保存目标层的梯度def backward_hook(module, grad_input, grad_output):self.gradients = grad_output[0].detach()# 在目标层注册前向钩子和反向钩子self.target_layer.register_forward_hook(forward_hook)self.target_layer.register_backward_hook(backward_hook)def generate_cam(self, input_image, target_class=None):# 前向传播,得到模型输出model_output = self.model(input_image)if target_class is None:# 如果未指定目标类别,则取模型预测概率最大的类别作为目标类别target_class = torch.argmax(model_output, dim=1).item()# 清除模型梯度,避免之前的梯度影响self.model.zero_grad()# 反向传播,构造one-hot向量,使得目标类别对应的梯度为1,其余为0,然后进行反向传播计算梯度one_hot = torch.zeros_like(model_output)one_hot[0, target_class] = 1model_output.backward(gradient=one_hot)# 获取之前保存的目标层的梯度和激活值gradients = self.gradientsactivations = self.activations# 对梯度进行全局平均池化,得到每个通道的权重,用于衡量每个通道的重要性weights = torch.mean(gradients, dim=(2, 3), keepdim=True)# 加权激活映射,将权重与激活值相乘并求和,得到类激活映射的初步结果cam = torch.sum(weights * activations, dim=1, keepdim=True)# ReLU激活,只保留对目标类别有正贡献的区域,去除负贡献的影响cam = F.relu(cam)# 调整大小并归一化,将类激活映射调整为与输入图像相同的尺寸(32x32),并归一化到[0, 1]范围cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)cam = cam - cam.min()cam = cam / cam.max() if cam.max() > 0 else camreturn cam.cpu().squeeze().numpy(), target_class# 可视化Grad-CAM结果的函数
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
# 设置中文字体支持
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
# 选择一个随机图像
# idx = np.random.randint(len(test_dataset))
idx = 102  # 选择测试集中的第101张图片 (索引从0开始)
image, label = test_dataset[idx]
print(f"选择的图像类别: {test_dataset.classes[label]}")# 转换图像以便可视化
def tensor_to_np(tensor):img = tensor.cpu().numpy().transpose(1, 2, 0)mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])img = std * img + meanimg = np.clip(img, 0, 1)return img# 添加批次维度并移动到设备
input_tensor = image.unsqueeze(0).to(device)# 初始化Grad-CAM(选择最后一个卷积层)
grad_cam = GradCAM(model, model.conv3)# 生成热力图
heatmap, pred_class = grad_cam.generate_cam(input_tensor)# 可视化
plt.figure(figsize=(12, 4))# 原始图像
plt.subplot(1, 3, 1)
plt.imshow(tensor_to_np(image))
plt.title(f"原始图像: {test_dataset.classes[label]}")
plt.axis('off')# 热力图
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM热力图: {test_dataset.classes[pred_class]}")
plt.axis('off')# 叠加的图像
plt.subplot(1, 3, 3)
img = tensor_to_np(image)
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("叠加热力图")
plt.axis('off')plt.tight_layout()
plt.savefig('grad_cam_result.png')
plt.show()print("Grad-CAM可视化完成。已保存为grad_cam_result.png")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/news/908058.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大数据与数据分析【数据分析全栈攻略:爬虫+处理+可视化+报告】

- 第 100 篇 - Date: 2025 - 05 - 25 Author: 郑龙浩/仟墨 大数据与数据分析 文章目录 大数据与数据分析一 大数据是什么?1 定义2 大数据的来源3 大数据4个方面的典型特征(4V)4 大数据的应用领域5 数据分析工具6 数据是五种生产要素之一 二 …

uniapp 开发企业微信小程序,如何区别生产环境和测试环境?来处理不同的服务请求

在 uniapp 开发企业微信小程序时,区分生产环境和测试环境是常见需求。以下是几种可靠的方法,帮助你根据环境处理不同的服务请求: 一、通过条件编译区分(推荐) 使用 uniapp 的 条件编译 语法,在代码中标记…

青少年编程与数学 02-020 C#程序设计基础 15课题、异常处理

青少年编程与数学 02-020 C#程序设计基础 15课题、异常处理 一、异常1. 异常的分类2. 异常的作用小结 二、异常处理1. 异常处理的定义2. 异常处理的主要组成部分3. 异常处理的作用小结 三、C#异常处理1. 异常的基本概念2. 异常处理的关键字3. 异常处理的流程4. 自定义异常5. 异…

云原生时代 Kafka 深度实践:05性能调优与场景实战

5.1 性能调优全攻略 Producer调优 批量发送与延迟发送 通过调整batch.size和linger.ms参数提升吞吐量: props.put(ProducerConfig.BATCH_SIZE_CONFIG, 16384); // 默认16KB props.put(ProducerConfig.LINGER_MS_CONFIG, 10); // 等待10ms以积累更多消息ba…

在 Dify 项目中的 Celery:异步任务的实现与集成

Celery 是一个强大而灵活的分布式任务队列系统,旨在帮助应用程序在后台异步运行耗时的任务,提高系统的响应速度和性能。在 Dify 项目中,Celery 被广泛用于处理异步任务和定时任务,并与其他工具(如 Sentry、OpenTelemet…

Pytorch Geometric官方例程pytorch_geometric/examples/link_pred.py环境安装教程及图数据集制作

最近需要训练图卷积神经网络(Graph Convolution Neural Network, GCNN),在配置GCNN环境上总结了一些经验。 我觉得对于初学者而言,图神经网络的训练会有2个难点: ①环境配置 ②数据集制作 一、环境配置 我最初光想…

2025年微信小程序开发:AR/VR与电商的最新案例

引言 微信小程序自2017年推出以来,已成为中国移动互联网生态的核心组成部分。根据最新数据,截至2025年,微信小程序的日活跃用户超过4.5亿,总数超过430万,覆盖电商、社交、线下服务等多个领域(WeChat Mini …

互联网向左,区块链向右

2008年,中本聪首次提出了比特币的设想,这打开了去中心化的大门。 比特币白皮书清晰的描述了去中心化支付的解决方案,并分别从以下几个方面阐述了他的理念: 一、由转账双方点对点的通讯,而不通过中心化的第三方&#xf…

PV操作的C++代码示例讲解

文章目录 一、PV操作基本概念(一)信号量(二)P操作(三)V操作 二、PV操作的意义三、C中实现PV操作的方法(一)使用信号量实现PV操作代码解释: (二)使…

《对象创建的秘密:Java 内存布局、逃逸分析与 TLAB 优化详解》

大家好呀!今天我们来聊聊Java世界里那些"看不见摸不着"但又超级重要的东西——对象在内存里是怎么"住"的,以及JVM这个"超级管家"是怎么帮我们优化管理的。放心,我会用最接地气的方式讲解,保证连小学…

简单实现Ajax基础应用

Ajax不是一种技术,而是一个编程概念。HTML 和 CSS 可以组合使用来标记和设置信息样式。JavaScript 可以修改网页以动态显示,并允许用户与新信息进行交互。内置的 XMLHttpRequest 对象用于在网页上执行 Ajax,允许网站将内容加载到屏幕上而无需…

详解开漏输出和推挽输出

开漏输出和推挽输出 以上是 GPIO 配置为输出时的内部示意图,我们要关注的其实就是这两个 MOS 管的开关状态,可以组合出四种状态: 两个 MOS 管都关闭时,输出处于一个浮空状态,此时他对其他点的电阻是无穷大的&#xff…

Matlab实现LSTM-SVM回归预测,作者:机器学习之心

Matlab实现LSTM-SVM回归预测,作者:机器学习之心 目录 Matlab实现LSTM-SVM回归预测,作者:机器学习之心效果一览基本介绍程序设计参考资料 效果一览 基本介绍 代码主要功能 该代码实现了一个LSTM-SVM回归预测模型,核心流…

Leetcode - 周赛 452

目录 一,3566. 等积子集的划分方案二,3567. 子矩阵的最小绝对差三,3568. 清理教室的最少移动四,3569. 分割数组后不同质数的最大数目 一,3566. 等积子集的划分方案 题目列表 本题有两种做法,dfs 选或不选…

【FAQ】HarmonyOS SDK 闭源开放能力 —Account Kit(5)

1.问题描述: 集成华为一键登录的LoginWithHuaweiIDButton, 但是Button默认名字叫 “华为账号一键登录”,太长无法显示,能否简写成“一键登录”与其他端一致? 解决方案: 问题分两个场景: 一、…

Asp.Net Core SignalR的分布式部署

文章目录 前言一、核心二、解决方案架构三、实现方案1.使用 Azure SignalR Service2.Redis Backplane(Redis 背板方案)3.负载均衡配置粘性会话要求无粘性会话方案(仅WebSockets)完整部署示例(Redis Docker)性能优化技…

L2-054 三点共线 - java

L2-054 三点共线 语言时间限制内存限制代码长度限制栈限制Java (javac)2600 ms512 MB16KB8192 KBPython (python3)2000 ms256 MB16KB8192 KB其他编译器2000 ms64 MB16KB8192 KB 题目描述: 给定平面上 n n n 个点的坐标 ( x _ i , y _ i ) ( i 1 , ⋯ , n ) (x\_i…

【 java 基础知识 第一篇 】

目录 1.概念 1.1.java的特定有哪些? 1.2.java有哪些优势哪些劣势? 1.3.java为什么可以跨平台? 1.4JVM,JDK,JRE它们有什么区别? 1.5.编译型语言与解释型语言的区别? 2.数据类型 2.1.long与int类型可以互转吗&…

高效背诵英语四级范文

以下是结合认知科学和实战验证的 ​​高效背诵英语作文五步法​​,助你在30分钟内牢固记忆一篇作文,特别适配考前冲刺场景: 📝 ​​一、解构作文(5分钟)​​ ​​拆解逻辑框架​​ 用荧光笔标出&#xff…

RHEL7安装教程

RHEL7安装教程 下载RHEL7镜像 通过网盘分享的文件:RHEL 7.zip 链接: https://pan.baidu.com/s/1ExLhdJigj-tcrHJxIca5XA?pwdjrrj 提取码: jrrj --来自百度网盘超级会员v6的分享安装 1.打开VMware,新建虚拟机,选择自定义然后下一步 2.点击…