【DL学习笔记】计算图与自动求导

计算图

  • 计算图(Computation Graph)是一种用于描述计算过程的图形化表示方法。

  • 在深度学习中,计算图通常用于描述 网络结构、运算过程 和数据流向

  • 计算图是一种有向无环图,用图形方式来表示算子与变量之间的关系,直观高效。

  • 它由节点(Node)和边(Edge)组成,如下图Netron库可视化的例子,其中节点表示操作或函数,边表示数据流向

在这里插入图片描述

前向传播 与 反向传播

  • 在Pytorch中,计算图的构建是通过神经网络的 前向传播 (forward) 过程完成的。

  • 反向传播 根据计算图来计算梯度,从而进行参数更新。它为自动微分(automatic differentiation)提供了基础,使得深度学习框架能够自动计算梯度并进行反向传播。

静态计算图、动态计算图

计算图可以分为两种类型:静态计算图 和 动态计算图

  • 静态计算图: 在静态计算图中,计算图在模型定义阶段就被固定下来,不会发生变化。典型的例子是 TensorFlow 1.x 中的计算图。在这种情况下,首先定义计算图,然后运行会话(session)来执行图中的操作。

  • 动态计算图: 在动态计算图中,计算图在运行时根据输入数据的形状和大小动态构建。PyTorch 和 TensorFlow 2.x 采用了动态计算图的方式。在这种情况下,每次前向传播都会重新构建计算图,使得模型更加灵活。

在整个前向计算过程中,PyTorch采用 动态计算图 的形式进行组织,且在每次 前向传播时重新构建。
其他深度学习架构,如TensorFlow、Keras 一般为静态图。

叶子节点、非叶子节点、根节点

在这里插入图片描述

  • 上面的计算图中,圆形表示变量矩形表示算子,这些变量和算子构成了一个完整的前向传播过程
  • 叶子节点 : x、w、bx、w、bxwb 为叶子节点,它们是用户创建的变量,不依赖于其他变量
  • 非叶子节点 : y、zy、zyz为非叶子节点,它们是通过计算得到的变量
  • 根节点 : zzz 为根节点,它之后不会再有后续的运算,我们一般让根节点来执行 反向传播方法 z.backward()

torch.tensor()的requires_grad参数

  • 对于叶子节点(Leaf Node)的 张量Tensor,需要用 requires_grad 指明是否记录对其的操作运算,以便之后通过 反向传播求梯度。

  • 一般仅对 叶子节点 设置 requires_grad, 这些叶子节点,一般就是网络中层的参数,他们一般都是 torch.nn.Parameter 对象,requires_grad 属性 默认为 True

  • 叶子结点如果需要求导,requires_grad 需设置为 True,那么由这些叶子节点计算得出的非叶子节点,requires_grad 会自动置为True

import torchx = torch.tensor([2.0], requires_grad=True)   # 叶子节点
w = torch.tensor([3.0], requires_grad=True)   # 叶子节点
b = torch.tensor([1.0], requires_grad=True)   # 叶子节点y = w * x  # 非叶子节点
z = y + b  # 非叶子节点# 查看叶子节点和非叶子节点的 requires_grad 属性
print('x 的 requires_grad 属性:', x.requires_grad)
print('w 的 requires_grad 属性:', w.requires_grad)
print('b 的 requires_grad 属性:', b.requires_grad)
print('y 的 requires_grad 属性:', y.requires_grad)
print('z 的 requires_grad 属性:', z.requires_grad) 

grad_fn属性

  • 通过运算创建的非叶子节点 tensor,会自动被赋予 grad_fn 属性,用于表明生成这个张量的操作
  • 叶子节点的 grad_fn 为 None,因为它们不是通过其他操作计算得来的,而是网络的参数或输入数据

一些常见的 grad_fn 类型包括:

  • <CatBackward>:表示这个张量是通过 torch.cat 操作得到的。
  • <MatMulBackward>:表示这个张量是通过矩阵乘法操作得到的。
  • <AddBackward>:表示这个张量是通过加法操作得到的。
  • <AddmmBackward>:表示一个张量是通过 torch.addmm 操作得到的。
  • <DivBackward>:表示这个张量是通过除法操作得到的。
  • <ReLUBackward>:表示这个张量是通过 ReLU 激活函数得到的。
