python打卡day47

特征图与注意力热图
知识点回顾:

  1. 不同CNN层的特征图:不同通道的特征图
  2. 通道注意力后的特征图和热力图

特征图本质就是不同的卷积核的输出,浅层指的是离输入图近的卷积层,浅层卷积层的特征图通常较大,而深层特征图会经过多次下采样,尺寸显著缩小,尺寸差异过大时,小尺寸特征图在视觉上会显得模糊或丢失细节。步骤逻辑如下:

1. 初始化设置:

  • 将模型设为评估模式,准备类别名称列表

2.数据加载与处理:

  • 从测试数据加载器中获取图像和标签
  • 限制处理量以避免冗余,仅处理前 `num_images` 张图像(如2张)

3. 注册钩子捕获特征图:

  • 为指定层(如 `conv1`, `conv2`, `conv3`)注册前向钩子
  • 钩子函数将这些层的输出(特征图)保存到字典中(以层名为键,特征图(Tensor)为值)

4. 前向传播与特征提取:

  • 模型处理图像,触发钩子函数,获取并保存特征图
  • 移除钩子,避免后续干扰

5. 可视化特征图:对每张图像

  • 恢复原始像素值并显示(反标准化)
  • 为每个目标层创建子图,展示前 `num_channels` 个通道的特征图(每个通道的特征图是单通道灰度图,需独立显示)
  • 每个通道的特征图以网格形式排列,显示通道编号
def visualize_feature_maps(model, test_loader, device, layer_names, num_images=3, num_channels=9):"""可视化指定层的特征图(修复循环冗余问题)参数:model: 模型test_loader: 测试数据加载器layer_names: 要可视化的层名称(如['conv1', 'conv2', 'conv3'])num_images: 可视化的图像总数num_channels: 每个图像显示的通道数(取前num_channels个通道)"""model.eval()  # 设置为评估模式class_names = ['飞机', '汽车', '鸟', '猫', '鹿', '狗', '青蛙', '马', '船', '卡车']# 从测试集加载器中提取指定数量的图像(避免嵌套循环)images_list, labels_list = [], []for images, labels in test_loader:images_list.append(images)labels_list.append(labels)if len(images_list) * test_loader.batch_size >= num_images:break# 拼接并截取到目标数量images = torch.cat(images_list, dim=0)[:num_images].to(device) # torch.cat拼接多个batch的数据labels = torch.cat(labels_list, dim=0)[:num_images].to(device)with torch.no_grad():# 存储各层特征图feature_maps = {}# 保存钩子句柄hooks = []# 定义钩子函数,捕获指定层的输出def hook(module, input, output, name):feature_maps[name] = output.cpu()  # 保存特征图到字典,保存到CPU,避免GPU内存占用# 为每个目标层注册钩子,并保存钩子句柄for name in layer_names:module = getattr(model, name) # 根据层名获取模块hook_handle = module.register_forward_hook(lambda m, i, o, n=name: hook(m, i, o, n)) # 绑定层名到钩子hooks.append(hook_handle)# 前向传播触发钩子_ = model(images)# 正确移除钩子for hook_handle in hooks:hook_handle.remove()# 可视化每个图像的各层特征图(仅一层循环)for img_idx in range(num_images):img = images[img_idx].cpu().permute(1, 2, 0).numpy() # [C,H,W]→[H,W,C]# 反标准化处理(恢复原始像素值)img = img * np.array([0.2023, 0.1994, 0.2010]).reshape(1, 1, 3) + np.array([0.4914, 0.4822, 0.4465]).reshape(1, 1, 3)img = np.clip(img, 0, 1)  # 确保像素值在[0,1]范围内# 创建子图num_layers = len(layer_names)fig, axes = plt.subplots(1, num_layers + 1, figsize=(4 * (num_layers + 1), 4))# 显示原始图像axes[0].imshow(img)axes[0].set_title(f'原始图像\n类别: {class_names[labels[img_idx]]}')axes[0].axis('off')# 显示各层特征图for layer_idx, layer_name in enumerate(layer_names):fm = feature_maps[layer_name][img_idx]  # 取第img_idx张图像的特征图fm = fm[:num_channels]  # 仅取前num_channels个通道num_rows = int(np.sqrt(num_channels)) # 计算网格行数num_cols = num_channels // num_rows if num_rows != 0 else 1 # 计算网格列数# 创建子图网格layer_ax = axes[layer_idx + 1]layer_ax.set_title(f'{layer_name}特征图 \n')# 加个换行让文字分离上去layer_ax.axis('off')  # 关闭大子图的坐标轴# 在大子图内创建小网格for ch_idx, channel in enumerate(fm):ax = layer_ax.inset_axes([ch_idx % num_cols / num_cols, (num_rows - 1 - ch_idx // num_cols) / num_rows, 1/num_cols, 1/num_rows])ax.imshow(channel.numpy(), cmap='viridis')ax.set_title(f'通道 {ch_idx + 1}')ax.axis('off')plt.tight_layout()plt.show()# 调用示例(按需修改参数)
layer_names = ['conv1', 'conv2', 'conv3']
visualize_feature_maps(model=model,test_loader=test_loader,device=device,layer_names=layer_names,num_images=5,  # 可视化5张测试图像 → 输出5张大图num_channels=9   # 每张图像显示前9个通道的特征图
)

