使用pytorch创建/训练/推理OCR模型

一、任务描述

        从手写数字图像中自动识别出对应的数字(0-9)” 的问题,属于单标签图像分类任务(每张图像仅对应一个类别,即 0-9 中的一个数字)

        1、任务的核心定义:输入与输出

  • 输入:28×28 像素的灰度图像(像素值范围 0-255,0 代表黑色背景,255 代表白色前景),图像内容是人类手写的 0-9 中的某一个数字,例如:一张 28×28 的图像,像素分布呈现 “3” 的形状,就是模型的输入。
  • 输出:一个 “类别标签”,即从 10 个可能的类别(0、1、2、…、9)中选择一个,作为输入图像对应的数字,例如:输入 “3” 的图像,模型输出 “类别 3”,即完成一次正确识别。
  • 目标:让模型在 “未见的手写数字图像” 上,尽可能准确地输出正确类别(通常用 “准确率” 衡量,即正确识别的图像数 / 总图像数)

        2、任务的核心挑战

  • 不同人书写习惯差异极大:有人写的 “4” 带弯钩,有人写的 “7” 带横线,有人字体粗大,有人字体纤细;甚至同一个人不同时间写的同一数字,笔画粗细、倾斜角度也会不同。例如:同样是 “5”,可能是 “直笔 5”“圆笔 5”,也可能是倾斜 10° 或 20° 的 “5”—— 模型需要忽略这些 “风格差异”,抓住 “数字的本质特征”(如 “5 有一个上半圆 + 一个竖线”)。
  • 图像噪声与干扰:手写数字图像可能存在噪声,比如纸张上的污渍、书写时的断笔、扫描时的光线不均,这些都会影响像素分布。例如:一张 “0” 的图像,边缘有一小块污渍,模型需要判断 “这是噪声” 而不是 “0 的一部分”,避免误判为 “6” 或 “8”。

二、模型训练

       1、MNIST数据集

        MNIST(Modified National Institute of Standards and Technology database)是由美国国家标准与技术研究院(NIST)整理的手写数字数据集,后经修改(调整图像大小、居中对齐)成为机器学习领域的 “基准数据集”,MNIST手写数字识别的核心是 “让计算机从标准化的手写数字灰度图中,自动识别出对应的 0-9 数字”,它看似基础,却浓缩了图像分类的核心挑战(风格多样性、噪声鲁棒性、特征自动提取),同时是实际 OCR 场景的技术基础和机器学习入门的经典案例。

  • 数据量适中:包含 70000 张图像,其中 60000 张用于训练(让模型学习特征),10000 张用于测试(验证模型泛化能力);
  • 图像规格统一:所有图像都是 28×28 灰度图,无需复杂的预处理(如尺寸缩放、颜色通道处理),降低入门门槛;
  • 标注准确:每张图像都有明确的 “正确数字标签”(人工标注),无需额外标注成本。

        2、代码

  • 数据准备:使用torchvision.datasets加载 MNIST 数据集,对数据进行转换(转为 Tensor 并标准化),使用DataLoader创建可迭代的数据加载器;
  • 模型定义:定义了一个简单的两层神经网络SimpleNN,第一层将 28x28 的图像展平后映射到 128 维,第二层将 128 维特征映射到 10 个类别(对应数字 0-9);
  • 训练设置:使用交叉熵损失函数(CrossEntropyLoss),使用 Adam 优化器,设置批量大小为64,训练轮次为5;
  • 训练过程:循环多个训练轮次(epoch),每个轮次中迭代所有批次数据,执行前向传播、计算损失、反向传播和参数更新。
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 设置随机种子,确保结果可复现
torch.manual_seed(42)# 1. 数据准备
# 定义数据变换
transform = transforms.Compose([transforms.ToTensor(),  # 转换为Tensortransforms.Normalize((0.1307,), (0.3081,))  # 标准化,MNIST数据集的均值和标准差
])# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data',  # 数据保存路径train=True,  # 训练集download=True,  # 如果数据不存在则下载transform=transform
)test_dataset = datasets.MNIST(root='./data',train=False,  # 测试集download=True,transform=transform
)# 创建数据加载器
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)# 2. 定义模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()# 输入层到隐藏层self.fc1 = nn.Linear(28 * 28, 128)  # MNIST图像大小为28x28# 隐藏层到输出层self.fc2 = nn.Linear(128, 10)  # 10个类别(0-9)def forward(self, x):# 将图像展平为一维向量x = x.view(-1, 28 * 28)# 隐藏层,使用ReLU激活函数x = torch.relu(self.fc1(x))# 输出层,不使用激活函数(因为后面会用CrossEntropyLoss)x = self.fc2(x)return x# 3. 初始化模型、损失函数和优化器
model = SimpleNN()
criterion = nn.CrossEntropyLoss()  # 交叉熵损失,适用于分类问题
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam优化器# 4. 训练模型
def train(model, train_loader, criterion, optimizer, epochs=5):model.train()  # 设置为训练模式train_losses = []for epoch in range(epochs):running_loss = 0.0for batch_idx, (data, target) in enumerate(train_loader):# 清零梯度optimizer.zero_grad()# 前向传播outputs = model(data)loss = criterion(outputs, target)# 反向传播和优化loss.backward()optimizer.step()running_loss += loss.item()# 每100个批次打印一次信息if batch_idx % 100 == 99:print(f'Epoch [{epoch + 1}/{epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')running_loss = 0.0train_losses.append(running_loss / len(train_loader))return train_losses# 6. 运行训练和测试
if __name__ == '__main__':# 训练模型print("开始训练模型...")train_losses = train(model, train_loader, criterion, optimizer, epochs=5)print("模型训练完成...")# 保存模型torch.save(model.state_dict(), 'mnist_model.pth')print("模型已保存为 mnist_model.pth")