**import torch# 创建叶子节点张量
x = torch.tensor([2.0], requires_grad=True)
y = torch.tensor([3.0], requires_grad=True)# 创建非叶子节点张量,通过运算生成
z = x * y# 查看叶子节点和非叶子节点的 grad_fn 属性
print('x 的 grad_fn:', x.grad_fn) 
print('y 的 grad_fn:', y.grad_fn)
print('z 的 grad_fn:', z.grad_fn) **

反向传播

在反向传播中,以 loss(根节点 tensor )为核心,步骤为:

  1. optimizer.zero_grad() 清空叶子节点梯度,避免多次 optimizer.step() 时梯度累加。
  2. 调用 loss.backward() 反向传播,计算叶子节点梯度并存入 .grad 属性。
  3. 执行 optimizer.step() ,依优化器算法和学习率,用 .grad 梯度更新叶子节点(即模型参数 ) 。
for epoch in range(epochs):model.train()for imgs, labels in train_loader        :# trainoptimizer.zero_grad()loss.backward()optimizer.step()

完整举例:

import torch# 输入张量 x, require_grad 属性默认为 False
x = torch.Tensor([2])# 初始化 权重参数w, 偏移量b,并设置 require_grad 属性为 True, 为自动求导
w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)# 实现向前传播
y = torch.mul(w, x)
z = torch.add(y, b)# 分别查看叶子节点 x, w, b 和 非叶子节点 y、z 的require_grad属性
print(x.requires_grad, w.requires_grad, b.requires_grad)  # False True True
print(y.requires_grad, z.requires_grad )  # True True# 查看各节点是否为叶子节点
print(x.is_leaf, w.is_leaf, b.is_leaf, y.is_leaf, z.is_leaf)  # True True True False False# 分别查看 叶子节点 和 非叶子节点 的 grad_fn 属性
print(x.grad_fn, w.grad_fn, b.grad_fn)   # None None None
print(y.grad_fn, z.grad_fn)   # <MulBackward0 object at 0x7f8ac1303910> <AddBackward0 object at 0x7f8ac1303070># 反向传播计算梯度
z.backward()  # 查看叶子节点的梯度,x是叶子节点但它无须求导,故其梯度为None 
print(w.grad,b.grad,x.grad)  # tensor([2.]) tensor([1.]) None# 非叶子节点的梯度,执行backward之后,会自动清空 
print(y.grad,z.grad)  # None None

在这里插入图片描述

自动求导 Autograd

  • 在神经网络中,一个重要内容就是进行参数学习,而参数学习的反向传播离不开求导。
  • 现在大部分深度学习架构都有自动求导的功能,torch.autograd包 就是用来自动求导的。
  • torch.autograd 包为张量上所有的操作提供了自动求导功能

实验:backward()反向传播自动求导

以下代码实现 : 机器学习 回归问题举例,使用 backward() 反向传播自动求导,并手动更新参数

  1. 先来造一批数据,作为样本数据 x 和 标签值y
import torch
import matplotlib.pyplot as plttorch.manual_seed(100)# 生成 x坐标数据,形状为 100 x 1
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)# 生成 y坐标数据,,形状为 100 x 1,加上一些噪声
y = 3 * x.pow(2) + 2 + 0.2 * torch.rand(x.size())# 把tensor数据转换为numpy数据,并可视化
plt.scatter(x.numpy(), y.numpy())
plt.show()

在这里插入图片描述

  1. 定义一个模型 y = wx +b, 我们要学习出 w 和 b 的值,用来拟合 x 和 y
# 初始化权重参数,参数 w、b 为需要学习的,故需要设置参数 requires_grad=True
w = torch.randn(1, 1, dtype=torch.float, requires_grad=True)
b = torch.zeros(1, 1, dtype=torch.float, requires_grad=True)
print(w)  # tensor([[1.1046]], requires_grad=True)
print(b)  # tensor([[0.]], requires_grad=True)lr = 0.001 # 学习率for i in range(800):# 向前传播,得到预测的y值,记为 y_predy_pred = w * x.pow(2) + b# 定义损失函数loss = (y - y_pred) ** 2loss = loss.sum()# 反向传播,自动计算梯度,存放在 grad 属性中loss.backward()# 手动更新参数,需要用torch.no_grad(), 使上下文环境中切断自动求导的计算with torch.no_grad():# 更新参数w -= lr * w.gradb -= lr * b.grad# 梯度清零w.grad.zero_()b.grad.zero_()print(w)  # tensor([[2.9668]], requires_grad=True)
print(b)  # tensor([[2.1138]], requires_grad=True)

