深度学习篇---MNIST:手写数字数据集

下面我将详细介绍使用 PyTorch 处理 MNIST 手写数字数据集的完整流程,包括数据加载、模型定义、训练和评估,并解释每一行代码的含义和注意事项。

整个流程可以分为五个主要步骤:准备工作、数据加载与预处理、模型定义、模型训练和模型评估

# MNIST手写数字数据集完整处理流程
# 包含数据加载、模型定义、训练和评估的全步骤# 1. 导入必要的库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt# 2. 设置超参数
batch_size = 64       # 每次训练的样本数量
learning_rate = 0.001 # 学习率
num_epochs = 5        # 训练轮数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 注意:如果有GPU,会使用cuda加速训练,否则使用CPU# 3. 数据预处理与加载
# 定义数据变换:将图像转为Tensor并标准化
transform = transforms.Compose([transforms.ToTensor(),  # 转换为Tensor格式,像素值从0-255归一化到0-1# 标准化处理:均值为0.1307,标准差为0.3081(MNIST数据集的统计特性)transforms.Normalize((0.1307,), (0.3081,))
])# 加载训练集
train_dataset = datasets.MNIST(root='./data',        # 数据保存路径train=True,           # True表示加载训练集download=True,        # 如果数据不存在则自动下载transform=transform   # 应用上面定义的数据变换
)# 加载测试集
test_dataset = datasets.MNIST(root='./data',train=False,          # False表示加载测试集download=True,transform=transform
)# 创建数据加载器,用于批量加载数据
train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True          # 训练时打乱数据顺序
)test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=False         # 测试时不需要打乱顺序
)# 4. 可视化样本数据(可选,用于理解数据)
def show_samples():# 获取一些随机的训练样本dataiter = iter(train_loader)images, labels = next(dataiter)# 显示6个样本plt.figure(figsize=(10, 4))for i in range(6):plt.subplot(1, 6, i+1)plt.imshow(images[i].numpy().squeeze(), cmap='gray')plt.title(f'Label: {labels[i].item()}')plt.axis('off')plt.show()# 调用函数显示样本
show_samples()# 5. 定义神经网络模型
class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 第一个卷积块:卷积层 + 激活函数 + 池化层self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)# 第二个卷积块self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)# 全连接层self.fc1 = nn.Linear(7 * 7 * 64, 128)  # 经过两次池化后,28x28变为7x7self.relu3 = nn.ReLU()self.fc2 = nn.Linear(128, 10)  # 10个输出,对应0-9十个数字def forward(self, x):# 前向传播过程x = self.pool1(self.relu1(self.conv1(x)))x = self.pool2(self.relu2(self.conv2(x)))x = x.view(-1, 7 * 7 * 64)  # 展平操作x = self.relu3(self.fc1(x))x = self.fc2(x)return x# 初始化模型并移动到设备上
model = SimpleCNN().to(device)# 6. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()  # 交叉熵损失,适合分类问题
optimizer = optim.Adam(model.parameters(), lr=learning_rate)  # Adam优化器# 7. 训练模型
def train_model():# 记录训练过程中的损失和准确率train_losses = []train_accuracies = []# 开始训练model.train()  # 设置为训练模式for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0# 遍历训练数据for i, (images, labels) in enumerate(train_loader):# 将数据移动到设备上images = images.to(device)labels = labels.to(device)# 清零梯度optimizer.zero_grad()# 前向传播outputs = model(images)loss = criterion(outputs, labels)# 反向传播和优化loss.backward()  # 计算梯度optimizer.step()  # 更新参数# 统计损失和准确率running_loss += loss.item()# 计算预测结果_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()# 每100个批次打印一次信息if (i + 1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], 'f'Loss: {running_loss/100:.4f}, Accuracy: {100*correct/total:.2f}%')running_loss = 0.0# 记录每个epoch的平均损失和准确率epoch_loss = running_loss / len(train_loader)epoch_acc = 100 * correct / totaltrain_losses.append(epoch_loss)train_accuracies.append(epoch_acc)print(f'Epoch [{epoch+1}/{num_epochs}] completed. Training Accuracy: {epoch_acc:.2f}%')print('训练完成!')return train_losses, train_accuracies# 调用训练函数
train_losses, train_accuracies = train_model()# 8. 绘制训练曲线
def plot_training_curves(losses, accuracies):plt.figure(figsize=(12, 5))# 损失曲线plt.subplot(1, 2, 1)plt.plot(range(1, num_epochs+1), losses)plt.title('Training Loss')plt.xlabel('Epoch')plt.ylabel('Loss')# 准确率曲线plt.subplot(1, 2, 2)plt.plot(range(1, num_epochs+1), accuracies)plt.title('Training Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy (%)')plt.tight_layout()plt.show()# 绘制训练曲线
plot_training_curves(train_losses, train_accuracies)# 9. 评估模型
def evaluate_model():model.eval()  # 设置为评估模式correct = 0total = 0# 不需要计算梯度with torch.no_grad():for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'测试集准确率: {100 * correct / total:.2f}%')return 100 * correct / total# 评估模型在测试集上的表现
test_accuracy = evaluate_model()# 10. 可视化预测结果(可选)
def show_predictions():model.eval()dataiter = iter(test_loader)images, labels = next(dataiter)# 选择前6个样本进行预测images = images[:6].to(device)labels = labels[:6]with torch.no_grad():outputs = model(images)_, predicted = torch.max(outputs.data, 1)# 显示结果plt.figure(figsize=(10, 4))for i in range(6):plt.subplot(1, 6, i+1)plt.imshow(images[i].cpu().numpy().squeeze(), cmap='gray')plt.title(f'真实: {labels[i].item()}\n预测: {predicted[i].item()}')plt.axis('off')plt.tight_layout()plt.show()# 显示预测结果
show_predictions()# 11. 保存模型(可选)
torch.save(model.state_dict(), 'mnist_cnn_model.pth')
print("模型已保存为 'mnist_cnn_model.pth'")

