用 PyTorch 搭建 CNN 实现 MNIST 手写数字识别

在图像识别领域,卷积神经网络(CNN) 凭借其对空间特征的高效提取能力,成为手写数字识别、人脸识别等任务的首选模型。而 MNIST(手写数字数据集)作为入门级数据集,几乎是每个深度学习学习者的 “第一个项目”。

本文将带大家从零开始,用 PyTorch 搭建一个 CNN 模型完成 MNIST 手写数字识别任务,不仅会贴出完整代码,还会逐行解析核心逻辑,帮你搞懂 “每个参数为什么这么设”“每一层的作用是什么”,即使是刚接触 PyTorch 的新手也能轻松跟上。

一、前置知识与环境准备

在开始前,我们需要先明确两个核心背景,以及搭建好运行环境:

1. 核心背景速览

  • MNIST 数据集:包含 70000 张 28×28 像素的灰度手写数字图片(0-9),其中 60000 张为训练集,10000 张为测试集,每张图片对应一个 “数字类别” 标签(0-9)。
  • CNN 为什么适合?:相比全连接神经网络,CNN 通过 “卷积层提取局部特征(边缘、纹理)+ 池化层下采样”,能大幅减少参数数量、避免过拟合,同时更好地保留图像的空间结构信息。

2. 环境准备

需要安装 PyTorch 和 TorchVision(PyTorch 官方的计算机视觉库,内置 MNIST 数据集):

# pip安装命令(根据系统自动匹配版本,若需指定CUDA版本可参考PyTorch官网)
pip install torch torchvision

验证环境是否安装成功:

import torch
print(torch.__version__)  # 输出PyTorch版本,如2.0.1
print(torch.cuda.is_available())  # 输出True表示支持GPU加速(需NVIDIA显卡)

二、完整代码先行(可直接运行)

先贴出完整可运行的代码,后面会逐段拆解解析:

注意:nn.Sequential()是将网络层组合在一起,内部不能写函数

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor# 1. 加载MNIST数据集
train_data = datasets.MNIST(root='data',          # 数据保存路径train=True,           # 加载训练集download=True,        # 若路径下无数据则自动下载transform=ToTensor()  # 将图像转为Tensor(0-1归一化+维度调整:(H,W,C)→(C,H,W))
)
test_data = datasets.MNIST(root='data',train=False,          # 加载测试集download=True,transform=ToTensor()
)# 2. 数据加载器(分批处理数据)
train_loader = DataLoader(train_data, batch_size=64)  # 每批64个样本
test_loader = DataLoader(test_data, batch_size=64)# 3. 设备配置(优先GPU,其次CPU)
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')  # 打印当前使用的设备# 4. 定义CNN模型
class CNN(nn.Module):def __init__(self):super().__init__()# 卷积块1:输入(1,28,28) → 输出(8,14,14)self.conv1 = nn.Sequential(# 卷积层:1个输入通道→8个输出通道,卷积核5×5,步长1,填充2nn.Conv2d(in_channels=1, out_channels=8, kernel_size=5, stride=1, padding=2),nn.ReLU(),  # 激活函数(引入非线性)nn.MaxPool2d(kernel_size=2)  # 池化层:2×2下采样,尺寸减半)# 卷积块2:输入(8,14,14) → 输出(32,7,7)self.conv2 = nn.Sequential(nn.Conv2d(8, 16, 5, 1, 2),  # 8→16通道,其他参数同上nn.ReLU(),nn.Conv2d(16, 32, 5, 1, 2),  # 16→32通道nn.ReLU(),nn.MaxPool2d(2)  # 下采样后尺寸14→7)# 卷积块3:输入(32,7,7) → 输出(64,7,7)(无池化,保留尺寸)self.conv3 = nn.Sequential(nn.Conv2d(32, 64, 5, 1, 2),  # 32→64通道nn.ReLU(),nn.Conv2d(64, 64, 5, 1, 2),  # 64→64通道(加深特征提取)nn.ReLU())# 全连接层:输入(64×7×7) → 输出10(对应10个数字类别)self.out = nn.Linear(64 * 7 * 7, 10)# 前向传播(定义数据在模型中的流动路径)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)  # 展平:(batch_size, 64,7,7) → (batch_size, 64×7×7)output = self.out(x)return output# 5. 初始化模型并移至指定设备
model = CNN().to(device)
print(model)  # 打印模型结构,验证是否正确# 6. 定义训练函数
def train(dataloader, model, loss_fn, optimizer):model.train()  # 启用训练模式(如BatchNorm、Dropout会生效)batch_count = 1  # 计数批次,用于打印日志for X, y in dataloader:# 将数据移至指定设备(GPU/CPU)X, y = X.to(device), y.to(device)# 前向传播:计算模型预测值pred = model(X)# 计算损失(多分类任务用CrossEntropyLoss)loss = loss_fn(pred, y)# 反向传播:更新模型参数optimizer.zero_grad()  # 清空上一轮梯度(避免累积)loss.backward()        # 计算梯度(反向传播)optimizer.step()       # 根据梯度更新参数(优化器执行)# 每100个批次打印一次损失(监控训练进度)if batch_count % 100 == 0:loss_value = loss.item()  # 取出损失值(脱离计算图)print(f'Batch: {batch_count:>4} | Loss: {loss_value:>6.4f}')batch_count += 1# 7. 定义测试函数
def test(dataloader, model, loss_fn):model.eval()  # 启用评估模式(关闭BatchNorm、Dropout)total_samples = len(dataloader.dataset)  # 测试集总样本数correct = 0  # 正确预测的样本数total_loss = 0  # 总损失# 禁用梯度计算(测试阶段无需更新参数,节省内存)with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)# 累积损失和正确数total_loss += loss_fn(pred, y).item()# pred.argmax(1):取每行最大概率的索引(即预测类别),与y比较correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率avg_loss = total_loss / len(dataloader)  # len(dataloader) = 总批次accuracy = (correct / total_samples) * 100  # 准确率(百分比)print(f'\nTest Result | Accuracy: {accuracy:>5.2f}% | Avg Loss: {avg_loss:>6.4f}\n')# 8. 配置训练参数并执行
loss_fn = nn.CrossEntropyLoss()  # 多分类交叉熵损失(内置Softmax)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Adam优化器,学习率0.001
epochs = 10  # 训练轮次(整个训练集遍历10次)# 循环训练+测试
for epoch in range(epochs):print(f'=================== Epoch {epoch + 1}/{epochs} ===================')train(train_loader, model, loss_fn, optimizer)  # 训练一轮test(test_loader, model, loss_fn)  # 测试一轮print("Training Finished!")

