YOLOv8 剪枝模型加载踩坑记:解决 YAML 覆盖剪枝结构的问题

1. 问题背景

模型剪枝是实现模型轻量化、加速推理的关键步骤。然而,在 Ultralytics YOLOv8 的生态中,在成功剪枝后,进行微调(Fine-tuning)时会遇到一个令人困惑的现象:明明加载的是剪枝后的模型(例如 20M 参数),但训练启动时打印的日志却显示为标准版模型的参数(例如 25M)。并且经过验证,微调后的模型参数就是标准的yolo模型。

加载代码如下:

    model = YOLO("pruned.pt")     # load a pretrained model (recommended for training)model.train(data=name_yaml, device=0, imgsz=640, epochs=50, batch=32, workers=16, name=path_fineturn)  # train the model

原因是Ultralytics 的 Trainer 仍会先依据 原始 YAML 构建标准结构(约 25M 参数)。随后仅将 .pt 文件中的权重加载到这张标准结构中。


2. 代码触发点与根本原因

问题的根源在于 Ultralytics 的 Trainer 在初始化模型时(get_model 方法)的执行顺序。

ultralytics/engine/model.py中的Model类的train()方法中,原始代码如下:

self.trainer.get_model 方法的执行流程如下:

  • 优先使用 cfg 参数构建模型:该参数接收 cfg=self.model.yaml。由于 pruned.pt 在保存时不会自动更新其内部的 YAML 配置( model = YOLO("pruned.pt")会构造出一个实例,里面的self.model有很多属性,其中self.model.model是模型网络,这是真正的、由网络层构成的可执行实体。我们的剪枝操作直接修改了这个对象,比如减少了某些卷积层的通道数,从而改变了它的实际结构self.model.yaml是配置文件,剪枝时只修改了self.model.model,没有更新原始的self.model.yaml),所以这里的 self.model.yaml 仍然是标准版 YOLOv8m 的网络结构

  • 创建标准结构并打印摘要get_model 会立即执行 model = DetectionModel(cfg) 通过self.model.yaml来构建一个完整的未剪枝模型(25.8M)。随后调用 model.info() 方法,这就是日志中显示"标准版"摘要的原因。完成标准结构创建后,get_model 才会处理 weights 参数,将 pruned.pt 中的权重加载到刚创建的标准结构中。PyTorch 的 load_state_dict 会按照名称和形状匹配的原则加载对应层的权重,跳过不匹配的层,此时模型仍保持标准骨架结构。


3. 改进写法(实际切换到剪枝后结构)

为了解决这个问题,我们必须在 Trainer 开始训练前,确保其内部持有的模型对象是我们剪枝后的那一个。

将代码调整为:

        if not args.get("resume"):  # manually set model only if not resumingself.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)# ★ 关键修正:用我们剪枝后的模型对象,替换掉 Trainer 内部刚刚由 YAML 创建的模型self.trainer.model.model = self.model.modelprint("\n--- Verifying model after swapping in Trainer ---")# 打印替换后的模型参数量params_after_swap = sum(p.numel() for p in self.trainer.model.model.parameters()) / 1e6print(f"Parameters inside trainer: {params_after_swap:.2f}M\n")  # 应显示约 20.8Mself.model = self.trainer.modelif SETTINGS["hub"] is True and not self.session:

  • 依然允许 get_model 按部就班地完成它的初始化流程(包括打印那条“误导性”的日志)。

  • 但在这之后,立即通过 self.trainer.model.model = self.model.model 这行代码,强行将 Trainer 内部的 nn.Module 对象替换为我们真正的、剪枝后的模型 (self.model.model)

  • 启动阶段的日志已打印过标准版结构,因此显示上仍是标准参数量,但通过打印替换后的模型对象的参数量可以看到已经替换为剪枝后的模型

深度解析:为什么是替换 .model.model 而不是 .model
  1. yolo.model 对象 (DetectionModelBaseModel 的实例)
    它是一个“功能完备的检测器”,不仅包含了网络结构,还封装了与之相关的元数据和方法(如 .train(), .info(), .yaml 等)。把它理解为一个高级接口

  2. yolo.model.model 对象 (纯 nn.Module 实例)
    这才是我们通常意义上所说的PyTorch 模型网络。它是一个纯粹的 torch.nn.Module 子类,由各种网络层搭建而成。我们的剪枝操作,直接修改的就是这个对象。

为什么不写成 self.trainer.model = self.model

  • 源(Source)self.model.model 是我们从加载的 pruned.pt 中取出的、那个已经被剪枝过的纯粹网络结构

  • 目标(Destination)self.trainer.model.modelTrainer 内部那个标准结构的纯粹网络

self.trainer.model 是一个高级的 BaseModel 对象,Trainer 在初始化时已经对其进行了一些配置(如设备分配等)。如果我们用self.trainer.model = self.model整个地替换掉它,可能会破坏这些已经完成的设置,存在潜在风险。只替换最底层的 nn.Module,既能保证网络结构正确,又不会干扰 Trainer 的其他工作流程。
注意替换模型必须在self.trainer.model构建好之后,如果直接使用self.trainer.model.model = self.model.model会显示self.trainer.model是个str,还不是对象。

