2025-05-28 Python深度学习8——优化器

文章目录

  • 1 工作原理
  • 2 常见优化器
    • 2.1 SGD
    • 2.2 Adam
  • 3 优化器参数
  • 4 学习率
  • 5 使用最佳实践

本文环境:

  • Pycharm 2025.1
  • Python 3.12.9
  • Pytorch 2.6.0+cu124

​ 优化器 (Optimizer) 是深度学习中的核心组件,负责根据损失函数的梯度来更新模型的参数,使模型能够逐步逼近最优解。在 PyTorch 中,优化器通过torch.optim模块提供。

​ Pytorch 链接:https://docs.pytorch.org/docs/stable/optim.html。

1 工作原理

​ 优化器的工作流程如下:

  1. 计算损失函数的梯度 (通过backward()方法)。
  2. 根据梯度更新模型参数 (通过step()方法)。
  3. 清除之前的梯度 (通过zero_grad()方法)。
result_loss.backward()  # 计算梯度
optim.step()           # 更新参数
optim.zero_grad()      # 清除梯度

2 常见优化器

​ PyTorch 提供多种优化器,以 SGD 和 Adam 为例。

2.1 SGD

​ 基础优化器,可以添加动量 (momentum) 来加速收敛。

image-20250528111156055
参数类型默认值作用使用建议
paramsiterable-待优化参数必须传入model.parameters()或参数组字典,支持分层配置
lrfloat1e-3学习率控制参数更新步长,SGD常用0.01-0.1,Adam常用0.001
momentumfloat0动量因子加速梯度下降(Adam内置动量,无需单独设置)
dampeningfloat0动量阻尼抑制动量震荡(仅当momentum>0时生效)
weight_decayfloat0L2正则化防止过拟合,AdamW建议0.01-0.1
nesterovboolFalseNesterov动量改进版动量法(需momentum>0)
maximizeboolFalse最大化目标默认最小化损失,True时改为最大化
foreachboolNone向量化实现CUDA下默认开启,内存不足时禁用
differentiableboolFalse可微优化允许优化器步骤参与自动微分(影响性能)
fusedboolNone融合内核CUDA加速,支持float16/32/64/bfloat16

2.2 Adam

image-20250528111450759
  • 特点:自适应矩估计,结合了动量法和 RMSProp 的优点。
  • 优点:通常收敛速度快,对学习率不太敏感。
参数名称类型默认值作用使用建议
paramsiterable-需要优化的参数(如model.parameters()必须传入,支持参数分组配置
lrfloat1e-3学习率(控制参数更新步长)推荐0.001起调,CV任务可尝试0.0001-0.01
betas(float, float)(0.9, 0.999)梯度一阶矩(β₁)和二阶矩(β₂)的衰减系数保持默认,除非有特殊需求
epsfloat1e-8分母稳定项(防止除以零)混合精度训练时可增大至1e-6
weight_decayfloat0L2正则化系数推荐0.01-0.1(使用AdamW时更有效)
decoupled_weight_decayboolFalse启用AdamW模式(解耦权重衰减)需要权重衰减时建议设为True
amsgradboolFalse使用AMSGrad变体(解决收敛问题)训练不稳定时可尝试启用
foreachboolNone使用向量化实现加速(内存消耗更大)CUDA环境下默认开启,内存不足时禁用
maximizeboolFalse最大化目标函数(默认最小化)特殊需求场景使用
capturableboolFalse支持CUDA图捕获仅在图捕获场景启用
differentiableboolFalse允许通过优化器步骤进行自动微分高阶优化需求启用(性能下降)
fusedboolNone使用融合内核实现(需CUDA)支持float16/32/64时启用可加速

3 优化器参数

​ 所有优化器都接收两个主要参数:

  1. params:要优化的参数,通常是model.parameters()
  2. lr:学习率(learning rate),控制参数更新的步长。

​ 其他常见参数:

  • weight_decay:L2 正则化系数,防止过拟合。
  • momentum:动量因子,加速 SGD 在相关方向的收敛。
  • betas(Adam 专用):用于计算梯度及其平方的移动平均的系数。

4 学习率

​ 学习率是优化器中最重要的超参数之一。

  • 太大:可能导致震荡或发散。
  • 太小:收敛速度慢。
  • 常见策略:
    • 固定学习率 (如代码中的 0.01)。
    • 学习率调度器 (Learning Rate Scheduler) 动态调整。

5 使用最佳实践

  1. 梯度清零:每次迭代前调用optimizer.zero_grad(),避免梯度累积。
  2. 参数更新顺序:先backward()step()
  3. 学习率选择:可以从默认值开始 (如 Adam 的 0.001),然后根据效果调整。
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoaderdataset = torchvision.datasets.CIFAR10(root='./dataset',  # 保存路径train=False,  # 是否为训练集transform=torchvision.transforms.ToTensor(),  # 转换为张量download=True  # 是否下载
)dataloader = DataLoader(dataset, batch_size=1)class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):return self.model(x)loss = nn.CrossEntropyLoss()
model = MyModel()
torch.optim.Adam(model.parameters(), lr=0.01)
optim = torch.optim.SGD(model.parameters(), lr=0.01)for epoch in range(20):running_loss = 0.0# 遍历dataloader中的数据for data in dataloader:# 获取数据和标签imgs, targets = data# 使用模型对数据进行预测output = model(imgs)# 计算预测结果和真实标签之间的损失result_loss = loss(output, targets)# 将梯度置零optim.zero_grad()# 反向传播计算梯度result_loss.backward()# 更新模型参数optim.step()running_loss += result_lossprint(f'第 {epoch + 1} 轮的损失为 {running_loss}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/diannao/84961.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Web攻防-SQL注入增删改查HTTP头UAXFFRefererCookie无回显报错

