卷积神经网络实战:MNIST手写数字识别

夜渐深,我还在😘

老地方

睡觉了🙌

文章目录

  • 📚 卷积神经网络实战:MNIST手写数字识别
    • 🧠 4.1 预备知识
      • ⚙️ 4.1.1 `torch.nn.Conv2d()` 三维卷积操作
      • 📏 4.1.2 `nn.MaxPool2d()` 池化层的作用
    • 📥 4.2 数据输入与处理
      • 🗃️ MNIST数据集加载
      • 🔍 数据格式验证
    • 🚀 4.3 卷积模型构建与训练
      • 🧩 4.3.1 网络架构设计
      • ⚡ 4.3.2 GPU加速与模型初始化
      • 📉 4.3.3 训练与评估函数
      • 🔁 4.3.4 模型训练循环
    • 🧪 4.4 函数式API
      • 🔌 4.4.1导入函数式模块
      • ⚡ 4.4.2激活函数应用
      • 🧮 4.4.3池化操作实现

📚 卷积神经网络实战:MNIST手写数字识别

🧠 4.1 预备知识

⚙️ 4.1.1 torch.nn.Conv2d() 三维卷积操作

torch.nn.Conv2d()是PyTorch中实现三维卷积的核心方法,其关键参数包括:

  • in_channels:输入通道数(彩色图为3,灰度图为1)
  • out_channels:输出通道数(卷积核数量)
  • kernel_size:卷积核尺寸(如3×3)
  • stride:步长(默认为1)
  • padding:填充(默认为0)
import torch
from torch import nn# 创建随机输入数据 (batch_size=20, 通道=3, 高=256, 宽=356)
input = torch.randn(20, 3, 256, 256) # 定义卷积层:输入通道3→输出通道16,3×3卷积核,步长1,填充1
conv_layer = nn.Conv2d(3, 16, (3, 3), stride=1, padding=1)# 执行卷积操作
output = conv_layer(input)
output.shape  # torch.Size([20, 16, 256, 256])

💡 输出解析:经过卷积后,特征图尺寸保持256×256不变(因padding=1),通道数从3增加到16

📏 4.1.2 nn.MaxPool2d() 池化层的作用

池化层的重要性

  1. 🎯 增大感受野:小卷积核视野有限,池化间接扩大覆盖区域
  2. 🛡️ 降低过拟合:减少参数量,增强模型泛化能力
  3. 加速计算:缩减特征图尺寸,减少后续计算量

核心参数kernel_size(池化窗口尺寸)

# 创建随机图像批次 (64张256×256的RGB图像)
img_batch = torch.randn(64, 3, 256, 256)# 2×2最大池化操作
pool_out = torch.max_pool2d(img_batch, kernel_size=(2, 2))
pool_out.shape  # torch.Size([64, 3, 128, 128])

💡 输出解析:池化后图像尺寸减半(256→128),通道数不变,实现特征降维


📥 4.2 数据输入与处理

🗃️ MNIST数据集加载

使用PyTorch内置工具加载手写数字数据集:

import torchvision
from torchvision.transforms import ToTensor# 下载并加载训练集/测试集
train_ds = torchvision.datasets.MNIST("data/", train=True, transform=ToTensor(), download=True
)
test_ds = torchvision.datasets.MNIST("data/", train=False, transform=ToTensor(), download=True
)# 创建数据加载器 (batch_size=64)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=64)

🔍 数据格式验证

imgs, labels = next(iter(train_dl))
print(imgs.shape, labels.shape)  # torch.Size([64, 1, 28, 28]) torch.Size([64])

数据格式:符合卷积网络输入要求(batch_size, 通道, 高, 宽)


🚀 4.3 卷积模型构建与训练

🧩 4.3.1 网络架构设计

LeNet风格CNN模型

class Model(nn.Module):def __init__(self):super(Model, self).__init__()# 卷积层1:1→6通道,5×5卷积核self.conv1 = nn.Conv2d(1, 6, 5)  # 卷积层2:6→16通道,5×5卷积核self.conv2 = nn.Conv2d(6, 16, 5)  # 全连接层1:256→256节点self.linear1 = nn.Linear(16*4*4, 256)  # 输出层:256→10节点 (10个数字类别)self.linear2 = nn.Linear(256, 10)  def forward(self, x):# 卷积→ReLU→池化 (28×28 → 12×12)x = torch.max_pool2d(torch.relu(self.conv1(x)), (2, 2))  # 卷积→ReLU→池化 (12×12 → 4×4)x = torch.max_pool2d(torch.relu(self.conv2(x)), (2, 2))  # 展平特征图x = x.view(-1, 16*4*4)  # 全连接层→ReLUx = torch.relu(self.linear1(x))  # 输出层return self.linear2(x)  

⚡ 4.3.2 GPU加速与模型初始化

