深度学习之模型压缩三驾马车:基于ResNet18的模型剪枝实战(1)

一、背景:为什么需要模型剪枝?

随着深度学习的发展,模型参数量和计算量呈指数级增长。以ResNet18为例,其在ImageNet上的参数量约为1100万,虽然在服务器端运行流畅,但在移动端或嵌入式设备上部署时,内存和计算资源的限制使得直接使用大模型变得困难。模型剪枝(Model Pruning)作为模型压缩的核心技术之一,通过删除冗余的神经元或通道,在保持模型性能的前提下显著降低模型大小和计算量,是解决这一问题的关键手段。
在前面一篇文章我们也提到了模型压缩的一些基本定义和核心原理:《深度学习之模型压缩三驾马车:模型剪枝、模型量化、知识蒸馏》。

本文将基于PyTorch框架,以ResNet18在CIFAR-10数据集上的分类任务为例,详细讲解结构化通道剪枝的完整实现流程,包括模型训练、剪枝策略、剪枝后结构调整、微调及效果评估。

二、整体流程概览

本文代码的核心流程可总结为以下6步:

  1. 环境初始化与数据集加载
  2. 原始模型训练与评估
  3. 卷积层结构化剪枝(以conv1层为例)
  4. 剪枝后模型结构调整(BN层、残差下采样层等)
  5. 剪枝模型微调
  6. 剪枝前后模型效果对比
    特地说明:在这里选择conv1层作为例子,不是因为选择这个就会效果更好。

三、关键步骤代码解析

3.1 环境初始化与数据集准备

首先需要配置计算设备(GPU/CPU),并加载CIFAR-10数据集。CIFAR-10包含10类32x32的彩色图像,训练集5万张,测试集1万张。

def setup_device():return torch.device("cuda" if torch.cuda.is_available() else "cpu")def load_dataset():transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # 归一化到[-1,1]])train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)return train_dataset, test_dataset

3.2 原始模型训练

使用预训练的ResNet18模型,修改全连接层输出为10类(匹配CIFAR-10的类别数),并进行5轮训练:

def create_model(device):model = models.resnet18(pretrained=True)  # 加载ImageNet预训练权重model.fc = nn.Linear(512, 10)  # 修改输出层为10类return model.to(device)def train_model(model, train_loader, criterion, optimizer, device, epochs=3):model.train()for epoch in range(epochs):running_loss = 0.0for images, labels in tqdm(train_loader):images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")return model

3.3 结构化通道剪枝核心实现

本文重点是对卷积层进行结构化剪枝(按通道剪枝),具体步骤如下:

3.3.1 计算通道重要性

通过计算卷积核的L2范数评估通道重要性。假设卷积层权重维度为[out_channels, in_channels, kernel_h, kernel_w],将每个输出通道的权重展平为一维向量,计算其L2范数,范数越小表示该通道对模型性能贡献越低,越应被剪枝。

layer = dict(model.named_modules())[layer_name]  # 获取目标卷积层
weight = layer.weight.data
channel_norm = torch.norm(weight.view(weight.shape[0], -1), p=2, dim=1)  # 计算每个输出通道的L2范数
3.3.2 生成剪枝掩码

根据剪枝比例(如20%),选择范数最小的通道生成掩码:

num_channels = weight.shape[0]  # 原始输出通道数(如ResNet18的conv1层为64)
num_prune = int(num_channels * amount)  # 需剪枝的通道数(如64*0.2=12)
_, indices = torch.topk(channel_norm, k=num_prune, largest=False)  # 找到最不重要的12个通道mask = torch.ones(num_channels, dtype=torch.bool)
mask[indices] = False  # 掩码:保留的通道标记为True(52个),剪枝的标记为False(12个)
3.3.3 替换卷积层

创建新的卷积层,仅保留掩码为True的通道:

new_conv = nn.Conv2d(in_channels=layer.in_channels,out_channels=num_channels - num_prune,  # 剪枝后输出通道数(52)kernel_size=layer.kernel_size,stride=layer.stride,padding=layer.padding,bias=layer.bias is not None
).to(device)  # 移动到模型所在设备new_conv.weight.data = layer.weight.data[mask]  # 保留掩码为True的通道权重
if layer.bias is not None:new_conv.bias.data = layer.bias.data[mask]  # 偏置同理
3.3.4 关键:剪枝后结构调整

直接剪枝会导致后续层(如BN层、残差连接中的下采样层)的输入/输出通道不匹配,必须同步调整:

(1) 调整BN层
卷积层后通常接BN层,BN的num_features需与卷积输出通道数一致:

