【机器学习】反向传播如何求梯度(公式推导)

写在前面

前期学习深度学习的时候,很多概念都是一笔带过,只是觉得它在一定程度上解释得通就行,但是在强化学习的过程中突然意识到,反向传播求梯度其实并不是一件简单的事情,这篇博客的目的就是要讲清楚反向传播是如何对特定的某一层求梯度,进而更新其参数的

为什么反向传播不容易实现

首先,众所周知,深度网络就是一个广义的多元函数,
但在通常情况下,想要求一个函数的梯度,就必须知道这个函数的具体表达式
但是问题就在于,深度网络的“传递函数”并不容易获得,或者说并不容易显式地获得
进而导致反向传播的过程难以进行

为什么反向传播可以实现

损失函数是关于参数的函数

如果要将一个函数F对一个变量x求偏导,那偏导存在的前提条件就是F是关于x的函数,否则求导结果就是0

  • 符号定义(后续公式均据此展开)
    • x=[x1,x2,…xn]x=[x_1,x_2,\dots x_n]x=[x1,x2,xn]
    • ypred=[y1,y2,…ym]y_{pred}=[y_1,y_2, \dots y_m]ypred=[y1,y2,ym]
    • θ=[w1,b1,w2,b2,…wn,bn]\theta=[w_1,b_1,w_2,b_2,\dots w_n,b_n]θ=[w1,b1,w2,b2,wn,bn]
    • ai=第i层网络激活函数的输出,最后一层的输出就是ypreda^{i}=第 i层网络激活函数的输出,最后一层的输出就是y_{pred}ai=i层网络激活函数的输出,最后一层的输出就是ypred
    • zi=第i层网络隐藏层的输出z^{i}=第i层网络隐藏层的输出zi=i层网络隐藏层的输出
    • gi ′(zi)第i层激活函数的导数,在输入=zi处的值g^i\ '(z^i)第i层激活函数的导数,在输入=z^i处的值gi (zi)i层激活函数的导数,在输入=zi处的值
  • 关系式
    • 网络的抽象函数式 ypred=F(x;θ)y_{pred}=F(x;\theta)ypred=F(x;θ)
      即网络就是一个巨大的多元函数,接受两个向量(模型输入和参数)作为输入,经过内部正向传播后输出一个向量
    • 损失函数Loss的抽象函数式 Loss=L(ytrue,ypred)=L(ytrue,F(x;θ))Loss=L(y_{true},y_{pred})=L(y_{true},F(x;\theta))Loss=L(ytrue,ypred)=L(ytrue,F(x;θ))
      其中 ytruey_{true}ytruexxx 属于参变量,它虽然会变,但是和模型本身没什么关系,唯一属于模型自己的变量就是 θ\thetaθ,所以不难看出,损失函数L是关于模型参数 θ\thetaθ 的函数,损失值Loss完全由模型参数 θ\thetaθ 决定

