【深度学习】实验四 卷积神经网络CNN

 实验四  卷积神经网络CNN

一、实验学时: 2学时

二、实验目的

  1. 掌握卷积神经网络CNN的基本结构;
  2. 掌握数据预处理、模型构建、训练与调参;
  3. 探索CNN在MNIST数据集中的性能表现;

三、实验内容

实现深度神经网络CNN。

四、主要实验步骤及结果

1.搭建一个CNN网络,使用MNIST手写数字数据集进行训练与测试,并体现模型最终结果,CNN网络的具体框架可参考下图,也可自己设计:

图4-1 CNN架构图

(1)该图表示输入层为28*28*1的尺寸,符合MNIST数据集的标准尺寸。

(2)第一个卷积层,使用5*5卷积核,32个滤波器,填充(Padding)为2。输出尺寸为28*28*32。

(3)第一个池化层,使用2*2池化窗口,步长(stride)为2。输出尺寸为14*14*32。

(4)第二个卷积层,使用5*5卷积核,64个滤波器,填充(Padding)为2。输出尺寸为14*14*64。

(5)第二个池化层,使用2*2池化窗口,步长(stride)为2。输出尺寸为7*7*64。

(6)全连接层包含1024个神经元,输出尺寸为1*1*1024。

(7)Dropout层用于防止过拟合。

(8)输出层包含10个神经元,对应手写数字的0-9。输出尺寸为1*1*10。

模型实现:

以该架构图搭建CNN网络,使用MNIST手写数字数据集进行训练与测试,训练和测试结果如图4-2所示:

图4-2 CNN测试结果

2.尝试使用不同的数据增强方法、优化器、损失函数、学习率、batch size和迭代次数来进行训练,记录训练过程,评估模型性能,保存最佳模型。

编号

batch size

训练轮次

学习率

数据增强方法

优化器

实验结果

1

32

2

1e-4

Adam

98.62%

2

64

2

1e-4

Adam

98.56%

3

64

4

1e-4

Adam

99.08%

4

64

4

3e-4

Adam

99.08%

5

64

4

3e-4

旋转+平移

Adam

98.90%

5

64

4

3e-4

Adam(L2正则化)

99.23%

6

64

4

1e-4

SGD+momentum

97.30%

其中数据增强方法采用随机旋转和平移吗,原始代码中包含ToTensor()和Normalize(),给原始代码添加随机旋转10度和随机平移10%,代码如下:

# 数据加载(归一化)
transform = torchvision.transforms.Compose([torchvision.transforms.RandomRotation(10),  # 随机旋转10度torchvision.transforms.RandomAffine(0, translate=(0.1, 0.1)),  # 随机平移10%torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))
])

优化器选择方面使用SGD+momentum(0.9)替代原Adam优化器,

# 使用SGD+momentum
optimizer = torch.optim.SGD(model.parameters(), lr=LEARN_RATE, momentum=0.9)

根据训练过程记录的数据,最佳模型尊却绿为99.23%,最佳模型代码如下:

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoaderBATCH_SIZE = 64
EPOCHS = 4
LEARN_RATE = 3e-4
DROPOUT_RATE = 0.5device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')# 数据加载(归一化)
transform = torchvision.transforms.Compose([# torchvision.transforms.RandomRotation(10),  # 随机旋转10度# torchvision.transforms.RandomAffine(0, translate=(0.1, 0.1)),  # 随机平移10%torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))
])train_data = torchvision.datasets.MNIST(root='./mnist',train=True,download=True,transform=transform
)
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)test_data = torchvision.datasets.MNIST(root='./mnist',train=False,transform=transform
)
test_loader = DataLoader(test_data, batch_size=1000, shuffle=False)class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_layers = nn.Sequential(# 第一层卷积:5x5 卷积核,32 个过滤器,padding=2nn.Conv2d(1, 32, kernel_size=5, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),  # 池化后 14x14x32# 第二层卷积:5x5 卷积核,64 个过滤器,padding=2nn.Conv2d(32, 64, kernel_size=5, padding=2),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2)  # 池化后 7x7x64)self.fc_layers = nn.Sequential(nn.Linear(64 * 7 * 7, 1024),  # 全连接层:7x7x64 → 1024nn.ReLU(),nn.Dropout(DROPOUT_RATE),  # Dropout层nn.Linear(1024, 10)  # 输出层:1024 → 10)self._initialize_weights()def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')nn.init.constant_(m.bias, 0)def forward(self, x):x = self.conv_layers(x)x = x.view(x.size(0), -1)  # 展平操作x = self.fc_layers(x)return xmodel = CNN().to(device)
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_RATE, weight_decay=1e-5)
# optimizer = torch.optim.SGD(model.parameters(), lr=LEARN_RATE, momentum=0.9)  # 使用SGD+momentum
# 训练循环
for epoch in range(EPOCHS):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = loss_fn(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Epoch {epoch + 1} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.4f}')# 测试
model.eval()
correct = 0
with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)pred = output.argmax(dim=1)correct += pred.eq(target).sum().item()print(f'Test Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.2f}%)')

