python打卡day42

Grad-CAM与Hook函数

知识点回顾

  1. 回调函数
  2. lambda函数
  3. hook函数的模块钩子和张量钩子
  4. Grad-CAM的示例

在深度学习中,我们经常需要查看或修改模型中间层的输出或梯度,但标准的前向传播和反向传播过程通常是一个黑盒,很难直接访问中间层的信息。PyTorch 提供了一种强大的工具——hook 函数,它允许我们在不修改模型结构的情况下,获取或修改中间层的信息。常用场景如下:

  1. 调试与可视化中间层输出
  2. 特征提取:如在图像分类模型中提取高层语义特征用于下游任务
  3. 梯度分析与修改: 在训练过程中,对某些层进行梯度裁剪或缩放,以改变模型训练的动态
  4. 模型压缩:在推理阶段对特定层的输出应用掩码(如剪枝后的模型权重掩码),实现轻量化推理

1、回调函数

Hook本质是回调函数,所以我们先介绍一下回调函数。回调函数是作为参数传递给其他函数的函数,其目的是在某个特定事件发生时被调用执行。这种机制允许代码在运行时动态指定需要执行的逻辑,其中回调函数作为参数传入,所以在定义的时候一般用callback来命名

在 PyTorch 的 Hook API 中,回调参数通常命名为 hook,PyTorch 的 Hook 机制基于其动态计算图系统:

  1. 当你注册一个 Hook 时,PyTorch 会在计算图的特定节点(如模块或张量)上添加一个回调函数
  2. 当计算图执行到该节点时(前向或反向传播),自动触发对应的 Hook 函数
  3. Hook 函数可以访问或修改流经该节点的数据(如输入、输出或梯度)

2、lambda函数

在hook中常常用到lambda函数,它是一种匿名函数(没有正式名称的函数),最大特点是用完即弃,无需提前命名和定义。它的语法形式非常简约,仅需一行即可完成定义,格式:lambda 参数列表: 表达式

  • 参数列表:可以是单个参数、多个参数或无参数
  • 表达式:函数的返回值(无需 return 语句,表达式结果直接返回)

举个例子

# 定义匿名函数:计算平方
square = lambda x: x ** 2# 调用
print(square(5))  # 输出: 25

3、hook函数

PyTorch 提供了两种主要的 hook:

  • Module Hooks(模块钩子):用于监听整个模块的输入和输出
  • Tensor Hooks:用于监听张量的梯度

(1)模块钩子

允许我们在模块的输入或输出经过时进行监听。PyTorch 提供了两种模块钩子:

  1. register_forward_hook:在模块的前向传播完成后立即被调用,这个函数可以访问模块的输入和输出,但不能修改
  2. register_backward_hook:在反向传播过程中被调用的,可以用来获取或修改梯度信息

前向钩子举个例子

# 创建模型实例
model = SimpleModel()# 创建一个列表用于存储中间层的输出
conv_outputs = []# 定义前向钩子函数 - 用于在模型前向传播过程中获取中间层信息
def forward_hook(module, input, output):print(f"钩子被调用!模块类型: {type(module)}")print(f"输入形状: {input[0].shape}") #  input是一个元组,对应 (image, label)print(f"输出形状: {output.shape}")# 保存卷积层的输出用于后续分析# 使用detach()避免追踪梯度,防止内存泄漏conv_outputs.append(output.detach())# 在卷积层注册前向钩子
# register_forward_hook返回一个句柄,用于后续移除钩子
hook_handle = model.conv.register_forward_hook(forward_hook)# 创建一个随机输入张量 (批次大小=1, 通道=1, 高度=4, 宽度=4)
x = torch.randn(1, 1, 4, 4)# 执行前向传播 - 此时会自动触发钩子函数
output = model(x)# 释放钩子 - 重要!防止在后续模型使用中持续调用钩子造成意外行为或内存泄漏
hook_handle.remove()

反向钩子

# 定义一个存储梯度的列表
conv_gradients = []# 定义反向钩子函数
def backward_hook(module, grad_input, grad_output):print(f"反向钩子被调用!模块类型: {type(module)}")print(f"输入梯度数量: {len(grad_input)}")print(f"输出梯度数量: {len(grad_output)}")# 保存梯度供后续分析conv_gradients.append((grad_input, grad_output))# 在卷积层注册反向钩子
hook_handle = model.conv.register_backward_hook(backward_hook)# 创建一个随机输入并进行前向传播
x = torch.randn(1, 1, 4, 4, requires_grad=True)
output = model(x)# 定义一个简单的损失函数并进行反向传播
loss = output.sum()
loss.backward()# 释放钩子
hook_handle.remove()

