模型剪枝----ResNet18剪枝实战

剪枝

模型剪枝(Model Pruning) 是一种 模型压缩(Model Compression) 技术,主要思想是:
深度神经网络里有很多 冗余参数(对预测结果贡献很小)。
通过去掉这些冗余连接/通道/卷积核,能让模型更小、更快,同时尽量保持精度。

非结构化剪枝(Unstructured Pruning)

对单个权重参数设置阈值,小于阈值的直接置零。
优点:保留了原始网络结构,容易实现。
缺点:稀疏矩阵计算对普通硬件加速有限(需要专门稀疏库)。

#将所有的卷积层通道减掉30%
for module in pruned_model.modules():if isinstance(module,nn.Conv2d):#这行代码的作用是对指定模块按照L2范数的标准,沿着输出通道维度剪去30%的不重要通道,prune.ln_structured(module,name = "weight",amount = 0.3,n=2,dim = 0)

对ResNet18减和不减的效果差不多,一个是精度,另一个是一轮推理的时间
在这里插入图片描述
分析原因 确实把 30% 卷积核置零,但是模块结构没变:Conv2d 还是原来那么大,只是部分权重被置零, PyTorch 的默认实现不会自动跳过这些“无效通道”, 所以 FLOPs 还是一样,ptflops 统计出来的数字没减少, GPU 上仍然执行全量卷积,推理时间几乎不会变化

结构化剪枝(Structured Pruning)

删除整个卷积核、通道、层。
优点:能直接减少计算量和推理时间。
缺点:剪掉的多了容易掉精度。

完整代码

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils.prune as prune
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import time
from tqdm import tqdm
from ptflops import get_model_complexity_info
import torch_pruning as tp# ======================
# 1. 数据准备
# ======================
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
])transform_test = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010)),
])trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,shuffle=False, num_workers=2)device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" )
# ======================
# 2. 定义训练和测试函数
# ======================
def train(model,optimizer,criterion,epoch):model.train()for inx,(inputs,targets) in enumerate(trainloader):inputs,targets = inputs.to(device),targets.to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs,targets)loss.backward()optimizer.step()def test(model,criterion,epoch,tag = ""):model.eval()start = time.time()correct,total,loss_sum = 0,0,0.0with torch.no_grad():for inputs, targets in testloader:inputs,targets = inputs.to(device), targets.to(device)outputs = model(inputs)loss_sum = criterion(outputs,targets).item()_, predicted = outputs.max(1)total += targets.size(0)correct += predicted.eq(targets).sum().item()acc = 100. * correct / totalend = time.time()time_cost = end - startprint(f"{tag} Epoch {epoch}: Loss={loss_sum:.4f}, Acc={acc:.2f}%, Time={time_cost:.2f}s")return acc,time_costdef print_model_stats(model,tag = ""):#统计模型参数和flopsmac, params = get_model_complexity_info(model,(3,32,32),as_strings = True,print_per_layer_stat = False,verbose = False)print(f"{tag} Params:{params},FLOPs:{mac}")# ======================
# 3. 训练基线模型
# ======================
print("===============BaseLine ResNet18")
baseline_model = models.resnet18(pretrained = True)
baseline_model.fc = nn.Linear(baseline_model.fc.in_features,10)
baseline_model = baseline_model.to(device)
print_model_stats(baseline_model,"Baseline")criterion = nn.CrossEntropyLoss()
optimer = optim.SGD(baseline_model.parameters(),lr = 0.01,momentum = 0.9,weight_decay = 5e-4)
baseline_acc = []
baseline_time = []
for epoch in tqdm(range(10)):train(baseline_model,optimer,criterion,epoch)acc,time_cost = test(baseline_model,criterion,epoch,"Baseline")baseline_acc.append(acc)baseline_time.append(time_cost)# ======================
# 4. 剪枝 + 微调
# ======================
pruned_model = models.resnet18(pretrained = True)
pruned_model.fc = nn.Linear(pruned_model.fc.in_features,10)
pruned_model = pruned_model.to(device)#===============非结构化剪枝=====================
# #将所有的卷积层通道减掉30%
# for module in pruned_model.modules():
#     if isinstance(module,nn.Conv2d):
#         #这行代码的作用是对指定模块按照L2范数的标准,沿着输出通道维度剪去30%的不重要通道,
#         prune.ln_structured(module,name = "weight",amount = 0.3,n=2,dim = 0)#==========================结构化剪枝=====================
# 创建依赖图对象,用于处理剪枝时各层之间的依赖关系
DG = tp.DependencyGraph()
# 构建模型的依赖关系图,需要提供示例输入来追踪计算图
# example_inputs用于追踪模型的前向传播路径,确定各层之间的依赖关系
DG.build_dependency(pruned_model,example_inputs = torch.randn(1,3,32,32).to(device))def prune_conv_by_ratio(conv, ratio=0.3):# 计算每个输出通道的L1范数(绝对值求和),用于评估通道的重要性# conv.weight.data.abs().sum((1, 2, 3)) 对卷积核的后三维(H, W, C_in)求和,得到每个输出通道的L1范数weight = conv.weight.data.abs().sum((1, 2, 3))  # 根据指定的剪枝比例计算需要移除的通道数量num_remove = int(weight.numel() * ratio)# 找到L1范数最小的num_remove个通道的索引# torch.topk返回最大的k个元素,设置largest=False后返回最小的k个元素_, idxs = torch.topk(weight, k=num_remove, largest=False)# 获取剪枝组,指定要剪枝的层、剪枝方式和剪枝索引# tp.prune_conv_out_channels表示沿输出通道维度进行剪枝group = DG.get_pruning_group(conv, tp.prune_conv_out_channels, idxs=idxs.tolist())# 执行剪枝操作,物理移除指定的通道group.prune()# 遍历剪枝模型的所有模块
for m in pruned_model.modules():# 检查模块是否为卷积层if isinstance(m, nn.Conv2d):# 对该卷积层执行剪枝操作,移除30%的输出通道prune_conv_by_ratio(m, ratio=0.3)#=======================================================print_model_stats(pruned_model,"Pruned")
criterion1 = nn.CrossEntropyLoss()
optimer1 = optim.SGD(pruned_model.parameters(),lr = 0.01,momentum = 0.9,weight_decay = 5e-4)
pruned_acc = []
pruned_time = []for epoch in tqdm(range(10)):train(pruned_model,optimer1,criterion1,epoch)acc,time_cost = test(pruned_model,criterion1,epoch,"Pruned")pruned_acc.append(acc)pruned_time.append(time_cost)# ======================
# 5. 对比结果
# ======================
print("\n==== Final Accuracy Comparison ====")print(f" Baseline={max(baseline_acc):.2f}% time={sum(baseline_time)/len(baseline_time):.2f}, Pruned={max(pruned_acc):.2f}% time={sum(pruned_time)/len(pruned_time):.2f}")