4. 显示不一致的原因

  • Summary 打印时机get_model 在构建标准结构后立即输出层数与参数量。

  • 结构替换发生在 summary 之后:没有重新打印,因此日志没有更新为剪枝后的参数量

  • 保存阶段:调用 model.save()torch.save({'model': ...}) 时,写入的是替换后的剪枝模型对象,所以最终 .pt 文件尺寸/参数量正确

5. 验证流程建议

为了确保操作是正确的,最好进行验证。

步骤 1:验证初始剪枝模型
在开始微调训练前,先确认 pruned.pt 是真的被剪枝了。

from ultralytics import YOLO
initial_model = YOLO("pruned.pt")
print("--- Verifying initial pruned model ---")
initial_model.model.info(verbose=False)  # 应显示约 20.8M 参数

步骤 2:在替换后立即验证
在修正代码的核心行之后,立刻加入打印验证,就是之前的代码。

# ...
self.trainer.model.model = self.model.model
print("\n--- Verifying model after swapping in Trainer ---")
# 打印替换后的模型参数量
params_after_swap = sum(p.numel() for p in self.trainer.model.model.parameters()) / 1e6
print(f"Parameters inside trainer: {params_after_swap:.2f}M\n") # 应显示约 20.8M

步骤 3:验证最终保存的模型
训练结束后,加载最终生成的权重文件,再次确认。

final_model = YOLO("runs/train/exp/weights/last.pt")
print("--- Verifying final saved model ---")
final_model.model.info() # 应显示约 20.8M 参数

结果如图:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/90937.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/90937.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

js的学习1

1.数组 数组方法 push()数组尾部添加unshift()数组头部添加pop()数组尾部删除shift()数组头部删除splice(起始位置,删除几个元素,要替换的元素)删除指定的元素,改变了原数组,返回值是被删除的元素indexOf()第一次查到的索引&#…

LeetCode 2563.统计公平数对的数目

给你一个下标从 0 开始、长度为 n 的整数数组 nums &#xff0c;和两个整数 lower 和 upper &#xff0c;返回 公平数对的数目 。 如果 (i, j) 数对满足以下情况&#xff0c;则认为它是一个 公平数对 &#xff1a; 0 < i < j < n&#xff0c;且 lower < nums[i] n…

ZABBIX配置自动发现与自动注册,网易邮箱告警和钉钉告警

一、自动发现zabbix server 主动的去发现所有的客户端&#xff0c;然后将客户端的信息登记在服务端上。缺点是如果定义的网段中的主机数量多&#xff0c;zabbix server 登记耗时较久&#xff0c;且压力会较大。1、部署准备准备三台虚拟机192.168.80.151&#xff1b;192.168.80.…

QT(五)常用类

1. QString字符串类(掌握) QString是Qt的字符串类&#xff0c;与C的string相比&#xff0c;不再使用ASCII编码&#xff0c;QString使用的是Unicode编码。 QString中每个字符都是一个16位的QChar&#xff0c;而不是8位的char。 QString完全支持中文&#xff0c;但是由于不同的技…

EXCEL怎么提取表名

错误的方法&#xff1a;使用以下方法提取表名的时候&#xff0c;会存在1个问题&#xff0c;公式只在当前工作表生效&#xff0c;换工作表会出现表名覆盖的情况。RIGHT(CELL("filename"),LEN(CELL("filename"))-FIND("]",CELL("filename&quo…

springboot校园外卖配送系统

目 录 第一章 绪 论 1.1背景及意义 1.2国内外研究概况 1.3 研究的内容 第二章 关键技术的研究 2.1开发技术 2.2 Springboot框架介绍 2.3 Vue.js 主要功能 2.4 MVVM模式介绍 2.4 B/S体系工作原理 2.5 MySQL数据库 第三章 系统分析 3.1 系统设计目标 3.2 系统可行性…

【智慧物联网平台】安装部署教程——仙盟创梦IDE

一、部署前准备1. 环境要求基础环境&#xff1a;JDK 1.8、MySQL 5.7/8.0、Maven 3.6、Redis&#xff08;用于缓存&#xff09;、Node.js&#xff08;用于前端构建&#xff0c;可选&#xff09;。依赖服务&#xff1a;若需对接门禁、道闸等硬件设备&#xff0c;需确保设备网络可…

【安全漏洞】防范未然:如何有效关闭不必要的HTTP请求方法,保护你的Web应用

在构建和维护Web应用的过程中&#xff0c;安全问题总是我们最关心的话题之一。今天&#xff0c;我们要探讨的是一个经常被忽视的Web漏洞——未关闭或限制不必要的HTTP请求方法。 虽然我们在日常开发中主要使用 GET 和 POST 这两种请求方法&#xff0c;但像 PUT、DELETE、HEAD、…

嵌入式Linux裸机开发笔记8(IMX6ULL)主频和时钟配置实验(1)

