【Pytorch学习笔记】模型模块06——hook函数

hook函数

什么是hook函数

hook函数相当于插件,可以实现一些额外的功能,而又不改变主体代码。就像是把额外的功能挂在主体代码上,所有叫hook(钩子)。下面介绍Pytorch中的几种主要hook函数。

torch.Tensor.register_hook

torch.Tensor.register_hook()是一个用于注册梯度钩子函数的方法。它主要用于获取和修改张量在反向传播过程中的梯度。

语法格式:

hook = tensor.register_hook(hook_fn)
# hook_fn的格式为:
def hook_fn(grad):# 处理梯度return new_grad  # 可选

主要特点:

  • hook函数在反向传播计算梯度时被调用
  • hook函数接收梯度作为输入参数
  • 可以返回修改后的梯度,或者不返回(此时使用原始梯度)
  • 可以注册多个hook函数,按照注册顺序依次调用

使用示例:

import torch# 创建需要跟踪梯度的张量
x = torch.tensor([1., 2., 3.], requires_grad=True)# 定义hook函数
def hook_fn(grad):print('梯度值:', grad)return grad * 2  # 将梯度翻倍# 注册hook函数
hook = x.register_hook(hook_fn)# 进行一些运算
y = x.pow(2).sum()
y.backward()# 移除hook函数(可选)
hook.remove()

注意事项:

  • 只能在requires_grad=True的张量上注册hook函数
  • hook函数在不需要时应该及时移除,以免影响后续计算
  • 不建议在hook函数中修改梯度的形状,可能导致错误
  • 主要用于调试、可视化和梯度修改等场景

torch.nn.Module.register_forward_hook

torch.nn.Module.register_forward_hook()是一个用于注册前向传播钩子函数的方法。它允许我们在模型的前向传播过程中获取和处理中间层的输出

语法格式:

hook = module.register_forward_hook(hook_fn)
# hook_fn的格式为:
def hook_fn(module, input, output):# 处理输入和输出return modified_output  # 可选

主要特点:

  • hook函数在前向传播过程中被调用
  • 可以访问模块的输入和输出数据
  • 可以用于监控和修改中间层的特征
  • 不影响反向传播过程

使用示例:

import torch
import torch.nn as nn# 创建一个简单的神经网络
class Net(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)def forward(self, x):x = self.conv1(x)x = self.conv2(x)return x# 创建模型实例
model = Net()# 定义hook函数
def hook_fn(module, input, output):print('模块:', module)print('输入形状:', input[0].shape)print('输出形状:', output.shape)# 注册hook函数
hook = model.conv1.register_forward_hook(hook_fn)# 前向传播
x = torch.randn(1, 1, 32, 32)
output = model(x)# 移除hook函数
hook.remove()

注意事项:

  • hook函数在每次前向传播时都会被调用
  • 可以同时注册多个hook函数,按注册顺序调用
  • 适用于特征可视化、调试网络结构等场景
  • 建议在不需要时移除hook函数,以提高性能

torch.nn,Module.register_forward_pre_hook

torch.nn.Module.register_forward_pre_hook()是一个用于注册前向传播预处理钩子函数的方法。它允许我们在模型的前向传播开始之前对输入数据进行处理或修改。

语法格式:

hook = module.register_forward_pre_hook(hook_fn)
# hook_fn的格式为:
def hook_fn(module, input):# 处理输入return modified_input  # 可选

主要特点:

  • hook函数在前向传播开始前被调用
  • 可以访问和修改输入数据
  • 常用于输入预处理和数据转换
  • 在实际计算前执行,可以改变输入特征

使用示例:

import torch
import torch.nn as nn# 创建一个简单的神经网络
class Net(nn.Module):def __init__(self):super().__init__()self.linear = nn.Linear(10, 5)def forward(self, x):return self.linear(x)# 创建模型实例
model = Net()# 定义pre-hook函数
def pre_hook_fn(module, input_data):print('模块:', module)print('原始输入形状:', input_data[0].shape)# 对输入数据进行处理,例如标准化modified_input = input_data[0] * 2.0return modified_input# 注册pre-hook函数
hook = model.linear.register_forward_pre_hook(pre_hook_fn)# 前向传播
x = torch.randn(32, 10)  # 批次大小为32,特征维度为10
output = model(x)# 移除hook函数
hook.remove()

注意事项:

  • pre-hook函数在每次前向传播前都会被调用
  • 可以用于数据预处理、特征转换等操作
  • 返回值会替换原始输入,影响后续计算
  • 建议在不需要时及时移除,以免影响模型性能

