CNN手写数字识别/全套源码+注释可直接运行

数据集选择:

MNIST数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST)。训练集(training set)由来自250个不同人手写的数字构成,其中50%是高中学生,50%来自人口普查局(the Census Bureau)的工作人员。测试集(test set)也是同样比例的手写数字数据,但保证了测试集和训练集的作者集不相交。MNIST数据集一共有7万张图片,其中6万张是训练集,1万张是测试集。每张图片是28 × 28 28\times 2828×28的0 − 9 0-90−9的手写数字图片组成。每个图片是黑底白字的形式,黑底用0表示,白字用0-1之间的浮点数表示,越接近1,颜色越白。

图片的标签以一维数组的one-hot编码形式给出:

[ 0 , 0 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 ]
每个元素表示图片对应的数字出现的概率,显然,该向量标签表示的是数字5。

MNIST数据集下载地址是http://yann.lecun.com/exdb/mnist/,它包含了4个部分:

训练数据集:train-images-idx3-ubyte.gz (9.45 MB,包含60,000个样本)。
训练数据集标签:train-labels-idx1-ubyte.gz(28.2 KB,包含60,000个标签)。
测试数据集:t10k-images-idx3-ubyte.gz(1.57 MB ,包含10,000个样本)。
测试数据集标签:t10k-labels-idx1-ubyte.gz(4.43 KB,包含10,000个样本的标签)。

废话不多说,首先看成果:
在这里插入图片描述
本项目请按照以下架构搭建:
在这里插入图片描述
下面是各个文件的python代码:
cnn_model.py:

模型架构和训练

# 导入必要的库
import torch
import torch.nn.functional as f  # 包含常用激活函数和操作
import torch.optim as optim  # 优化算法模块
from DataSet.mnist_set import mnist_set  # 自定义MNIST数据集加载器# 定义神经网络模型
class Net(torch.nn.Module):def __init__(self):super(Net, self).__init__()# 定义第一个卷积层:输入通道1(灰度图),输出通道10,卷积核5x5self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)# 定义第二个卷积层:输入通道10,输出通道20,卷积核5x5self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)# 定义最大池化层,窗口大小2x2,用于下采样self.pooling = torch.nn.MaxPool2d(2)# 全连接层:输入320个特征(由图像尺寸计算得到),输出10个类别(MNIST有0-9十个数字)self.fc = torch.nn.Linear(320, 10)def forward(self, x):batch_size = x.size(0)  # 获取当前批次大小# 第一层卷积 -> 池化 -> ReLU激活x = f.relu(self.pooling(self.conv1(x)))# 第二层卷积 -> 池化 -> ReLU激活x = f.relu(self.pooling(self.conv2(x)))# 将四维张量展平为二维:[batch_size, channels*width*height]x = x.view(batch_size, -1)  # -1表示自动计算该维度大小# 通过全连接层得到最终输出(未使用softmax,因为CrossEntropyLoss会自动处理)x = self.fc(x)return x# 创建模型实例
model = Net()
# 检测GPU可用性,并设置设备(GPU优先)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("是否使用GPU", torch.cuda.is_available())
model.to(device)  # 将模型转移到选定的设备(GPU/CPU)# 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()  # 交叉熵损失(适用于分类任务)
optimizer = optim.SGD(  # 随机梯度下降优化器model.parameters(),  # 需要优化的模型参数lr=0.01,  # 学习率momentum=0.5  # 动量参数,加速收敛
)def train(epoch, train_loader):""" 模型训练函数:param epoch: 当前训练轮次:param train_loader: 训练数据加载器"""running_loss = 0.0  # 累计损失值# 遍历训练数据(enumerate自动生成批次索引)for batch_idx, data in enumerate(train_loader, 0):inputs, target = data  # 解包数据(输入图像,目标标签)inputs, target = inputs.to(device), target.to(device)  # 数据转移至设备optimizer.zero_grad()  # 清空梯度(防止梯度累积)outputs = model(inputs)  # 前向传播loss = criterion(outputs, target)  # 计算损失loss.backward()  # 反向传播计算梯度optimizer.step()  # 更新模型参数running_loss += loss.item()  # 累加损失值# 每300个batch打印一次训练状态(300是任意选择的打印频率)if batch_idx % 300 == 299:print('[%d, %.5d] loss: %.3f' %(epoch + 1, batch_idx + 1, running_loss / 2000))running_loss = 0.0  # 重置累计损失def test(test_loader):""" 模型测试函数:param test_loader: 测试数据加载器"""correct = 0  # 正确预测数total = 0  # 总样本数with torch.no_grad():  # 禁用梯度计算(节省内存,加速计算)for data in test_loader:inputs, target = datainputs, target = inputs.to(device), target.to(device)outputs = model(inputs)# 获取预测结果(返回最大值和对应索引,这里取索引即类别)_, predicted = torch.max(outputs.data, dim=1)total += target.size(0)  # 累加批次样本总数correct += (predicted == target).sum().item()  # 统计正确预测数# 打印测试准确率(正确数/总数)print('Accuracy on test set: %d %% [%d/%d]' %(100 * correct / total, correct, total))if __name__ == '__main__':# 加载数据集train_loader, test_loader = mnist_set()# 训练循环for epoch in range(10):train(epoch, train_loader)test(test_loader)  # 每个epoch后测试# 训练完成后保存模型参数model_path = 'mnist_model.pth'torch.save(model.state_dict(), model_path)print(f'\n模型参数已保存至:{model_path}')# 初始化新模型实例loaded_model = Net().to(device)# 加载保存的权重loaded_model.load_state_dict(torch.load(model_path))print('\n模型参数加载验证完成')

