深度学习中的模型剪枝工具Torch-Pruning的使用

      Torch-Pruning(TP)是一个结构化剪枝框架,源码地址:https://github.com/VainF/Torch-Pruning,最新发布版本v1.6.0,License为MIT。

      TP支持对各种深度神经网络进行结构化剪枝。与通过掩码将参数设置为零的torch.nn.utils.prune不同,TP部署了一种名为DepGraph的算法来分组和移除耦合参数(coupled parameter)。

      TP仅依赖PyTorch和Numpy,并且与PyTorch 1.x和2.x兼容,在Anaconda虚拟环境上通过pip安装v1.6.0版本,执行以下命令:

pip install torch-pruning==1.6.0

      在结构化剪枝中,移除单个参数可能会影响多个层。例如,剪枝线性层的输出维度将需要移除下一个线性层中相应的输入维度。层之间的这种依赖关系使得手动剪枝复杂网络变得非常困难。TP通过引入一种名为DepGraph的基于图的算法来解决这个问题,该算法可以自动识别依赖关系并收集需要剪枝的组。

      这里以 https://blog.csdn.net/fengbingchun/article/details/149307432 中的数据集为例,使用DenseNet进行分类,测试代码如下:

      1. 对之前生成的分类模型进行剪枝::保存剪枝后的模型使用torch.save(model,name),不能使用torch.save(model.state_dict(),name)

def model_pruning(model_name, classes_number, prune_amount):# https://github.com/VainF/Torch-Pruning/blob/master/examples/torchvision_models/torchvision_global_pruning.pymodel = models.densenet121(weights=None)model.classifier = nn.Linear(model.classifier.in_features, classes_number)# print("before pruning, model:", model)model.load_state_dict(torch.load(model_name, weights_only=False, map_location="cpu"))orininal_size = tp.utils.count_params(model)model.cpu().eval()for p in model.parameters():p.requires_grad_(True)ignored_layers = []for m in model.modules():if isinstance(m, nn.Linear):ignored_layers.append(m)print(f"ignored_layers: {ignored_layers}")example_inputs = torch.randn(1, 3, 224, 224)# build network prunersimportance = tp.importance.MagnitudeImportance(p=1)pruner = tp.pruner.MagnitudePruner(model,example_inputs=example_inputs,importance=importance,iterative_steps=1,pruning_ratio=prune_amount,global_pruning=True,round_to=None,unwrapped_parameters=None,ignored_layers=ignored_layers,channel_groups={})# pruninglayer_channel_cfg = {}for module in model.modules():if module not in pruner.ignored_layers:if isinstance(module, nn.Conv2d):layer_channel_cfg[module] = module.out_channelselif isinstance(module, nn.Linear):layer_channel_cfg[module] = module.out_featurespruner.step()# print("after pruning, model", model)result_size = tp.utils.count_params(model)print(f"model: original size: {orininal_size}; result_size: {result_size}")# testingwith torch.no_grad():out = model(example_inputs)print("test out:", out)torch.save(model, "new_structured_prune_melon_classify.pt") # cann't bu used: torch.save(model.state_dict(), "")

      剪枝前后,模型的改动如下图所示:

      剪枝前模型大小约为27.1MB,剪枝后模型大小约为14.0M。

      2. 模型剪枝后需要对其进行微调,即重新训练:

def _load_dataset(dataset_path, mean, std, batch_size):mean = _str2tuple(mean)std = _str2tuple(std)train_transform = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std), # RGB])train_dataset = ImageFolder(root=dataset_path+"/train", transform=train_transform)print(f"train dataset length: {len(train_dataset)}; classes: {train_dataset.class_to_idx}; number of categories: {len(train_dataset.class_to_idx)}")train_loader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=0)val_transform = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std), # RGB])val_dataset = ImageFolder(root=dataset_path+"/val", transform=val_transform)print(f"val dataset length: {len(val_dataset)}; classes: {val_dataset.class_to_idx}")assert len(train_dataset.class_to_idx) == len(val_dataset.class_to_idx), f"the number of categories int the train set must be equal to the number of categories in the validation set: {len(train_dataset.class_to_idx)} : {len(val_dataset.class_to_idx)}"val_loader = DataLoader(val_dataset, batch_size, shuffle=True, num_workers=0)return len(train_dataset), len(val_dataset), train_loader, val_loaderdef fine_tuning(dataset_path, epochs, mean, std, model_name):device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = torch.load(model_name, weights_only=False)model.to(device)train_dataset_num, val_dataset_num, train_loader, val_loader = _load_dataset(dataset_path, mean, std, 4)optimizer = optim.Adam(model.parameters(), lr=0.00001) # set the optimizercriterion = nn.CrossEntropyLoss() # set the losshighest_accuracy = 0.minimum_loss = 100.new_model_name = "fine_tuning_melon_classify.pt"for epoch in range(epochs):epoch_start = time.time()train_loss = 0.0train_acc = 0.0val_loss = 0.0val_acc = 0.0model.train() # set to training modefor _, (inputs, labels) in enumerate(train_loader):inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad() # clean existing gradientsoutputs = model(inputs) # forward passloss = criterion(outputs, labels) # compute lossloss.backward() # backpropagate the gradientsoptimizer.step() # update the parameterstrain_loss += loss.item() * inputs.size(0) # compute the total loss_, predictions = torch.max(outputs.data, 1) # compute the accuracycorrect_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor)) # convert correct_counts to floattrain_acc += acc.item() * inputs.size(0) # compute the total accuracy# print(f"train batch number: {i}; train loss: {loss.item():.4f}; accuracy: {acc.item():.4f}")model.eval() # set to evaluation modewith torch.no_grad():for _, (inputs, labels) in enumerate(val_loader):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs) # forward passloss = criterion(outputs, labels) # compute lossval_loss += loss.item() * inputs.size(0) # compute the total loss_, predictions = torch.max(outputs.data, 1) # compute validation accuracycorrect_counts = predictions.eq(labels.data.view_as(predictions))acc = torch.mean(correct_counts.type(torch.FloatTensor)) # convert correct_counts to floatval_acc += acc.item() * inputs.size(0) # compute the total accuracyavg_train_loss = train_loss / train_dataset_num # average training lossavg_train_acc = train_acc / train_dataset_num # average training accuracyavg_val_loss = val_loss / val_dataset_num # average validation lossavg_val_acc = val_acc / val_dataset_num # average validation accuracyepoch_end = time.time()print(f"epoch:{epoch+1}/{epochs}; train loss:{avg_train_loss:.6f}, accuracy:{avg_train_acc:.6f}; validation loss:{avg_val_loss:.6f}, accuracy:{avg_val_acc:.6f}; time:{epoch_end-epoch_start:.2f}s")if highest_accuracy < avg_val_acc and minimum_loss > avg_val_loss:torch.save(model, new_model_name)highest_accuracy = avg_val_accminimum_loss = avg_val_lossif avg_val_loss < 0.0001 or avg_val_acc > 0.9999:print(colorama.Fore.YELLOW + "stop training early")torch.save(model, new_model_name)break

      微调时迭代几次即可满足要求,执行结果如下图所示:

      3. 使用剪枝后的模型和微调后的模型进行预测::加载模型使用torch.load(model_name, weights_only=False),不能使用model.load_state_dict(torch.load(model_name, weights_only=False, map_location="cpu"))

def _parse_labels_file(labels_file):classes = {}with open(labels_file, "r") as file:for line in file:idx_value = []for v in line.split(" "):idx_value.append(v.replace("\n", "")) # remove line breaks(\n) at the end of the lineassert len(idx_value) == 2, f"the length must be 2: {len(idx_value)}"classes[int(idx_value[0])] = idx_value[1]return classesdef _get_images_list(images_path):image_names = []p = Path(images_path)for subpath in p.rglob("*"):if subpath.is_file():image_names.append(subpath)return image_namesdef predict(model_name, labels_file, images_path, mean, std):classes = _parse_labels_file(labels_file)assert len(classes) != 0, "the number of categories can't be 0"image_names = _get_images_list(images_path)assert len(image_names) != 0, "no images found"mean = _str2tuple(mean)std = _str2tuple(std)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = torch.load(model_name, weights_only=False)model.to(device)model.eval()with torch.no_grad():for image_name in image_names:input_image = Image.open(image_name)preprocess = transforms.Compose([transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=mean, std=std) # RGB])input_tensor = preprocess(input_image) # (c,h,w)input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model, (1,c,h,w)input_batch = input_batch.to(device)output = model(input_batch)probabilities = torch.nn.functional.softmax(output[0], dim=0) # the output has unnormalized scores, to get probabilities, you can run a softmax on itmax_value, max_index = torch.max(probabilities, dim=0)print(f"{image_name.name}\t{classes[max_index.item()]}\t{max_value.item():.4f}")

      执行结果如下图所示:微调前的模型准确率非常低,微调后的模型准确率非常高

      GitHub:https://github.com/fengbingchun/NN_Test

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/89577.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/89577.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