知识点: 1、Web攻防-SQL注入-操作方法&增删改查 2、Web攻防-SQL注入-HTTP头&UA&Cookie 3、Web攻防-SQL注入-HTTP头&XFF&Referer 案例说明: 在应用中,存在增删改查数据的操作,其中SQL语句结构不一导致注入语句…

Windows MongoDB C++驱动安装

MongoDB驱动下载 MongoDB 官网MongoDB C驱动程序入门MongoDB C驱动程序入门 安装环境 安装CMAKE安装Visual Studio 编译MongoDB C驱动 C驱动依赖C驱动,需要先编译C驱动 下载MongoDB C驱动源码 打开CMAKE(cmake-gui) 选择源码及输出路径,然后点击configure …

使用 C/C++ 和 OpenCV 调用摄像头

使用 C/C 和 OpenCV 调用摄像头 📸 OpenCV 是一个强大的计算机视觉库,它使得从摄像头捕获和处理视频流变得非常简单。本文将指导你如何使用 C/C 和 OpenCV 来调用摄像头、读取视频帧并进行显示。 准备工作 在开始之前,请确保你已经正确安装…

使用微软最近开源的WSL在Windows上优雅的运行Linux

install wsl https://github.com/microsoft/WSL/releases/download/2.4.13/wsl.2.4.13.0.x64.msi install any distribution from microsoft store, such as kali-linux from Kali office website list of distribution PS C:\Users\50240> wsl -l -o 以下是可安装的有…

Win11安装Dify

1、打开Virtual Machine Platform功能 电脑系统为:Windows 11 家庭中文版24H2版本。 打开控制面板,点击“程序”,点击“启用或关闭Windows功能”。 下图标记的“Virtual Machine Platform”、“适用于 Linux 的 Windows 子系统”、“Windows…

C++模板类深度解析与气象领域应用指南

支持开源,为了更好的后来者 ————————————————————————————————————————————————————By 我说的 C模板类深度解析与气象领域应用指南 一、模板类核心概念 1.1 模板类定义 模板类是C泛型编程的核心机制&#x…

MongoDB(七) - MongoDB副本集安装与配置

文章目录 前言一、下载MongoDB1. 下载MongoDB2. 上传安装包3. 创建相关目录 二、安装配置MongoDB1. 解压MongoDB安装包2. 重命名MongoDB文件夹名称3. 修改配置文件4. 分发MongoDB文件夹5. 配置环境变量6. 启动副本集7. 进入MongoDB客户端8. 初始化副本集8.1 初始化副本集8.2 添…

mac笔记本如何快捷键截图后自动复制到粘贴板

前提:之前只会进行部分区域截图操作(commandshift4)操作,截图后发现未自动保存在剪贴板,还要进行一步手动复制到剪贴板的操作。 mac笔记本如何快捷键截图后自动复制到粘贴板 截取 Mac 屏幕的一部分并将其自动复制到剪…

WPF 按钮点击音效实现

WPF 按钮点击音效实现 下面我将为您提供一个完整的 WPF 按钮点击音效实现方案&#xff0c;包含多种实现方式和高级功能&#xff1a; 完整实现方案 MainWindow.xaml <Window x:Class"ButtonClickSound.MainWindow"xmlns"http://schemas.microsoft.com/win…