model_use.py:

从数据集挑十张图片进行预测,使用保存的模型

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as f  # 包含常用激活函数和操作
import randomdef visualize_predictions(model, dataset, num_images=10):"""可视化模型预测结果Args:model: 加载好的模型dataset: 数据集对象(测试集)num_images: 需要可视化的图片数量"""# 设置为评估模式(影响Dropout和BatchNorm等层的计算)model.eval()device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 随机选择图片索引indices = random.sample(range(len(dataset)), num_images)# 创建画布fig, axes = plt.subplots(2, 5, figsize=(30, 12))plt.subplots_adjust(hspace=3, wspace=2)  # 调整子图间距for idx, ax in enumerate(axes.flat):# 获取数据image, true_label = dataset[idx]original_image = image.numpy().squeeze()  # 转换为numpy数组并去除通道维度# 预处理:添加批次维度并转移到设备image = image.unsqueeze(0).to(device)  # 形状从 [1,28,28] -> [1,1,28,28]# 预测with torch.no_grad():output = model(image)probabilities = f.softmax(output, dim=1)predicted_prob, predicted_label = torch.max(probabilities, 1)# 可视化设置ax.imshow(original_image, cmap='gray')ax.set_xticks([])ax.set_yticks([])# 标题显示预测结果(红色表示错误预测,蓝色表示正确)color = 'blue' if predicted_label == true_label else 'red'ax.set_title(f'Pred: {predicted_label.item()}' +f'True: {true_label}' +f'Prob: {predicted_prob.item():.1%}',color=color)plt.show()

mnist_set.py:

提供数据集的下载和导入,没有自动下载哦

from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader# 准备数据集
def mnist_set():""":param::return: train_loader, test_loader"""batch_size = 64transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])train_dataset = datasets.MNIST(root='../dataset/mnist/',train=True,download=True,transform=transform)train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size)test_dataset = datasets.MNIST(root='../dataset/mnist',train=False,download=True,transform=transform)test_loader = DataLoader(test_dataset,shuffle=False,batch_size=batch_size)return train_loader, test_loader

main.py:
在训练并保存模型后,调用保存的模型来进行手写预测

注意cv2库实际上要下载opencv-python

Python安装cv2(OpenCV)的终极指南:告别pip install cv2的坑!-CSDN博客

