模型调试实用技巧 (Pytorch Lightning)

【PL 基础】模型调试实用技巧

  • 摘要
  • 1. 设置断点
  • 2. 快速运行所有模型代码一次
  • 3. 缩短 epoch 长度
  • 4. 运行健全性检查
  • 5. 打印 LightningModule 权重摘要
  • 6. 打印输入输出层尺寸

摘要

  本文总结了6种实用的模型调试技巧:1)通过设置断点逐行检查代码;2)使用fast_dev_run参数快速验证全流程;3)限制批次量缩短训练周期;4)利用num_sanity_val_steps进行预验证;5)通过ModelSummary打印模型权重结构;6)设置example_input_array显示各层输入输出尺寸。这些方法可显著提升调试效率,特别适用于大规模深度学习模型的开发验证环节,帮助开发者快速定位问题并优化模型结构。

1. 设置断点

  断点会停止代码执行,以便您可以检查变量等。并允许您的代码一次执行一行。

def function_to_debug():x = 2# set breakpointbreakpoint()y = x**2

在此示例中,代码将在执行该行 y = x**2 之前停止。

2. 快速运行所有模型代码一次

  如果你曾经历过模型训练数日后却在验证或测试阶段崩溃的痛苦,那么这个训练器参数将成为你的救星。

fast_dev_run(快速开发运行模式)参数会让训练器仅执行:
5个批次的训练 → 验证 → 测试 → 预测全流程
快速检测代码是否存在错误:

trainer = Trainer(fast_dev_run=True)

要更改要使用的批次数,请将参数更改为整数。在这里,我们运行每个批次的 7 个批次:

trainer = Trainer(fast_dev_run=7)

启用fast_dev_run参数时,将自动禁用以下功能组件:

  • 超参优化器(tuner)

  • 模型检查点回调(checkpoint callbacks)

  • 早停回调(early stopping callbacks)

  • 所有日志记录器(loggers)

  • 日志类回调(如学习率监控器 LearningRateMonitor / 设备状态监控器 DeviceStatsMonitor)

3. 缩短 epoch 长度

  在某些场景下,仅使用训练集/验证集/测试集/预测数据的子集(或限定批次量)能显著提升效率。例如:

✅ 仅抽取20%训练集

✅ 仅使用1%验证集

  在处理ImageNet等大型数据集时,此方法可帮助您:

✅ 快速完成调试或验证

✅ 避免等待完整周期结束

✅ 大幅缩短反馈周期

# use only 10% of training data and 1% of val data
trainer = Trainer(limit_train_batches=0.1, limit_val_batches=0.01)# use 10 batches of train and 5 batches of val
trainer = Trainer(limit_train_batches=10, limit_val_batches=5)

4. 运行健全性检查

  Lightning框架在训练初始阶段会预先执行2步验证,该设计能有效避免:当训练进入耗时漫长的深水区后,才在验证环节意外崩溃的风险。

trainer = Trainer(num_sanity_val_steps=2)

5. 打印 LightningModule 权重摘要

  1. 每当调用该函数.fit() 时,Trainer 都会打印 LightningModule 的权重摘要。

trainer.fit(...)

这会生成一个表,如下所示:

  | Name  | Type        | Params | Mode
-------------------------------------------
0 | net   | Sequential  | 132 K  | train
1 | net.0 | Linear      | 131 K  | train
2 | net.1 | BatchNorm1d | 1.0 K  | train
  1. 如需在模型摘要中显示子模块,需添加 ModelSummary 回调:
from lightning.pytorch.callbacks import ModelSummary  # 导入模型摘要组件trainer = Trainer(callbacks=[ModelSummary(max_depth=-1)])  # 创建训练器时配置回调

参数解释

ModelSummary(max_depth=-1,  # 深度控制:-1=无限递归,0=仅顶层,1=展开一级子模块max_recursion=10  # 可选:防止无限递归的保险机制(默认10层)
)

典型输出示例

| Name        | Type          | Params | In dim       | Out dim      |
|-------------|---------------|--------|--------------|--------------|
| net         | Sequential    | 1.5 M  | [32, 3, 224] | [32, 1000]   |
|  ├─conv1    | Conv2d        | 9.4 K  | [32, 3, 224] | [32, 64,112] |
|  ├─bn1      | BatchNorm2d   | 128    | [32,64,112]  | [32,64,112]  |
|  └─...      | ...           | ...    | ...          | ...          |
  1. 若需在不调用 .fit() 的情况下打印模型摘要,请使用以下方案:
from lightning.pytorch.utilities.model_summary import ModelSummary  # 从工具库导入摘要类model = LitModel()  # 实例化自定义模型
summary = ModelSummary(model, max_depth=-1)  # 生成深度摘要对象
print(summary)  # 打印结构化模型报告

参数解释

ModelSummary(model,        # 必需:继承LightningModule的自定义模型max_depth=-1, # 层级深度:-1=无限递归(显示所有子模块)max_recursion=10  # 递归安全限制(防循环引用崩溃)
)

典型输出示例

╒═════════════╤══════════════╤═════════╤══════════╤═══════════╕
│ Layer       │ Type         │ Params  │ In dim   │ Out dim   │
╞═════════════╪══════════════╪═════════╪══════════╪═══════════╡
│ encoder     │ Sequential   │ 4.7M    │ [32,256][32,512]  │
│ ├─lstm1     │ LSTM         │ 3.2M    │ [32,256][32,128]  │
│ ├─dropout   │ Dropout      │ 0[32,128][32,128]  │
│ └─...............       │
╘═════════════╧══════════════╧═════════╧══════════╧═══════════╛
Trainable params: 4.7M
Non-trainable params: 0
  1. 要关闭自动汇总,请使用:
trainer = Trainer(enable_model_summary=False)

6. 打印输入输出层尺寸

另一个调试工具是通过在 LightningModule 中设置属性来显示所有层的中间输入和输出大小。example_input_array

class LitModel(LightningModule):def __init__(self, *args, **kwargs):self.example_input_array = torch.Tensor(32, 1, 28, 28)

对于输入数组,摘要表将包括输入和输出层维度:

  | Name  | Type        | Params | Mode  | In sizes  | Out sizes
----------------------------------------------------------------------
0 | net   | Sequential  | 132 K  | train | [10, 256] | [10, 512]
1 | net.0 | Linear      | 131 K  | train | [10, 256] | [10, 512]
2 | net.1 | BatchNorm1d | 1.0 K  | train | [10, 512] | [10, 512]

调用 Trainer.fit() 方法时,该机制可帮助您检测网络层组合中的潜在错误。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/87304.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/87304.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机网络(四)网际层IP

目录 一、概念 ​编辑 二、网际层和数据链路层的关系​ 三、IP地址的基础认识 四、IP地址的分类 五、无分类地址CIDR 六、子网掩码 七、为什么要分离网络号和主机号 八、公有IP和私有IP ​编辑 九、IP地址与路由控制 十、IP分片和重组 十一、IPv6 十二、IP协议…

Java--多态--向上转型--动态绑定机制--断点调试--向下转型

目录 1. 向上转型 2. 向下转型 3. java的动态绑定机制: 4. Object类讲解 5. 断点调试 1. 向上转型 提前:俩个对象(类)存在继承关系 本质:父类的引用指向了子类的对象 语法:父类 类型 引用名 new…

Python爬虫实战:研究urllib 库相关技术

1. 引言 1.1 研究背景与意义 互联网每天产生海量数据,如何高效获取和利用这些数据成为重要研究方向。网页爬虫作为自动获取网络信息的核心技术,在市场调研、舆情分析、学术研究等领域具有广泛应用。Python 凭借其简洁语法和丰富库支持,成为爬虫开发的首选语言。 1.2 相关…

【机器学习赋能的智能光子学器件系统研究与应用】

目前在Nature和Science杂志上发表的机器学习与光子学结合的研究主要集中在以下几个方面: 1.光子器件的逆向设计:通过机器学习,特别是深度学习,可以高效地进行光子器件的逆向设计,这在传统的多参数优化问题中尤为重要。…

Codeforces Round 1034 (Div. 3)

比赛链接如下:https://codeforces.com/contest/2123 A. Blackboard Game Initially, the integers from 00 to n−1 are written on a blackboard. In one round, Alice chooses an integer a on the blackboard and erases it;then Bob chooses an integer b on …

微电网系列之微电网的孤岛运行

个人主页:云纳星辰怀自在 座右铭:“所谓坚持,就是觉得还有希望!” 微电网的孤岛运行 微电网具有并网和孤岛两种运行模式,由于孤岛运行模式下,分布式电源为微电网内部负荷提供频率和电压支撑,由…

JsonCpp的核心类及核心函数使用汇总

文章目录 JsonCpp的核心类及核心函数使用汇总一、前言二、JsonCpp 核心类介绍三、Value 类函数解析1. 值获取函数(asxxx 系列 )2. 值类型判断函数(isxxx 系列 )3. 数组操作函数4. 对象操作函数5. 运算符重载6. 迭代器7. JSON 转化…

Qt写入excel

