深度学习入门day4--手写数字识别初探

鱼书提供的代码可以在github找到。源码地址

环境配置部分可以看前面几篇博客,还是用Anaconda,运行下面代码,可以看哪个库缺失。

import importlib
import numpy as np
deps = {"torch": "torch","torchvision": "torchvision","timm": "timm","scipy": "scipy","matplotlib": "matplotlib","tensorboard": "tensorboard","open_clip_torch": "open_clip_torch","sklearn": "sklearn","wandb": "wandb","tqdm": "tqdm","fairscale": "fairscale","sentencepiece": "sentencepiece","PIL": "PIL","cv2": "opencv-python","gradio": "gradio","autofaiss": "autofaiss","diffusers": "diffusers","plotly": "plotly","easydict": "easydict","huggingface_hub": "huggingface_hub","transformers": "transformers","open3d": "open3d","openai": "openai","knn_cuda": "knn_cuda",
}for k, v in deps.items():try:importlib.import_module(k)print(f"[✓] {v}")except ImportError:print(f"[✗] {v}")

一、MNIST数据集

由0到9的数字图像构成,训练图像约6万张,测试图片约1万张。图片格式28×28的灰度图像(单通道)。各个像素取值在0~255之间

  • mnist_show.py
# coding: utf-8
import sys, os
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
import numpy as np
from dataset.mnist import load_mnist
from PIL import Image
def img_show(img):pil_img = Image.fromarray(np.uint8(img))pil_img.show()#训练图像,训练标签,测试图像,测试标签
(x_train, t_train), (x_test, t_test) = load_mnist(flatten=True, normalize=False)
print(x_train.shape)
print(t_train.shape)
print(x_test.shape)
print(t_test.shape)
img = x_train[0]
label = t_train[0]
print(label)  # 5print(img.shape)  # (784,)
img = img.reshape(28, 28)  # 把图像的形状变为原来的尺寸
print(img.shape)  # (28, 28)
img_show(img)

其中,load_mnist这个函数在datatest目录下的mnist.py中定义,负责下载数据集,并将其转化为NumPy数字。函数全称是load_mnist(normalize=True, flatten=True, one_hot_label=False)

  • 第一个参数normalize表示是否将输入图像正规化为0.0~1.0的值。否则保持原来的0~255。
  • 第二个参数flatten表示是否展开输入图像(变为一维数组)。若为True,输入图像会保存为由784个元素构成的一维数组。若为False,则输入图像为1×28×28的三维数组1.
  • 第三个参数one_hot_label表示是否把标签保存为one-hot表示,one-hot是仅正确解释标签为1,其余皆为0的数组,类似[0,0,0,1,0,0]。若为False,仅仅表示为2,3,4这样的简单正确解标签,为True时,就是one-hot。

由于flatten=True时读入的图像是一列Numpy数组。比如这样

因此为了显示成功,使用img.reshape(28, 28) , 把图像的形状变为原来的尺寸

二、神经网络推理处理

其中,输入层有28×28=784个神经单元,输出层10个神经元(0-9这10个类别)。然后还包括2个隐藏层,第一个隐藏层有50个神经单元,第2个隐藏层100个神经单元。(神经元数目可以自己设置)

#头文件解释
import sys, os
#sys:提供与 Python 解释器相关的功能(如模块搜索路径 sys.path)。
#os:提供与操作系统交互的功能(如文件路径操作 os.pardir)。
sys.path.append(os.pardir)  # 为了导入父目录的文件而进行的设定
#将父目录(os.pardir,通常是 "..")添加到 Python 的模块搜索路径 sys.path 中。
#目的:为了让 Python 能够找到父目录中的模块(例如后续要导入的 dataset.mnist 和 common.functions)
import numpy as np
import pickle#内置模块,用于加载或保存模型/数据
from dataset.mnist import load_mnist#导入自定义模块中的函数,(/datatest/mnist.py)
  • 函数部分。
    get_data()负责加载 MNIST 数据集,并返回测试集的输入数据 x_test和标签 t_test。init_network()负责初始化网络读入保存在pickle中的sample_weight.pkl中学习到的权重参数。
    predict(network,x)负责前向传播,并返回预测概率。
