《零基础入门AI:线性回归进阶(梯度下降算法详解)》

在上一篇博客中,我们学习了线性回归的基本概念、损失函数(如MSE)以及最小二乘法。最小二乘法通过求解解析解(直接计算出最优参数)的方式得到线性回归模型,但它有一个明显的局限:当特征数量很多时,计算过程会非常复杂(涉及矩阵求逆等操作)。今天我们来学习另一种更通用、更适合大规模数据的参数优化方法——梯度下降

一、什么是梯度下降?

梯度下降(Gradient Descent)是一种迭代优化算法,核心思想是:通过不断地沿着损失函数"下降最快"的方向调整参数,最终找到损失函数的最小值(或近似最小值)。

我们可以用一个生活中的例子理解:假设你站在一座山上,周围被大雾笼罩,你看不见山脚在哪里,但你想以最快的速度走到山脚下。此时,你能做的最合理的选择就是:先感受一下脚下的地面哪个方向坡度最陡且向下,然后沿着那个方向走一步;走到新的位置后,再重复这个过程——感受坡度最陡的向下方向,再走一步;直到你感觉自己已经走到了最低点(脚下各个方向都不再向下倾斜)。

这个过程就是梯度下降的直观体现:

  • 这座山就是我们的损失函数
  • 你的位置代表当前参数值
  • 你感受到的"坡度最陡的向下方向"就是负梯度方向
  • 你每次走的"一步"的长度就是学习率
  • 最终到达的"山脚下"就是损失函数的最小值点

在线性回归中,我们的目标是找到最优参数(如权重w),使得损失函数L(w)达到最小值。梯度下降的作用就是帮助我们一步步调整这些参数,最终找到让损失函数最小的参数值。

二、梯度下降的基本步骤

梯度下降的过程可以总结为4个核心步骤,我们以单特征且不含偏置项的线性回归模型y = wx为例(即b=0,损失函数使用MSE),逐步说明:

步骤1:初始化参数

首先需要给参数w设定初始值。初始值可以是任意的(比如随机值、0或1),因为梯度下降会通过迭代不断优化它。

为什么初始值可以任意选择?因为梯度下降是一个迭代优化的过程,无论从哪个点开始,只要迭代次数足够多且学习率合适,最终都会收敛到损失函数的最小值附近。

例如:我们可以简单地将初始值设为w = 0,然后开始优化过程。

步骤2:计算损失函数的梯度

“梯度"在单参数情况下就是损失函数对该参数的导数,它表示损失函数在当前参数位置的"变化率"和"变化方向”。

对单特征且b=0的模型y = wx,我们只需要计算一个导数:

  • 损失函数L对w的导数:∂L/∂w(表示当w变化时,损失函数L的变化率)

这个导数就是"梯度",它指向损失函数增长最快的方向。这很重要:梯度指向的是损失函数值上升最快的方向,所以要让损失函数减小,我们需要向相反的方向移动。

步骤3:更新参数

为了让损失函数减小,我们需要沿着梯度的反方向(即负梯度方向)调整参数。更新公式为:

w = w - α · (∂L/∂w)

其中α是"学习率"(后面会详细解释),它控制参数更新的"步长"。

为什么是减去梯度而不是加上?因为梯度指向损失函数增大的方向,所以减去梯度就意味着向损失函数减小的方向移动,这正是我们想要的。

步骤4:重复迭代,直到收敛

重复步骤2和步骤3:每次计算当前参数的梯度,然后沿负梯度方向更新参数。当满足以下条件之一时,停止迭代(即"收敛"):

  • 梯度的绝对值接近0(此时损失函数变化很小,接近最小值);
  • 损失函数L(w)的变化量小于某个阈值(比如连续两次迭代的损失差小于10⁻⁶);
  • 达到预设的最大迭代次数(防止无限循环)。

"收敛"这个词可以理解为:参数值已经稳定下来,继续迭代也不会有明显变化,此时我们可以认为找到了最优参数

三、梯度下降的公式推导(单特征且b=0)

要实现梯度下降,核心是求出损失函数对参数w的导数。我们以MSE损失函数为例(且b=0),详细推导∂L/∂w的计算过程,每一步都会给出详细说明。

