[预备知识] 5. 优化理论(一)

优化理论

梯度下降(Gradient Descent)

数学原理与可视化

梯度下降是优化领域的基石算法,其核心思想是沿负梯度方向迭代更新参数。数学表达式为:
θ t + 1 = θ t − α ∇ θ J ( θ t ) \theta_{t+1} = \theta_t - \alpha \nabla_\theta J(\theta_t) θt+1=θtαθJ(θt)
其中:

  • α \alpha α:学习率,控制步长
  • ∇ θ J \nabla_\theta J θJ:损失函数关于参数的梯度

几何解释:在三维空间中,梯度下降如同沿着最陡下降方向下山。二维可视化展示参数更新路径:

import matplotlib.pyplot as plt
import numpy as np# 定义二次函数及其梯度
def f(x): return x**2
def grad(x): return 2*x# 梯度下降轨迹可视化
x_path = []
x = 2.0
lr = 0.1
for _ in range(20):x_path.append(x)x -= lr * grad(x)# 绘制函数曲线和更新路径
xs = np.linspace(-2, 2, 100)
plt.figure(figsize=(10,6))
plt.plot(xs, f(xs), label="f(x) = x²")
plt.scatter(x_path, [f(x) for x in x_path], c='red', s=50, zorder=3)
plt.plot(x_path, [f(x) for x in x_path], 'r--', label="gradient descent path")
plt.title("梯度下降在二次函数上的优化轨迹", fontsize=14)
plt.xlabel("x", fontsize=12)
plt.ylabel("f(x)", fontsize=12)
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

在这里插入图片描述

学习率对比实验

lrs = [0.01, 0.1, 0.5]  # 不同学习率plt.figure(figsize=(12,6))
for lr in lrs:x = 2.0path = []for _ in range(20):path.append(x)x -= lr * grad(x)plt.plot(path, label=f"lr={lr}")plt.title("不同学习率对收敛速度的影响", fontsize=14)
plt.xlabel("Number of iterations", fontsize=12)
plt.ylabel("Parameter value", fontsize=12)
plt.axhline(0, color='black', linestyle='--')
plt.legend()
plt.grid(True, alpha=0.3)

在这里插入图片描述


随机梯度下降(Stochastic Gradient Descent, SGD)

算法原理

与传统梯度下降的对比:

方法梯度计算内存需求收敛性适用场景
批量梯度下降全数据集稳定小数据集
SGD单样本震荡在线学习
小批量SGD批量样本平衡最常见

数学表达式:
θ t + 1 = θ t − α ∇ θ J ( θ t ; x ( i ) , y ( i ) ) \theta_{t+1} = \theta_t - \alpha \nabla_\theta J(\theta_t; x^{(i)}, y^{(i)}) θt+1=θtαθJ(θt;x(i),y(i))

实际应用示例(MNIST分类)

import torchvision
from torch.utils.data import DataLoader# 数据准备
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
train_set = torchvision.datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)# 模型定义
model = torch.nn.Sequential(torch.nn.Flatten(),torch.nn.Linear(784, 10)
)# 优化器配置
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)# 训练循环
losses = []
for epoch in range(5):for batch_idx, (data, target) in enumerate(train_loader):optimizer.zero_grad()output = model(data)loss = torch.nn.functional.cross_entropy(output, target)loss.backward()optimizer.step()# 记录损失losses.append(loss.item())# 绘制损失曲线
plt.figure(figsize=(12,6))
plt.plot(losses, alpha=0.6)
plt.title("SGD在MNIST分类任务中的损失曲线", fontsize=14)
plt.xlabel("Number of iterations", fontsize=12)
plt.ylabel("Cross-entropy loss", fontsize=12)
plt.grid(True, alpha=0.3)

在这里插入图片描述


动量法(Momentum)

物理类比与数学表达

动量法引入速度变量 v v v,模拟物体运动惯性:

更新规则:
v t + 1 = β v t − α ∇ θ J ( θ t ) θ t + 1 = θ t + v t + 1 \begin{aligned} v_{t+1} &= \beta v_t - \alpha \nabla_\theta J(\theta_t) \\ \theta_{t+1} &= \theta_t + v_{t+1} \end{aligned} vt+1θt+1=βvtαθJ(θt)=θt+vt+1

其中 β ∈ [ 0 , 1 ) \beta \in [0,1) β[0,1)为动量系数,典型值为0.9

对比实验