力扣-121.买卖股票的最佳时机

121.买卖股票的最佳时机 class Solution {public int maxProfit(int[] prices) {int min prices[0];int max 0;for (int i 1; i < prices.length; i) {max Math.max(prices[i] - min, max);if (prices[i] < min) {min prices[i];}}return max;} }小结&#xff1a;贪…

lvs原理及实战部署

一、集群与分布式系统 1 集群 1-1概念 集群式架构是将多个相同或相似的节点组合在一起&#xff0c;形成一个逻辑上的 “整体”&#xff0c;对外提供统一的服务或资源。节点之间通常具有较高的同构性&#xff08;硬件、软件配置相似&#xff09;&#xff0c;且紧密协作。 1-2 三…

[Linux]如何設置靜態IP位址?

自從將Ubuntu Server 24.04 LTS作業系統建置在VM上後&#xff0c;逐漸導入一些容器和微服務器並使可由其他Client端來連接使用&#xff0c;其中包含AIGC模型和自動化工作流等服務&#xff0c;例如Open-WebUI和n8n。然而&#xff0c;若VM重新開機或路由器因故斷電等等狀態&#…

【Leecode 随笔】

文章目录题目一&#xff1a;盛最多水的容器题目描述&#xff1a;题目分析&#xff1a;解题思路&#xff1a;示例代码&#xff1a;深入剖析&#xff1a;题目二&#xff1a;最长无重复字符的子串题目描述&#xff1a;题目分析&#xff1a;解题思路&#xff1a;示例代码&#xff1…

Springboot项目应用PageInfo分页问题失效

使用github的pagehelper分页依赖<!-- 分页控件 --><dependency><groupId>com.github.pagehelper</groupId><artifactId>pagehelper</artifactId><version>5.3.0</version><scope>compile</scope></dependency&…

【无标题】标准模型粒子行为与11维拓扑量子色动力学模型严格对应的全面论述

