PyTorch 实现 MNIST 手写数字识别

PyTorch 实现 MNIST 手写数字识别

MNIST 是一个经典的手写数字数据集,包含 60000 张训练图像和 10000 张测试图像。使用 PyTorch 实现 MNIST 分类通常包括数据加载、模型构建、训练和评估几个部分。

数据加载与预处理

使用 torchvision 加载 MNIST 数据集,并进行归一化和数据增强(可选)。以下是数据加载的示例代码:

import torch
from torchvision import datasets, transformstransform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data', train=False, transform=transform)train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False)

构建模型

定义一个简单的卷积神经网络(CNN)模型:

import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.fc1 = nn.Linear(1024, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = F.relu(self.conv1(x))x = F.max_pool2d(x, 2)x = F.relu(self.conv2(x))x = F.max_pool2d(x, 2)x = torch.flatten(x, 1)x = F.relu(self.fc1(x))x = self.fc2(x)return F.log_softmax(x, dim=1)

训练模型

定义优化器和损失函数,并进行训练:

model = Net()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()def train(model, device, train_loader, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = criterion(output, target)loss.backward()optimizer.step()if batch_idx % 100 == 0:print(f'Epoch: {epoch}, Loss: {loss.item():.4f}')

测试模型

在测试集上评估模型性能:

def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += criterion(output, target).item()pred = output.argmax(dim=1, keepdim=True)correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print(f'Test Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.2f}%)')

完整训练循环

将训练和测试整合到一个完整的循环中:进行10次训练

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)for epoch in range(1, 10):train(model, device, train_loader, optimizer, epoch)test(model, device, test_loader)

模型保存与加载

训练完成后,可以保存模型:

torch.save(model.state_dict(), 'mnist_cnn.pt')

加载模型:

加载模块

model = Net()
model.load_state_dict(torch.load('mnist_cnn.pt'))
model.eval()

或者是Mnist其他源代码

Mnist main.py

import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLRclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 32, 3, 1)self.conv2 = nn.Conv2d(32, 64, 3, 1)self.dropout1 = nn.Dropout(0.25)self.dropout2 = nn.Dropout(0.5)self.fc1 = nn.Linear(9216, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = F.relu(x)x = self.conv2(x)x = F.relu(x)x = F.max_pool2d(x, 2)x = self.dropout1(x)x = torch.flatten(x, 1)x = self.fc1(x)x = F.relu(x)x = self.dropout2(x)x = self.fc2(x)output = F.log_softmax(x, dim=1)return output#训练模型
def train(args, model, device, train_loader, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if batch_idx % args.log_interval == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))if args.dry_run:break#测试模型
def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch losspred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probabilitycorrect += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))#主函数 MNIST Example的主函数
def main():# Training settingsparser = argparse.ArgumentParser(description='PyTorch MNIST Example')parser.add_argument('--batch-size', type=int, default=64, metavar='N',help='input batch size for training (default: 64)')parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',help='input batch size for testing (default: 1000)')parser.add_argument('--epochs', type=int, default=14, metavar='N',help='number of epochs to train (default: 14)')parser.add_argument('--lr', type=float, default=1.0, metavar='LR',help='learning rate (default: 1.0)')parser.add_argument('--gamma', type=float, default=0.7, metavar='M',help='Learning rate step gamma (default: 0.7)')parser.add_argument('--no-accel', action='store_true',help='disables accelerator')parser.add_argument('--dry-run', action='store_true',help='quickly check a single pass')parser.add_argument('--seed', type=int, default=1, metavar='S',help='random seed (default: 1)')parser.add_argument('--log-interval', type=int, default=10, metavar='N',help='how many batches to wait before logging training status')parser.add_argument('--save-model', action='store_true', help='For Saving the current Model')args = parser.parse_args()use_accel = not args.no_accel and torch.accelerator.is_available()torch.manual_seed(args.seed)if use_accel:device = torch.accelerator.current_accelerator()else:device = torch.device("cpu")train_kwargs = {'batch_size': args.batch_size}test_kwargs = {'batch_size': args.test_batch_size}if use_accel:accel_kwargs = {'num_workers': 1,'pin_memory': True,'shuffle': True}train_kwargs.update(accel_kwargs)test_kwargs.update(accel_kwargs)transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])dataset1 = datasets.MNIST('../data', train=True, download=True,transform=transform)dataset2 = datasets.MNIST('../data', train=False,transform=transform)train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)model = Net().to(device)optimizer = optim.Adadelta(model.parameters(), lr=args.lr)scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)for epoch in range(1, args.epochs + 1):train(args, model, device, train_loader, optimizer, epoch)test(model, device, test_loader)scheduler.step()if args.save_model:torch.save(model.state_dict(), "mnist_cnn.pt")if __name__ == '__main__':main()

