【漫话机器学习系列】277.梯度裁剪(Gradient Clipping)

【深度学习】什么是梯度裁剪(Gradient Clipping)?一张图彻底搞懂!

在训练深度神经网络,尤其是 RNN、LSTM、Transformer 这类深层结构时,你是否遇到过以下情况:

  • 模型 loss 突然变成 NaN;

  • 梯度爆炸导致训练中断;

  • 训练刚开始几步模型就“失控”了。

这些问题,很多时候都是因为——梯度过大(梯度爆炸)。而应对这个问题的常见方案之一,就是本文要讲的主角:梯度裁剪(Gradient Clipping)


一、梯度裁剪是什么?

我们先看一张图,一图胜千言:

图中文字解读如下:

  • 标题:梯度裁剪(Gradient Clipping)

  • 说明文字

    损失函数中的梯度悬崖会导致模型在学习过程中超出期望最小值。发生这种情况,是因为梯度陡峭。解决方法:阻止梯度选择极端值。

  • 图示公式

    if ‖g‖ > v:g ← (g / ‖g‖) * v
    

    意思是:

    • 如果梯度的范数(即长度)大于某个阈值 v,就将梯度缩放为长度为 v 的向量。

    • 这样可以防止某些参数更新过大。


二、为什么需要梯度裁剪?

1. 梯度爆炸的根源

在反向传播中,每一层的梯度是前面所有梯度的乘积。在深层网络中,如果这些乘积的值都 > 1,最终梯度将呈指数级增长,导致所谓的梯度爆炸(Gradient Explosion)

表现形式

  • loss 一直上升,甚至变成 NaN

  • 参数更新过大,模型发散

  • 模型无法收敛

2. 梯度裁剪的作用

梯度裁剪并不会改变梯度的方向,它只是在梯度的模(大小)超过某个阈值时,进行缩放。这就像是给模型装了一个“刹车”系统,一旦速度过快就减速。


三、梯度裁剪的数学原理

设:

  • 当前梯度为 g

  • 范数为 ∥g∥

  • 阈值为 v

裁剪操作如下:

\text{if } \|g\| > v, \quad g \leftarrow \frac{g}{\|g\|} \cdot v

也就是说:将梯度的模限制在最大值 vv 内,方向保持不变。


四、实战中如何实现梯度裁剪?

在 PyTorch 中非常简单:

import torch# 假设已经定义 optimizer 和 model
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

在 TensorFlow(Keras)中也可以:

optimizer = tf.keras.optimizers.Adam(clipnorm=1.0)

五、梯度裁剪 vs 梯度正则化

名称作用是否改变方向
梯度裁剪控制梯度最大值,避免爆炸
L2 正则化(权重衰减)防止模型过拟合,限制权重大小

注意:梯度裁剪是为了“救训练”,不是为了“提高精度”!


六、何时需要使用梯度裁剪?

  • 训练深度模型如 RNN、LSTM、Transformer

  • loss 出现爆炸性增长,模型训练不稳定;

  • 使用高学习率训练时容易出问题;

  • 模型结构复杂,层数深,非线性强。


七、调参建议

参数建议取值说明
clip norm0.1 ~ 5通常从 1.0 开始尝试,逐步调整
适用优化器Adam、SGD梯度裁剪不依赖特定优化器
使用频率每次 step 前每次梯度更新前裁剪

八、总结

梯度裁剪是深度学习中极其实用的一种 训练稳定性保障机制。它的作用不是提升模型能力,而是防止模型“发疯”。在某些模型结构中(如 LSTM、GAN),它几乎是标配操作。

一句话总结:梯度裁剪不是为了让模型跑得快,而是为了别让它翻车。


推荐阅读

  • 《Deep Learning》by Ian Goodfellow(第 6 章)

  • PyTorch 官方文档:clip_grad_norm_


如果你觉得本文对你有帮助,欢迎点赞、收藏、评论~
也欢迎你分享你在训练中使用梯度裁剪的经验或踩过的坑!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/bicheng/82402.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

零基础弄懂 ngx_http_slice_module分片缓存加速

