pytorch基本运算-梯度运算:requires_grad_(True)和backward()

引言

前序学习进程中,已经对pytorch基本运算中的求导进行了基础讨论,相关文章链接为:

导数运算pytorch基本运算-导数和f-string-CSDN博客

实际上,求导是微分的进一步计算,要想求导的前一步其实是计算微分:

导数表达式:
f ′ ( x ) 或 d y d x f^{'}(x) 或 \frac{dy}{dx} f(x)dxdy
​微分表达式:
f ′ ( x ) d x 或 d y = f ′ ( x ) d x f^{'}(x) dx 或 {dy=f^{'}(x) dx} f(x)dxdy=f(x)dx
导数是某一点处的变化率,微分是某一点附近的变化量。
如果一个函数在多个点进行导数求解,或者说子安多维度上进行导数计算,实际上就是在求梯度。

pytorch自动微分获取梯度

为完整展示pytorch的梯度计算功能,将测试分为以下部分。

初始定义

首先是引入模块,完成变量定义:

# 导入模块
import torch
# 定义变量
x=torch.arange(3.0)
print('x=',x)

这里的输出结果是:

x= tensor([0., 1., 2.])

需要说明的是,因为pytorch默认对浮点数进行求导,所以定义变量的时候,pyorch.arange()使用了3.0而不是整数3。
紧接着,需要对变量执行梯度运算。

梯度运算标定

梯度运算标定的目的是,声明要对x进行梯度运算。任何没有经过提前标定的量,都不能正常执行梯度运算。

# 标记需要对x进行梯度计算
z=x.requires_grad_(True)
print('z=',z)

梯度标定使用requires_grad_(True),就像对话一样,需要求梯度_(需要)。
代码运行的效果为:

z= tensor([0., 1., 2.], requires_grad=True)

下一步是定义一个函数。

函数定义

这里定义一个简单函数:
f ( x ) = 2 x 2 f(x)=2x^{2} f(x)=2x2
具体定义代码为:

# 点乘定义
m=2*torch.dot(x,x)
print('m=',m)

计算微分对函数开展才有意义,所以必须定义函数,这里只是一个示例,也可以是其他函数。
torch.dot()函数的计算规则为:对位相乘然后求和。
代码运行效果为:

m= tensor(10., grad_fn=)

这里输出了两个部分:
第一部分是10,就是元素对位相乘后求和的效果(2X0X0+2X1X1+2X2X2=10)。
第二部分是grad_fn=,grad_fn的意思是grad_function,就是求导函数的意思,后面的MulBackward0是对求导函数的具体定义。
MulBackward0 表示这是一个乘法操作的梯度函数,具体拆开来:multiplication-backward,字面意思解释:乘法-反向传播。
这就是pytorch自动微分的核心机制:它可以自动测算求导函数的类型,比如这是一个自变量相乘的函数,并且指出要用哪种方法,比如这里要用反向传播法。
到这一步还无法计算微分,只是通过输出效果知道用反向传播方法计算微分,然后就是正式使用反向传播方法计算微分。

梯度计算

微分计算使用的代码为:

# 执行梯度运算
n=m.backward()
k=x.grad
print('n=',n)
print('k=',k)

这里用了两步,第一步是定义对函数m调用backward方法求倒数,然后具体是对x求导数,所获得计算结果为:

n= None
k= tensor([0., 4., 8.])

n对应的其实是方法定义,k才是具体的对x的求导效果。

实际上到这一步,如何用pytorch直接计算导数已经非常清晰:先要标定梯度计算的变量,然后要对函数声明梯度计算的方法,最后直接计算梯度。完整代码为:

# 导入模块
import torch
# 定义变量
x=torch.arange(3.0)
print('x=',x)
# 标记需要对x进行梯度计算
z=x.requires_grad_(True)
print('z=',z)
# 点乘定义
m=2*torch.dot(x,x)
print('m=',m)
# 执行梯度运算
n=m.backward()
k=x.grad
print('n=',n)
print('k=',k)

新的函数