代码详细解释与注意事项

1. 准备工作

  • 我们导入了 PyTorch 的核心库以及处理图像所需的工具
  • device设置会自动检测是否有可用的 GPU,如果有则使用 GPU 加速训练,否则使用 CPU

2. 数据加载与预处理

  • 数据变换 (transforms)

    • ToTensor()将图像从 PIL 格式转换为 PyTorch 的 Tensor 格式,并将像素值从 0-255 归一化到 0-1 范围
    • Normalize()进行标准化,使用的均值和标准差是 MNIST 数据集的统计特性,这有助于模型更快收敛
  • 数据集加载

    • datasets.MNIST会自动下载数据(如果本地没有)并加载
    • train=True加载训练集(60,000 张图片),train=False加载测试集(10,000 张图片)
  • DataLoader

    • 用于批量加载数据,支持自动打乱数据顺序
    • batch_size=64表示每次处理 64 张图片
    • 训练时shuffle=True打乱数据顺序,测试时shuffle=False保持顺序

3. 模型定义

  • 我们定义了一个简单的卷积神经网络 (SimpleCNN),包含:

    • 两个卷积块:每个卷积块由卷积层、ReLU 激活函数和池化层组成
    • 两个全连接层:最后一层输出 10 个值,对应 0-9 十个数字的预测概率
  • 卷积操作的作用:

    • 提取图像的局部特征,如边缘、纹理等
    • 池化层用于降低特征图尺寸,减少计算量

4. 模型训练

  • 损失函数:使用CrossEntropyLoss,适合多分类问题

  • 优化器:使用Adam优化器,比传统的 SGD 收敛更快

  • 训练过程中的关键步骤:

    1. 清零梯度:optimizer.zero_grad()
    2. 前向传播:计算模型输出和损失
    3. 反向传播:loss.backward()计算梯度
    4. 更新参数:optimizer.step()应用梯度更新
  • 注意事项:

    • 训练前调用model.train()设置为训练模式
    • 定期打印损失和准确率,监控训练进度
    • 将数据和模型移动到相同的设备上(CPU 或 GPU)

5. 模型评估

  • 评估时调用model.eval()设置为评估模式,这会关闭 dropout 等训练特有的操作
  • 使用torch.no_grad()关闭梯度计算,节省内存并加速计算
  • 计算测试集上的准确率,评估模型的泛化能力

6. 常见问题与解决方法

  1. 训练速度慢

    • 检查是否使用了 GPU(代码会自动检测,但需要正确安装 PyTorch GPU 版本)
    • 尝试调大batch_size(受限于 GPU 内存)
  2. 过拟合

    • 增加训练轮数
    • 添加正则化(如 Dropout)
    • 增加数据增强
  3. 准确率低

    • 检查模型结构是否合理
    • 尝试调整学习率
    • 增加训练轮数