最终训练10轮的情况下精度下降7%,模型参数量减少4倍,感觉能够接受
Params:11.18 M – > 2.7M
FLOPs:37.25 MMac --> 9.48 MMac
acc : 82.86% —> 75.77%
time : 1.20 ----> 1.12
在这里插入图片描述

基于正则化/稀疏约束

在训练时加上稀疏正则项,让网络自动学习出“重要性低”的权重趋近于零,再做剪枝。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/98441.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/98441.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

K8S-Pod(上)

Pod概念 Pod 是可以在 Kubernetes 中创建和管理的、最小的可部署的计算单元。 Pod是一组(一个或多个)容器;这些容器共享存储、网络、以及怎样运行这些容器的规约。Pod 中的内容总是并置(colocated)的并且一同调度&am…

Flink TaskManager日志时间与实际时间有偏差

Flink 启动一个任务后,发现TaskManager上日志时间与实际时间相差约 15 小时。 核心原因可能是: 1、 服务器(或容器)的系统时间配置错误2、 Flink 日志组件(如 Logback/Log4j)的时间配置未使用系统默认时区…

Webug3.0通关笔记18 中级进阶第06关 实战练习:DisCuz论坛SQL注入漏洞

目录 一、环境搭建 1、服务启动 2、源码解压 3、构造访问靶场URL 4、靶场安装 5、访问论坛首页 二、代码分析 1、源码分析 2、SQL注入分析 三、渗透实战 (1)判断是否有SQL注入风险 (2)查询账号密码 Discuz! 作为国内知…

SWEET:大语言模型的选择性水印