def sigmoid(x):return 1 / (1 + np.exp(-x))
def softmax(x):return np.exp(x) / np.sum(np.exp(x), axis=0)def get_data():(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)return x_test, t_testdef init_network():with open("sample_weight.pkl", 'rb') as f:network = pickle.load(f)return networkdef predict(network, x):W1, W2, W3 = network['W1'], network['W2'], network['W3']b1, b2, b3 = network['b1'], network['b2'], network['b3']a1 = np.dot(x, W1) + b1z1 = sigmoid(a1)a2 = np.dot(z1, W2) + b2z2 = sigmoid(a2)a3 = np.dot(z2, W3) + b3y = softmax(a3)return y
  • 主程序整体流程。首先获取数据集,x获取图像数据,t获取标签数据(Numpy数组形式)。然后使用for循环,逐一取出保存在x中的图像数据,使用predict()进行预测,会返回一个含有各个标签对应概率的Numpy数组,[0,1,0.2,0.03...],表示识别为“0”,“1”,“2”的概率。然后我们用np,argmax(x)取出数组中最大概率值对应的索引。然后我们比较神经网络所预测的标签和正确答案标签,把正确的概率作为识别精度。
x, t = get_data()          # 加载测试数据
network = init_network()   # 加载预训练的网络权重
accuracy_cnt = 0           # 初始化正确预测的计数器for i in range(len(x)):y = predict(network, x[i])  # 对第 i 个样本进行预测p = np.argmax(y)            # 取概率最高的类别作为预测结果if p == t[i]:               # 如果预测正确accuracy_cnt += 1       # 计数器 +1print("Accuracy:" + str(float(accuracy_cnt) / len(x)))  # 计算准确率

项目结构是这样的,因为权重是加载好的,运行一下neuralnet_mnist.py这个文件。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/912768.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/912768.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STM32中定时器配置,HAL_Delay的原理,滴答定时器,微秒延时实现,PWM,呼吸灯

目录 定时器基本定时功能实现 CubeMX设置 手动书写代码部分 定时器启动 实现溢出回调函数 HAL_Delay介绍 HAL_Delay实现原理 HAL_Delay的优点 HAL_Delay的缺点 利用滴答定时器(SysTick)实现微秒级延时 PWM PWM介绍 通用定时器中的重要寄存器 PWM中的捕获比较通道 …

飞牛NAS(fnOS)详细安装教程

