深度学习——模型训练

以Pytorch自带的手写数据集为例。我们已经构建了一个输入层(28*28),两个隐藏层(128和256),一个输出层(10)的人工神经网络。并且结合非线性激活函数sigmoid定义前向传播的方向。


class NeuralNet(nn.Module):def __init__(self):super().__init__()self.flatten = nn.Flatten()self.hidden1 = nn.Linear(28*28,128)#第一层(28*28为输入的神经原数,128为输出的神经元数)self.hidden2 = nn.Linear(128,256)self.output = nn.Linear(256,10)def forward(self, x):#前向传播,表明数据流向,不能改变函数名,在父类中拥有同名函数,#必须在子类中覆盖该函数,不然会调用父类中的空函数。x = self.flatten(x)x = self.hidden1(x)x = torch.sigmoid(x)#非线性激活函数x = self.hidden2(x)x = torch.sigmoid(x)x = self.output(x)#x = torch.sigmoid(x)#对输出进行非线性激活return x

现在我们需要对模型进行训练

1.准备

创建数据加载器DataLoader加载数据

DataLoader是用来批量加载数据的工具,可以高效地迭代数据集并支持多进程加速。

training_dataloader = DataLoader(dataset=training_data, batch_size=64)

  • dataset=training_data:指定要加载的数据集(通常是 torch.utils.data.Dataset的子类实例)。
  • batch_size=64:每个批次加载 ​​64个样本​​。

以下是 DataLoader的常用参数(你可以在需要时补充):

参数

作用

shuffle

是否打乱数据顺序(训练集通常设为 True,测试集设为 False)。

num_workers

使用多少子进程加载数据(建议设为CPU核心数,如 4)。

drop_last

是否丢弃最后一个不完整的批次(当样本数不能被 batch_size整除时)。

pin_memory

是否将数据锁页(True可加速GPU传输,需配合CUDA使用)。

​数据格式​​:

  • 确保 training_data返回的数据是张量(或可转换为张量)。
  • 如果使用自定义数据集,需实现 __getitem__和 __len__方法。

现在我们加载训练集和测试集。

training_dataloader = DataLoader(dataset=training_data,batch_size=64)
test_dataloader = DataLoader(dataset=test_data,batch_size=64)

如果你拥有gpu可以通过以下代码对使用gpu

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# print(f'Using {device} device')model = NeuralNet().to(device)
print(model)

model.to(device)将模型的所有参数和缓冲区移动到指定设备(GPU/CPU):

model = NeuralNet().to(device)

​作用​​:

  • 若 device="cuda",模型会在NVIDIA GPU上运行(需安装CUDA版PyTorch)。
  • 若 device="mps",模型会使用Apple Silicon的GPU加速(需macOS 12.3+和M1/M2芯片)。
  • 若 device="cpu",模型在CPU上运行(兼容所有环境但速度较慢)。

分别导入交叉熵损失函数和(随机梯度下降)SGD优化器

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

1. 交叉熵损失函数(CrossEntropyLoss)​

​用途​
  • 适用于多分类任务(如MNIST手写数字识别、CIFAR-10图像分类)。
  • 输入应为​​未归一化的类别分数(logits)​​,无需手动添加Softmax层。
​数学形式​
  • yc​:真实标签的one-hot编码(实际由PyTorch自动处理)。
  • pc​:预测类别的概率(通过Softmax隐式计算)。
​关键注意事项​
  1. ​输入形状​​:
    • 预测值(logits):[batch_size, num_classes]
    • 真实标签:[batch_size](值为类别索引,如0到9)。

 随机梯度下降优化器(SGD)​

​参数解析​