if 'conv1' in layer_name:bn1 = model.bn1new_bn1 = nn.BatchNorm2d(new_conv.out_channels).to(device)  # 新BN层通道数52with torch.no_grad():# 同步原始BN层的参数(仅保留未被剪枝的通道)new_bn1.weight.data = bn1.weight.data[mask].clone()new_bn1.bias.data = bn1.bias.data[mask].clone()new_bn1.running_mean.data = bn1.running_mean.data[mask].clone()new_bn1.running_var.data = bn1.running_var.data[mask].clone()model.bn1 = new_bn1

(2) 调整残差下采样层
ResNet的残差块(如layer1.0)中,若主路径的通道数被剪枝,需要通过1x1卷积的下采样层(downsample)匹配 shortcut 的通道数:

block = model.layer1[0]
if not hasattr(block, 'downsample') or block.downsample is None:# 原始无downsample,创建新的1x1卷积+BNdownsample_conv = nn.Conv2d(in_channels=new_conv.out_channels,  # 52(剪枝后的conv1输出)out_channels=block.conv2.out_channels,  # 64(主路径conv2的输出)kernel_size=1,stride=1,bias=False).to(device)torch.nn.init.kaiming_normal_(downsample_conv.weight, mode='fan_out', nonlinearity='relu')  # 初始化权重downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels).to(device)block.downsample = nn.Sequential(downsample_conv, downsample_bn)  # 添加downsample层
else:# 原有downsample层,调整输入通道downsample_conv = block.downsample[0]downsample_conv.in_channels = new_conv.out_channels  # 输入通道改为52downsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone())  # 输入通道用掩码筛选

(3) 前向传播验证
调整后需验证模型能否正常前向传播,避免通道不匹配导致的错误:

with torch.no_grad():test_input = torch.randn(1, 3, 32, 32).to(device)  # 测试输入(B, C, H, W)try:model(test_input)print("✅ 前向传播验证通过")except Exception as e:print(f"❌ 验证失败: {str(e)}")raise

3.3的总结,直接上代码

def prune_conv_layer(model, layer_name, amount=0.2):# 获取模型当前所在设备device = next(model.parameters()).device  # 新增:获取设备layer = dict(model.named_modules())[layer_name]weight = layer.weight.datachannel_norm = torch.norm(weight.view(weight.shape[0], -1), p=2, dim=1)num_channels = weight.shape[0]  # 原始通道数(如 64)num_prune = int(num_channels * amount)_, indices = torch.topk(channel_norm, k=num_prune, largest=False)mask = torch.ones(num_channels, dtype=torch.bool)mask[indices] = False  # 生成剪枝掩码(长度 64,52 个 True)new_conv = nn.Conv2d(in_channels=layer.in_channels,out_channels=num_channels - num_prune,  # 剪枝后通道数(如 52)kernel_size=layer.kernel_size,stride=layer.stride,padding=layer.padding,bias=layer.bias is not None)new_conv = new_conv.to(device)  # 新增:移动到模型所在设备new_conv.weight.data = layer.weight.data[mask]  # 保留 mask 为 True 的通道if layer.bias is not None:new_conv.bias.data = layer.bias.data[mask]# 替换原始卷积层parent_name, sep, name = layer_name.rpartition('.')parent = model.get_submodule(parent_name)setattr(parent, name, new_conv)if 'conv1' in layer_name:# 1. 更新与 conv1 直接关联的 BN1 层bn1 = model.bn1new_bn1 = nn.BatchNorm2d(new_conv.out_channels)  # 新 BN 层通道数 52new_bn1 = new_bn1.to(device)  # 新增:移动到模型所在设备with torch.no_grad():new_bn1.weight.data = bn1.weight.data[mask].clone()new_bn1.bias.data = bn1.bias.data[mask].clone()new_bn1.running_mean.data = bn1.running_mean.data[mask].clone()new_bn1.running_var.data = bn1.running_var.data[mask].clone()model.bn1 = new_bn1# 2. 处理残差连接中的 downsample(关键修正:添加缺失的 downsample)block = model.layer1[0]if not hasattr(block, 'downsample') or block.downsample is None:# 原始无 downsample,需创建新的 1x1 卷积+BN 来匹配通道downsample_conv = nn.Conv2d(in_channels=new_conv.out_channels,  # 52out_channels=block.conv2.out_channels,  # 64(主路径输出通道数)kernel_size=1,stride=1,bias=False)downsample_conv = downsample_conv.to(device)  # 新增:移动到模型所在设备# 初始化 1x1 卷积权重(这里简单复制原模型可能的统计量,实际可根据需求调整)torch.nn.init.kaiming_normal_(downsample_conv.weight, mode='fan_out', nonlinearity='relu')downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels)downsample_bn = downsample_bn.to(device)  # 新增:移动到模型所在设备with torch.no_grad():# 初始化 BN 参数(可保持默认,或根据原模型统计量调整)downsample_bn.weight.fill_(1.0)downsample_bn.bias.zero_()downsample_bn.running_mean.zero_()downsample_bn.running_var.fill_(1.0)block.downsample = nn.Sequential(downsample_conv, downsample_bn)print("✅ 为 layer1.0 添加新的 downsample 层")else:# 原有 downsample 层,调整输入通道downsample_conv = block.downsample[0]downsample_conv.in_channels = new_conv.out_channels  # 输入通道调整为 52downsample_conv.weight = nn.Parameter(downsample_conv.weight.data[:, mask, :, :].clone())  # 输入通道用 mask 筛选downsample_conv = downsample_conv.to(device)  # 新增:移动到模型所在设备downsample_bn = block.downsample[1]new_downsample_bn = nn.BatchNorm2d(downsample_conv.out_channels)new_downsample_bn = new_downsample_bn.to(device)  # 新增:移动到模型所在设备with torch.no_grad():new_downsample_bn.weight.data = downsample_bn.weight.data.clone()new_downsample_bn.bias.data = downsample_bn.bias.data.clone()new_downsample_bn.running_mean.data = downsample_bn.running_mean.data.clone()new_downsample_bn.running_var.data = downsample_bn.running_var.data.clone()block.downsample[1] = new_downsample_bn# 3. 同步 layer1.0.conv1 的输入通道(保持原有逻辑)next_convs = ['layer1.0.conv1']for conv_path in next_convs:try:conv = model.get_submodule(conv_path)if conv.in_channels != new_conv.out_channels:print(f"同步输入通道: {conv.in_channels}{new_conv.out_channels}")conv.in_channels = new_conv.out_channelsconv.weight = nn.Parameter(conv.weight.data[:, mask, :, :].clone())conv = conv.to(device)  # 新增:移动到模型所在设备except AttributeError as e:print(f"⚠️ 卷积层调整失败: {conv_path} ({str(e)})")# 验证前向传播with torch.no_grad():test_input = torch.randn(1, 3, 32, 32).to(device)  # 确保测试输入也在相同设备try:model(test_input)print("✅ 前向传播验证通过")except Exception as e:print(f"❌ 验证失败: {str(e)}")raisereturn model

