机器学习进阶,梯度提升机(GBM)与XGBoost

梯度提升机(Gradient Boosting Machine, GBM),特别是其现代高效实现——XGBoost。这是继随机森林后自然进阶的方向,也是当前结构化数据竞赛和工业界应用中最强大、最受欢迎的算法之一。

为什么推荐XGBoost?

  1. 与随机森林互补:同属集成学习,但Random Forest是Bagging思想,而XGBoost是Boosting思想。学习它可以帮助你全面理解集成学习的两种主流范式。
  2. State-of-the-Art性能:在表格型数据上,XGBoost通常比随机森林表现更好,是Kaggle等数据科学竞赛中的"大杀器"。
  3. 高效且可扩展:专为速度和性能设计,支持并行处理,能处理大规模数据。
  4. 内置正则化:相比传统GBM,XGBoost自带正则化项,更不容易过拟合。

核心概念:Boosting vs Bagging

● Bagging(随机森林):并行构建多个独立的弱模型,然后通过投票/平均得到最终结果。
● Boosting(XGBoost):串行构建多个相关的弱模型,每个新模型都专注于纠正前一个模型的错误。

完整代码示例

下面我们使用XGBoost来解决同样的鸢尾花分类问题,并与随机森林进行对比。

# xgboost_module.py
# -*- coding: utf-8 -*-"""
XGBoost分类器示例 - 鸢尾花数据集
模块化实现,包含数据加载、模型训练、评估、可视化和高级功能
"""# 1. 导入必要的库
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import accuracy_score, classification_report
from sklearn.ensemble import RandomForestClassifier
import xgboost as xgb
import warnings
warnings.filterwarnings('ignore')# 设置全局样式
plt.style.use('seaborn-v0_8')
np.random.seed(42)  # 设置随机种子以确保结果可重现# 2. 数据加载模块
def load_data():"""加载鸢尾花数据集"""iris = load_iris()X = iris.datay = iris.targetfeature_names = iris.feature_namestarget_names = iris.target_namesreturn X, y, feature_names, target_names# 3. 数据预处理模块
def prepare_data(X, y, test_size=0.2, random_state=42):"""准备训练和测试数据集"""X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state, stratify=y)print(f"训练集大小: {X_train.shape[0]}")print(f"测试集大小: {X_test.shape[0]}")return X_train, X_test, y_train, y_test# 4. 随机森林基准模型模块
def train_random_forest(X_train, y_train, **params):"""训练随机森林模型作为基准"""# 设置默认参数default_params = {'n_estimators': 100,'max_depth': 3,'random_state': 42}# 更新默认参数default_params.update(params)# 初始化并训练模型model = RandomForestClassifier(**default_params)model.fit(X_train, y_train)print("\n=== 随机森林模型训练完成 ===")print(f"使用参数: {default_params}")return model# 5. XGBoost模型训练模块
def train_xgboost(X_train, y_train, **params):"""训练XGBoost模型"""# 设置默认参数default_params = {'n_estimators': 100,'max_depth': 3,'learning_rate': 0.1,'random_state': 42,'use_label_encoder': False,'eval_metric': 'logloss'}# 更新默认参数default_params.update(params)# 初始化并训练模型model = xgb.XGBClassifier(**default_params)model.fit(X_train, y_train)print("\n=== XGBoost模型训练完成 ===")print(f"使用参数: {default_params}")return model# 6. 模型评估模块
def evaluate_model(model, X_test, y_test, model_name="模型"):"""评估模型性能"""# 预测y_pred = model.predict(X_test)# 计算准确率accuracy = accuracy_score(y_test, y_pred)print(f"\n=== {model_name}性能 ===")print(f"测试集准确率: {accuracy:.4f}")return accuracy, y_pred# 7. 交叉验证比较模块
def compare_cv_models(models, X, y, cv=5):"""使用交叉验证比较多个模型"""print("\n=== 交叉验证比较 ===")results = {}for name, model in models.items():scores = cross_val_score(model, X, y, cv=cv, scoring='accuracy')results[name] = scoresprint(f"{name} 交叉验证平均分: {scores.mean():.4f}{scores.std():.4f})")return results# 8. 特征重要性可视化模块
def plot_feature_importance(models, feature_names):"""可视化多个模型的特征重要性"""n_models = len(models)plt.figure(figsize=(5 * n_models, 5))for i, (name, model) in enumerate(models.items(), 1):plt.subplot(1, n_models, i)# 获取特征重要性if hasattr(model, 'feature_importances_'):importances = model.feature_importances_else:# 对于XGBoost模型importances = model.get_booster().get_score(importance_type='weight')# 转换为数组格式importances_array = np.zeros(len(feature_names))for j, feat in enumerate(feature_names):importances_array[j] = importances.get(f"f{j}", 0)importances = importances_array# 排序并绘制indices = np.argsort(importances)[::-1]plt.bar(range(len(feature_names)), importances[indices])plt.xticks(range(len(feature_names)), [feature_names[i] for i in indices], rotation=45)plt.title(f'{name} - Feature Importance')plt.tight_layout()plt.show()# 9. 高级功能:早停法训练模块
def train_xgboost_early_stopping(X_train, y_train, X_test, y_test, **params):"""使用早停法训练XGBoost模型"""# 设置默认参数default_params = {'max_depth': 3,'learning_rate': 0.1,'objective': 'multi:softmax','num_class': 3,'eval_metric': 'mlogloss'}# 更新默认参数default_params.update(params)# 转换为XGBoost的DMatrix格式dtrain = xgb.DMatrix(X_train, label=y_train)dtest = xgb.DMatrix(X_test, label=y_test)# 训练并使用早停法evals = [(dtrain, 'train'), (dtest, 'test')]model = xgb.train(default_params, dtrain, num_boost_round=1000,evals=evals,early_stopping_rounds=10,verbose_eval=False)print("\n=== 早停法训练完成 ===")print(f"在 {model.best_iteration} 轮停止")print(f"最佳验证分数: {model.best_score:.4f}")return model# 10. 预测模块
def make_predictions(model, new_samples, target_names, model_type='sklearn'):"""使用模型进行新样本预测"""if model_type == 'xgboost_early_stop':# 对于早停法训练的XGBoost模型dnew = xgb.DMatrix(new_samples)predictions = model.predict(dnew)# 早停法训练的模型不直接提供概率,需要额外处理print("注意: 早停法训练的XGBoost模型不直接提供概率输出")predictions_proba = Noneelse:# 对于标准sklearn接口的模型predictions = model.predict(new_samples)predictions_proba = model.predict_proba(new_samples)print("\n=== 新样本预测 ===")for i, sample in enumerate(new_samples):predicted_class = target_names[int(predictions[i])]print(f"样本 {i+1} {sample}:")print(f"  预测类别: {predicted_class}")if predictions_proba is not None:print(f"  类别概率: {dict(zip(target_names, predictions_proba[i].round(4)))}")return predictions, predictions_proba# 11. 主函数 - 整合所有模块
def main():"""主函数,整合所有模块"""# 加载数据X, y, feature_names, target_names = load_data()print("=== 鸢尾花数据集 ===")print(f"数据集形状: {X.shape}")print(f"特征名称: {feature_names}")print(f"类别名称: {target_names}")# 准备数据X_train, X_test, y_train, y_test = prepare_data(X, y)# 训练随机森林模型rf_model = train_random_forest(X_train, y_train)rf_accuracy, rf_pred = evaluate_model(rf_model, X_test, y_test, "随机森林")# 训练XGBoost模型xgb_model = train_xgboost(X_train, y_train)xgb_accuracy, xgb_pred = evaluate_model(xgb_model, X_test, y_test, "XGBoost")# 交叉验证比较models = {'随机森林': rf_model,'XGBoost': xgb_model}cv_results = compare_cv_models(models, X, y)# 特征重要性可视化plot_feature_importance(models, feature_names)# 详细分类报告print("\n=== XGBoost详细分类报告 ===")print(classification_report(y_test, xgb_pred, target_names=target_names))# 高级功能:早停法训练xgb_early_model = train_xgboost_early_stopping(X_train, y_train, X_test, y_test)# 进行预测new_samples = [[5.1, 3.5, 1.4, 0.2],  # 很可能为setosa[6.7, 3.0, 5.2, 2.3]   # 很可能为virginica]predictions, predictions_proba = make_predictions(xgb_model, new_samples, target_names)return {'rf_model': rf_model,'xgb_model': xgb_model,'xgb_early_model': xgb_early_model,'rf_accuracy': rf_accuracy,'xgb_accuracy': xgb_accuracy,'cv_results': cv_results,'predictions': predictions}# 12. 执行主程序
if __name__ == "__main__":results = main()

代码解析与学习要点

  1. 参数对比:
    ○ XGBoost有与随机森林相似的参数(n_estimators, max_depth)
    ○ 但也有特有参数如learning_rate(学习率),控制每棵树的贡献程度
  2. 性能比较:
    ○ 代码中比较了两种算法的准确率和交叉验证结果
    ○ 通常情况下,XGBoost会略优于随机森林
  3. 特征重要性:
    ○ 可视化对比两种算法计算的特征重要性
    ○ 注意:两种算法计算重要性的方法不同,结果可能有差异
  4. 高级功能:
    ○ 演示了早停法(Early Stopping),这是防止过拟合的重要技术
    ○ 展示了DMatrix数据格式,这是XGBoost的高效数据容器
  5. 预测概率:
    ○ XGBoost可以提供每个类别的预测概率,这对于不确定性分析很有用

代码运行结果

=== 鸢尾花数据集 ===
数据集形状: (150, 4)
特征名称: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
类别名称: ['setosa' 'versicolor' 'virginica']
训练集大小: 120
测试集大小: 30=== 随机森林模型训练完成 ===
使用参数: {'n_estimators': 100, 'max_depth': 3, 'random_state': 42}=== 随机森林性能 ===
测试集准确率: 0.9667=== XGBoost模型训练完成 ===
使用参数: {'n_estimators': 100, 'max_depth': 3, 'learning_rate': 0.1, 'random_state': 42, 'use_label_encoder': False, 'eval_metric': 'logloss'}=== XGBoost性能 ===
测试集准确率: 0.9333=== 交叉验证比较 ===
随机森林 交叉验证平均分: 0.9667 (±0.0211)
XGBoost 交叉验证平均分: 0.9467 (±0.0267)=== XGBoost详细分类报告 ===precision    recall  f1-score   supportsetosa       1.00      1.00      1.00        10versicolor       0.90      0.90      0.90        10virginica       0.90      0.90      0.90        10accuracy                           0.93        30macro avg       0.93      0.93      0.93        30
weighted avg       0.93      0.93      0.93        30=== 早停法训练完成 ===
在 33 轮停止
最佳验证分数: 0.1948=== 新样本预测 ===
样本 1 [5.1, 3.5, 1.4, 0.2]:预测类别: setosa类别概率: {'setosa': 0.9911, 'versicolor': 0.0067, 'virginica': 0.0023}
样本 2 [6.7, 3.0, 5.2, 2.3]:预测类别: virginica类别概率: {'setosa': 0.0019, 'versicolor': 0.0025, 'virginica': 0.9956}

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/95639.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/95639.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【ARMv7】开篇:掌握ARMv7架构Soc开发技能

本专栏,开始与大家共同总结使用ARMv7系列CPU的Soc开发技能。大概汇总了一下,后面再逐步完善下面的思维导图。简单说说:与通用的ARMv7-A/R相比,以STM32F为代表的ARMv7-M架构有以下关键区别和重点:无MMU,有MP…

【学术会议论文投稿】JavaScript在数据可视化领域的探索与实践

【ACM出版 | EI快检索 | 高录用】2024年智能医疗与可穿戴智能设备国际学术会议(SHWID 2024)_艾思科蓝_学术一站式服务平台 更多学术会议请看 学术会议-学术交流征稿-学术会议在线-艾思科蓝 目录 引言 JavaScript可视化库概览 D3.js基础入门 1. 引入…

CSS基础学习步骤

好的,这是一份为零基础初学者量身定制的 **CSS 学习基础详细步骤**。我们将从最根本的概念开始,通过一步一步的实践,带你稳稳地入门。 第一步:建立核心认知 - CSS 是做什么的? 1. 理解角色: HTML&…

MTK Linux DRM分析(三十七)- MTK phy-mtk-hdmi.c 和 phy-mtk-hdmi-mt8173.c

一、简介 HDMI PHY驱动 HDMI 的物理层接口主要就是 HDMI Type-A 接口(19 pin),除此之外还有 Type-B、Type-C(Mini HDMI)、Type-D(Micro HDMI)、Type-E(车载专用)。 1. HDMI Type-A(常见 19-pin 标准接口) HDMI Type-A Connector Pinout ========================…

【人工智能学习之MMdeploy部署踩坑总结】

【人工智能学习之MMdeploy部署踩坑总结】报错1:TRTNet: device must be a GPU!报错2:Failed to create Net backend: tensorrt报错3:Failed to load library libonnxruntime_providers_shared.so1. 确认库文件是否存在2. 重新安装 ONNX Runti…

力扣516 代码随想录Day16 第一题

找二叉树左下角的值class Solution { public:int maxd0;int result;void traversal(TreeNode* root,int depth){if(root->leftNULL&&root->rightNULL){if(depth>maxd){maxddepth;resultroot->val;}}if(root->left){depth;traversal(root->left,depth…

网格图--Day07--网格图DFS--LCP 63. 弹珠游戏,305. 岛屿数量 II,2061. 扫地机器人清扫过的空间个数,489. 扫地机器人,2852. 所有单元格的远离程度之和

网格图–Day07–网格图DFS–LCP 63. 弹珠游戏,305. 岛屿数量 II,2061. 扫地机器人清扫过的空间个数,489. 扫地机器人,2852. 所有单元格的远离程度之和 今天要训练的题目类型是:【网格图DFS】,题单来自灵茶山…

多功能修改电脑机器码序列号工具 绿色版

多功能修改电脑机器码序列号工具 绿色版电脑机器码序列号修改软件是一款非常使用的数据化虚拟修改工具。机器码修改软件可以虚拟的定制您电脑上的硬件信息,软件不会对您的电脑造成伤害。软件不需要您有专业的知识,就可以模拟一份硬件信息。机器码修改软…

React Hooks深度解析:useState、useEffect及自定义Hook最佳实践

React Hooks自16.8版本引入以来,彻底改变了我们编写React组件的方式。它们让函数组件拥有了状态管理和生命周期方法的能力,使代码更加简洁、可复用且易于测试。本文将深入探讨三个最重要的Hooks:useState、useEffect,以及如何创建…

期权平仓后权利金去哪了?

本文主要介绍期权平仓后权利金去哪了?期权平仓后权利金的去向需结合交易角色(买方/卖方)、平仓方式及市场价格变动综合分析,具体可拆解为以下逻辑链条。期权平仓后权利金去哪了?1. 买方平仓:权利金的“差价…

2025国赛C题题目及最新思路公布!

C 题 NIPT 的时点选择与胎儿的异常判 问题 1 试分析胎儿 Y 染色体浓度与孕妇的孕周数和 BMI 等指标的相关特性,给出相应的关系模 型,并检验其显著性。 思路1:针对附件中孕妇的 NIPT 数据,首先对数据进行预处理,并对多…

NLP技术爬取

“NLP技术爬取”这个词组并不指代一种单独的爬虫技术,而是指将自然语言处理(NLP)技术应用于网络爬虫的各个环节,以解决传统爬虫难以处理的问题,并从中挖掘出更深层次的价值。简单来说,它不是指“用NLP去爬”…

让录音变得清晰的软件:语音降噪AI模型与工具推荐

在数字内容创作日益普及的今天,无论是播客、线上课程、视频口播,还是远程会议,清晰的录音质量都是提升内容专业度和观众体验的关键因素之一。然而,由于环境噪音、设备限制等因素,录音中常常夹杂各种干扰声音。本文将介…

大话 IOT 技术(1) -- 架构篇

文章目录前言抛出问题现有条件初步设想HTTP 与 MQTT中间的服务端完整的链路测试的虚拟设备实现后话当你迷茫的时候,请点击 物联网目录大纲 快速查看前面的技术文章,相信你总能找到前行的方向 前言 Internet of Things (IoT) 就是物联网,万物…

【wpf】WPF 自定义控件绑定数据对象的最佳实践

WPF 自定义控件绑定数据对象的最佳实践:以 ImageView 为例 在 WPF 中开发自定义控件时,如何优雅地绑定数据对象,是一个经常遇到的问题。最近在实现一个自定义的 ImageView 控件时,我遇到了一个典型场景: 控件内部需要使…

[Dify 专栏] 如何通过 Prompt 在 Dify 中模拟 Persona:即便没有专属配置,也能让 AI 扮演角色

在 AI 应用开发中,“Persona(角色扮演)”常被视为塑造 AI 个性与专业边界的重要手段。然而,许多开发者在使用 Dify 时会疑惑:为什么我在 Chat 应用 / Agent 应用 / Workflow 里都找不到所谓的 Persona 配置项? 答案是:Dify 平台目前并没有内建的 Persona 配置入口。角色…

解决双向循环链表中对存储数据进行奇偶重排输出问题

1. 概念 对链表而言,双向均可遍历是最方便的,另外首尾相连循环遍历也可大大增加链表操作的便捷性。因此,双向循环链表,是在实际运用中是最常见的链表形态。 2. 基本操作 与普通的链表完全一致,双向循环链表虽然指针较多,但逻辑是完全一样。基本的操作包括: 节点设计 初…

Kubernetes集群升级与etcd备份恢复指南

目录 Kubernetes etcd备份恢复 集群管理命令 环境变量 查看etcd版本 查看etcd集群节点信息 查看集群健康状态 查看告警事件 添加成员(单节点部署的etcd无法直接扩容)(不用做) 更新成员 删除成员 数据库操作命令 增加(put) 查询(get) 删除(…

【LeetCode热题100道笔记】旋转图像

题目描述 给定一个 n n 的二维矩阵 matrix 表示一个图像。请你将图像顺时针旋转 90 度。 你必须在 原地 旋转图像,这意味着你需要直接修改输入的二维矩阵。请不要 使用另一个矩阵来旋转图像。 示例 1:输入:matrix [[1,2,3],[4,5,6],[7,8,9]…

SpringBoot【集成p6spy】使用p6spy-spring-boot-starter集成p6spy监控数据库(配置方法举例)

使用p6spy-spring-boot-starter集成p6spy监控数据库1.简单说明2.核心依赖3.主要配置4.简单测试5.其他配置1.简单说明 p6spy 类似于 druid 可以拦截 SQL 可以用于项目调试,直接引入 p6spy 的博文已经很多了,这里主要是介绍一下 springboot 使用 p6spy-sp…