使用纯NumPy实现回归任务:深入理解机器学习本质

在深度学习框架普及的今天,回归基础用NumPy从头实现机器学习模型具有特殊意义。本文将完整演示如何用纯NumPy实现二次函数回归任务,揭示机器学习底层原理。整个过程不使用任何深度学习框架,每一行代码都透明可见。


1. 环境配置与数据生成

import numpy as np
from matplotlib import pyplot as plt 设置随机种子保证可复现性 
np.random.seed(100)  生成训练数据:100个点在[-1,1]区间均匀分布 
x = np.linspace(-1, 1, 100).reshape(100, 1)基于y=3x²+2生成目标值,并添加高斯噪声
y = 3 * np.power(x, 2) + 2 + 0.2 * np.random.rand(x.size).reshape(100, 1)

**数据可视化结果: **

散点图展示了添加噪声后的数据分布,我们的目标是找到最佳拟合曲线y=wx2+by=wx^2+by=wx2+b


2. 模型初始化与核心参数

随机初始化待学习参数
w = np.random.rand(1, 1)  # 权重参数 (理论值应接近3)
b = np.random.rand(1, 1)  # 偏置项 (理论值应接近2)lr = 0.001  # 学习率 (梯度下降步长)
epochs = 800  # 训练轮数

初始参数可视化:

print(f"初始参数: w={w[0][0]:.4f}, b={b[0][0]:.4f}")
典型输出: w=0.7123, b=0.1582 (每次运行结果不同)

3. 训练过程与数学原理

3.1 前向传播计算预测值

y_pred = np.power(x, 2) * w + b

3.2 损失函数定义

采用均方误差(MSE)的变体:

loss = 0.5 * (y_pred - y)  2 
total_loss = loss.sum()  # 所有样本损失之和 

3.3 梯度计算解析

关键数学推导(链式法则):

权重w的梯度: ∂Loss/∂w = Σ(y_pred - y)*x² 
grad_w = np.sum((y_pred - y) * np.power(x, 2))偏置b的梯度: ∂Loss/∂b = Σ(y_pred - y)
grad_b = np.sum((y_pred - y))

3.4 参数更新(梯度下降)

w -= lr * grad_w  # w = w - η·(∂Loss/∂w)
b -= lr * grad_b  # b = b - η·(∂Loss/∂b)

4. 完整训练代码

for epoch in range(epochs):# 前向传播y_pred = np.power(x, 2) * w + b # 损失计算 loss = 0.5 * (y_pred - y)  2total_loss = loss.sum()# 梯度计算grad_w = np.sum((y_pred - y) * np.power(x, 2))grad_b = np.sum((y_pred - y))# 参数更新w -= lr * grad_w b -= lr * grad_b# 每100轮打印训练进展 if epoch % 100 == 0:print(f"Epoch {epoch}: w={w[0][0]:.4f}, b={b[0][0]:.4f}, Loss={total_loss:.4f}")

训练过程输出:

Epoch 0: w=0.9461, b=0.3827, Loss=160.9256 
Epoch 100: w=2.1433, b=1.8047, Loss=1.8925 
Epoch 200: w=2.6555, b=2.0404, Loss=0.4583
Epoch 300: w=2.8543, b=2.1023, Loss=0.2985 
...
Epoch 700: w=2.9887, b=2.0161, Loss=0.2502 

5. 训练结果可视化

生成预测曲线 
x_test = np.linspace(-1, 1, 30).reshape(30, 1)
y_test = np.power(x_test, 2) * w + b 绘制结果对比图 
plt.figure(figsize=(10, 6))
plt.scatter(x, y, color='blue', alpha=0.5, label='真实数据')
plt.plot(x_test, y_test, 'r-', linewidth=3, label='模型预测')
plt.plot(x_test, 3*x_test2+2, 'g--', label='理论曲线')
plt.xlim(-1, 1)
plt.ylim(2, 6)
plt.legend()
plt.title('NumPy实现回归结果')
plt.show()输出最终参数
print(f"训练结果: w={w[0][0]:.4f} (接近理论值3), b={b[0][0]:.4f} (接近理论值2)")

可视化结果:

红色实线为模型预测曲线,绿色虚线为理论曲线y=3x2+2y=3x^2+2y=3x2+2,蓝色点为带噪声的训练数据


6. 关键技术解析

1. 梯度下降的本质

通过参数空间中的"下坡运动"寻找最优解,学习率控制步长大小:

  • 学习率过大 → 震荡发散
  • 学习率过小 → 收敛缓慢
  • 本例0.001是多次试验后的平衡值

2. 手动求导的意义

# 关键导数计算
grad_w = np.sum((y_pred - y) * np.power(x, 2))