链式法则

  • 这个法则是深层网络得以实现梯度计算的关键
    核心公式如下:
    ∂L∂θi=∂L∂zi⋅∂zi∂θi \frac{\partial L}{\partial \theta^i}=\frac{\partial L}{\partial z^i}·\frac{\partial z^i}{\partial \theta^i} θiL=ziLθizi
    其中,∂L∂zi\frac{\partial L}{\partial z^i}ziL是损失L对第i层加权输入ziz^izi的梯度,∂zi∂θi\frac{\partial z^i}{\partial \theta^i}θizi是第i层加权输入ziz^izi对本层参数θi\theta^iθi的梯度

  • 进一步深究可以发现∂zi∂θi\frac{\partial z^i}{\partial \theta^i}θizi相对容易求,因为它只涉及到当前层的当前神经元的求解,在面向对象语言中,很容易为每个属于同一个类的实例增加一个方法,比如像这里的输入对参数求导,举例来说;
    if θi=Wi and Zi=Wi∗ai−1+bi,then ∂zi∂θi=(ai−1)T if\ \theta^i=W^i\ and\ Z^i=W^i*a^{i-1}+b^i,\\ then\ \frac{\partial z^i}{\partial \theta^i}=(a^{i-1})^T if θi=Wi and Zi=Wiai1+bi,then θizi=(ai1)T
    其中,
    (说实话,我非常想把隐藏层称为“传递函数”,控制和机器学习实际上有非常多可以相互借鉴的地方,而且在事实上,二者也确实是不可分割的关系)

  • 然后我们要来处理相对麻烦的 ∂L∂zi\frac{\partial L}{\partial z^i}ziL

    • 多层感知机为例,共k层,已知网络输出,求网络第i层的梯度
    • 用数学归纳法在这种递归系统中比较合适
      • 归纳奠基
        L=L(ytrue,ypred)=L(ytrue,F(x;θ))L=L(y_{true},y_{pred})=L(y_{true},F(x;\theta))L=L(ytrue,ypred)=L(ytrue,F(x;θ))

        ∂L∂zk=∂L∂ak⋅∂ak∂zk=∂L∂ak⊗(gk)′(zk)\frac{\partial L}{\partial z^k}=\frac{\partial L}{\partial a^k}·\frac{\partial a^k}{\partial z^k}=\frac{\partial L}{\partial a^k}\otimes (g^k)'(z^k)zkL=akLzkak=akL(gk)(zk)

        上面的公式说明:损失对隐藏层输出的偏导,等价于损失函数对最终输出的偏导,再逐元素乘上最后层激活函数 在隐藏层输出处 的导数

        其中,激活函数在创建网络时就明确已知,因此求导取值并没有难度
        由于Loss=L(ytrue,ypred)Loss=L(y_{true},y_{pred})Loss=L(ytrue,ypred)直接与网络最终输出ypredy_{pred}ypred相关,因此损失对最终输出的偏导并不难求;
        比如将损失函数定义为均方差MSE:(其他网络基本同理)
        L=12∑j=1m(yj−ajk)2∂L∂ak=−(yj−ajk)L=\frac{1}{2}\sum^m_{j=1}(y_j-a_j^k)^2\\\frac{\partial L}{\partial a^k}=-(y_j-a_j^k)L=21j=1m(yjajk)2akL=(yjajk)

      • 归纳递推(从第 i 层到第 i-1 层)

        假设已知 ∂L∂zi\frac{\partial L}{\partial z^i}ziL(反向传播,因此我们假设的是后一层已知)

        由链式法则可得:
        ∂L∂zi−1=(∂L∂zi)⋅(∂zi∂ai−1)⋅(∂ai−1∂zi−1)\frac{\partial L}{\partial z^{i-1}}=(\frac{\partial L}{\partial z^i})·(\frac{\partial z^i}{\partial a^{i-1}})·(\frac{\partial a^{i-1}}{\partial z^{i-1}})zi1L=(ziL)(ai1zi)(zi1ai1)
        其中,第一个因子已知

        第二个因子∂zi∂ai−1\frac{\partial z^i}{\partial a^{i-1}}ai1zi,分子为第 i 层隐藏层的输出,分母为第 i 层隐藏层的输入(即第 i-1 层激活层的输出),因此其值就是第 i 层隐藏层的权重矩阵WiW^iWi本身

        第三个因子∂ai−1∂zi−1\frac{\partial a^{i-1}}{\partial z^{i-1}}zi1ai1,分子为第 i-1 层激活层的输出,分母为第 i-1 层激活层的输入,因此其值就是 第 i-1 层激活函数 在隐藏层输出处 的导数

        综上:在已知第 i 层损失对输出的梯度的情况下,可以推出第 i-1 层损失对输出的梯度,递推成立

      • 归纳总结
        综上所述,反向传播求梯度完全可行,按照上面的过程撰写程序,就可以很方便地反向逐层 根据损失梯度 更新参数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/914300.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/914300.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ALB、NLB、CLB 负载均衡深度剖析

ALB、NLB、CLB 负载均衡深度剖析 前言 笔者在上周的实际工作中遇到了一个典型的负载均衡选择问题:在使用代理调用相关模型时,最初配置 Nginx 的代理地址为 ALB 的 7 层虚拟 IP(VIP),但由于集团网络默认的超时时间为 3 …

历史数据分析——云南白药

医药板块走势分析: 从月线级别来看 2008年11月到2021年2月,月线上走出了两个震荡中枢的月线级别2085-20349的上涨段; 2021年2月到2024年9月,月线上走出了20349-6702的下跌段; 目前月线级别放巨量,总体还在震荡区间内,后续还有震荡和上涨的概率。 从周线级别来看 从…

【读书笔记】《Effective Modern C++》第3章 Moving to Modern C++

《Effective Modern C》第3章 Moving to Modern C 一、区分圆括号 () 与大括号 {} (Item 7) C11 引入统一初始化(brace‑initialization),即使用 {} 来初始化对象,与传统的 () 存在细微差别:避…

Rust基础-part1

Rust基础[part1]—安装和编译 安装 ➜ rust curl --proto https --tlsv1.2 https://sh.rustup.rs -sSf | sh安装成功 [外链图片转存中…(img-ClSHJ4Op-1752058241580)] 验证 ➜ rust rustc --version zsh: command not found: rustc因为我是用的是zsh,所以zsh配置…

PyQt5布局管理(QGridLayout(网格布局))

