怎么用pytorch训练一个模型,并跑起来

MNIST 手写数字识别

任务描述

MNIST 手写数字识别是机器学习和计算机视觉领域的经典任务,其本质是解决 “从手写数字图像中自动识别出对应的数字(0-9)” 的问题,属于单标签图像分类任务(每张图像仅对应一个类别,即 0-9 中的一个数字)。

任务的核心定义:输入与输出

MNIST 任务的本质是建立 “手写数字图像” 到 “数字类别” 的映射关系,具体如下:
维度
| 具体 | 内容 |
|输入|28×28 像素的灰度图像(像素值范围 0-255,0 代表黑色背景,255 代表白色前景),图像内容是人类手写的 0-9 中的某一个数字。
例如:一张 28×28 的图像,像素分布呈现 “3” 的形状,就是模型的输入。|
|输出 |一个 “类别标签”,即从 10 个可能的类别(0、1、2、…、9)中选择一个,作为输入图像对应的数字。
例如:输入 “3” 的图像,模型输出 “类别 3”,即完成一次正确识别。 |
|目标|让模型在 “未见的手写数字图像” 上,尽可能准确地输出正确类别(通常用 “准确率” 衡量,即正确识别的图像数 / 总图像数)|

任务的核心挑战

为什么需要 “机器学习模型”?如果只是简单的 “看图像认数字”,人类可以轻松完成,但让计算机自动识别,需要解决多个关键挑战 —— 这些挑战也是 MNIST 成为经典任务的原因(它浓缩了计算机视觉的核心难题):
不同人书写习惯差异极大:有人写的 “4” 带弯钩,有人写的 “7” 带横线,有人字体粗大,有人字体纤细;甚至同一个人不同时间写的同一数字,笔画粗细、倾斜角度也会不同。
例如:同样是 “5”,可能是 “直笔 5”“圆笔 5”,也可能是倾斜 10° 或 20° 的 “5”—— 模型需要忽略这些 “风格差异”,抓住 “数字的本质特征”(如 “5 有一个上半圆 + 一个竖线”)。
图像噪声与干扰
手写数字图像可能存在噪声:比如纸张上的污渍、书写时的断笔、扫描时的光线不均,这些都会影响像素分布。
例如:一张 “0” 的图像,边缘有一小块污渍,模型需要判断 “这是噪声” 而不是 “0 的一部分”,避免误判为 “6” 或 “8”。

特征的自动提取

人类认数字时,会自动关注 “关键特征”(如 “0 是圆形、1 是竖线、8 是两个圆形叠加”),但计算机只能处理像素矩阵 —— 模型需要从 28×28=784 个像素值中,自动学习到这些抽象的 “数字特征”,而不是依赖人工定义(这也是深度学习优于传统方法的核心)。

MNIST 数据集的背景

MNIST(Modified National Institute of Standards and Technology database)是由美国国家标准与技术研究院(NIST)整理的手写数字数据集,后经修改(调整图像大小、居中对齐)成为机器学习领域的 “基准数据集”,其规模和特点非常适合入门:
数据量适中:包含 70000 张图像,其中 60000 张用于训练(让模型学习特征),10000 张用于测试(验证模型泛化能力);
图像规格统一:所有图像都是 28×28 灰度图,无需复杂的预处理(如尺寸缩放、颜色通道处理),降低入门门槛;
标注准确:每张图像都有明确的 “正确数字标签”(人工标注),无需额外标注成本。

任务的实际价值:解决这个问题有什么用?

MNIST 看似简单,但它是很多实际场景的 “简化版任务”,其解决思路可以迁移到更复杂的场景:
光学字符识别(OCR)的基础
例如:银行支票上的手写数字识别(识别金额)、快递单上的手写邮编识别、试卷批改中的选择题填涂识别 —— 这些场景本质都是 “手写字符分类”,MNIST 的技术思路(如卷积神经网络、全连接网络)可以直接复用或扩展。
机器学习模型的 “基准测试”
新提出的模型(如早期的 LeNet-5、后来的 ResNet 轻量版)会先在 MNIST 上测试性能:如果在简单的 MNIST 上都表现差,说明模型设计有问题;如果在 MNIST 上表现好,再迁移到更复杂的任务(如人脸识别、医学图像分类),这能大幅降低研发成本。
入门教学的 “绝佳案例”
MNIST 任务足够简单(数据量小、目标明确),但又能覆盖机器学习的完整流程(数据准备、模型定义、训练、评估、泛化性验证),因此成为初学者理解 “如何用代码实现一个完整机器学习任务” 的最佳载体(就像编程入门的 “Hello World”)。
总结
MNIST 手写数字识别的核心是 “让计算机从标准化的手写数字灰度图中,自动识别出对应的 0-9 数字”,它看似基础,却浓缩了图像分类的核心挑战(风格多样性、噪声鲁棒性、特征自动提取),同时是实际 OCR 场景的技术基础和机器学习入门的经典案例。