三、核心代码逐段解析

上面的代码看似长,但逻辑很清晰,我们按 “数据→模型→训练→测试” 的流程拆解核心部分。

1. 数据加载与预处理

MNIST 数据集的加载全靠torchvision.datasets.MNIST,无需手动下载和解析,非常方便。关键参数解析:

  • root='data':数据会保存在当前目录的data文件夹下(自动创建);
  • train=True/FalseTrue加载 6 万张训练集,False加载 1 万张测试集;
  • transform=ToTensor():这是核心预处理步骤,作用有两个:
    1. 将图像从 “PIL 格式(0-255 像素值)” 转为 “Tensor 格式(0-1 归一化值)”,避免大数值导致梯度爆炸;
    2. 调整维度:从图像默认的(高度H, 宽度W, 通道C)转为 PyTorch 要求的(通道C, 高度H, 宽度W)(MNIST 是灰度图,C=1)。

然后用DataLoader将数据集分批:

  • batch_size=64:每次训练取 64 个样本计算梯度(batch_size 越大,训练越稳定,但内存占用越高);
  • DataLoader会自动打乱训练集(默认shuffle=True),避免模型学习到 “样本顺序” 的无关特征。

2. 设备配置:GPU 加速有多重要?

代码中这行是 “硬件适配” 的关键:

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
  • cuda:NVIDIA 显卡的 GPU 加速(训练 10 轮可能只需 1-2 分钟);
  • mps:苹果芯片(M1/M2)的 GPU 加速;
  • cpu:默认选项(训练 10 轮可能需要 10-20 分钟,速度慢很多)。

后续通过model.to(device)X.to(device),将模型和数据都移到指定设备上,确保计算在同一设备进行(否则会报错)。

3. CNN 模型搭建(核心中的核心)

我们定义的CNN类继承自nn.Module(PyTorch 所有模型的基类),核心是__init__(定义层)和forward(定义数据流动)。

先看模型结构总览