理解此式需掌握:

  • 链式法则:∂Loss/∂w = (∂Loss/∂y_pred)·(∂y_pred/∂w)
  • 损失函数导数:∂Loss/∂y_pred = (y_pred - y)
  • 模型输出导数:∂y_pred/∂w = x²

3. 批量梯度下降特点

  • 每次迭代使用全部样本(不同于随机梯度下降)
  • 计算稳定但内存消耗大
  • 适合中小规模数据集

7. 拓展思考

1. 学习率动态调整

# 添加学习率衰减 
if epoch % 200 == 0:lr *= 0.8  # 每200轮衰减20%

2. 添加正则化项(L2正则化)

# 修改损失函数
lambda_reg = 0.01  # 正则化系数 
loss = 0.5*(y_pred-y)2 + 0.5*lambda_reg*(w2)

3. 动量优化(Momentum)

# 添加动量项
beta = 0.9  # 动量系数 
v_w = beta*v_w + (1-beta)*grad_w 
w -= lr * v_w

8. 总结与启示

NumPy实现的价值

  • 透明机制:每个运算步骤完全可见
  • ⚙️ 数学本质:揭示梯度下降和反向传播核心原理
  • 🔍 调试优势:便于定位问题和理解优化过程

局限性:

  • 📈 仅适合简单模型
  • ⏱️ 复杂网络需大量重复代码
  • 缺乏自动微分等高级功能

通过这个基础实现,我们能更深刻地理解PyTorch/TensorFlow等框架封装的高级功能背后的数学原理,为后续学习打下坚实基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/93072.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/93072.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java理解

springboot 打包 mvn install:install-file -Dfile=<path-to-jar> -DgroupId=<group-id> -DartifactId=<artifact-id> -Dversion=<version> -Dpackaging=jar <path-to-jar> 是你的 JAR 文件的路径。 <group-id> 是你的项目的组 ID。 <…

图论核心算法详解:从存储结构到最短路径(附C++实现)

目录 一、图的基础概念与术语 二、图的存储结构 1. 邻接矩阵 实现思路&#xff1a; 2. 邻接表 实现思路&#xff1a; 应用场景&#xff1a; 时间复杂度分析&#xff1a; 三、图的遍历算法 1. 广度优先搜索&#xff08;BFS&#xff09; 核心思想&#xff1a; 应用场…

力扣top100(day03-02)--图论

本文为力扣TOP100刷题笔记 笔者根据数据结构理论加上最近刷题整理了一套 数据结构理论加常用方法以下为该文章&#xff1a; 力扣外传之数据结构&#xff08;一篇文章搞定数据结构&#xff09; 200. 岛屿数量 class Solution {// DFS辅助方法&#xff0c;用于标记和"淹没&q…

建造者模式:从“参数地狱”到优雅构建

深夜&#xff0c;一条紧急告警刺穿寂静&#xff1a;核心报表服务因NullPointerException全线崩溃。排查根源&#xff0c;罪魁祸首竟是一个拥有10多个参数的“上帝构造函数”。本文将从这个灾难现场出发&#xff0c;引入“链式建造者模式”进行重构&#xff0c;并深入Spring AI、…

jenkins在windows配置sshpass

我的服务器里jenkins是通过docker安装的&#xff0c;jenkins与项目都部署在同一台服务器上还好&#xff0c;但是当需要通过jenkins构建&#xff0c;再通过scp远程推送到别的服务器上&#xff0c;就出问题了&#xff0c;毕竟不是手动执行scp命令&#xff0c;可以手动输入密码&am…

Linux操作系统从入门到实战(十八)在Linux里面怎么查看进程

Linux操作系统从入门到实战&#xff08;十八&#xff09;在Linux里面怎么查看进程前言一、如何识别一个进程&#xff1f;—— PID二、怎么查看进程的信息&#xff1f;方式1&#xff1a;通过/proc目录方式2&#xff1a;用ps命令三、父进程是什么&#xff1f;—— PPID四、bash是…

[TryHackMe](知识学习)---基于堆栈得到缓冲区溢出

1.了解缓冲区溢出WINDOWS程序动态调试工具immunity debuggerhttps://www.immunityinc.com/products/debugger/2.Mona脚本#!/usr/bin/env python3import socket, time, sysip "10.201.99.37"port 1337 timeout 5 prefix "OVERFLOW1 "string prefix &q…

LRU算法与LFU算法

知识点&#xff1a; LRU是Least Recently Used的缩写&#xff0c;意思是最近最少使用&#xff0c;它是一种Cache替换算法 Cache的容量有限&#xff0c;因此当Cache的容量用完后&#xff0c;而又有新的内容需要添加进来时&#xff0c; 就需要挑选 并舍弃原有的部分内容&#xf…

目标检测公开数据集全解析:从经典到前沿

