神经网络模型搭建及手写数字识别案例

代码实现:

import torch
print(torch.__version__)
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
training_data = datasets.MNIST(root='data',train=True,download=True,transform=ToTensor())
test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor())
from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):img,labels = training_data[i + 59000]figure.add_subplot(3,3,i + 1)plt.title(labels)plt.axis('off')plt.imshow(img.squeeze(),cmap="gray")
plt.show()
train_dataloader = DataLoader(training_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)for X,y in test_dataloader:print(f"Shape of X[N,C,H,W]:{X.shape}")print(f"Shape of y: {y.shape} {y.dtype}")breakdevice = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using {device} device")class NeuralNetwork(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.hidden1 = nn.Linear(28*28,128)self.hidden2 = nn.Linear(128,256)self.out = nn.Linear(256,10)def forward(self,x):x = self.flatten(x)x = self.hidden1(x)x = torch.sigmoid(x)x = self.hidden2(x)x = torch.sigmoid(x)x = self.out(x)return xmodel = NeuralNetwork().to(device)
print(model)def train(dataloader,model,loss_fn,optimizer):model.train()batch_size_num = 1for X ,y in dataloader:X,y = X.to(device),y.to(device)pred = model.forward(X)loss = loss_fn(pred,y)optimizer.zero_grad()loss.backward()optimizer.step()loss_value = loss.item()if batch_size_num % 100 ==0:print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num +=1
def test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches= len(dataloader)model.eval()test_loss = 0correct = 0with torch.no_grad():for X ,y in dataloader:X,y = X.to(device),y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_pj_loss = test_loss / num_batchestest_acy = correct / size * 100print(f"Avg loss: {test_pj_loss:>7f} \n Accuray: {test_acy:>5.2f}%")
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.01)
# train(train_dataloader,model,loss_fn,optimizer)
# test(test_dataloader,model,loss_fn)
i=10
for j in range(i):print(f"Epoch {j+1}\n----------")train(train_dataloader, model,loss_fn,optimizer)
print("Done!")
test(test_dataloader,model,loss_fn)

这段代码是通过PyTorch 实现的 MNIST 手写数字识别神经网络,包含数据加载、可视化、模型构建、训练和评估的完整流程。下面分步骤解析:

1. 库导入

import torch  # 导入PyTorch核心库
print(torch.__version__)  # 打印PyTorch版本
from torch import nn  # 导入神经网络模块
from torch.utils.data import DataLoader  # 导入数据加载工具
from torchvision import datasets  # 导入计算机视觉数据集
from torchvision.transforms import ToTensor  # 导入图像转张量的工具
from matplotlib import pyplot as plt  # 导入可视化库

核心库:torch是 PyTorch 的主库,nn用于构建神经网络,DataLoader用于批量加载数据。

数据集:torchvision.datasets提供了 MNIST 等经典数据集,ToTensor将图像(PIL 格式)转换为 PyTorch 张量(便于计算)。

可视化:matplotlib用于展示数据样本。

2. 数据加载

# 加载MNIST训练集(60000张图片),若本地没有则自动下载,并用ToTensor转换
training_data = datasets.MNIST(root='data',  # 数据存储路径train=True,   # 标记为训练集download=True,  # 自动下载transform=ToTensor()  # 转换为张量
)# 加载MNIST测试集(10000张图片),参数含义同上
test_data = datasets.MNIST(root='data',train=False,  # 标记为测试集download=True,transform=ToTensor()
)

MNIST 是手写数字数据集(0-9),每张图片为 28×28 像素的灰度图,常用于图像识别入门。

train=True对应训练集(用于模型学习),train=False对应测试集(用于评估模型性能)。

3. 数据可视化

figure = plt.figure()  # 创建画布
for i in range(9):  # 展示9张图片img, labels = training_data[i + 59000]  # 取训练集中第59000+1到59000+9张图片figure.add_subplot(3, 3, i + 1)  # 3行3列布局plt.title(labels)  # 显示标签(真实数字)plt.axis('off')  # 关闭坐标轴plt.imshow(img.squeeze(), cmap="gray")  # 显示灰度图(squeeze移除多余维度)
plt.show()  # 展示图片

作用:直观查看数据样本,确认数据加载正确(图片与标签是否匹配)。

img.squeeze():原始图像张量形状为(1,28,28)(1 个通道,28×28 像素),squeeze()移除通道维度,变为(28,28)便于显示。

4. 数据批量处理