以下是飞牛NAS(fnOS)的详细安装教程,结合官方指南和社区实践整理而成: 一、准备工作 硬件需求 8GB或更大容量的U盘(用于制作启动盘)待安装设备(支持x86架构的物理机或迷你主机,如天钡…

springboot 显示打印加载bean耗时工具类

一 spring的原生接口说明 1.1 接口说明 Aware是Spring框架提供的一组特殊接口,可以让Bean从Spring容器中拿到一些资源信息。 BeanFactoryAware:实现该接口,可以访问BeanFactory对象,从而获取Bean在容器中的相关信息。 Environm…

OpenGL空间站场景实现方案

OpenGL空间站场景实现方案 需求分析 根据任务要求,我需要完成一个基于Nehe OpenGL的空间站场景,实现以下功能: 完整的空间站场景建模(包含多个模型和纹理贴图)Phong光照模型实现(包含多种光源和材质效果)摄像机键盘控制交互功能解决方案设计 技术栈 C++编程语言OpenG…

基于昇腾310B4的YOLOv8目标检测推理

YOLOv8目标检测 om 模型推理 本篇博客将手把手教你如何将 YOLOv8 目标检测模型部署到华为昇腾 310B4 开发板上进行高效推理(其他昇腾开发版也可参考此流程)。 整个流程包括: 模型格式转换(ONNX → OM)昇腾推理环境配…

前端跨域问题解决Access to XMLHttpRequest at xxx from has been blocked by CORS policy

在前端开发中,跨域资源共享(CORS)是一个常见的问题。它涉及到浏览器安全机制,防止网页从一个域获取资源时被另一个域阻止。错误信息如“Access to XMLHttpRequest at xxx from origin has been blocked by CORS policy”是典型的跨…

[ linux-系统 ] 软硬链接与动静态库

软硬链接 介绍 软链接 通过下图可以看出软链接和原始文件是两个独立的文件,因为软链接有着自己的inode编号: 具有独立的 inode ,也有独立的数据块,它的数据块里面保存的是指向的文件的路径,公用 inode 硬链接 通过…

3D 商品展示与 AR 试戴能为珠宝行业带来一些便利?

对于珠宝行业而言,长久以来,如何让消费者在做出购买决策之前,便能真切且直观地领略到珠宝独一无二的魅力,始终是横亘在行业发展道路上的一道棘手难题。而 3D 互动营销的横空出世,恰似一道曙光,完美且精准地…

电子电气架构 --- SOVD功能简单介绍

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 简单,单纯,喜欢独处,独来独往,不易合同频过着接地气的生活,除了生存温饱问题之外,没有什么过多的欲望,表面看起来很高冷,内心热情,如果你身…

【Java编程动手学】 Java中的运算符全解析

文章目录 一、引言二、算术运算符1、基本概念2、具体运算符及示例 三、关系运算符1、基本概念2、具体运算符及示例 四、自增减运算符1、基本概念2、具体运算符及示例 五、逻辑运算符1、基本概念2、具体运算符及示例 六、位运算符1、基本概念2、具体运算符及示例 七、移位运算符…

【前端】1 小时实现 React 简历项目

近期更新完毕。仅包括核心代码 目录结构 yarn.lock保证开发者每次能下载到同版本依赖,一般不需要特别留意 package.json 是 Node.js 项目、前端项目、npm/yarn的配置文件。 Dockerfile 是用来 定义 Docker 镜像构建过程的文本文件。它是一份脚本,告诉 …

python中的pydantic是什么?

Pydantic 是 Python 中一个用于数据验证和设置管理的库,主要通过 Python 类型注解(Type Hints)来定义数据结构,并自动验证输入数据的合法性。它广泛应用于 API 开发(如 FastAPI)、配置管理、数据序列化等场…

腾讯云市场目前饱和度

首先我需要理解市场饱和度的概念。市场饱和度通常指一个产品或服务在潜在市场中的渗透程度,高饱和度意味着市场增长空间有限,低饱和度则表明还有较大发展潜力。 从搜索结果看,腾讯云目前在中国云服务市场排名第三,市场份额约为15%…

EDR、NDR、XDR工作原理和架构及区别

大家读完觉得有帮助记得关注和点赞!!! EDR、NDR、XDR是网络安全中关键的检测与响应技术,它们在覆盖范围、数据源和响应机制上有显著差异。以下是它们的工作原理和架构详解: --- ### 🔍 一、EDR&#xff0…

vue3 + luckysheet 实现在线编辑Excel

效果图奉上: 引入的依赖: "dependencies": {"types/jquery": "^3.5.32","types/xlsx": "^0.0.36","jquery": "^3.7.1","xlsx": "^0.18.5",}在index.html中…

Linux下MinIO分布式安装部署

文章目录 一、MinIO简单说明二、MinIO分布式安装部署1、关闭SELINUX2、开启防火墙2.1、关闭firewall:2.2、安装iptables防火墙 3、安装MinIO4、添加MinIO集群控制脚本4.1添加启动脚本4.2添加关闭脚本 5、MinIO控制台使用 一、MinIO简单说明 1、MinIO是一个轻量的对…

Codeforces Round 980 (Div. 2)

ABC 略 D 这个过程一定是由1向后跳的过程中穿插有几次向前一步一步走。直到跳到一个位置后再把前面所有没有走过的位置倒序走一遍。总分就等于最大位置的前缀和-前面所有起跳位置和。前缀和固定我们只需要求到每个位置的最小起跳和即可。对于这个向后跳和向前走的过程我们可以…

Langchain实现rag功能

RAG(检索增强生成)的核心是通过外部知识库增强大模型回答的准确性和针对性,其工作流程与优化策略如下: 一、RAG 核心流程 ‌知识库构建‌ ‌文档加载与分割‌:将非结构化文档(PDF、Markdown等)…

算法笔记上机训练实战指南刷题

算法笔记上机训练实战指南刷题记录 文章目录 算法笔记上机训练实战指南刷题记录模拟B1001 害死人不偿命的(3n1)猜想B1011 AB 和 CB1016 部分ABB1026 程序运行时间B1046划拳B1008数组元素循环右移问题B1012 数字分类B1018 锤子剪刀布A1042 Shuffling Machine 每天两题&#xff0…

MYSQL基础内容

一、介绍 1.不用数据库:使用IO流对数据进行管理 2.使用数据库:使用SQL语句对开发的数据进行管理,能储存上亿条数据 3.MYSQL: 是流行的关系型数据库管理系统之一,将数据保存在不同的数据表中,通过表与表之…