def optimize_with_momentum(lr=0.01, beta=0.9):x = torch.tensor([2.0], requires_grad=True)velocity = 0path = []for _ in range(20):path.append(x.item())loss = x**2loss.backward()with torch.no_grad():velocity = beta * velocity - lr * x.gradx += velocityx.grad.zero_()return path# 运行对比实验
paths = {'普通SGD': optimize_with_momentum(beta=0),'动量法(beta=0.9)': optimize_with_momentum()
}# 可视化对比
plt.figure(figsize=(12,6))
for label, path in paths.items():plt.plot(path, marker='o', linestyle='--', label=label)plt.title("动量法与普通SGD收敛对比", fontsize=14)
plt.xlabel("Number of iterations", fontsize=12)
plt.ylabel("Parameter value", fontsize=12)
plt.axhline(0, color='black', linestyle='--')
plt.legend()
plt.grid(True, alpha=0.3)

在这里插入图片描述


算法选择指南

算法优点缺点适用场景
梯度下降稳定收敛计算成本高小规模数据集
SGD内存需求低收敛路径震荡在线学习、大规模数据
动量法加速收敛、抑制震荡需调参动量系数高维非凸优化

实践建议

  1. 学习率设置:从3e-4开始尝试,按数量级调整
  2. 批量大小:通常选择2的幂次(32, 64, 128)
  3. 动量系数:默认0.9,对RNN可尝试0.99
  4. 学习率衰减:配合StepLR或CosineAnnealing使用效果更佳