3.4 剪枝模型微调

剪枝后模型的部分参数被删除,需要通过微调恢复性能。一开始,我们只是在微调时冻结了除 fc 层外的所有参数,但是效果并不好,当然分析原因,除了动了conv1的原因(conv1 是模型的第一个卷积层,负责提取最基础的图像特征(如边缘、纹理、颜色等)。这些底层特征对后续所有层的特征提取至关重要。),最重要的是裁剪后,需要对裁剪的层进行微调,确保参数适应新的特征维度。
微调时冻结了除 fc 层外的所有参数的代码和结果:

for name, param in pruned_model.named_parameters():if 'fc' not in name:param.requires_grad = Falseoptimizer = optim.Adam(pruned_model.fc.parameters(), lr=0.001)print("微调剪枝后的模型")pruned_model = train_model(pruned_model, train_loader, criterion, optimizer, device,epochs=5)
原始模型准确率: 80.07%
剪枝后模型准确率: 37.80%

可以看到这个相差很大
本文选择解冻被剪枝的层(如conv1bn1)及相关层(如layer1.0.conv1downsample)进行参数更新:

print("开始微调剪枝后的模型")
for name, param in pruned_model.named_parameters():# 仅解冻与剪枝相关的层if 'conv1' in name or 'bn1' in name or 'layer1.0.conv1' in name or 'layer1.0.downsample' in name or 'fc' in name:param.requires_grad = Trueelse:param.requires_grad = False
optimizer = optim.Adam(filter(lambda p: p.requires_grad, pruned_model.parameters()), lr=0.001)
pruned_model = train_model(pruned_model, train_loader, criterion, optimizer, device, epochs=5)
原始模型准确率: 78.94%
剪枝后模型准确率:  81.30%

重新微调了裁剪后的层后,结果有了很大改变。

四、实验结果与分析

通过代码中的evaluate_model函数评估剪枝前后的模型准确率:

def evaluate_model(model, device, test_loader):model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()acc = 100 * correct / totalreturn acc