# 将数据集转换为可迭代的批量数据加载器
train_dataloader = DataLoader(training_data, batch_size=64)  # 训练集,每批64个样本
test_dataloader = DataLoader(test_data, batch_size=64)  # 测试集,每批64个样本# 打印一个测试批次的数据形状
for X, y in test_dataloader:print(f"Shape of X[N,C,H,W]: {X.shape}")  # 输出:[64,1,28,28]print(f"Shape of y: {y.shape} {y.dtype}")  # 输出:[64] torch.int64break

DataLoader的作用:将数据集拆分为多个批次(batch_size=64),支持并行加载,提高训练效率。

数据形状说明:

X(图像):[N, C, H, W],其中N=64(批次大小)、C=1(灰度图通道数)、H=28W=28(图像尺寸)。

y(标签):[64],每个元素是 0-9 的整数(表示图像对应的数字)。

5. 设备选择

# 优先使用GPU(cuda),其次苹果芯片GPU(mps),最后CPU
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f"Using {device} device")

作用:选择计算设备,GPU(cuda/mps)的并行计算能力可大幅加速神经网络训练,CPU 速度较慢。

6. 神经网络模型定义

class NeuralNetwork(nn.Module):  # 继承nn.Module(PyTorch神经网络基类)def __init__(self):super().__init__()  # 初始化父类self.flatten = nn.Flatten()  # 展平层:将28×28图像转为1维向量self.hidden1 = nn.Linear(28*28, 128)  # 全连接层1:784→128(输入784=28×28)self.hidden2 = nn.Linear(128, 256)  # 全连接层2:128→256self.out = nn.Linear(256, 10)  # 输出层:256→10(10个类别,对应0-9)def forward(self, x):  # 定义前向传播(必须实现)x = self.flatten(x)  # 展平:(64,1,28,28)→(64,784)x = self.hidden1(x)  # 第一层计算:(64,784)→(64,128)x = torch.sigmoid(x)  # 激活函数:引入非线性x = self.hidden2(x)  # 第二层计算:(64,128)→(64,256)x = torch.sigmoid(x)  # 激活函数x = self.out(x)  # 输出层:(64,256)→(64,10)return x# 实例化模型并移动到指定设备
model = NeuralNetwork().to(device)
print(model)  # 打印模型结构

模型结构:3 层全连接网络(2 个隐藏层 + 1 个输出层),通过nn.Linear定义线性变换,sigmoid激活函数引入非线性(使模型能拟合复杂关系)。

forward方法:定义数据在网络中的流动过程,是模型计算的核心。

7. 训练函数

def train(dataloader, model, loss_fn, optimizer):model.train()  # 设为训练模式(启用 dropout/batchnorm等训练特有的层)batch_size_num = 1  # 批次计数器for X, y in dataloader:  # 遍历每个批次X, y = X.to(device), y.to(device)  # 数据移到设备(GPU/CPU)# 前向传播:计算预测值pred = model.forward(X)# 计算损失(预测值与真实标签的差距)loss = loss_fn(pred, y)# 反向传播+参数更新optimizer.zero_grad()  # 清空上一轮梯度loss.backward()  # 反向传播计算梯度optimizer.step()  # 优化器更新模型参数# 每100个批次打印一次损失loss_value = loss.item()  # 取出损失值(转为Python数值)if batch_size_num % 100 == 0:print(f"loss: {loss_value:>7f} [number: {batch_size_num}]")batch_size_num += 1

核心逻辑:通过 “前向传播计算损失→反向传播求梯度→优化器更新参数” 的循环,让模型逐步学习数据规律。

model.train():启用训练模式(部分层如 Dropout 在训练和测试时行为不同)。

梯度清空:optimizer.zero_grad()避免梯度累积,保证每轮梯度计算独立。

8. 测试(评估)函数

def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches= len(dataloader)model.eval()  # 设为评估模式(关闭 dropout/batchnorm等训练特有的层)test_loss = 0  # 总损失correct = 0  # 正确预测数with torch.no_grad():  # 关闭梯度计算(节省内存,加速评估)for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)  # 前向传播(不计算梯度)test_loss += loss_fn(pred, y).item()  # 累加损失# 计算正确预测数:pred.argmax(1)取预测概率最大的类别,与y比较correct += (pred.argmax(1) == y).type(torch.float).sum().item()# 计算平均损失和准确率test_pj_loss = test_loss /  num_batchestest_acy = correct / size  * 100print(f"Avg loss: {test_pj_loss:>7f} \n Accuracy: {test_acy:>5.2f}%")