(2)张量钩子

PyTorch 还提供了张量钩子,允许我们直接监听和修改张量的梯度。张量钩子有两种:

  1. register_hook:用于监听张量的梯度
  2. register_full_backward_hook:用于在完整的反向传播过程中监听张量的梯度
# 创建一个需要计算梯度的张量
x = torch.tensor([2.0], requires_grad=True)
y = x ** 2
z = y ** 3# 定义一个钩子函数,用于修改梯度
def tensor_hook(grad):print(f"原始梯度: {grad}")# 修改梯度,例如将梯度减半return grad / 2# 在y上注册钩子
hook_handle = y.register_hook(tensor_hook)# 计算梯度,梯度会从z反向传播经过y到x,此时调用钩子函数
z.backward()print(f"x的梯度: {x.grad}")# 释放钩子
hook_handle.remove()

4、Grad-CAM

一个可视化算法,通过梯度信息用热力图显示图片中哪些区域让CNN做出了某个分类决定(比如为什么认为这是“猫”),原理:

  • 梯度计算:看最后几层特征图的梯度,哪个特征图对预测“猫”的贡献大
  • 加权融合:把重要的特征图合并成一张热力图(重要区域更亮)
  • 叠加显示:把热力图盖在原图上,一眼看出猫的脸/耳朵等关键部位被高亮了
# Grad-CAM实现
class GradCAM:def __init__(self, model, target_layer):self.model = modelself.target_layer = target_layerself.gradients = Noneself.activations = None# 注册钩子,用于获取目标层的前向传播输出和反向传播梯度self.register_hooks()def register_hooks(self):# 前向钩子函数,在目标层前向传播后被调用,保存目标层的输出(激活值)def forward_hook(module, input, output):self.activations = output.detach()# 反向钩子函数,在目标层反向传播后被调用,保存目标层的梯度def backward_hook(module, grad_input, grad_output):self.gradients = grad_output[0].detach()# 在目标层注册前向钩子和反向钩子self.target_layer.register_forward_hook(forward_hook)self.target_layer.register_backward_hook(backward_hook)def generate_cam(self, input_image, target_class=None):# 前向传播,得到模型输出model_output = self.model(input_image)if target_class is None:# 如果未指定目标类别,则取模型预测概率最大的类别作为目标类别target_class = torch.argmax(model_output, dim=1).item()# 清除模型梯度,避免之前的梯度影响self.model.zero_grad()# 反向传播,构造one-hot向量,使得目标类别对应的梯度为1,其余为0,然后进行反向传播计算梯度one_hot = torch.zeros_like(model_output)one_hot[0, target_class] = 1model_output.backward(gradient=one_hot)# 获取之前保存的目标层的梯度和激活值gradients = self.gradientsactivations = self.activations# 对梯度进行全局平均池化,得到每个通道的权重,用于衡量每个通道的重要性weights = torch.mean(gradients, dim=(2, 3), keepdim=True)# 加权激活映射,将权重与激活值相乘并求和,得到类激活映射的初步结果cam = torch.sum(weights * activations, dim=1, keepdim=True)# ReLU激活,只保留对目标类别有正贡献的区域,去除负贡献的影响cam = F.relu(cam)# 调整大小并归一化,将类激活映射调整为与输入图像相同的尺寸(32x32),并归一化到[0, 1]范围cam = F.interpolate(cam, size=(32, 32), mode='bilinear', align_corners=False)cam = cam - cam.min()cam = cam / cam.max() if cam.max() > 0 else camreturn cam.cpu().squeeze().numpy(), target_classidx = 102  # 选择测试集中的第101张图片 (索引从0开始)
image, label = testset[idx]
print(f"选择的图像类别: {classes[label]}")# 转换图像以便可视化
def tensor_to_np(tensor):img = tensor.cpu().numpy().transpose(1, 2, 0)mean = np.array([0.5, 0.5, 0.5])std = np.array([0.5, 0.5, 0.5])img = std * img + meanimg = np.clip(img, 0, 1)return img# 添加批次维度并移动到设备
input_tensor = image.unsqueeze(0).to(device)# 初始化Grad-CAM(选择最后一个卷积层)
grad_cam = GradCAM(model, model.conv3)# 生成热力图
heatmap, pred_class = grad_cam.generate_cam(input_tensor)# 可视化
plt.figure(figsize=(12, 4))# 原始图像
plt.subplot(1, 3, 1)
plt.imshow(tensor_to_np(image))
plt.title(f"原始图像: {classes[label]}")
plt.axis('off')# 热力图
plt.subplot(1, 3, 2)
plt.imshow(heatmap, cmap='jet')
plt.title(f"Grad-CAM热力图: {classes[pred_class]}")
plt.axis('off')# 叠加的图像
plt.subplot(1, 3, 3)
img = tensor_to_np(image)
heatmap_resized = np.uint8(255 * heatmap)
heatmap_colored = plt.cm.jet(heatmap_resized)[:, :, :3]
superimposed_img = heatmap_colored * 0.4 + img * 0.6
plt.imshow(superimposed_img)
plt.title("叠加热力图")
plt.axis('off')plt.tight_layout()
plt.savefig('grad_cam_result.png')
plt.show()