在这里插入图片描述

  1. 可视化一下结果,红色曲线是预测结果 ,蓝色点是真实标签值
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/94003.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/94003.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大型地面光伏电站开发建设流程

​地面电站特特点&#xff1a;规模大&#xff0c;通常占用土地、水面等&#xff0c;地面式选址选项多&#xff0c;且不断拓展出新的用地模式&#xff0c;地面式选址集中在山体、滩涂、沼泽、戈壁、沙漠、受污染土地等闲置或废弃土地上。

除数博弈(动态规划)

爱丽丝和鲍勃一起玩游戏&#xff0c;他们轮流行动。爱丽丝先手开局。最初&#xff0c;黑板上有一个数字 n 。在每个玩家的回合&#xff0c;玩家需要执行以下操作&#xff1a;选出任一 x&#xff0c;满足 0 < x < n 且 n % x 0 。用 n - x 替换黑板上的数字 n 。如果玩家…

一起学springAI系列一:初体验

Spring AI是干嘛的官网最权威&#xff0c;直接粘贴&#xff1a;“Spring AI”项目旨在简化那些包含人工智能功能的应用程序的开发过程&#xff0c;同时避免不必要的复杂性。AI相关领域的功能对python的支持是最好的&#xff0c;相关供应商在出了啥功能的时候&#xff0c;都会优…

Ext JS极速项目之 Coworkee

ExtJS Coworkee 是什么? Ext JS 的 Coworkee 是一个由 Sencha 官方提供的完整员工管理应用示例,旨在展示 Ext JS 框架在企业级应用开发中的能力。 在线试用的地址是: https://examples.sencha.com/coworkee/#home 页面效果与布局 登录页面: 主页效果 左右分区结构:左…

飞算科技:原创技术重塑 Java 开发,引领行业数智化新浪潮

在科技革新的浪潮中&#xff0c;飞算科技作为一家坚持自主创新的数字科技企业&#xff0c;同时也是国家级高新技术企业&#xff0c;正深耕互联网科技、大数据、人工智能等前沿领域&#xff0c;为众多企业的数字化与智能化转型提供强劲动力。​飞算科技的成长轨迹&#xff0c;是…

cesium FBO(一)渲染到纹理(RTT)

一听到三维的RTT&#xff08;Render To Texture&#xff09;&#xff0c;似乎很神秘&#xff0c;但从底层实现一看&#xff0c;其实也就那样&#xff0c;设计API的哪些顶级家伙已经帮你安排的明明白白了&#xff0c;咱们只需要学会怎么用就可以了。我认为得从WebGL入手&#xf…

PNP机器人机器人学术年会展示灵巧手动作捕捉方案。

2025年8月1-3日&#xff0c;第六届中国机器人学术年会&#xff08;CCRS2025&#xff09;在长沙国际会议中心举行&#xff0c;主题“人机共融&#xff0c;智向未来”。PNP机器人与灵巧智能联合展出最新灵巧手模仿学习方案&#xff1a;基于少量示教数据即可快速复现复杂抓取动作&…

【45】C#入门到精通——C#调用C/C++生成动态库.dll及C++ 生成动态库.dll ,DllImport()方式导入 C++动态库.dll方法总结

文章目录1 C 生成动态库.dll2 C#调用C/C生成动态库.dll2.1 [DllImport()] 方式导入 C动态库.dll2.2 调用测试3 C/C 生成通用dll,改进3.1改进后.h3.2 .cpp3.2 C# 调用4 [DllImport()] 方式导入C生成的 .dll 总结4.1 指定路径导入4.2 .dll放在 执行目录下&#xff08;一定要放对&…

从协议栈到ath12k_mac_op_tx的完整调用路径