1.tableView导出到excel 点击导出函数按钮、发送sendMessage信号(信号名称,对象,数据) void HydroelectricPowerPluginImpl::exportTableViewSelectedRows(QTableView* tableView, QWidget* parent) {if (!tableView || !tableVie…

OSCP - Proving Grounds - DC - 1

主要知识点 drupal 7 RCEfind SUID提权 具体步骤 nmap起手,80端口比较有意思,安装了 Drupal 7 Starting Nmap 7.94SVN ( https://nmap.org ) at 2024-12-17 14:23 UTC Nmap scan report for 192.168.57.193 Host is up (0.00087s latency). Not shown: 65531 cl…

仿小红书交流社区(微服务架构)

文章目录 framework - 平台基础设施starter - jacksoncommonexceptionresponseutil starter - content 全局上下文distributed - id - generate - 分布式 IdSnowflake - 基于雪花算法生成 IdSegment - 基于分段式生成 Id OSS - 对象存储KV - 短文本存储笔记评论 user - 用户服务…

大模型开源技术解析 4.5 的系列开源技术解析:从模型矩阵到产业赋能的全栈突破

提示:本篇文章 1300 字,阅读时间:5分钟。 前言 6 月 30 日,百度正式开源文心大模型 4.5 系列,这一动作不仅兑现了 2 月发布会上的技术承诺,更以 10 款全维度模型矩阵刷新了国内开源模型的技术边界。从学术…

[6-02-01].第05节:配置文件 - YAML配置文件语法

SpringBoot学习大纲 一、YAML语法 1.1.概述: 1.YAML是一种数据序列化格式;2.它是以数据为中心3.容易阅读,容易与脚本语言交互,如下图所示: 1.2.基本语法 1.key: value:kv之间有空格2.使用缩进表示层级关系3.缩进时…

FPGA学习

一、module : 定义: 是构建数字系统的基本单元,用于封装电路的结构和行为。它可以表示从简单的逻辑门到复杂的处理器等任何硬件组件。 1. module 的基本定义 module 模块名 (端口列表);// 端口声明input [位宽] 输入端口1;output [位宽] 输出端口1;ino…

26-计组-存储器与Cache机制

一、存储器与局部性原理 1. 局部性原理 基础概念: 时间局部性:一个存储单元被访问后,短时间内可能再次被访问(例如循环变量)。空间局部性:一个存储单元被访问后,其附近单元可能在短时间内被访…

I/O 线程 7.3

前言 以下: 概述 1.基础 2.代码演示 3.练习 4.分析题 1.基础 一、线程基础概念 并发执行原理 通过时间片轮转实现多任务"并行"效果 实际为CPU快速切换执行不同线程 线程 vs 进程 线程共享进程地址空间,切换开销更小 进程拥有独立资源&am…

MySQL JSON数据类型完全指南:从版本演进到企业实践的深度对话

📊 MySQL JSON数据类型完全指南:从版本演进到企业实践的深度对话 在当今数据驱动的时代,MySQL作为最受欢迎的关系型数据库之一,不断演进以满足现代应用的需求。JSON数据类型的引入,让MySQL在保持关系型数据库优势的同时…

BI × 餐饮行业 | 以数据应用重塑全链路业务增长路径

在竞争激烈的餐饮行业中,数据已成为企业保持竞争力的关键资产。通过深入分析顾客数据,餐饮企业能够洞察消费者的需求和偏好,从而提供更加精准和个性化的服务。此外,利用数据优化业务管理,降低成本,并提高运…

【学习线路】机器学习线路概述与内容关键点说明

文章目录 零、机器学习的企业价值一、基础概念1. 机器学习定义2. 学习类型3. 学习范式 二、核心算法与技术1. 监督学习2. 无监督学习3. 模型评估与优化 三、深度学习与神经网络1. 神经网络基础2. 深度学习框架3. 应用场景 四、工具与实践1. 数据处理2. 模型部署3. 机器学习的生…

Linux 命令:cp

Linux cp 命令详细教程 cp 是 Linux 系统中最常用的命令之一,用于复制文件或目录。它可以将源文件/目录复制到指定的目标位置,支持批量复制、强制覆盖、保留文件属性等功能。下面详细介绍其用法。资料已经分类整理好:https://pan.quark.cn/s…

java分页插件| MyBatis-Plus分页 vs PageHelper分页:全面对比与最佳实践

MyBatis-Plus分页 vs PageHelper分页:全面对比与最佳实践 一、分页技术概述 在Java持久层框架中,分页是高频使用的功能。主流方案有: MyBatis-Plus分页:MyBatis增强工具的内置分页方案PageHelper分页:独立的MyBatis…