DAY43打卡

@浙大疏锦行

kaggle找到一个图像数据cnn网络进行训练并且grad-cam可视化

进阶并拆分成多个文件

fruit_cnn_project/
├─ data/                # 存放数据集(需手动创建,后续放入图片)
│  ├─ train/            # 训练集图像
│  └─ val/              # 验证集图像
├─ models/              # 模型定义
│  └─ cnn_model.py      # CNN网络结构
├─ utils/               # 工具函数
│  ├─ dataset_utils.py  # 数据加载与预处理
│  ├─ grad_cam.py       # Grad-CAM可视化
│  └─ train_utils.py    # 训练与评估
├─ main.py              # 主程序
└─ requirements.txt     # 依赖列表(可选)
# 第一部分:导入库
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline# 第二部分:数据加载与预处理
def load_data():data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])train_dataset = datasets.ImageFolder(root='data/train', transform=data_transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)test_dataset = datasets.ImageFolder(root='data/test', transform=data_transform)test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)return train_loader, test_loader# 第三部分:模型定义
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)self.fc1 = nn.Linear(32 * 56 * 56, 128)self.fc2 = nn.Linear(128, 2)def forward(self, x):x = self.pool(self.relu(self.conv1(x)))x = self.pool(self.relu(self.conv2(x)))x = x.view(-1, 32 * 56 * 56)x = self.relu(self.fc1(x))x = self.fc2(x)return x# 第四部分:模型训练
train_loader, _ = load_data()
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 10for epoch in range(num_epochs):running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch + 1}, Loss: {running_loss / len(train_loader)}')torch.save(model.state_dict(), 'trained_model.pth')# 第五部分:模型测试
_, test_loader = load_data()
model = SimpleCNN()
model.load_state_dict(torch.load('trained_model.pth'))
model.eval()
correct = 0
total = 0with torch.no_grad():for data in test_loader:images, labels = dataoutputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the network on the test images: {100 * correct / total}%')# 第六部分:Grad-CAM可视化(修复版)
def get_activation():activation = {}def hook(model, input, output):activation['target_layer'] = output.detach()return hook, activationdef grad_cam(model, image, target_class_index):hook, activation = get_activation()target_layer = model.conv2target_layer.register_forward_hook(hook)model.eval()image = image.unsqueeze(0)image.requires_grad_(True)output = model(image)one_hot = torch.zeros(1, output.size()[-1]).to(image.device)one_hot[0][target_class_index] = 1output.backward(gradient=one_hot, retain_graph=True)gradients = image.grad[0].cpu().numpy()# 从activation字典中获取激活图activation_map = activation['target_layer'].cpu().numpy()[0]weights = np.mean(gradients, axis=(1, 2))cam = np.zeros(activation_map.shape[1:], dtype=np.float32)for i, w in enumerate(weights):cam += w * activation_map[i]cam = np.maximum(cam, 0)cam = F.interpolate(torch.from_numpy(cam).unsqueeze(0).unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False)[0][0].numpy()cam = (cam - cam.min()) / (cam.max() - cam.min())return cam# 可视化前几张测试图片
dataiter = iter(test_loader)
images, labels = dataiter.next()for i in range(5):  # 可视化前5张图片image = images[i]label = labels[i].item()cam = grad_cam(model, image, label)plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.imshow(image.permute(1, 2, 0).numpy())plt.title(f'Original Image (Class: {label})')plt.axis('off')plt.subplot(1, 2, 2)plt.imshow(image.permute(1, 2, 0).numpy())plt.imshow(cam, cmap='jet', alpha=0.5)plt.title('Grad-CAM Visualization')plt.axis('off')plt.tight_layout()plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/pingmian/83327.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[蓝桥杯C++ 2024 国 B ] 立定跳远(二分)

题目描述 在运动会上,小明从数轴的原点开始向正方向立定跳远。项目设置了 n n n 个检查点 a 1 , a 2 , ⋯ , a n a_1, a_2, \cdots , a_n a1​,a2​,⋯,an​ 且 a i ≥ a i − 1 > 0 a_i \ge a_{i−1} > 0 ai​≥ai−1​>0。小明必须先后跳跃到每个检查…

LINUX530 rsync定时同步 环境配置

rsync定时代码同步 环境配置 关闭防火墙 selinux systemctl stop firewalld systemctl disable firewalld setenforce 0 vim /etc/selinux/config SELINUXdisable设置主机名 hostnamectl set-hostname code hostnamectl set-hostname backup设置静态地址 cd /etc/sysconfi…

鸿蒙OSUniApp结合机器学习打造智能图像分类应用:HarmonyOS实践指南#三方框架 #Uniapp

UniApp结合机器学习打造智能图像分类应用:HarmonyOS实践指南 引言 在移动应用开发领域,图像分类是一个既经典又充满挑战的任务。随着机器学习技术的发展,我们现在可以在移动端实现高效的图像分类功能。本文将详细介绍如何使用UniApp结合Ten…

【Redis】大key问题详解

目录 1、什么是大key2、大key的危害【1】阻塞风险【2】网络阻塞【3】内存不均【4】持久化问题 3、如何发现大key【1】使用内置命令【2】使用memory命令(Redis 4.0)【3】使用scan命令【4】监控工具 4、解决方案【1】拆分大key【2】使用合适的数据结构【3】…

redis核心知识点

Redis是一种基于内存的数据库,对数据的读写操作都是在内存中完成,因此读写速度非常快,常用于缓存,消息队列、分布式锁等场景。 Redis 提供了多种数据类型来支持不同的业务场景,比如 String(字符串)、Hash(哈希)、 Lis…