可以看到特征逐层抽象,从“看得见的细节”(conv1)→ “局部结构”(conv2)→ “类别相关的抽象模式”(conv3),模型通过这种方式实现从“看图像”到“理解语义”的跨越;并且每层各个通道也聚焦了不同特征(如通道1检测边缘轮廓,通道2检测纹理细节...),共同协作完成分类任务

 分别针对不同层次的模型理解需求的三种主流的神经网络可视化方法,就差注意力热图没学了,一般用于检查注意力模块是否聚焦与任务相关的关键通道或者空间,具体来说:

  1. 通道注意力热图:展示不同通道的权重分布,例如,在图像分类任务中,模型可能会对包含目标物体的通道赋予更高的注意力权重

  2. 空间注意力热图:展示输入图像中不同空间位置的权重分布,例如,在目标检测或图像分类任务中,模型可能会对目标物体所在的区域分配更高的注意力权重

注意力热图的步骤和特征图差不多,就是注意一下钩子注册位置仅在最后一个卷积层,就用昨天加入了通道注意力模块的cnn模型训练的结果来可视化

# 可视化空间注意力热力图(显示模型关注的图像区域)
def visualize_attention_map(model, test_loader, device, class_names, num_samples=3):"""可视化模型的注意力热力图,展示模型关注的图像区域"""model.eval()  # 设置为评估模式with torch.no_grad():for i, (images, labels) in enumerate(test_loader):if i >= num_samples:  # 只可视化前几个样本breakimages, labels = images.to(device), labels.to(device)# 创建一个钩子,捕获中间特征图activation_maps = []def hook(module, input, output):activation_maps.append(output.cpu())# 为最后一个卷积层注册钩子(获取特征图)hook_handle = model.conv3.register_forward_hook(hook)# 前向传播,触发钩子outputs = model(images)# 移除钩子hook_handle.remove()# 获取预测结果_, predicted = torch.max(outputs, 1)# 获取原始图像img = images[0].cpu().permute(1, 2, 0).numpy()# 反标准化处理img = img * np.array([0.2023, 0.1994, 0.2010]).reshape(1, 1, 3) + np.array([0.4914, 0.4822, 0.4465]).reshape(1, 1, 3)img = np.clip(img, 0, 1)# 获取激活图(最后一个卷积层的输出)feature_map = activation_maps[0][0].cpu()  # 取第一个样本# 计算通道注意力权重(使用SE模块的全局平均池化)channel_weights = torch.mean(feature_map, dim=(1, 2))  # [C]# 按权重对通道排序sorted_indices = torch.argsort(channel_weights, descending=True)# 创建子图fig, axes = plt.subplots(1, 4, figsize=(16, 4))# 显示原始图像axes[0].imshow(img)axes[0].set_title(f'原始图像\n真实: {class_names[labels[0]]}\n预测: {class_names[predicted[0]]}')axes[0].axis('off')# 显示前3个最活跃通道的热力图for j in range(3):channel_idx = sorted_indices[j]# 获取对应通道的特征图channel_map = feature_map[channel_idx].numpy()# 归一化到[0,1]channel_map = (channel_map - channel_map.min()) / (channel_map.max() - channel_map.min() + 1e-8)# 调整热力图大小以匹配原始图像from scipy.ndimage import zoomheatmap = zoom(channel_map, (32/feature_map.shape[1], 32/feature_map.shape[2]))# 显示热力图axes[j+1].imshow(img)axes[j+1].imshow(heatmap, alpha=0.5, cmap='jet')axes[j+1].set_title(f'注意力热力图 - 通道 {channel_idx}')axes[j+1].axis('off')plt.tight_layout()plt.show()# 调用可视化函数
visualize_attention_map(model, test_loader, device, class_names, num_samples=3)