未计算对新函数进行求导运算,需要提前将梯度清零,避免梯度计算效果彼此叠加,出现预料之外的效果。

梯度清零

代码为:

# 梯度清零
kk=x.grad.zero_()
print('kk=',kk)

代码运行效果为:

kk= tensor([0., 0., 0.])

定义新函数

代码为:

# 定义新函数
hh=x.sum()
print('hh=',hh)

这里使用了求和函数sim(),代码运行效果为:

hh= tensor(3., grad_fn=)

这里也输出了两个部分:
第一部分是3,就是元素求和的效果(0+1+2=3)。
第二部分是grad_fn=,grad_fn的意思是grad_function,就是求导函数的意思,后面的SumBackward0是对求导函数的具体定义。
SumBackward0 表示这是一个加法操作的梯度函数,具体拆开来:Sum-backward,字面意思解释:加法-反向传播。

导数计算

此时可以直接计算导数,代码为:

# 定义用backward方法计算导数
nn=hh.backward()
print('nn=',nn)
# 导数计算
tt=x.grad
print('tt=',tt)

代码运行效果为:

nn= None
tt= tensor([1., 1., 1.])

因为是各个变量直接叠加,所以每个变量前的系数都是1,所以导数运算的结果是[1.0,1.0,1.0].
此时的完整代码为:

# 导入模块
import torch
# 定义变量
x=torch.arange(3.0)
print('x=',x)
# 标记需要对x进行梯度计算
z=x.requires_grad_(True)
print('z=',z)
# 点乘定义
m=2*torch.dot(x,x)
print('m=',m)
# 执行梯度运算
n=m.backward()
k=x.grad
print('n=',n)
print('k=',k)
# 梯度清零
kk=x.grad.zero_()
print('kk=',kk)
# 定义新函数
hh=x.sum()
print('hh=',hh)
# 定义用backward方法计算导数
nn=hh.backward()
print('nn=',nn)
# 导数计算
tt=x.grad
print('tt=',tt)

完整的输出效果为:

x= tensor([0., 1., 2.])
z= tensor([0., 1., 2.], requires_grad=True)
m= tensor(10., grad_fn=)
n= None
k= tensor([0., 4., 8.])
kk= tensor([0., 0., 0.])
hh= tensor(3., grad_fn=)
nn= None
tt= tensor([1., 1., 1.])

梯度清零操作的讨论

前述有一个梯队清零的操作,如果没有这步操作,输出效果会如何变化,这里直接给出完整代码来测试。给出完整代码为:

# 导入模块
import torch
# 定义变量
x=torch.arange(3.0)
print('x=',x)
# 标记需要对x进行梯度计算
z=x.requires_grad_(True)
print('z=',z)
# 点乘定义
m=2*torch.dot(x,x)
print('m=',m)
# 执行梯度运算
n=m.backward()
k=x.grad
print('n=',n)
print('k=',k)
# 梯度清零
#kk=x.grad.zero_()
#print('kk=',kk)
# 定义新函数
hh=x.sum()
print('hh=',hh)
# 定义用backward方法计算导数
nn=hh.backward()
print('nn=',nn)
# 导数计算
tt=x.grad
print('tt=',tt)

此时的输出效果为:

x= tensor([0., 1., 2.])
z= tensor([0., 1., 2.], requires_grad=True)
m= tensor(10., grad_fn=)
n= None
k= tensor([0., 4., 8.])
hh= tensor(3., grad_fn=)
nn= None
tt= tensor([1., 5., 9.])

这里可以看到sum()函数的梯度输出为:[1.,5.,9.],这个结果的来源其实是:[0., 4., 8.]+[1., 1., 1.]=[1., 5., 9.]。
此处可见,及时将梯度清零很有必要。

总结

掌握了通过python+pytorch执行梯度运算的基本技巧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/web/83747.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

idea64.exe.vmoptions配置

这个idea64.exe.vmoptions文件是用于配置 IntelliJ IDEA(64位版本)运行时的 Java 虚拟机(JVM)参数。这些参数直接影响到 IDEA 的性能、内存使用、调试能力和行为。 下面是对文件中每一行配置的详细解读: -Xms2048m 作…