代码

下面我将为你提供一个使用 PyTorch 训练模型的完整示例,包括数据准备、模型定义、训练和测试的全过程。我们将使用一个简单的神经网络来解决 MNIST 手写数字识别问题。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 设置随机种子,确保结果可复现
torch.manual_seed(42)# 1. 数据准备
# 定义数据变换
transform = transforms.Compose([transforms.ToTensor(),  # 转换为Tensortransforms.Normalize((0.1307,), (0.3081,))  # 标准化,MNIST数据集的均值和标准差
])# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data',  # 数据保存路径train=True,     # 训练集download=True,  # 如果数据不存在则下载transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,    # 测试集download=True,transform=transform
)# 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 2. 定义模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()# 输入层到隐藏层self.fc1 = nn.Linear(28*28, 128)  # MNIST图像大小为28x28# 隐藏层到输出层self.fc2 = nn.Linear(128, 10)     # 10个类别(0-9)def forward(self, x):# 将图像展平为一维向量x = x.view(-1, 28*28)# 隐藏层,使用ReLU激活函数x = torch.relu(self.fc1(x))# 输出层,不使用激活函数(因为后面会用CrossEntropyLoss)x = self.fc2(x)return x# 3. 初始化模型、损失函数和优化器
model = SimpleNN()
criterion = nn.CrossEntropyLoss()  # 交叉熵损失,适用于分类问题
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam优化器# 4. 训练模型
def train(model, train_loader, criterion, optimizer, epochs=5):model.train()  # 设置为训练模式train_losses = []for epoch in range(epochs):running_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):# 清零梯度optimizer.zero_grad()# 前向传播outputs = model(data)loss = criterion(outputs, target)# 反向传播和优化loss.backward()optimizer.step()running_loss += loss.item()# 每100个批次打印一次信息if batch_idx % 100 == 99:print(f'Epoch [{epoch+1}/{epochs}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {running_loss/100:.4f}')running_loss = 0.0train_losses.append(running_loss / len(train_loader))return train_losses# 5. 测试模型
def test(model, test_loader):model.eval()  # 设置为评估模式correct = 0total = 0# 不计算梯度,节省内存和计算时间with torch.no_grad():for data, target in test_loader:outputs = model(data)_, predicted = torch.max(outputs.data, 1)total += target.size(0)correct += (predicted == target).sum().item()accuracy = 100 * correct / totalprint(f'Test Accuracy: {accuracy:.2f}%')return accuracy# 6. 运行训练和测试
if __name__ == '__main__':# 训练模型print("开始训练模型...")train_losses = train(model, train_loader, criterion, optimizer, epochs=5)# 测试模型print("开始测试模型...")test_accuracy = test(model, test_loader)# 保存模型torch.save(model.state_dict(), 'mnist_model.pth')print("模型已保存为 mnist_model.pth")# 绘制训练损失曲线plt.plot(train_losses)plt.title('Training Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.show()# 可视化一些预测结果model.eval()with torch.no_grad():# 获取一些测试数据data, target = next(iter(test_loader))outputs = model(data)_, predicted = torch.max(outputs, 1)# 显示前5个样本fig, axes = plt.subplots(1, 5, figsize=(15, 3))for i in range(5):axes[i].imshow(data[i].numpy().squeeze(), cmap='gray')axes[i].set_title(f'预测: {predicted[i]}, 实际: {target[i]}')axes[i].axis('off')plt.show()

代码解释