三、模型使用测试

import torch
import torch.nn as nn
from PIL import Image
import numpy as np
from torchvision import transforms  # 修正transforms的导入方式# 定义与训练时相同的模型结构
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.fc1 = nn.Linear(28*28, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = x.view(-1, 28*28)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 加载模型
def load_model(model_path='mnist_model.pth'):model = SimpleNN()# 加载模型时添加参数以避免潜在的Python 3兼容性问题model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))model.eval()  # 设置为评估模式return model# 图像预处理(与训练时保持一致)
def preprocess_image(image_path):# 打开图像并转换为灰度图img = Image.open(image_path).convert('L')  # 'L'表示灰度模式# 调整大小为28x28img = img.resize((28, 28))# 转换为numpy数组并归一化img_array = np.array(img) / 255.0# 定义图像转换(使用torchvision的transforms)transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])# 注意:这里需要先将numpy数组转换为PIL图像再应用transformimg_pil = Image.fromarray((img_array * 255).astype(np.uint8))img_tensor = transform(img_pil).unsqueeze(0)  # 增加批次维度return img_tensor# 预测函数
def predict_digit(model, image_path):# 预处理图像img_tensor = preprocess_image(image_path)# 预测with torch.no_grad():  # 不计算梯度outputs = model(img_tensor)_, predicted = torch.max(outputs.data, 1)return predicted.item()  # 返回预测的数字# 示例使用
if __name__ == '__main__':# 加载模型model = load_model('mnist_model.pth')# 预测示例图像test_image_path = 'test_digit.png'  # 用户需要提供的测试图像路径try:predicted_digit = predict_digit(model, test_image_path)print(f"预测的数字是: {predicted_digit}")except Exception as e:print(f"预测出错: {str(e)}")

使用gpu0(第一块gpu)进行训练/推理:
        torch.cuda.set_device(0)    
        model = model.cuda(0)
使用cpu记性训练/推理:
        model = model.cpu()


怎么用pytorch训练一个模型-手写数字识别
手把手教你如何跑通一个手写中文汉字识别模型-OCR识别【pytorch】
手把手教你用PyTorch从零训练自己的大模型(非常详细)零基础入门到精通,收藏这一篇就够了
揭秘大模型的训练方法:使用PyTorch进行超大规模深度学习模型训练
全套解决方案:基于pytorch、transformers的中文NLP训练框架,支持大模型训练和文本生成,快速上手,海量训练数据!
用 pytorch 从零开始创建大语言模型(三):编码注意力机制

YOLOv5源码逐行超详细注释与解读(1)——项目目录结构解析

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/98023.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/98023.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

新启航开启深孔测量新纪元:激光频率梳技术攻克光学遮挡,达 130mm 深度 2μm 精度

摘要:本文聚焦于深孔测量领域,介绍了一种创新的激光频率梳技术。该技术成功攻克传统测量中的光学遮挡难题,在深孔测量深度达 130mm 时,可实现 2μm 的高精度测量,为深孔测量开启了新的发展篇章。关键词:激光…

GEO优化推荐:AI搜索新纪元下的品牌内容权威构建

引言:AI搜索引擎崛起与GEO策略的战略重心转移2025年,以ChatGPT、百度文心一言、DeepSeek为代表的AI搜索引擎已深入成为公众信息获取的核心渠道。这标志着品牌营销策略的重心,正从传统的搜索引擎优化(SEO)加速向生成式引…

uniapp的上拉加载H5和小程序

小程序配置{"path": "list/course-list","style": {"navigationBarTitleText": "课程列表","enablePullDownRefresh": true,"onReachBottomDistance": 150}}上拉拉触底钩子onReachBottom() {var that …

【和春笋一起学C++】(四十)抽象数据类型

抽象数据类型(abstract data type, ADT)以通用的方式描述数据类型。C中类的概念非常适合于ADT方法。例如,C程序通过堆栈来管理自动变量,堆栈可由对它执行的操作来描述。可创建空堆栈;可将数据项添加到堆顶(…

大文件断点续传解决方案:基于Vue 2与Spring Boot的完整实现

大文件断点续传解决方案:基于Vue 2与Spring Boot的完整实现 在现代Web应用中,大文件上传是一个常见但具有挑战性的需求。传统的文件上传方式在面对网络不稳定、大文件传输时往往表现不佳。本文将详细介绍如何实现一个支持断点续传的大文件上传功能,结合Vue 2前端和Spring Bo…

LeNet-5:手写数字识别经典CNN

配套讲解视频,点击下方名片获取20 世纪 90 年代,计算机已经能识别文本,但图片识别很困难。比如银行支票的手写数字识别,传统方法需要人工设计规则,费时费力且精度不高。 于是,Yann LeCun 及其团队提出了 Le…

如何在 C# 中将文本转换为 Word 以及将 Word 转换为文本

在现代软件开发中,处理文档内容是一个非常常见的需求。无论是生成报告、存储日志,还是处理用户输入,开发者都可能需要在纯文本与 Word 文档之间进行转换。有时需要将文本转换为 Word,以便生成结构化的 .docx 文件,使内…

Open SWE:重构代码协作的智能范式——从规划到PR的全流程自动化革命

在软件开发的演进史上,工具链的每一次革新都深刻重塑着开发者的工作方式。LangChain AI推出的Open SWE,作为首个开源的异步编程代理,正在重新定义代码协作的边界——它不再仅仅是代码生成工具,而是构建了从代码库分析、方案规划、代码实现到拉取请求创建的端到端自动化工作…

【ARDUINO】通过ESP8266控制电机【待测试】

需求 通过Wi-Fi控制Arduino驱动的3V直流电机。这个方案使用外部6V或9V电源,ESP8266作为Wi-Fi模块,Arduino作为主控制器,L298N作为电机驱动器。 手机/电脑 (Wi-Fi客户端) | | (Wi-Fi) | ESP8266 (Wi-Fi模块, AT指令模式) | | (串口通信) | A…

cuda编程笔记(18)-- 使用im2col + GEMM 实现卷积

我们之前介绍了cudnn调用api直接实现卷积,本文我们探究手动实现。对于直接使用for循环在cpu上的实现方法,就不过多介绍,只要了解卷积的原理,就很容易实现。im2col 的核心思想im2col image to column把输入 feature map 的每个卷积…

Loopback for Mac:一键打造虚拟音频矩阵,实现跨应用音频自由流转

虚拟音频设备创建 模拟物理设备:Loopback允许用户在Mac上创建虚拟音频设备,这些设备可被系统及其他应用程序识别为真实硬件,实现音频的虚拟化传输。多源聚合:支持将麦克风、应用程序(如Skype、Zoom、GarageBand、Logic…

深入解析Django重定向机制

概述 核心是一个基类 HttpResponseRedirectBase,以及两个具体的子类 HttpResponseRedirect(302 临时重定向)和 HttpResponsePermanentRedirect(301 永久重定向)。它们都是 HttpResponse 的子类,专门用于告诉…

【Java实战⑳】从IO到NIO:Java高并发编程的飞跃

目录一、NIO 与 IO 的深度剖析1.1 IO 的局限性1.2 NIO 核心特性1.3 NIO 核心组件1.4 NIO 适用场景二、NIO 核心组件实战2.1 Buffer 缓冲区2.2 Channel 通道2.3 Selector 选择器2.4 NIO 文件操作案例三、NIO2.0 实战3.1 Path 类3.2 Files 类3.3 Files 类高级操作3.4 NIO2.0 实战…

OpenCV 实战:图像模板匹配与旋转处理实现教程

目录 一、功能概述:代码能做什么? 二、环境准备:先搭好运行基础 1. 安装 Python 2. 安装 OpenCV 库 3. 准备图像文件 三、代码逐段解析:从基础到核心 1. 导入 OpenCV 库 2. 读取图像文件 3. 模板图像旋转:处理…

一、cadence的安装及入门教学(反相器的设计与仿真)

一、Cadence的安装 1、安装VMware虚拟机 2、安装带有cadence软件的Linux系统 注:网盘链接 分享链接:https://disk.ningsuan.com.cn/#s/8XaVdtRQ 访问密码:11111 所有文件压缩包及文档密码: Cadence_ic 3、安装tsmc18工艺库…

用ai写了个UE5插件

文章目录实际需求1.头文件2.源文件3.用法小结实际需求 这个需求来源于之前的一个项目,当时用了一个第三方插件,里边有一些绘制线段的代码,c层用的是drawdebugline,当时看底层,觉得应该没问题,不应该在rele…

机器学习从入门到精通 - 强化学习初探:Q-Learning到Deep Q-Network实战

机器学习从入门到精通 - 强化学习初探:从 Q-Learning 到 Deep Q-Network 实战 一、开场白:推开强化学习这扇门 不知道你有没有过这种感觉 —— 盯着一个复杂的系统,既想让它达到某个目标,又苦于无法用传统规则去精确描述每一步该怎…

【OpenHarmony文件管理子系统】文件访问接口解析

OpenHarmony文件访问接口(filemanagement_file_api) 概述 OpenHarmony文件访问接口(filemanagement_file_api)是开源鸿蒙操作系统中的核心文件系统接口,为应用程序提供了完整的文件IO操作能力。该项目基于Node-API&…

云手机运行是否消耗自身流量?

云手机运行是否消耗自身流量,取决于具体的使用场景和设置:若用户在连接云手机时,使用的是家中Wi-Fi、办公室局域网等非移动数据网络,那么在云手机运行过程中,基本不会消耗用户自身的移动数据流量,在家中连接…

JavaSe之多线程

一、多线程基本了解 1、多线程基本知识 1.进程:进入到内存中执行的应用程序 2.线程:内存和CPU之间开通的通道->进程中的一个执行单元 3.线程作用:负责当前进程中程序的运行.一个进程中至少有一个线程,一个进程还可以有多个线程,这样的应用程序就称之为多线程程序 4.简单理解…