深度学习与图像处理 | 基于PaddlePaddle的梯度下降算法实现(线性回归投资预测)

 演示基于PaddlePaddle自动求导技术实现梯度下降,简化求解过程。

01、梯度下降法

梯度下降法是机器学习领域非常重要和具有代表性的算法,它通过迭代计算来逐步寻找目标函数极小值。既然是一种迭代计算方法,那么最重要的就是往哪个方向迭代,梯度下降法选择从目标函数的梯度切入。首先需要明确一个数学概念,即函数的梯度方向是函数值变化最快的方向。梯度下降法就是基于此来进行迭代。

图2.28对应一个双自变量函数

图片

。想要求得该函数极小值,只需要随机选择一个初始点,然后计算当前点对应的梯度,按照梯度反方向下降一定高度,然后重新计算当前位置对应的梯度,继续按照梯度反方向下降。按照上述方式迭代,最终就可以用最快的速度到达极小值附近。

■  梯度下降法示意图

 

如果函数很复杂并有多个极小值点,那么选择不同的初始值,按照梯度下降算法的计算方式很有可能会到达不同的极小值点,并且耗时也不一样。因此,在工程实现上选择一个好的初始值是非常重要的。

对于前面的直线拟合任务来说,其目标函数就是L,模型参数就是a和b。按照梯度下降算法的原理,对应实现步骤如下:

(1)初始化模型参数a和b;

(2)输入每个样本x,根据公式y=ax+b计算每个样本数据的预测输出值

图片

(3)计算所有样本的预测值

图片

和真值y之间的平方差L;

(4)计算当前L对模型参数a和b的梯度值,即

图片

按照下式更新参数a和b:

图片

其中t表示当前迭代的轮次,

图片

是一个提前设置好的参数,这个参数的作用是代表每一步迭代下降的跨度,专业术语也叫学习率;

(6)重复步骤(2)~(5),直至迭代次数超过某个预设值。

注意到,上述算法第(5)步中,需要计算目标函数L对a和b的偏导。尽管对于这个直线拟合任务来说其偏导求取非常简单,但是依然需要手工进行求导。在2.3.3节中,介绍过可以通过PaddlePaddle来自动计算梯度,因此,可以使用PaddlePaddle来更便捷的实现这个梯度下降算法。

完整代码如下(machine_learning/auto_diff.py):

import matplotlib.pyplot as plt
import numpy as np
import paddle
# 输入数据
x_train = np.array(    [3.3, 4.4, 5.5, 6.7, 6.9, 4.2, 9.8, 6.2, 7.6, 2.2, 7, 10.8, 5.3, 8, 3.1],    dtype=np.float32,)
y_train = np.array(    [17, 28, 21, 32, 17, 16, 34, 26, 25, 12, 28, 35, 17, 29, 13], dtype=np.float32)
# numpy转tensor
x_train = paddle.to_tensor(x_train)
y_train = paddle.to_tensor(y_train)
# 随机初始化模型参数
a = np.random.randn(1)
a = paddle.to_tensor(a, dtype="float32", stop_gradient=False)
b = np.random.randn(1)
b = paddle.to_tensor(b, dtype="float32", stop_gradient=False)
# 循环迭代
for t in range(10):# 计算平方差损失   y_ = a * x_train + b loss = paddle.sum((y_ - y_train) ** 2)# 自动计算梯度loss.backward()# 更新参数(梯度下降),学习率默认使用1e-3a = a.detach() - 1e-3 * float(a.grad)b = b.detach() - 1e-3 * float(b.grad)a.stop_gradient = Falseb.stop_gradient = False# 输出当前轮的目标函数值Lprint("epoch: {}, loss: {}".format(t, (float(loss))))
# 训练结束,终止a和b的梯度计算
a.stop_gradient = True
b.stop_gradient = True
# 可视化输出
x_pred = paddle.arange(0, 15)
y_pred = a * x_pred + b
plt.plot(x_train.numpy(), y_train.numpy(), "go", label="Original Data")
plt.plot(x_pred.numpy(), y_pred.numpy(), "r-", label="Fitted Line")plt.xlabel("investment")
plt.ylabel("income")plt.legend()plt.savefig("result.png")
# 预测第16年的收益值
x = 12.5
y = a * x + b
print(y.numpy())

 上述代码对每轮迭代的目标函数进行了输出,同时预测了第16年的收益值,结果如下:

epoch: 0, loss: 9141.90625
epoch: 1, loss: 1103.665283203125
epoch: 2, loss: 397.5347900390625
epoch: 3, loss: 335.0281982421875
epoch: 4, loss: 329.02337646484375
epoch: 5, loss: 327.982421875
epoch: 6, loss: 327.381103515625
epoch: 7, loss: 326.8223571777344
epoch: 8, loss: 326.2711486816406
epoch: 9, loss: 325.72442626953125
[44.96492]

可以看到,随着迭代的不断进行,目标函数逐渐减少,说明模型的预测输出越来越接近真值。最终训练好的模型所预测的第16年的收益值与上一节使用导数法求解的标准解非常接近,验证了梯度下降算法的有效性。

梯度下降法拟合结果如图2.29所示。

■  梯度下降法拟合结果

 

从拟合结果看到,利用PaddlePaddle自动帮助求导,通过梯度下降迭代更新模型参数,最后得到了令人满意的结果,拟合出来的直线基本吻合数据的分布。整个过程不需要手工计算梯度,实现非常简单。

注意,上述代码使用了随机值来初始化模型参数,因此每次运算的结果可能略有不同。另外,使用了固定的学习率1e-3,并得到了一个比较好的训练结果。如果训练过程中目标函数没有逐步下降,那么就需要适当调整学习率重新训练。

本节案例是一个非常简单的使用PaddlePaddle进行机器学习的示例,旨在帮助读者熟悉和巩固PaddlePaddle的基本使用方法。虽然任务简单,但是该示例“五脏俱全”,整个建模学习过程分为4个部分,如图2.30所示。

■  基于PaddlePaddle的梯度下降法步骤

 对于后面的深度学习任务,也会按照上述方式进行模型训练。下面正式开始介绍如何基于PaddlePaddle实现更复杂的深度学习图像应用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/917060.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/917060.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

负载均衡集群HAproxy

