用 PyTorch 实现一个简单的神经网络:从数据到预测

PyTorch 是目前最流行的深度学习框架之一,以其灵活性和易用性受到开发者的喜爱。本文将带你从零开始,用 PyTorch 实现一个简单的神经网络,用于解决经典的 MNIST 手写数字分类问题。我们将涵盖数据准备、模型构建、训练和预测的完整流程,并提供可运行的代码示例。

1. 环境准备

首先,确保你已安装 PyTorch 和相关依赖。本例使用 Python 3.8+ 和 PyTorch。你可以通过以下命令安装:

pip install torch torchvision

我们将使用 MNIST 数据集,它包含 28x28 像素的手写数字图像(0-9),目标是训练一个神经网络来识别这些数字。

2. 数据准备

MNIST 数据集可以通过 PyTorch 的 torchvision 模块直接加载。我们需要将数据加载为张量,并进行归一化处理以加速训练。

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms# 定义数据预处理:将图像转换为张量并归一化
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))  # MNIST 的均值和标准差
])# 加载 MNIST 数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

代码说明

  • transforms.ToTensor() 将图像转换为 PyTorch 张量,并将像素值从 [0, 255] 缩放到 [0, 1]。

  • transforms.Normalize 标准化数据,加速梯度下降收敛。

  • DataLoader 用于批量加载数据,batch_size=64 表示每次处理 64 张图像。

3. 构建神经网络

我们将定义一个简单的全连接神经网络,包含两个隐藏层,适合处理 MNIST 的分类任务。

import torch.nn as nnclass SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.flatten = nn.Flatten()  # 将 28x28 图像展平为 784 维向量self.fc1 = nn.Linear(28 * 28, 128)  # 第一个全连接层self.relu = nn.ReLU()  # 激活函数self.fc2 = nn.Linear(128, 64)  # 第二个全连接层self.fc3 = nn.Linear(64, 10)   # 输出层,10 个类别(0-9)def forward(self, x):x = self.flatten(x)x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.fc3(x)return x# 实例化模型
model = SimpleNN()

代码说明

  • nn.Module 是 PyTorch 模型的基类,自定义模型需要继承它。

  • forward 方法定义了前向传播的计算流程。

  • 网络结构:输入层 (784) → 隐藏层1 (128) → ReLU → 隐藏层2 (64) → ReLU → 输出层 (10)。

4. 定义损失函数和优化器

我们使用交叉熵损失(适合分类任务)和 Adam 优化器来训练模型。

import torch.optim as optim# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

代码说明

  • nn.CrossEntropyLoss 结合了 softmax 和负对数似然损失,适合多分类任务。

  • Adam 优化器以 0.001 的学习率更新模型参数。

5. 训练模型

接下来,我们训练模型 5 个 epoch,观察损失变化。

def train(model, train_loader, criterion, optimizer, epochs=5):model.train()  # 切换到训练模式for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:optimizer.zero_grad()  # 清零梯度outputs = model(images)  # 前向传播loss = criterion(outputs, labels)  # 计算损失loss.backward()  # 反向传播optimizer.step()  # 更新参数running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")# 开始训练
train(model, train_loader, criterion, optimizer)

代码说明

  • model.train() 启用训练模式(影响 dropout 和 batch norm 等层)。

  • 每次迭代清零梯度、计算损失、反向传播并更新参数。

  • 每轮 epoch 打印平均损失。

6. 测试模型

训练完成后,我们在测试集上评估模型的准确率。

def test(model, test_loader, criterion):model.  # 切换到评估模式correct = 0total = 0test_loss = 0.0with torch.no_grad():  # 禁用梯度计算for images, labels in test_loader:outputs = model(images)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = torch.max(outputs.data, 1)  # 获取预测类别total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f"Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {accuracy:.2f}%")# 测试模型
test(model, test_loader, criterion)

代码说明

  • model. 切换到评估模式,禁用 dropout 等。

  • 使用 torch.no_grad() 减少内存消耗。

  • 计算测试集的损失和准确率。

7. 进行预测

最后,我们用训练好的模型对单张图像进行预测。

import matplotlib.pyplot as plt# 获取一张测试图像
dataiter = iter(test_loader)
images, labels = next(dataiter)
image, label = images[0], labels[0]# 预测
model.
with torch.no_grad():output = model(image.unsqueeze(0))  # 增加 batch 维度_, predicted = torch.max(output, 1)# 显示图像和预测结果
plt.imshow(image.squeeze(), cmap='gray')
plt.title(f"Predicted: {predicted.item()}, Actual: {label.item()}")
plt.savefig('prediction.png')  # 保存图像

代码说明

  • 从测试集取一张图像,调用模型进行预测。

  • 使用 Matplotlib 显示图像及其预测结果,保存为 PNG 文件。

8. 完整代码

