深度学习----卷积神经网络实现数字识别

一、准备工作

导入库,导入数据集,划分训练批次数量,规定训练硬件(这部分

import torch
from torch import nn  # 导入神经网络模块
from torch.utils.data import DataLoader  # 数据包管理工具,打包数据
from torchvision import datasets  # 封装了很多与图像相关的模型,和数据集
from torchvision.transforms import ToTensor  # 将其他数据类型转化为张量train_data = datasets.MNIST(root='data',train=True,  # 是否读取下载后数据中的训练集download=True,  # 如果之前下载过则不用下载transform=ToTensor()
)
test_data = datasets.MNIST(root='data',train=False,download=True,transform=ToTensor()
)train_dataloader = DataLoader(train_data,batch_size=256)#是一个类,现在初始化了,但没开始打包,训练开始才打包
test_dataloader = DataLoader(test_data,batch_size=256)device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using {device} device')

二、定义神经网络(重点)

这部分比较重要,我分开讲

1、类定义与继承

class CNN(nn.Module):

这里定义了一个名为CNN的类,它继承自 PyTorch 的nn.Module类。nn.Module是 PyTorch 中所有神经网络模块的基类,通过继承它,我们可以利用 PyTorch 提供的各种功能,如参数管理、设备迁移等。

2、初始化方法

def __init__(self):super().__init__()

这是类的构造函数,super().__init__()调用了父类nn.Module的构造函数,确保父类得到正确初始化。

3、网络层定义

  • 第一个卷积块(conv1)
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,      # 输入通道数,1表示灰度图像out_channels=16,    # 输出通道数/卷积核数量kernel_size=5,      # 卷积核大小5×5stride=1,           # 步长为1padding=2,          # 填充为2,保持特征图大小不变),nn.ReLU(),               # ReLU激活函数nn.MaxPool2d(kernel_size=2),  # 2×2最大池化
)

这个卷积块接收 1 通道的输入,通过 16 个 5×5 的卷积核进行卷积操作,然后经过 ReLU 激活和 2×2 的最大池化。

  • 第二个卷积块(conv2)
self.conv2 = nn.Sequential(nn.Conv2d(16,32,5,1,2),  # 16→32通道,5×5卷积核nn.ReLU(),nn.Conv2d(32,32,5,1,2),  # 32→32通道,5×5卷积核nn.ReLU(),nn.Conv2d(32,32,5,1,2),  # 32→32通道,5×5卷积核nn.ReLU(),nn.Conv2d(32,64,5,1,2),  # 32→64通道,5×5卷积核nn.ReLU(),nn.Conv2d(64,64,5,1,2),  # 64→64通道,5×5卷积核nn.ReLU(),nn.MaxPool2d(kernel_size=2),  # 2×2最大池化
)

这个卷积块包含多个卷积层,逐步增加通道数,并在最后进行一次最大池化。

  • 第三个卷积块(conv3)
self.conv3 = nn.Sequential(nn.Conv2d(64,64,5,1,2),  # 64→64通道,5×5卷积核nn.ReLU(),
)

这是一个简单的卷积块,保持通道数不变。

  • 全连接层(out)
self.out = nn.Linear(64*7*7,10)

这是网络的输出层,将卷积得到的特征图展平后映射到 10 个输出(可能对应 10 类分类问题)。

4、前向传播方法

def forward(self, x):x = self.conv1(x)    # 通过第一个卷积块x = self.conv2(x)    # 通过第二个卷积块x = self.conv3(x)    # 通过第三个卷积块x = x.view(x.size(0),-1)  # 展平特征图,保留批次维度output = self.out(x)  # 通过全连接层得到输出return output

forward方法定义了数据在网络中的流动路径,即前向传播过程。x.view(x.size(0),-1)将卷积操作得到的多维特征图展平成一维向量,以便输入到全连接层。

5、模型实例化

model = CNN().to(device)

创建 CNN 类的实例,并将模型迁移到指定的设备(CPU 或 GPU)上。

完整代码:

 定义神经网络,通过类的继承