已知条件

  • 模型:y_pred = wx(预测值,因b=0,无偏置项)

  • 真实值:y

  • 损失函数(MSE):

    L(w) = (1/2n)Σ(yᵢ - y_pred,ᵢ)² = (1/2n)Σ(yᵢ - wxᵢ)²
    

(注:公式中加入1/2是为了后续求导时抵消平方项的系数2,使计算更简洁,不影响最终结果)

推导∂L/∂w(损失函数对w的导数)

  1. 先对单个样本的损失求导:
    单个样本的损失为lᵢ = (1/2)(yᵢ - wxᵢ)²,对w求导:

    ∂lᵢ/∂w = 2 · (1/2)(yᵢ - wxᵢ) · (-xᵢ) = -(yᵢ - wxᵢ)xᵢ
    

    这里用到了复合函数求导法则(链式法则):首先对平方项求导得到2·(1/2)(…),然后对括号内的内容求导,由于我们是对w求导,所以(wxᵢ)对w的导数是xᵢ,前面有个负号,所以整体是-(yᵢ - wxᵢ)xᵢ。

  2. 对所有样本的损失求和后求导:
    总损失L是所有单个样本损失的平均值:L = (1/n)Σlᵢ,因此:

    ∂L/∂w = (1/n)Σ(∂lᵢ/∂w) = (1/n)Σ[-(yᵢ - wxᵢ)xᵢ] = -(1/n)Σ(yᵢ - y_pred,ᵢ)xᵢ
    

    这一步的含义是:总损失对w的导数等于所有单个样本损失对w的导数的平均值。

最终更新公式

将上面得到的导数代入参数更新公式(参数 = 参数 - 学习率 × 导数),得到:

w = w + α · (1/n)Σ(yᵢ - y_pred,ᵢ)xᵢ

(注:负负得正,公式中的减号变为加号)

这个公式的含义是:

  • 如果预测值y_pred,ᵢ小于真实值yᵢ(即yᵢ - y_pred,ᵢ为正),则w会增大;反之则减小。
  • 增大或减小的幅度取决于三个因素:误差大小(yᵢ - y_pred,ᵢ)、特征值xᵢ的大小和学习率α。
  • 特征值xᵢ越大,相同误差下w的更新幅度也越大,这体现了特征对参数调整的影响。

四、学习率(α)的作用

学习率(Learning Rate)是梯度下降中最重要的超参数(需要人工设定的参数),它控制参数更新的"步长"。我们继续用"下山"的例子来理解:

  • 如果学习率α太小:就像每次只迈一小步下山,虽然安全,但需要走很多步才能到达山脚(迭代次数多,效率低)。
  • 如果学习率α太大:就像每次迈一大步下山,可能会直接跨过山脚,甚至走到对面的山坡上(跳过最小值,甚至导致损失函数越来越大,无法收敛)。
  • 合适的学习率:步长适中,能快速逼近最小值,既不会太慢也不会跳过。

实际应用中,学习率通常需要通过尝试确定,常见的初始值有0.1、0.01、0.001等。一种常用的策略是"学习率衰减":随着迭代次数增加,逐渐减小学习率,这样在开始时可以快速接近最小值,后期可以精细调整。

举个形象的例子:假设你在下山,开始时你离山脚很远,可以大踏步前进(较大的学习率);当快到山脚时,你会放慢脚步,小步移动(较小的学习率),以免走过头。

完整示例(手动实现梯度下降,单特征,b=0)

import numpy as np
import matplotlib.pyplot as plt  # 可视化# 创建数据  植物的温度、和生长高度 [[20,10],[22,10],[27,12],[25,16]]
data =np.array([[20,10],[22,10],[27,12],[25,16]])
# 划分
x=data[:,0]
y=data[:,1]
print(x)
print(y)# 创建一个模型
def model(x,w):return x*w# 定义损失函数
# def loss(y_pred,y):
#      return np.sum((y_pred-y)**2)/len(y)# 手动将损失函数展开  便于下面写梯度函数
def loss(w):return 2238*(w**2) - 1144*w + 600# 梯度函数  即,将损失函数求导
def gradient(w):return 2*2238*w - 1144# 梯度下降  给定初始系数w  迭代100次 优化w
w=0
learning_rate = 1e-5  # 降低学习率避免溢出
for i in range(100):w=w-learning_rate*gradient(w)print('e:',loss(w),'w:',w)# 绘制损失函数
plt.plot(np.linspace(0,1,100),loss(np.linspace(0,1,100)))# 绘制模型
def draw_line(w):point_x = np.linspace(0, 30, 100)point_y = model(point_x, w)plt.plot(point_x, point_y, label=f'Fitted line (w={w:.4f})')plt.scatter(x, y, color='red', label='Data points')plt.legend()plt.xlabel("Temperature")plt.ylabel("Height")plt.title("Linear Regression via Gradient Descent")plt.grid(True)plt.show()# draw_line(w)

