知识点回顾
1.回调函数
2.lambda函数
3.hook函数的模块钩子和张量钩子
4.Grad-CAM的示例
一。回调函数示例
Hook本质是回调函数,所以我们先介绍一下回调函数。回调函数是作为参数传递给其他函数的函数,其目的是在某个特定事件发生时被调用执行。这种机制允许代码在运行时动态指定需要执行的逻辑,实现了代码的灵活性和可扩展性。
回调函数的核心价值在于:解耦逻辑:将通用逻辑与特定处理逻辑分离,使代码更模块化。
事件驱动编程:在异步操作、事件监听(如点击按钮、网络请求完成)等场景中广泛应用。
延迟执行:允许在未来某个时间点执行特定代码,而不必立即执行。
其中回调函数作为参数传入,所以在定义的时候一般用callback来命名,在 PyTorch 的 Hook API 中,回调参数通常命名为 hook
# 训练过程中的回调函数
class Callback:def on_train_begin(self):print("训练开始")def on_epoch_end(self, epoch, logs=None):print(f"Epoch {epoch} 完成")# 使用示例
callback = Callback()
callback.on_train_begin()
for epoch in range(10):# ...训练代码...callback.on_epoch_end(epoch)
二、lambda函数示例
在hook中常常用到lambda函数,它是一种匿名函数(没有正式名称的函数),最大特点是用完即弃,无需提前命名和定义。它的语法形式非常简约,仅需一行即可完成定义,格式如下:
lambda 参数列表: 表达式
参数列表:可以是单个参数、多个参数或无参数。
表达式:函数的返回值(无需 return 语句,表达式结果直接返回)
# 简单lambda
add = lambda x, y: x + y# 在PyTorch中的使用
data = torch.randn(10)
processed = list(map(lambda x: x*2, data)) # 每个元素乘以2
三、hook函数示例
# 模块钩子
model = nn.Sequential(nn.Linear(10,5), nn.ReLU())
def module_hook(module, input, output):print(f"{module.__class__.__name__} 输出形状: {output.shape}")
model[0].register_forward_hook(module_hook)# 张量钩子
x = torch.randn(3, requires_grad=True)
x.register_hook(lambda grad: grad * 0.5) # 梯度修改
四、Grad-CAM示例
class GradCAM:def __init__(self, model, target_layer):self.model = modelself.gradients = Noneself.activations = Nonetarget_layer.register_forward_hook(self.save_activations)target_layer.register_backward_hook(self.save_gradients)def save_activations(self, module, input, output):self.activations = output.detach()def save_gradients(self, module, grad_input, grad_output):self.gradients = grad_output[0].detach()def __call__(self, x, class_idx=None):# ...前向/反向传播逻辑...cam = torch.relu(torch.sum(self.activations * weights, dim=1))return cam