Pytorhc Lightning进阶:一篇实例玩转Pytorhc Lightning 让训练更高效
Pytorhc Lightning 主要包含以下几大类,主要围绕以下讲解:
- 模型,PyTorch Lightning 的核心是继承
pl.LightningModule
- 数据,数据模块继承
pl.LightningDataModule
- 回调函数的构造和使用,以及自定义
- 钩子函数使用,模型中、数据类中、回调函数中调用先后
- 日志记录,一般用Pytorhc Lightning自带的tensorboard
1. 定义模型类
1.1 如下模型的基本结构
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl class LitModel(pl.LightningModule): def __init__(self, imgsize=28, hidden=128, num_class=10): super().__init__() self.flatten = nn.Flatten() self.net = nn.Sequential( nn.Linear(imgsize * imgsize, hidden), nn.ReLU(), nn.Linear(hidden, num_class) ) self.loss_fn = nn.CrossEntropyLoss() def forward(self, x): x = self.flatten(x) return self.net(x) def training_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) self.log("train_loss", loss) # 自动记录训练损失 return loss def validation_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) self.log("val_loss", loss) # 自动记录验证损失 return loss def test_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) self.log("test_loss", loss) # 自动记录测试损失 return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=1e-3)
1.2 方法validation_step
training_step
test_step
传入两个必须参数是batch
,batch_idx
,名字不能改。当然还有其他的,但是基本不怎么用。
batch
就是从dataloader
中返回的结果batch_idx
记录当前epoch数据批次的索引,比如我想每隔100step去记录日志,可以if batch_idx % 100 == 0:
training_step
需要返回进行反向传播的loss,其他两个可以不用,多个优化器需要手动实现反向传播(高版本)validation_step
training_step
训练时一定要有;test_step
是模型测试的时候调用,训练时可以不要
1.3 优化器和调度定义方法configure_optimizers
该方法不需要额外传参,如果有多个优化器则按照return [optimizer1, optimizer1...],[scheduler1, scheduler2,...]
当然还可以按照字典形式返回,这里拿以列表返回举例。
1.4 多个优化器手动实现反向传播
我们改写一下1.1 的模型来看看怎么实现多个优化器情况。这里假设有多个优化器
class LitModel(pl.LightningModule): def __init__(self, imgsize=28, hidden=128, num_class=10): super().__init__() self.flatten = nn.Flatten() #################### # 1. 修改模型为两部分方便应用不同的优化器 self.feature_extractor = nn.Sequential( nn.Linear(imgsize * imgsize, hidden), nn.ReLU() ) self.classifier = nn.Sequential( nn.Linear(hidden, num_class) ) #关闭自动优化 self.automatic_optimization = False #################### self.loss_fn = nn.CrossEntropyLoss() # 增加保存参数 self.save_hyperparameters() def forward(self, x): x = self.flatten(x) features = self.feature_extractor(x) return self.classifier(features) def training_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) opt1, opt2 = self.optimizers() #3,手动优化 if batch_idx % 2 != 0: opt1.zero_grad() self.manual_backward(loss) opt1.step() # 每2步更新一次分类器 if batch_idx % 2 == 0: opt2.zero_grad() self.manual_backward(loss) opt2.step() self.log("train_loss", loss) # 自动记录训练损失 # 不在返回loss,因为已经手动实现 return None def validation_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) self.log("val_loss", loss) # 自动记录验证损失 return loss def test_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) self.log("test_loss", loss) # 自动记录测试损失 return loss def configure_optimizers(self): #2. 这里对不同参数应用不同优化和调度, optimizer1 = torch.optim.Adam( self.feature_extractor.parameters(), lr=1e-3, weight_decay=1e-4 ) optimizer2 = torch.optim.SGD( self.classifier.parameters(), lr=1e-2, momentum=0.9 ) # 假如这里只需对第二个优化器进行学习率调度 scheduler2 = torch.optim.lr_scheduler.MultiStepLR( optimizer2, milestones=[1,3,5], gamma=0.5, last_epoch=-1) # 返回值return [optimizer1, optimizer2], scheduler2 # 4. 增加钩子函数,每个batch后手动更新学习率 def on_train_batch_end(self, outputs, batch, batch_idx): # 获取调度器 scheduler2 = self.lr_schedulers() # 更新调度器(按批次) scheduler2.step() # 记录学习率 lr = scheduler2.get_last_lr()[0] self.log("lr", lr, prog_bar=True) # 5. 增加钩子函数,模型中保存参数配置 def on_save_checkpoint(self, checkpoint: dict) -> None: checkpoint["save_data"] = self.hparams
2 数据
class MNISTDataModule(pl.LightningDataModule): def __init__(self, batch_size=32): super().__init__() self.batch_size = batch_size # 数据处理部分,只会加载一次 def prepare_data(self): # 下载数据集 MNIST(root="data", train=True, download=True) MNIST(root="data", train=False, download=True) # 在分布式训练的时候,每个进程都会加载 def setup(self, stage=None): # 数据预处理和划分 transform = ToTensor() mnist_full = MNIST(root="data", train=True, transform=transform) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) self.mnist_test = MNIST(root="data", train=False, transform=transform) def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True) def val_dataloader(self): return DataLoader(self.mnist_val, batch_size=self.batch_size) def test_dataloader(self): # 当然如果没有test,只需要返回None return DataLoader(self.mnist_test, batch_size=self.batch_size)
以上是数据模块的基本结构
3 回调函数定义
下面是一个定义的例子,作用是保存config
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
from typing import Dict, Any
import os
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
from omegaconf import OmegaConfclass SetupCallback(Callback): def __init__(self, now, cfgdir, config): super().__init__() self.now = now self.cfgdir = cfgdir self.config = config def on_train_start(self, trainer, pl_module): # 只在主进程保留 if trainer.global_rank == 0: print("Project config") OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) # 把config 也存到模型中 def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> None: checkpoint["cfg_project"] = self.config
这里on_train_start``on_save_checkpoint
就是定义在回调函数中的钩子函数,下面介绍这几个钩子函数使用方法
4 几个常见的钩子函数使用
以下是三个钩子函数在模型类和回调函数中使用方法及调用顺序的简明对比表格:
钩子函数 | 位置 | 方法签名 | 典型用途 | 调用顺序 | 执行频率 |
---|---|---|---|---|---|
on_fit_start | 模型类 | def on_fit_start(self) | 全局初始化、分布式设置 | 先执行 | 整个训练过程一次 |
回调函数 | def on_fit_start(self, trainer, pl_module) | 准备日志系统、设置全局状态 | 后执行 | 整个训练过程一次 | |
on_train_start | 模型类 | def on_train_start(self) | 初始化训练指标、计时器 | 先执行 | 每个训练循环一次 |
回调函数 | def on_train_start(self, trainer, pl_module) | 重置回调状态、训练前准备 | 后执行 | 每个训练循环一次 | |
on_save_checkpoint | 模型类 | def on_save_checkpoint(self, checkpoint) | 保存模型额外状态 | 后执行 | 每次保存检查点时 |
回调函数 | def on_save_checkpoint(self, trainer, pl_module, checkpoint) | 保存回调状态、添加元数据 | 先执行 | 每次保存检查点时 |
调用顺序示意图
训练开始
├── on_fit_start
│ ├── 模型类.on_fit_start (1)
│ └── 回调函数.on_fit_start (2)
│
├── on_train_start
│ ├── 模型类.on_train_start (3)
│ └── 回调函数.on_train_start (4)
│
├── 训练周期...
│
└── 保存检查点├── 回调函数.on_save_checkpoint (5)└── 模型类.on_save_checkpoint (6)
5 完整代码
simple.yaml
trainer:accelerator: gpudevices: [1]max_epochs: 100
call_back:modelckpt:filename: "{epoch:03}-{train_loss:.2f}-{val_loss:.2f}"save_top_k: -1save_step: trueevery_n_epochs: 10verbose: truesave_last: true
model:target: main_pl.LitModelparams:imgsize: 28hidden: 128num_class: 10
data:target: main_pl.MNISTDataModuleparams:batch_size: 32
main_pl.py
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
from typing import Dict, Any
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
from omegaconf import OmegaConf
import argparse, os, datetime, importlib, glob
from pytorch_lightning import Trainer def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def instantiate_from_config(config): if not "target" in config: raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) # 1. 定义 LightningModule 子类
class LitModel(pl.LightningModule): def __init__(self, imgsize=28, hidden=128, num_class=10): super().__init__() self.flatten = nn.Flatten() #################### # 1. 修改模型为两部分方便应用不同的优化器 self.feature_extractor = nn.Sequential( nn.Linear(imgsize * imgsize, hidden), nn.ReLU() ) self.classifier = nn.Sequential( nn.Linear(hidden, num_class) ) #关闭自动优化 self.automatic_optimization = False #################### self.loss_fn = nn.CrossEntropyLoss() # 增加保存参数 self.save_hyperparameters() def forward(self, x): x = self.flatten(x) features = self.feature_extractor(x) return self.classifier(features) def training_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) opt1, opt2 = self.optimizers() #3,手动优化 if batch_idx % 2 != 0: opt1.zero_grad() self.manual_backward(loss) opt1.step() # 每2步更新一次分类器 if batch_idx % 2 == 0: opt2.zero_grad() self.manual_backward(loss) opt2.step() self.log("train_loss", loss) # 自动记录训练损失 # 不在返回loss,因为已经手动实现 return None def validation_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) self.log("val_loss", loss) # 自动记录验证损失 return loss def test_step(self, batch, batch_idx): x, y = batch pred = self(x) loss = self.loss_fn(pred, y) self.log("test_loss", loss) # 自动记录测试损失 return loss def configure_optimizers(self): #2. 这里对不同参数应用不同优化和调度, optimizer1 = torch.optim.Adam( self.feature_extractor.parameters(), lr=1e-3, weight_decay=1e-4 ) optimizer2 = torch.optim.SGD( self.classifier.parameters(), lr=1e-2, momentum=0.9 ) # 假如这里只需对第二个优化器进行学习率调度 scheduler2 = torch.optim.lr_scheduler.MultiStepLR( optimizer2, milestones=[1,3,5], gamma=0.5, last_epoch=-1) # 返回值return [optimizer1, optimizer2], scheduler2 # 4. 增加钩子函数,每个batch后手动更新学习率 def on_train_batch_end(self, outputs, batch, batch_idx): # 获取调度器 scheduler2 = self.lr_schedulers() # 更新调度器(按批次) scheduler2.step() # 记录学习率 lr = scheduler2.get_last_lr()[0] self.log("lr", lr, prog_bar=True) def on_save_checkpoint(self, checkpoint: dict) -> None: checkpoint["save_data"] = self.hparams # 2. 准备数据模块
class MNISTDataModule(pl.LightningDataModule): def __init__(self, batch_size=32): super().__init__() self.batch_size = batch_size # 数据处理部分,只会加载一次 def prepare_data(self): # 下载数据集 MNIST(root="data", train=True, download=True) MNIST(root="data", train=False, download=True) # 在分布式训练的时候,每个进程都会加载 def setup(self, stage=None): # 数据预处理和划分 transform = ToTensor() mnist_full = MNIST(root="data", train=True, transform=transform) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) self.mnist_test = MNIST(root="data", train=False, transform=transform) def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True) def val_dataloader(self): return DataLoader(self.mnist_val, batch_size=self.batch_size) def test_dataloader(self): # 当然如果没有test,只需要返回None return DataLoader(self.mnist_test, batch_size=self.batch_size) class SetupCallback(Callback): def __init__(self, now, cfgdir, config): super().__init__() self.now = now self.cfgdir = cfgdir self.config = config def on_train_start(self, trainer, pl_module): # 只在主进程保留 if trainer.global_rank == 0: print("Project config") OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) # 把config 也存到模型中 def on_save_checkpoint( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] ) -> None: checkpoint["cfg_project"] = self.config if __name__ == '__main__': now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") configs_path = 'simple.yaml' logdir = os.path.join("logs", now) config = OmegaConf.load(configs_path) #0. 从config中读取 trainer 基本参数 accelerator: gpu devices: [0] max_epochs: 100 trainer_kwargs = dict(config.trainer) #1, 定义loss loger trainer_kwargs logger_cfg = { "target": "pytorch_lightning.loggers.TensorBoardLogger", "params": { "name": "tensorboard", "save_dir": logdir, } } logger_cfg = OmegaConf.create(logger_cfg) trainer_kwargs["logger"] = instantiate_from_config(logger_cfg) #2. callback modelckpt_params = config.call_back.modelckpt callbacks_cfg = { "setup_callback": { "target": "main_pl.SetupCallback", "params": { "now": now, "cfgdir": logdir, "config": config, } }, "learning_rate_logger": { "target": "main_pl.LearningRateMonitor", "params": { "logging_interval": "step", } }, "checkpoint_callback":{ "target": "pytorch_lightning.callbacks.ModelCheckpoint", "params": { "dirpath": logdir, "filename": modelckpt_params.filename, "save_top_k": modelckpt_params.save_top_k, "verbose": modelckpt_params.verbose, "save_last": modelckpt_params.save_last, "every_n_epochs": modelckpt_params.every_n_epochs, } } } callbacks = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] trainer_kwargs["callbacks"] = callbacks print(trainer_kwargs) #3. 构建trainer trainer = Trainer(**trainer_kwargs) #4. 构建data 和 model # build data and model data = instantiate_from_config(config.data) model = instantiate_from_config(config.model) #5. 训练 try: trainer.fit(model, data) except Exception: raise
运行结果:
logs中会保存我们的配置和tensorboard的日志以及模型
5 总结
PyTorch Lightning 是一个轻量级的 PyTorch 封装库,它通过结构化代码和自动化工程细节,显著提升深度学习研究和开发的效率。以下是其主要优势总结:
1. 代码结构化与可读性
- 关注科研而非工程:将模型定义、训练逻辑、工程代码解耦
- 标准化接口:强制使用
LightningModule
方法(training_step
,configure_optimizers
等) - 减少样板代码:训练循环代码量减少 80%+
# 传统 PyTorch vs Lightning
# ---------------------------
# PyTorch: 需手动编写训练循环
for epoch in epochs:for batch in data:optimizer.zero_grad()loss = model(batch)loss.backward()optimizer.step()# Lightning: 只需定义逻辑
class LitModel(pl.LightningModule):def training_step(self, batch, batch_idx):x, y = batchy_hat = self(x)loss = F.cross_entropy(y_hat, y)return loss
2. 自动化工程细节
功能 | 实现方式 | 优势 |
---|---|---|
分布式训练 | Trainer(accelerator="gpu", devices=4) | 单行代码启用多GPU/TPU |
混合精度训练 | Trainer(precision="16-mixed") | 显存节省+速度提升 |
梯度累积 | Trainer(accumulate_grad_batches=4) | 模拟更大batch_size |
早停机制 | callbacks=[EarlyStopping(...)] | 自动防止过拟合 |
3. 可复现性与实验管理
- 版本控制:自动保存超参数 (
self.save_hyperparameters()
) - 实验跟踪:内置支持 TensorBoard/W&B/MLFlow
- 完整检查点:一键保存模型+优化器+超参数状态
# 自动记录所有实验
trainer = Trainer(logger=TensorBoardLogger("logs/"))
4. 硬件无关性
单行切换硬件环境:
# CPU → GPU → TPU → 多节点分布式
trainer = Trainer(accelerator="auto", # 自动检测硬件devices="auto", # 使用所有可用设备strategy="ddp_find_unused_parameters_true" # 分布式策略
)
5. 调试与开发效率
# 快速验证代码
trainer = Trainer(fast_dev_run=True) # 只跑1个batch# 性能分析
trainer = Trainer(profiler="advanced") # 识别瓶颈