假设原始模型准确率为88.5%,剪枝20%通道后(模型大小降低约20%),通过微调可恢复至87.2%,验证了剪枝策略的有效性。

五、总结与改进方向

本文实现了基于通道L2范数的结构化剪枝,重点解决了剪枝后模型结构不一致的问题(如BN层、残差下采样层的调整),并通过微调恢复了模型性能。
在这个例子中,仅裁剪 conv1 层的影响
仅裁剪 conv1 层对模型的影响极大,原因如下:

  • 底层特征的重要性 : conv1 输出的是最基础的图像特征,所有后续层的特征均基于此生成。裁剪 conv1 会直接限制后续所有层的特征表达能力。
  • 结构连锁反应 : conv1 的输出通道减少会触发 bn1 、 layer1.0.conv1 、 downsample 等多个模块的调整,任何一个模块的调整失误(如通道数不匹配、参数初始化不当)都会导致整体性能下降。
    实际应用中可从以下方向改进:

模型裁剪通常优先选择 中间层(如ResNet的 layer2 、 layer3 ) ,而非底层或顶层,原因如下:

  • 底层(如 conv1 ) :负责基础特征提取,裁剪后特征损失大,对性能影响显著。
  • 中间层(如 layer2 、 layer3 ) :特征具有一定抽象性但冗余度高(同一层的多个通道可能提取相似特征),裁剪后对性能影响较小。
  • 顶层(如 fc 层) :负责分类决策,参数密度高但冗余度低,裁剪易导致分类能力下降。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/news/908873.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uni-app学习笔记二十四--showLoading和showModal的用法

showLoading(OBJECT) 显示 loading 提示框, 需主动调用 uni.hideLoading 才能关闭提示框。 OBJECT参数说明 参数类型必填说明平台差异说明titleString是提示的文字内容,显示在loading的下方maskBoolean否是否显示透明蒙层,防止触摸穿透,默…

【大模型RAG】六大 LangChain 支持向量库详细对比

摘要 向量数据库已经成为检索增强生成(RAG)、推荐系统和多模态检索的核心基础设施。本文从 Chroma、Elasticsearch、Milvus、Redis、FAISS、Pinecone 六款 LangChain 官方支持的 VectorStore 出发,梳理它们的特性、典型应用场景与性能边界&a…

【MySQL】数据库三大范式

目录 一. 什么是范式 二. 第一范式 三. 第二范式 不满足第二范式时可能出现的问题 四. 第三范式 一. 什么是范式 在数据库中范式其实就是一组规则,在我们设计数据库的时候,需要遵守不同的规则要求,设计出合理的关系型数据库,…

Coze工作流-语音故事创作-文本转语音的应用

教程简介 本教程将带着大家去了解怎么样把文本转换成语音,例如说我们要做一些有声故事,我们可能会用上一些语音的技术,来把你创作的故事朗读出来 首先我们创建一个工作流 对各个模块进行编辑,如果觉得系统提示词写的不好&#xf…

5.子网划分及分片相关计算

某公司网络使用 IP 地址空间 192.168.2.0/24,现需将其均分给 市场部 和 研发部 两个子网。已知: 🏢 市场部子网 🖥️ 已分配 IP 地址范围:192.168.2.1 ~ 192.168.2.30🌐 路由器接口 IP:192.16…

三体问题详解

从物理学角度,三体问题之所以不稳定,是因为三个天体在万有引力作用下相互作用,形成一个非线性耦合系统。我们可以从牛顿经典力学出发,列出具体的运动方程,并说明为何这个系统本质上是混沌的,无法得到一般解…

机器学习算法时间复杂度解析:为什么它如此重要?

时间复杂度的重要性 虽然scikit-learn等库让机器学习算法的实现变得异常简单(通常只需2-3行代码),但这种便利性往往导致使用者忽视两个关键方面: 算法核心原理的理解缺失 忽视算法的数据适用条件 典型算法的时间复杂度陷阱 SV…

uniapp 对接腾讯云IM群组成员管理(增删改查)

UniApp 实战:腾讯云IM群组成员管理(增删改查) 一、前言 在社交类App开发中,群组成员管理是核心功能之一。本文将基于UniApp框架,结合腾讯云IM SDK,详细讲解如何实现群组成员的增删改查全流程。 权限校验…

OPENCV图形计算面积、弧长API讲解(1)

一.OPENCV图形面积、弧长计算的API介绍 之前我们已经把图形轮廓的检测、画框等功能讲解了一遍。那今天我们主要结合轮廓检测的API去计算图形的面积,这些面积可以是矩形、圆形等等。图形面积计算和弧长计算常用于车辆识别、桥梁识别等重要功能,常用的API…