上面的代码实现了一个完整的 PyTorch 模型训练流程,主要包含以下几个部分:

  1. 数据准备:
    ◦ 使用torchvision.datasets加载 MNIST 数据集
    ◦ 对数据进行转换(转为 Tensor 并标准化)
    ◦ 使用DataLoader创建可迭代的数据加载器
  2. 模型定义:
    ◦ 定义了一个简单的两层神经网络SimpleNN
    ◦ 第一层将 28x28 的图像展平后映射到 128 维
    ◦ 第二层将 128 维特征映射到 10 个类别(对应数字 0-9)
  3. 训练设置:
    ◦ 使用交叉熵损失函数(CrossEntropyLoss)
    ◦ 使用 Adam 优化器
    ◦ 设置批量大小为 64,训练轮次为 5
  4. 训练过程:
    ◦ 循环多个训练轮次(epoch)
    ◦ 每个轮次中迭代所有批次数据
    ◦ 执行前向传播、计算损失、反向传播和参数更新
  5. 测试评估:
    ◦ 在测试集上评估模型性能
    ◦ 计算并打印准确率
  6. 结果可视化:
    ◦ 绘制训练损失曲线
    ◦ 展示部分测试样本的预测结果
    运行后,程序会自动下载 MNIST 数据集(首次运行),然后开始训练模型。训练完成后,会打印测试准确率,保存模型,并显示损失曲线和部分预测结果。
    这个示例比较基础,你可以根据需要调整模型结构、超参数(如学习率、批量大小、训练轮次等)来获得更好的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/96494.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/96494.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt应用程序发布方式

解决的问题:在自己电脑上用QT Creator编译的exe文件放到其他电脑上不能正常打开的问题。1、拷贝已经编译好的exe应用程序到桌面文件夹。桌面新建文件夹WindowsTest,并且将编译好的软件WindowTest.exe放入此文件夹中。2、在此文件夹空白处按住Shift再点击…

Linux 软件编程(九)网络编程:IP、端口与 UDP 套接字

1. 学习目的实现 不同主机之间的进程间通信。在 Linux 下,进程间通信(IPC)不仅可以发生在同一台主机上,也可以通过网络实现不同主机之间的通信。要做到这一点,必须同时满足以下两个条件:物理层面&#xff1…

5.Kotlin作用于函数let、run、with、apply、also

选择建议 需要返回值:使用 let、run 或 with配置对象:使用 apply附加操作:使用 also非空检查:使用 let链式调用:使用 let 或 run Kotlin作用域函数详解 概述 Kotlin提供了5个作用域函数:let、run、with、ap…

嵌入式学习日记(32)Linux下的网络编程

1. 目的不同主机,进程间通信。2. 解决的问题1). 主机与主机之间物理层面必须互联互通。2.) 进程与进程在软件层面必须互联互通。IP地址:计算机的软件地址,用来标识计算机设备 MAC地址:计算机的硬件地址&…

C#_接口设计:角色与契约的分离

2.3 接口设计:角色与契约的分离 在软件架构中,接口(Interface)远不止是一种语言结构。它是一份契约(Contract),明确规定了实现者必须提供的能力,以及使用者可以依赖的服务。优秀的接…

vsCode或Cursor 使用remote-ssh插件链接远程终端

一、Remote-SSH介绍Remote-SSH 是 VS Code 官方提供的一个扩展插件,允许开发者通过 SSH 协议连接到远程服务器,并在本地编辑器中直接操作远程文件,实现远程开发。它将本地编辑器的功能(如语法高亮、智能提示、调试等)与…

C语言实战:从零开始编写一个通用配置文件解析器

资料合集下载链接: ​https://pan.quark.cn/s/472bbdfcd014​ 在软件开发中,我们经常需要将一些可变的参数(如数据库地址、端口号、游戏角色属性等)与代码本身分离,方便日后修改而无需重新编译整个程序。这种存储配置信息的文件,我们称之为配置文件。 一、 什么是配置…

车机两分屏运行Unity制作的效果

目录 效果概述 实现原理 完整实现代码 实际车机集成注意事项 1. 显示系统集成 多屏显示API调用 代码示例(AAOS副驾屏显示) 2. 性能优化 GPU Instancing 其他优化技术 3. 输入处理 触控处理 物理按键处理 4. 安全规范 驾驶员侧限制 乘客侧…

vivo“空间计算-机器人”生态落下关键一子

出品 | 何玺排版 | 叶媛不出所料,vivo Vision热度很高。从21号下午发布到今天(22号),大众围绕vivo Vision探索版展开了多方面的讨论,十分热烈。从讨论来看,大家现在的共识是,MR行业目前还处于起…