3.使用画图工具将自己的学号逐个写出,使用保存的最佳模型对每个数字进行推理,比较模型对每个数字的准确率预测,也可以尝试实现一个实时识别手写数字的demo。
(1)使用画图工具将自己的学号逐个写出,进行反色处理,并将图片命名为“x_001.png”格式。

图4-3手写数字

(2)在训练代码(CNN.py)中添加模型保存代码。

torch.save(model.state_dict(), 'mnist_cnn.pth')

(3)编写推理代码读取img文件夹中的手写图片并预测,预测代码如下所示:

import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import numpy as np
import os# 定义模型结构(需与训练代码一致)
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv_layers = nn.Sequential(nn.Conv2d(1, 32, kernel_size=5, padding=2),nn.ReLU(),nn.MaxPool2d(2, 2),nn.Conv2d(32, 64, kernel_size=5, padding=2),nn.ReLU(),nn.MaxPool2d(2, 2))self.fc_layers = nn.Sequential(nn.Linear(64 * 7 * 7, 1024),nn.ReLU(),nn.Dropout(0.5),nn.Linear(1024, 10))def forward(self, x):x = self.conv_layers(x)x = x.view(x.size(0), -1)x = self.fc_layers(x)return x# 加载模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CNN().to(device)
model.load_state_dict(torch.load('mnist_cnn.pth', map_location=device))
model.eval()# 定义预处理(与训练一致)
transform = transforms.Compose([transforms.Resize((28, 28)),  # 确保输入为28x28transforms.Grayscale(num_output_channels=1),  # 转换为单通道transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])# 遍历img文件夹中的图片并推理
img_dir = 'img'
digit_stats = {str(i): {'correct': 0, 'total': 0} for i in range(10)}for filename in os.listdir(img_dir):if filename.lower().endswith(('.png', '.jpg', '.jpeg')):# 从文件名中提取真实标签(假设文件名为 "label_xxx.png")try:true_label = filename.split('_')[0]  # 例如文件名 "3_001.png" → 标签为3true_label = int(true_label)if true_label < 0 or true_label > 9:continueexcept:print(f"跳过文件 {filename}(文件名格式错误)")continue# 加载并预处理图像img_path = os.path.join(img_dir, filename)image = Image.open(img_path)image = transform(image).unsqueeze(0).to(device)  # 添加batch维度# 推理with torch.no_grad():output = model(image)pred = output.argmax(dim=1).item()# 统计结果digit_stats[str(true_label)]['total'] += 1if pred == true_label:digit_stats[str(true_label)]['correct'] += 1print(f"图片 {filename} 真实标签: {true_label}, 预测: {pred} → {'正确' if pred == true_label else '错误'}")# 计算每个数字的准确率
accuracies = {}
for digit in digit_stats:if digit_stats[digit]['total'] > 0:acc = digit_stats[digit]['correct'] / digit_stats[digit]['total']accuracies[digit] = accprint(f"数字 {digit} 的准确率: {acc:.2%}")

预测结果如图4-4所示:

图4-4预测结果

预测结果显示“1”和“4”预测结果错误,其他均正确。

五、实验小结(包括问题和解决办法、心得体会、意见与建议等)

1.问题和解决办法:

问题1:RuntimeError: Dataset not found. You can use download=True to download it。

解决方法:添加下载训练集的参数download=True。

问题2:使用SGD+momentum优化器后,准确率反而下降了。

解决方法:因为SGD对学习率比较敏感,学习率没有适配,使用StepLR梯度衰减,另外也可以增加训练轮次。

问题3:预测结果全部错误。

解决方法:图片要像素28*28,且黑色背景,白色笔迹,对Windows画图的图片反色处理即可。

2.心得体会:通过本次CNN手写数字识别实验的完整实践,我深刻体会到深度学习模型性能的提升是一个系统工程,需要从数据、模型、训练策略到结果分析的全流程精细化把控,尝试使用不同的数据增强方法、优化器、损失函数、学习率、batch size和迭代次数来进行训练,迭代出最佳模型,再手写数字进行测试。通过以上的学习和实践,我对神经网络的原理和应用有了更深入的理解。神经网络的发展给人工智能带来了巨大的影响,它在图像识别、自然语言处理等领域发挥着重要的作用。我相信,随着技术的进步,神经网络将会有更广泛的应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/web/82311.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot高校宿舍信息管理系统小程序

概述 基于SpringBoot的高校宿舍信息管理系统小程序项目&#xff0c;这是一款非常适合高校使用的信息化管理工具。该系统包含了完整的宿舍管理功能模块&#xff0c;采用主流技术栈开发&#xff0c;代码结构清晰&#xff0c;非常适合学习和二次开发。 主要内容 这个宿舍管理系…

Redis 难懂命令-- ZINTERSTORE

**背景&#xff1a;**学习的过程中 常用的redis命令都能快速通过官方文档理解 但是还是有一些比较难懂的命令 **目的&#xff1a;**写博客记录一下&#xff08;当然也可以使用AI搜索&#xff09; 在Redis中&#xff0c;ZINTERSTORE 是一个用于计算多个有序集合&#xff08;So…

React 路由管理与动态路由配置实战

React 路由管理与动态路由配置实战 前言 在现代单页应用(SPA)开发中&#xff0c;路由管理已经成为前端架构的核心部分。随着React应用规模的扩大&#xff0c;静态路由配置往往难以满足复杂业务场景的需求&#xff0c;尤其是当应用需要处理权限控制、动态菜单和按需加载等高级…

【学习笔记】深度学习-梯度概念

一、定义 梯度向量不仅表示函数变化的速度&#xff0c;还表示函数增长最快的方向 二、【问】为什么说它表示方向&#xff1f; 三、【问】那在深度学习梯度下降的时候&#xff0c;还要判断梯度是正是负来更新参数吗&#xff1f; 假设某个参数是 w&#xff0c;损失函数对它的…

题海拾贝:P8598 [蓝桥杯 2013 省 AB] 错误票据

Hello大家好&#xff01;很高兴我们又见面啦&#xff01;给生活添点passion&#xff0c;开始今天的编程之路&#xff01; 我的博客&#xff1a;<但凡. 我的专栏&#xff1a;《编程之路》、《数据结构与算法之美》、《题海拾贝》 欢迎点赞&#xff0c;关注&#xff01; 1、题…

webpack的安装及其后序部分

npm install原理 这个其实就是npm从registry下载项目到本地&#xff0c;没有什么好说的 值得一提的是npm的缓存机制&#xff0c;如果多个项目都需要同一个版本的axios&#xff0c;每一次重新从registry中拉取的成本过大&#xff0c;所以会有缓存&#xff0c;如果缓存里有这个…

百度golang研发一面面经

输入一个网址&#xff0c;到显示界面&#xff0c;中间的过程是怎样的 IP 报文段的结构是什么 Innodb 的底层结构 知道几种设计模式 工厂模式 简单工厂模式&#xff1a;根据传入类型参数判断创建哪种类型对象工厂方法模式&#xff1a;由子类决定实例化哪个类抽象工厂模式&#…

使用 HTML + JavaScript 实现图片裁剪上传功能

本文将详细介绍一个基于 HTML 和 JavaScript 实现的图片裁剪上传功能。该功能支持文件选择、拖放上传、图片预览、区域选择、裁剪操作以及图片下载等功能&#xff0c;适用于需要进行图片处理的 Web 应用场景。 效果演示 项目概述 本项目主要包含以下核心功能&#xff1a; 文…

GO+RabbitMQ+Gin+Gorm+docker 部署 demo

更多个人笔记见&#xff1a; &#xff08;注意点击“继续”&#xff0c;而不是“发现新项目”&#xff09; github个人笔记仓库 https://github.com/ZHLOVEYY/IT_note gitee 个人笔记仓库 https://gitee.com/harryhack/it_note 个人学习&#xff0c;学习过程中还会不断补充&…

【安全】VulnHub靶场 - W1R3S

【安全】VulnHub靶场 - W1R3S 备注一、故事背景二、Web渗透1.主机发现端口扫描2.ftp服务3.web服务 三、权限提升 备注 2025/05/22 星期四 简单的打靶记录 一、故事背景 您受雇对 W1R3S.inc 个人服务器进行渗透测试并报告所有发现。 他们要求您获得 root 访问权限并找到flag&…

WEB安全--SQL注入--MSSQL注入

一、SQLsever知识点了解 1.1、系统变量 版本号&#xff1a;version 用户名&#xff1a;USER、SYSTEM_USER 库名&#xff1a;DB_NAME() SELECT name FROM master..sysdatabases 表名&#xff1a;SELECT name FROM sysobjects WHERE xtypeU 字段名&#xff1a;SELECT name …

工作流引擎-18-开源审批流项目之 plumdo-work 工作流,表单,报表结合的多模块系统

工作流引擎系列 工作流引擎-00-流程引擎概览 工作流引擎-01-Activiti 是领先的轻量级、以 Java 为中心的开源 BPMN 引擎&#xff0c;支持现实世界的流程自动化需求 工作流引擎-02-BPM OA ERP 区别和联系 工作流引擎-03-聊一聊流程引擎 工作流引擎-04-流程引擎 activiti 优…

Docker 笔记 -- 借助AI工具强势辅助

常用命令 镜像管理命令&#xff1a; docker images&#xff08;列出镜像&#xff09; docker pull&#xff08;拉取镜像&#xff09; docker build&#xff08;构建镜像&#xff09; docker save/load&#xff08;保存/加载镜像&#xff09; 容器操作命令 docker run&#…

5G-A时代与p2p

5G-A时代正在走来&#xff0c;那么对P2P的影响有多大。 5G-A作为5G向6G过渡的关键技术&#xff0c;将数据下载速率从千兆提升至万兆&#xff0c;上行速率从百兆提升至千兆&#xff0c;时延降至毫秒级。这种网络性能的跨越式提升&#xff0c;为P2P提供了更强大的底层支撑&#x…

Redis-6.2.9 主从复制配置和详解

1 主从架构图 192.168.254.120 u24-redis-120 #主库 192.168.254.121 u24-redis-121 #从库 2 redis软件版本 rootu24-redis-121:~# redis-server --version Redis server v6.2.9 sha00000000:0 malloclibc bits64 build56edd385f7ce4c9b 3 主库redis配置文件(192.168.254.1…

004 flutter基础 初始文件讲解(3)

之前&#xff0c;我们正向的学习了一些flutter的基础&#xff0c;如MaterialApp&#xff0c;Scaffold之类的东西&#xff0c;那么接下来&#xff0c;我们将正式接触原代码&#xff1a; import package:flutter/material.dart;void main() {runApp(const MyApp()); }class MyAp…

Linux 系统 Docker Compose 安装

个人博客地址&#xff1a;Linux 系统 Docker Compose 安装 | 一张假钞的真实世界 本文方法是直接下载 GitHub 项目的 release 版本。项目地址&#xff1a;GitHub - docker/compose: Define and run multi-container applications with Docker。 执行以下命令将发布程序加载至…

Tree 树形组件封装

整体思路 数据结构设计 使用递归的数据结构&#xff08;TreeNode&#xff09;表示树形数据每个节点包含id、name、可选的children数组和selected状态 状态管理 使用useState在组件内部维护树状态的副本通过deepCopyTreeData函数进行深拷贝&#xff0c;避免直接修改原始数据 核…

tortoisegit 使用rebase修改历史提交

在 TortoiseGit 中使用 rebase 修改历史提交&#xff08;如修改提交信息、合并提交或删除提交&#xff09;的步骤如下&#xff1a; --- ### **一、修改最近一次提交** 1. **操作**&#xff1a; - 右键项目 → **TortoiseGit** → **提交(C)** - 勾选 **"Amend…

中科院报道铁电液晶:从实验室突破到多场景应用展望

2020年的时候&#xff0c;相信很多关注科技前沿的朋友都注意到&#xff0c;中国科学院一篇报道聚焦一项有望改写显示产业格局的新技术 —— 铁电液晶&#xff08;FeLC&#xff09;。这项被业内称为 "下一代显示核心材料" 的研究&#xff0c;究竟取得了哪些实质性进展…