前文中,只是给了基础模型:
PyTorch 实现 CIFAR-10 图像分类:从数据预处理到模型训练与评估-CSDN博客
今天我们增加交叉验证和超参数调优,
先看运行结果:
===== 在测试集上评估最终模型 =====
最终模型在测试集上的准确率:60.14%
最优模型已保存为 'cifar10_best_model.pth'(超参数:{'batch_size': 32, 'epochs': 5, 'lr': 0.01, 'momentum': 0.85})
Process finished with exit code 0
比基础模型准确率高了一点,
完整代码如下:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from sklearn.model_selection import KFold, ParameterGrid # 用于交叉验证和超参数网格搜索# --------------------------
# 1. 数据准备(与原代码一致,但后续会在训练集内部做交叉验证)
# --------------------------
# 数据预处理:标准化(与原代码相同)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 数据集路径(请替换为你的实际路径)
data_path = r'D:\workspace_py\deeplean\data'# 加载完整训练集和测试集(测试集始终不变,用于最终评估)
full_trainset = datasets.CIFAR10(root=data_path, train=True, download=False, transform=transform)
testset = datasets.CIFAR10(root=data_path, train=False, download=False, transform=transform)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# --------------------------
# 2. 定义CNN模型(与原代码一致)
# --------------------------
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return x# --------------------------
# 3. 交叉验证函数(核心新增)
# --------------------------
def cross_validate(model, train_dataset, k_folds=5, epochs=5, lr=0.001, batch_size=32, momentum=0.9):"""5折交叉验证:将训练集分成5份,每次用4份训练,1份验证,返回平均准确率"""kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42) # 固定随机种子,结果可复现fold_results = [] # 存储每折的验证准确率for fold, (train_ids, val_ids) in enumerate(kfold.split(train_dataset)):print(f'\n===== 第 {fold + 1}/{k_folds} 折交叉验证 =====')# 1. 划分当前折的训练集和验证集train_subset = Subset(train_dataset, train_ids) # 本次训练用的数据val_subset = Subset(train_dataset, val_ids) # 本次验证用的数据# 2. 创建数据加载器train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)# 3. 初始化模型和优化器(每折都重新训练新模型,避免干扰)model_instance = Net() # 重新实例化模型criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(model_instance.parameters(), lr=lr, momentum=momentum)# 4. 训练当前折的模型for epoch in range(epochs):model_instance.train() # 训练模式running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = model_instance(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()# 每200步打印一次损失(简化输出)if i % 200 == 199:print(f'折 {fold + 1},轮次 {epoch + 1},第 {i + 1} 步:平均损失 {running_loss / 200:.3f}')running_loss = 0.0# 5. 在验证集上评估当前折的模型model_instance.eval() # 验证模式correct = 0total = 0with torch.no_grad():for data in val_loader:images, labels = dataoutputs = model_instance(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()val_acc = 100 * correct / totalprint(f'第 {fold + 1} 折验证准确率:{val_acc:.2f}%')fold_results.append(val_acc)# 计算所有折的平均准确率(该超参数组合的最终得分)avg_acc = sum(fold_results) / len(fold_results)print(f'\n===== 该超参数组合的平均验证准确率:{avg_acc:.2f}% =====')return avg_acc# --------------------------
# 4. 超参数调优(核心新增)
# --------------------------
def hyperparameter_tuning(train_dataset):"""超参数网格搜索:尝试不同的超参数组合,用交叉验证选最优"""# 定义要测试的超参数组合(可根据需要增减)param_grid = {'lr': [0.001, 0.01], # 学习率:尝试两个值'batch_size': [32, 64], # 批大小:尝试两个值'momentum': [0.9, 0.85], # 动量:尝试两个值'epochs': [5] # 训练轮次(固定为5,减少计算量)}best_acc = 0.0best_params = None # 存储最优超参数# 遍历所有超参数组合(共 2×2×2=8 种组合)for params in ParameterGrid(param_grid):print(f'\n---------- 测试超参数组合:{params} ----------')# 用交叉验证评估当前组合的性能current_acc = cross_validate(model=Net(),train_dataset=train_dataset,k_folds=5,epochs=params['epochs'],lr=params['lr'],batch_size=params['batch_size'],momentum=params['momentum'])# 记录最优组合if current_acc > best_acc:best_acc = current_accbest_params = paramsprint(f'★ 发现更优组合!当前最优准确率:{best_acc:.2f}%')print(f'\n===== 超参数调优完成 =====')print(f'最优超参数:{best_params}')print(f'最优平均验证准确率:{best_acc:.2f}%')return best_params# --------------------------
# 5. 主函数:执行超参数调优 + 最终训练 + 测试集评估
# --------------------------
if __name__ == '__main__':# 步骤1:超参数调优(用交叉验证选最优参数)print('===== 开始超参数调优(这一步比较慢,需要耐心等待)=====')best_params = hyperparameter_tuning(full_trainset)# 步骤2:用最优超参数在完整训练集上训练最终模型print('\n===== 用最优超参数训练最终模型 =====')final_model = Net()criterion = nn.CrossEntropyLoss()optimizer = optim.SGD(final_model.parameters(),lr=best_params['lr'],momentum=best_params['momentum'])train_loader = DataLoader(full_trainset,batch_size=best_params['batch_size'],shuffle=True)# 训练最终模型(轮次与调优时一致)for epoch in range(best_params['epochs']):final_model.train()running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = dataoptimizer.zero_grad()outputs = final_model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 200 == 199:print(f'最终模型训练 - 轮次 {epoch + 1},第 {i + 1} 步:平均损失 {running_loss / 200:.3f}')running_loss = 0.0# 步骤3:在测试集上评估最终模型(用从未见过的测试数据)print('\n===== 在测试集上评估最终模型 =====')final_model.eval()test_loader = DataLoader(testset, batch_size=32, shuffle=False)correct = 0total = 0with torch.no_grad():for data in test_loader:images, labels = dataoutputs = final_model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()test_acc = 100 * correct / totalprint(f'最终模型在测试集上的准确率:{test_acc:.2f}%')# 步骤4:保存最优模型torch.save(final_model.state_dict(), 'cifar10_best_model.pth')print(f"最优模型已保存为 'cifar10_best_model.pth'(超参数:{best_params})")
新增加的功能 :
(1)5 折交叉验证(cross_validate
函数)
- 作用:把训练集分成 5 份,每次用 4 份训练、1 份验证,重复 5 次,取平均准确率作为 “该参数组合的得分”。
- 白话举例:相当于学生做 5 套模拟题,每次用 4 套复习、1 套测试,最后算平均分,比只做 1 套题更能反映真实水平。
- 关键细节:每折都重新训练新模型,避免前一折的 “记忆” 影响结果。
(2)超参数调优(hyperparameter_tuning
函数)
- 作用:尝试不同的超参数组合(如学习率 0.001 vs 0.01,批大小 32 vs 64),用交叉验证选平均分最高的组合。
- 白话举例:相当于学生尝试不同的复习方法(每天学 1 小时 vs 2 小时,刷题 vs 看笔记),通过模拟题平均分找到最适合自己的方法。
- 参数网格:代码中测试了 8 种组合(2 学习率 ×2 批大小 ×2 动量),可根据需要增减(组合越多,计算时间越长)。
(3)最终模型训练
- 用调优得到的 “最优超参数” 在完整训练集上重新训练模型(之前交叉验证只用了部分数据)。
- 最后在独立的测试集上评估(测试集从未参与训练和调优,相当于 “高考”)。
3. 运行说明
- 计算时间:超参数调优 + 交叉验证会比原代码慢很多(8 种组合 ×5 折 ×5 轮训练),建议在有 GPU 的环境运行。
- 结果解读:最终会输出 “最优超参数” 和 “测试集准确率”,这个准确率比原代码更可信(排除了偶然因素)。
- 可调整项:
param_grid
中的参数可以修改(如增加学习率选项[0.0001, 0.001, 0.01]
),但组合数会增加,计算时间变长。
通过这两个步骤,模型的性能和可靠性会显著提升,尤其适合数据量不大的场景(如医学影像、小数据集)。
交叉验证
一、什么是交叉验证?为什么需要它?
1. 核心问题:如何判断模型好坏?
假设你用一份训练集训练模型,然后用同一批数据测试,准确率 90%—— 这能说明模型好吗?不能!因为模型可能 “死记硬背” 了训练数据(过拟合),换一批新数据就不行了。
所以需要用 “没见过的数据” 来验证模型 —— 但我们只有一份训练集,怎么办?
2. 交叉验证的解决思路
交叉验证(以代码中的5 折交叉验证为例)就像 “多次模拟考试”:
- 把训练集分成 5 等份(比如 5 个小数据集 A、B、C、D、E)。
- 第一次:用 A、B、C、D 训练模型,用 E 验证(看模型在 E 上的准确率)。
- 第二次:用 A、B、C、E 训练,用 D 验证。
- 重复 5 次(每次换一份做验证集),最后取 5 次验证准确率的平均值。
这样做的好处:
- 避免 “一次验证” 的偶然性(比如刚好抽到简单的验证集)。
- 更全面地评估模型在不同数据分布上的表现,结果更可靠。
3. 代码中的交叉验证实现(cross_validate 函数)
代码里的cross_validate函数就是干这个的:
- 用KFold(n_splits=5)把训练集分成 5 份。
- 循环 5 次(每折):
- 每次从 5 份中选 4 份做 “临时训练集”,1 份做 “临时验证集”。
- 用临时训练集训练模型,用临时验证集算准确率。
- 最后返回 5 次准确率的平均值,作为这个模型 / 超参数组合的 “评分”。
二、什么是超参数调优?为什么需要它?
1. 超参数是什么?
超参数是训练前手动设定的参数,不是模型自己学出来的。比如代码中的:
- lr(学习率):模型更新参数的 “步长”,太大可能跑过头,太小可能学太慢。
- batch_size(批大小):每次训练用多少数据,影响训练速度和稳定性。
- momentum(动量):优化器的参数,帮助模型更快收敛。
这些参数直接影响模型的训练效果,但没有 “标准答案”,需要试出来。
2. 超参数调优的目的
找到一组最好的超参数组合,让模型的性能(比如准确率)达到最高。
比如:学习率 0.01 + 批大小 32 + 动量 0.9 可能比 学习率 0.001 + 批大小 64 + 动量 0.85 效果更好,我们需要找到这个 “更好” 的组合。
3. 代码中的超参数调优实现(网格搜索)
代码用了 “网格搜索” 的方法,原理很简单:
- 列清单:先定义每个超参数的可能取值(比如lr选 [0.001, 0.01],batch_size选 [32, 64])。
- 组合所有可能:把这些取值的所有搭配列出来(比如 2×2×2=8 种组合)。
- 逐个测试:对每种组合,用交叉验证算它的 “评分”(平均验证准确率)。
- 选最优:最后挑出评分最高的组合,作为 “最佳超参数”。
对应代码中的hyperparameter_tuning函数:
- param_grid定义了要测试的超参数和可能值。
- ParameterGrid自动生成所有组合。
- 循环每个组合,用cross_validate算分,保存最高分的组合。
三、交叉验证和超参数调优的关系
简单说:超参数调优是 “找最好的配方”,交叉验证是 “判断配方好不好的工具”。
- 没有交叉验证,直接用一组数据测试超参数,可能因为 “运气好” 选错(比如刚好验证集简单)。
- 用交叉验证评估每个超参数组合,结果更可靠,能真正找到 “稳定好” 的组合。
总结
- 交叉验证:通过多次 “训练 - 验证” 划分,更可靠地评估模型性能,避免偶然性。
- 超参数调优:通过尝试不同的超参数组合(网格搜索),结合交叉验证的评分,找到让模型表现最好的 “参数配方”。
代码中,先通过超参数调优找到最好的参数,再用这些参数训练最终模型,最后在测试集上验证 —— 这样得到的模型更可能在新数据上表现良好。