目标检测公开数据集全解析&#xff1a;从经典到前沿 一、引言 目标检测&#xff08;Object Detection&#xff09;是计算机视觉领域的核心任务之一&#xff0c;旨在在图像或视频中识别并定位感兴趣的物体。与图像分类不同&#xff0c;目标检测不仅需要判断物体的类别&#xf…

数据备份与进程管理

一、数据备份1.Linux服务器中需要备份的数据&#xff08;1&#xff09;Linux系统重要数据&#xff1a;/root/目录&#xff0c;/home/目录&#xff0c;/etc/目录&#xff08;2&#xff09;安装服务的数据&#xff1a;Apache&#xff08;配置文件&#xff0c;网页主目录&#xff…

docker volume卷入门教程

1. 基础概念 Docker卷是专门用于持久化容器数据的存储方案&#xff0c;独立于容器生命周期。其核心优势包括&#xff1a; 数据持久化&#xff1a;容器删除后数据仍保留跨容器共享&#xff1a;多个容器可访问同一卷备份与迁移&#xff1a;支持直接复制卷数据驱动支持&#xff1a…

计算机网络——协议

1. 计算机网络分层1.1 OSI 7层模型应用层表示层会话层传输层网络层数据链路层物理层1.2 TCP/IP 4 层模型应用层运输层网际层网络接口层1.3 5层体系机构应用层传输层网络层数据链路层物理层2. 应用层协议2.1 HTTP协议2.1.1 基本介绍HTTP&#xff08;HyperText Transfer Protocol…

【React】hooks 中的闭包陷阱

在 React Hooks 中的 闭包陷阱&#xff08;Closure Trap&#xff09;在 useEffect、事件回调、定时器等场景里很常见。1. 闭包陷阱是什么 当你在函数组件里定义一个回调&#xff08;比如事件处理函数&#xff09;&#xff0c;这个回调会捕获当时渲染时的变量值。如果后面状态更…

校园快递小程序(腾讯地图API、二维码识别、Echarts图形化分析)

&#x1f388;系统亮点&#xff1a;腾讯地图API、二维码识别、Echarts图形化分析&#xff1b;一.系统开发工具与环境搭建1.系统设计开发工具后端使用Java编程语言的Spring boot框架 项目架构&#xff1a;B/S架构 运行环境&#xff1a;win10/win11、jdk17小程序&#xff1a; 技术…

Python网络爬虫(二) - 解析静态网页

文章目录一、网页解析技术介绍二、Beautiful Soup库1. Beautiful Soup库介绍2. Beautiful Soup库几种解析器比较3. 安装Beautiful Soup库3.1 安装 Beautiful Soup 43.2 安装解析器4. Beautiful Soup使用步骤4.1 创建Beautiful Soup对象4.2 获取标签4.2.1 通过标签名获取4.2.2 通…

【Linux基础知识系列】第九十四篇 - 如何使用traceroute命令追踪路由

在网络环境中&#xff0c;了解数据包从源主机到目标主机的路径是非常重要的。这不仅可以帮助我们分析网络连接问题&#xff0c;还可以用于诊断网络延迟、丢包等问题。traceroute命令是一个强大的工具&#xff0c;它能够追踪数据包在网络中的路径&#xff0c;显示每一跳的延迟和…

达梦数据闪回查询-快速恢复表

Time:2025/08/12Author:skatexg一、环境说明DM数据库&#xff1a;DM8.0及以上版本二、适用场景研发在误操作或变更数据后&#xff0c;想马上恢复表到某个时间点&#xff0c;可以通过闪回查询功能快速实现&#xff08;通过全量备份恢复时间长&#xff0c;成本高&#xff09;三、…

力扣(LeetCode) ——225 用队列实现栈(C语言)

题目&#xff1a;用队列实现栈示例1&#xff1a; 输入&#xff1a; [“MyStack”, “push”, “push”, “top”, “pop”, “empty”] [[], [1], [2], [], [], []] 输出&#xff1a; [null, null, null, 2, 2, false] 解释&#xff1a; MyStack myStack new MyStack(); mySta…

微软推出AI恶意软件检测智能体 Project Ire

开篇 在8月5号&#xff0c;微软研究院发布了一篇博客文章&#xff0c;在该篇博客中推出了一款名为Project Ire的AI Agent。该Agent可以在无需人类协助的情况下&#xff0c;自主分析和分类二进制文件。它可以在无需了解二进制文件来源或用途的情况下&#xff0c;对文件进行完全的…

哪些对会交由SpringBoot容器管理?

在 Spring Boot 中,交由容器管理的对象通常称为“Spring Bean”,这些对象的创建、依赖注入、生命周期等由 Spring 容器统一管控。以下是常见的会被 Spring Boot 容器管理的对象类型及识别方式: 一、通过注解声明的组件(最常见) Spring Boot 通过类级别的注解自动扫描并注…