自己编写一个神经网络模型识别数字验证码(卷积神经网络的 Hello world)

开篇之前说明一下:本文纯粹是技术交流和探讨,所用数据为非公开数据集,仅限于学习,不可用以商业和其他用途。

 一、项目目标

  • 通过构建一个简单的 CNN 神经网络,实现对 数字验证码(如 “7384”) 的识别。
  • 熟悉图像分类流程:数据生成、图像预处理、网络设计、训练评估与预测。
  • 掌握卷积神经网络在图像识别中的基本应用方法。

二、整体思路

实现数字验证码识别目标的整体思路是:(1)验证码图像数据准备-->(2)数据集制作(为了简单,将每张验证码图像中的4个数字切割开来,形成每张图只有1个数字的训练集,并且文件名的第一个字符为训练图的数字,即标签。)-->(3)编写数据处理方法、模型、训练方法、测试方法-->(4)训练和保存模型->(5)测试

三、实现步骤

步骤一:验证码图像数据准备

本例准备了1143张验证码图像,切割成4572张单个数字的训练图像。

步骤二:数据集制作

如上文所属,本例的训练集为4572张单个数字的训练图像,并且图像文件名的第一个字符为标签,如下图所示:

步骤三:编写实现代码

        1.编写数据加载模块

        构建一个用于数字验证码识别任务的数据加载模块,并通过 PyTorch 提供的数据接口 DatasetDataLoader 对图像数据进行封装和加载。

        具体过程如下:


        (1)导入必要的库:

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

        这些库的功能:

        torch.utils.data.Dataset: PyTorch 自定义数据集基类

        DataLoader: 用于批量加载数据、打乱顺序

        transforms: 图像预处理(如灰度、归一化)

        PIL.Image: 用于打开图像文件

        os: 文件路径操作


        (2)自定义数据集类 DigitDataset

class DigitDataset(Dataset):

        这个类继承了 PyTorch 的 Dataset,是用户自定义的数据集类。

        初始化方法:

def __init__(self, data_dir):

        data_dir 是图像数据所在的目录(如 "data")。

        图像预处理流程:

self.transform = transforms.Compose([transforms.Grayscale(),         # 转换为灰度图(1通道)transforms.ToTensor(),          # 转换为张量并归一化为 [0, 1]transforms.Normalize(0.5, 0.5)  # 再归一化为 [-1, 1]
])

        获取所有图像文件名:

self.file_list = os.listdir(data_dir)

        将文件夹中所有图片文件名收集到列表中。


        (3)数据集长度:

def __len__(self):return len(self.file_list)

        告诉 PyTorch 这个数据集总共有多少样本。


        (4) 获取某个样本:

def __getitem__(self, idx):

        PyTorch 在训练时会调用这个方法来获取第 idx 个样本。

filename = self.file_list[idx]
img_path = os.path.join(self.data_dir, filename)
image = Image.open(img_path)

         获取图像路径并打开图像。

label = int(filename[0])  # 文件名第一个字符是标签

         将文件名的**第一个字符(数字)**作为标签,例如:"7abc.png" 的标签是 7

return self.transform(image), label

         返回的是图像张量和对应的标签整数。


          (5) 创建数据加载器:

dataset = DigitDataset("data")
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

        加载器的作用是:使用 "data" 文件夹下的数据创建一个数据集。用 DataLoader 创建批次大小为 32 的训练集,每个 epoch 自动打乱顺序。


        总结一下:数据加载模块实现了一个用于数字图像分类任务(如验证码识别)的数据集加载器,图像标签从文件名中自动提取,图像被统一预处理为灰度、张量、归一化格式,方便输入到神经网络中进行训练。DigitDataset的完整代码如下:

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import osclass DigitDataset(Dataset):def __init__(self, data_dir):self.data_dir = data_dirself.transform = transforms.Compose([transforms.Grayscale(),         # 转换为灰度图transforms.ToTensor(),          # 转为张量 [0,1]transforms.Normalize(0.5, 0.5)  # 归一化到 [-1,1]])self.file_list = os.listdir(data_dir)def __len__(self):return len(self.file_list)def __getitem__(self, idx):filename = self.file_list[idx]img_path = os.path.join(self.data_dir, filename)image = Image.open(img_path)label = int(filename[0])  # 文件名第一个字符为标签return self.transform(image), label# 创建数据集和数据加载器
dataset = DigitDataset("data")
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

        2.编写模型

        这里定义一个用于数字识别(0~9)分类的卷积神经网络模型 DigitNet,使用的是 PyTorch 框架中的 nn.Module

        下面我们逐层解读它的结构和作用:


        📦 模型整体结构概览

DigitNet(└── Conv2d(1, 2, 3, padding=1)  -> ReLU -> MaxPool2d(2)└── Conv2d(2, 4, 3, padding=1)  -> ReLU└── Flatten -> Linear(168, 128) -> ReLU└── Linear(128, 10)
)

        🔍 详细分层解读

        1️⃣ 输入层

nn.Conv2d(1, 2, kernel_size=3, padding=1)

                 输入通道数 1:灰度图像

                输出通道数 2:提取两个特征图

                卷积核大小 3x3

                Padding=1 保证输入输出尺寸不变
        👉 输出尺寸仍为 13x15(本例中,训练输入图像大小为 13x15)

 

nn.ReLU()

         激活函数,增加非线性能力

 

nn.MaxPool2d(2)

         池化操作:尺寸缩小为 1/2

        输入 13x15 → 输出约为 6x7(向下取整)


        2️⃣ 第二个卷积层

nn.Conv2d(2, 4, kernel_size=3, padding=1)

         输入通道数 2,输出通道数 4

        卷积后尺寸仍为 6x7

        特征图变为 4 个

nn.ReLU()

        3️⃣ 展平层(Flatten)

nn.Flatten()

         将多维特征图转换为一维向量

        展平后的尺寸为:

6(height) * 7(width) * 4(channels) = 168

        4️⃣ 全连接层

nn.Linear(6*7*4, 128)

         输入维度:168

        输出维度:128,特征压缩

 

nn.ReLU()

         激活函数

 

nn.Linear(128, 10)

         输出层:10个神经元,对应 数字类别 0~9


        🧠 forward 方法

def forward(self, x):return self.model(x)

         将输入图像 x 传入 self.model 中按顺序执行模型结构,输出是10维的向量,表示每个数字的分类概率(一般用于交叉熵分类损失)。


        📈 模型输出解释

输出为:

tensor([[0.1, 0.05, ..., 0.15]])  # 共10个数

         每个值表示图像属于某个数字(0~9)的预测得分(在训练时会搭配 CrossEntropyLoss,不需要手动加 softmax)。


        ✅ 总结:该模型是一个轻量级的 CNN 分类器,用于识别尺寸约为 13×15 的灰度图像中的单个数字,最终输出一个 10 类的分类结果。模型 DigitNet的完整代码如下:

import torch.nn as nnclass DigitNet(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Conv2d(1, 2, kernel_size=3, padding=1), # 13x15 -> 13x15nn.ReLU(),nn.MaxPool2d(2),                            # 6x7nn.Conv2d(2, 4, kernel_size=3, padding=1), # 6x7 -> 6x7nn.ReLU(),nn.Flatten(),nn.Linear(6*7*4, 128),nn.ReLU(),nn.Linear(128, 10))def forward(self, x):return self.model(x)

        3.编写训练函数

        定义一个用于训练 PyTorch 神经网络模型的函数 train_model


        ✅ 一、总体功能说明

def train_model(model, train_loader, epochs=10):

        该函数的作用是:使用指定的数据加载器和模型,进行多轮(epoch)训练,并返回训练完成后的模型。

输入参数:

  • model: 待训练的神经网络模型(如你之前定义的 DigitNet

  • train_loader: 数据加载器(封装了训练数据)

  • epochs: 训练轮数,默认 10

输出结果:

  • 返回训练完成后的 model


        🔍 二、代码编写逻辑

        1️⃣ 确定训练设备(CPU 或 GPU)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
  • 自动检测是否有可用的 GPU,如果有则用 cuda 加速,否则退回到 cpu

  • 将模型移到设备上(很关键!)


        2️⃣ 定义损失函数和优化器

criterion = nn.CrossEntropyLoss()
  • 使用交叉熵损失函数,适用于多分类任务(如 0~9 数字识别)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  • 使用 Adam 优化器(自动调整学习率),学习率设为 0.001


        3️⃣ 模型训练模式

model.train()
  • 设置模型为“训练模式”,激活 Dropout 和 BatchNorm(如果有)


        4️⃣ 外层循环:遍历每个 epoch

for epoch in range(epochs):running_loss = 0.0
  • running_loss:用于记录每个 epoch 的累计损失


        5️⃣ 内层循环:遍历每个 batch

for images, labels in train_loader:images, labels = images.to(device), labels.to(device)
  • 加载一个 batch 的图像和标签,并移动到指定设备(GPU/CPU)

optimizer.zero_grad()        # 梯度清零
outputs = model(images)      # 正向传播
loss = criterion(outputs, labels)  # 计算损失
loss.backward()              # 反向传播,计算梯度
optimizer.step()             # 更新权重

这一套就是典型的 PyTorch 训练步骤:
清零 → 前向传播 → 计算损失 → 反向传播 → 更新参数

running_loss += loss.item()
  • .item() 获取损失的浮点数,累加总损失


        6️⃣ 每个 epoch 打印平均损失

print(f"Epoch {epoch + 1}, Loss: {running_loss / len(train_loader):.4f}")
  • 将这个 epoch 的平均损失打印出来,方便观察训练情况。


        7️⃣ 返回训练好的模型

return model

📈 总结:

        代码实现了一个完整的 PyTorch 训练流程,是一个标准、干净、可复用的模型训练函数模板。适合用于任何图像分类模型(只要输出是 logits 且标签是整数类别 ID)。用于训练神经网络模型的函数 train_model的完整代码如下:

import torch
import torch.nn as nn
def train_model(model, train_loader, epochs=10):device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = model.to(device)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)model.train()for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:images, labels = images.to(device), labels.to(device)optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch + 1}, Loss: {running_loss / len(train_loader):.4f}")return model

        4.编写预测函数

        实现了一个 数字验证码图像的预测函数 predict,其核心目的是:

  加载训练好的模型参数,读取指定图像,并输出预测的数字类别(0~9 之间的某个整数)。


        ✅ 总体功能:

def predict(model_path, image_path) -> int

        输入:

  • model_path: 训练好的模型参数文件(如 .pt.pth

  • image_path: 一张数字图像文件路径(如 "7abc.png"

        输出:

  • 返回图像中预测的数字(0~9 的某一个 int


        🔍 编写逻辑

        1️⃣ 模型初始化与加载

model = DigitNet()
model.load_state_dict(torch.load(model_path))
  • 初始化模型结构(必须和训练时一致)

  • 加载模型参数(state_dict 是权重的字典)


        2️⃣ 设备选择

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
  • 自动判断是否使用 GPU

  • 模型放到 GPU 或 CPU 上


        3️⃣ 图像预处理

transform = transforms.Compose([transforms.Resize((15, 13)),       # 调整图像大小为 (高15, 宽13)transforms.Grayscale(),            # 转为灰度图(1 通道)transforms.ToTensor(),             # 转为 Tensor 格式并归一化到 [0,1]transforms.Normalize(0.5, 0.5)     # 再归一化到 [-1,1],与训练一致!
])
  • 这一部分非常重要,必须和训练时的图像处理完全一致,否则模型可能无法正确识别。


        4️⃣ 模型评估模式

model.eval()
  • 设置为推理模式:关闭 Dropout、BatchNorm 的训练行为。


        5️⃣ 不计算梯度(节省资源)

with torch.no_grad():
  • 在推理阶段关闭反向传播计算,提升效率、节省内存。


        6️⃣ 图像读取与转换

image = Image.open(image_path)
tensor = transform(image).unsqueeze(0).to(device)
  • 使用 PIL 打开图像

  • 用 transform 进行预处理

  • unsqueeze(0):在第 0 维添加 batch 维度(变为 [1, 1, 15, 13]


        7️⃣ 模型推理与预测

output = model(tensor)                # 模型输出 logits (形如 [1,10])
predicted = torch.argmax(output, dim=1)  # 取最大值所在的类别索引
  • torch.argmax(..., dim=1) 表示在分类维度上找到概率最大的类别(0~9)


        8️⃣ 返回预测结果

return predicted.item()
  • .item() 将张量转换为 Python 的 int 类型


        📌 示例使用:

digit = predict("model.pth", "data/7abc.png")
print("预测结果为:", digit)

        ✅ 总结:

模块作用
DigitNet()构造模型结构
load_state_dict()加载训练权重
transform()图像预处理
model.eval()设置为评估模式
torch.no_grad()关闭梯度计算
argmax(output)选出最可能的数字标签

        预测函数 predict的完整代码如下:

import torch
#from mymodel import DigitNet
from model import DigitNet
from torchvision import transforms  # 正确导入图像变换模块
from PIL import Image               # 标准PIL导入方式#切割好的单数字的验证码图像预测
def predict(model_path, image_path):# 初始化模型结构model = DigitNet()# 加载参数model.load_state_dict(torch.load(model_path))# 设备选择device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(device)model = model.to(device)# 图像预处理(需与训练时一致)transform = transforms.Compose([transforms.Resize((15, 13)),transforms.Grayscale(),transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])# 执行预测model.eval()with torch.no_grad():image = Image.open(image_path)tensor = transform(image).unsqueeze(0).to(device)output = model(tensor)predicted = torch.argmax(output, dim=1)return predicted.item()

        步骤四:训练、模型保存和预测

        直接上代码:

from model import DigitNet
from train import train_model
from dataset import  train_loader
import torch
from predict import predictif __name__ == "__main__":# 初始化模型model = DigitNet()# 训练模型trained_model = train_model(model, train_loader, epochs=20)# 保存模型torch.save(trained_model.state_dict(), "digit_model1.pth")# 预测#print(predict('digit_model1.pth', "test/3_39031.jpeg"))

        运行以上代码,即进行模型训练,训练完成后保存模型 digit_model1.pth 。预测的时候则注释掉模型训练和保存模型的代码,取消后面预测代码的注释。

        声明

        本例附上完整的项目代码和数据,仅供学习和交流,不可作为其它用途。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/news/908947.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

常用ADB命令

ADB:Android Debug Bridge,Android 调试桥。 是一个命令行工具,主要用于在开发过程中实现计算机与Android设备之间的通信。 ADB工具允许开发者执行一系列调试操作,如安装应用、管理应用的生命周期、读取日志数据、执行shell命令等…

JavaScript BOM 详细介绍

JavaScript BOM (Browser Object Model) 详细介绍 BOM (Browser Object Model) 是浏览器对象模型,它提供了与浏览器窗口交互的对象和方法,允许 JavaScript 与浏览器"对话"。 1. BOM 概述 BOM 的核心是 window 对象,它代表浏览器…

DeepSeek生成流程图

通过DeepSeek生成代码 请用 Mermaid 语法生成一个电商订单处理流程的流程图,流程包括用户下单、订单审核、库存检查、生成发货单、发货以及各个环节可能出现的分支情况,如订单审核不通过返回修改,库存不足通知用户等 打开在线绘图 Flowchart…

WebGL与Three.js:从基础到应用的关系与原理解析

WebGL 和 Three.js 是现代网页中实现 3D 图形和动画的两大关键技术。尽管它们有着紧密的关系,但它们在功能和使用场景上有所不同。简单来说,WebGL 是一个底层图形库,提供了对计算机 GPU 的直接访问,而 Three.js 则是建立在 WebGL …

Spring Boot消息系统开发指南

消息系统基础概念 消息系统作为分布式架构的核心组件,实现了不同系统模块间的高效通信机制。其应用场景从即时通讯软件延伸至企业级应用集成,形成了现代软件架构中不可或缺的基础设施。 通信模式本质特征 同步通信要求收发双方必须同时在线交互&#…

JavaWeb笔记

六、MVC模式 ✅ Model(模型) 职责:处理数据和业务逻辑。 负责数据的存储、读取和操作。 包含业务规则和逻辑。 ✅ View(视图) 职责:展示界面和接收用户输入。 把数据以可视化的形式呈现给用户。 不处…

解决启动SpringBoot是报错Command line is too long的问题

文章目录 错误全称原因解决方法(一图到底) 错误全称 在启动springBoot项目时,会报错: Error running Application. Command line is too long. Shorten the command line via JAR manifest 原因 命令行太长的原因导致SpringBoot和…

DAY47打卡

DAY 47 注意力热图可视化 昨天代码中注意力热图的部分顺移至今天 知识点回顾:热力图(代码学习在day46天) 作业:对比不同卷积层热图可视化的结果 通道注意力热图的代码整体结构与核心功能 数据处理:对 CIFAR-10 数据集进…

Java在word中指定位置插入图片。

Java使用(Poi-tl) 在word(docx)中指定位置插入图片 Poi-tl 简介Maven 依赖配置Poi-tl 实现原理与步骤1. 模板标签规范2.完整实现代码3.效果展示 Poi-tl 简介 Poi-tl 是基于 Apache POI 的 Java 开源文档处理库,专注于…

迁移科技:破解纸箱拆垛场景的自动化升级密码

一、当传统拆垛遇上智能视觉:一场效率革命的必然选择 在汽车制造基地的物流中转区,每天有超过2万件零部件纸箱需要完成拆垛分拣。传统人工拆垛面临三大挑战: 效率瓶颈:熟练工人每小时处理量不超过200箱安全隐患:重型…

redis和redission的区别

Redis 和 Redisson 是两个密切相关但又本质不同的技术,它们扮演着完全不同的角色: Redis: 内存数据库/数据结构存储 本质: 它是一个开源的、高性能的、基于内存的 键值存储数据库。它也可以将数据持久化到磁盘。 核心功能: 提供丰…

AIStarter 4.0 苹果版体验评测|轻松部署 ComfyUI 与 DeepSeek 的 AI 工具箱

最近在测试一款名为 AIStarter 4.0 的 AI 工具管理平台,主要用于在 Mac 系统上快速部署各类开源 AI 项目,如 ComfyUI 和 DeepSeek ,非常适合开发者、设计师及 AI 入门者使用。 通过简单的拖拽操作即可完成安装,支持普通下载与网盘…

ArcGIS Pro 3.4 二次开发 - 图形图层

环境:ArcGIS Pro SDK 3.4 + .NET 8 文章目录 图形图层1.1 创建图形图层1.2 访问GraphicsLayer1.3 复制图形元素1.4 移除图形元素2 创建图形元素2.1 使用CIMGraphic创建点图形元素2.2 使用CIMGraphic创建线图元素2.3 使用 CIMGraphic 的多边形图形元素2.4 使用CIMGraphic创建多…

《广度优先搜索》题集

1、模板题集 聚合一块 2、课内题集 寻找图中是否存在路径 钥匙和房间 受限条件下可到达节点的数目 3、课后题集 最少操作数 社交网络新来的朋友 Ignatius and the Princess I Collect More Jewels Gap Nightmare Remainder Ferry Loading III 连连看 诡异的楼梯 Open the …

界面组件DevExpress WPF中文教程:Grid - 如何获取行句柄?

DevExpress WPF拥有120个控件和库,将帮助您交付满足甚至超出企业需求的高性能业务应用程序。通过DevExpress WPF能创建有着强大互动功能的XAML基础应用程序,这些应用程序专注于当代客户的需求和构建未来新一代支持触摸的解决方案。 无论是Office办公软件…

零跑汽车5月交付45067台车,同比增长超148%

零跑汽车在5月交付新车45,067辆,同比大增148%,连续5个月实现单月交付量增长,稳居新势力交付量第一位置。今年1-5月,零跑累计交付新车达173,658辆,展现出强劲的市场竞争力和产品实力。 根据Q1财报,零跑不仅营…

netty中的粘包问题详解

一起来学netty 一、粘包问题的本质二、粘包问题的成因三、Netty中的解决方案1. 固定长度解码器(FixedLengthFrameDecoder)2. 行分隔符解码器(LineBasedFrameDecoder)3. 分隔符解码器(DelimiterBasedFrameDecoder)4. 长度字段解码器(LengthFieldBasedFrameDecoder)四、解…

【基础算法】枚举(普通枚举、二进制枚举)

文章目录 一、普通枚举1. 铺地毯(1) 解题思路(2) 代码实现 2. 回文日期(1) 解题思路思路一:暴力枚举思路二:枚举年份思路三:枚举月日 (2) 代码实现 3. 扫雷(2) 解题思路(2) 代码实现 二、二进制枚举1. 子集(1) 解题思路(2) 代码实现 2. 费解的…

利用ngx_stream_return_module构建简易 TCP/UDP 响应网关

一、模块概述 ngx_stream_return_module 提供了一个极简的指令&#xff1a; return <value>;在收到客户端连接后&#xff0c;立即将 <value> 写回并关闭连接。<value> 支持内嵌文本和内置变量&#xff08;如 $time_iso8601、$remote_addr 等&#xff09;&a…

Java如何权衡是使用无序的数组还是有序的数组

在 Java 中,选择有序数组还是无序数组取决于具体场景的性能需求与操作特点。以下是关键权衡因素及决策指南: ⚖️ 核心权衡维度 维度有序数组无序数组查询性能二分查找 O(log n) ✅线性扫描 O(n) ❌插入/删除需移位维护顺序 O(n) ❌直接操作尾部 O(1) ✅内存开销与无序数组相…