用 PyTorch 实现全连接网络识别 MNIST 手写数字

目录

一、什么是全连接网络

二、代码实现步骤

1. 导入必要的库

2. 数据准备

3. 定义网络结构

4. 模型训练

5. 模型保存和加载

6. 预测单张图片

7. 主函数

三、运行结果说明

四、小结


一、什么是全连接网络

全连接神经网络(Fully Connected Neural Network)是一种最基础的神经网络结构,其特点是每一层的每个神经元都与上一层的所有神经元相连。

打个比方,就像公司里的部门架构:输入层是基层员工,隐藏层是中层管理,输出层是高层决策。基层的每个人都要向所有中层汇报,中层再向所有高层汇报,这样信息就能经过多层处理后得到最终结果。

但全连接网络处理图像时有个缺点:它会把图像的二维像素矩阵转换成一维向量,这就像把一张完整的图片撕成一条线,会丢失图像的空间特征。

二、代码实现步骤

1. 导入必要的库

import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from PIL import Image

这些库就像我们的工具包:

  • torch 是 PyTorch 的核心库
  • nn 模块包含神经网络相关的工具
  • optim 提供优化器
  • torchvision 有现成的数据集和图像处理工具
  • DataLoader 帮助我们批量加载数据
  • PIL 用于处理图像

2. 数据准备

def build_data():transform = transforms.Compose([transforms.ToTensor(),])train_set = datasets.MNIST(root = '../dataset',train = True,download = True,transform = transform)test_set = datasets.MNIST(root = '../dataset',train = False,download = True,transform = transform)train_loader = DataLoader(dataset = train_set,batch_size = 128,shuffle = True)test_loader = DataLoader(dataset = test_set,batch_size = 64,shuffle = True)return train_loader, test_loader

这段代码做了三件事:

  • 定义了数据转换方式,ToTensor()会把图像转换成张量并归一化
  • 加载 MNIST 数据集(手写数字数据集,包含 0-9 共 10 类数字)
  • DataLoader把数据分成批次,方便训练时批量处理

batch_size表示每次处理多少张图片,shuffle=True表示打乱数据顺序,让模型学习更全面。

3. 定义网络结构

class MNISTNet(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(28 * 28, 256)self.relu1 = nn.ReLU()self.fc2 = nn.Linear(256, 128)self.relu2 = nn.ReLU()self.fc3 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28 * 28)  # 把28x28的图像展平成784维向量x = self.relu1(self.fc1(x))x = self.relu2(self.fc2(x))x = self.fc3(x)return x

我们定义了一个 3 层的全连接网络:

  • 输入层:MNIST 图像是 28x28 的,展平后是 784 个像素点
  • 第一个隐藏层:256 个神经元,使用 ReLU 激活函数
  • 第二个隐藏层:128 个神经元,同样使用 ReLU 激活函数
  • 输出层:10 个神经元(对应 0-9 十个数字)

激活函数 ReLU 的作用是引入非线性,让网络能够学习复杂的模式,就像给计算器增加了更多运算功能。

4. 模型训练

def train(model, train_loader, epochs):criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数,适合分类问题opt = optim.SGD(model.parameters(), lr=0.01)  # 随机梯度下降优化器for epoch in range(epochs):loss_sum = 0count = 0for x, y in train_loader:y_pred = model(x)  # 前向传播,得到预测结果loss = criterion(y_pred, y)  # 计算损失# 反向传播更新参数opt.zero_grad()  # 清空梯度loss.backward()  # 计算梯度opt.step()  # 更新参数loss_sum += loss.item()_, pred = torch.max(y_pred, dim=1)  # 找到概率最大的类别count += (pred == y).sum().item()  # 统计正确的数量acc = count / len(train_loader.dataset)  # 计算准确率print(f'epoch: {epoch+1}, Loss: {loss_sum:.4f}, Acc: {acc:.4f}')

训练过程就像学生做习题:

  1. 先用当前模型做预测(前向传播)
  2. 计算预测结果和正确答案的差距(损失函数)
  3. 分析哪里错了,怎么改进(反向传播计算梯度)
  4. 调整模型参数(优化器更新参数)

我们用交叉熵损失函数来衡量预测错误的程度,用随机梯度下降(SGD)来优化模型参数,学习率lr=0.01控制每次调整的幅度。

5. 模型保存和加载

def save_model(model, model_path):torch.save(model.state_dict(), model_path)  # 保存模型参数def load_model(model_path):model = MNISTNet()model.load_state_dict(torch.load(model_path))  # 加载模型参数return model

训练好的模型可以保存下来,下次用的时候直接加载,不用重新训练,就像保存游戏进度一样。