以下是完整的可运行代码,整合了上述所有步骤:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt# 数据准备
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)# 定义模型
class SimpleNN(nn.Module):def __init__(self):super(SimpleNN, self).__init__()self.flatten = nn.Flatten()self.fc1 = nn.Linear(28 * 28, 128)self.relu = nn.ReLU()self.fc2 = nn.Linear(128, 64)self.fc3 = nn.Linear(64, 10)def forward(self, x):x = self.flatten(x)x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.fc3(x)return x# 实例化模型、损失函数和优化器
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练函数
def train(model, train_loader, criterion, optimizer, epochs=5):model.train()for epoch in range(epochs):running_loss = 0.0for images, labels in train_loader:optimizer.zero_grad()outputs = model(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")# 测试函数
def test(model, test_loader, criterion):model.correct = 0total = 0test_loss = 0.0with torch.no_grad():for images, labels in test_loader:outputs = model(images)loss = criterion(outputs, labels)test_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = 100 * correct / totalprint(f"Test Loss: {test_loss/len(test_loader):.4f}, Accuracy: {accuracy:.2f}%")# 训练和测试
train(model, train_loader, criterion, optimizer)
test(model, test_loader, criterion)# 预测单张图像
dataiter = iter(test_loader)
images, labels = next(dataiter)
image, label = images[0], labels[0]
model.
with torch.no_grad():output = model(image.unsqueeze(0))_, predicted = torch.max(output, 1)
plt.imshow(image.squeeze(), cmap='gray')
plt.title(f"Predicted: {predicted.item()}, Actual: {label.item()}")
plt.savefig('prediction.png')

9. 总结

通过本文,可以了解到如何用 PyTorch 实现一个简单的神经网络,包括:

  • 加载和预处理 MNIST 数据集。

  • 构建一个全连接神经网络。

  • 使用交叉熵损失和 Adam 优化器进行训练。

  • 在测试集上评估模型性能。

  • 对单张图像进行预测并可视化结果。

这个模型虽然简单,但在 MNIST 数据集上通常能达到 95% 以上的准确率。可以进一步尝试调整网络结构(如增加层数)、优化超参数(如学习率)或使用卷积神经网络(CNN)来提升性能。希望这篇文章对你理解 PyTorch 和深度学习有所帮助!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/91967.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/91967.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

四级页表通俗讲解与实践(以 64 位 ARM Cortex-A 为例)

📖 🎥 B 站博文精讲视频:点击链接,配合视频深度学习 四级页表通俗讲解与实践(以 64 位 ARM Cortex-A 为例) 本文面向希望彻底理解现代 64 位架构下四级页表的开发者,结合 ARM Cortex-A 系列处理…

AI模型整合包上线!一键部署ComfyUI,2.19TB模型全解析

最近体验了AIStarter平台上线的AI模型整合包,包含2.19TB ComfyUI大模型,整合市面主流模型,一键部署ComfyUI,省去重复下载烦恼!以下是使用心得和部署步骤,适合AI开发者参考。工具亮点这款AI模型整合包由熊哥…

灰色优选模型及算法MATLAB代码

电子装备试验方案优选是一个典型的多属性决策问题,通常涉及指标复杂、信息不完整、数据量少且存在不确定性的特点。灰色系统理论(Grey System Theory)特别擅长处理“小样本、贫信息”的不确定性问题,因此非常适合用于此类方案的优…

AI框架工具FastRTC快速上手6——视频流案例之物体检测(下)

一 前言 上一篇,我们实现了用YOLO对图片上的物体进行检测,并在图片上框出具体的对象并打出标签。但只是应用在单张图片,且还没用上FastRTC。 本篇,我们希望结合FastRTC的能力,实现基于YOLO的实时视频流的物体检测。 本篇文字将不会太多。学习完本篇,对比前面的文章,你…

PHP常见中高面试题汇总

一、 PHP部分 1、PHP如何实现静态化 PHP的静态化分为:纯静态和伪静态。其中纯静态又分为:局部纯静态和全部纯静态。 PHP伪静态:利用Apache mod_rewrite实现URL重写的方法; PHP纯静态,就是生成HTML文件的方式&#xff0…

基于Java AI(人工智能)生成末日题材的实践

Java AI 生成《全球末日》文章的实例 使用Java结合AI技术生成《全球末日》题材的文章可以通过多种方式实现,包括调用预训练模型、使用自然语言处理库或结合生成式AI框架。以下是30个实例的生成方法和示例代码片段。 调用预训练模型(如GPT-3或GPT-4) 使用OpenAI API生成末日…

针对软件定义车载网络的动态服务导向机制

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 做到欲望极简,了解自己的真实欲望,不受外在潮流的影响,不盲从,不跟风。把自己的精力全部用在自己。一是去掉多余,凡事找规律,基础是诚信;二是…

Pytorch实现婴儿哭声检测和识别

Pytorch实现婴儿哭声检测和识别 目录 Pytorch实现婴儿哭声检测识别 1. 项目说明 2. 数据说明 (1)婴儿哭声语音数据集 (2)自定义数据集 3. 模型训练 (1)项目安装 (2)准备Tra…

海信IP810N/海信IP811N_海思MV320-安卓9.0主板-TTL烧录包-可救砖

海信IP810N/海信IP811N_海思MV320处理器-安卓9主板-TTL烧录包-可救砖准备工作:TTL线自备跑码工具【putty跑码中文版】路径:【工具大全】-【putty跑码中文版】测试跑码以后将跑码窗口关闭;然后到下方下载烧录工具并大致看下教程烧录…

Go 中的 interface{} 与 Java 中的 Object:相似之处与本质差异

在软件系统开发中,“通用类型”的处理是各语言设计中不可忽视的一部分。Java 使用 Object,Go 使用 interface{},它们都可以容纳任意类型的值,是实现动态行为或通用容器的基础类型。然而,虽然两者在使用层面看似相似&am…

Docker-07.Docker基础-数据卷挂载

一.案例首先我们通过一则案例来引出问题。我们要修改nginx容器内的html目录下的index.html文件,并且要将静态资源部署到nginx的html目录,就要首先知道该html目录的所在位置。我们首先查看nginx镜像的帮助文档,这里就是将有关静态资源目录的&a…

数据结构(三)双向链表

一、什么是 make 工具?make 是一个自动化构建工具,主要用于管理 C/C 项目的编译和链接过程。它通过读取 Makefile 文件中定义的规则,自动判断哪些文件被修改,并仅重新编译这些部分,从而大幅提高构建效率。二、什么是 M…

如何在没有iCloud的情况下将联系人转移到新iPhone?

升级到新 iPhone 后,设置已完成,想在不使用 iCloud 的情况下将联系人从 iPhone 转移到 iPhone 吗?别担心。还有其他 5 种方法可以帮助您轻松地将联系人转移到新 iPhone。这样,您就无需再次重置新设备了。第 1 部分:如何…

SpringBoot3.x入门到精通系列:4.2 整合 Kafka 详解

SpringBoot 3.x 整合 Kafka 详解 🎯 Kafka简介 Apache Kafka是一个分布式流处理平台,主要用于构建实时数据管道和流应用程序。它具有高吞吐量、低延迟、可扩展性和容错性等特点。 核心概念 Producer: 生产者,发送消息到Kafka集群Consumer: 消…

Android audio之 AudioDeviceInventory

1. 类介绍 AudioDeviceInventory 是 Android 音频系统中的一个核心类,位于 frameworks/base/services/core/java/com/android/server/audio/ 路径下。它负责 管理所有音频设备的连接状态,包括设备的添加、移除、状态更新以及策略应用。 设备连接状态管理:记录所有已连接的音…

系统设计入门:成为更优秀的工程师

系统设计入门指南 动机 现在你可以学习如何设计大规模系统,为系统设计面试做准备。本指南包含的是一个有组织的资源集合,旨在帮助你了解如何构建可扩展的系统。 学习设计大规模系统 学习如何设计可扩展系统将帮助你成为更优秀的工程师。系统设计是一个…

Pandas数据分析工具基础

文章目录 0. 学习目标 1. Pandas的数据结构分析 1.1 Series - 序列 1.1.1 Series概念 1.1.2 Series类的构造方法 1.1.3 创建Series对象 1.1.3.1 基于列表创建Series对象 1.1.3.2 基于字典创建Series对象 1.1.4 获取Series对象的数据 1.1.5 Series对象的运算 1.1.6 增删Series对…

大模型——Qwen开源会写中文的生图模型Qwen-Image

Qwen开源会写中文的生图模型Qwen-Image 会写中文,这基本上是开源图片生成模型的独一份了。 这次开源的Qwen-Image 的最大卖点是“像素级文字生成”。它能直接在像素空间内完成排版:从小字注脚到整版海报均可清晰呈现,且同时支持英文字母与汉字。 以下图片均来自官网的生成…

大模型知识库(1)京东云 JoyAgent介绍

一、核心定位​ JoyAgent 是京东云推出的 ​首个 100% 开源的企业级多智能体平台,定位为“可插拔的智能发动机”,旨在通过开箱即用的产品级能力,降低企业部署智能体的门槛。其特点包括: ​完整开源​:前端&#xff0…

PowerShell 入门2: 使用帮助系统

PowerShell 入门 2:使用帮助系统 🎯 一、认识 PowerShell 帮助系统 1. 使用 Get-Help 查看命令说明 Get-Help Get-Service或使用别名: gsv2. 更新帮助系统 Update-Help3. 搜索包含关键词的命令(模糊搜索) Help *log*&a…