PyTorch自动微分:从基础到实战

目录

1. 自动微分是什么?

1.1 计算图

1.2 requires_grad 属性

2. 标量和向量的梯度计算

2.1 标量梯度

2.2 向量梯度

3. 梯度上下文控制

3.1 禁用梯度计算

3.2 累计梯度

4. 梯度下降实战

4.1 求函数最小值

4.2 线性回归参数求解

5. 总结


在深度学习中,自动微分是神经网络训练的核心机制之一。PyTorch通过torch.autograd模块提供了强大的自动微分功能,能够自动计算张量操作的梯度。今天,我们就来深入探讨PyTorch的自动微分机制,并通过一些实战案例来理解它的原理和应用。

1. 自动微分是什么?

在神经网络训练中,我们通常需要计算损失函数对模型参数的梯度,以便通过梯度下降法更新参数,从而最小化损失函数。手动计算梯度是非常繁琐且容易出错的,尤其是当网络结构复杂时。自动微分通过自动构建计算图并计算梯度,极大地简化了这一过程。

1.1 计算图

计算图是自动微分的核心概念。它是一个有向图,节点表示张量(Tensor),边表示张量之间的操作。当我们对张量进行操作时,PyTorch会自动构建一个动态计算图,并在反向传播时沿着这个图计算梯度。

例如:
在上述代码中,x y 是输入张量,即叶子节点,z 是中间结果,loss 是最终输出。每一步操作都
会记录依赖关系:
z = x * yz 依赖于 x y
loss = z.sum()loss 依赖于 z
这些依赖关系形成了一个动态计算图,如下所示:

1.2 requires_grad 属性

在PyTorch中,每个张量都有一个requires_grad属性,用于指定是否需要计算梯度。如果requires_grad=True,则该张量的所有操作都会被记录在计算图中;如果requires_grad=False,则不会记录操作,也不会计算梯度。

x = torch.tensor(1.0, requires_grad=True)
y = x ** 2
y.backward()
print(x.grad)  # 输出梯度

2. 标量和向量的梯度计算

2.1 标量梯度

当我们对一个标量进行操作时,可以直接调用backward()方法来计算梯度。

x = torch.tensor(1.0, requires_grad=True)
y = x ** 2
y.backward()
print(x.grad)  # 输出:tensor(2.)

2.2 向量梯度

对于向量,我们需要提供一个与输出形状相同的梯度张量作为backward()的参数。

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x ** 2
y.backward(torch.tensor([1.0, 1.0, 1.0]))
print(x.grad)  # 输出:tensor([2., 4., 6.])

如果我们将输出转换为标量(例如通过求和),则可以直接调用backward()

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x ** 2
loss = y.sum()
loss.backward()
print(x.grad)  # 输出:tensor([2., 4., 6.])

3. 梯度上下文控制

在某些情况下,我们可能不需要计算梯度,或者希望控制梯度的计算过程。PyTorch提供了几种方式来控制梯度计算的上下文。

3.1 禁用梯度计算

使用torch.no_grad()上下文管理器可以临时禁用梯度计算。

x = torch.tensor(1.0, requires_grad=True)
with torch.no_grad():y = x ** 2
print(y.requires_grad)  # 输出:False

3.2 累计梯度

默认情况下,多次调用backward()会累计梯度。如果需要清零梯度,可以使用x.grad.zero_()

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
for i in range(3):y = x ** 2loss = y.sum()if x.grad is not None:x.grad.zero_()loss.backward()print(x.grad)
# 输出:tensor([2., 4., 6.])tensor([2., 4., 6.])tensor([2., 4., 6.])

4. 梯度下降实战

4.1 求函数最小值

通过梯度下降法,我们可以找到函数的最小值。以下是一个简单的例子,通过梯度下降法找到函数y = x^2的最小值。

x = torch.tensor([3.0], requires_grad=True)
lr = 0.1
epochs = 50
for epoch in range(epochs):y = x ** 2if x.grad is not None:x.grad.zero_()y.backward()with torch.no_grad():x -= lr * x.gradprint(f'Epoch {epoch}, x: {x.item()}, y: {y.item()}')

4.2 线性回归参数求解

我们还可以通过梯度下降法求解线性回归模型的参数。以下是一个简单的线性回归模型,通过梯度下降法求解参数ab