6. 预测单张图片

def predict(model, filePath):img = Image.open(filePath)# 图像预处理:调整大小、转成张量、归一化transform = transforms.Compose([transforms.Resize((28, 28)),transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])t_img = transform(img)with torch.no_grad():  # 预测时不需要计算梯度y_pred = model(t_img)_, pred = torch.max(y_pred, dim=1)print(f'预测结果: {pred.item()}')

预测时需要对输入图片做和训练数据相同的预处理,with torch.no_grad()可以加快计算速度,因为预测时不需要更新参数。

7. 主函数

if __name__ == '__main__':train_loader, test_loader = build_data()model = MNISTNet()# 训练模型train(model, train_loader, epochs=10)# 保存模型save_model(model, './mnist.pt')# 加载模型并预测model_pred = load_model('./mnist.pt')predict(model_pred, './img/3.png')

三、运行结果说明

训练过程中,我们会看到损失(Loss)逐渐减小,准确率(Acc)逐渐提高,这说明模型在不断进步。

对于 MNIST 这种简单数据集,用这个全连接网络通常能达到 97% 以上的准确率。如果想进一步提高性能,可以考虑使用卷积神经网络(CNN),它能更好地保留图像的空间特征。

四、小结

本文用 PyTorch 实现了一个全连接神经网络来识别 MNIST 手写数字,主要步骤包括:

  1. 准备数据:加载并预处理 MNIST 数据集
  2. 定义网络:设计 3 层全连接网络
  3. 训练模型:使用交叉熵损失和 SGD 优化器
  4. 保存和加载模型:方便复用
  5. 单张图片预测:实际应用模型

全连接网络虽然简单,但它是理解更复杂神经网络的基础。通过这个例子,我们可以了解神经网络的基本工作原理和 PyTorch 的使用方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/90328.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/90328.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vscode怎么安装MINGW

下载: 第一步选择MINGW官网:MinGW-w64 - for 32 and 64 bit Windows - SourceForge.net 点击Files 点击Toolchains targetting Win64 点击第一个 Personal Builds 点击mingw-builds 选择8.1.0 点击第二个 threads-posix 点击第二个seh 最后左键点击下…

CSS图片分层设置

在CSS中实现图片分层效果,主要通过定位属性和层叠上下文控制。以下是核心实现方法和示例: 一、核心实现原理定位方式 使用 position: relative/absolute/fixed 使图片脱离文档流 .layer {position: absolute; /* 关键属性 */top: 0;left: 0; }层叠控制 通…

GEMINUS 和 Move to Understand a 3D Scene

论文链接:https://arxiv.org/abs/2507.14456 代码链接:https://github.com/newbrains1/GEMINUS 端到端自动驾驶的挑战 端到端自动驾驶是一种“一站式”方法:模型直接从传感器输入(如摄像头图像)生成驾驶轨迹或控制信号…

算法与数据结构:线性表