optimizer = torch.optim.SGD( model.parameters(), # 待优化的模型参数

lr=0.01, # 学习率(关键超参数)

momentum=0.9, # 动量(可选,加速收敛)

weight_decay=0.001 # L2正则化(可选,防止过拟合) )

二、训练模型

在进行一系列处理后我们就可以训练模型了

def train(train_dataloader, model, loss_fn, optimizer):#train_dataloader为要训练的数据#model为训练的模型#loss_fn损失函数#optimizer优化器model.train()  # 设置模型为训练模式(启用Dropout/BatchNorm等)batch_size_num = 1  # 初始化批次计数器for X, y in train_dataloader:X, y = X.to(device), y.to(device)  # 数据移动到设备(GPU/CPU)pred = model(X)  # 前向传播(等价于model.forward(X))loss = loss_fn(pred, y)  # 计算损失optimizer.zero_grad()  # 清零梯度(防止累积)loss.backward()        # 反向传播(计算梯度)optimizer.step()       # 更新参数loss_val = loss.item()  # 获取标量损失值if batch_size_num % 100 == 0:print(f'Train loss: {loss_val:>7f}[number: {batch_size_num}]')batch_size_num += 1     # 更新批次计数

导入相应参数开始训练

train(training_dataloader,model,loss_fn,optimizer)

下列为训练结果

有训练好的模型后需要对其进行测试

def test(test_dataloader, model, loss_fn):size = len(test_dataloader.dataset)  # 测试集总样本数num_batches = len(test_dataloader)   # 测试集批次数量model.eval()  # 设置模型为评估模式(关闭Dropout/BatchNorm的随机性)test_loss, correct = 0, 0  # 初始化累计损失和正确预测数with torch.no_grad():  # 禁用梯度计算(节省内存和计算资源)for X, y in test_dataloader:X, y = X.to(device), y.to(device)  # 数据移动到设备pred = model(X)  # 前向传播test_loss += loss_fn(pred, y).item()  # 累加批次损失correct += (pred.argmax(1) == y).type(torch.float).sum().item()  # 累加正确预测数# 计算平均损失和准确率test_loss /= num_batchescorrect /= sizeprint(f'Test result:\n Accuracy: {(100*correct)}%, Avg loss: {test_loss}')

填入参数进行测试

test(test_dataloader,model,loss_fn)

结果如下

通过测试发现我们的结果准确率仅有23.73%,结果并不理想。在这我们可以通过多轮训练来优化模型。

epochs = 50
for epoch in range(epochs):print(f'Epoch {epoch+1}')train(training_dataloader,model,loss_fn,optimizer)
print("Finished Training")
test(test_dataloader,model,loss_fn)

结果如下

上图为50轮和100轮的结果。

通过调整优化器的学习率使lr=1可以将准确率进一步提高

但是上述方法的训练轮次太多,太过消耗时间。我们可以通过改变激活函数或者优化器优化训练模型

由于仅构建了两层隐含层,使用relu激活函数效果不如sigmoid激活函数。

这里我们修改优化器为Adam优化器

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

Adam优化器​​(Adaptive Moment Estimation)是深度学习中广泛使用的自适应学习率优化算法,结合了​​动量(Momentum)​​和​​RMSProp​​的优点,能够自动调整每个参数的学习率。以下是关于Adam优化器的详细解析及在PyTorch中的实践指南:

​1. Adam的核心思想​

  • ​自适应学习率​​:为每个参数维护独立的学习率,根据梯度的一阶矩(均值)和二阶矩(方差)动态调整。
  • ​动量机制​​:保留梯度的指数移动平均值(类似Momentum),加速收敛。
  • ​偏差校正​​:对初始时刻的矩估计进行校正,避免冷启动偏差。

​2. Adam的数学形式​

对于参数 θ和梯度 gt​:

  1. ​计算梯度的一阶矩(均值)和二阶矩(方差)​​:
    • mt​:梯度均值(动量)。
    • vt​:梯度方差(自适应学习率)。
    • β1​,β2​:衰减率(默认0.9和0.999)。
  2. ​偏差校正​​:
  3. 参数更新​​:
    • η:初始学习率。
    • ϵ:极小值(如1e-8)防止除零。

​3. PyTorch中的Adam优化器​

​基本用法​

import torch.optim as optim optimizer = optim.Adam( model.parameters(), # 待优化的模型参数 lr=0.001, # 初始学习率(默认0.001) betas=(0.9, 0.999), # 动量衰减系数(β₁, β₂) eps=1e-08, # 数值稳定项(默认1e-8) weight_decay=0.0 # L2正则化(默认0) )

​关键参数说明​

参数

作用

lr

学习率(通常设为0.001,需根据任务调整)。

betas

一阶矩和二阶矩的衰减率(默认(0.9, 0.999))。

eps

数值稳定项,防止分母为零(通常无需修改)。

weight_decay

L2正则化系数(如0.01),防止过拟合。

amsgrad

是否使用AMSGrad变体(默认False,解决Adam可能不收敛的问题)。


​4. Adam的优缺点​

​优点​
  • 自适应学习率​​:无需手动调整学习率,适合大多数任务。
  • 高效收敛​​:结合动量和自适应学习率,在稀疏梯度场景下表现优异。
  • ​超参数鲁棒性​​:默认参数(如lr=0.001)通常表现良好。
​缺点​

最终结果

  • ​内存占用较高​​:需保存每个参数的 mt​和 vt​。
  • 可能不收敛​​:在某些非凸问题上(如GAN训练),AMSGrad变体可能更稳定。

我们发现仅用10轮就达到了97.06%的准确率

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/94363.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/94363.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用Kiro智能开发PYTHON应用程序

文章目录使用Kiro智能开发PYTHON应用程序1. 什么是KIRO?2. 获取KIRO3. 安装KIRO4. 用KIRO开发智能应用程序6. 推荐阅读使用Kiro智能开发PYTHON应用程序 By JacksonML KIRO是AWS亚马逊云科技旗下的独立AI产品,是用来开发生产级应用程序的AI IDE。 本文简…

UNIX网络编程笔记:高级套接字编程12-19

IPv4与IPv6互操作性:技术解析与实践指南 在网络协议演进进程中,IPv4向IPv6的过渡是绕不开的关键阶段。尽管IPv6凭借海量地址、更优扩展性成为发展方向,但IPv4设备与网络的广泛存在,使得二者的互操作性成为保障网络平滑演进、业务持…

同类软件对比(一):Visual Studio(IDE) VS Visual Studio Code

文章目录前言一、Visual Studio(IDE)是什么?二、Visual Studio Code 是什么?三、两者的相同点四、两者的不同点五、实战选择建议总结前言 Visual Studio 和 Visual Studio Code,它们一个是微软旗下的老牌霸主&#xf…

数据结构初阶:详解单链表(一)

🔥个人主页:胡萝卜3.0 🎬作者简介:C研发方向学习者 📖个人专栏: 《C语言》《数据结构》 《C干货分享》 ⭐️人生格言:不试试怎么知道自己行不行 目录 顺序表问题与思考 正文 一、单链表 1.…

塞尔达传说 旷野之息 PC/手机双端(The Legend of Zelda: Breath of the Wild)免安装中文版

网盘链接: 塞尔达传说 旷野之息 免安装中文版 名称:塞尔达传说 旷野之息 PC/手机双端 免安装中文版 描述:忘记你所知道的关于塞尔达传说游戏的一切。在《塞尔达传说:旷野之息》中步入一个充满发现、探索和冒险的世界&#xff0…

【分享开题答辩过程】一辆摩托车带来的通关副本攻略----《摩托车网上销售系统》开题答辩!!

一、开题陈述 各位评委老师好,我是A同学。 本次我设计与实现的是基于ASP.NET的摩托车网上销售系统,该系统以 MySQL 为后台数据库,主要解决当前社会背景下用户线下看车购车困难的问题,同时顺应摩托车网络营销的发展趋势&#xff…

python + unicorn + xgboost + pytorch 搭建机器学习训练平台遇到的问题

1.背景前段时间,使用 python unicorn xgboost pytorch 写了一个机器学习训练平台的后端服务,根据公司开发需要,需具备两种需求:1. 可以本地加载使用;2.支持web服务,2. 使用本地加载使用2.1 问题针对第一…

Odoo 非标项目型生产行业解决方案:专业、完整、开源

概述您眼前的这张应用蓝图,是由 Odoo 官方金牌服务商——开源智造 (OSCG) 凭借多年在非标项目型制造领域的深厚积累,精心设计的 Odoo 解决方案核心流程图。它不仅体现了我们对行业复杂业务场景的深刻理解,更彰显了我们将先进的管理理念与强大…

OpenAI 开源模型 gpt-oss 是在合成数据上训练的吗?一些合理推测

编者按: OpenAI 首次发布的开源大模型 gpt-oss 系列为何在基准测试中表现亮眼,却在实际应用后发现不如预期? 我们今天为大家带来的这篇文章,作者推测 OpenAI 的新开源模型本质上就是微软 Phi 模型的翻版,采用了相同的合…

Linux / 宝塔面板下 PHP OPcache 完整实践指南

Linux / 宝塔面板下 PHP OPcache 完整实践指南 OPcache 是 PHP 官方提供的字节码缓存扩展,通过缓存 PHP 脚本的编译结果,提高 PHP 执行效率。本文讲解从 检测 → 开启 → 使用 → 清理 → 排查问题 的全流程,同时针对宝塔面板界面不实用或无法…

Linux(从入门到精通)

Linux概述 Linux内核最初只是由芬兰人林纳斯托瓦兹1991年在赫尔辛基大学上学时出于个人爱好而编写的。 Linux特点 首先Linux作为自由软件有两个特点:一是它免费提供源代码,二是爱好者可以根据自己的需要自由修改、复制和发布源码 Linux的各个发行版本 Linux 的发行版说简单…

链表相关题目---19、删除链表的倒数第N个节点

题目链接:删除链表的倒数第N个节点 这道题 很常规的思路就是 先拷贝两次头结点 然后一个先走N步 然后同时开始走,直到先走N步的节点为空后,就停止,此时另一个没提前走的节点的下一个就是要删除的节点。不过需要注意的是&#xff0…

Vue工具类使用指南:实用函数与全局组件安装

概述在Vue项目开发中,我们经常需要一些通用的工具函数来处理路径转换、链接判断、数据格式化等任务。本文将介绍一个实用的Vue工具类,包含多种常用功能,并演示如何在项目中使用它们。工具函数详解1. 路径转驼峰命名import { pathToCamel } fr…

​Visual Studio + UE5 进行游戏开发的常见故障问题解决

从零开始,学习 虚幻引擎5(UE5),开始游戏开发之旅! 本文章仅提供学习,切勿将其用于不法手段! 有些项目在 Visual Studio 的 Unreal Engine 集成配置界面中,涉及 ​Unreal Engine 与 V…

MiniCPM-V4.0开源并上线魔乐社区,多模态能力进化,手机可用,还有最全CookBook!

今天,面壁小钢炮新一代多模态模型 MiniCPM-V 4.0 正式开源。依靠 4B 参数,在 OpenCompass、OCRBench、MathVista 等多个榜单上取得了同级 SOTA 成绩,且 实现了在手机上稳定、丝滑运行。此外,面壁团队也正式开源了 推理部署工具 Mi…

FCT/ATE/ICT通用测试上位机软件

在当今智能制造与电子产品快速迭代的背景下,功能测试(FCT)已成为确保产品质量的关键环节。然而,传统的测试上位机往往存在扩展困难、功能固化、二次开发成本高等问题。为此,我们提出一款模块化、可扩展、可脚本化的 FC…

IndexTTS介绍与部署(B站开源的工业级语音合成模型)

语音合成效果非常好,可作为自己日常文本转语音使用工具! 软件介绍 IndexTTS 是由哔哩哔哩(B 站)开源的工业级可控高效零样本文本转语音(TTS)系统,基于 XTTS 和 Tortoise 构建,采用 …

uniApp对接实人认证

前端代码部分<template><view class"wrap"><view class"box"><view class"item flex-row align-items-center space-between"><view class"name"><text style"color:#FF4D4D">*</te…

pytest 并发执行用例(基于受限的测试资源)

概要 本文主要介绍了如何在测试资源&#xff08;被测对象&#xff09;受限的情况下&#xff0c;使用 pytest 进行并发测试以减少总体测试时间的方法和过程。 背景 在软件开发过程中&#xff0c;我们通常使用测试用例来持续保证软件的质量&#xff08;例如&#xff0c;确保关…

结构化智能编程:用树形向量存储重构AI代码理解范式

结构化智能编程:用树形向量存储重构AI代码理解范式 告别暴力embedding,通过分层存储策略让AI精准理解百万行代码库 在AI编程助手日益普及的今天,开发者面临一个新的困境:当项目规模达到数万甚至数百万行代码时,传统的暴力向量化方法不仅效率低下,而且往往导致AI理解偏差。…