摘要背景与问题大语言模型出色的生成能力引发了伦理与法律层面的担忧,于是通过嵌入水印来检测机器生成文本的方法逐渐发展起来。但现有工作在代码生成任务中无法良好发挥作用,原因在于代码生成任务本身的特性(代码有其特定的语法、逻辑结构&a…

FastDFS V6双IP特性及配置

FastDFS V6.0开始支持双IP,tracker server和storage server均支持双IP。V6.0新增特性说明如下:支持双IP,一个内网IP,一个外网IP,可以支持NAT方式的内网和外网两个IP,解决跨机房或混合云部署问题。FastDFS双…

笔记本、平板如何成为电脑拓展屏?向日葵16成为副屏功能一键实现

向日葵16重磅上线,本次更新新增了诸多实用功能,提升远控效率,实现应用融合突破设备边界,同时全面提升远控性能,操作更顺滑、画质更清晰!无论远程办公、设计、IT运维、开发还是游戏娱乐,向日葵16…

基于Spring Boot + MyBatis的用户管理系统配置

我来为您详细分析这两个配置文件的功能和含义。 一、文件整体概述 这是一个基于Spring Boot MyBatis的用户管理系统配置: UserMapper.xml:MyBatis的SQL映射文件,定义了用户表的增删改查操作application.yml:Spring Boot的核心配置…

80(HTTP默认端口)和8080端口(备用HTTP端口)区别

文章目录**1. 用途**- **80端口**- **8080端口****2. 默认配置**- **80端口**- **8080端口****3. 联系**- **逻辑端口**:两者都是TCP/IP协议中的逻辑端口,用于标识不同的网络服务。- **可配置性**:端口号可以根据需要修改(例如将T…

【开题答辩全过程】以 汽车知名品牌信息管理系统为例,包含答辩的问题和答案

个人简介一名14年经验的资深毕设内行人,语言擅长Java、php、微信小程序、Python、Golang、安卓Android等开发项目包括大数据、深度学习、网站、小程序、安卓、算法。平常会做一些项目定制化开发、代码讲解、答辩教学、文档编写、也懂一些降重方面的技巧。感谢大家的…

从全栈工程师视角解析Java与前端技术在电商场景中的应用

从全栈工程师视角解析Java与前端技术在电商场景中的应用 面试背景介绍 面试官:你好,很高兴见到你。我叫李明,是这家电商平台的资深架构师。今天我们会聊聊你的技术能力和项目经验。你可以先简单介绍一下自己吗? 应聘者&#xff1a…

【python】python进阶——多线程

引言在现代软件开发中,程序的执行效率至关重要。无论是处理大量数据、响应用户交互,还是与外部系统通信,常常需要让程序同时执行多个任务。Python作为一门功能强大且易于学习的编程语言,提供了多种并发编程方式,其中多…

【JavaEE】(23) 综合练习--博客系统

一、功能描述 用户登录后,可查看所有人的博客。点击 “查看全文” 可查看该博客完整内容。如果该博客作者是登录用户,可以编辑或删除博客。发表博客的页面同编辑页面。 本练习的博客网站,并没有添加注册功能,以及上传作者头像功能…

MySQL全库检索关键词 - idea 工具 Full-Text Search分享

我们经常要在库中查找一个数据,又不知道在哪个表、哪个字段;或者想找到哪里有在用这个数据。我们可以用:idea 的 Database工具 - Full-Text Search打开idea,在工具栏找到 Database 然后新建自己的连接,然后右键&#x…

银行卡号识别案例

代码实现:import cv2 import numpy as np import argparse import myutils-i moban.png -t card1.pngap argparse.ArgumentParser() ap.add_argument("-i","--image", requiredTrue,help"path to input image") ap.add_argument(&quo…

云管平台上线只是开始:从“建好”到“用好”的运营、推广与深化指南

项目上线的喜悦转瞬即逝,随之而来的是一个更为现实和复杂的阶段:运营。云管平台(CMP)的成功,不再仅仅取决于其技术架构的先进性,更在于它能否融入组织的肌理,为不同角色持续创造价值。本文将从管理者、平台团队、开发者、运维和财务五个核心角色的视角,深入探讨平台上线…

distributed.client.Client 用户可调用函数分析

distributed.client.Client 用户可调用函数分析 1. 核心计算函数 任务提交和执行submit(func, *args, keyNone, workersNone, resourcesNone, retriesNone, priority0, fifo_timeout60s, allow_other_workersFalse, actorFalse, actorsFalse, pureNone, **kwargs) 提交单个函数…

数字图像处理——信用卡识别

在数字支付时代,信用卡处理自动化技术日益重要。本文介绍如何利用Python和OpenCV实现信用卡数字的自动识别,结合图像处理与模式识别技术,具有显著实用价值。系统概述与工作原理信用卡数字识别系统包含两大核心模块:模板数字预处理…

嵌入式ARM64 基于RK3588原生SDK添加用户配置选项./build lunch debian

1 背景 在我们正常拿到SDK后会有一些配置选项,在使用./build.sh lunch之后会输出一些defautconfig让我们选择,瑞芯微的原厂sdk会提供一些主板的配置选项,但是我们的如果是一块新的主板就需要添加自己的配置选项,本文就讨论如何来添…

专为石油和天然气检测而开发的基于无人机的OGI相机

专为石油和天然气检测而开发的基于无人机的OGI相机基于无人机的 OGI 相机:(Optical Gas Imaging,光学气体成像)其实是近几年油气、电力、化工等行业里非常热门的应用方向。什么是 OGI 相机OGI(Optical Gas Imaging)&am…

iPhone17全系优缺点分析,加持远程控制让你的手机更好用!

知名数码厂商苹果,不久前已官宣将于北京时间9月10日凌晨1点开启发布会,主打对于iPhone 17系列产品介绍,并且和以往不同的是,今年会在购物平台上开启线上直播,还是很有新意的。9.13全平台渠道将开启预售模式&#xff0c…