深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(2)

前言

《深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(1)》里面我只是提到了对conv1层进行剪枝,只是为了验证这个剪枝的整个过程,但是后面也有提到:仅裁剪 conv1层的影响极大,原因如下:

  • 底层特征的重要性 : conv1输出的是最基础的图像特征,所有后续层的特征均基于此生成。裁剪 conv1 会直接限制后续所有层的特征表达能力。
  • 结构连锁反应 : conv1的输出通道减少会触发 bn1layer1.0.conv1downsample 等多个模块的调整,任何一个模块的调整失误(如通道数不匹配、参数初始化不当)都会导致整体性能下降。
    虽然,在例子中,我们只是简单的进行了验证,发现效果也不是很差,但是如果具体到自己的数据,或者更加复杂的特征或者模型,可能就会影响到了整体的性能,因此,我们在原有的基础上做了如下的改动:
  1. 剪枝目标层调整 :将 conv1 改为 layer2.0.conv1 ,减少对底层特征的破坏。
  2. 通道评估优化 :通过前向传播收集激活值,优先剪枝激活值低的通道,更符合实际特征贡献。
  3. 微调策略改进 :动态解冻剪枝层及关联的BN、downsample层,学习率降低(0.0001),微调轮次增加(10轮),确保参数充分适应。

这些修改可显著提升剪枝后模型的稳定性和准确率。建议运行时观察微调阶段的Loss是否持续下降,若下降缓慢可进一步降低学习率(如0.00001)。
所有代码都在这:https://gitee.com/NOON47/model_prune

详细改动

  1. 剪枝目标层调整 :将 conv1 改为 layer2.0.conv1 ,减少对底层特征的破坏。
    layer_to_prune = 'layer2.0.conv1'  # 显式定义要剪枝的层名pruned_model = prune_conv_layer(model, layer_to_prune, amount=0.2)
  1. 通道评估优化 :通过前向传播收集激活值,优先剪枝激活值低的通道,更符合实际特征贡献。
    model.eval()with torch.no_grad():test_input = torch.randn(128, 3, 32, 32).to(device)  # 模拟 CIFAR10 输入features = []def hook_fn(module, input, output):features.append(output)handle = layer.register_forward_hook(hook_fn)model(test_input)handle.remove()activation = features[0]  # shape: [128, out_channels, H, W]channel_importance = activation.mean(dim=(0, 2, 3))  # 按通道求平均激活值num_channels = weight.shape[0]num_prune = int(num_channels * amount)_, indices = torch.topk(channel_importance, k=num_prune, largest=False)mask = torch.ones(num_channels, dtype=torch.bool)mask[indices] = False  # 生成剪枝掩码
  1. 微调策略改进 :动态解冻剪枝层及关联的BN、downsample层,学习率降低(0.0001),微调轮次增加(10轮),确保参数充分适应。
    print("开始微调剪枝后的模型")# 新增:根据剪枝层动态解冻相关层(假设剪枝层为layer2.0.conv1)pruned_layer_prefix = layer_to_prune.rpartition('.')[0]  # 例如 'layer2.0'for name, param in pruned_model.named_parameters():if (pruned_layer_prefix in name) or ('fc' in name) or ('bn' in name):  # 解冻剪枝层、BN层和fc层param.requires_grad = Trueelse:param.requires_grad = Falseoptimizer = optim.Adam(filter(lambda p: p.requires_grad, pruned_model.parameters()), lr=0.0001)  # 微调学习率降低pruned_model = train_model(pruned_model, train_loader, criterion, optimizer, device, epochs=10)  # 增加微调轮次

完整的裁剪函数:

def prune_conv_layer(model, layer_name, amount=0.2):device = next(model.parameters()).devicelayer = dict(model.named_modules())[layer_name]weight = layer.weight.data# 基于激活值的通道重要性评估model.eval()with torch.no_grad():test_input = torch.randn(128, 3, 32, 32).to(device)  # 模拟 CIFAR10 输入features = []def hook_fn(module, input, output):features.append(output)handle = layer.register_forward_hook(hook_fn)model(test_input)handle.remove()activation = features[0]  # shape: [128, out_channels, H, W]channel_importance = activation.mean(dim=(0, 2, 3))  # 按通道求平均激活值num_channels = weight.shape[0]num_prune = int(num_channels * amount)_, indices = torch.topk(channel_importance, k=num_prune, largest=False)mask = torch.ones(num_channels, dtype=torch.bool)mask[indices] = False  # 生成剪枝掩码# 创建并替换新卷积层new_conv = nn.Conv2d(in_channels=layer.in_channels,out_channels=num_channels - num_prune,kernel_size=layer.kernel_size,stride=layer.stride,padding=layer.padding,bias=layer.bias is not None).to(device)new_conv.weight.data = layer.weight.data[mask]  # 应用掩码剪枝权重if layer.bias is not None:new_conv.bias.data = layer.bias.data[mask]# 替换原始卷积层parent_name, sep, name = layer_name.rpartition('.')parent = model.get_submodule(parent_name)setattr(parent, name, new_conv)# 仅处理首层 conv1 的特殊逻辑if layer_name == 'conv1':# 更新首层 BN 层(bn1)bn1 = model.bn1new_bn1 = nn.BatchNorm2d(new_conv.out_channels).to(device)with torch.no_grad():new_bn1.weight.data = bn1.weight.data[mask].clone()new_bn1.bias.data = bn1.bias.data[mask].clone()new_bn1.running_mean.data = bn1.running_mean.data[mask].clone()new_bn1.running_var.data = bn1.running_var.data[mask].clone()model.bn1 = new_bn1# 处理 layer1.0 的 downsample 层(若不存在则创建)block = model.layer1[0]if not hasattr(block, 'downsample') or block.downsample is None:# 创建 1x1 卷积 + BN 用于通道匹配downsample_conv = nn.Conv2d(in_channels=new_conv.out_channels,out_channels=block.conv2.out_channels,  # 与主路径输出通道一致(ResNet18 为 64)kernel_size=1,stride=1,bias=False).to(device)# 初始化权重(使用原卷积层的统计量)with torch.no_grad():downsample_conv.weight.data = layer.weight.data.mean(dim=(2,3), keepdim=True)  # 原卷积核均值初始化downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels).to(device)with torch.no_grad():downsample_bn.weight.data.fill_(1.0)downsample_bn.bias.data.zero_()downsample_bn.running_mean.data.zero_()downsample_bn.running_var.data.fill_(1.0)block.downsample = nn.Sequential(downsample_conv, downsample_bn)print("✅ 为 layer1.0 添加新的 downsample 层")else:# 调整已有 downsample 层的输入通道downsample_conv = block.downsample[0]downsample_conv.in_channels = new_conv.out_channelsdownsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone()).to(device)# 更新对应的 BN 层downsample_bn = block.downsample[1]new_downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels).to(device)with torch.no_grad():new_downsample_bn.weight.data = downsample_bn.weight.data.clone()new_downsample_bn.bias.data = downsample_bn.bias.data.clone()new_downsample_bn.running_mean.data = downsample_bn.running_mean.data.clone()new_downsample_bn.running_var.data = downsample_bn.running_var.data.clone()block.downsample[1] = new_downsample_bn# 同步 layer1.0.conv1 的输入通道target_conv = model.layer1[0].conv1if target_conv.in_channels != new_conv.out_channels:print(f"同步 layer1.0.conv1 输入通道: {target_conv.in_channels}{new_conv.out_channels}")target_conv.in_channels = new_conv.out_channelstarget_conv.weight = nn.Parameter(target_conv.weight.data[:, mask, :, :].clone()).to(device)else:# 中间层剪枝逻辑(如 layer2.0.conv1)block_prefix = layer_name.rsplit('.', 1)[0]  # 提取 block 前缀(如 'layer2.0')block = model.get_submodule(block_prefix)     # 获取对应的 block(如 layer2.0)# 更新当前 block 内的 BN 层(conv1 对应 bn1,conv2 对应 bn2)target_bn_name = f"{block_prefix}.bn1" if 'conv1' in layer_name else f"{block_prefix}.bn2"try:target_bn = model.get_submodule(target_bn_name)new_bn = nn.BatchNorm2d(new_conv.out_channels).to(device)with torch.no_grad():new_bn.weight.data = target_bn.weight.data[mask].clone()new_bn.bias.data = target_bn.bias.data[mask].clone()new_bn.running_mean.data = target_bn.running_mean.data[mask].clone()new_bn.running_var.data = target_bn.running_var.data[mask].clone()setattr(block, target_bn_name.split('.')[-1], new_bn)  # 替换原 BN 层print(f"✅ 更新剪枝层 {layer_name} 对应的 BN 层 {target_bn_name}")except AttributeError:print(f"⚠️ 未找到剪枝层 {layer_name} 对应的 BN 层,跳过 BN 更新")# 新增:同步后续卷积层的输入通道(如 conv1 后调整 conv2)if 'conv1' in layer_name:next_conv = block.conv2if next_conv.in_channels != new_conv.out_channels:print(f"同步 {block_prefix}.conv2 输入通道: {next_conv.in_channels}{new_conv.out_channels}")next_conv.in_channels = new_conv.out_channelsnext_conv.weight = nn.Parameter(next_conv.weight.data[:, mask, :, :].clone()).to(device)  # 按剪枝掩码筛选输入通道权重# 可选:如果存在 downsample 层,调整其输入通道(根据实际需求启用)# if hasattr(block, 'downsample') and block.downsample is not None:#     downsample_conv = block.downsample[0]#     downsample_conv.in_channels = new_conv.out_channels#     downsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone()).to(device)#     print(f"✅ 调整剪枝层 {layer_name} 关联的 downsample 层输入通道")# 验证前向传播with torch.no_grad():test_input = torch.randn(1, 3, 32, 32).to(device)try:model(test_input)print("✅ 前向传播验证通过")except Exception as e:print(f"❌ 验证失败: {str(e)}")raisereturn model