一、为什么需要 Slice? 在 NGINX 反向代理或 CDN 场景中,大文件(视频、软件包、镜像等)常因单体体积过大而令缓存命中率低、回源代价高。 ngx_http_slice_module 通过把一次完整响应拆分成 固定大小的字节块(Slice&am…

机器人强化学习入门学习笔记(三)

强化学习(Reinforcement Learning, RL)与监督学习不同——你不需要预先准备训练数据集,而是要设计环境、奖励函数,让智能体通过交互不断探索和学习。 🎯 一、强化学习和训练数据的关系 强化学习不依赖固定的数据集。它…

【python实战】二手房房价数据分析与预测

个人主页:大数据蟒行探索者 目录 一、数据分析目标与任务 1.1背景介绍 1.2课程设计目标与任务 1.3研究方法与技术路线 二、数据预处理 2.1数据说明 2.2数据清洗 2.3数据处理 三、数据探索分析 四、数据分析模型 五、方案评估 摘要:随着社会经…

Kotlin IR编译器插件开发指南

在 Kotlin 中开发基于 IR(Intermediate Representation)的编译器插件,可以深度定制语言功能或实现高级代码转换。以下是分步骤指南: 一、IR 编译器插件基础 IR 是什么? Kotlin 编译器将源码转换为 IR 中间表示&#xf…

如何用 python 代码复现 MATLAB simulink 的 PID

MATLAB在 Simulink 里做以下设置MATLAB 脚本调用示例 python 实现离散 PID 实现(并行形式) Simulink 中两种 PID 结构(并联形式, I-形式)下连续/离散时域里积分增益 I 的表示并联(Parallel) vs 理想&#x…

黑马点评--基于Redis实现共享session登录

集群的session共享问题分析 session共享问题:多台Tomcat无法共享session存储空间,当请求切换到不同Tomcat服务时,原来存储在一台Tomcat服务中的数据,在其他Tomcat中是看不到的,这就导致了导致数据丢失的问题。 虽然系…

SkyWalking启动失败:OpenSearch分片数量达到上限的完美解决方案

🚨 问题现象 SkyWalking OAP服务启动时报错: org.apache.skywalking.oap.server.library.module.ModuleStartException: java.lang.RuntimeException: {"error":{"root_cause":[{"type":"validation_exception", "reason&q…

向量数据库选型实战指南:Milvus架构深度解析与技术对比

导读:随着大语言模型和AI应用的快速普及,传统数据库在处理高维向量数据时面临的性能瓶颈日益凸显。当文档经过嵌入模型处理生成768到1536维的向量后,传统B-Tree索引的检索效率会出现显著下降,而现代应用对毫秒级响应的严苛要求使得…

MySQL#秘籍#一条SQL语句执行时间以及资源分析

背景 一条 SQL 语句的执行完,每个模块耗时,不同资源(CPU/IO/IPC/SWAP)消耗情况我该如何知道呢?别慌俺有 - MySQL profiling 1. SQL语句执行前 - 开启profiling -- profiling (0-关闭 1-开启) -- 或者:show variables like prof…

【数据结构】实现方式、应用场景与优缺点的系统总结

以下是编程中常见的数据结构及其实现方式、应用场景与优缺点的系统总结: 一、线性数据结构 1. 数组 (Array) 定义:连续内存空间存储相同类型元素。实现方式:int[] arr new int[10]; // Javaarr [0] * 10 # Python操作: 访问&…

PyTorch中cdist和sum函数使用示例详解

以下是PyTorch中cdist与sum函数的联合使用详解: 1. cdist函数解析 功能:计算两个张量间的成对距离矩阵 输入格式: X1:形状为(B, P, M)的张量X2:形状为(B, R, M)的张量p:距离类型(默认2表示欧式距离)输出:形状为(B, P, R)的距离矩阵,其中元素 d i j d_{ij} dij​表示…

Ansible配置文件常用选项详解

Ansible 的配置文件采用 INI 格式,分为多个模块,每个模块包含特定功能的配置参数。 以下是ansible.cfg配置文件中对各部分的详细解析: [defaults](全局默认配置) inventory 指定主机清单文件路径,默认值为 …

了解FTP搜索引擎

根据资料, FTP搜索引擎是专门搜集匿名FTP服务器提供的目录列表,并向用户提供文件信息的网站; FTP搜索引擎专门针对FTP服务器上的文件进行搜索; 就是它的搜索结果是一些FTP资源; 知名的FTP搜索引擎如下, …

【大模型面试每日一题】Day 28:AdamW 相比 Adam 的核心改进是什么?

【大模型面试每日一题】Day 28:AdamW 相比 Adam 的核心改进是什么? 📌 题目重现 🌟🌟 面试官:AdamW 相比 Adam 的核心改进是什么? #mermaid-svg-BJoVHwvOm7TY1VkZ {font-family:"trebuch…

C++系统IO

C系统IO 头文件的使用 1.使用系统IO必须包含相应的头文件,通常使用#include预处理指令。 2.头文件中包含了若干变量的声明,用于实现系统IO。 3.头文件的引用方式有双引号和尖括号两种,区别在于查找路径的不同。 4.C标准库提供的头文件通常没…

多模态理解大模型高性能优化丨前沿多模态模型开发与应用实战第七期

一、引言 在前序课程中,我们系统剖析了多模态理解大模型(Qwen2.5-VL、DeepSeek-VL2)的架构设计。鉴于此类模型训练需消耗千卡级算力与TB级数据,实际应用中绝大多数的用户场景均围绕推理部署展开,模型推理的效率影响着…

各个网络协议的依赖关系

网络协议的依赖关系 学习网络协议之间的依赖关系具有多方面重要作用,具体如下: 帮助理解网络工作原理 - 整体流程明晰:网络协议分层且相互依赖,如TCP/IP协议族,应用层协议依赖传输层的TCP或UDP协议来传输数据&#…

11.8 LangGraph生产级AI Agent开发:从节点定义到高并发架构的终极指南

使用 LangGraph 构建生产级 AI Agent:LangGraph 节点与边的实现 关键词:LangGraph 节点定义, 条件边实现, 状态管理, 多会话控制, 生产级 Agent 架构 1. LangGraph 核心设计解析 LangGraph 通过图结构抽象复杂 AI 工作流,其核心要素构成如下表所示: 组件作用描述代码对应…

相机--基础

在机器人开发领域,相机种类很多,作为一个机器人领域的开发人员,我们需要清楚几个问题: 1,相机的种类有哪些? 2,各种相机的功能,使用场景? 3,需要使用的相机…

【备忘】 windows 11安装 AdGuardHome,实现开机自启,使用 DoH

windows 11安装 AdGuardHome,实现开机自启,使用 DoH 下载 AdGuardHome解压 AdGuardHome启动 AdGuard Home设置 AdGuardHome设置开机自启安装 NSSM设置开机自启重启电脑后我们可以访问 **http://127.0.0.1/** 设置使用 AdGuardHome DNS 效果图 下载 AdGua…