小白的进阶之路系列之七----人工智能从初步到精通pytorch自动微分优化以及载入和保存模型

本文将介绍Pytorch的以下内容

自动微分函数

优化

模型保存和载入

好了,我们首先介绍一下关于微分的内容。

在训练神经网络时,最常用的算法是反向传播算法。在该算法中,根据损失函数相对于给定参数的梯度来调整参数(模型权重)。

为了计算这些梯度,PyTorch有一个内置的微分引擎,名为torch.autograd。它支持任何计算图的梯度自动计算。

考虑最简单的单层神经网络,输入x,参数w和b,以及一些损失函数。它可以在PyTorch中以以下方式定义:

import torchx = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

张量、函数与计算图

这段代码定义了以下计算图:

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

在这个网络中,w和b是我们需要优化的参数。因此,我们需要能够计算损失函数相对于这些变量的梯度。为了做到这一点,我们设置了这些张量的requires_grad属性。

我们应用于张量来构造计算图的函数实际上是函数类的对象。该对象知道如何在正向方向上计算函数,以及如何在反向传播步骤中计算其导数。对反向传播函数的引用存储在张量的grad_fn属性中。您可以在文档中找到Function的更多信息。

print(f"Gradient function for z = {z.grad_fn}")
print(f"Gradient function for loss = {loss.grad_fn}")

输出为:

Gradient function for z = <AddBackward0 object at 0x0000022EDB445C30>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x0000022EDB445D20>

计算梯度

为了优化神经网络中参数的权重,我们需要计算损失函数对参数的导数,即我们需要∂loss/∂w和∂loss/∂B。为了计算这些导数,我们调用loss.backward(),然后从w.g grad和b.g grad中检索值:

loss.backward()
print(w.grad)
print(b.grad)

输出为:

tensor([[0.0549, 0.1796, 0.0399],[0.0549, 0.1796, 0.0399],[0.0549, 0.1796, 0.0399],[0.0549, 0.1796, 0.0399],[0.0549, 0.1796, 0.0399]])
tensor([0.0549, 0.1796, 0.0399])

禁用梯度跟踪

默认情况下,所有requires_grad=True的张量都在跟踪它们的计算历史并支持梯度计算。然而,在某些情况下,我们不需要这样做,例如,当我们训练了模型,只想将其应用于一些输入数据时,即我们只想通过网络进行前向计算。我们可以通过使用torch.no_grad()块包围我们的计算代码来停止跟踪计算:

z = torch.matmul(x, w)+b
print(z.requires_grad)with torch.no_grad():z = torch.matmul(x, w)+b
print(z.requires_grad)

输出为:

True
False

实现相同结果的另一种方法是在张量上使用detach()方法:

z = torch.matmul(x, w)+b
z_det = z.detach()
print(z_det.requires_grad)

输出为:

False

你可能想要禁用渐变跟踪的原因如下:

  • 将神经网络中的一些参数标记为冻结参数。

  • 当你只做正向传递时,为了加快计算速度,因为在不跟踪梯度的张量上的计算会更有效率。

更多关于计算图的知识

从概念上讲,autograd在由Function对象组成的有向无环图(DAG)中保存数据(张量)和所有执行的操作(以及产生的新张量)的记录。在DAG中,叶是输入张量,根是输出张量。通过从根到叶的跟踪图,您可以使用链式法则自动计算梯度。

在向前传递中,autograd同时做两件事:

  • 运行请求的操作来计算结果张量

  • 在DAG中维持操作的梯度函数。

当在DAG根上调用.backward()时,向后传递开始。autograd:

  • 计算每个。grad_fn的梯度,

  • 在各自张量的.grad属性中累积它们

  • 利用链式法则,一直传播到叶张量。

[!TIP]

PyTorch中的dag是动态的,需要注意的重要一点是图形是从头开始重新创建的;在每次.backward()调用之后,autograd开始填充一个新图。这正是允许您在模型中使用控制流语句的原因;如果需要,您可以在每次迭代中更改形状、大小和操作

张量梯度和雅可比积

在很多情况下,我们有一个标量损失函数,我们需要计算关于一些参数的梯度。然而,在某些情况下,输出函数是一个任意张量。在这种情况下,PyTorch允许你计算所谓的雅可比积,而不是实际的梯度。

inp = torch.eye(4, 5, requires_grad=True)
out = (inp+1).pow(2).t()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"First call\n{inp.grad}")
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")
inp.grad.zero_()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")