import cv2
import numpy as np
import torch
import timefrom CnnModel import Netclass DrawingApp:def __init__(self):# 窗口参数self.win_name = "MNIST Drawing Pad"self.win_size = (800, 400)self.pad_pos = (50, 80)  # 书写区域位置self.pad_size = (280, 280)self.preview_pos = (400, 80)# 初始化黑底画布(0=黑色,255=白色)self.img = np.zeros((self.pad_size[1], self.pad_size[0]), np.uint8)self.processed_img = np.zeros((28, 28), np.uint8)# 创建窗口cv2.namedWindow(self.win_name, cv2.WINDOW_NORMAL)cv2.resizeWindow(self.win_name, *self.win_size)cv2.setMouseCallback(self.win_name, self.mouse_handler)# 加载模型self.model = Net()self.model.load_state_dict(torch.load('CnnModel/mnist_model.pth'))self.model.eval()# 预测参数self.last_predict = {"pred": -1, "conf": 0.0}self.last_predict_time = 0def mouse_handler(self, event, x, y, flags, param):pad_x = x - self.pad_pos[0]pad_y = y - self.pad_pos[1]if (0 <= pad_x < self.pad_size[0]) and (0 <= pad_y < self.pad_size[1]):if event == cv2.EVENT_LBUTTONDOWN:self.drawing = Trueself.last_point = (pad_x, pad_y)elif event == cv2.EVENT_MOUSEMOVE and self.drawing:# 用白色(255)绘制线条cv2.line(self.img, self.last_point, (pad_x, pad_y), 255, 15)self.last_point = (pad_x, pad_y)elif event == cv2.EVENT_LBUTTONUP:self.drawing = Falsecv2.line(self.img, self.last_point, (pad_x, pad_y), 255, 15)else:self.drawing = Falsedef preprocess(self):"""预处理(保持黑底白字)"""resized = cv2.resize(self.img, (28, 28))# 直接归一化,保持黑底白字tensor_img = torch.from_numpy(resized).float() / 255.0# 存储处理后的图像用于显示self.processed_img = resizedreturn tensor_img.unsqueeze(0).unsqueeze(0)def update_ui(self):# 创建黑底背景canvas = np.zeros((self.win_size[1], self.win_size[0], 3), dtype=np.uint8)# 绘制书写区域边框(浅灰色)cv2.rectangle(canvas, self.pad_pos,(self.pad_pos[0] + self.pad_size[0], self.pad_pos[1] + self.pad_size[1]),(200, 200, 200), 2)# 嵌入书写内容(直接显示白字)canvas[self.pad_pos[1]:self.pad_pos[1] + self.pad_size[1],self.pad_pos[0]:self.pad_pos[0] + self.pad_size[0]] = cv2.cvtColor(self.img, cv2.COLOR_GRAY2BGR)# 显示预处理画面preview_size = 140preview_img = cv2.resize(self.processed_img, (preview_size, preview_size),interpolation=cv2.INTER_NEAREST)# 转换为彩色显示preview_display = cv2.cvtColor(preview_img, cv2.COLOR_GRAY2BGR)# 绘制预处理框(浅灰色)cv2.rectangle(canvas, self.preview_pos,(self.preview_pos[0] + preview_size, self.preview_pos[1] + preview_size),(200, 200, 200), 2)canvas[self.preview_pos[1]:self.preview_pos[1] + preview_size,self.preview_pos[0]:self.preview_pos[0] + preview_size] = preview_display# 添加文字说明(白色)cv2.putText(canvas, "Model Input (28x28)",(self.preview_pos[0], self.preview_pos[1] - 10),cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)# 显示预测结果(绿色文字)result_text = f"Prediction: {self.last_predict['pred']}" if self.last_predict['pred'] != -1 else "Draw a digit"conf_text = f"Confidence: {self.last_predict['conf']:.1f}%"cv2.putText(canvas, result_text, (20, 40),cv2.FONT_HERSHEY_DUPLEX, 0.9, (0, 255, 0), 2)cv2.putText(canvas, conf_text, (20, 80),cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 200, 200), 2)# 帮助文字(白色)help_text = "[SPACE] Clear  [ESC] Exit"cv2.putText(canvas, help_text, (self.win_size[0] - 250, self.win_size[1] - 20),cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)return canvasdef run(self):while True:# 自动预测if time.time() - self.last_predict_time > 0.5 and np.any(self.img != 0):processed_tensor = self.preprocess()with torch.no_grad():output = self.model(processed_tensor)prob, pred = torch.max(torch.nn.functional.softmax(output, 1), 1)self.last_predict = {"pred": pred.item(), "conf": prob.item() * 100}self.last_predict_time = time.time()# 更新界面display = self.update_ui()cv2.imshow(self.win_name, display)# 处理窗口变化new_size = cv2.getWindowImageRect(self.win_name)[2:]if new_size != self.win_size:self.win_size = new_sizeself.pad_size = (min(300, self.win_size[0] // 2 - 100), min(300, self.win_size[1] - 100))self.img = cv2.resize(self.img, self.pad_size)# 按键处理key = cv2.waitKey(1)if key == 27:breakelif key == 32:  # 空格键清除self.img = np.zeros((self.pad_size[1], self.pad_size[0]), np.uint8)self.processed_img = np.zeros((28, 28), np.uint8)self.last_predict = {"pred": -1, "conf": 0.0}if __name__ == "__main__":app = DrawingApp()app.run()cv2.destroyAllWindows()