通过这个完整流程,你可以加载 MNIST 数据集,训练一个卷积神经网络对手写数字进行分类,并评估模型性能。对于初学者来说,这个例子涵盖了深度学习的基本流程和关键概念,是一个很好的入门练习。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/97947.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/97947.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

k8s集群搭建(二)-------- 集群搭建

安装 containerd 需要在集群内的每个节点上都安装容器运行时&#xff08;containerd runtime&#xff09;&#xff0c;这个软件是负责运行容器的软件。 1. 启动 ipv4 数据包转发 # 设置所需的 sysctl 参数&#xff0c;参数在重新启动后保持不变 cat <<EOF | sudo tee …

【Docker】P1 前言:容器化技术发展之路

目录容器发展之路物理服务器时代&#xff1a;一机一应用的局限虚拟化时代&#xff1a;突破与局限并存容器化时代&#xff1a;轻量级的革新技术演进的价值体现各位&#xff0c;欢迎来到容器化时代。 容器发展之路 现代业务的核心是应用程序&#xff08;Application&#xff09;…

WPF依赖属性和依赖属性的包装器:

依赖属性是WPF&#xff08;Windows Presentation Foundation&#xff09;中的一种特殊类型的属性&#xff0c;特别适用于内存使用优化和属性值继承。依赖属性的定义包括以下几个步骤&#xff1a; 使用 DependencyProperty.Register 方法注册依赖属性。 该方法需要四个参数&…

图生图算法

图生图算法研究细分&#xff1a;技术演进、应用与争议 1. 基于GAN的传统图生图方法 定义&#xff1a;利用生成对抗网络&#xff08;GAN&#xff09;将输入图像转换为目标域图像&#xff08;如语义图→照片、草图→彩图&#xff09;。关键发展与趋势&#xff1a; Pix2Pix&#…

Go 自建库的使用教程与测试

附加一个Go库的实现&#xff0c;相较于Python&#xff0c;Go的实现更较为日常&#xff0c;不需要额外增加setup.py类的文件去额外定义,计算和并发的性能更加。 1. 创建 Go 模块项目结构 首先创建完整的项目结构&#xff1a; gomathlib/ ├── go.mod ├── go.sum ├── cor…

What is a prototype network in few-shot learning?

A prototype network is a method used in few-shot learning to classify new data points when only a small number of labeled examples (the “shots”) are available per class. It works by creating a representative “prototype” for each class, which is typical…

Linux中用于线程/进程同步的核心函数——`sem_wait`函数

<摘要> sem_wait 是 POSIX 信号量操作函数&#xff0c;用于对信号量执行 P 操作&#xff08;等待、获取&#xff09;。它的核心功能是原子地将信号量的值减 1。如果信号量的值大于 0&#xff0c;则减 1 并立即返回&#xff1b;如果信号量的值为 0&#xff0c;则调用线程&…

25高教社杯数模国赛【B题超高质量思路+问题分析】

注&#xff1a;本内容由”数模加油站“ 原创出品&#xff0c;虽无偿分享&#xff0c;但创作不易。 欢迎参考teach&#xff0c;但请勿抄袭、盗卖或商用。 B 题 碳化硅外延层厚度的确定碳化硅作为一种新兴的第三代半导体材料&#xff0c;以其优越的综合性能表现正在受到越来越多…

【Linux篇章】再续传输层协议UDP :从低可靠到极速传输的协议重生之路,揭秘无连接通信的二次进化密码!

&#x1f4cc;本篇摘要&#xff1a; 本篇将承接上次的UDP系列网络编程&#xff0c;来深入认识下UDP协议的结构&#xff0c;特性&#xff0c;底层原理&#xff0c;注意事项及应用场景&#xff01; &#x1f3e0;欢迎拜访&#x1f3e0;&#xff1a;点击进入博主主页 &#x1f4c…

《A Study of Probabilistic Password Models》(IEEE SP 2014)——论文阅读

提出更高效的密码评估工具&#xff0c;将统计语言建模技术引入密码建模&#xff0c;系统评估各类概率密码模型性能&#xff0c;打破PCFGw的 “最优模型” 认知。一、研究背景当前研究存在两大关键问题&#xff1a;一是主流的 “猜测数图” 计算成本极高&#xff0c;且难以覆盖强…

