使用PyTorch构建卷积神经网络(CNN)实现CIFAR-10图像分类

在计算机视觉领域,卷积神经网络(CNN)已经成为处理图像识别任务的事实标准。从人脸识别到医学影像分析,CNN展现出了惊人的能力。本文将详细介绍如何使用PyTorch框架构建一个CNN模型,并在经典的CIFAR-10数据集上进行图像分类任务。

CIFAR-10数据集包含10个类别的60000张32x32彩色图像,每个类别有6000张图像,其中50000张用于训练,10000张用于测试。这个数据集虽然图像尺寸较小,但包含了足够的复杂性,是学习计算机视觉和深度学习的理想起点。

一、卷积神经网络基础

1.1 卷积层

卷积层是CNN的核心组件,它通过卷积核(滤波器)在输入图像上滑动,计算局部区域的点积。PyTorch中的nn.Conv2d实现了这一功能:

self.conv1 = nn.Conv2d(3, 32, 3, padding=1)

这行代码创建了一个卷积层,参数含义如下:

  • 输入通道数:3(对应RGB三通道)

  • 输出通道数:32(即使用32个不同的滤波器)

  • 卷积核大小:3×3

  • padding=1保持空间维度不变

卷积层能够自动学习从简单边缘到复杂模式的各种特征,这种层次化的特征学习是CNN强大性能的关键。

1.2 池化层

池化层(通常是最大池化)用于降低特征图的空间维度:

self.pool = nn.MaxPool2d(2, 2)

最大池化取2×2窗口中的最大值,步长为2,这会使特征图尺寸减半。池化的作用包括:

  1. 减少计算量和参数数量

  2. 增强特征的位置不变性

  3. 防止过拟合

1.3 全连接层

在多个卷积和池化层之后,我们使用全连接层进行分类:

self.fc1 = nn.Linear(128 * 4 * 4, 512)
self.fc2 = nn.Linear(512, 10)

第一个全连接层将展平的特征向量(128×4×4)映射到512维空间,第二个则输出10维向量对应10个类别。

二、数据准备与预处理

2.1 数据加载

PyTorch的torchvision.datasets模块提供了便捷的CIFAR-10加载方式:

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)

2.2 数据预处理

良好的数据预处理对模型性能至关重要:

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

这里进行了两个关键操作:

  1. ToTensor():将PIL图像转换为PyTorch张量,并自动将像素值从[0,255]缩放到[0,1]

  2. Normalize:用均值0.5和标准差0.5对每个通道进行标准化

2.3 数据批量加载

使用DataLoader实现高效的批量数据加载:

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,shuffle=True, num_workers=2)

参数说明:

  • batch_size=64:每次迭代处理64张图像

  • shuffle=True:每个epoch打乱数据顺序

  • num_workers=2:使用2个子进程加载数据

三、模型构建

3.1 网络架构设计

我们构建的CNN包含四个卷积层和两个全连接层:

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.conv4 = nn.Conv2d(128, 128, 3, padding=1)self.fc1 = nn.Linear(128 * 4 * 4, 512)self.fc2 = nn.Linear(512, 10)self.dropout = nn.Dropout(0.5)

3.2 前向传播

定义数据在网络中的流动路径:

def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = self.pool(F.relu(self.conv3(x)))x = F.relu(self.conv4(x))x = x.view(-1, 128 * 4 * 4)x = self.dropout(x)x = F.relu(self.fc1(x))x = self.dropout(x)x = self.fc2(x)return x

关键点:

  1. 每个卷积层后接ReLU激活函数引入非线性

  2. 使用view将三维特征图展平为一维向量

  3. Dropout层以0.5的概率随机失活神经元,防止过拟合

四、模型训练

4.1 训练设置

model = CNN()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

我们使用:

  • 交叉熵损失函数:适合多分类问题

  • Adam优化器:自适应学习率,通常比SGD表现更好

  • GPU加速(如果可用)

4.2 训练循环

for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()

每个epoch中:

  1. 从DataLoader获取一个batch的数据

  2. 清零梯度(防止梯度累积)

  3. 前向传播计算输出和损失

  4. 反向传播计算梯度

  5. 优化器更新权重

  6. 统计损失和准确率

4.3 训练可视化

绘制训练过程中的损失和准确率曲线:

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Training Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

五、模型评估

5.1 测试集评估

correct = 0
total = 0
with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy on test images: {100 * correct / total:.2f}%')

关键点:

  1. with torch.no_grad():禁用梯度计算,节省内存和计算资源

  2. 计算模型在未见过的测试集上的准确率

5.2 示例预测

可视化一些测试图像及其预测结果:

dataiter = iter(testloader)
images, labels = next(dataiter)imshow(torchvision.utils.make_grid(images[:4]))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))outputs = model(images.to(device))
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}' for j in range(4)))

六、性能优化建议

虽然我们的基础模型已经能达到75-80%的准确率,但还可以通过以下方法进一步提升:

  1. 网络架构改进

    • 添加批量归一化层(nn.BatchNorm2d)加速训练并提高性能

    • 使用更深的网络结构(如ResNet残差连接)

  2. 数据增强

    transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomCrop(32, padding=4),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
  3. 训练技巧

    • 使用学习率调度器(如lr_scheduler.StepLR

    • 早停法防止过拟合

    • 尝试不同的优化器(如AdamW)

  4. 正则化

    • 增加Dropout比例

    • 在优化器中添加权重衰减(L2正则化)

七、总结

本文详细介绍了使用PyTorch实现CNN进行CIFAR-10图像分类的完整流程。我们从CNN的基础组件开始,逐步构建了一个包含卷积层、池化层和全连接层的网络模型。通过合理的数据预处理、模型训练和评估,我们实现了一个具有不错分类性能的图像识别系统。

CNN之所以在图像任务中表现优异,关键在于它的两个特性:

  1. 局部连接:卷积核只关注局部区域,大大减少了参数量

  2. 参数共享:同一卷积核在整个图像上滑动使用,提高了效率

通过本实践,读者不仅能够理解CNN的工作原理,还能掌握PyTorch实现深度学习模型的标准流程。这为进一步探索更复杂的计算机视觉任务(如目标检测、图像分割等)奠定了坚实基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/95887.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/95887.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Springboot3+SpringSecurity6Oauth2+vue3前后端分离认证授权-客户端

客户端服务整体流程前端技术栈项目结构代码后端技术栈项目结构代码整体流程 #mermaid-svg-G5v6x8CSkEcvglmt {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-G5v6x8CSkEcvglmt .error-icon{fill:#552222;}#mermaid-s…

DEEP THINK WITH CONFIDENCE-Meta-基于置信度的深度思考

原文地址 摘要 大型语言模型(LLM)通过自我一致性和多数投票等测试时间缩放方法,在推理任务中显示出巨大的潜力。然而,这种方法经常导致精度回报递减和高计算开销。为了应对这些挑战,我们引入了深度自信思考(DeepConf),这是一种简…

零基础学习数据采集与监视控制系统SCADA

新晋码农一枚,小编定期整理一些写的比较好的代码,作为自己的学习笔记,会试着做一下批注和补充,转载或者参考他人文献会标明出处,非商用,如有侵权会删改!欢迎大家斧正和讨论! 目录 一…

docker run 命令,不接it选项,run一个centos没有显示在运行,而run一个nginx却可以呢?

docker run 命令,不接it选项,run一个centos没有显示在运行,而run一个nginx却可以呢? ChatGPT said: 你问到的这个现象,其实就是 镜像默认启动命令 (ENTRYPOINT / CMD) 的差异导致的。🔍 情况分析 1. docker…

【完整源码+数据集+部署教程】水培植物病害检测系统源码和数据集:改进yolo11-AKConv

背景意义 研究背景与意义 随着全球人口的不断增长,农业生产面临着前所未有的挑战,尤其是在资源有限的环境中,如何提高作物的产量和质量成为了亟待解决的问题。水培技术作为一种新兴的农业生产方式,因其高效的水资源利用和较少的土…

第2课:环境搭建:基于DeepSeek API的开发环境配置

概述 在开始大模型RAG实战之旅前,一个正确且高效的开发环境是成功的基石。本文将手把手指导您完成从零开始的环境配置过程,涵盖Python环境设置、关键库安装、DeepSeek API配置以及开发工具优化。通过详细的步骤说明、常见问题解答和最佳实践分享&#x…

Boost电路:稳态和小信号分析

稳态分析 参考张卫平的《开关变换器的建模与控制》的1.3章节内容;伏秒平衡:在稳态下,一个开关周期内电感电流的增量是0,即 dIL(t)dt0\frac{dI_{L}(t)}{dt} 0dtdIL​(t)​0。电荷平衡:在稳态下,一个开关周期…

Vue-25-利用Vue3大模型对话框设计之前端和后端的基础实现

文章目录 1 设计思路 1.1 核心布局与组件 1.2 交互设计(Interaction Design) 1.3 视觉与用户体验 1.4 高级功能与创新设计 2 vue3前端设计 2.1 项目启动 2.1.1 创建和启动项目(vite+vue) 2.1.2 清理不需要的代码 2.1.3 下载必备的依赖(element-plus) 2.1.4 完整引入并注册(main…

Elasticsearch面试精讲 Day 7:全文搜索与相关性评分

【Elasticsearch面试精讲 Day 7】全文搜索与相关性评分 文章标签:Elasticsearch, 全文搜索, 相关性评分, TF-IDF, BM25, 面试, 搜索引擎, 后端开发, 大数据 文章简述: 本文是“Elasticsearch面试精讲”系列的第7天,聚焦于全文搜索与相关性评…

Vllm-0.10.1:vllm bench serve参数说明

一、KVM 虚拟机环境 GPU:4张英伟达A6000(48G) 内存:128G 海光Cpu:128核 大模型:DeepSeek-R1-Distill-Qwen-32B 推理框架Vllm:0.10.1 二、测试命令(random ) vllm bench serve \ --backend vllm \ --base-url http://127.0.…

B.50.10.11-Spring框架核心与电商应用

Spring框架核心原理与电商应用实战 核心理念: 本文是Spring框架深度指南。我们将从Spring的两大基石——IoC和AOP的底层原理出发,详细拆解一个Bean从定义到销毁的完整生命周期,并深入探讨Spring事务管理的实现机制。随后,我们将聚焦于Spring …

雅菲奥朗SRE知识墙分享(六):『混沌工程的定义与实践』

混沌工程不再追求“永不宕机”的童话,而是主动在系统中注入可控的“混乱”,通过实验验证系统在真实故障场景下的弹性与自我修复能力。混沌工程不是简单的“搞破坏”,也不是运维团队的专属游戏。它是一种以实验为导向、以度量为核心、以文化为…

从0死磕全栈第五天:React 使用zustand实现To-Do List项目

代码世界是现实的镜像,状态管理教会我们:真正的控制不在于凝固不变,而在于优雅地引导变化。 这是「从0死磕全栈」系列的第5篇文章,前面我们已经完成了环境搭建、路由配置和基础功能开发。今天,我们将引入一个轻量级但强大的状态管理工具 —— Zustand,来实现一个完整的 T…

力扣29. 两数相除题解

原题链接29. 两数相除 - 力扣(LeetCode) 主要不能用乘除取余,于是用位运算代替: Java题解 class Solution {public int divide(int dividend, int divisor) {//全都转为负数计算, 避免溢出, flag记录结果的符号int flag 1;if(…

【工具类】Nuclei YAML POC 编写以及批量检测

Nuclei YAML POC 编写以及批量检测法律与道德使用声明前言Nuclei 下载地址下载对应版本的文件关于检查cpu架构关于hkws的未授权访问参考资料关于 Neclei Yaml 脚本编写BP Nuclei Template 插件下载并安装利用插件编写 POC YAML 文件1、找到有漏洞的页面抓包发送给插件2、同时将…

自动化运维之ansible

一、认识自动化运维假如管理很多台服务器,主要关注以下几个方面“1.管理机与被管理机的连接(管理机如何将管理指令发送给被管理机)2.服务器信息收集(如果被管理的服务器有centos7.5外还有其它linux发行版,如suse,ubunt…

【温室气体数据集】亚洲地区长期空气污染物和温室气体排放数据 REAS

目录 REAS 数据集概述 REAS 数据版本及特点 数据内容(以 REASv3.2.1 为例) 数据形式 数据下载 参考 REAS 数据集(Regional Emission inventory in ASia,亚洲区域排放清单)是由日本国立环境研究所(NIES)及相关研究人员开发的一个覆盖亚洲地区长期空气污染物和温室气体排放…

中州养老项目:利用Redis解决权限接口响应慢的问题

目录 在Java中使用Redis缓存 项目中集成SpringCache 在Java中使用Redis缓存 Redis作为缓存,想要在Java中操作Redis,需要 Java中的客户端操纵Redis就像JDBC操作数据库一样,实际底层封装了对Redis的基础操作 如何在Java中使用Redis呢?先导入Redis的依赖,这个依赖导入后相当于把…

MathJax - LaTeX:WordPress 公式精准呈现方案

写在前面:本博客仅作记录学习之用,部分图片来自网络,如需引用请注明出处,同时如有侵犯您的权益,请联系删除! 文章目录前言安装 MathJax-LaTeX 插件修改插件文件效果总结互动致谢参考前言 在当今知识传播与…

详细解读Docker

1.概述Docker是一种优秀的开源的容器化平台。用于部署、运行应用程序,它通过将应用及其依赖打包成轻量级、可移植的容器,实现高效一致的运行效果,简单来说,Docker就是一种轻量级的虚拟技术。2.核心概念2.1.容器(Contai…