齐次变换矩阵相乘的复合变换:左乘与右乘的深度解析

在三维几何变换中,齐次变换矩阵相乘是实现复杂变换的核心方法。本文将通过一个包含四个变换步骤的完整示例,深入探讨齐次变换矩阵左乘和右乘的区别,并结合 Python sympy 库的代码实现,详细阐述变换过程和结果差异。 二维齐次坐标的旋转变换 在二维齐次坐标系中,一个点可以…

5g LDPC编译码-LDPC编码

目录 1、LDPC编码基础知识 2、5g的LDPC编码 2.1 LDPC分块: 2.2 LDCP编码 2.3 校验位的产生 1、LDPC编码基础知识 LDPC属于线性分组码,线性分组码的基本知识如下: 编码后的码字是由初始二进制序列与生成矩阵在二进制域相乘后得到,生成矩阵与校验矩阵,校验矩阵与编码后…

OpenVINO使用教程--resnet分类模型部署

OpenVINO使用教程--resnet分类模型部署 本节内容模型准备推理测试分析&总结本节内容 OpenVINO 根据AI技术类型将部署任务分成传统模型模型部署和生成式AI模型部署,传统模型指的是各种CNN小模型,这部分部署只需要OpenVINO包,具体安装教程可以参考之前的章节:OpenVINO环境…

无字母数字webshell的命令执行

在Web安全领域,WebShell是一种常见的攻击手段,通过它攻击者可以远程执行服务器上的命令,获取敏感信息或控制系统。而无字母数字WebShell则是其中一种特殊形式,通过避免使用字母和数字字符,来绕过某些安全机制的检测。 …

C++斯特林数在C++中的数学理论与计算实现1

一、 斯特林数概述 1.1 组合数学中的核心地位 斯特林数(Stirling Numbers)是组合数学中连接排列、组合与分划问题的核心工具,分为两类: 第一类斯特林数(Stirling Numbers of the First Kind)&#xff1a…

[C++] STL大家族之<map>(字典)容器(附洛谷)

map-目录 使用方法头文件与声明定义基本操作 使用方法 头文件与声明定义 头文件是: #include <map>我们这样声明一个字典: map</*key_type*/, /*value_type*/> /*map_name*/; // 例子: map<int, char> mp;这里稍作解释: key_type是你每个键值对中的键的…

使用 Flutter 在 Windows 平台开发 Android 应用

以下是完整的开发流程&#xff0c;包括环境搭建、代码实现和应用发布&#xff0c;帮助你开发一个具有地图显示、TCP 通信功能的 Android 应用。 一、环境搭建 1. 安装 Flutter SDK 从 Flutter 官网 下载最新稳定版 SDK解压到本地目录&#xff08;如 D:\flutter&#xff09;添…

【模板】埃拉托色尼筛法(埃氏筛)

一、算法简介 在数论与编程竞赛中&#xff0c;求解 [ 1 , n ] [1,n] [1,n] 范围内的所有质数是常见的基础问题。埃拉托色尼筛法&#xff08;Sieve of Eratosthenes&#xff09; 是一种古老而高效的算法&#xff0c;可以在 O ( n log ⁡ log ⁡ n ) O(n \log \log n) O(nlogl…

AI Agent实战 - LangChain+Playwright构建火车票查询Agent

本篇文章将带你一步步构建一个智能火车票查询 Agent&#xff1a;你只需要输入自然语言指令&#xff0c;例如&#xff1a; “帮我查一下6月15号从上海到南京的火车票” Agent就能自动理解你的需求并使用 Playwright 打开 12306 官网查询前 10 条车次信息&#xff0c;然后汇总结果…

RabbitMQ的交换机和队列概念

&#x1f3ea; 场景&#xff1a;一个外卖平台的后台系统 假设你开了一家在线外卖平台&#xff1a; 饭店是消息的生产者&#xff08;Producer&#xff09;顾客是消息的消费者&#xff08;Consumer&#xff09;你开的外卖平台就是RabbitMQ消息系统 &#x1f501; 第一部分&…