关于全局平均池化 torch.mean vs nn.AdaptiveAvgPool2d 两种方法其实功能上完全等价,只不过前者是直接操作张量的函数,更适合后处理;后者适合神经网络。再明确一点,假设输入特征图[C, H, W]:

  • torch.mean(dim=(1, 2)) → 输出 [C](每个通道的标量均值)
  • nn.AdaptiveAvgPool2d(1) → 输出 [C, 1, 1](可通过 squeeze() 转为 [C],所以这里直接注册钩子捕获这一层输出得到通道权重也行

当前代码的热力图是单通道特征图的归一化激活强度,再叠加到原图上,并不是严格意义上的注意力热图,严谨一点的话应该用sigmoid后的通道权重来画直方图

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/news/908559.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

缓存一致性 与 执行流

上接多执行流系统中的可见性 在缓存一致性协议描述中,使用“处理器”或“CPU核心”比“执行流”更精确吗? 核心结论:在缓存一致性协议描述中,使用“处理器”或“CPU核心”比“执行流”更精确! 你的直觉是正确的。 原因分析&am…

机器学习:load_predict_project

本文目录: 一、project目录二、utils里的两个工具包(一)common.py(二)log.py 三、src文件夹代码(一)模型训练(train.py)(二)模型预测(…

Qt Test功能及架构

Qt Test 是 Qt 框架中的单元测试模块,在 Qt 6.0 中提供了全面的测试功能。 一、主要功能 核心功能 1. 单元测试框架 提供完整的单元测试基础设施 支持测试用例、测试套件的组织和执行 包含断言宏和测试结果收集 2. 测试类型支持 单元测试:对单个函…

零基础在实践中学习网络安全-皮卡丘靶场(第十一期-目录遍历模块)

经过前面几期的内容我们学习了很多网络安全的知识,而这期内容就涉及到了前面的第六期-RCE模块,第七期-File inclusion模块,第八期-Unsafe Filedownload模块。 什么是"遍历"呢:对学过一些开发语言的朋友来说应该知道&…

LLM 笔记:Speculative Decoding 投机采样

1 基本介绍 投机采样(Speculative Sampling)是一种并行预测多个可能输出,然后快速验证并采纳正确部分的加速策略 在不牺牲输出质量的前提下,减少语言模型生成 token 所需的时间 传统的语言模型生成是 串行 的 必须生成一个&…

Mysql批处理写入数据库

在学习mybatisPlus时,看到一个原本没用过的参数: rewriteBatchedStatementstrue 将上述代码装入jdbc的url中即可使数据库启用批处理写入。 需要注意的是,这个参数仅适用于MySQL JDBC 驱动的私有扩展参数。 作用原理是: 原本的…

数据类型--实型

C中的实型(也称为浮点型,Floating Point Type)用于表示带有小数部分的数值。 常见的实型有 float、double 和 long double,它们在精度和存储空间上有所不同。 1. 常见实型及其特性 类型字节数(通常)精度&…

引领AI安全新时代 Accelerate 2025北亚巡展·北京站成功举办

6月5日,网络安全行业年度盛会——"Accelerate 2025北亚巡展北京站"圆满落幕!来自智库、产业界、Fortinet管理层及技术团队的权威专家,与来自各行业的企业客户代表齐聚一堂,围绕"AI智御全球引领安全新时代"主题…

coze平台创建智能体,关于智能体后端接入的问题

一、智能体的插件在coze平台能正常调用,在Apifox中测试,它却直接回复直接回复“人设”或“知识库”,你的提问等内容: 为什么会这样?: Coze官方的插件(工具调用)机制是“分步交互式”…

Shell编程核心符号与格式化操作详解

Shell编程作为Linux系统管理和自动化运维的核心技能,掌握其常用符号和格式化操作是提升脚本开发效率的关键。本文将深入解析Shell中重定向、管道符、EOF、输入输出格式化等核心概念,并通过丰富的实践案例帮助读者掌握这些重要技能。 一、信息传递与重定…

C++课设:简易科学计算器(支持+-*/、sin、cos、tan、log等科学函数)

名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 专栏介绍:《编程项目实战》 目录 一、项目概览与设计理念1. 功能特色2. 技…

WPF八大法则:告别模态窗口卡顿

⚙️ 核心问题:阻塞式模态窗口的缺陷 原始代码中ShowDialog()会阻塞UI线程,导致后续逻辑无法执行: var result modalWindow.ShowDialog(); // 线程阻塞 ProcessResult(result); // 必须等待窗口关闭根本问题&#xff1a…

UOS无法安装deb软件包

UOS无法安装deb软件包 问题描述解决办法: 关闭安全中心的应用隔离结果验证 问题描述 UOS安装Linux微信的deb包时,无法正常安装 解决办法: 关闭安全中心的应用隔离 要关闭-安全中心的应用隔离后才可以正常软件和运行。 应用安全----》 允许任意应用。 结果验证 # …

鸿蒙jsonToArkTS_工具exe版本来了

前言导读 相信大家在学习鸿蒙开发过程中最痛苦的就是编写model 类 特别是那种复杂的json的时候对不对, 这时候有一个自动化的工具给你生成model是不是很开心。我们今天要分享的就是这个工具 JsonToArkTs 的用法 工具地址 https://gitee.com/qiuyu123/jsontomodel…

【Java算法】八大排序

八大排序算法 目录 注意:以下排序均属于内部排序 (1)插入排序 直接插入排序 改进版本 折半插入排序 希尔排序 (2)交换排序 冒泡排序 快速排序 (3)选择排序 简单选择排序 堆排序&…

玩转Docker | 使用Docker部署Qwerty Learner英语单词学习网站

玩转Docker | 使用Docker部署Qwerty Learner英语单词学习网站 前言一、Qwerty Learner简介Qwerty Learner 简介主要特点二、系统要求环境要求环境检查Docker版本检查检查操作系统版本三、部署Qwerty Learner服务下载Qwerty Learner镜像编辑部署文件创建容器检查容器状态检查服务…

Vue3中computed和watch的区别

文章目录 前言🔍 一、computed vs watch✅ 示例对比1. computed 示例(适合模板绑定、衍生数据)2. watch 示例(副作用,如调用接口) 🧠 二、源码实现原理(简化理解)1. comp…

C++修炼:C++11(二)

Hello大家好&#xff01;很高兴我们又见面啦&#xff01;给生活添点passion&#xff0c;开始今天的编程之路&#xff01; 我的博客&#xff1a;<但凡. 我的专栏&#xff1a;《编程之路》、《数据结构与算法之美》、《题海拾贝》、《C修炼之路》 欢迎点赞&#xff0c;关注&am…

单元测试与QTestLib框架使用

一.单元测试的意义 在软件开发中&#xff0c;单元测试是指对软件中最小可测试单元&#xff08;通常是函数、类的方法&#xff09;进行隔离的、可重复的验证。进行单元测试具有以下重要意义&#xff1a; 1.提升代码质量与可靠性&#xff1a; 早期错误检测&#xff1a; 在开发…

(附实现代码)Step-Back 回答回退策略扩大检索范围

1. LangChain 少量示例提示模板 在与 LLM 的对话中&#xff0c;提供少量的示例被称为 少量示例&#xff0c;这是一种简单但强大的指导生成的方式&#xff0c;在某些情况下可以显著提高模型性能&#xff08;与之对应的是零样本&#xff09;&#xff0c;少量示例可以降低 Prompt…