【大模型LLM】大模型训练加速 - 梯度累积(Gradient Accumulation)原理详解

在这里插入图片描述

梯度累积(Gradient Accumulation)原理详解

梯度累积是一种在深度学习训练中常用的技术,特别适用于显存有限但希望使用较大批量大小(batch size)的情况。通过梯度累积,可以在不增加单个批次大小的情况下模拟较大的批量大小,从而提高模型的稳定性和收敛速度。

基本概念

在标准的随机梯度下降(SGD)及其变体(如Adam、RMSprop等)中,每次更新模型参数时都需要计算整个批次数据的损失函数梯度,并立即用这个梯度来更新模型参数。然而,在处理大规模数据集或使用非常大的模型时,单个批次的数据量可能会超出GPU显存的容量。此时,梯度累积技术就可以发挥作用。

工作原理

梯度累积的核心思想是:将多个小批次(mini-batch)的梯度累加起来,然后一次性执行一次参数更新。具体步骤如下:

  1. 初始化梯度累积器:在每个训练步骤开始时,初始化一个梯度累积器(通常为零)。
  2. 前向传播与梯度计算
    • 对于每一个小批次 i(从 1 到 k),执行前向传播计算损失。
    • 执行反向传播计算该小批次的梯度。
  3. 累积梯度:将当前小批次的梯度累加到梯度累积器中。
  4. 参数更新:当累积了 k 个小批次的梯度后,使用累积的梯度来更新模型参数,并重置梯度累积器。
详细步骤

假设我们希望使用的批量大小是 N,但由于显存限制只能使用较小的批量大小 n(其中 N = k * n),那么我们可以进行 k 次前向和后向传播,每次都计算一个小批次的梯度并将其累加,直到累积了 k 个小批次的梯度之后,再进行一次参数更新。

示例代码

以下是一个简单的PyTorch示例,展示了如何实现梯度累积:

import torch
import torch.nn as nn
import torch.optim as optim# 假设有一个简单的模型
model = nn.Linear(10, 2)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)# 设置梯度累积步数
accumulation_steps = 4
optimizer.zero_grad()  # 清空梯度for i, (inputs, labels) in enumerate(data_loader):outputs = model(inputs)loss = criterion(outputs, labels)# 将损失除以累积步数,使得总的损失不变loss = loss / accumulation_steps# 反向传播计算梯度loss.backward()if (i + 1) % accumulation_steps == 0:# 累积足够步数后,执行优化步骤optimizer.step()optimizer.zero_grad()  # 清空梯度
关键点解释
  1. 损失缩放:由于我们将一个大批次分成多个小批次,并且每次只计算一个小批次的损失,因此需要将每个小批次的损失除以累积步数 accumulation_steps,以确保总的损失值保持不变。

  2. 梯度累积:每次反向传播后,梯度会被累加而不是立即用于更新参数。只有当累积了足够的步数后,才会使用累积的梯度进行一次参数更新。

  3. 参数更新:在累积了足够的梯度后,调用 optimizer.step() 来更新模型参数,并清空梯度累积器(即调用 optimizer.zero_grad())。

优点
  • 突破显存限制:通过使用较小的批量大小,可以有效地减少每一步所需的显存量,从而允许在有限的硬件资源上训练更大的模型或使用更大的批量大小。
  • 模拟大批次训练效果:梯度累积实际上模拟了使用较大批量大小的效果,有助于提高模型训练的稳定性和收敛速度。
  • 灵活性:可以根据实际硬件条件灵活调整累积步数,适应不同的训练需求。
注意事项
  • 学习率调整:由于梯度累积实际上是将多个小批次的梯度累加起来进行一次更新,因此需要相应地调整学习率。例如,如果原始设置的学习率为 lr,并且使用了 k 步梯度累积,则新的有效学习率应为 lr * k
  • 随机性影响:梯度累积可能会引入一定的随机性,因为不同小批次之间的顺序可能会影响最终的梯度累积结果。不过,在实践中这种影响通常是可以接受的。
总结

梯度累积是一种非常实用的技术,特别是在显存受限但希望利用更大批量大小的情况下。它不仅帮助克服了硬件限制,还能够保持甚至提升模型训练的质量。通过合理配置梯度累积步数和学习率,可以显著改善训练效率和效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/93889.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/93889.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据分享】各省文旅融合耦合协调度及原始数据(2012-2022)

数据介绍引言 文旅融合是推动区域经济高质量发展、促进共同富裕的重要路径。党的二十大报告明确提出“推进文化和旅游深度融合发展”的战略目标,文旅产业通过资源整合与业态创新,可显著缩小城乡、区域差距,提升物质与精神双重福祉&#xff08…

Linux编程: 10、线程池与初识网络编程

今天我计划通过一个小型项目,系统讲解线程池与网络编程的核心原理及实践。项目将围绕 “利用线程池实现高并发网络通信” 这一核心需求展开,具体设计如下: 为保证线程安全,线程池采用单例模式设计,确保全局唯一实例避…

藏云阁 Logo 库(开源项目SVG/PNG高清Logo)

在日常技术方案设计、架构图绘制或PPT制作中,常常会遇到一些问题,比如: 找不到统一风格的开源项目组件图标,PPT中的logo五花八门下载的图标分辨率不足,放大后模糊失真不同来源的图标颜色风格冲突,破坏整体…

从0开始学习R语言--Day64--决策树回归

对于没有特征或者说需要寻找另类关系的数据集,我们通常会用聚合或KNN近邻的方法来分类,但这样的分类或许在结果上是好的,但是解释性并不好,有时候我们甚至能看到好的结果反直觉;而决策树回归做出的结果,由于…