最后如果报错的话注意路径即可,正常情况是能直接运行的,因为使用的相对路径
end

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/pingmian/82428.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

探秘谷歌Gemini:开启人工智能新纪元

一、引言 在人工智能的浩瀚星空中&#xff0c;每一次重大模型的发布都宛如一颗璀璨新星闪耀登场&#xff0c;而谷歌 Gemini 的亮相&#xff0c;无疑是其中最为耀眼的时刻之一。它的出现&#xff0c;犹如在 AI 领域投下了一颗重磅炸弹&#xff0c;引发了全球范围内的广泛关注与热…

小白场成长之路-计算机网络(三)

文章目录 一、网络参数配置1.图形化配置2.命令行配置2.1、ifconfig命令2.2ifup和ifdown子接口配置 2.3 多ip地址配置2.4子接口配置 总结 一、网络参数配置 1.图形化配置 NetworkManager&#xff0c;Linux7系统中&#xff0c;一般建议停止该管理方式&#xff1b;Linux8以上操作…

WireShark网络抓包—详细教程

本文仅用于技术研究&#xff0c;禁止用于非法用途。 Wireshark入门指南&#xff1a;从零开始掌握网络抓包分析 一、Wireshark是什么&#xff1f; Wireshark 是全球最受欢迎的开源网络协议分析工具&#xff0c;被广泛应用于网络故障排查、协议学习、网络安全分析等领域。它支…

区块链DApp的开发技术方案

区块链DApp开发技术方案&#xff1a;架构设计与实践指南 引言&#xff1a;DApp的技术革新与生态价值 区块链技术的去中心化特性与智能合约的自动化执行能力&#xff0c;推动DApp&#xff08;去中心化应用&#xff09;成为Web3.0的核心载体。截至2025年&#xff0c;全球DApp用…

Linux(3)——基础开发工具

目录 一、软件包管理器——yum 1.Linux下安装程序的方式 2.什么是yum 3.查找软件包 4.安装软件 5.本地与服务器端进行文件互传 6.卸载软件 二、Linux的编辑器——vim 1.基本概念 2.vim下各个模式之间的切换 3.vim在命令行模式下的命令汇总 4.vim在底行模式下的命令…

大数据学习(121)-sql重点问题

&#x1f34b;&#x1f34b;大数据学习&#x1f34b;&#x1f34b; &#x1f525;系列专栏&#xff1a; &#x1f451;哲学语录: 用力所能及&#xff0c;改变世界。 &#x1f496;如果觉得博主的文章还不错的话&#xff0c;请点赞&#x1f44d;收藏⭐️留言&#x1f4dd;支持一…

【QT】QString和QStringList去掉空格的方法总结

目录 一、QString去掉空格 1. 移除字符串首尾的空格&#xff08;trimmed&#xff09; 2. 移除字符串中的所有空格&#xff08;remove&#xff09; 3. 仅移除左侧&#xff08;开头&#xff09;或右侧&#xff08;结尾&#xff09;空格 4. 替换多个连续空格为单个空格 5. 移…

电脑 IP 地址修改工具,轻松实现异地登陆

在互联网时代&#xff0c;异地登陆需求日益频繁 —— 访问区域限制内容、跨区协作、优化游戏体验等场景&#xff0c;都需要通过修改 IP 地址实现。 一、IP 地址基础认知 IP 地址是设备的网络身份标识&#xff0c;不同地区分配不同 IP 段。通过修改 IP&#xff0c;可模拟目标地…

[BUG]Debian/Linux操作系统中 安装 curl等软件显示无候选安装(E: 软件包 curl 没有可安装候选)

本文内容组织形式 问题描述失效原因解决方案首先修改源列表为国内确认当前系统的版本Debian 11 (Bullseye)Debian 12 (Bookworm) 执行系统升级更新系统重新安装curl 结语 问题描述 日期&#xff1a;20250526 操作系统&#xff1a; debian darkchunkdebian:/home$ sudo apt i…

leetcode hot100刷题日记——12.反转链表

解答&#xff1a; /*** Definition for singly-linked list.* struct ListNode {* int val;* ListNode *next;* ListNode() : val(0), next(nullptr) {}* ListNode(int x) : val(x), next(nullptr) {}* ListNode(int x, ListNode *next) : val(x), next(n…