vscode不满足先决条件问题的解决——vscode的老版本安装与禁止更新(附安装包)

目录 起因 vscode更新设置的关闭 安装包 结语 起因 由于主包用的系统是centos的,且版本有点老了,再加上vscode现在不支持老版本的,这对主包来说更是雪上加霜啊 但是主包看了网上很多教程,眼花缭乱,好多配置要改&…

如何成为一名优秀的产品经理(自动驾驶)

一、 夯实核心基础 深入理解智能驾驶技术栈: 感知: 摄像头、雷达(毫米波、激光雷达)、超声波传感器的工作原理、优缺点、融合策略。了解目标检测、跟踪、SLAM等基础算法概念。 定位: GNSS、IMU、高精地图、轮速计等定…

【ISAQB大纲解读】信息隐藏指的是什么

在软件架构中,信息隐藏(Information Hiding) 是核心设计原则之一,由 David Parnas 在 1972 年提出。它强调通过限制对模块内部实现细节的访问,来降低系统复杂度、提高可维护性和可扩展性。在 ISAQB 的学习目标&#xf…

网页前端开发(基础进阶2--JS)

前面学习了html与css,接下来学习JS(JavaScript与Java无关)。 web标准(网页标准)分为3个部分: 1.html主要负责网页的结构(页面的元素和内容) 2.css主要负责网页的表现(…

完全移除内联脚本

说明 日期&#xff1a;2025年5月9日。 内联脚本给跨站脚本攻击&#xff08;XSS&#xff09;留了条路。 示例 日期&#xff1a;2025年5月9日。 如下网页文件a.html&#xff1a; <!-- 内联脚本块 --> <script> function handleClick{ alert("Hello")…

[蓝桥杯]约瑟夫环

约瑟夫环 题目描述 nn 个人的编号是 1 ~ nn&#xff0c;如果他们依编号按顺时针排成一个圆圈&#xff0c;从编号是 1 的人开始顺时针报数。 &#xff08;报数是从 1 报起&#xff09;当报到 kk 的时候&#xff0c;这个人就退出游戏圈。下一个人重新从 1 开始报数。 求最后剩…

电子电气架构 --- 如何应对未来区域式电子电气(E/E)架构的挑战?

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 做到欲望极简,了解自己的真实欲望,不受外在潮流的影响,不盲从,不跟风。把自己的精力全部用在自己。一是去掉多余,凡事找规律,基础是诚信;二是…

isp中的 ISO代表什么意思

isp中的 ISO代表什么意思 在摄影和图像信号处理&#xff08;ISP&#xff0c;Image Signal Processor&#xff09;领域&#xff0c;ISO是一个用于衡量相机图像传感器对光线敏感度的标准参数。它最初源于胶片摄影时代的 “国际标准化组织&#xff08;International Organization …

第十二节:第五部分:集合框架:Set集合的特点、底层原理、哈希表、去重复原理

Set系列集合特点 哈希值 HashSet集合的底层原理 HashSet集合去重复 代码 代码一&#xff1a;整体了解一下Set系列集合的特点 package com.itheima.day20_Collection_set;import java.util.HashSet; import java.util.LinkedHashSet; import java.util.Set; import java.util.…

迈向分布式智能:解析MCP到A2A的通信范式迁移

智能体与外部世界的桥梁之言&#xff1a; 在深入探讨智能体之间的协作机制之前&#xff0c;我们有必要先厘清一个更基础的问题&#xff1a;**单个智能体如何与外部世界建立连接&#xff1f;** 这就引出了我们此前介绍过的 **MCP&#xff08;Model Context Protocol&…

Android Studio 配置之gitignore

1.创建或编辑.gitignore文件 在项目根目录下检查是否已有.gitignore文件。如果没有&#xff0c;创建一个新文件&#xff0c;命名为.gitignore&#xff08;注意文件名前有个点&#xff09;。 添加忽略规则&#xff1a;在.gitignore中添加以下内容&#xff1a; 忽略整个 .idea …

算法:二分查找

1.二分查找 704. 二分查找 - 力扣&#xff08;LeetCode&#xff09; 二分查找算法要确定“二段性”&#xff0c;时间复杂度为O(lonN)。为了防止数据溢出&#xff0c;所以求mid时要用防溢出的方式。 class Solution { public:int search(vector<int>& nums, int tar…

day62—DFS—太平洋大西洋水流问题(LeetCode-417)

题目描述 有一个 m n 的矩形岛屿&#xff0c;与 太平洋 和 大西洋 相邻。 “太平洋” 处于大陆的左边界和上边界&#xff0c;而 “大西洋” 处于大陆的右边界和下边界。 这个岛被分割成一个由若干方形单元格组成的网格。给定一个 m x n 的整数矩阵 heights &#xff0c; hei…

Langchaine4j 流式输出 (6)

Langchaine4j 流式输出 大模型的流式输出是指大模型在生成文本或其他类型的数据时&#xff0c;不是等到整个生成过程完成后再一次性 返回所有内容&#xff0c;而是生成一部分就立即发送一部分给用户或下游系统&#xff0c;以逐步、逐块的方式返回结果。 这样&#xff0c;用户…

自动驾驶与智能交通:构建未来出行的智能引擎

随着人工智能、物联网、5G和大数据等前沿技术的发展&#xff0c;自动驾驶汽车和智能交通系统正以前所未有的速度改变人类的出行方式。这一变革不仅是技术的融合创新&#xff0c;更是推动城市可持续发展的关键支撑。 一、自动驾驶与智能交通的定义 1. 自动驾驶&#xff08;Auto…