# 自动检测GPU加速
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Model().to(device)
model
Model((conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))(linear1): Linear(in_features=256, out_features=256, bias=True)(linear2): Linear(in_features=256, out_features=10, bias=True)
)

📉 4.3.3 训练与评估函数

# 训练函数
def train(dataloader, model, loss_fn, optimizer):model.train()total_samples = len(dataloader.dataset)total_batches = len(dataloader)train_loss, correct = 0, 0for X, y in dataloader:X, y = X.to(device), y.to(device)# 前向传播pred = model(X)loss = loss_fn(pred, y)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 统计指标with torch.no_grad():correct += (pred.argmax(1) == y).sum().item()train_loss += loss.item()return train_loss/total_batches, correct/total_samples# 测试函数
def test(dataloader, model):model.eval()total_samples = len(dataloader.dataset)total_batches = len(dataloader)test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).sum().item()return test_loss/total_batches, correct/total_samples

🔁 4.3.4 模型训练循环

# 超参数设置
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
loss_fn = nn.CrossEntropyLoss()
epochs = 20# 训练日志
for epoch in range(epochs):train_loss, train_acc = train(train_dl, model, loss_fn, optimizer)test_loss, test_acc = test(test_dl, model)# 打印训练进度print(f"epoch:{epoch:2d}, train_loss:{train_loss:.5f}, "f"train_acc:{train_acc*100:.1f}%, test_loss:{test_loss:.5f}, "f"test_acc:{test_acc*100:.1f}%")

训练输出

epoch: 0, train_loss:0.24543, train_acc:92.8%, test_loss:0.07341, test_acc:97.7%
epoch: 1, train_loss:0.06720, train_acc:97.9%, test_loss:0.04788, test_acc:98.4%
...
epoch:19, train_loss:0.00509, train_acc:99.8%, test_loss:0.04585, test_acc:99.2%
Done

🎯 性能总结:模型在20个epoch内达到**99.2%**的测试准确率,显著优于全连接网络


🧪 4.4 函数式API

🔌 4.4.1导入函数式模块

import torch.nn.functional as F # 行业标准导入方式

⚡ 4.4.2激活函数应用

# 传统方式
output = torch.relu(input)# 函数式API方式
output = F.relu(input)

🧮 4.4.3池化操作实现