作用:评估模型在测试集上的性能(泛化能力),不更新参数。

model.eval():关闭训练特有的层(如 Dropout),确保评估稳定。

torch.no_grad():关闭梯度计算,减少内存占用,加速评估。

准确率计算:pred.argmax(1)获取每个样本预测的最大概率类别,与真实标签y比较,统计正确比例。

9. 训练与评估执行

loss_fn = nn.CrossEntropyLoss()  # 损失函数:适合多分类问题(内置Softmax)
optimizer = torch.optim.SGD(model.parameters(), lr=1)  # 优化器:随机梯度下降,学习率1epochs = 10  # 训练轮数(完整遍历训练集10次)
for j in range(epochs):print(f"Epoch {j+1}\n----------")train(train_dataloader, model, loss_fn, optimizer)  # 训练一轮
print("Done!")
test(test_dataloader, model, loss_fn)  # 训练完成后在测试集评估

损失函数:CrossEntropyLoss是多分类任务的常用损失,结合了nn.LogSoftmaxnn.NLLLoss,直接接收原始输出(无需手动加 Softmax)。

优化器:SGD(随机梯度下降)用于更新模型参数,lr=1是学习率(控制参数更新幅度)。

训练轮数(epochs=10):模型会完整遍历训练集 10 次,逐步降低损失、提高准确率。

最终评估:训练完成后,在测试集上输出平均损失和准确率(通常能达到 90% 以上)。

总结

这段代码完整实现了一个基于全连接网络的 MNIST 手写数字识别流程,涵盖数据加载、可视化、模型构建、训练和评估。核心逻辑是通过反向传播算法,让模型从数据中学习 “图像像素→数字类别” 的映射关系,最终实现对未知手写数字的识别。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/94603.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/94603.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CRMEB标准版PC扫码登录配置教程(PHP版)

需要在开放平台创建网站应用 微信开放平台地址:https://open.weixin.qq.com/ 1、注册网站应用 2、填写信息,网站地址填写前台访问的域名就行 3、复制开放平台AppId和开放平台AppSecret 4、粘贴到后台应用配置的PC站点配置里

AmazeVault 核心功能分析,认证、安全和关键的功能