输入(1,28,28) → 卷积块1 → 输出(8,14,14) → 卷积块2 → 输出(32,7,7) → 卷积块3 → 输出(64,7,7) → 展平 → 全连接层 → 输出(10)
(1)卷积层参数解析

conv1的第一个卷积层为例:

nn.Conv2d(in_channels=1, out_channels=8, kernel_size=5, stride=1, padding=2)
  • in_channels=1:输入通道数(MNIST 是灰度图,所以 1);
  • out_channels=8:输出通道数 = 卷积核数量(8 个卷积核,提取 8 种不同特征);
  • kernel_size=5:卷积核大小(5×5 的窗口,比 3×3 能提取更复杂的局部特征);
  • stride=1:卷积核每次滑动 1 个像素(步长越小,特征保留越完整);
  • padding=2:填充(在图像边缘补 2 个像素),目的是让卷积后图像尺寸不变:
    👉 尺寸计算公式:输出尺寸 = (输入尺寸 - 卷积核尺寸 + 2×padding) / stride + 1
    👉 代入:(28 - 5 + 2×2)/1 + 1 = 28,所以卷积后还是 28×28。
(2)激活函数 ReLU

每个卷积层后都加nn.ReLU(),作用是引入非线性

  • 没有激活函数的话,多个卷积层叠加还是线性变换,无法拟合复杂数据;
  • ReLU 的公式:ReLU(x) = max(0, x),计算简单、梯度不易消失,是目前最常用的激活函数。
(3)池化层 MaxPool2d

nn.MaxPool2d(kernel_size=2)是 2×2 最大池化,作用是下采样

  • 尺寸减半:28×28→14×14,14×14→7×7,大幅减少后续计算量;
  • 保留关键特征:取 2×2 窗口的最大值,相当于 “强化局部最显著的特征”,提高模型鲁棒性。
(4)展平与全连接层

卷积块 3 输出的是(batch_size, 64, 7, 7)的张量(batch_size 是每批样本数),需要用x.view(x.size(0), -1)展平为(batch_size, 64×7×7)的一维向量,才能输入全连接层:

  • x.size(0):获取 batch_size(确保展平后每一行对应一个样本);
  • -1:让 PyTorch 自动计算剩余维度(64×7×7=3136);
  • 全连接层nn.Linear(3136, 10):将 3136 维特征映射到 10 维(对应 0-9 的 10 个类别)。

4. 训练函数:模型如何 “学习”?

训练的核心是 “前向传播算损失→反向传播求梯度→优化器更新参数” 的循环:

  1. model.train():启用训练模式(比如如果模型有 BatchNorm,会计算当前批次的均值和方差);
  2. 前向传播:pred = model(X),用当前模型参数计算预测值;
  3. 计算损失:loss = loss_fn(pred, y),用CrossEntropyLoss(多分类任务专用,内置了 Softmax,无需手动在模型输出加 Softmax);
  4. 反向传播:
    • optimizer.zero_grad():清空上一轮的梯度(如果不清空,梯度会累积,导致参数更新错误);
    • loss.backward():自动计算所有可训练参数的梯度(PyTorch 的自动微分机制);
    • optimizer.step():用计算出的梯度更新参数(Adam 优化器会自适应调整学习率,比 SGD 收敛更快)。

5. 测试函数:模型学得怎么样?

测试阶段不需要更新参数,核心是计算 “准确率” 和 “平均损失”:

  1. model.eval():启用评估模式(关闭 BatchNorm 的批次统计更新、关闭 Dropout);
  2. with torch.no_grad():禁用梯度计算(节省内存,加速测试);
  3. 准确率计算:pred.argmax(1) == y,比较预测类别和真实类别,求和后除以总样本数。

四、预期结果与优化方向

1. 预期训练结果

在 GPU 上训练 10 轮后,通常能达到:

  • 测试准确率:98.5% 以上(甚至 99%);
  • 测试平均损失:0.04 以下。

训练过程中,损失会逐渐下降,准确率会逐渐上升(如果出现损失不下降或准确率波动,可能是学习率太大或 batch_size 太小)。

2. 模型优化方向