与register_forward_hook的区别:

  • pre-hook在模块计算之前执行,forward_hook在计算之后执行
  • pre-hook只能访问输入数据,forward_hook可以同时访问输入和输出
  • pre-hook更适合做输入预处理,forward_hook更适合做特征分析

torch.nn.Module.register_full_backward_hook

torch.nn.Module.register_full_backward_hook()是一个用于注册完整反向传播钩子函数的方法。它允许我们在模型的反向传播过程中访问和修改梯度信息

语法格式:

hook = module.register_full_backward_hook(hook_fn)
# hook_fn的格式为:
def hook_fn(module, grad_input, grad_output):# 处理梯度return modified_grad_input  # 可选

主要特点:

  • hook函数在反向传播过程中被调用
  • 可以同时访问输入梯度和输出梯度
  • 可以修改反向传播的梯度流
  • 比register_backward_hook更强大,提供更完整的梯度信息

使用示例:

import torch
import torch.nn as nn# 创建一个简单的神经网络
class Net(nn.Module):def __init__(self):super().__init__()self.linear = nn.Linear(5, 3)def forward(self, x):return self.linear(x)# 创建模型实例
model = Net()# 定义backward hook函数
def backward_hook_fn(module, grad_input, grad_output):print('模块:', module)print('输入梯度形状:', [g.shape if g is not None else None for g in grad_input])print('输出梯度形状:', [g.shape if g is not None else None for g in grad_output])# 可以返回修改后的输入梯度return grad_input# 注册backward hook函数
hook = model.linear.register_full_backward_hook(backward_hook_fn)# 前向和反向传播
x = torch.randn(2, 5, requires_grad=True)
output = model(x)
loss = output.sum()
loss.backward()# 移除hook函数
hook.remove()

注意事项:

  • hook函数可能会影响模型的训练过程,使用时需要谨慎
  • 建议仅在调试和分析梯度流时使用
  • 返回值会替换原始输入梯度,可能影响模型收敛
  • 在不需要时应及时移除hook函数

与register_backward_hook的区别:

  • register_full_backward_hook提供更完整的梯度信息
  • 更适合处理复杂的梯度修改场景
  • 建议使用register_full_backward_hook替代已废弃的register_backward_hook

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/pingmian/83304.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SQL进阶之旅 Day 11:复杂JOIN查询优化

【SQL进阶之旅 Day 11】复杂JOIN查询优化 在数据处理日益复杂的今天,JOIN操作作为SQL中最强大的功能之一,常常成为系统性能瓶颈。今天我们进入"SQL进阶之旅"系列的第11天,将深入探讨复杂JOIN查询的优化策略。通过本文学习&#xf…

Spring AI 之检索增强生成(Retrieval Augmented Generation)

检索增强生成(RAG)是一种技术,有助于克服大型语言模型在处理长篇内容、事实准确性和上下文感知方面的局限性。 Spring AI 通过提供模块化架构来支持 RAG,该架构允许自行构建自定义的 RAG 流程,或者使用 Advisor API 提…

前端开源JavaScrip库

以下内容仍在持续完善中,如有遗漏或需要补充之处,欢迎在评论区指出。感谢支持,如果觉得有帮助,欢迎点赞鼓励。感谢支持 JavaScript 框架Vue.jsVue.js - 渐进式 JavaScript 框架 | Vue.jsReactReactAngularHome • AngularjQueryj…

什么是 CPU 缓存模型?

导语: CPU 缓存模型是后端性能调优、并发编程乃至分布式系统设计中一个绕不开的核心概念。它不仅关系到指令执行效率,还影响锁机制、内存可见性等多个面试高频点。本文将以资深面试官视角,详解缓存模型的原理、常见面试题及实战落地&#xff…

海外tk抓包简单暴力方式