Azure TTS Importer:一键导入,将微软TTS语音接入你的阅读软件!

Azure TTS Importer:一键导入,将微软TTS语音接入你的阅读软件! 文章来源:Poixe AI 厌倦了机械、生硬的文本朗读?想让你的阅读软件拥有自然流畅的AI语音?今天,我们将为您介绍一款强大且安全的开…

用过redis哪些数据类型?Redis String 类型的底层实现是什么?

Redis 数据类型有哪些? 详细可以查看:数据类型及其应用场景 基本数据类型: String:最常用的一种数据类型,String类型的值可以是字符串、数字或者二进制,但值最大不能超过512MB。一般用于 缓存和计数器 Ha…

大视协作码垛机:颠覆传统制造,开启智能工厂新纪元

在东三省某食品厂的深夜生产线上,码垛作业正有序进行,却不见人影——这不是魔法,而是大视协作码垛机器人带来的现实变革。在工业4.0浪潮席卷全球的今天,智能制造已成为企业生存与发展的必由之路。智能码垛环节作为产线的关键步骤&…

c# 保姆级分析继承详见问题 父类有一个列表对象,子类继承这个列表对象并对其进行修改后,将子类对象赋值给父类对象,父类对象是否能包含子类新增的内容?

文章目录 深入解析:父类与子类列表继承关系的终极指南 一、问题背景:从实际开发困惑说起 二、基础知识回顾:必备概念理解 2.1 继承的本质 2.2 引用类型 vs 值类型 2.3 多态的实现方式 三、核心问题分析:列表继承场景 3.1 基础代码示例 3.2 关键问题分解 3.3 结论验证 四、深…

tensorflow-gpu 2.7下的tensorboard与profiler插件版本问题

可行版本: python3.9.23cuda12.0tensorflow-gpu2.7.0tensorboard2.20.0 tensorboard-plugin-profile 2.4.0 问题描述: 1. 安装tensorboard后运行tensorboard --logdirlogs在网页中打开,发现profile模块无法显示,报错如下&#x…

数据结构青铜到王者第一话---数据结构基本常识(1)

目录 一、集合框架 1、什么是集合框架 2、集合框架的重要性 2.1开发中的使用 2.2笔试及面试题 3、背后涉及的数据结构以及算法 3.1什么是数据结构 3.2容器背后对应的数据结构 3.3相关java知识 3.4什么是算法 3.5如何学好数据结构以及算法 二、时间和空间复杂度 1、…

【Verilog】延时和时序检查

Verilog中延时和时序检查1. 延时模型1.1 分布延迟1.2 集总延迟1.3 路径延迟2. specify 语法2.1 指定路径延时基本路径延时边沿敏感路径延时状态依赖路径延时2.2 时序检查$setup, $hold, $setuphold$recovery, $removal, $recrem$width, $periodnotifier1. 延时模型 真实的逻辑元…

DigitalOcean Gradient AI平台现已支持OpenAI gpt-oss

OpenAI 的首批开源 GPT 模型(200 亿和 1200 亿参数)现已登陆 Gradient AI 平台。此次发布让开发者在构建 AI 应用时拥有更高的灵活度和更多选择,无论是快速原型还是大规模生产级智能体,都能轻松上手。新特性开源 GPT 模型&#xf…

藏在 K8s 幕后的记忆中枢(etcd)

目录1)etcd 基本架构2)etcd 的读写流程总览a)一个读流程b)一个写流程3)k8s存储数据过程源码解读4)watch 机制Informer 机制etcd watch机制etcd的watchableStore源码解读5) k8s大规模集群时会存在…

腾讯云EdgeOne安全防护:快速上手,全面抵御Web攻击

为什么需要专业的安全防护? 在当今数字化时代,网站面临的安全威胁日益增多。据统计,2023年全球Web应用程序攻击超7千亿次,持续快速增长。 其中最常见的包括: DDoS攻击:通过海量请求使服务器瘫痪Web应用攻…

SpringBoot中的条件注解

文章目录前言什么是条件注解核心原理常用条件注解详解1. ConditionalOnClass和ConditionalOnMissingClass2. ConditionalOnBean和ConditionalOnMissingBean3. ConditionalOnProperty应用场景:多数据源配置在SpringBoot自动配置中的核心作用自动配置的工作原理经典自…