如果想进一步提升性能,可以尝试这些改进:

  1. 增加 Dropout 层:在卷积层或全连接层后加nn.Dropout(0.2),随机 “关闭” 20% 的神经元,防止过拟合;
  2. 使用学习率调度:比如torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5),每 5 轮将学习率减半,后期精细调整;
  3. 加深网络:增加卷积块数量(比如再加一个 conv4),或增加每个卷积层的输出通道数;
  4. 数据增强:用torchvision.transforms添加旋转、平移、缩放等操作,比如:
    transform = transforms.Compose([transforms.RandomRotation(5),  # 随机旋转±5度transforms.ToTensor()
    ])

        数据增强能让模型看到更多 “变种” 样本,提升泛化能力。

五、总结

本文用 PyTorch 实现了一个基础的 CNN 模型,完成了 MNIST 手写数字识别任务,核心收获包括:

  1. 掌握了 PyTorch 加载数据集、搭建 CNN 模型的基本流程;
  2. 理解了卷积层、池化层、激活函数的作用和参数意义;
  3. 熟悉了 “训练 - 测试” 的循环逻辑,以及 GPU 加速的配置方法。

MNIST 是入门任务,但 CNN 的核心思想(特征提取 + 下采样)可以迁移到更复杂的图像任务(如 CIFAR-10、ImageNet)。建议大家动手修改代码,比如调整卷积核大小、学习率、网络层数,观察结果变化,这样才能真正理解每个参数的影响~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/95105.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/95105.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CTFshow系列——命令执行web61-68

本篇文章介绍了不同了方法进行题目的解析以及原因讲解。 文章目录Web61尝试了一下,被过滤的payload如下:所以,根据上述思路,这里尝试过的payload为:Web62(同Web61)Web63(同Web62&…

.Net程序员就业现状以及学习路线图(二)

一、.NET程序员就业现状分析 1. 市场需求与岗位分布 2025年.NET开发岗位全国招聘职位约1676个,占全国技术岗位的0.009%,主要集中在一线城市如深圳、上海等地。就业单位类型分布为:软件公司占43.3%,研发机构占33.1%,物联…

MTK Linux DRM分析(二十二)- MTK mtk_drm_crtc.c(Part1)

一、代码分析 mtk_drm_crtc.c以mtk_crtc_comp_is_busy函数为界限进行拆分分析 static const struct drm_crtc_funcs mtk_crtc_funcs = {.set_config = drm_atomic_helper_set_config,.page_flip = drm_atomic_helper_page_flip,.destroy = mtk_drm_crtc_destroy,.reset = mtk…

stm32f103c8t6 led闪灯实验

目录 闪灯原理 2种接线方式控制闪灯 使用推挽接法 使用开漏接法 看原理图 写代码 闪灯原理 LED灯有个2-10mA的电流就可以点亮 3.3/5100.006A6mA 2种接线方式控制闪灯 使用推挽接法 当设置推挽模式时,CPU控制寄存器写0,IO引脚输出低电压&#xff0…

“我同意”按钮别乱点——你的“职业EULA”漏洞扫描报告

尊敬的审核: 本人文章《“我同意”按钮别乱点——你的“职业EULA”漏洞扫描报告》 1. 纯属技术交流,无任何违法内容 2. 所有法律引用均来自公开条文 3. 请依据《网络安全法》第12条“不得无故删除合法内容”处理 附:本文结构已通过区块链存证…

Product Hunt 每日热榜 | 2025-09-01

1. A01 标语:你个人的新闻助手 介绍:A01 是你的新闻助手,可以帮你关注你关心的任何话题。只需告诉它你想了解什么,它就能为你带来最新的文章。 产品网站: 立即访问 Product Hunt: View on Product Hunt…

【OpenFeign】基础使用

【OpenFeign】基础使用1. Feign介绍1.1 使用示例1.2 Feign与RPC对比1.3 SpringCloud Alibaba快速整合OpenFeign1.3.1 详细代码1. Feign介绍 1.什么是 Feign Feign 是 Netflix 开发的一个 声明式的 HTTP 客户端,在 Spring Cloud 中被广泛使用。它的目标是&#xff…

访问相同的url,相同入参的请求,Apifox/Postman可以正常响应结果,而本地调用不行(或结果不同)

文章目录问题概述Apifox查看实际请求总结问题概述 开发中有一个需求需要去别的系统中拿数据,配置好相关参数后发起请求时发现响应结果和在Apifox上不同,Apifox上正常显示数据,而本地调用后返回数据不存在。 这就很奇怪了,想了很多…

数据结构(C语言篇):(七)双向链表

目录 前言 一、概念与结构 二、双向链表的实现 2.1 头文件的准备 2.2 函数的实现 2.2.1 LTPushBack( )函数(尾插) (1)LTBuyNode( ) (2)LTInit( ) (3)LTPrint( ) &#x…

从拿起简历(resume)重新找工作开始聊起

经济萧条或经济衰退在经济相关学术上似乎有着严格的定义,我不知道我们的经济是否已经走向了衰退或者萧条,但有一点那是肯定的,那就现在我们的经济肯定是不景气的。经济不景气会怎么样?是的,会有很多人失业,…

OS+MySQL+(其他)八股小记

鲁迅先生曾经说过,每天进步一点点,妈妈夸我小天才。 依旧今日八股,这是我在多个文档整合一起的,可能格式有些问题,请谅解。 操作系统 1.进程和线程的区别? 进程是代码在数据集合的一次执行活动,…

Transformer的并行计算与长序列处理瓶颈总结

🌟 第0层:极简版(30秒理解)一句话核心:Transformer像圆桌会议——所有人都能同时交流(并行优势),但人越多会议越混乱(长序列瓶颈)。核心问题 并行优势&#x…

Vue 3 useId 完全指南:生成唯一标识符的最佳实践

📖 概述 useId() 是 Vue 3 中的一个组合式 API 函数,用于生成唯一的标识符。它确保在服务端渲染(SSR)和客户端渲染之间生成一致的 ID,避免水合不匹配的问题。 🎯 基本概念 什么是 useId? useId…

CGroup 资源控制组 + Docker 网络模式

1 CGroup 资源控制组1.1 为什么需要 CGroup - 容器本质 宿主机上一组进程 - 若无资源边界,一个暴走容器即可拖垮整机 - CGroup 提供**内核级硬限制**,比 ulimit、nice 更可靠1.2 核心概念 3 件套 | 概念 | 一句话解释 | 查看方式 | | Hierarchy | 树…

【ArcGIS微课1000例】0150:如何根据地名获取经纬度坐标

本文介绍了三种获取地理坐标的方法:1)在ArcGIS Pro中通过搜索功能定位目标点(如月牙泉)并查看其WGS84坐标;2)使用ArcGIS内置工具获取坐标;3)推荐三个在线工具(maplocation、地球在线、yanue)支持批量查询和多地图源坐标转换。强调了使用WGS84坐标系以减少误差,并展示…