@浙大疏锦行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/bicheng/83263.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

中国风展示工作总结商务通用PPT模版

中国风展示工作总结商务通用PPT模版:中国风商务通用PPT 模版https://pan.quark.cn/s/42ad18c010d4

TeleAI发布TeleChat2.5及T1正式版,双双开源上线魔乐社区!

5月12日,中国电信开源TeleChat系列四个模型,涵盖复杂推理和通用问答的多个尺寸模型,包括TeleChat-T1-35B、TeleChat-T1-115B、TeleChat2.5-35B和TeleChat2.5-115B,实测模型性能均有显著的性能效果。TeleChat系列模型基于昇思MindS…

机器视觉2D定位引导一般步骤

机器视觉的2D定位引导是工业自动化中的核心应用,主要用于精确确定目标物体的位置(X, Y坐标)和角度(旋转角度θ),并引导机器人或运动机构进行抓取、装配、对位、检测等操作。其一般步骤可概括如下: 一、系统规划与硬件选型 明确需求: 定位精度要求(多少毫米/像素,多少…

儿童节快乐,聊聊数字的规律和同余原理

某年的6月1日是星期日。那么,同一年的6月30日是星期几? 星期是7天一个循环。所以说,这一天是星期几,7天之后同样也是星期几。而6月30日是在6月1日的29天之后:29 7 4 ... 1用29除以7,可以得出余数为1。而…

最佳实践|互联网行业软件供应链安全建设的SCA纵深实践方案

在数字化转型的浪潮中,开源组件已成为企业构建云服务与应用的基石,但其引入的安全风险也日益凸显。某互联网大厂的核心安全研究团队,通过深度应用软件成分分析(SCA)技术,构建了一套覆盖开源组件全生命周期管…

Docker Compose(容器编排)

目录 什么是 Docker Compose Docker Compose 的功能 Docker Compose 使用场景 Docker Compose 文件(docker-compose.yml) Docker Compose 命令清单 常见命令说明 操作案例 总结 什么是 Docker Compose docker-compose 是 Docker 官方的开源项…

【网络安全】轻量敏感路径扫描工具

订阅专栏,获取文末项目源码。 文章目录 工具简介工具特点项目结构使用方法1.环境准备2.配置目标URL3.运行扫描4.结果查看5.自定义扩展项目源码工具简介 该工具是一款基于Python的异步敏感路径扫描工具,用于检测目标网站是否存在敏感文件或路径泄露(如配置文件、密钥、版本控…

SpringAI+DeepSeek大模型应用开发实战

内容来自黑马程序员 这里写目录标题 认识AI和大模型大模型应用开发模型部署方案对比模型部署-云服务模型部署-本地部署调用大模型什么是大模型应用传统应用和大模型应用大模型应用 大模型应用开发技术架构 SpringAI对话机器人快速入门会话日志会话记忆 认识AI和大模型 AI的发…

高温炉制造企业Odoo ERP实施规划与深度分析报告

摘要 本报告旨在为高温炉生产企业提供一个基于Odoo 18平台的企业资源规划(ERP)系统实施的全面分析与规划。报告首先系统梳理了高温炉制造业独特的业务流程特点,随后详细映射了Odoo 18各核心模块功能与这些业务需求的匹配程度。重点分析了生产…

简述什么是全局锁?它的应用场景有哪些?

全局锁是数据库管理系统中的一种特殊锁机制,用于对整个数据库实例进行加锁,使数据库处于只读状态,阻止所有数据更新(DML)、数据定义(DDL)及更新类事务提交等操作。 其核心应用场景包括&#xf…

window 显示驱动开发-呈现开销改进(二)

对共享表面的纹理格式支持 驱动程序应支持共享资源和可共享的后台缓冲区,以使用 DXGI_FORMAT 枚举中的这些附加纹理格式: DXGI_FORMAT_A8_UNORMDXGI_FORMAT_R8_UNORMDXGI_FORMAT_R8G8_UNORMDXGI_FORMAT_BC1_TYPELESS\*DXGI_FORMAT_BC1_UNORMDXGI_FORMAT…

jenkins集成gitlab实现自动构建

jenkins集成gitlab实现自动构建 前面我们已经部署了Jenkins和gitlab,本文介绍将二者结合使用 项目源码上传至gitee提供公网访问:https://gitee.com/ye-xiao-tian/my-webapp 1、创建一个群组和项目 2、添加ssh密钥 #生成密钥 [rootgitlab ~]# ssh-keyge…

barker-OFDM模糊函数原理及仿真

文章目录 前言一、巴克码序列二、barker-OFDM 信号1、OFDM 信号表达式2、模糊函数表达式 三、MATLAB 仿真1、MATLAB 核心源码2、仿真结果①、barker-OFDM 模糊函数②、barker-OFDM 距离分辨率③、barker-OFDM 速度分辨率④、barker-OFDM 等高线图 四、资源自取 前言 本文进行 …

深入解析 Redis Cluster 架构与实现(一)

#作者:stackofumbrella 文章目录 Redis Cluster特点Redis Cluster与其它集群模式的区别集群目标性能hash tagsMutli-key操作Cluster Bus安全写入(write safety)集群节点的属性集群拓扑节点间handshake重定向与reshardingMOVED重定向ASK重定向…

linux centos 服务器性能排查 vmstat、top等常用指令

背景:项目上经常出现系统运行缓慢,由于数据库服务器是linux服务器,记录下linux服务器性能排查常用指令 vmstat vmstat介绍 vmstat 命令报告关于内核线程、虚拟内存、磁盘、陷阱和 CPU 活动的统计信息。由 vmstat 命令生成的报告可以用于平衡系统负载活动。系统范围内的这…

在IIS上无法使用PUT等请求

错误来源: chat:1 Access to XMLHttpRequest at http://101.126.139.3:11000/api/receiver/message from origin http://101.126.139.3 has been blocked by CORS policy: No Access-Control-Allow-Origin header is present on the requested resource. 其实我的后…

Python训练第四十一天

DAY 41 简单CNN 知识回顾 数据增强卷积神经网络定义的写法batch归一化:调整一个批次的分布,常用与图像数据特征图:只有卷积操作输出的才叫特征图调度器:直接修改基础学习率 卷积操作常见流程如下: 1. 输入 → 卷积层 →…

Linux线程同步实战:多线程程序的同步与调度

个人主页:chian-ocean 文章专栏-Linux Linux线程同步实战:多线程程序的同步与调度 个人主页:chian-ocean文章专栏-Linux 前言:为什么要实现线程同步线程饥饿(Thread Starvation)示例:抢票问题 …

5.2 初识Spark Streaming

在本节实战中,我们初步探索了Spark Streaming,它是Spark的流式数据处理子框架,具备高吞吐量、可伸缩性和强容错能力。我们了解了Spark Streaming的基本概念和运行原理,并通过两个案例演示了如何利用Spark Streaming实现词频统计。…

Go 即时通讯系统:日志模块重构,并从main函数开始

重构logger 上次写的logger.go过于繁琐,有很多没用到的功能;重构后只提供了简洁的日志接口,支持日志轮转、多级别日志记录等功能,并采用单例模式确保全局只有一个日志实例 全局变量 var (once sync.Once // 用于实现…