文章目录 从协议栈到ath12k_mac_op_tx的完整调用路径 1. 整体架构概览 2. 详细调用路径分析 2.1 应用层到Socket层 2.2 协议层处理 2.3 网络设备层到mac80211 2.4 mac80211发送入口 2.5 mac80211核心发送处理 2.6 mac80211发送核心处理 2.7 mac80211发送调度 2.8 最终驱动调用 …

WPFC#超市管理系统(4)入库管理

入库管理7. 商品入库管理7.2 入库实现显示名称、图片、单位7.3 界面设计7.3 功能实现7. 商品入库管理 数据库中StockRecord表需要增加商品出入库Type类型为nvarchar(50)。C#中的数据库重新同步StockRecord表在Entity→Model中新建枚举类型StockType namespace 超市管理系统.E…

CSS 打字特效

效果图.wxml <view class"tips"><text>{{ tipsText }}</text><text class"tips-line">|</text> </view>.wxss .tips{padding: 50rpx 100rpx;font-size: 28rpx; } .tips-line{color: #ccc;animation: tips-line .5s al…

直播小程序 app 系统架构分析

一、引言 直播行业近年来发展迅猛&#xff0c;直播小程序和 APP 成为众多用户获取直播内容以及主播进行内容输出的重要平台。一个完善且高效的系统架构是支撑直播业务稳定运行、提供优质用户体验的关键。本文将详细剖析直播小程序 / APP 的系统架构&#xff0c;包括整体架构设计…

Vue常见题目

1. 什么是 Vue.js&#xff1f;它的核心特点是什么&#xff1f; Vue.js 是一个渐进式 JavaScript 框架&#xff0c;用于构建用户界面。它的核心特点包括&#xff1a; - 响应式数据绑定 - 组件化开发 - 虚拟 DOM - 指令系统 - 轻量级且易于集成 - 丰富的生态系统&#xff08;Vue…

ipynb文件直接发布csdn

第一步&#xff0c;下载markdown文件 file --> save and export notebook as --> markdown第二步&#xff0c;导入markdown文件 进入csdn发布文章界面&#xff0c;点击导入&#xff0c;选择第一步下载的markdown文件即可

广东省省考备考(第六十四天8.2)——判断推理(重点回顾)

判断推理&#xff1a;数量规律 错题解析解析解析解析解析解析解析标记题解析解析解析解析解析解析解析今日题目正确率&#xff1a;53% 判断推理&#xff1a;属性规律 错题解析解析解析解析解析解析标记题解析解析今日题目正确率&#xff1a;60%

【C++/STL】vector的OJ,深度剖析和模拟实现

vector在OJ中的使用 1.只出现一次的数字 class Solution { public:int singleNumber(vector<int>& nums) {int value 0;for(auto e : v) {value ^ e; }return value;} };2.杨辉三角 class Solution { public:vector<vector<int>> generate(int numRow…

衡石湖仓一体架构深度解构:统一元数据层如何破除数据孤岛?

一、数据融合的世纪难题典型困境二、衡石统一元数据层设计架构核心关键技术实现智能元数据发现自动构建跨源血缘关系动态查询重写 将标准SQL翻译为最优执行计划text Original: SELECT SUM(sales) FROM virtual_view Rewritten: [S3] SELECT SUM(amount) FROM crm_sales [My…

Windows 下 fping 指令使用指南

fping 作为一款强大的网络工具&#xff0c;能够同时向多个主机发送 ICMP 回声请求&#xff0c;相较于传统的 ping 命令&#xff0c;在处理大量主机时具有显著优势。 一、fping 简介​ fping 是 “fast pinger” 的缩写&#xff0c;它可以向一系列 IP 地址发送 ICMP 回声请求。…

代码随想录day52图论3

文章目录101. 孤岛的总面积102. 沉没孤岛103. 水流问题104.建造最大岛屿101. 孤岛的总面积 题目链接 文章讲解 #include<bits/stdc.h> using namespace std;int ans 0; // 记录不与边界相连的孤岛数量 int sum 0; // 当前孤岛的面积 bool flag false; /…

linux pip/conda 修改默认cache位置

1 pip pip cache默认在/home/{username}目录下&#xff0c;容易导致系统盘写满报错。查看pip cache位置pip cache dir假设移动pip cache目录到 /data/.cache/pip/cache&#xff0c;命令如下pip config set global.cache-dir /data/.cache/pip/cache2 conda 查看conda缓存位置c…