HTML应用指南:利用GET请求获取MSN财经股价数据并可视化

随着数字化金融服务的不断深化,及时、准确的财经信息已成为投资者决策与市场分析的重要支撑。MSN财经股价数据服务作为广受信赖的金融信息平台,依托微软强大的技术架构与数据整合能力,持续为全球用户提供全面、可靠的证券市场数据。平台不仅提…

雅思听力第四课:配对题核心技巧与词汇深化

现在,请拿出剑桥真题,开始你的刻意练习! 内容大纲 课程核心目标旧题回顾与基础巩固配对题/匹配题核心解题策略考点总结与精听训练表 一、课程核心目标 掌握第二部分配对题的解题策略攻克第三部分匹配题的改写难点系统整理高频场景词汇与特…

SQL Server从入门到项目实践(超值版)读书笔记 25

第12章 存储过程的应用 🎉学习指引 存储过程(Stored Procedure)是在大型数据库系统中,一组为了完成特定功能的SQL语句集,存储过程时数据库中的一个重要对象,它代替了传统的逐条执行SQL语句的方式。本章就来…

20.29 QLoRA适配器实战:24GB显卡轻松微调650亿参数大模型

QLoRA适配器实战:24GB显卡轻松微调650亿参数大模型 QLoRA 适配器配置深度解析 一、QLoRA 适配器核心原理 QLoRA 作为当前大模型微调领域的前沿技术,通过量化与低秩适配的协同设计,在保证模型效果的前提下实现了显存占用的革命性降低。其核心由三大技术支柱构成: 4位量化…

QMainWindow使用QTabWidget添加多个QWidget

QTabWidget添加其它Wdiget的2个函数如下&#xff1a; QTabWidget的介绍可参考官网QTabWidget Class | Qt Widgets | Qt 6.9.1 直接上代码&#xff0c;代码如下&#xff1a; #include <QMainWindow>#include <QApplication> #include <QVBoxLayout> #includ…