带命令的Mnist函数

python main.py --help
usage: main.py [-h] [--batch-size N] [--test-batch-size N] [--epochs N] [--lr LR] [--gamma M] [--no-accel][--dry-run] [--seed S] [--log-interval N] [--save-model]PyTorch MNIST Exampleoptional arguments:-h, --help           show this help message and exit--batch-size N       input batch size for training (default: 64)--test-batch-size N  input batch size for testing (default: 1000)--epochs N           number of epochs to train (default: 14)--lr LR              learning rate (default: 1.0)--gamma M            Learning rate step gamma (default: 0.7)--no-accel           disables accelerator--dry-run            quickly check a single pass--seed S             random seed (default: 1)--log-interval N     how many batches to wait before logging training status--save-model         For Saving the current Model

数据目录

注意事项

  • 确保安装了 PyTorch 和 torchvision 库。
  • 可以根据硬件条件调整 batch_size
  • 模型结构和超参数(如学习率)可以根据需求调整。
  • 使用 GPU 可以显著加速训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/web/83857.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python内存互斥与共享深度探索:从GIL到分布式内存的实战之旅

引言:并发编程的内存困局 在开发高性能Python应用时,我遭遇了这样的困境:多进程间需要共享百万级数据,而多线程间又需保证数据一致性。传统解决方案要么性能低下,要么引发竞态条件。本文将深入探讨Python内存互斥与共…

【Unity】使用 C# SerialPort 进行串口通信

索引 一、SerialPort串口通信二、使用SerialPort1.创建SerialPort对象,进行基本配置2.写入串口数据①.写入串口数据的方法②.封装数据 3.读取串口数据①.读取串口数据的方法②.解析数据 4.读取串口数据的时机①.DataReceived事件②.多线程接收数据 5.粘包问题处理 一…

如何写好单元测试:Mock 脱离数据库,告别 @SpringBootTest 的重型启动

如何写好单元测试:Mock 脱离数据库,告别 SpringBootTest 的重型启动 作者:Killian(重庆) — 欢迎各位架构猎头、技术布道者联系我,项目实战丰富,代码稳健,Mock测试爱好者。 技术栈&a…

【DNS】在 Windows 下修改 `hosts` 文件

在 Windows 下修改 hosts 文件,一般用于本地 DNS 覆盖。操作步骤如下(以 Windows 10/11 为例): 1. 以管理员权限打开记事本 点击 开始 → 输入 “记事本”在“记事本”图标上右键 → 选择 以管理员身份运行 如果提示“是否允许此…

共享内存实现进程通信

目录 system V共享内存 共享内存示意图 共享内存函数 shmget函数 shmat函数 shmdt函数 shmctl函数 代码示例 shm头文件 构造函数 获取key值 创建者的构造方式 GetShmHelper 函数 GetShmUseCreate 函数 使用者的构造方式 GetShmForUse 函数 分离附加操作 DetachShm 函数 AttachS…

6月15日星期日早报简报微语报早读

6月15日星期日,农历五月二十,早报#微语早读。 1、证监会拟修订期货公司分类评价:明确扣分标准,优化加分标准; 2、国家考古遗址公园再添10家,全国已评定65家; 3、北京多所高校禁用罗马仕充电宝…

破解关键领域软件测试“三重难题”:安全、复杂性、保密性

在国家关键领域,软件系统正成为核心战斗力的一部分。相比通用软件,关键领域软件在 安全性、复杂性、实时性、保密性 等方面要求极高。如何保障安全合规前提下提升测试效率,确保系统稳定,已成为软件质量保障的核心挑战。 关键领域…

记录一次 Oracle DG 异常停库问题解决过程

记录一次 Oracle DG 异常停库问题解决过程 某医院有以下架构的双节点 Oracle 集群: 节点1:172.16.20.2 节点2:172.16.20.3 SCAN IP:172.16.20.1 DG:172.16.20.1206月12日,医院信息科用户反映无法连接 DG 服务器。 登录 DG 服务…

MySQL使用EXPLAIN命令查看SQL的执行计划