QGridLayout(网格布局) QGridLayout(网格布局)是将窗口分隔成行和列的网格来进行排列。通常可以使用函数addWidget()将被管理的控件(Widget)添加到窗口中,或者使用addLayout() 函数将布局(Layou…

Java设计模式之行为型模式(责任链模式)介绍与说明

一、核心概念与定义 责任链模式是一种行为型设计模式,其核心思想是将请求沿着处理对象链传递,直到某个对象能够处理该请求为止。通过这种方式,解耦了请求的发送者与接收者,使多个对象有机会处理同一请求。 关键特点: 动…

SQL server之版本的初认知

SQL server之版本的初认知 为什么要编写此篇文档呢,主要是因为在最近测试OGG实时同步SQL server数据库表数据的时候,经过多次测试,发现在安装了一套SQL server2017初始版本,未安装任何补丁的时候,在添加TRANDATA的时候…

【前端】jQuery动态加载CSS方法总结

在jQuery 中动态加载 CSS 文件有多种方法&#xff0c;以下是几种常用实现方式&#xff1a; 方法 1&#xff1a;创建 <link> 标签&#xff08;推荐&#xff09; // 动态加载外部 CSS 文件 function loadCSS(url) {$(<link>, {rel: stylesheet,type: text/css,href:…

Python爬虫实战:研究xlwings库相关技术

1. 引言 在金融科技快速发展的背景下,数据驱动决策已成为投资领域的核心竞争力。金融市场数据具有海量、多源、实时性强等特点,传统人工收集与分析方式难以满足高效决策需求。Python 凭借其丰富的开源库生态,成为金融数据分析的首选语言。结合 Requests、BeautifulSoup 等爬…

Linux 内核日志中常见错误

目录 **1. `Oops`****含义****典型日志****可能原因****处理建议****2. `panic`****含义****典型日志****可能原因****处理建议****3. `BUG`****含义****典型日志****可能原因****处理建议****4. `kernel NULL pointer`****含义****典型日志****可能原因****处理建议****5. `WA…

Linux驱动开发2:字符设备驱动

Linux驱动开发2&#xff1a;字符设备驱动 字符设备驱动开发流程 字符设备是 Linux 驱动中最基本的一类设备驱动&#xff0c;字符设备就是一个一个字节&#xff0c;按照字节流进行读写操作的设备&#xff0c;读写数据是分先后顺序的。比如最常见的点灯、按键、 IIC、 SPI&#x…

RuoYi-Cloud 验证码处理流程

以该处理流程去拓展其他功能模块处理流程&#xff0c;进而熟悉项目开发代码一、思路JavaWeb流程主干线&#xff1a;发起请求、处理请求、响应请求二、登录页面在登录页面按键F12打开开发者工具&#xff0c;点击network&#xff0c;刷新页面&#xff0c;点击code&#xff0c;查看…

云计算三大服务模式深度解析:IaaS、PaaS、SaaS

架构本质&#xff1a;云计算服务模式定义了资源抽象层级和责任分担边界&#xff0c;形成从基础设施到应用的全栈服务金字塔。三种模式共同构成云计算的服务交付模型核心框架。一、服务模式全景图 #mermaid-svg-f0Klw2fbuhBQqJTh {font-family:"trebuchet ms",verdana…

【sql学习之拉链表】

1.拉链表理解 记录历史。记录一个事物从开始&#xff0c;一直到当前状态的所有变化的信息。字段说明&#xff1a; start_dt&#xff1a;该条记录的生命周期开始时间 end_dt&#xff1a;该条记录的生命周期结束时间 end_dt’9999/12/31’表示该条记录目前处于有效状态 如果查询当…

STM32中实现shell控制台(shell窗口输入实现)

文章目录 一、总体结构二、串口接收机制三、命令输入与处理逻辑四、命令编辑与显示五、历史命令管理六、命令执行七、初始化与使用八、小结在嵌入式系统开发中,使用串口Shell控制台是一种非常常见且高效的调试方式。本文将基于STM32平台,分析一个简洁但功能完整的Shell控制台…

区分三种IO模型和select/poll/epoll

部分内容来源&#xff1a;JavaGuide select/poll/epoll 和 三种IO模型之间的关系是什么&#xff1f;区分普通IO和IO多路复用普通IO&#xff0c;即一个线程对应一个连接&#xff0c;因为每个线程只处理一个客户端 socket&#xff0c;目标明确&#xff1a;线程中直接操作该 socke…

Actor-Critic重要性采样原理

目录 AC的数据低效性&#xff1a; 根本原因&#xff1a;策略更新导致数据失效 应用场景&#xff1a; 1. 离策略值函数估计 2. 离策略策略优化 3. 经验回放&#xff08;Experience Replay&#xff09; 4. 策略梯度方法 具体场景分析 场景1&#xff1a;连续策略更新 场…

【赠书福利,回馈公号读者】《智慧城市与智能网联汽车,融合创新发展之路》

「5G行业应用」公号作家团队推出《智慧城市与智能网联汽车&#xff0c;融合创新发展之路》。本书由机械工业出版社出版&#xff0c;探讨如何通过车城融合和创新应用&#xff0c;促进汽车产业转型升级与生态集群发展&#xff0c;提升智慧城市精准治理与出行服务效能。&#xff0…

5G NR PDCCH之处理流程

本节主要介绍PDCCH处理流程概述。PDCCH&#xff08;Physical Downlink Control Channel&#xff0c;物理下行控制信道&#xff09;主要用于传输DCI&#xff08;Downlink Control Information&#xff0c;下行控制信息&#xff09;&#xff0c;用于通知UE资源分配&#xff0c;调…

力扣网编程135题:分发糖果(贪心算法)

一. 简介本文记录力扣网上涉及数组方面的编程题&#xff1a;分发糖果。这里使用贪心算法的思路来解决&#xff08;求局部最优&#xff0c;最终求全局最优解&#xff09;&#xff1a;每个孩子只需要考虑与相邻孩子的相对关系。二. 力扣网编程135题&#xff1a;分发糖果&#xff…