# 传统方式
pooled = torch.max_pool2d(input, kernel_size=2)# 函数式API方式
pooled = F.max_pool2d(input, kernel_size=2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/88862.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/88862.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HarmonyOS应用无响应(AppFreeze)深度解析:从检测原理到问题定位

HarmonyOS应用无响应(AppFreeze)深度解析:从检测原理到问题定位 在日常应用使用中,我们常会遇到点击无反应、界面卡顿甚至完全卡死的情况——这些都可能是应用无响应(AppFreeze) 导致的。对于开发者而言&am…

湖北设立100亿元人形机器人产业投资母基金

湖北设立100亿元人形机器人产业投资母基金 湖北工信 2025年07月08日 12:03 湖北 ,时长01:20 近日,湖北设立100亿元人形机器人产业投资母基金,重点支持人形机器人和人工智能相关产业发展。 人形机器人产业投资母基金由湖北省财政厅依托省政府…

时序预测 | Pytorch实现CNN-LSTM-KAN电力负荷时间序列预测模型

预测效果 代码主要功能 该代码实现了一个结合CNN(卷积神经网络)、LSTM(长短期记忆网络)和KAN(Kolmogorov-Arnold Network)的混合模型,用于时间序列预测任务。主要流程包括: 数据加…

OCR 识别:车牌识别相机的 “火眼金睛”

车牌识别相机在交通管理、停车场收费等场景中,需快速准确识别车牌信息。但实际环境中,车牌可能存在污渍、磨损、光照不均等情况,传统识别方式易出现误读、漏读。OCR 技术让车牌识别相机如虎添翼。它能精准提取车牌上的字符,不管是…

Java面试基础:面向对象(2)

1. 接口里可以定义哪些方法抽象方法:抽象方法是接口的核心部分,所有实现接口的类都必须实现这些方法。抽象方法默认是 public 和 abstract 修饰,这些修饰符可以省略。public interface Animal {void Sound(); }默认方法:默认方法是…

有哪些更加简洁的for循环?循环语句?

目录 简洁的for循环 循环过程修改循环变量 循环语句 不同编程语言支持的循环语句 foreach 无限循环 for循环历史 break和continue 循环判断结束值 循环标签 循环语句优化 循环表达式返回值 简洁的for循环 如果需要快速枚举一个集合的元素,尽管C语言可以…

RK3568/3588 Android 12 源码默认使用蓝牙mic录音

遇到客户一个需求,如果连接了带mic的蓝牙耳机,默认所有的录音要走蓝牙mic通道。这个功能搞了好久,终于搞定了。1. 向RK寻求帮助,先打通 bt sco能力。此时,还无法默认就切换到蓝牙 mic通道,接下来我们需求默…

解锁HTTP:从理论到实战的奇妙之旅

目录一、HTTP 协议基础入门1.1 HTTP 协议是什么1.2 HTTP 协议的特点1.3 HTTP 请求与响应的结构二、HTTP 应用场景大揭秘2.1 网页浏览2.2 API 调用2.3 文件传输2.4 内容分发网络(CDN)2.5 流媒体服务三、HTTP 应用实例深度剖析3.1 使用 JavaScript 的 fetc…

uvm_config_db examples

通过uvm_config_db类访问的UVM配置数据库,是在多个测试平台组件之间传递不同对象的绝佳方式。 methods 有两个主要函数用于从数据库中放入和检索项目,分别是 set() 和 get()。 static function void set ( uvm_component cntxt,string inst_name,string …

(C++)任务管理系统(文件存储)(正式版)(迭代器)(list列表基础教程)(STL基础知识)

目录 前言: 源代码: 代码解析: 一.头文件和命名空间 1. #include - 输入输出功能2. #include - 链表容器3. #include - 字符串处理4. using namespace std; - 命名空间 可视化比喻:建造房子 🏠 二.menu()函数 …

Java 中的异步编程详解

前言 在现代软件开发中,异步编程(Asynchronous Programming) 已经成为构建高性能、高并发应用程序的关键技术之一。Java 作为一门广泛应用于后端服务开发的语言,在其发展过程中不断引入和优化异步编程的支持。从最初的 Thread 和…

MySQL逻辑删除与唯一索引冲突解决

问题背景 在MySQL数据库设计中,逻辑删除(软删除)是一种常见的实践,它通过设置标志位(如is_delete)来标记记录被"删除",而不是实际删除数据。然而,当表中存在唯一约束时&am…

php命名空间用正斜杠还是反斜杠?

在PHP中,命名空间使用反斜杠(\)作为分隔符,这是PHP语言规范明确规定的。反斜杠在命名空间中扮演路径分隔的角色,用于区分不同层级的命名空间。 具体说明:语法规则 PHP命名空间使用反斜杠(\&…

《从依赖纠缠到接口协作:ASP.NET Core注入式开发指南》

在C#的ASP.NET Core开发中,依赖注入绝非简单的技术技巧,而是重构代码关系的底层逻辑。它像一套隐形的神经网络,让程序模块摆脱硬编码的束缚,在运行时实现动态连接,从而为系统注入可测试、可进化的核心生命力。理解其深…

星云ERP本地环境搭建笔记

看到星云ERP两个比较实用的功能,编号规则和打印模板,如下图所示,于是本地跑起来学习学习。开发环境必备:1. JDK 1.82. MySQL 5.73. Redis 44. RabbitMQ 3.12.45. nodejs 206. pnpm 9.7.1 (npm install -g pnpm9.7.1)其他开发工具&…

RedisJSON 的 `JSON.ARRAPPEND`一行命令让数组动态生长

1 、 为什么选择 JSON.ARRAPPEND 在传统的键值模型里,若要往数组尾部追加元素,通常需要 取→改→写 三步: GET 整个 JSON;在应用层把元素 push 进数组;SET 回 Redis。 一条 JSON.ARRAPPEND 则可一次完成,具…

14:00开始面试,14:08就出来了,问的问题有点变态。。。

从小厂出来,没想到在另一家公司又寄了。 到这家公司开始上班,加班是每天必不可少的,看在钱给的比较多的份上,就不太计较了。没想到4月一纸通知,所有人不准加班,加班费不仅没有了,薪资还要降40%…

Unity物理系统由浅入深第四节:物理约束求解与稳定性

Unity物理系统由浅入深第一节:Unity 物理系统基础与应用 Unity物理系统由浅入深第二节:物理系统高级特性与优化 Unity物理系统由浅入深第三节:物理引擎底层原理剖析 Unity物理系统由浅入深第四节:物理约束求解与稳定性 物理引擎的…

深入浅出Kafka Consumer源码解析:设计哲学与实现艺术

一、Kafka Consumer全景架构 1.1 核心组件交互图 #mermaid-svg-JDEEOd2M5PzLkYa6 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-JDEEOd2M5PzLkYa6 .error-icon{fill:#552222;}#mermaid-svg-JDEEOd2M5PzLkYa6 .erro…

Matplotlib(一)- 数据可视化与Matplotlib

文章目录一、数据可视化1. 数据可视化的概念2. 数据可视化流程3. 数据可视化目的4. 常见的可视化图表4.1 折线图4.2 柱形图4.3 条形图4.4 堆积图4.4.1 堆积面积图4.4.2 堆积柱形图和堆积条形图4.5 直方图4.6 箱形图4.7 饼图4.8 散点图4.9 气泡图4.10 误差棒图4.11 雷达图二、Py…