JavaSE核心知识点04工具04-01(JDK21)

&#x1f91f;致敬读者 &#x1f7e9;感谢阅读&#x1f7e6;笑口常开&#x1f7ea;生日快乐⬛早点睡觉 &#x1f4d8;博主相关 &#x1f7e7;博主信息&#x1f7e8;博客首页&#x1f7eb;专栏推荐&#x1f7e5;活动信息 文章目录 JavaSE核心知识点04工具04-01&#xff08;JD…

数据库入门:以商品订单系统为例

数据库入门&#xff1a;以商品订单系统为例 一、前言 数据库是现代软件开发中不可或缺的基础&#xff0c;掌握数据库的基本概念和操作&#xff0c;是每个开发者的必经之路。本文将以“商品-品牌-客户-订单-订单项”为例&#xff0c;带你快速入门数据库的核心知识和基本操作。…

UE失落方舟特效学习 笔记01

通过法线扭曲贴图 Begin Object Class/Script/UnrealEd.MaterialGraphNode Name"MaterialGraphNode_0" ExportPath"/Script/UnrealEd.MaterialGraphNode/Engine/Transient.M_RadialUV_01:MaterialGraph_0.MaterialGraphNode_0"Begin Object Class/Script/E…

跨境支付风控失效?用代理 IP 构建「地域 - 设备 - 行为」三维防护网

针对跨境支付风控失效问题&#xff0c;结合代理IP技术构建「地域-设备-行为」三维防护网是当前最有效的解决方案。以下是基于最新实践的技术路径与策略指南&#xff1a; 一、地域维度&#xff1a;IP地理特征精准匹配 IP属地真实性验证 优先选择住宅代理IP&#xff08;Residenti…

AI的“软肋”:架构设计与业务分析的壁垒

尽管人工智能&#xff08;AI&#xff09;在代码生成、数据分析等方面取得了显著进展&#xff0c;但在架构设计和业务分析的核心领域&#xff0c;人类的智慧和经验仍然是不可替代的。这些领域往往涉及高度的抽象思维、战略远见、对复杂商业逻辑的深刻理解以及在模糊不清的环境中…

【Redis实战篇】基于Redis的功能实现附近商铺查询(Geo),用户签到与统计(Bitmap),网站UV统计(HyperLogLog)

文章目录 附近商铺GEOSEARCH 实现语法参数解释 GEORADIUS 实现基本语法参数详解必选参数可选参数参数详解必选参数 代码实现 用户签到BitmapRedis 中 Bitmap 基本操作1. 设置位值2. 获取位值3. 统计位值为 1 的数量4. 位图运算 Spring Data Redis 中操作 Bitmap1. 操作示例(1) …

【C++高阶一】二叉搜索树

【C高阶一】二叉搜索树剖析 1.什么是二叉搜索树2.二叉搜索树非递归实现2.1插入2.2删除2.2.1删除分析一2.2.2删除分析二 2.3查找 3.二叉搜索树递归实现3.1插入3.2删除3.3查找 4.完整代码 1.什么是二叉搜索树 任何一个节点&#xff0c;他的左子树的所有节点都比他小&#xff0c;右…

前端面试热门知识点总结

URL从输入到页面展示的过程 版本1 1.用户在浏览器的地址栏输入访问的URL地址。浏览器会先根据这个URL查看浏览器缓存-系统缓存-路由器缓存&#xff0c;若缓存中有&#xff0c;直接跳到第6步操作&#xff0c;若没有&#xff0c;则按照下面的步骤进行操作。 2.浏览器根据输入的UR…

Swagger | 解决Springboot2.x/3.x不兼容和依赖报错等问题

目录 不兼容报错提醒 1. 修改Spring Boot版本 2. 修改application.yml配置文件 3. 使用其他替代方案 依赖兼容 配置 Yaml 文件 依赖报错提醒 解决方法 1. 选择一个库 2. 移除springfox依赖 3. 添加springdoc依赖 4. 配置springdoc 5. 清理项目 6. 启动项目 示例代…

C++默认构造函数、普通构造函数、拷贝构造、移动构造、委托构造及析构函数深度解析

目录 一、默认构造函数&#xff08;Default Constructor&#xff09;二、普通构造函数&#xff08;General Constructor&#xff09;三、拷贝构造函数&#xff08;Copy Constructor&#xff09;四、移动构造函数&#xff08;Move Constructor&#xff0c;C11&#xff09;五、委…