【PL 基础】如何启用早停机制
- 摘要
- 1. on_train_batch_start()
- 2. EarlyStopping Callback
摘要
本文介绍了两种在 PyTorch Lightning 中实现早停机制的方法。第一种是通过重写on_train_batch_start()
方法手动控制训练流程;第二种是使用内置的EarlyStopping
回调,可以监控验证指标并在指标停止改善时自动停止训练。文章详细说明了EarlyStopping
的参数设置,包括监控指标、模式选择、耐心值等核心参数,以及停止阈值、发散阈值等进阶参数。同时介绍了如何通过子类化修改早停触发时机,并提醒注意验证频率与耐心值的配合使用。文末提供了完整的代码示例,展示了如何在实际训练中配置和使用早停机制。
1. on_train_batch_start()
通过重写 on_train_batch_start()
方法,在满足特定条件时提前返回,从而停止并跳过当前epoch的剩余训练批次。
如果对于最初要求的每个epoch重复这样做,将停止整个训练。
2. EarlyStopping Callback
EarlyStopping
回调可用于监控指标,并在没有观察到改善时停止训练。
要启用此功能,请执行以下操作:
-
导入
EarlyStopping
回调模块; -
使用
log()
方法记录需要监控的指标; -
初始化回调并设置要监控的指标名称(
monitor
参数); -
根据指标特性设置监控模式(
mode
参数); -
将
EarlyStopping
回调传递给Trainer
的callbacks
参数。
from lightning.pytorch.callbacks.early_stopping import EarlyStoppingclass LitModel(LightningModule):def validation_step(self, batch, batch_idx):loss = ...self.log("val_loss", loss)model = LitModel()
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model)
可以通过更改其参数来自定义回调行为。
early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=3, verbose=False, mode="max")
trainer = Trainer(callbacks=[early_stop_callback])
用于在极值点停止训练的附加参数:
-
stopping_threshold
(停止阈值):当监控指标达到该阈值时立即终止训练。适用于已知超过特定最优值后模型不再提升的场景。 -
divergence_threshold
(发散阈值):当监控指标劣化至该阈值时即刻停止训练。当指标恶化至此程度时,我们认为模型已无法恢复,此时应提前终止并尝试不同初始条件。 -
check_finite
(有限值检测):启用后,若监控指标出现NaN(非数值)或无穷大时终止训练。 -
check_on_train_epoch_end
(训练周期结束检测):启用后,在训练周期结束时检查指标。仅当监控指标通过周期级训练钩子记录时才需启用此功能。
若需在训练过程的其他阶段启用早停机制,请通过创建子类继承 EarlyStopping
类并修改其调用位置:
class MyEarlyStopping(EarlyStopping):def on_validation_end(self, trainer, pl_module):# override this to disable early stopping at the end of val looppassdef on_train_end(self, trainer, pl_module):# instead, do it at the end of training loopself._run_early_stopping_check(trainer)
默认情况下,EarlyStopping
回调会在每个验证周期结束时触发。但验证频率可通过 Trainer
中的参数调节,例如通过设置 check_val_every_n_epoch
(每N个训练周期验证一次)和 val_check_interval
(验证间隔)。需特别注意:patience
(耐心值)统计的是验证结果未提升的次数,而非训练周期数。因此当设置 check_val_every_n_epoch=10
且 patience=3
时,训练器需经历至少 40个训练周期才会停止。