x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float)
y = torch.tensor([3, 5, 7, 9, 11], dtype=torch.float)
a = torch.tensor([1.0], requires_grad=True)
b = torch.tensor([1.0], requires_grad=True)
lr = 0.01
epochs = 1000
for epoch in range(epochs):y_pred = a * x + bloss = ((y_pred - y) ** 2).mean()if a.grad is not None and b.grad is not None:a.grad.zero_()b.grad.zero_()loss.backward()with torch.no_grad():a -= lr * a.gradb -= lr * b.gradif (epoch + 1) % 100 == 0:print(f'Epoch {epoch + 1}, Loss: {loss.item()}')
print(f'a: {a.item()}, b: {b.item()}')

5. 总结

通过这篇文章,我们学习了PyTorch的自动微分机制,包括:

  • 如何构建计算图。

  • 如何计算标量和向量的梯度。

  • 如何控制梯度计算的上下文。

  • 如何通过梯度下降法求解函数最小值和线性回归模型的参数。

自动微分是深度学习的核心技术之一,希望这篇文章能帮助你更好地理解和使用PyTorch的自动微分功能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/914340.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/914340.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Spring AI 项目实战(十六):Spring Boot + AI + 通义万相图像生成工具全栈项目实战(附完整源码)

系列文章 序号文章名称1Spring AI 项目实战(一):Spring AI 核心模块入门2Spring AI 项目实战(二):Spring Boot + AI + DeepSeek 深度实战(附完整源码)3Spring AI 项目实战(三):Spring Boot + AI + DeepSeek 打造智能客服系统(附完整源码)4

从零到一:企业如何组建安全团队

在这个"黑客满天飞,漏洞遍地跑"的时代,没有安全团队的企业就像裸奔的勇士——虽然很有勇气,但结局往往很悲惨。 📋 目录 为什么要组建安全团队安全团队的核心职能团队架构设计人员配置策略技术体系建设制度流程建立实施…

业务访问控制-ACL与包过滤