B+树高效实现与优化技巧

B树的定义 一颗M阶B树T,满足以下条件 每个结点至多拥有M课子树 根结点至少拥有两颗子树 除了根结点以外,其余每个分支结点至少拥有M/2课子树 所有的叶结点都在同一层上 有k棵子树的分支结点则存在k-1个关键字,关键字按照递增顺序进行排序 关键字数量满足 ceil( M/2 ) - 1 &…

Android 基础入门学习目录(持续更新)

四大组件 Activity: Service: BroadcastReceiver: ContentProvider: UI 与交互开发 常见的UI布局和UI控件 样式与主题 Fragment Intent 数据存储 自定义View和自定义Group 自定义View 自定义ViewGroup 事件分发 Key…

Linux移动大量文件命令

背景 使用 mv 命令报“/bin/mv: 参数列表过长”,也是第一遇到,查了一下,最后用rsync命令解决了。还好每台服务器,都必装rsync了,记录如下。 命令 nohup rsync -av --remove-source-files --progress /public/tmp/video…

SQL中的HAVING用法

HAVING 是 SQL 中专门对 “分组之后的聚合结果” 再做筛选的子句。 它一般跟在 GROUP BY 后面,不能单独使用,作用类似于分组版的 WHERE。✅ 1. 语法位置 SELECT 列1, 聚合函数(列2) AS 别名 FROM 表 GROUP BY 列1 HAVING 聚合条件; -- 这里写对聚合…

【Halcon 】Halcon 实战:如何为 XLD 模板添加极性信息以提升匹配精度?

Halcon 实战:如何为 XLD 模板添加极性信息以提升匹配精度? 在使用 Halcon 进行模板匹配时,我们通常有两种方式创建模板: 基于图像灰度(CreateScaledShapeModel)基于轮廓 XLD(CreateScaledShapeM…

grafana/lock-stack 日志 Pipeline 配置

前言 本文使用的是 grafana/loki-stack chart 抓取的 k8s 日志。其他 chart 配置都差不多。 日志问题 docker 容器运行时 pod 内原始日志 [cpu-4] Hello, 第 9788 次报时,时间:2025-08-01T06:35:420000 {"HOSTNAME":"cpu-4",&qu…

appium2.0+之PointerActions详解

以下内容在 夜神模拟器 上进行。 一、应用场景 一些针对手势的操作,比如滑动、长按、拖动等。可以将这些基本手势组合成一个相对复杂的手势。 二、使用步骤创建触摸输入设备(模拟手指操作) touch_input PointerInput(interaction.POINTER_TO…

Java HTTPS 请求失败排查与证书导入全过程

文章目录Java HTTPS 请求失败排查与证书导入全过程问题背景问题初步分析排查过程查看目标地址证书导入证书验证证书是否导入成功重启应用进一步验证:是否真的是证书问题?1. 浏览器访问2. 抓包工具验证(如 Charles、Wireshark)补充…

android APT技术

1,背景 对于注解的使用,想必大家都不陌生,它出现在我们的源码中,以及大部分框架中,比如ButterKnife、Arouter、Retrofit,但它们是有区别的,其中前2个是编译时注解,最后一个是运行时注…

MySQL 和 PostgreSQL综合比对分析汇总

面对大数据项目或其它类型项目中,面对关系型数据库选择一直是很总要的一点,本文针对MySQL 和 PostgreSQL进行综合比对分析汇总,内容仅供参考。MySQL 和 PostgreSQL 是两款主流的开源关系型数据库(RDBMS),但…

Linux---make和makefile

一、基本概念1.是什么make是一条命令,makefile是一个文件2.对应在vs中按一下f5就能运行代码,在Linux中make就相当于f5,使用makefile来封装从而实现我, 想要的功能3.使用①创建makefile文件②编辑makefile解释:test.exe…

【DAB收音机】DAB收音机协议及其他资料汇总

目录[ETSI DAB标准协议文档](https://www.etsi.org/standards)Other DAB资料DAB收音机相关的专利DAB收音机相关的期刊及学位论文DAB开源项目代码仓库qt-dab工具welle.io工具dablin工具【eti广播工具】⚙️ 项目对比与选型建议Other 收音机资料Other资料ETSI DAB标准协议文档 官…

RabbitMQ的特点和消息可靠性保障

掌握RabbitMQ的核心知识,需从其特点和消息可靠性保障(尤其是消息丢失解决方案)两方面入手,以下是详细说明: 一、RabbitMQ的核心特点 RabbitMQ是基于AMQP(Advanced Message Queuing Protocol)协议…

项目升级啦

公司要新做一个医疗行业的业务,经过业务端和产品端的评估该业务与公司已有的产品线关联不大,用户后续也不想在老系统那台老爷车上继续使用,话说老系统到现在差不多10年了,中间经历过的前后端开发者形形色色,维护者换了…

Android中页面生命周期变化

一、Activity切换的生命周期变化(A启动B)1. 标准流程(B完全覆盖A)完整生命周期路径:Activity A:onPause():失去焦点,仍部分可见onStop():完全不可见(当B完全覆…

自动驾驶控制算法——PID算法

自动驾驶控制算法——PID算法 文章目录自动驾驶控制算法——PID算法一、PID 是什么?二、PID 原理2.1 **比例环节(P)**2.2 **积分环节(I)**2.3 **微分环节(D)**2.4 特点总结2.5 案例分析 —— 小…