将地址替换下面代码就可以 function hook_dlopen(module_name, fun) {var android_dlopen_ext Module.findExportByName(null, "android_dlopen_ext");if (android_dlopen_ext) {Interceptor.attach(android_dlopen_ext, {onEnter: function (args) {var pathptr …

多模态大语言模型arxiv论文略读(103)

Are Bigger Encoders Always Better in Vision Large Models? ➡️ 论文标题:Are Bigger Encoders Always Better in Vision Large Models? ➡️ 论文作者:Bozhou Li, Hao Liang, Zimo Meng, Wentao Zhang ➡️ 研究机构: 北京大学 ➡️ 问题背景&…

代码随想录算法训练营 Day61 图论ⅩⅠ Floyd A※ 最短路径算法

图论 题目 97. 小明逛公园 本题是经典的多源最短路问题。 在这之前我们讲解过,dijkstra朴素版、dijkstra堆优化、Bellman算法、Bellman队列优化(SPFA) 都是单源最短路,即只能有一个起点。 而本题是多源最短路,即求多…

【机器学习】集成学习与梯度提升决策树

目录 一、引言 二、自举聚合与随机森林 三、集成学习器 四、提升算法 五、Python代码实现集成学习与梯度提升决策树的实验 六、总结 一、引言 在机器学习的广阔领域中,集成学习(Ensemble Learning)犹如一座闪耀的明星,它通过组合多个基本学习器的力量,创造出…

yarn、pnpm、npm

非常好,这样从“问题驱动 → 工具诞生 → 优化演进”的角度来讲,更清晰易懂。下面我按时间线和动机,把 npm → yarn → pnpm 的演变脉络讲清楚。 🧩 一、npm 为什么一开始不够好? 早期(npm v4 及之前&…

如何用AI写作?

过去半年,我如何用AI高效写作,节省数倍时间 过去六个月,我几乎所有文章都用AI辅助完成。我的朋友——大多是文字工作者,对语言极为敏感——都说看不出我的文章是AI写的还是亲手创作的。 我的AI写作灵感部分来自丘吉尔。这位英国…

什么是trace,分布式链路追踪(Distributed Tracing)

在你提到的 “个人免费版” 套餐中,“Trace 上报量:5 万条 / 月,存储 3 天” 里的 Trace 仍然是指 分布式链路追踪记录,但需要结合具体产品的场景来理解其含义和限制。以下是更贴近个人用户使用场景的解释: 一、这里的…

[免费]微信小程序网上花店系统(SpringBoot后端+Vue管理端)【论文+源码+SQL脚本】

大家好,我是java1234_小锋老师,看到一个不错的微信小程序网上花店系统(SpringBoot后端Vue管理端)【论文源码SQL脚本】,分享下哈。 项目视频演示 【免费】微信小程序网上花店系统(SpringBoot后端Vue管理端) Java毕业设计_哔哩哔哩_bilibili 项…

PyTorch——DataLoader的使用

batch_size, drop_last 的用法 shuffle shuffleTrue 各批次训练的图像不一样 shuffleFalse 在第156step顺序一致

【Linux】基础文件IO

🌟🌟作者主页:ephemerals__ 🌟🌟所属专栏:Linux 前言 无论是日常使用还是系统管理,文件是Linux系统中最核心的概念之一。对于初学者来说,理解文件是如何被创建、读取、写入以及存储…

【JAVA后端入门基础001】Tomcat 是什么?通俗易懂讲清楚!

📚博客主页:代码探秘者 ✨专栏:《JavaSe》 其他更新ing… ❤️感谢大家点赞👍🏻收藏⭐评论✍🏻,您的三连就是我持续更新的动力❤️ 🙏作者水平有限,欢迎各位大佬指点&…

TDengine 的 AI 应用实战——电力需求预测

作者: derekchen Demo数据集准备 我们使用公开的UTSD数据集里面的电力需求数据,作为预测算法的数据来源,基于历史数据预测未来若干小时的电力需求。数据集的采集频次为30分钟,单位与时间戳未提供。为了方便演示,按…

D2000平台上Centos使用mmap函数遇到的陷阱

----------原创不易,欢迎点赞收藏。广交嵌入式开发的朋友,讨论技术和产品------------- 在飞腾D2000平台上,安装了麒麟linux系统,我写了个GPIO点灯的程序,在应用层利用mmap函数将内核空间映射到用户态,然后…

深入了解linux系统—— 进程间通信之管道

前言 本篇博客所涉及到的代码一同步到本人gitee:testfifo 迟来的grown/linux - 码云 - 开源中国 一、进程间通信 什么是进程间通信 在之前的学习中,我们了解到了进程具有独立性,就算是父子进程,在修改数据时也会进行写时拷贝&…

设计模式——模版方法设计模式(行为型)

摘要 模版方法设计模式是一种行为型设计模式,定义了算法的步骤顺序和整体结构,将某些步骤的具体实现延迟到子类中。它通过抽象类定义模板方法,子类实现抽象步骤,实现代码复用和算法流程控制。该模式适用于有固定流程但部分步骤可…

Python使用

Python学习,从安装,到简单应用 前言 Python作为胶水语言在web开发,数据分析,网络爬虫等方向有着广泛的应用 一、Python入门 相关基础语法直接使用相关测试代码 Python编译器版本使用3以后,安装参考其他教程&#xf…