引言在前几章实验中我们都没有涉及到 I.MX6U 的时钟和主频配置操作&#xff0c;全部使用的默认配置&#xff0c; 默认配置下 I.MX6U 工作频率为 396MHz。但是 I.MX6U 系列标准的工作频率为 528MHz&#xff0c;有些 型号甚至可以工作到 696MHz。本章学习 I.MX6U 的时钟系统&…

设计模式(四)创建型:生成器模式详解

设计模式&#xff08;四&#xff09;创建型&#xff1a;生成器模式详解生成器模式&#xff08;Builder Pattern&#xff09;是 GoF 23 种设计模式中的核心创建型模式之一&#xff0c;其核心价值在于将一个复杂对象的构建过程与其表示分离&#xff0c;使得同样的构建过程可以创建…

《Angular+Spring Boot:ERP前端采购销售库存协同架构解析》

基于Angular与Spring Boot构建的全栈ERP前端&#xff0c;绝非技术的简单叠加&#xff0c;而是通过深度融合两者特性&#xff0c;打造出兼具稳定性与灵活性的业务载体。Angular的组件化架构将复杂界面拆解为可复用的独立单元&#xff0c;依赖注入机制则让服务调用与数据流转条理…

Java 排序

文章目录排序插入排序分析希尔排序分析选择排序分析堆排序分析冒泡排序分析快速排序霍尔法分析挖坑法找基准前后指针法题目快排的优化三数取中法非递归实现快排归并排序分析非递归实现归并排序海量数据的排序非比较的排序计数排序分析基数排序桶排序排序 稳定的排序&#xff1…

日本IT就职面试|仪容礼仪篇分享建议

日系企業で好印象を与える「身だしなみ」と「面接マナー」ガイドこんにちは。 日系企業への就職・転職活動をされている方にとって、「第一印象」は合否を左右する大切なポイントですよね。実は、面接の評価は入室の瞬間から始まっていると言っても過言ではありません。 今回は…

英语听力口语词汇-8.美食类

1.crispy,crisp adj.酥脆的&#xff0c;易碎的 2.sweet adj.甜的 比如说chocolate is so sweet and delicious 3.chewy adj.难嚼的&#xff0c;难咽的 4.oatmeal n.燕麦粉 5.pickle n.泡菜 7.stir-fry v.炒菜 8.bacon n.咸肉&#xff0c;熏肉 9.yummy adj.美味可口的 1…

力扣7:整数反转

力扣7:整数反转题目思路代码题目 给你一个 32 位的有符号整数 x &#xff0c;返回将 x 中的数字部分反转后的结果。 如果反转后整数超过 32 位的有符号整数的范围 [−2^31, 2^31 − 1] &#xff0c;就返回 0。 思路 这道题我们可以分成两部分来做&#xff0c;一是完成反转二…

PWM信号控制电机

1&#xff1a;环境 STM32F103C8T6 KEIL5.38 2个电机 2个轮子 1个L298N STLINKV2 CH340 1个4位独立按键 杜邦线若干 2&#xff1a;代码 key.h #ifndef __KEY_H #define __KEY_H#include "stm32f10x.h"extern volatile uint8_t key_t ; extern volatile uint8_t …

开源赋能产业,生态共筑未来 | 开源科学计算与系统建模(openSCS)分论坛圆满举行

2025开放原子开源生态大会于7月23日-24日在北京国家会议中心召开。本届大会以“开源赋能产业&#xff0c;生态共筑未来”为主题&#xff0c;汇聚政、产、学、研、用、金、创、投等各领域开源力量&#xff0c;聚焦开源政策导向、生态发展趋势、开源产业实践&#xff0c;共探中国…

Android广播机制体系初识

Android广播机制体系大白话把Android的广播机制想象成小区里的“大喇叭”谁在喊话&#xff1f;任何App或系统都能当“大喇叭”&#xff0c;比如喊一嗓子“电量不足啦&#xff01;”&#xff08;这就是发送广播&#xff09;谁在听&#xff1f;其他App只要“竖起耳朵”&#xff0…

微信小程序点击输入框时,顶部导航栏被遮挡问题如何解决?

前言 不知道大家开发微信小程序的时候有没有遇到这么一个问题&#xff0c;就是在表单页面中&#xff0c;点击输入框后&#xff0c;输入框顶起会把顶部栏给遮挡住&#xff0c;如下图所示&#xff1a;遇到这种情况有没有解决的办法呢&#xff1f;能不能既将页面顶起&#xff0c;同…

通过具有一致性嵌入的大语言模型(LMMs)实现端到端乳腺癌放射治疗计划制定|文献速递-医学影像算法文献分享

Title题目End-to-end breast cancer radiotherapy planning via LMMs with consistencyembedding通过具有一致性嵌入的大语言模型&#xff08;LMMs&#xff09;实现端到端乳腺癌放射治疗计划制定01文献速递介绍近年来&#xff0c;受大型语言模型&#xff08;LLM&#xff09;启发…