德国马克斯·普朗克数学研究所:几何朗兰兹猜想

2025年科学突破奖 4月5日在美国洛杉矶揭晓&#xff1a;数学突破奖&#xff1a;德国马克斯普朗克数学研究所&#xff1a;几何朗兰兹猜想 德国马克斯普朗克数学研究所&#xff08;Max Planck Institute for Mathematics, MPIM&#xff09;在几何朗兰兹猜想的研究中扮演了核心角色…

TerraFE 脚手架开发实战系列(一):项目架构设计与技术选型

TerraFE 脚手架开发实战系列&#xff08;一&#xff09;&#xff1a;项目架构设计与技术选型 前言 在前端开发中&#xff0c;项目初始化往往是一个重复且繁琐的过程。每次新建项目都需要配置 webpack、安装依赖、设置目录结构等&#xff0c;这些重复性工作不仅浪费时间&#…

准确--CentOS 7.9在线安装docker

一、安装Docker前的准备工作 操作系统版本为CentOS 7.9&#xff0c;内核版本需要在3.10以上。确保能够连通互联网&#xff0c;为避免网络异常&#xff0c;建议关闭Linux的防火墙&#xff08;生产环境下请根据实际情况设置防火墙出入站规则&#xff09;。 # 查看内核版本 sudo…

中兴B860AV1.1强力降级固件包

中兴B860AV1.1强力降级固件包 关于中兴b860av1.1顽固盒子降级教程终极版 将附件解压好以后&#xff0c;准备一个8G以下的U盘重新格式化为FAT32格式后&#xff0c;并插入电脑 将以下文件及文件夹一同复制到优盘主目录下&#xff08;见下图&#xff09; 全选并复制到U盘主目录下&…

nacos-作为注册中心与springcloud整合(三)

前一篇文章nacos-简介和初体验&#xff08;一&#xff09;我们已经在服务器部署了nacos应用了。 在另外一篇文章中nacos-作为配置中心与springcloud整合&#xff08;二&#xff09;已经作为配置中心整合到springcloud 接下来让我们尝试把nacos作为注册中心和springcloud中整合&…

Seata的TC(事务协调器)高可用如何实现?

Seata的TC&#xff08;事务协调器&#xff09;确实运行在Seata服务进程中&#xff0c;其高可用实现和宕机恢复主要通过以下机制实现&#xff1a; 一、高可用架构 集群部署 多TC节点组成集群&#xff0c;通过注册中心&#xff08;如Nacos&#xff09;实现服务发现采用Raft协议实…

Mac安装docker desktop

一、背景 最近在学习Spring AI&#xff0c;于是在GitHub上找了个开源项目&#xff0c;个人觉得还是比较适合有Java基础和AI基础的同学学习的。GitHub地址如下&#xff1a; https://github.com/qifan777/dive-into-spring-ai 但是看了下运行环境需要 MySQL 8 Redis-Stack n…

【算法深练】二分答案:从「猜答案」到「精准求解」的解题思路

目录 前言 二分求最小值 1283. 使结果不超过阈值的最小除数 2187. 完成旅途的最少时间 1011. 在 D 天内送达包裹的能力 875. 爱吃香蕉的珂珂 3296. 移山所需的最少秒数 475. 供暖器 2594. 修车的最少时间 1482. 制作 m 束花所需的最少天数 3048. 标记所有下标的最早秒…

基于RK3588,飞凌教育品牌推出嵌入式人工智能实验箱EDU-AIoT ELF 2

在AIoT技术驱动产业变革的浪潮中&#xff0c;嵌入式人工智能已成为工业物联网、智慧交通、智慧医疗等领域创新突破的关键引擎。飞凌嵌入式教育品牌ElfBoard立足产业前沿&#xff0c;重磅推出嵌入式人工智能实验箱EDU-AIoT ELF 2&#xff0c;以“软硬协同、产教融合”为设计理念…