30天打牢数模基础-卷积神经网络讲解

案例代码实现

一、代码说明

本案例使用PyTorch实现一个改进版LeNet-5模型,用于CIFAR-10数据集的图像分类任务。代码包含以下核心步骤:

数据加载与预处理(含数据增强,划分训练/验证/测试集);

定义CNN网络结构(LeNet-5改进版,适配3通道输入);

模型训练(用验证集评估泛化能力);

模型测试与结果可视化(用独立测试集最终评估)。

适合人群:数模小白(无需深度学习基础,代码注释详细,逻辑清晰)。运行环境:Python3.8+、PyTorch1.10+、torchvision0.11+、matplotlib3.5+。

二、完整代码实现

# 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np# ------------------------------
# 1. 配置全局参数(数模小白可调整这里)
# ------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 优先用GPU
BATCH_SIZE = 64  # 每批数据量(越大训练越快,但占内存越多)
EPOCHS = 10  # 训练轮数(越大模型越准,但训练时间越长)
LEARNING_RATE = 0.001  # 学习率(越小收敛越稳,但训练越慢)
VAL_SPLIT = 0.2  # 验证集占训练集的比例(20%)# ------------------------------
# 2. 数据加载与预处理(含数据增强,划分训练/验证/测试集)
# ------------------------------
def load_data():"""加载CIFAR-10数据集,返回训练/验证/测试DataLoader"""# 训练集数据增强(防止过拟合):随机裁剪、水平翻转、归一化train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),  # 随机裁剪32x32,边缘补4像素transforms.RandomHorizontalFlip(),     # 随机水平翻转(50%概率)transforms.ToTensor(),                 # 转为Tensor(0-1)transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 归一化到[-1,1]])# 验证集/测试集预处理(不增强,保持真实分布)val_test_transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 下载/加载数据集(第一次运行会下载,约170MB)full_train_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)val_dataset = datasets.CIFAR10(root="./data", train=True, download=True, transform=val_test_transform)test_dataset = datasets.CIFAR10(root="./data", train=False, download=True, transform=val_test_transform)# 划分训练集和验证集(8:2)train_size = int((1 - VAL_SPLIT) * len(full_train_dataset))val_size = len(full_train_dataset) - train_sizetrain_dataset, _ = random_split(full_train_dataset, [train_size, val_size])_, val_dataset = random_split(val_dataset, [train_size, val_size])  # 保持验证集transform正确# 生成DataLoader(批量加载数据)train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)return train_loader, val_loader, test_loader# ------------------------------
# 3. 定义CNN网络结构(改进版LeNet-5)
# ------------------------------
class LeNet5(nn.Module):"""改进版LeNet-5,适配CIFAR-10的3通道输入(3x32x32)"""def __init__(self):super(LeNet5, self).__init__()# 卷积层1:提取边缘特征(3通道→6通道,5x5 kernel)self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)# 最大池化层1:简化特征(2x2窗口,步长2)self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)# 卷积层2:提取纹理/形状特征(6通道→16通道,5x5 kernel)self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)# 最大池化层2:进一步简化特征(2x2窗口,步长2)self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)# 全连接层1:整合高级特征(16*5*5→120)self.fc1 = nn.Linear(16 * 5 * 5, 120)# 全连接层2:进一步整合特征(120→84)self.fc2 = nn.Linear(120, 84)# 输出层:分类决策(84→10类,对应CIFAR-10标签)self.fc3 = nn.Linear(84, 10)# 激活函数(ReLU,引入非线性,解决线性模型表达能力不足问题)self.relu = nn.ReLU()def forward(self, x):"""前向传播:定义数据在网络中的流动路径"""# 卷积层1 → ReLU → 池化层1:3x32x32 → 6x28x28 → 6x14x14x = self.pool1(self.relu(self.conv1(x)))# 卷积层2 → ReLU → 池化层2:6x14x14 → 16x10x10 → 16x5x5x = self.pool2(self.relu(self.conv2(x)))# 展平:将二维特征图转为一维向量(16x5x5 → 400),适配全连接层x = x.view(-1, 16 * 5 * 5)# 全连接层1 → ReLU:400 → 120x = self.relu(self.fc1(x))# 全连接层2 → ReLU:120 → 84x = self.relu(self.fc2(x))# 输出层:84 → 10(不使用Softmax,因为CrossEntropyLoss会自动处理)x = self.fc3(x)return x# ------------------------------
# 4. 模型训练与验证函数(用验证集评估泛化能力)
# ------------------------------
def train_model(model, train_loader, val_loader, optimizer, criterion):"""训练模型,每轮输出训练/验证损失与准确率"""best_val_acc = 0.0  # 记录最佳验证准确率(用于保存最优模型)for epoch in range(EPOCHS):# ------------------------------# 训练阶段(更新模型参数)# ------------------------------model.train()  # 切换到训练模式(启用BatchNorm/ Dropout等训练专用层)train_loss = 0.0train_correct = 0for inputs, labels in train_loader:inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)  # 数据移至GPU/CPUoptimizer.zero_grad()  # 清空梯度(避免梯度累积)outputs = model(inputs)  # 前向传播:输入→模型→输出(预测值)loss = criterion(outputs, labels)  # 计算损失(预测值与真实值的差距)loss.backward()  # 反向传播:计算梯度(从损失到各层参数)optimizer.step()  # 更新参数(用梯度调整参数,最小化损失)# 统计训练损失与准确率train_loss += loss.item() * inputs.size(0)  # 累计损失(乘以批量大小,避免批量大小影响)_, preds = torch.max(outputs, 1)  # 取预测概率最大的类别(0-9)train_correct += (preds == labels).sum().item()  # 统计正确预测的样本数# 计算训练集平均损失与准确率train_loss = train_loss / len(train_loader.dataset)train_acc = train_correct / len(train_loader.dataset)# ------------------------------# 验证阶段(评估泛化能力,不更新参数)# ------------------------------model.eval()  # 切换到验证模式(关闭BatchNorm/ Dropout等)val_loss = 0.0val_correct = 0with torch.no_grad():  # 关闭梯度计算(节省内存,加速验证)for inputs, labels in val_loader:inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)outputs = model(inputs)loss = criterion(outputs, labels)# 统计验证损失与准确率val_loss += loss.item() * inputs.size(0)_, preds = torch.max(outputs, 1)val_correct += (preds == labels).sum().item()# 计算验证集平均损失与准确率val_loss = val_loss / len(val_loader.dataset)val_acc = val_correct / len(val_loader.dataset)# 打印本轮训练/验证结果print(f"Epoch {epoch+1}/{EPOCHS}")print(f"训练集:损失={train_loss:.4f},准确率={train_acc:.4f}")print(f"验证集:损失={val_loss:.4f},准确率={val_acc:.4f}")print("-" * 50)# 保存最佳模型(验证准确率最高的模型,避免过拟合)if val_acc > best_val_acc:best_val_acc = val_acctorch.save(model.state_dict(), "best_model.pth")print(f"训练结束,最佳验证准确率={best_val_acc:.4f}(模型已保存至best_model.pth)")# ------------------------------
# 5. 模型测试与结果可视化(用独立测试集最终评估)
# ------------------------------
def test_model(model, test_loader):"""用独立测试集评估模型性能,输出准确率并可视化预测结果"""model.eval()  # 切换到验证模式test_correct = 0with torch.no_grad():  # 关闭梯度计算for inputs, labels in test_loader:inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)outputs = model(inputs)_, preds = torch.max(outputs, 1)test_correct += (preds == labels).sum().item()# 计算测试集准确率test_acc = test_correct / len(test_loader.dataset)print(f"\n测试集最终准确率={test_acc:.4f}")# 可视化10张测试图像的预测结果(直观展示模型效果)class_names = ["飞机", "汽车", "鸟", "猫", "鹿", "狗", "青蛙", "马", "船", "卡车"]inputs, labels = next(iter(test_loader))  # 取一批测试数据(BATCH_SIZE=64)inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)outputs = model(inputs)_, preds = torch.max(outputs, 1)# 绘制图像(2行5列,显示10张)plt.figure(figsize=(12, 6))for i in range(10):plt.subplot(2, 5, i+1)# 反归一化:将[-1,1]转回[0,1](方便显示图像)img = inputs[i].cpu().numpy().transpose((1, 2, 0))  # 转为HWC格式(高度×宽度×通道)img = img * 0.5 + 0.5  # 反归一化(原归一化公式:img = (img - mean) / std → 反推:img = img * std + mean)plt.imshow(img)# 设置标题:真实标签 vs 预测标签plt.title(f"真实:{class_names[labels[i]]}\n预测:{class_names[preds[i]]}", fontsize=10)plt.axis("off")  # 隐藏坐标轴plt.tight_layout()  # 调整子图间距plt.show()# ------------------------------
# 6. 主程序(整合所有步骤,执行训练与测试)
# ------------------------------
if __name__ == "__main__":# 1. 加载数据(划分训练/验证/测试集)print("正在加载数据...")train_loader, val_loader, test_loader = load_data()print(f"数据加载完成:\n- 训练集大小:{len(train_loader.dataset)} \n- 验证集大小:{len(val_loader.dataset)} \n- 测试集大小:{len(test_loader.dataset)}")# 2. 初始化模型、损失函数、优化器print("\n正在初始化模型...")model = LeNet5().to(DEVICE)  # 将模型移至GPU/CPUcriterion = nn.CrossEntropyLoss()  # 交叉熵损失(适用于多分类任务)optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)  # Adam优化器(自适应学习率,收敛更稳定)# 3. 训练模型(用验证集评估)print("\n正在训练模型...")train_model(model, train_loader, val_loader, optimizer, criterion)# 4. 加载最佳模型并测试(用独立测试集)print("\n正在测试最佳模型...")model.load_state_dict(torch.load("best_model.pth"))  # 加载训练过程中保存的最佳模型test_model(model, test_loader)