# 最佳实践示例:带学习率衰减的动量SGD
optimizer = torch.optim.SGD(model.parameters(),lr=0.1,momentum=0.9,weight_decay=1e-4  # L2正则化
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/904285.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/904285.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大模型微调Fine-tuning:从概念到实践的全面解析

目录 引言 一、什么是大模型微调? 1.1 预训练与微调的区别 1.2 微调的技术演进 二、为什么需要微调? 2.1 解决大模型的固有局限 2.2 微调的优势 三、主流微调方法 3.1 全参数微调 3.2 参数高效微调(PEFT) 四、微调实践指…

Docker 使用下 (二)

Docker 使用下 (二) 文章目录 Docker 使用下 (二)前言一、初识Docker1.1 、Docker概述1.2 、Docker的历史1.3 、Docker解决了什么问题1.4 、Docker 的优点1.5 、Docker的架构图 二、镜像三、容器四、数据卷4.1、数据卷的概念4.2 、…

洛谷P12238 [蓝桥杯 2023 国 Java A] 单词分类

[Problem Discription] \color{blue}{\texttt{[Problem Discription]}} [Problem Discription] Copy from luogu. [Analysis] \color{blue}{\texttt{[Analysis]}} [Analysis] 既然都是字符串前缀的问题了,那当然首先就应该想到 Trie \text{Trie} Trie 树。 我们可…

pta作业中有启发性的程序题

1 【知识点】&#xff1a;多态 函数接口定义&#xff1a; 以Student为基类&#xff0c;构建GroupA, GroupB和GroupC三个类 裁判测试程序样例&#xff1a; #include<iostream> #include <string> using namespace std;/* 请在这里填写答案 */int main() {const …

Scrapy框架之CrawlSpider爬虫 实战 详解

CrawlSpider 是 Scrapy 框架中一个非常实用的爬虫基类&#xff0c;它继承自 Spider 类&#xff0c;主要用于实现基于规则的网页爬取。相较于普通的 Spider 类&#xff0c;CrawlSpider 可以根据预定义的规则自动跟进页面中的链接&#xff0c;从而实现更高效、更灵活的爬取。 Scr…

Glide 如何加载远程 Base64 图片

最近有个需求&#xff0c;后端给出的图片地址并不是正常的 URL&#xff0c;而且需要一个接口去请求&#xff0c;但是返回的是 base64 数据流。这里不关心为啥要这么多&#xff0c;原因有很多&#xff0c;可能是系统的问题&#xff0c;也可能是能力问题。当然作为我们 Android 程…

004-nlohmann/json 快速认识-C++开源库108杰

了解 nlohmann/json 的特点&#xff1b;理解编程中 “数据战场”划分的概念&#xff1b;迅速上手多种方式构建一个JSON对象&#xff1b; 1 特点与安装 nlohmann/json 是一个在 github 长期霸占 “JSON” 热搜版第1的CJSON处理库。它的最大优点是与 C 标准库的容器数据&#xf…

#基础Machine Learning 算法(上)

机器学习算法的分类 机器学习算法大致可以分为三类&#xff1a; 监督学习算法 (Supervised Algorithms&#xff09;:在监督学习训练过程中&#xff0c;可以由训练数据集学到或建立一个模式&#xff08;函数 / learning model&#xff09;&#xff0c;并依此模式推测新的实例。…

正弦波、方波、三角波和锯齿波信号发生器——Multisim电路仿真

目录 Multisim使用教程说明链接 一、正弦波信号发生电路 1.1正弦波发生电路 电路组成 工作原理 振荡频率 1.2 正弦波发生电路仿真分析 工程文件链接 二、方波信号发生电路 2.1 方波发生电路可调频率 工作原理 详细过程 2.2 方波发生电路可调频率/可调占空比 调节占空比 方波产生…

【AND-OR-~OR锁存器设计】2022-8-31

缘由锁存器11111111111-硬件开发-CSDN问答 重置1&#xff0c;不论输入什么&#xff0c;输出都为0&#xff1b; 重置0&#xff0c;输入1就锁住1 此时输入再次变为0&#xff0c;输出不变&#xff0c;为锁住。

力扣-字符串-468 检查ip

思路 考察字符串的使用&#xff0c;还有对所有边界条件的检查 spilt&#xff08;“\.”&#xff09;&#xff0c;toCharArray&#xff0c;Integer.parseInt() 代码 class Solution {boolean checkIpv4Segment(String str){if(str.length() 0 || str.length() > 4) retur…

BC8 十六进制转十进制

题目&#xff1a;BC8 十六进制转十进制 描述 BoBo写了一个十六进制整数ABCDEF&#xff0c;他问KiKi对应的十进制整数是多少。 输入描述&#xff1a; 无 输出描述&#xff1a; 十六进制整数ABCDEF对应的十进制整数&#xff0c;所占域宽为15。 备注&#xff1a; printf可以使用…

ARM子程序和栈

微处理器中的栈由栈指针指向存储器中的栈顶来实现&#xff0c;当数据项入栈时&#xff0c;栈 指针向上移动&#xff0c;当数据项出栈时&#xff0c;栈指针向下移动。 实现栈时需要做出两个决定&#xff1a;一是当数据项进栈时是向低位地址方向向上生 长&#xff08;图a和图b&a…

jwt身份验证和基本的利用方式

前言 &#xff1a; 什么是jwt&#xff08;json web token&#xff09;&#xff1f; 看看英文单词的意思就是 json形式的token 他的基本的特征 &#xff1a; 类似于这样的 他有2个点 分割 解码的时候会有三个部分 头部 payload 对称密钥 这个就是对称加密 头部&am…

n8n工作流自动化平台的实操:利用本地嵌入模型,完成文件内容的向量化及入库

1.成果展示 1.1n8n的工作流 牵涉节点&#xff1a;FTP、Code、Milvus Vector Store、Embeddings OpenAI、Default Data Loader、Recursive Character Text Splitter 12.向量库的结果 2.实操过程 2.1发布本地嵌入模型服务 将bge-m3嵌入模型&#xff0c;发布成满足open api接口…

MATLAB人工大猩猩部队GTO优化CNN-LSTM多变量时间序列预测

本博客来源于CSDN机器鱼&#xff0c;未同意任何人转载。 更多内容&#xff0c;欢迎点击本专栏目录&#xff0c;查看更多内容。 目录 0 引言 1 数据准备 2 CNN-LSTM模型搭建 3 GTO超参数优化 3.1 GTO函数极值寻优 3.2 GTO优化CNN-LSTM超参数 3.3 主程序 4 结语 0 引言…

git项目迁移,包括所有的提交记录和分支 gitlab迁移到gitblit

之前git都是全新项目上传&#xff0c;没有迁移过&#xff0c;因为迁移的话要考虑已有项目上的分支都要迁移过去&#xff0c;提交记录能迁移就好&#xff1b;分支如果按照全新项目上传的方式需要新git手动创建好老git已有分支&#xff0c;在手动一个一个克隆老项目分支代码依次提…

Photo-SLAM论文理解、环境搭建、代码理解与实测效果

前言&#xff1a;第一个解耦式Photo-SLAM&#xff0c;亮点和效果。 参考&#xff1a;https://zhuanlan.zhihu.com/p/715311759 全网最细PhotoSLAM的conda环境配置教程&#xff0c;拒绝环境污染&#xff01;&#xff01;-CSDN博客 1. 环境搭建 硬件&#xff1a;RTX 4090D wi…

如何使用VSCode编写C、C++和Python程序

一、首先准备好前期工作。如下载安装Python、VSCode、一些插件等。写代码之前需要先创建文件夹和文件。 二、将不同语言写的代码放在不同的文件夹中&#xff0c;注意命名时不要使用中文。 三、打开VSCode&#xff0c;点击“文件”->“打开文件夹”->“daimalainxi”->…

基于不确定性感知学习的单图像自监督3D人体网格重建 (论文笔记与思考)

文章目录 论文解决的问题提出的算法以及启发点 论文解决的问题 首先这是 Self-Supervised 3D Human mesh recovery from a single image with uncertainty-aware learning &#xff08;AAAI 2024&#xff09;的论文笔记。该文中主要提出了一个自监督的framework用于人体的姿态…