系列文章目录 Amazevault 是一款专注于本地安全的桌面密码管理器 AmazeVault 核心功能分析,认证、安全和关键的功能 AmazeVault 快速开始,打造个人专属桌面密码管理器 文章目录系列文章目录前言一、认证系统核心组件图形解锁实现图形锁控件 (PatternLoc…

Coze用户账号设置修改用户昵称-后端源码

前言 本文将深入分析Coze Studio项目的用户昵称修改功能后端实现,通过源码解读来理解整个昵称更新流程的架构设计和技术实现。用户昵称修改作为用户个人信息管理系统的重要组成部分,主要负责处理用户显示名称的更新和管理。 昵称修改功能相对简单但不失重…

基于WebTransport(底层QUIC)实现视频传输(HTML+JavaScript)

工作目录和基本操作见博客《基于HTTP3的WebTransport实践》,在此仅展示服务器端和客户端代码。 服务器端 import { readFile } from "node:fs/promises"; import { createServer } from "node:https"; import {Server} from "socket.io&qu…

Git 怎么仓库迁移?error: remote origin already exists.怎么解决

1. 确认本地已经有完整的 旧 仓库你本地应该有旧的项目,并且能看到 .git 文件夹。如果没有,可以先从旧仓库克隆下来:git clone 旧仓库git地址 cd 旧代码目录比如:git clone https://gitee.com/star-information/esflow-pro-api.gi…

【语法】【C+V】本身常用图表类型用法快查【CSDN不支持,VSCODE可用】

文章目录Mermaid 简介Mermaid 由三部分组成Mermaid 的使用方法复杂图表结构的技巧饼图简介饼图语法饼图示例雷达图简介雷达图语法雷达图语法细节标题轴曲线选项雷达图示例时间线简介时间线语法时间线示例树形图简介树形图语法树形图示例数据包图简介数据包图语法1:数…

Anaconda下载安装教程(附Anaconda安装包)

Anaconda下载安装教程(附Anaconda安装包)获取:python开发工具包丨夸克网盘-资源免费下载 有位小粉丝刚刚入门python,问我怎么安装Anaconda,于是我帮他找了一个详细的Anaconda安装教程和安装包。 Anaconda其实是一个p…

那我现在有3个输入 9层神经元 每层神经元数为 3 9 3 5 6 2 3 9 8 请给出我所有的权重矩阵

问题1 那我现在有3个输入 9层神经元 每层神经元数为 3 9 3 5 6 2 3 9 8 请给出我所有的权重矩阵 明白,我们来系统地列出你的 9层全连接网络 的权重矩阵。1️⃣ 网络结构 输入层:3 个神经元隐藏层/输出层:9 层神经元数分别为 [3,9,3,5,6,2,3,9…

广东省省考备考(第八十七天8.26)——判断推理(听课后强化训练)

判断推理:定义判断 错题解析 第一步:找出定义关键词。 “农村中各项经济活动及由此产生的经济关系”、“同农业有直接或间接的关系”。 第二步:逐一分析选项。 A项:该项指出具体的夏粮产量和增量,其中生产粮食属于种植…

读取 STM32H5 Data Flash 触发 NMI 的问题解析 LAT1544

关键字:STM32H5, data flash, high-cycle data, NMI问题描述客户反馈,使用 STM32H563 的 data flash(high-cycle data flash),在还没有写入任何数据之前去读取 data flash, 会触发 hardfault 异常。1. 问题分析我们尝试在 NUCLEO-…

学云计算还是网络,选哪个好?

云计算工程师和网络工程师,都是IT界香饽饽,但方向差很大!选错路后悔3年!今天极限二选一,帮你彻底搞懂工作职责 网络工程师:网络世界的交警工程师!主要管物理网络和逻辑连接。负责设计、搭建、维…

Matlab使用——开发上位机APP,通过串口显示来自单片机的电压电流曲线,实现光伏I-V特性监测的设计

预览此处的测试数据的采集频率和曲线变化是通过更换电阻来测试的,所以电压电流曲线显示并不是很平滑,图中可以看到每一个采集点的数值。这个设计是福州大学第三十期SRTP的一个校级的项目,打算通过分布式的在线扫描电路低成本的单片机&#xf…

云原生 JVM 必杀技:3 招让容器性能飞跃 90%

最近佳作推荐: Java 大厂面试题 – JVM 与分布式系统的深度融合:实现技术突破(34)(New) Java 大厂面试题 – JVM 新特性深度解读:紧跟技术前沿(33)(New&#…

你真的了解操作系统吗?

文章目录操作系统是什么?操作系统核心功能为什么需要操作系统(目的)?操作系统的下层是什么?上层又是什么?如何理解“管理”?——“先描述,再组织”操作系统是什么? 任何…

从0到1详解requests接口自动化测试

前言 接口测试是测试系统组件间接口的一种测试。接口测试主要用于检测外部系统与系统之间以及内部各个子系统之间的交互点。测试的重点是要检查数据的交换,传递和控制管理过程,以及系统间的相互逻辑依赖关系等。 1、理解什么是接口 接口一般来说有两种…

Linux系统操作编程——http

万维网www万维网是一个大规模的、联机式的信息储藏所 ,实现从一个站点链接到另一个站点万维网服务器后台标记万维网数据方式:url:统一资源定位符万维网客户端与万维网服务器的通信方式:HTTP:超文本传输协议万维网客户端…

Langchian-chatchat私有化部署和踩坑问题以及解决方案[v0.3.1]

文章目录一 langchain-chatchat项目二 本地私有部署2.1 源码下载2.2 创建虚拟环境2.3 安装Poetry2.4 安装项目依赖2.5 初始化项目2.6 修改配置信息2.7 初始化知识库2.8 启动服务三 问题和解决方法3.1 poetry和packaging版本兼容性3.2 Langchain-chatchatPDF加载错误分析[win平台…

Day3--HOT100--42. 接雨水,3. 无重复字符的最长子串,438. 找到字符串中所有字母异位词

Day3–HOT100–42. 接雨水,3. 无重复字符的最长子串,438. 找到字符串中所有字母异位词 每日刷题系列。今天的题目是力扣HOT100题单。 双指针和滑动窗口题目。其中438题踩了坑,很值得看一下。 42. 接雨水 思路: 每个位置i&#x…

Kafka Broker 核心原理全解析:存储、高可用与数据同步

Kafka Broker 核心原理全解析:存储、高可用与数据同步 思维导图正文:Kafka Broker 核心原理深度剖析 Kafka 作为高性能的分布式消息队列,其 Broker 节点的设计是支撑高吞吐、高可用的核心。本文将从存储结构、消息清理、高可用选举、数据同步…

RTTR反射机制示例

1. Person类型头文件 #ifndef PERSON_H …