改动后结果

经过改动后, 增加微调轮次,得到的结果如下:

剪枝前模型大小信息:
==========================================================================================
Total params: 11,181,642
Trainable params: 11,181,642
Non-trainable params: 0
Total mult-adds (M): 37.03
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.81
Params size (MB): 44.73
Estimated Total Size (MB): 45.55
==========================================================================================
原始模型准确率: 81.42%剪枝后模型大小信息:
==========================================================================================
Total params: 11,138,392
Trainable params: 11,138,392
Non-trainable params: 0
Total mult-adds (M): 36.33
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 0.80
Params size (MB): 44.55
Estimated Total Size (MB): 45.37
==========================================================================================
剪枝后模型准确率: 83.28%

个人认为,这个才是比较符合实际应用的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/diannao/86533.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

传输层协议:UDP

目录 1、概念 2、报文结构 3、核心特性 3.1 无连接 3.2 不可靠交付 3.3 面向数据报 3.4 轻量级&高效 3.5 支持广播和组播 4、典型应用场景 5、优缺点分析 6、与TCP的区别 1、概念 UDP(User Datagram Protocol,用户数据报协议&#xff09…

JVM虚拟机:内存结构、垃圾回收、性能优化

1、JVM虚拟机的简介 Java 虚拟机(Java Virtual Machine 简称:JVM)是运行所有 Java 程序的抽象计算机,是 Java 语言的运行环境,实现了 Java 程序的跨平台特性。JVM 屏蔽了与具体操作系统平台相关的信息,使得 Java 程序只需生成在 JVM 上运行的目标代码(字节码),就可以…

c++ 面试题(1)-----深度优先搜索(DFS)实现

操作系统:ubuntu22.04 IDE:Visual Studio Code 编程语言:C11 题目描述 地上有一个 m 行 n 列的方格,从坐标 [0,0] 起始。一个机器人可以从某一格移动到上下左右四个格子,但不能进入行坐标和列坐标的数位之和大于 k 的格子。 例…

【汇编逆向系列】七、函数调用包含多个参数之浮点型- XMM0-3寄存器

目录 1. 汇编代码 1.1 debug编译 1.2 release编译 2. 汇编分析 2.1 浮点参数传递规则 2.2 栈帧rsp的变化时序 2.3 参数的访问逻辑 2.4 返回值XMM0寄存器 3. 汇编转化 3.1 Debug编译 3.2 Release 编译 3.3 C语言转化 1. 汇编代码 上一节介绍了整型的函数传参&#x…

华为云Flexus+DeepSeek征文 | 从零到一:用Flexus云服务打造低延迟联网搜索Agent

作者简介 我是摘星,一名专注于云计算和AI技术的开发者。本次通过华为云MaaS平台体验DeepSeek系列模型,将实际使用经验分享给大家,希望能帮助开发者快速掌握华为云AI服务的核心能力。 目录 作者简介 前言 1. 项目背景与技术选型 1.1 项目…

【多智能体】受木偶戏启发实现多智能体协作编排

😊你好,我是小航,一个正在变秃、变强的文艺倾年。 🔔本专栏《人工智能》旨在记录最新的科研前沿,包括大模型、具身智能、智能体等相关领域,期待与你一同探索、学习、进步,一起卷起来叭&#xff…

Java八股文——Spring篇

文章目录 Java八股文专栏其它文章Java八股文——Spring篇SpringSpring的IoC和AOPSpring IoC实现机制Spring AOP实现机制 动态代理JDK ProxyCGLIBByteBuddy Spring框架中的单例Bean是线程安全的吗?什么是AOP,你们项目中有没有使用到AOPSpring中的事务是如…

NineData数据库DevOps功能全面支持百度智能云向量数据库 VectorDB,助力企业 AI 应用高效落地

NineData 的数据库 DevOps 解决方案已完成对百度智能云向量数据库 VectorDB 的全链路适配,成为国内首批提供 VectorDB 原生操作能力的服务商。此次合作聚焦 AI 开发核心场景,通过标准化 SQL 工作台与细粒度权限管控两大能力,助力企业安全高效…

开源技术驱动下的上市公司财务主数据管理实践

开源技术驱动下的上市公司财务主数据管理实践 —— 以人造板制造业为例 引言:财务主数据的战略价值与行业挑战 在资本市场监管日益严格与企业数字化转型的双重驱动下,财务主数据已成为上市公司财务治理的核心基础设施。对于人造板制造业而言&#xff0…

借助它,普转也能获得空转信息?

在生命科学研究领域,转录组技术是探索基因表达奥秘的有力工具,在疾病机制探索、生物发育进程解析等诸多方面取得了显著进展。然而,随着研究的深入,研究人员发现普通转录组只能提供整体样本中的基因表达水平信息,却无法…

synchronized 学习

学习源: https://www.bilibili.com/video/BV1aJ411V763?spm_id_from333.788.videopod.episodes&vd_source32e1c41a9370911ab06d12fbc36c4ebc 1.应用场景 不超卖,也要考虑性能问题(场景) 2.常见面试问题: sync出…

Java事务回滚详解

一、什么是事务回滚? 事务回滚指的是:当执行过程中发生异常时,之前对数据库所做的更改全部撤销,数据库状态恢复到事务开始前的状态。这是数据库“原子性”原则的体现。 二、Spring 中的 Transactional 默认行为 在 Spring 中&am…

云灾备数据复制技术研究

云灾备数据复制技术:数字时代的“安全气囊” 在当今信息化时代,数据就像城市的“生命线”,一旦中断,后果不堪设想。想象一下,如果政务系统突然崩溃,成千上万的市民服务将陷入瘫痪。这就是云灾备技术的重要…

如何处理Shopify主题的显示问题:实用排查与修复指南

在Shopify店铺运营过程中,主题显示问题是影响用户体验与品牌形象的常见痛点。可能是字体错位、图片无法加载、移动端显示混乱、功能失效等,这些都可能造成客户流失和转化下降。 本文将从问题识别、原因分析、修复方法到开发者建议全方位解读如何高效解决…

前端监控方案详解

一、前端监控方案是什么? 前端监控方案是一套系统化的工具和流程,用于收集、分析和报告网站或Web应用在前端运行时的各种性能指标、错误日志、用户行为等数据。它通常包括以下几个核心模块: 性能监控:页面加载时间、资源加载时间…

Camera相机人脸识别系列专题分析之十二:人脸特征检测FFD算法之libvega_face.so数据结构详解

【关注我,后续持续新增专题博文,谢谢!!!】 上一篇我们讲了: Camera相机人脸识别系列专题分析之十一:人脸特征检测FFD算法之低功耗libvega_face.so人脸属性(年龄,性别,肤…

如何配置HarmonyOS 5与React Native的开发环境?

配置 HarmonyOS 5 与 React Native 的开发环境需遵循以下步骤 一、基础工具安装 ‌DevEco Studio 5.0‌ 从 HarmonyOS 开发者官网 下载安装勾选组件: HarmonyOS SDK (API 12)ArkTS 编译器JS/ArkTS 调试工具HarmonyOS 本地模拟器 ‌Node.js 18.17 # 安装后验证版…

kotlin kmp 副作用函数 effect

在 Kotlin Multiplatform (KMP) Compose 中,“effect functions”(或“effect handlers”)是专门的可组合函数,用于在 UI 中管理副作用。 在 Compose 中,可组合函数应该是“纯”的和声明式的。这意味着它们应该理想地…

3.3.1_1 检错编码(奇偶校验码)

从这节课开始,我们会探讨数据链路层的差错控制功能,差错控制功能的主要目标是要发现并且解决一个帧内部的位错误,我们需要使用特殊的编码技术去发现帧内部的位错误,当我们发现位错误之后,通常来说有两种解决方案。第一…

【Pandas】pandas DataFrame isna

Pandas2.2 DataFrame Missing data handling 方法描述DataFrame.fillna([value, method, axis, …])用于填充 DataFrame 中的缺失值(NaN)DataFrame.backfill(*[, axis, inplace, …])用于**使用后向填充(即“下一个有效观测值”&#xff09…