三、代码使用说明

1.环境安装

打开命令行,运行以下命令安装依赖库(建议使用虚拟环境):

pip install torch torchvision matplotlib numpy

2.运行代码

将代码保存为cnn_cifar10.py,在命令行中运行:

python cnn_cifar10.py

3.结果解释

训练过程:每轮(Epoch)输出训练集(更新参数)和验证集(评估泛化能力)的损失(Loss,越小说明预测越准)和准确率(Accuracy,越大说明模型越准)。

最佳模型:训练结束后,保存验证准确率最高的模型到best_model.pth(避免过拟合)。

测试结果:加载最佳模型后,用独立测试集评估,输出测试集准确率(一般在70%-85%之间,增加EPOCHS可提高),并显示10张测试图像的真实标签预测标签(直观看到模型效果)。

四、数模小白调整建议

提高准确率:若训练集准确率低(<80%),可增加EPOCHS(如改为20),让模型多学习几轮;或增大LEARNING_RATE(如改为0.002),加快收敛速度。

缓解过拟合:若验证集准确率远低于训练集(如差 10% 以上),可添加更多数据增强(如transforms.RandomRotation(10)随机旋转 10 度、transforms.ColorJitter(brightness=0.2)调整亮度),或减小模型复杂度(如将conv1的out_channels=6改为3)。