业务访问控制-ACL与包过滤 ACL的定义及应用场景ACL(Access Control List,访问控制列表)是用来实现数据包识别功能的;ACL可以应用于诸多场景: 包过滤功能:对数据包进行放通或过滤操作。NAT(Netwo…

穿梭时空的智慧向导:Deepoc具身智能如何赋予导览机器人“人情味”

穿梭时空的智慧向导:Deepoc具身智能如何赋予导览机器人“人情味”清晨,当第一缕阳光透过高大的彩绘玻璃窗,洒在博物馆光洁的地板上,一位特别的“馆员”已悄然“苏醒”。它没有制服,却有着清晰的指引;它无需…

PostgreSQL 查询库中所有表占用磁盘大小、表大小

SELECTn.nspname AS schema_name,c.relname AS table_name,-- 1️⃣ 总大小(表 toast 索引)pg_size_pretty(pg_total_relation_size(c.oid)) AS total_size,-- 2️⃣ 表不包含索引(含 TOAST)pg_size_pretty(pg_total_relation_s…

日记-生活随想

最近鼠鼠也是来到上海打拼(实习)了,那么秉持着来都来了的原则,鼠鼠也是去bw逛了逛,虽说没票只能在外场看看😭。可惜几乎没有多少我非常喜欢的ip,不由感慨现在的二次元圈已经变样了。虽说我知道内…

串口A和S的含义以及RT的含义

A async 异步S sync 同步RT 收发U A RT 异步U SA RT 同步/异步

spring cloud负载均衡分析之FeignBlockingLoadBalancerClient、BlockingLoadBalancerClient

本文主要分析被 FeignClient 注解的接口类请求过程中负载均衡逻辑&#xff0c;流程分析使用的依赖版本信息如下&#xff1a;<spring-boot.version>3.2.1</spring-boot.version><spring-cloud.version>2023.0.0</spring-cloud.version><com.alibaba.…

ref 和 reactive

文章目录ref 和 reactive一、差异二、能否替代的场景分析&#xff08;1&#xff09;基本类型数据&#xff08;2&#xff09;对象类型数据&#xff08;3&#xff09;数组类型数据&#xff08;4&#xff09; 需要整体替换的场景三、替代方案与兼容写法1. 用 reactive 模拟 ref2. …

BatchNorm 与 LayerNorm:原理、实现与应用对比

BatchNorm 与 LayerNorm&#xff1a;原理、实现与应用对比 Batch Normalization (批归一化) 和 Layer Normalization (层归一化) 是深度学习中两种核心的归一化技术&#xff0c;它们解决了神经网络训练中的内部协变量偏移问题&#xff0c;大幅提升了模型训练的稳定性和收敛速度…

OcsNG基于debian一键部署脚本

&#x1f914; 为什么有了GLPI还要部署OCS-NG&#xff1f; 核心问题&#xff1a;数据收集的风险 GLPI直接收集的问题&#xff1a; Agent直接向GLPI报告数据时&#xff0c;任何收集异常都会直接影响资产数据库网络问题、Agent故障可能导致重复资产、错误数据、资产丢失无法对收集…

001_Claude开发者指南介绍

Claude开发者指南介绍 目录 Claude简介Claude 4 模型开始使用核心功能支持资源 Claude简介 Claude 是由 Anthropic 构建的高性能、可信赖和智能的 AI 平台。Claude 具备出色的语言、推理、分析和编程能力&#xff0c;可以帮助您解决各种复杂任务。 想要与 Claude 聊天吗&a…

004_Claude功能特性与API使用

Claude功能特性与API使用 目录 API 基础使用核心功能特性高级功能开发工具平台支持 API 基础使用 快速开始 通过 Anthropic Console 获取 API 访问权限&#xff1a; 在 console.anthropic.com/account/keys 生成 API 密钥使用 Workbench 在浏览器中测试 API 认证方式 H…

ReAct论文解读(1)—什么是ReAct?

什么是ReAct&#xff1f; 在大语言模型&#xff08;LLM&#xff09;领域中&#xff0c;ReAct 指的是一种结合了推理&#xff08;Reasoning&#xff09; 和行动&#xff08;Acting&#xff09; 的提示方法&#xff0c;全称是 “ReAct: Synergizing Reasoning and Acting in Lan…

【云服务器安全相关】服务器防火墙常见系统日志信息说明

目录✅ 一、防火墙日志是做什么的&#xff1f;&#x1f6e0;️ 二、常见防火墙日志信息及说明&#x1f9ea; 三、典型日志示例解析1. 被阻断的访问&#xff08;DROP&#xff09;2. 被允许的访问&#xff08;ACCEPT&#xff09;3. 被拒绝的端口访问4. 可疑端口扫描行为&#x1f…

011_视觉能力与图像处理

视觉能力与图像处理 目录 视觉能力概述支持的图像格式图像上传方式使用限制最佳实践应用场景API使用示例视觉能力概述 多模态交互 Claude 3 系列模型具备强大的视觉理解能力,可以分析和理解图像内容,实现真正的多模态AI交互。这种能力使Claude能够: 图像内容分析:理解图…

ansible自动化部署考试系统前后端分离项目

1. ✅ansible编写剧本步骤1️⃣创建roles目录结构2️⃣在group_vars/all/main.yml中定义变量列表3️⃣在tasks目录下编写tasks任务4️⃣在files目录下准备部署文件5️⃣在templates目录下创建j2模板文件6️⃣在handlers目录下编写handlers7️⃣在roles目录下编写主playbook8️⃣…

【AI论文】GLM-4.1V-Thinking:迈向具备可扩展强化学习的通用多模态推理

摘要&#xff1a;我们推出GLM-4.1V-Thinking&#xff0c;这是一款旨在推动通用多模态推理发展的视觉语言模型&#xff08;VLM&#xff09;。在本报告中&#xff0c;我们分享了在以推理为核心的训练框架开发过程中的关键发现。我们首先通过大规模预训练开发了一个具备显著潜力的…

Linux进程通信——匿名管道

目录 1、进程间通信基础概念 2、管道的工作原理 2.1 什么是管道文件 3、匿名管道的创建与使用 3.1、pipe 系统调用 3.2 父进程调用 fork() 创建子进程 3.3. 父子进程的文件描述符共享 3.4. 关闭不必要的文件描述符 3.5 父子进程通过管道进行通信 父子进程通信的具体例…

sql:sql在office中的应用有哪些?

在Office软件套件中&#xff0c;主要是Access和Excel会用到SQL&#xff08;结构化查询语言&#xff09;&#xff0c;以下是它们在这两款软件中的具体应用&#xff1a; 在Access中的应用 创建和管理数据库对象&#xff1a; 创建表&#xff1a;使用CREATE TABLE语句可以创建新的数…