C++ list基础概念、list初始化、list赋值操作、list大小操作、list数据插入

list基础概念&#xff1a;list中的每一部分是一个Node&#xff0c;由三部分组成&#xff1a;val、next、prev&#xff08;指向上一个节点的指针&#xff09; list初始化的代码&#xff0c;见下 #include<iostream> #include<list>using namespace std;void printL…

【Pandas】pandas DataFrame equals

Pandas2.2 DataFrame Reindexing selection label manipulation 方法描述DataFrame.add_prefix(prefix[, axis])用于在 DataFrame 的行标签或列标签前添加指定前缀的方法DataFrame.add_suffix(suffix[, axis])用于在 DataFrame 的行标签或列标签后添加指定后缀的方法DataFram…

【ROS2】创建单独的launch包

【ROS】郭老二博文之:ROS目录 1、简述 项目中,可以创建单独的launch包来管理所有的节点启动 2、示例 1)创建launch包(python) ros2 pkg create --build-type ament_python laoer_launch --license Apache-2.02)创建启动文件 先创建目录:launch 在目录中创建文件:r…

GitHub 趋势日报 (2025年05月23日)

本日报由 TrendForge 系统生成 https://trendforge.devlive.org/ &#x1f310; 本日报中的项目描述已自动翻译为中文 &#x1f4c8; 今日整体趋势 Top 10 排名项目名称项目描述今日获星总星数语言1All-Hands-AI/OpenHands&#x1f64c;开放式&#xff1a;少代码&#xff0c;做…

鸿蒙OSUniApp 实现的数据可视化图表组件#三方框架 #Uniapp

UniApp 实现的数据可视化图表组件 前言 在移动互联网时代&#xff0c;数据可视化已成为产品展示和决策分析的重要手段。无论是运营后台、健康监测、还是电商分析&#xff0c;图表组件都能让数据一目了然。UniApp 作为一款优秀的跨平台开发框架&#xff0c;支持在鸿蒙&#xf…

[ctfshow web入门] web124

信息收集 error_reporting(0); //听说你很喜欢数学&#xff0c;不知道你是否爱它胜过爱flag if(!isset($_GET[c])){show_source(__FILE__); }else{//例子 c20-1$content $_GET[c];// 长度不允许超过80个字符if (strlen($content) > 80) {die("太长了不会算");}/…

Vue 技术文档

一、引言 Vue 是一款用于构建用户界面的渐进式 JavaScript 框架&#xff0c;具有易上手、高性能、灵活等特点&#xff0c;能够帮助开发者快速开发出响应式的单页面应用。本技术文档旨在全面介绍 Vue 的相关技术知识&#xff0c;为开发人员提供参考和指导。 二、环境搭建 2.1…

Nodejs+http-server 使用 http-server 快速搭建本地图片访问服务

在开发过程中&#xff0c;我们经常需要临时查看或分享本地的图片资源&#xff0c;比如设计稿、截图、素材等。虽然可以通过压缩发送&#xff0c;但效率不高。本文将教你使用 Node.js 的一个轻量级工具 —— http-server&#xff0c;快速搭建一个本地 HTTP 图片预览服务&#xf…

通义智文开源QwenLong-L1: 迈向长上下文大推理模型的强化学习

&#x1f389; 动态 2025年5月26日: &#x1f525; 我们正式发布&#x1f917;QwenLong-L1-32B——首个采用强化学习训练、专攻长文本推理的LRM模型。在七项长文本文档问答基准测试中&#xff0c;QwenLong-L1-32B性能超越OpenAI-o3-mini和Qwen3-235B-A22B等旗舰LRM&#xff0c…

学习如何设计大规模系统,为系统设计面试做准备!

前言 在当今快速发展的技术时代&#xff0c;系统设计能力已成为衡量一名软件工程师专业素养的重要标尺。随着云计算、大数据、人工智能等领域的兴起&#xff0c;构建高性能、可扩展且稳定的系统已成为企业成功的关键。然而&#xff0c;对于许多工程师而言&#xff0c;如何有效…

Python生成ppt(python-pptx)N问N答(如何绘制一个没有背景的矩形框;如何绘制一个没有背景的矩形框)

文章目录 [toc]1. **如何安装python-pptx库&#xff1f;**2. **如何创建一个空白PPT文件&#xff1f;**3. **如何添加幻灯片并设置布局&#xff1f;**4. **如何添加文本内容&#xff1f;**5. **如何插入图片&#xff1f;**6. **如何设置动画和转场效果&#xff1f;**9. **如何绘…