C语言数据结构基础:线性表详解线性表是数据结构中最基础、最常用的形式,就像一列整齐排队的游客:每个元素有固定位置(前驱和后继),长度可动态变化。在C语言中,它主要通过顺序表(数组…

制作mac 系统U盘

使用 installinstallmacos.py(更兼容) 苹果官方不提供所有历史版本的安装器,但可以通过一个开源脚本下载(Apple 提供的企业支持工具): git clone https://github.com/munki/macadmin-scripts.git cd macadm…

渗透部分总结

docker环境搭建以及dns等原理讲解Docker搭建:Linux 系统上安装 Docker 引擎并启动服务:# 安装Docker引擎 curl -fsSL https://get.docker.com | sh 通过 curl 下载并执行 Docker 官方的安装脚本,这会自动配置 Docker 仓库并安装最新版本的 Do…

k8s pvc是否可绑定在多个pod上

1.pvc是否可绑定在多个podPVC 是否能被多个 Pod 使用,取决于它的 accessModes。PVC 的 accessModes是否支持多个 Pod 同时使用说明ReadWriteOnce (RWO)❌ 若多个Pod,需在相同节点上(仅允许被单个节点上的Pod挂载)常用于本地磁盘、…

如何加固Endpoint Central服务器的安全?(下)

Endpoint Central 作为企业终端管理的 “中枢系统”,掌控着全网终端的补丁推送、软件部署、配置管理、远程控制等关键权限,存储着大量终端资产信息、用户数据及企业策略配置。一旦服务器被攻破,攻击者可能篡改管理指令(如推送恶意…

信息整合注意力IIA,通过双方向注意力机制重构空间位置信息,动态增强目标关键特征并抑制噪声

在遥感图像语义分割等视觉任务中,编码器 - 解码器结构通过跳跃连接融合多尺度特征时,常面临两大挑战:一是编码器的局部细节特征与解码器的全局语义特征融合时,空间位置信息易丢失,导致目标定位不准;二是复杂…

如何迁移jenkins至另一台服务器

前言公司旧的服务器快到期了,需要将部署在其上的jenkins整体迁移到另一台服务器,两台都是aws ec2服务器。文章主要提供给大家一种迁移思路,并不一定是最优解,仅供参考,大家根据实际情况自行选用和修改,举一…

在vue中遇到Uncaught TypeError: Assignment to constant variable(常亮无法修改)

1.问题如下:2.出现这个问题的原因----在设计变量的时候采用了const来进行修饰,在修改的时候直接对其进行修改3.利用响应式变量的特点,修改为下面这样就可以正常了

RCE随笔-奇技淫巧(2)

Linux命令长度限制在7个字符的情况下&#xff0c;如何拿到shell <?php $param $_REQUEST[param]; If ( strlen($param) < 8 ) { echo shell_exec($param); }分析代码&#xff1a;这段代码传入参数param然后进入if语句判断是否小于8个字符&#xff0c;然后如果小于就会进…

设计模式九:构建器模式 (Builder Pattern)

动机(Motivation)1、在软件系统中&#xff0c;有时候面临着“一个复杂对象”的创建工作&#xff0c;其通常由各个部分的子对象用一定的算法构成&#xff1b;由于需求的变化&#xff0c;这个复杂对象的各个部分经常面临着剧烈的变化&#xff0c;但是将它们组合在一起的算法却相对…

如何高效合并音视频文件

在自我学习或者进行视频剪辑的时候&#xff0c;经常从资源网址下载音视频分离的文件&#xff0c;例如audio_file1.m4a和video_1.mp4&#xff0c;之后需要把这两个文件合并在一起。于是条件反射得想要利用剪映等第三方工具&#xff0c;进行音视频的封装。可惜不幸的是&#xff0…

虚幻 5 与 3D 软件的协作:实时渲染,所见所得

《曼达洛人》的星际飞船在片场实时掠过虚拟荒漠&#xff0c;游戏开发者拖动滑块就能即时看到角色皮肤的通透变化&#xff0c;实时渲染技术正以 “所见即所得” 的核心优势&#xff0c;重塑着 3D 创作的整个逻辑。虚幻引擎 5&#xff08;UE5&#xff09;凭借 Lumen 全局光照和 N…

​Eyeriss 架构中的访存行为解析(腾讯元宝)

​Eyeriss 架构中的访存行为解析​Eyeriss 是 MIT 提出的面向卷积神经网络&#xff08;CNN&#xff09;的能效型 NPU&#xff08;神经网络处理器&#xff09;架构&#xff0c;其核心创新在于通过硬件结构优化访存行为&#xff0c;以解决传统 GPU 在处理 CNN 时因数据搬运导致的…

数字图像处理(三:图像如果当作矩阵,那加减乘除处理了矩阵,那图像咋变):从LED冬奥会、奥运会及春晚等等大屏,到手机小屏,快来挖一挖里面都有什么

数字图像处理&#xff08;三&#xff09;一、&#xff08;准备工作&#xff1a;咋玩&#xff0c;用什么玩具&#xff09;图像以矩阵形式存储&#xff0c;那矩阵一变、图像立刻跟着变&#xff1f;1. Python Jupyter Notebook/Lab 库 (NumPy, OpenCV, Matplotlib, scikit-image…

docker-desktop启动失败

报错提示deploying WSL2 distributions ensuring main distro is deployed: checking if main distro is up to date: checking main distro bootstrap version: getting main distro bootstrap version: open \\wsl$\docker-desktop\etc\wsl_bootstrap_version: The network n…

基于FastMCP创建MCP服务器的小白级教程

以下是基于windows 11操作系统环境的开发步骤。 1、python环境搭建 访问官网&#xff1a;https://www.python.org/。下载相应的版本&#xff08;如&#xff1a;3.13.5&#xff09;&#xff0c;然后安装。 安装完成之后&#xff0c;使用命令行工具输入python&#xff0c;显示…

网络协议与层次对应表

网络协议与层次对应表&#xff08;OSI & TCP/IP模型&#xff09;OSI七层模型TCP/IP四层模型协议/技术核心功能与应用​应用层​应用层HTTP/HTTPS网页传输协议&#xff08;HTTP&#xff09;及其加密版&#xff08;HTTPS&#xff09;FTP文件上传/下载协议SMTP/POP3/IMAPSMTP发…