校园外卖点餐系统(代码+数据库+LW)

摘要 随着校园生活节奏的加快&#xff0c;学生对外卖的需求日益增长。然而&#xff0c;传统的外卖服务存在诸多不便&#xff0c;如配送时间长、菜品选择有限、信息更新不及时等。为解决这些问题&#xff0c;本研究开发了一款校园外卖点餐系统&#xff0c;采用前端 Vue、后端 S…

友思特案例 | 食品行业视觉检测案例集锦(三)

食品制造质量检测对保障消费者安全和产品质量稳定至关重要&#xff0c;覆盖原材料至成品全阶段&#xff0c;含过程中检测与成品包装检测。近年人工智能深度学习及自动化系统正日益融入食品生产。本篇文章将介绍案例三&#xff1a;友思特Neuro-T深度学习平台进行面饼质量检测。在…

SQLynx 3.7 发布:数据库管理工具的性能与交互双重进化

目录 &#x1f511; 核心功能更新 1. 单页百万级数据展示 2. 更安全的数据更新与删除机制 3. 更智能的 SQL 代码提示 4. 新增物化视图与外表支持 5. 数据库搜索与过滤功能重构 ⚡ 总结与思考 在大数据与云原生应用快速发展的今天&#xff0c;数据库管理工具不仅要“能用…

10G网速不是梦!5G-A如何“榨干”毫米波,跑出比5G快10倍的速度?

5G-A&#xff08;5G-Advanced&#xff09;网络技术已经在中国福建省厦门市软件园成功实现万兆&#xff08;10Gbps&#xff09;速率验证&#xff0c;标志着我国正式进入5G增强版商用阶段。这一突破性成果不仅验证了5G-A技术的可行性&#xff0c;也为6G网络的发展奠定了坚实基础。…

Linux笔记---UDP套接字实战:简易聊天室

1. 项目需求分析 我们要设计的是一个简单的匿名聊天室&#xff0c;用户的客户端要求用户输入自己的昵称之后即可在一个公共的群聊当中聊天。 为了简单起见&#xff0c;我们设计用户在终端当中与客户端交互&#xff0c;而在一个文件当中显式群聊信息&#xff1a; 当用户输入的…

RTP打包与解包全解析:从RFC规范到跨平台轻量级RTSP服务和低延迟RTSP播放器实现

引言 在实时音视频系统中&#xff0c;RTSP&#xff08;Real-Time Streaming Protocol&#xff09;负责会话与控制&#xff0c;而 RTP&#xff08;Real-time Transport Protocol&#xff09;负责媒体数据承载。开发者在实现跨平台、低延迟的 RTSP 播放器或轻量级 RTSP 服务时&a…

Ubuntu 用户和用户组

一、 Linux 用户linux 是一个多用户操作系统&#xff0c;不同的用户拥有不同的权限&#xff0c;可以查看和操作不同的文件。 Ubuntu 有三种用户1、初次创建的用户2、root 用户---上帝3、普通用户初次创建的用户权限比普通用户要多&#xff0c;但是没有 root 用户多。Linux 用户…

FastGPT社区版大语言模型知识库、Agent开源项目推荐

​ FastGPT 项目说明 项目概述 FastGPT 是一个基于大语言模型&#xff08;LLM&#xff09;的知识库问答系统&#xff0c;提供开箱即用的数据处理和模型调用能力&#xff0c;支持通过可视化工作流编排实现复杂问答场景。 技术架构 前端: Next.js TypeScript Chakra UI 后…

jsencrypt公钥分段加密,支持后端解密

前端使用jsencryp实现分段加密。 解决长文本RSA加密报错问题。 支持文本包含中文。 支持后端解密。前端加密代码&#xff1a; // import { JSEncrypt } from jsencrypt const JSEncrypt require(jsencrypt) /*** 使用 JSEncrypt 实现分段 RSA 加密&#xff08;正确处理中文字符…

生成一份关于电脑电池使用情况、健康状况和寿命估算的详细 HTML 报告

核心作用 powercfg /batteryreport 是一个在 Windows 命令提示符或 PowerShell 中运行的命令。它的核心作用是&#xff1a;生成一份关于电脑电池使用情况、健康状况和寿命估算的详细 HTML 报告。 这份报告非常有用&#xff0c;特别是对于笔记本电脑用户&#xff0c;它可以帮你&…