1. pytorch手写数字预测

1. pytorch手写数字预测

  • 1.背景
  • 2.准备数据集
  • 2.定义模型
  • 3.dataloader和训练
  • 4.训练模型
  • 5.测试模型
  • 6.保存模型

1.背景

因为自身的研究方向是多模态目标跟踪,突然对其他的视觉方向产生了兴趣,所以心血来潮的回到最经典的视觉任务手写数字预测上来,所以这份教程并不是一份非常详尽的教程,是在一部分pytorch,深度学习基础上的教程,如果需要的是非常保姆级的教程建议看别的文章

2.准备数据集

这里我才用了直接导torchvision中的dataset包来下载Mnist数据集,也算是一个非常经典的数据集了

# 导入数据集
from torchvision.datasets import MNIST
import torch# 设置随机种子
torch.manual_seed(3306)# 数据预处理
from torchvision import transforms
# 定义数据转换
transform = transforms.Compose([transforms.ToTensor(),  # 转换为 Tensortransforms.Normalize((0.1307,), (0.3081,))  # 标准化
])# 下载 MNIST 数据集
mnist_train = MNIST(root='./dataset_file/mnist_raw', train=True, download=True,transform=transform)
mnist_test = MNIST(root='./dataset_file/mnist_raw', train=False, download=True,transform=transform)
# 查看数据集大小
print(f"MNIST train dataset size: {len(mnist_train)}")
print(f"MNIST test dataset size: {len(mnist_test)}")

其中,MNIST()中的root代表的是数据集存放的位置,download代表是如果当前位置没有数据集是否需要下载。
transformer则是对数据的处理方式,我这里采用了简单地转成tensor和简单地标准化。

不过这样子下载下来的数据集是二进制格式的,无法直接查看图片,当然,如果你需要查看图片,也有办法。

# 查看图片
import matplotlib.pyplot as pltdef show_image(id):img, label = mnist_train[id]img = img.squeeze().numpy()  # 去掉通道维度print(img.shape)# print(img)plt.imshow(img, cmap='gray')plt.title(f"Label: {label}")plt.axis('off')plt.show()show_image(1)

效果
在这里插入图片描述

又或者你想要下载的数据集是图片格式,我这里也准备了代码

代码是在别人的基础上改的,其中数据集存放路径是dataset_dir,如果需要修改自行打印然后修改位置就好了。