五、多特征的梯度下降(以2个特征为例)

现实中,我们遇到的问题往往有多个特征(比如用"面积"和"房间数"预测房价)。下面我们推导2个特征的线性回归模型的梯度下降公式,方法与单特征类似,但需要考虑更多参数。

模型与损失函数

  • 2个特征的模型:y_pred = w₁x₁ + w₂x₂(x₁、x₂是两个特征,w₁、w₂是对应的权重,因b=0,无偏置项)

  • 损失函数(MSE):

    L(w₁,w₂) = (1/2n)Σ(yᵢ - (w₁x₁,ᵢ + w₂x₂,ᵢ))²
    

推导各参数的偏导数

与单特征思路一致,我们分别对w₁、w₂求偏导:

  1. 对w₁的偏导:

    ∂L/∂w₁ = -(1/n)Σ(yᵢ - y_pred,ᵢ)x₁,ᵢ
    

    推导过程与单特征中w的导数完全相同,只是这里特征是x₁,所以最后乘以x₁,ᵢ。

  2. 对w₂的偏导:

    ∂L/∂w₂ = -(1/n)Σ(yᵢ - y_pred,ᵢ)x₂,ᵢ
    

    同理,这里特征是x₂,所以最后乘以x₂,ᵢ。

参数更新公式

将上述偏导数代入更新公式,得到:

w₁ = w₁ + α · (1/n)Σ(yᵢ - y_pred,ᵢ)x₁,ᵢ
w₂ = w₂ + α · (1/n)Σ(yᵢ - y_pred,ᵢ)x₂,ᵢ

多特征的扩展规律

从2个特征的推导可以看出,梯度下降的公式可以很容易扩展到k个特征的情况:

  • 模型:y_pred = w₁x₁ + w₂x₂ + … + wₖxₖ

  • 对第j个权重wⱼ的更新公式:

    wⱼ = wⱼ + α · (1/n)Σ(yᵢ - y_pred,ᵢ)xⱼ,ᵢ
    

(xⱼ,ᵢ表示第i个样本的第j个特征值)

这个规律非常重要,它告诉我们:无论有多少个特征,梯度下降的更新规则都是相似的——每个权重wⱼ的更新量都与对应特征xⱼ和误差(yᵢ - y_pred,ᵢ)的乘积有关。

完整示例(手动实现梯度下降,两个特征,b=0)

import numpy as np
import matplotlib.pyplot as plt
# 如果使用中文显示,建议添加以下配置
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False    # 用来正常显示负号# 创建数据  [[1,1,3],[2,1,4],[1,2,5],[2,2,6]]
data = np.array([[1,1,3],[2,1,4],[1,2,5],[2,2,6]])
# 划分
x=data[:,:-1]
y=data[:,-1]
print(x)
print(y)# 创建模型
def model(x,w):return np.sum(x*w)# 创建损失函数
# def loss(w,x):# return np.sum((np.sum(x*w,axis=1)-y)**2)# return np.sum((model(x,w)-y)**2)
def loss(w1,w2):return 5*w1**2 + 5*w2**2 +9*w1*w2 -28*w1-29*w2 +43# 创建梯度函数
def gradient_w1(w1,w2):return 10*w1+9*w2-28def gradient_w2(w1,w2):return 9*w1+10*w2-29# 初始化w1,w2
w1=0
w2=0# 迭代100次 优化w1,w2
for i in range(100):w1,w2=w1-0.01*gradient_w1(w1,w2),w2-0.01*gradient_w2(w1,w2)print('e:',loss(w1,w2),'w1:',w1,'w2:',w2)# # 绘制模型  没写出来(所以注释了)
# def draw_line(w1,w2):
#     point_x=np.linspace(0,5,100)
#     point_y=model(point_x,w1,w2)
#     plt.plot(point_x,point_y)# draw_line(w1,w2)