加速训练:若训练太慢,可增大BATCH_SIZE(如改为128,需确保GPU内存足够),或使用更高效的优化器(如optim.AdamW,带权重衰减的Adam)。

五、常见问题解答

Q:为什么要划分验证集?A:验证集用于在训练过程中评估模型的泛化能力,避免模型“记住”训练集细节(过拟合)。测试集是最终评估模型性能的“考题”,不能在训练过程中使用。

Q:数据增强为什么有效?A:数据增强(如随机裁剪、翻转)通过生成“虚拟”训练数据,扩大了训练集的多样性,让模型学习到更通用的特征,从而提高泛化能力。

Q:为什么用Adam优化器而不是SGD?A:Adam优化器会为每个参数自适应调整学习率,比传统SGD(随机梯度下降)收敛更快、更稳定,适合新手使用。

通过运行这份代码,你可以完整体验CNN从数据预处理到模型部署的全流程,理解“卷积层提取特征、池化层简化特征、全连接层做决策”的核心逻辑,为后续更复杂的深度学习模型(如ResNet、YOLO)打下基础!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/92250.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/92250.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Dev-C++——winAPI贪吃蛇小游戏

&#x1f680;欢迎互三&#x1f449;&#xff1a;雾狩 &#x1f48e;&#x1f48e; &#x1f680;关注博主&#xff0c;后期持续更新系列文章 &#x1f680;如果有错误感谢请大家批评指出&#xff0c;及时修改 &#x1f680;感谢大家点赞&#x1f44d;收藏⭐评论✍ 今天水一篇吧…