#!/usr/bin/env python3
# -*- encoding utf-8 -*-'''
@File: save_mnist_to_jpg.py
@Date: 2024-08-23
@Author: KRISNAT
@Version: 0.0.0
@Email: ****
@Copyright: (C)Copyright 2024, KRISNAT
@Desc:1. 通过 torchvision.datasets.MNIST 下载、解压和读取 MNIST 数据集;2. 使用 PIL.Image.save 将 MNIST 数据集中的灰度图片以 JPEG 格式保存。
'''import sys, os
sys.path.insert(0, os.getcwd())from torchvision.datasets import MNIST
import PIL
from tqdm import tqdmif __name__ == "__main__":home_dir = os.path.abspath('.')root = os.path.abspath(os.path.join(home_dir, '../dataset_file'))print(root)# exit(0)# 图片保存路径dataset_dir = os.path.join(root, 'mnist_jpg')if not os.path.exists(dataset_dir):os.makedirs(dataset_dir)# 从网络上下载或从本地加载MNIST数据集# 训练集60K、测试集10K# torchvision.datasets.MNIST接口下载的数据一组元组# 每个元组的结构是: (PIL.Image.Image image model=L size=28x28, 标签数字 int)training_dataset = MNIST(root='mnist',train=True,download=True,)test_dataset = MNIST(root='mnist',train=False,download=True,)# 保存训练集图片with tqdm(total=len(training_dataset), ncols=150) as pro_bar:for idx, (X, y) in enumerate(training_dataset):f = dataset_dir + "/" + "training_" + str(idx) + \"_" + str(training_dataset[idx][1] ) + ".jpg"  # 文件路径training_dataset[idx][0].save(f)pro_bar.update(n=1)# 保存测试集图片with tqdm(total=len(test_dataset), ncols=150) as pro_bar:for idx, (X, y) in enumerate(test_dataset):f = dataset_dir + "/" + "test_" + str(idx) + \"_" + str(test_dataset[idx][1] ) + ".jpg"  # 文件路径test_dataset[idx][0].save(f)pro_bar.update(n=1)

2.定义模型

这里我准备了两个模型,一个MLP模型和一个简单地CNN模型,其中MLP模型参数量1M,CNN模型参数量大概8M,当然这俩模型也没有很仔细的规划

import torch
import torch.nn as nnclass DigitLinear(nn.Module):def __init__(self):super(DigitLinear, self).__init__()self.fc1 = nn.Linear(28 * 28, 1000)self.fc2 = nn.Linear(1000, 500)self.dropout = nn.Dropout(0.3)self.fc3 = nn.Linear(500, 10)def forward(self, x):x = x.view(-1, 28 * 28)x = self.fc1(x)x = torch.relu(x)x = self.dropout(x)x = self.fc2(x)x = torch.relu(x)x = self.fc3(x)return xclass DigitCNN(nn.Module):def __init__(self):super(DigitCNN,self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)self.fc1 = nn.Linear(64*28*28, 128)self.dropout = nn.Dropout(0.1)self.fc2 = nn.Linear(128, 10)def forward(self, x):# print("x.shape:", x.shape)B,N,H,W = x.shapex = self.conv1(x)x = torch.relu(x)x = self.conv2(x)x = torch.relu(x)x = x.view(B, -1)  # 展平x = self.fc1(x)x = torch.relu(x)x = self.dropout(x)x = self.fc2(x)return x

3.dataloader和训练

这里的代码就很简单了,就是一些参数的选择,例如epoch,batchsize。其中的训练函数我写的买有很全面,只是勉强满足了训练功能,还有好多可以优化的点,比如打印fps,断点续训练啥的,不过这个任务提不起劲去干这事,大家可以自行优化。

# 数据加载器
from torch.utils.data import DataLoader
from lib.model.DigitModel import DigitLinear,DigitCNN
# 定义数据加载器
batch_size = 256
train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=False)epoch = 50
# 训练模型
net = DigitLinear() # 参数量1M 97.50%
# net = DigitCNN() # 参数量8M 98.81%
net.cuda()# 定义损失函数和优化器
import torch.optim as optim
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 训练函数def train_model(model, train_loader, criterion, optimizer, num_epochs=10):model.train()  # 设置模型为训练模式for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0for i, (inputs, labels) in enumerate(train_loader):inputs= inputs.cuda()y = torch.tensor(torch.zeros((inputs.shape[0],10), dtype=torch.float)).cuda()y[torch.arange(inputs.shape[0]), labels] = 1optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, y)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels.cuda()).sum().item()epoch_loss = running_loss / len(train_loader)epoch_acc = 100. * correct / totalprint(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%')# 训练模型
train_model(net, train_loader, criterion, optimizer, num_epochs=epoch)

4.训练模型

有了上面的代码就可以开始训练了,我这里训练的截图是我的MLP模型,效果不是很好,CNN的效果稍微好一点,比MLP高1%,但是图忘记截了。反正够用了,因为本身MNIST的数据就不是很完美,有很多类似于噪声的数据例如:
在这里插入图片描述
这些数字我人眼都分不出是什么玩意。

训练效果如下
在这里插入图片描述

5.测试模型

训练完当然是测试了
最后我的MLP模型跑了97.50%的准确率

代码如下

# 测试模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.eval()
correct = 0
total = 0
with torch.no_grad():for inputs, labels in test_loader:inputs, labels = inputs.to(device).float(), labels.to(device).float()outputs = net(inputs)_, predicted = outputs.max(1)total += labels.size(0)correct += predicted.eq(labels.cuda()).sum().item()# print(f"Predicted: {predicted}, Ground Truth: {targets}")print(f"Accuracy: {correct / total * 100:.4f} %")

在这里插入图片描述

6.保存模型

保存模型代码就更简单了

# 保存模型
torch.save(net.state_dict(), './digit_model.pth')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/bicheng/82992.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

AWS WebRTC:获取ICE服务地址(part 2): ICE Agent的作用

上一篇,已经获取到了ICE服务地址,从返回结果中看,是两组TURN服务地址。 拿到这些地址有什么用呢?接下来就要说到WebRTC中ICE Agent的作用了,返回的服务地址会传给WebRTC最终给到ICE Agent。 ICE Agent的作用&#xf…

大数据时代的利剑:Bright Data网页抓取与自动化工具共建高效数据采集新生态

目录 一、为何要选用Bright Data网页自动化抓取——帮助我们高效高质解决以下问题! 二、Bright Data网页抓取工具 - 网页爬虫工具实测 2.1 首先注册用户 2.2 首先点击 Proxies & Scraping ,再点击浏览器API的开始使用 2.3 填写通道名称&#xff…

指纹识别+精准化POC攻击

开发目的 解决漏洞扫描器的痛点 第一就是扫描量太大,对一个站点扫描了大量的无用 POC,浪费时间 指纹识别后还需要根据对应的指纹去进行 payload 扫描,非常的麻烦 开发思路 我们的思路分为大体分为指纹POC扫描 所以思路大概从这几个方面…

【Day40】

DAY 40 训练和测试的规范写法 知识点回顾: 彩色和灰度图片测试和训练的规范写法:封装在函数中展平操作:除第一个维度batchsize外全部展平dropout操作:训练阶段随机丢弃神经元,测试阶段eval模式关闭dropout 作业&#x…

【HTML-13】HTML表格合并技术详解:打造专业数据展示

表格是HTML中展示结构化数据的重要元素,而表格合并则是提升表格表现力的关键技术。本文将全面介绍HTML中的表格合并方法,帮助您创建更专业、更灵活的数据展示界面。 1. 表格合并基础概念 在HTML中,表格合并主要通过两个属性实现&#xff1a…

<uniapp><threejs>在uniapp中,怎么使用threejs来显示3D图形?

前言 本专栏是基于uniapp实现手机端各种小功能的程序,并且基于各种通讯协议如http、websocekt等,实现手机端作为客户端(或者是手持机、PDA等),与服务端进行数据通讯的实例开发。 发文平台 CSDN 环境配置 系统:windows 平台:visual studio code、HBuilderX(uniapp开…

如何制作全景VR图?

全景VR图,特别是720度全景VR,为观众提供一种沉浸式体验。 全景VR图能够捕捉场景的全貌,还能将多个角度的图片或视频无缝拼接成一个完整的全景视角,让观众在虚拟环境中自由探索。随着虚拟现实(VR)技术的飞速…

前端使用qrcode来生成二维码的时候中间添加logo图标

这个开源仓库可以让你在前端页面中生成二维码图片,并且支持调整前景色和背景色,但是有个问题,就是不能添加logo图片。issue: GitHub Where software is built 但是已经有解决方案了: add a logo picture Issue #21…

【C语言】函数指针及其应用

目录 1.1 函数指针的概念和应用 1.2 赋值与内存模型 1.3 调用方式与注意事项 二、函数指针的使用 2.1 函数指针的定义和访问 2.2 动态调度:用户输入驱动函数执行 2.3 函数指针数组进阶应用 2.4 函数作为参数的高阶抽象 三、回调函数 3.1 指针函数…

安装flash-attention失败的终极解决方案(WINDOWS环境)

想要看linux版本下安装问题的请走这里:安装flash-attention失败的终极解决方案(LINUX环境) 其实,现在的flash-attention不像 v2.3.2之前的版本,基本上不兼容WINDOWS环境。但是在WINDOWS环境安装总还是有那么一点不顺畅…

[C]基础16.数据在内存中的存储

博客主页:向不悔本篇专栏:[C]您的支持,是我的创作动力。 文章目录 0、总结1、整数在内存中的存储1.1 整数的二进制表示方法1.2 不同整数的表示方法1.3 内存中存储的是补码 2、大小端字节序和字节序判断2.1 什么是大小端2.2 为什么有大小端2.3…

Python 基于卷积神经网络手写数字识别

Ubuntu系统:22.04 python版本:3.9 安装依赖库: pip install tensorflow2.13 matplotlib numpy -i https://mirrors.aliyun.com/pypi/simple 代码实现: import tensorflow as tf from tensorflow.keras.models import Sequent…

ElectronBot复刻-电路测试篇

typec-16p 接口部分 USB1(Type - C 接口):这是通用的 USB Type - C 接口,具备供电和数据传输功能。 GND 引脚(如 A1、A12、B1、B12 等):接地引脚,用于提供电路的参考电位&#xff0…

ESP8266+STM32 AT驱动程序,心知天气API 记录时间: 2025年5月26日13:24:11

接线为 串口2 接入ESP8266 esp8266.c #include "stm32f10x.h"//8266预处理文件 #include "esp8266.h"//硬件驱动 #include "delay.h" #include "usart.h"//用得到的库 #include <string.h> #include <stdio.h> #include …

CDN安全加速:HTTPS加密最佳配置方案

CDN安全加速的HTTPS加密最佳配置方案需从证书管理、协议优化、安全策略到性能调优进行全链路设计&#xff0c;以下是核心实施步骤与注意事项&#xff1a; ​​一、证书配置与管理​​ ​​证书选择与格式​​ ​​证书类型​​&#xff1a;优先使用受信任CA机构颁发的DV/OV/EV证…

【前端】Twemoji(Twitter Emoji)

目录 注意使用Vue / React 项目 验证 Twemoji 的作用&#xff1a; Twemoji 会把你网页/应用中的 Emoji 字符&#xff08;如 &#x1f604;&#xff09;自动替换为 Twitter 风格的图片&#xff08;SVG/PNG&#xff09;&#xff1b; 它不依赖系统字体&#xff0c;因此在 Android、…

GCN图神经网络的光伏功率预测

一、GCN图神经网络的核心优势 图结构建模能力 GCN通过邻接矩阵&#xff08;表示节点间关系&#xff09;和节点特征矩阵&#xff08;如气象数据、历史功率&#xff09;进行特征传播&#xff0c;能够有效捕捉光伏电站间的空间相关性。其核心公式为&#xff1a; H ( l 1 ) σ (…

按照状态实现自定义排序的方法

方法一&#xff1a;使用 MyBatis-Plus 的 QueryWrapper 自定义排序 在查询时动态构建排序规则&#xff0c;通过 CASE WHEN 语句实现优先级排序&#xff1a; import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; import org.springframework.stereotype.Ser…

【计算机网络】IPv6和NAT网络地址转换

IPv6 IPv6协议使用由单/双冒号分隔一组数字和字母&#xff0c;例如2001:0db8:85a3:0000:0000:8a2e:0370:7334&#xff0c;分成8段。IPv6 使用 128 位互联网地址&#xff0c;有 2 128 2^{128} 2128个IP地址无状态地址自动配置&#xff0c;主机可以通过接口标识和网络前缀生成全…

【Redis】string

String 字符串 字符串类型是 Redis 最基础的数据类型&#xff0c;关于字符串需要特别注意&#xff1a; 首先 Redis 中所有的键的类型都是字符串类型&#xff0c;而且其他几种数据结构也都是在字符串的基础上构建的。字符串类型的值实际可以是字符串&#xff0c;包含一般格式的…