class CNN(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Sequential(#容器,添加网络层nn.Conv2d(in_channels=1,out_channels = 16,kernel_size = 5,stride = 1,padding = 2,),nn.ReLU(),nn.MaxPool2d(kernel_size = 2),)self.conv2 = nn.Sequential(  # 容器,添加网络层nn.Conv2d(16,32,5,1,2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.Conv2d(32,64,5,1,2),nn.ReLU(),nn.Conv2d(64, 64, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(kernel_size=2),)self.conv3 = nn.Sequential(nn.Conv2d(64,64,5,1,2),nn.ReLU(),)self.out = nn.Linear(64*7*7,10)def forward(self, x):  # 前向传播x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0),-1)output = self.out(x)return output
model = CNN().to(device)

三、模型的训练

这一段和上一个博客的步骤一样这里就不多做讲解了

不了解的直接看

深度学习----由手写数字识别案例来认识PyTorch框架-CSDN博客


def train(dataloader, model, loss_fn, optimizer):model.train()  # 开启模型训练模式,像 Dropout、BatchNorm 层会在训练/测试时表现不同,需此设置batch_size_num = 1  # 用于计数当前处理到第几个 batchfor X, y in dataloader:  # 从数据加载器中逐个取出 batch 的数据(特征 X、标签 y )X, y = X.to(device), y.to(device)  # 把数据和标签放到指定计算设备(CPU/GPU)pred = model(X)  # 将数据输入模型,得到预测结果(模型自动做前向传播计算 )loss = loss_fn(pred, y)  # 用损失函数计算预测结果和真实标签的损失# 以下是反向传播更新参数的标准流程optimizer.zero_grad()  # 清空优化器里参数的梯度,避免梯度累加影响计算loss.backward()  # 反向传播,计算参数的梯度optimizer.step()  # 根据梯度,更新模型参数loss = loss.item()  # 取出损失张量的数值(脱离计算图 )# 打印当前 batch 的损失和 batch 编号if batch_size_num % 100 == 0:print(f"loss:{loss:>7f}  [number:{batch_size_num}")batch_size_num += 1  # batch 计数加一def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for X , y in dataloader:X ,y = X.to(device),y.to(device)pred = model.forward(X)#.forward可以被省略test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= sizeprint(f"test result: \n Accuracy: {(100*correct)}%,Avg loss:{test_loss}")# print(list(model.parameters()))
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)#换优化器可以提高准确率,Adam,SGD等# train(train_dataloader,model,loss_fn,optimizer)
# test(test_dataloader,model,loss_fn)
#epochs = 10
for t in range(epochs):print(f"轮次:{t+1}\n----------------------------")train(train_dataloader,model,loss_fn,optimizer)
print("Done")
test(test_dataloader,model,loss_fn)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/94702.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/94702.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙Harmony-从零开始构建类似于安卓GreenDao的ORM数据库(四)

目录 一,查询表的所有数据 二,根据条件查询数据 三,数据库升级 前面章节已经讲解了数据库的创建,表的创建,已经增删改等操作。下面我们来讲解一下数据库的查询以及升级操作。 一,查询表的所有数据 先来看看官方文档: query(predicates: RdbPredicates, callback: Asy…

20250829_编写10.1.11.213MySQL8.0异地备份传输脚本+在服务器上创建cron任务+测试成功

0.已知前提条件: 10.1.11.213 堡垒机访问 mysql 8.0 版本 密码在/root/.my.cnf 备份脚本:/data/backup_mysql/mysql_backup.sh alarm_system:动环数据库 exit_and_entry:出入境数据库 logs:备份日志 project_cg_view_prod:采购跟踪系统 all :数据库整体备份 imip_ecb…

PostgreSQL 流复制与逻辑复制性能优化与故障切换实战经验分享

PostgreSQL 流复制与逻辑复制性能优化与故障切换实战经验分享 在高可用和数据安全愈发受到重视的生产环境中,PostgreSQL 复制技术是保障业务连续性的重要手段。本文结合真实生产场景,分享流复制(Physical Replication)与逻辑复制&…

Django开发规范:构建可维护的AWS资源管理应用

引言 在现代Web开发中,遵循一致的开发规范对于项目的可维护性和团队协作至关重要。本文基于实际的AWS资源管理项目,分享一套经过实践检验的Django开发规范,涵盖模型设计、Admin配置、管理命令和工具类开发等方面。 模型开发规范 数据模型设计原则 良好的数据模型设计是应…

机器学习可解释库Shapash的快速使用教程(五)

文章目录1 快速使用1.1 安装1.2 三个简单步骤快速入门1.2.1 步骤 1:准备模型和数据1.2.2 步骤 2:声明并编译 SmartExplainer1.2.3 步骤 3:可视化和探索1.2.4 启动 Web 应用1.2.5 将解释结果导出为数据2 Shapash的后端集成2.1 方法一&#xff…

如何在emacs中添加imenu插件

在配置文件中添加: ;; 删除现有的包管理器配置(如果有),然后添加以下:;; 初始化包管理器 (require package);; 清除现有的仓库列表 (setq package-archives nil);; 添加正确的仓库(注意:使用 H…

Linux下的网络编程SQLITE3详解

常用数据库关系型数据库将复杂的数据结构简化为二维表格形式大型:Oracle、DB2中型:MySql、SQLServer小型:Sqlite非关系型数据库以键值对存储,且结构不固定JSONRedisMongoDBsqlite数据库特点开源免费,C语言开发代码量少…

适配openai

openai 脚本 stream脚本import os from openai import OpenAIclient OpenAI(base_url"http://127.0.0.1:9117/api/v1",api_keyos.environ["ACCESS_TOKEN"], )stream client.chat.completions.create(model "Qwen/Qwen2-7B-Instruct",messages…

一天认识一个神经网络之--CNN卷积神经网络

CNN 是一种非常强大的深度学习模型,尤其擅长处理像图片这样的网格结构数据。你可以把它想象成一个系统,它能像我们的大脑一样,自动从图片中学习并识别出各种特征,比如边缘、角落、纹理,甚至是更复杂的物体部分&#xf…

13 SQL进阶-InnoDB引擎(8.23)

一、逻辑存储结构(1)表空间(ibd文件):一个mysql实例可以对应多个表空间,用于存储记录、索引等数据。cd /var/lib/mysql(2)段,分为数据段(leaf node segment&a…

MTK Linux DRM分析(二十四)- MTK mtk_drm_plane.c

一、代码分析 mtk_drm_plane.h 和 mtk_drm_plane.c 两个文件,并生成基于文本的函数调用图,我将首先解析文件中的主要函数及其功能,然后根据代码中的调用关系整理出调用图。由于文件内容较长,我会专注于关键函数及其相互调用关系,并以清晰的文本形式呈现。 文件分析 1. …

滚珠导轨如何赋能精密制造?

在智能制造发展的趋势下,新兴行业对高精度、高稳定性的运动控制需求激增。作为直线传动领域的“精密纽带”,滚珠导轨凭借低摩擦、长寿命、高刚性优势,广泛应用于精密传动领域,成为产业升级的关键。新能源汽车制造领域:…

医疗 AI 的 “破圈” 时刻:辅助诊断、药物研发、慢病管理,哪些场景已落地见效?

一、引言在科技迅猛发展的当下,医疗领域正经历着深刻变革,人工智能(AI)技术宛如一颗璀璨新星,强势 “破圈” 闯入,为医疗行业带来了前所未有的机遇与活力。从辅助医生精准诊断病情,到助力药企高…

【项目思维】编程思维学习路线(推荐)

本篇博客是一份系统性、分阶段的 编程思维学习路线图推荐,从零基础小白到系统架构级别,帮助你全面建立和提升编程思维能力。 🚦 阶段 0:思维准备(理解编程是什么) 🎯 学习目标: 理…

vue3+antd实现华为云OBS文件拖拽上传详解

1、文件上传核心流程 选择文件​​:用户通过拖拽或点击选择文件手动触发上传​​:点击"确定"按钮后开始上传(阻止自动上传)​​获取上传凭证​​:从后端获取华为云OBS的上传配置构建表单数据​​&#xff1…

Mac 开发环境与配置操作速查表

Mac 开发环境与配置操作速查表 安装和配置 nvm / Node 安装 Homebrew Homebrew 安装参考文章 如果没有VPN,不要使用此命令安装! /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)" brew --v…

【论文简读】MuGS

今天读一篇ICCV 2025的文章,关注的是Generalizable Gaussian Splatting,作者来自华中科技大学。 文章链接:arxiv 代码仓库:https://github.com/EuclidLou/MuGS(摘要中的链接,但暂时404) 文章目…

基于SpringBoot和百度人脸识别API开发的保安门禁系统

角色: 管理员、保安 技术: Spring Boot, MyBatis, MySQL, PageHelper, Bootstrap, jQuery, JavaScript, CSS3, HTML5, JSP, 百度人脸识别API 核心功能: 小区保安门禁系统是一个基于Spring Boot技术栈开发的综合性平台,旨在实现小区…

抖音电商首创最严珠宝玉石质检体系,推动行业规范与消费扩容

8月27日,“抖音电商开放日质检专场”活动在广州华林国际举行。活动上,抖音电商首次对外介绍了质检仓配一体化中心(QIC)的运作流程,并发布了服务升级计划。这一行业首创的“先鉴定后发货”模式,被认为推动了…

SpringBoot整合Spring WebFlux弃用自带的logback,使用log4j2,并启动异步日志处理

第一步&#xff1a;修改pom文件<!-- Spring Boot Starter WebFlux (排除默认日志) --><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-webflux</artifactId><version>${spring-boot.vers…