输出为:

First call
tensor([[4., 2., 2., 2., 2.],[2., 4., 2., 2., 2.],[2., 2., 4., 2., 2.],[2., 2., 2., 4., 2.]])Second call
tensor([[8., 4., 4., 4., 4.],[4., 8., 4., 4., 4.],[4., 4., 8., 4., 4.],[4., 4., 4., 8., 4.]])Call after zeroing gradients
tensor([[4., 2., 2., 2., 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/web/81732.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【图像处理基石】立体匹配的经典算法有哪些?

1. 立体匹配的经典算法有哪些&#xff1f; 立体匹配是计算机视觉中从双目图像中获取深度信息的关键技术&#xff0c;其经典算法按技术路线可分为以下几类&#xff0c;每类包含若干代表性方法&#xff1a; 1.1 基于区域的匹配算法&#xff08;Local Methods&#xff09; 通过…

《Map 到底适合用哪个?HashMap、TreeMap、LinkedHashMap 对比实战》

大家好呀&#xff01;今天我们来聊聊Java中超级重要的Map集合家族 &#x1f3a2;。Map就像是一个神奇的魔法口袋&#xff0c;可以帮我们把东西&#xff08;值&#xff09;和标签&#xff08;键&#xff09;一一对应存放起来。不管你是Java新手还是老司机&#xff0c;掌握Map都是…

TencentOSTiny

开放原子开源基金会 腾讯物联网终端操作系统 _物联网操作系统_物联网OS_TencentOS tiny-腾讯云 GitHub - OpenAtomFoundation/TobudOS: 开放原子开源基金会孵化的物联网操作系统&#xff0c;捐赠前为腾讯物联网终端操作系统TencentOS Tiny 项目简介 TencentOS Tiny 是腾讯…

使用 Selenium 进行自动化测试:入门指南

在现代软件开发中&#xff0c;自动化测试已经成为不可或缺的一部分。它不仅提高了测试效率&#xff0c;还减少了人为错误的可能性。Selenium 是一个强大的开源工具&#xff0c;广泛用于 Web 应用程序的自动化测试。本文将详细介绍如何使用 Selenium 进行自动化测试&#xff0c;…

C54-动态开辟内存空间

1.malloc 原型&#xff1a;void* malloc(size_t size);&#xff08;位于 <stdlib.h> 头文件中&#xff09; 作用&#xff1a;分配一块连续的、未初始化的内存块&#xff0c;大小为 size 字节。 返回值&#xff1a; 成功&#xff1a;返回指向分配内存首地址的 void* 指针…

ELK服务搭建-0-1搭建记录

ELK搭建 需要准备一台linux服务器&#xff08;最好是CentOS7&#xff09;,内存至少4G以上&#xff08;三个组件都比较占用内存&#xff09; 演示基于ElasticSearch采用的是8.5.0版本 1、 Docker安装Elasticsearch 创建一个网络 因为我们还需要部署kibana容器、logstash容器&am…

调参指南:如何有效优化模型训练效果

🚀 调参指南:如何有效优化模型训练效果(深度学习实战) 模型跑通不难,调得好才是本事。本篇文章将系统讲解如何在训练过程中有效调参,从学习率到网络结构,从损失函数到正则化,让你的模型效果“飞升”。 🧠 一、为什么需要调参? 初学者常常以为模型训练完就“任务完…

laya3的2d相机与2d区域

2d相机和2d区域都继承自Sprite。 2d相机必须作为2d区域的子节点&#xff0c;且2d相机必须勾选isMain才能正常使用。 2d区域下如果没有主相机&#xff0c;则他和Sprite无异&#xff0c;他的主要操作皆是针对主相机。 2d相机可以调整自己的移动范围&#xff0c;是否紧密跟随&a…

【保姆级教程】Windows部署LibreTV+cpolar实现远程影音库访问全步骤

文章目录 前言1.关于LibreTV2.docker部署LibreTV3.简单使用LibreTV4.安装cpolar内网穿透5.配置ward公网地址6.配置固定公网地址总结 前言 当周末的闲暇时光来临时&#xff0c;您是否也习惯性地瘫倒在沙发上&#xff0c;渴望通过影视作品缓解一周的疲惫&#xff1f;然而在准备点…

Windows安装Docker部署dify,接入阿里云api-key进行rag测试

一、安装docker 1.1 傻瓜式安装docker Get Docker | Docker Docs Docker原理&#xff08;图解秒懂史上最全&#xff09;-CSDN博客 官网选择好windows的安装包下载&#xff0c;傻瓜式安装。如果出现下面的报错&#xff0c;说明主机没有安装WSL 1.2 解决办法 安装 WSL | Mic…

Cursor 与DeepSeek的完美契合

这两天在看清华大学最近出的一个关于deepseek入门的官方视频中&#xff0c;看了几个deepseek的应用场景还是能够感觉到它的强大之处的&#xff0c;例如根据需求生成各种markdown格式的代码&#xff0c;再结合市面上已有的一些应用平台生成非常好看的流程图&#xff0c;PPT,报表…

【深度学习】13. 图神经网络GCN,Spatial Approach, Spectral Approach

图神经网络 图结构 vs 网格结构 传统的深度学习&#xff08;如 CNN 和 RNN&#xff09;在处理网格结构数据&#xff08;如图像、语音、文本&#xff09;时表现良好&#xff0c;因为这些数据具有固定的空间结构。然而&#xff0c;真实世界中的很多数据并不遵循网格结构&#x…

[Python] 避免 PyPDF2 写入 PDF 出现黑框问题:基于语言自动匹配系统字体的解决方案

在使用 Python 操作 PDF 文件时,尤其是在处理中文、日语等非拉丁字符语言时,常常会遇到一个令人头疼的问题——文字变成“黑框”或“方块”,这通常是由于缺少合适的字体支持所致。本文将介绍一种自动选择系统字体的方式,结合 PyPDF2 模块解决此类问题。 一、问题背景:黑框…

Java求职面试:从核心技术到AI与大数据的全面考核

Java求职面试&#xff1a;从核心技术到AI与大数据的全面考核 第一轮&#xff1a;基础框架与核心技术 面试官&#xff1a;谢飞机&#xff0c;咱们先从简单的开始。请你说说Spring Boot的启动过程。 谢飞机&#xff1a;嗯&#xff0c;Spring Boot启动的时候会自动扫描组件&…

Espresso 是什么

Espresso 是 Android 开发者的首选 UI 测试工具&#xff0c;是 Google 官方推出的 Android 应用 UI 测试框架&#xff0c;专为 白盒测试 设计&#xff0c;强调 速度快、API 简洁&#xff0c;适合开发者在编写代码时同步进行自动化测试。它是 Android Jetpack 测试工具的一部分&…

Axios 如何通过配置实现通过接口请求下载文件

前言 今天&#xff0c;我写了 《Nodejs 实现 Mysql 数据库的全量备份的代码演示》 和 《NodeJS 基于 Koa, 开发一个读取文件&#xff0c;并返回给客户端文件下载》 两篇文章。在这两篇文章中&#xff0c;我实现了数据库的备份&#xff0c;和提供数据库下载等接口。 但是&…

IDEA项目推送到远程仓库

打开IDEA——>VCS——>Creat Git 选择项目 push提交到本地 创建远程仓库 复制地址 定义远程仓库 推送 推送成功

Prompt工程:解锁大语言模型的终极密钥

Prompt工程&#xff1a;解锁大语言模型的终极密钥 一、引言&#xff1a;Prompt的战略价值重构 在人工智能技术加速渗透的2025年&#xff0c;Prompt&#xff08;提示词&#xff09;作为连接人类意图与大语言模型&#xff08;LLM&#xff09;的核心接口&#xff0c;其战略地位已…

架构意识与性能智慧的双重修炼

架构意识与性能智慧的双重修炼 ——现代软件架构师的核心能力建设指南 作者:蓝葛亮 🎯引言 在当今快速发展的技术环境中,软件架构师面临着前所未有的挑战。随着业务复杂度的不断增长和用户对性能要求的日益严苛,如何在架构设计中平衡功能实现与性能优化,已成为每个技术…

Flutter下的一点实践

目录 1、背景2、refena创世纪代码3、localsend里refena的刷新3.1 初始状态3.2 发起设备扫描流程3.3 扫描过程3.3 刷新界面 4.localsend的设备扫描流程4.1 UDP广播设备注册流程4.2 TCP/HTTP设备注册流程4.3 localsend的服务器初始化工作4.4总结 1、背景 在很久以前&#xff0c;…