一.设计模式的基本概念

一.核心概念 对软件设计中重复出现问题的成熟解决方案,提供代码可重用性、可维护性和扩展性保障。核心原则包括: 1.1. 单一职责原则‌ ‌定义‌:一个类只承担一个职责,避免因职责过多导致的代码耦合。 1.2. 开闭原则‌ ‌定义‌&#xf…

React第五十七节 Router中RouterProvider使用详解及注意事项

前言 在 React Router v6.4 中&#xff0c;RouterProvider 是一个核心组件&#xff0c;用于提供基于数据路由&#xff08;data routers&#xff09;的新型路由方案。 它替代了传统的 <BrowserRouter>&#xff0c;支持更强大的数据加载和操作功能&#xff08;如 loader 和…

Opencv中的addweighted函数

一.addweighted函数作用 addweighted&#xff08;&#xff09;是OpenCV库中用于图像处理的函数&#xff0c;主要功能是将两个输入图像&#xff08;尺寸和类型相同&#xff09;按照指定的权重进行加权叠加&#xff08;图像融合&#xff09;&#xff0c;并添加一个标量值&#x…

C++ 基础特性深度解析

目录 引言 一、命名空间&#xff08;namespace&#xff09; C 中的命名空间​ 与 C 语言的对比​ 二、缺省参数​ C 中的缺省参数​ 与 C 语言的对比​ 三、引用&#xff08;reference&#xff09;​ C 中的引用​ 与 C 语言的对比​ 四、inline&#xff08;内联函数…

关于面试找工作的总结(四)

不同情况下收到offer后的处理方法 1.不会去的,只是面试练手2.还有疑问,考虑中3.offer/职位不满足期望的4.已确认,但又收到更好的5.还想挽回之前的offer6.确认,准备入职7.还想拖一下的1.不会去的,只是面试练手 HR您好,非常荣幸收到贵司的offer,非常感谢一直以来您的帮助,…

什么是高考?高考的意义是啥?

能见到这个文章的群体&#xff0c;应该都经历过高考&#xff0c;突然想起“什么是高考&#xff1f;意义何在&#xff1f;” 一、高考的定义与核心功能 **高考&#xff08;普通高等学校招生全国统一考试&#xff09;**是中国教育体系的核心选拔性考试&#xff0c;旨在为高校选拔…

L1和L2核心区别 !!--part 2

哈喽&#xff0c;我是 我不是小upper~ 昨天&#xff0c;咱们分享了关于 L1 正则化和 L2 正则化核心区别的精彩内容。今天我来进一步补充和拓展。 首先&#xff0c;咱们先来聊聊 L1 和 L2 正则化&#xff0c;方便刚接触的同学理解。 L1 正则化&#xff08;Lasso&#xff09;&…

字节推出统一多模态模型 BAGEL,GPT-4o 级的图像生成能力直接开源了!

字节推出的 BAGEL 是一个开源的统一多模态模型&#xff0c;他们直接开源了GPT-4o级别的图像生成能力。&#xff08;轻松拿捏“万物皆可吉卜力”玩法~&#xff09;。可以在任何地方对其进行微调、提炼和部署&#xff0c;它以开放的形式提供与 GPT-4o 和 Gemini 2.0 等专有系统相…

互联网大厂Java面试:从Spring Cloud到Kafka的技术考察

场景&#xff1a;互联网大厂Java求职者面试 面试官与谢飞机的对话 面试官&#xff1a;我们先从基础开始&#xff0c;谢飞机&#xff0c;你能简单介绍一下Java SE和Java EE的区别吗&#xff1f; 谢飞机&#xff1a;哦&#xff0c;这个简单。Java SE是标准版&#xff0c;适合桌…

18-Oracle 23ai JSON二元性颠覆传统

在当今百花齐放的多模型数据库时代&#xff0c;开发人员常在关系型与文档型数据库间艰难取舍。Oracle Database 23ai推出的JSON关系二元性&#xff08;JSON Relational Duality&#xff09;​​ 和二元性视图&#xff08;Duality Views&#xff09;​​ 创新性地统一了两者优势…

蓝桥杯 冶炼金属

原题目链接 &#x1f527; 冶炼金属转换率推测题解 &#x1f4dc; 原题描述 小蓝有一个神奇的炉子用于将普通金属 O O O 冶炼成为一种特殊金属 X X X。这个炉子有一个属性叫转换率 V V V&#xff0c;是一个正整数&#xff0c;表示每 V V V 个普通金属 O O O 可以冶炼出 …