标准模型粒子行为与11维拓扑量子色动力学模型严格对应的全面论述标准模型粒子与拓扑结构的严格对应 mermaid graph LRsubgraph 标准模型粒子A[费米子] --> A1[夸克]A --> A2[轻子]B[玻色子] --> B1[规范玻色子]B --> B2[希格斯]endsubgraph 11维拓扑模型C[实体顶点…

SQL一些关于存储过程和使用的总结

存储过程&#xff1a;数据库里的 "定制工具箱"存储过程就像一个装满工具的箱子&#xff0c;你需要什么功能&#xff0c;就调用对应的工具。它是用 SQL 语句写好的一段程序&#xff0c;存储在数据库里&#xff0c;随时可以调用。创建存储过程 就像在工具箱里放新工具。…

springCloud -- 微服务01

目录 一、认识微服务 1.单体架构 2.微服务 3.SpringCloud 二、微服务拆分 1.服务拆分原则 2.服务调用 3. RestTemplate 三、服务注册和发现 1. 注册中心原理 2. 服务发现 2.1 服务注册 2.2 服务发现 四、OpenFeign 一、认识微服务 1.单体架构 单体架构就是整个项目中所有功能…

Deep Multi-scale Convolutional Neural Network for Dynamic Scene Deblurring 论文阅读

用于动态场景去模糊的深度多尺度卷积神经网络 摘要 针对一般动态场景的非均匀盲去模糊是一个具有挑战性的计算机视觉问题&#xff0c;因为模糊不仅来源于多个物体运动&#xff0c;还来源于相机抖动和场景深度变化。为了去除这些复杂的运动模糊&#xff0c;传统的基于能量优化的…

PDF 拆分合并PDFSam:开源免费 多文件合并 + 按页码拆分 本地处理

各位打工人和学生党们&#xff0c;你知道吗&#xff0c;处理PDF文件简直是咱们的日常噩梦啊&#xff0c;尤其是遇到要合并好几个文件&#xff0c;或者从中抠几页出来的时候&#xff0c;简直头大如斗&#xff01;今天给你们安利一个神仙工具&#xff0c;PDFSam&#xff0c;听我的…

AI产品经理面试宝典第32天:AI+工业场景落地核心问题与应答策略

一、AI+工业落地价值怎么答? 面试官:AI在工业领域能创造哪些核心价值?请用具体案例说明 你的回答: AI在工业领域创造价值的底层逻辑是"数据闭环"。以阿里云ET工业大脑为例,通过采集生产线3000+传感器数据,构建出影响良品率的60个关键变量模型。当数据流经AI…

【09】MFC入门到精通——MFC 属性页对话框的 CPropertyPage类 和 CPropertySheet 类

文章目录九、属性页对话框的类CPropertyPage类 和 CPropertySheet 类。9.1 CPropertyPage 类&#xff08;1&#xff09;构造函数&#xff08;2&#xff09;CancelToClose()函数&#xff08;3&#xff09;SetModified()函数&#xff08;4&#xff09;可重载函数9.2 CPropertyShe…

Python学习笔记4

时间:2025.7.18学习内容&#xff1a;【语法基础】if判断、比较运算符与逻辑运算符一、if判断if判断基本格式&#xff1a;if要判断的条件&#xff0c;条件成立时要做的事情注意&#xff1a;input内默认存储的是字符串age17 if age<18:print(未成年不能上网) scoreinput(你的成…

20250718-2-Kubernetes 应用程序生命周期管理-Pod对象:基本概念(豌豆荚)_笔记

二、Kubernetes应用程序生命周期管理&#xfeff;1. 课程内容概述主要内容&#xff1a;Pod资源共享实现机制管理命令应用自修复&#xff08;重启策略健康检查&#xff09;环境变量Init container静态Pod2. Pod对象介绍&#xfeff;1&#xff09;Pod基本概念&#xfeff;&#x…

为Notepad++插上JSON格式化的翅膀

文章目录概要安装步骤效果展示概要 JSMinNPP.dll 是一个 Notepad 插件&#xff0c;用于压缩 JavaScript 代码和格式化JSON字符床。以下是安装和使用的详细步骤&#xff1a; 安装步骤 下载 JSMinNPP.dll 插件 https://pan.quark.cn/s/73dd0ac225be 放置 DLL 文件 打开 Notepa…

STM32-第七节-TIM定时器-3(输入捕获)

一、简介&#xff1a;1.名称&#xff1a;IC&#xff0c;输入捕获2.电路&#xff1a;如图为通用定时器框图&#xff0c;下半部分的左半模块&#xff0c;与输出比较部分共用捕获/比较寄存器与引脚。3.功能&#xff1a;当通道输入引脚出现电平跳变时&#xff0c;当前CNT的值&#…

Console 纳管 Elasticsearch 9(二):日志监控

前面介绍过 INFINI Console 纳管 Elasticsearch 9&#xff08;一&#xff09;&#xff0c;进行指标监控、数据管理、DSL 语句执行&#xff0c;但日志监控功能需要结合 Agent 才能使用。现在来实现一下&#xff1a; Agent 需要和 ES 部署到同一机器上&#xff0c;这里是在我本地…

实训十——路由器与TCP/IP模型

补充拓扑图&#xff08;交换机串联通信&#xff09;电脑A——交换机S1——交换机S2——电脑B问&#xff1a;A和B如何通信&#xff1f;首先A会将通信的数据封装好&#xff0c;将源端口、目标端口&#xff0c;源地址、目标地址&#xff0c;源MAC、目标MAC封装起来&#xff0c;但是…

【Android】ViewBinding(视图绑定)

一、什么是ViewBindingViewBinding是Android Studio 3.6推出的新特性&#xff0c;旨在替代findViewById(内部实现还是使用findViewById)。通过ViewBinding&#xff0c;可以更轻松地编写可与视图交互的代码。在模块中启用ViewBinding之后&#xff0c;系统会为该模块中的每个 XML…

泛型与类型安全深度解析及响应式API实战

一、泛型通配符&#xff1a;灵活与安全的平衡术 在Java动物收容所系统中&#xff0c;我们常需要处理不同动物类型的集合。通过泛型通配符&#xff0c;可以构建更灵活的API&#xff1a; class Shelter<T extends Animal> {private List<T> animals new ArrayList&l…