六、梯度下降与最小二乘法的对比

特点梯度下降最小二乘法
本质迭代优化(数值解)直接求解方程(解析解)
计算复杂度低(适合大规模数据/多特征)高(涉及矩阵求逆)
适用性几乎所有损失函数仅适用于凸函数且有解析解
超参数依赖需要调整学习率等无需超参数
内存需求低(可分批处理数据)高(需要一次性加载所有数据)

简单来说,当特征数量较少时,最小二乘法可能更简单直接;但当特征数量很多(比如超过1000个)时,梯度下降通常是更好的选择。

总结

梯度下降是机器学习中最基础也最常用的优化算法,它通过"沿损失函数负梯度方向迭代更新参数"的方式,找到使损失最小的参数值。与最小二乘法相比,梯度下降更适合处理大规模数据和复杂模型。

本文我们从概念、步骤、公式推导(单特征且b=0和双特征)、学习率作用等方面详细讲解了梯度下降,希望能帮助你理解其核心逻辑。掌握梯度下降不仅对理解线性回归至关重要,也是学习更复杂机器学习算法(如神经网络)的基础。

下一篇博客中,我们将通过实际案例演示如何用梯度下降实现线性回归,进一步加深理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/917176.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/917176.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于C语言实现的KV存储引擎(一)

基于C语言实现的KV存储引擎项目简介整体架构网络模块的实现recatorproactorNtyco项目简介 本文主要是基于 C 语言来实现一个简单的 KV 存储架构,目的就是将网络模块跟实际开发结合起来。 首先我们知道对于数据的存储可以分为两种方式,一种是在内存中进…

c++和python联合编程示例

安装 C与 Python 绑定工具 pip install pybind11这其实相当于使用 python 安装了一个 c的库 pybind11,这个库只由头文件构成, 支持基础数据类型传递以及 python 的 numpy 和 c的 eigen 库之间的自动转换。 编写 CMakeList.txt cmake_minimum_required(VERSION 3.14)…

【OD机试题解法笔记】贪心歌手

题目描述 一个歌手准备从A城去B城参加演出。 按照合同,他必须在 T 天内赶到歌手途经 N 座城市歌手不能往回走每两座城市之间需要的天数都可以提前获知。歌手在每座城市都可以在路边卖唱赚钱。 经过调研,歌手提前获知了每座城市卖唱的收入预期&#xff1a…

AI: 告别过时信息, 用RAG和一份PDF 为LLM打造一个随需更新的“外脑”

嘿,各位技术同学!今天,我们来聊一个大家在使用大语言模型(LLM)时都会遇到的痛点:知识过时。 无论是像我一样,用 Gemini Pro 学习日新月异的以太坊,还是希望它能精确掌握某个特定工具…

深度学习(鱼书)day08--误差反向传播(后三节)