HAProxy 简介HAProxy 是一款高性能的负载均衡器和代理服务器,支持 TCP 和 HTTP 应用。广泛用于高可用性集群,能够有效分发流量到多个后端服务器,确保服务的稳定性和可扩展性。HAProxy 核心功能负载均衡:支持轮询(round…

重生之我在10天内卷赢C++ - DAY 1

坐稳了,我们的C重生之旅现在正式发车!请系好安全带,前方高能,但绝对有趣!🚀 重生之我在10天内卷赢C - DAY 1导师寄语:嘿,未来的编程大神!欢迎来到C的世界。我知道&#x…

[mind-elixir]Mind-Elixir 的交互增强:单击、双击与鼠标 Hover 功能实现

[mind-elixir]Mind-Elixir 的交互增强:单击、双击与鼠标 Hover 功能实现 功能简述 通过防抖,实现单击双击区分通过mousemove事件,实现hover效果 实现思路 (一)单击与双击事件 功能描述 单击节点时,可以触发…

c++-迭代器类别仿函数常用算法函数

C常用算法函数 1. 前置知识 1.1 迭代器的类别 C中,迭代器是 STL 容器库的核心组件之一,具有举足轻重的作用,它提供了一种 统一的方式来访问和遍历容器,而无需关心底层数据结构的具体实现。迭代器类似指针,但比指针更通…

Python深度学习框架TensorFlow与Keras的实践探索

基础概念与安装配置 TensorFlow核心架构解析 TensorFlow是由Google Brain团队开发的开源深度学习框架,其核心架构包含数据流图(Data Flow Graph)和张量计算系统。数据流图通过节点表示运算操作(如卷积、激活函数)&…

c# net6.0+ 安装中文智能提示

https://github.com/stratosblue/IntelliSenseLocalizer 1、安装tool dotnet tool install -g islocalizer 2、 安装IntelliSense 文件,安装其他net版本修改下版本号 安装中文net6.0采集包 islocalizer install auto -m net6.0 -l zh-cn 安装中英文双语net6.0采集包…

【建模与仿真】二阶邻居节点信息驱动的节点重要性排序算法

导读: 在复杂网络中,挖掘重要节点对精准推荐、交通管控、谣言控制和疾病遏制等应用至关重要。为此,本文提出一种局部信息驱动的节点重要性排序算法Leaky Noisy Integrate-and-Fire (LNIF)。该算法通过获取节点的二阶邻居信息计算节点重要性&…

指令微调Qwen3实现文本分类任务

参考文档: SwanLab入门深度学习:Qwen3大模型指令微调 - 肖祥 - 博客园 vLLM:让大语言模型推理更高效的新一代引擎 —— 原理详解一_vllm 原理-CSDN博客 概述 为了实现对100个标签的多标签文本分类任务,前期调用gpt-4o进行prom…

【机器学习-3】 | 决策树与鸢尾花分类实践篇

0 序言 本文将深入探讨决策树算法,先回顾下前边的知识,从其基本概念、构建过程讲起,带你理解信息熵、信息增益等核心要点。 接着在引入新知识点,介绍Scikit - learn 库中决策树的实现与应用,再通过一个具体项目的方式来…

【数字投影】折幕影院都是沉浸式吗?

折幕影院作为一种现代化的展示形式,其核心特点在于通过多块屏幕拼接和投影融合技术,打造更具包围感的视觉体验。折幕影院设计通常采用多折幕结构,如三折幕、五折幕等,利用多台投影机的协同工作,呈现无缝衔接的超大画面…

数据结构——图(三、图的 广度/深度 优先搜索)

一、广度优先搜索(BFS)①找到与一个顶点相邻的所有顶点 ②标记哪些顶点被访问过 ③需要一个辅助队列#define MaxVertexNum 100 bool visited[MaxVertexNum]; //访问标记数组 void BFSTraverse(Graph G){ //对图进行广度优先遍历,处理非连通图的函数 for(int i0;i…

直击WAIC | 百度袁佛玉:加速具身智能技术及产品研发,助力场景应用多样化落地

7月26日,2025世界人工智能大会暨人工智能全球治理高级别会议(WAIC)在上海开幕。同期,由国家地方共建人形机器人创新中心(以下简称“国地中心”)与中国电子学会联合承办,百度智能云、中国联通上海…

2025年人形机器人动捕技术研讨会将在本周四召开

2025年7月31日爱迪斯通所主办的【2025人形机器动作捕捉技术研讨会】是携手北京天树探界公司线下活动结合线上直播的形式,会议将聚焦在“动作捕捉软硬件协同,加速人形机器人训练”,将深度讲解多项核心技术,包含全球知名的惯性动捕大…

Apple基础(Xcode①-项目结构解析)

要运行设备之前先选择好设备Product---->Destination---->选择设备首次运行手机提示如出现 “未受信任的企业级开发者” → 手机打开 设置 ▸ 通用 ▸ VPN与设备管理 → 信任你的 Apple ID 即可ContentView 是 SwiftUI 项目里 最顶层、最主界面 的那个“页面”&#xff0…

微服务 02

一、网关路由网关就是网络的关口。数据在网络间传输,从一个网络传输到另一网络时就需要经过网关来做数据的路由和转发以及数据安全的校验。路由是网关的核心功能之一,决定如何将客户端请求映射到后端服务。1、快速入门创建新模块,引入网关依赖…

04动手学深度学习笔记(上)

04数据操作 import torch(1)张量表示一个数据组成的数组,这个数组可能有多个维度。 xtorch.arange(12) xtensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])(2)通过shape来访问张量的形状和张量中元素的总数 x.shapetorch.Size([12])(3)number of elements表…

MCU中的RTC(Real-Time Clock,实时时钟)是什么?

MCU中的RTC(Real-Time Clock,实时时钟)是什么? 在MCU(微控制器单元)中,RTC(Real-Time Clock,实时时钟) 是一个独立计时模块,用于在系统断电或低功耗状态下持续记录时间和日期。以下是关于RTC的详细说明: 1. RTC的核心功能 精准计时:提供年、月、日、时、分、秒、…

Linux 进程调度管理

进程调度器可粗略分为两类:实时调度器(kernel),系统中重要的进程由实时调度器调度,获得CPU能力强。非实时调度器(user),系统中大部分进程由非实时调度器调度,获得CPU能力弱。实时调度器实时调度器支持的调度策略&#…

基于 C 语言视角:流程图中分支与循环结构的深度解析

前言(约 1500 字)在 C 语言程序设计中,控制结构是构建逻辑的核心骨架,而流程图作为可视化工具,是将抽象代码逻辑转化为直观图形的桥梁。对于入门 C 语言的工程师而言,掌握流程图与分支、循环结构的对应关系…

threejs创建自定义多段柱

最近在研究自定义建模,有一个多断柱模型比较有意思,分享下,就是利用几组点串,比如上中下,然后每组点又不一样多,点续还不一样,(比如第一个环的第一个点在左边,第二个环在右边)&#…