【openbmc6】entity-manager

文章目录 2.1 事件监听:dbus在linux上使用的底层通信方式多半是unix domain socket ,事件的到来可被抽象为:socket上有数据,可读 2.2 事件处理:由于主线程肯定有逻辑得跑,因此新开一个线程甚至多个线程专门用来监听和处理事件,但存在多线程就意味着可能存在竞争,存在竞…

Java 实现 UDP 多发多收通信

在网络通信领域&#xff0c;UDP&#xff08;用户数据报协议&#xff09;以其无连接、高效率的特点&#xff0c;在实时通信场景中占据重要地位。本文将结合一段实现 UDP 多发多收的 Java 代码&#xff0c;详细解析其实现逻辑&#xff0c;帮助开发者深入理解 UDP 通信的底层逻辑与…

Java学习第六十二部分——Git

目录 一、关键概述 二、核心概念 三、常用命令 四、优势因素 五、应用方案 六、使用建议 一、关键概述 提问&#xff1a;Git 是什么&#xff1f; 回答&#xff1a;一句话&#xff0c;分布式版本控制系统&#xff08;DVCS&#xff09;&#xff0c;用来跟踪文件&#…

CDN和DNS 在分布式系统中的作用

一、DNS&#xff1a;域名系统&#xff08;Domain Name System&#xff09; 1. 核心功能 DNS是互联网的“地址簿”&#xff0c;负责将人类易记的域名&#xff08;如www.baidu.com&#xff09;解析为计算机可识别的IP地址&#xff08;如180.101.50.242&#xff09;。没有DNS&…

uniapp用webview导入本地网页,ios端打开页面空白问题

目前还没解决&#xff0c;DCloud官方也说不行 IOS下webview加载本地网页时&#xff0c;无法加载资源 - DCloud问答

软考 系统架构设计师系列知识点之面向服务架构设计理论与实践(8)

接前一篇文章:软考 系统架构设计师系列知识点之面向服务架构设计理论与实践(7) 所属章节: 第15章. 面向服务架构设计理论与实践 第3节 SOA的参考架构 15.3 SOA的参考架构 IBM的Websphere业务集成参考架构(如图15-2所示,以下简称参考架构)是典型的以服务为中心的企业集…

基于 Docker 及 Kubernetes 部署 vLLM:开启机器学习模型服务的新篇章

在当今数字化浪潮中&#xff0c;机器学习模型的高效部署与管理成为众多开发者和企业关注的焦点。vLLM 作为一款性能卓越的大型语言模型推理引擎&#xff0c;其在 Docker 及 Kubernetes 上的部署方式如何呢&#xff1f;本文将深入探讨如何在 Docker 及 Kubernetes 集群中部署 vL…

工业互联网六大安全挑战的密码“解法”

目录 工业互联网密码技术应用Q&A Q1&#xff1a;设备身份认证与接入控制 Q2&#xff1a;通信数据加密与完整性保护 Q3&#xff1a;远程安全访问 Q4&#xff1a;平台与数据安全 Q5&#xff1a;软件与固件安全 Q6&#xff1a;日志审计与抗抵赖 首传信安-解决方案 总…

基于springboot的在线问卷调查系统的设计与实现(源码+论文)

一、开发环境 1 Java语言 Java语言是当今为止依然在编程语言行业具有生命力的常青树之一。Java语言最原始的诞生&#xff0c;不仅仅是创造者感觉C语言在编程上面很麻烦&#xff0c;如果只是专注于业务逻辑的处理&#xff0c;会导致忽略了各种指针以及垃圾回收这些操作&#x…