深度学习(鱼书)day08–误差反向传播(后三节)一、激活函数层的实现 这里,我们把构成神经网络的层实现为一个类。先来实现激活函数的ReLU层和Sigmoid层。ReLU层 激活函数ReLU(Rectified Linear Unit&#xff…

C# 中生成随机数的常用方法

1. 使用 Random 类(简单场景) 2. 使用 RandomNumberGenerator 类(安全场景) 3. 生成指定精度的随机小数 C# 中生成随机数的常用方法: 随机数类型实现方式示例代码特点与适用场景随机整数(无范围&#xf…

Flink 算子链设计和源代码实现

1、JobGraph (JobManager) JobGraph 生成时,通过 ChainingStrategy 连接算子,最终在 Task 中生成 ChainedDriver 链表。StreamingJobGraphGeneratorcreateJobGraph() 构建jobGrapch 包含 JobVertex setChaining() 构建算子链isCha…

对接八大应用渠道

背景最近公司想把游戏包上到各个渠道上,因此需要对接各种渠道,渠道如下,oppo、vivo、华为、小米、应用宝、taptap、荣耀、三星等应用渠道 主要就是对接登录、支付接口(后续不知道会不会有其他的)&#x…

学习:入门uniapp Vue3组合式API版本(17)

42.打包发行微信小程序的上线全流程 域名 配置 发行 绑定手机号 上传 提交后等待,上传 43.打包H5并发布上线到unicloud的前端页面托管 完善配置 unicloud 手机号实名信息不一致:请确保手机号的实名信息与开发者姓名、身份证号一致,请前往开…

SOLIDWORKS材料明细表设置,属于自己的BOM表模板

上一期我们了解了如何在SOLIDWORKS工程图中添加材料明细表?接下来,我们将进行对SOLIDWORKS材料明细表的设置、查看缩略图、模板保存的深度讲解。01 材料明细表设置菜单栏生成表格后左侧菜单栏会显示关于材料明细表的相关设置信息。我们先了解一下菜单栏设置详情&am…

全栈:Maven的作用是什么?本地仓库,私服还有中央仓库的区别?Maven和pom.xml配置文件的关系是什么?

Maven和pom.xml配置文件的关系是什么: Maven是一个构建工具和依赖管理工具,而pom.xml(Project Object Model)是Maven的核心配置文件。 SSM 框架的项目不一定是 Maven 项目,但推荐使用 Maven进行管理。 SSM 框架的项目可…

超越 ChatGPT:智能体崛起,开启全自主 AI 时代

引言 短短三年,生成式 AI 已从对话助手跨越到能自主规划并完成任务的“智能体(Agentic AI)”时代。这场演进不仅体现在模型规模的提升,更在于系统架构、交互范式与安全治理的全面革新。本文按时间线梳理关键阶段与核心技术,为您呈现 AI 智能体革命的脉络与未来趋势。 1. …

一杯就够:让大脑瞬间在线、让肌肉满电的 “Kick-out Drink” 全解析

一杯就够:让大脑瞬间在线、让肌肉满电的 “Kick-out Drink” 全解析“每天清晨,当闹钟还在哀嚎,你举杯一饮,睡意像被扔出擂台——这,就是 Kick-out Drink 的全部浪漫。”清晨 30 分钟后,250 mL 常温水里溶解…

系统开机时自动执行指令

使用 systemd 创建一个服务单元可以让系统开机时自动执行指令,假设需要执行的指令如下,运行可执行文件(/home/demo/可执行文件),并输入参数(–input/home/config/demo.yaml): /home/…

Docker 初学者需要了解的几个知识点 (七):php.ini

这段配置是 php.ini 文件中针对 PHP 扩展和 Xdebug 调试工具的设置,主要用于让 PHP 支持数据库连接和代码调试(尤其在 Docker 环境中),具体解释如下:[PHP] extensionpdo_mysql extensionmysqli xdebug.modedebug xdebu…

【高阶版】R语言空间分析、模拟预测与可视化高级应用

随着地理信息系统(GIS)和大尺度研究的发展,空间数据的管理、统计与制图变得越来越重要。R语言在数据分析、挖掘和可视化中发挥着重要的作用,其中在空间分析方面扮演着重要角色,与空间相关的包的数量也达到130多个。在本…

dolphinscheduler中一个脚本用于从列定义中提取列名列表

dolphinscheduler中,我们从一个mysql表导出数据,上传到hdfs, 再创建一个临时表,所以需要用到列名定义和列名列表。 原来定义两个变量,不仅繁锁,还容易出现差错,比如两者列序不对。 所以考虑只定义列定义变量…

JavaWeb(苍穹外卖)--学习笔记16(定时任务工具Spring Task,Cron表达式)

前言 本篇文章是学习B站黑马程序员苍穹外卖的学习笔记📑。我的学习路线是Java基础语法-JavaWeb-做项目,管理端的功能学习完之后,就进入到了用户端微信小程序的开发,用户端开发的流程大致为用户登录—商品浏览(其中涉及…

灵敏度,精度,精确度,精密度,精准度,准确度,分辨率,分辨力——概念

文章目录前提总结前提 我最近在整理一份数据指标要求的时候,总是混淆这几个概念:灵敏度,精度,精确度,精密度,精准度,准确度,分辨率,分辨力,搜了一些文章&…

python-异常(笔记)

#后续代码可以正常运行 try:f open("xxx.txt","r",encodingutf-8)except:print("except error")#捕获指定异常,其他异常报错程序中止,管不到 try:print(name) except NameError as you_call:print("name error"…