1‌、EXPLAIN 的语法 MySQL 中的 EXPLAIN 命令是用于分析 SQL 查询执行计划的关键工具,它能帮助开发者理解查询的执行方式并找出性能瓶颈‌‌。 语法格式: EXPLAIN <sql语句> 【示例】查询学生表关联班级表的执行计划。 (1)创建班级信息表和学生信息表,并创建索…

Go语言2个协程交替打印

WaitGroup 无缓冲channel waitgroup 用来控制2个协程 Add() 、Done()、Wait() channel用来实现信号的传递和信号的打印 ch1: 用来记录打印的信号 ch2:用来实现信号的传递&#xff0c;实现2个协程的顺序打印 package mainimport ("fmt""sync" )func ma…

微信小程序 路由跳转

路由方式 官方参考文档 wx.switchTab 实现底部导航栏 1.配置信息 app.json"tabBar": {"custom": true,"list": [{"pagePath": "pages/home/index","text": "首页"},{"pagePath": "p…

[Java 基础]正则表达式

正则表达式是一种强大的文本模式匹配工具&#xff0c;它使用一种特殊的语法来描述要搜索或操作的字符串模式。在 Java 中&#xff0c;我们可以使用 java.util.regex包提供的类来处理正则表达式。 :::color3 正则表达式不止 Java 语言提供了相应的功能&#xff0c;很多其他语言…

ArcGIS安装出现1606错误解决办法

问题背景&#xff1a; 由于最近Arcgis10.2打是有些功能不正常退出&#xff0c;比如arctoolbox中的&#xff0c;table to excel 功能&#xff0c;只要一点击&#xff0c;arcgis就报错退出&#xff0c;平常在使用过程中&#xff0c;也经常出现一些莫名其妙的崩溃现象&#xff0c…

wpf 解决DataGridTemplateColumn中width绑定失效问题

感谢酪酪烤奶 提供的Solution 文章目录 感谢酪酪烤奶 提供的Solution使用示例示例代码分析各类交互流程 WPF DataGrid 列宽绑定机制分析整体架构数据流分析1. ViewModel到Slider的绑定2. ViewModel到DataGrid列的绑定a. 绑定代理(BindingProxy)b. 列宽绑定c. 数据流 关键机制详…

语音转文本ASR、文本转语音TTS

ASR Automatic Speech Recognition&#xff0c;语音转文本。 技术难点&#xff1a; 声学多样性 口音、方言、语速、背景噪声会影响识别准确性&#xff1b;多人对话场景&#xff08;如会议录音&#xff09;需要区分说话人并分离语音。 语言模型适配 专业术语或网络新词需要动…

通用embedding模型和通用reranker模型,观测调研

调研Qwen3-Embedding和Qwen3-Reranker 现在有一个的问答库&#xff0c;包括150个QA-pair&#xff0c;用10个query去同时检索问答库的300个questionanswer Embedding模型&#xff0c;query-question的匹配分数 普遍高于 query-answer的匹配分数。比如对于10个query&#xff0c…

基于YOLOv8+Deepface的人脸检测与识别系统

摘要 人脸检测与识别系统是一个集成了先进计算机视觉技术的应用&#xff0c;通过深度学习模型实现人脸检测、识别和管理功能。系统采用双模式架构&#xff1a; ​​注册模式​​&#xff1a;检测新人脸并添加到数据库​​删除模式​​&#xff1a;识别数据库中的人脸并移除匹…

Grdle版本与Android Gradle Plugin版本, Android Studio对应关系

Grdle版本与Android Gradle Plugin版本&#xff0c; Android Studio对应关系 各个 Android Gradle 插件版本所需的 Gradle 版本&#xff1a; https://developer.android.com/build/releases/gradle-plugin?hlzh-cn Maven上发布的Android Gradle Plugin&#xff08;AGP&#x…

用c语言实现简易c语言扫雷游戏

void test(void) {int input 0;do{menu();printf("请选择&#xff1a; >");scanf("%d", &input);switch (input){menu();case 1:printf("扫雷\n");game();break;case 2:printf("退出游戏\n");break;default:printf("输入…

系统辨识的研究生水平读书报告期末作业参考

这是一份关于系统辨识的研究生水平读书报告&#xff0c;内容系统完整、逻辑性强&#xff0c;并深入探讨了理论、方法与实际应用。报告字数超过6000字 从理论到实践&#xff1a;系统辨识的核心思想、方法论与前沿挑战 摘要 系统辨识作为连接理论模型与客观世界的桥梁&#xff…