民法学学习笔记(个人向) Part.1

民法学学习笔记(个人向) Part.1有关民法条文背后的事理、人心、经济社会基础&#xff1b;民法的结构民法学习的特色就是先学最难的民法总论&#xff0c;再学较难的物权法、合同法等&#xff0c;最后再学习最简单的婚姻、继承、侵权部分。这是一个由难到易的过程&#xff0c;尤为…

ElasticSearch Doc Values和Fielddata详解

一、Doc Values介绍倒排索引在搜索包含指定 term 的文档时效率极高&#xff0c;但在执行相反操作&#xff0c;比如查询一个文档中包含哪些 term&#xff0c;以及进行排序、聚合等与指定字段相关的操作时&#xff0c;表现就很差了&#xff0c;这时候就需要用到 Doc Values。倒排…

【C语言】解决VScode中文乱码问题

文章目录【C语言】解决VScode中文乱码问题弹出无法写入用户设置的处理方法弹出无法在只读编辑器编辑的问题处理方法【C语言】解决VScode中文乱码问题 &#x1f4ac;欢迎交流&#xff1a;在学习过程中如果你有任何疑问或想法&#xff0c;欢迎在评论区留言&#xff0c;我们可以共…

MySQL笔记4

一、范式1.概念与意义范式&#xff08;Normal Form&#xff09;是数据库设计需遵循的规范&#xff0c;解决“设计随意导致后期重构困难”问题。主流有 三大范式&#xff08;1NF、2NF、3NF&#xff09;&#xff0c;还有进阶的 BCNF、4NF、5NF 等&#xff0c;范式间是递进依赖&am…

切比雪夫不等式的理解以及推导【超详细笔记】

文章目录参考教程一、意义1. 正态分布的 3σ 法则2. 不等式的含义3. 不等式的意义二、不等式的证明1. 马尔科夫不等式马尔可夫不等式证明(YYY 为非负随机变量 &#xff09;2. 切比雪夫不等式推导参考教程 一个视频&#xff0c;彻底理解切比雪夫不等式 一、意义 1. 正态分布的…

Spring Boot Jackson 序列化常用配置详解

一、引言在当今的 Web 开发领域&#xff0c;JSON&#xff08;JavaScript Object Notation&#xff09;已然成为数据交换的中流砥柱。无论是前后端分离架构下前后端之间的数据交互&#xff0c;还是微服务架构里各个微服务之间的通信&#xff0c;JSON 都承担着至关重要的角色 。它…

Jetpack ViewModel LiveData:现代Android架构组件的核心力量

引言在Android应用开发中&#xff0c;数据管理和界面更新一直是开发者面临的重大挑战。传统的开发方式常常导致Activity和Fragment变得臃肿&#xff0c;难以维护&#xff0c;且无法优雅地处理配置变更&#xff08;如屏幕旋转&#xff09;。Jetpack中的ViewModel和LiveData组件正…

Python数据分析案例79——基于征信数据开发信贷风控模型

背景 虽然模型基本都是表格数据那一套了&#xff0c;算法都没什么新鲜点&#xff0c;但是本次数据还是很值得写个案例的&#xff0c;有征信数据&#xff0c;各种&#xff0c;个人&#xff0c;机构&#xff0c;逾期汇总..... 这么多特征来做机器学习模型应该还不错。本次带来&…

板凳-------Mysql cookbook学习 (十二--------3_2)

3.3链接表 结构 P79页 用一个类图来表示EmployeeNode类的结构&#xff0c;展示其属性和关系&#xff1a; plaintext ----------------------------------------- | EmployeeNode | ----------------------------------------- | - emp_no: int …

深度学习图像预处理:统一输入图像尺寸方案

在实际训练中&#xff0c;最常见也最简单的做法&#xff0c;就是在送入网络前把所有图片「变形」到同一个分辨率&#xff08;比如 256256 或 224224&#xff09;&#xff0c;或者先裁剪&#xff0f;填充成同样大小。具体而言&#xff0c;可以分成以下几类方案&#xff1a;一、图…