机器学习回顾——决策树详解

决策树基础概念与应用详解

1. 决策树基础概念

1.1 什么是决策树

决策树是一种树形结构的预测模型,其核心思想是通过一系列规则对数据进行递归划分。它模拟人类决策过程,广泛应用于分类和回归任务。具体结构包括:

  • 内部节点:表示对某个特征的条件判断,例如"年龄>30岁?"或"收入<5万?"
  • 分支:代表判断结果的可能取值,如"是/否"或离散特征的各个类别
  • 叶节点:包含最终的预测结果。在分类任务中可能输出"批准贷款"或"拒绝贷款";在回归任务中可能输出具体数值如"房价=45.6万"

1.2 决策树的主要组成部分

  1. 根节点:位于树的最顶端,包含完整的训练数据集。例如在客户信用评估中,根节点可能包含所有申请人的特征数据
  2. 决策节点:进行条件判断的内部节点,通常会选择最具区分度的特征进行判断。如选择"信用评分"而非"性别"作为关键判断标准
  3. 叶节点:终止节点,存储最终决策结果。在医疗诊断中,可能输出"良性"或"恶性"的诊断结论
  4. 分支:连接节点的路径,表示决策条件的具体取值。例如"体温>38.5℃"的分支可能导向"疑似感染"的子节点

1.3 决策树的工作流程

决策树的构建遵循以下详细步骤:

  1. 特征选择:在当前节点计算所有特征的分裂质量(使用信息增益、基尼指数等指标),选择最优特征
  2. 数据分割:根据选定特征的取值将数据集划分为若干子集。例如将"收入"特征按阈值50k划分为高/低收入两组
  3. 递归构建:对每个子节点重复步骤1-2,直到满足以下任一停止条件:
    • 节点样本数小于预设阈值(如min_samples_leaf=5)
    • 所有样本属于同一类别(纯度达到100%)
    • 树达到最大深度限制(如max_depth=10)
    • 继续分裂不能显著改善模型性能
  4. 剪枝处理:为防止过拟合,可能进行预剪枝或后剪枝操作

2. 决策树构建算法

2.1 ID3算法

核心思想:通过最大化信息增益来选择特征分裂点,倾向于选择能够最有效降低不确定性的特征。

信息熵计算: H(S) = -∑ p_i log₂ p_i 其中p_i是第i类样本在集合S中的比例。例如对于一个二分类问题(正例60%,负例40%): H(S) = -0.6log₂0.6 - 0.4log₂0.4 ≈ 0.971

信息增益计算: Gain(S, A) = H(S) - ∑ (|S_v|/|S|) * H(S_v) 其中S_v是特征A取值为v的子集。例如,对于包含100个样本的节点,按特征A分为两个子集(60个和40个),分别计算其熵值后加权平均。

实际应用中的局限性

  1. 偏向于选择取值较多的特征(如"用户ID"这种唯一标识符)
  2. 无法直接处理连续型特征,需要预先离散化
  3. 对缺失值敏感,缺乏有效的处理机制
  4. 没有剪枝步骤,容易生成过深的树导致过拟合

2.2 C4.5算法

核心改进

  1. 信息增益率:解决ID3对多值特征的偏好问题 GainRatio(S, A) = Gain(S, A) / SplitInfo(S, A) 其中SplitInfo(S, A) = -∑ (|S_v|/|S|) * log₂(|S_v|/|S|) 这相当于对信息增益进行标准化处理

  2. 连续特征处理:采用二分法自动离散化连续特征

    • 对特征值排序后,取相邻值的中点作为候选分割点
    • 选择信息增益率最大的分割点
  3. 缺失值处理

    • 在计算信息增益时,仅使用特征A不缺失的样本
    • 预测时,如果遇到缺失值,可以按照分支样本比例分配
  4. 剪枝策略

    • 采用悲观剪枝(PEP)方法
    • 基于统计显著性检验决定是否剪枝

2.3 CART算法

算法特点

  1. 二叉树结构:每个节点只产生两个分支,简化决策过程

    • 对于离散特征:生成"是否属于某类别"的判断
    • 对于连续特征:生成"是否≤阈值"的判断
  2. 基尼指数(用于分类): Gini(S) = 1 - ∑ p_i² 基尼指数表示从数据集中随机抽取两个样本,其类别不一致的概率。值越小表示纯度越高。

  3. 回归树实现

    • 分裂标准:最小化平方误差 MSE = 1/n ∑ (y_i - ŷ)²
    • 叶节点输出:该节点所有样本的目标变量均值
    • 特征选择:选择使MSE降低最多的特征和分割点

3. 决策树的关键技术

3.1 特征选择标准

分类任务

  1. 信息增益(ID3):

    • 优点:理论基础强,符合信息论原理
    • 缺点:对多值特征有偏好
  2. 信息增益率(C4.5):

    • 优点:解决了多值特征偏好问题
    • 缺点:可能过度补偿,倾向于选择分裂信息小的特征
  3. 基尼指数(CART):

    • 优点:计算简单,不涉及对数运算
    • 缺点:没有信息增益的理论基础

回归任务

  1. 方差减少:选择使子节点目标变量方差和最小的分裂
  2. 最小二乘偏差:直接优化预测误差的平方和

3.2 剪枝技术

预剪枝(提前停止树的生长):

  1. 最大深度限制(max_depth):控制树的层数
  2. 最小样本分裂数(min_samples_split):节点样本数少于该值则不再分裂
  3. 最小叶节点样本数(min_samples_leaf):确保叶节点有足够样本支撑
  4. 最大特征数(max_features):限制每次分裂考虑的特征数量

后剪枝(先构建完整树再修剪):

  1. 代价复杂度剪枝(CCP):

    • 计算各节点的α值:α = (R(t)-R(T_t))/(|T_t|-1)
    • 剪去使整体代价函数Cα(T)=R(T)+α|T|最小的子树
    • 通过交叉验证选择最优α
  2. 悲观错误剪枝(PEP):

    • 基于统计检验,认为训练误差是乐观估计
    • 使用二项分布计算误差上限,决定是否剪枝

3.3 连续值和缺失值处理

连续值处理流程

  1. 对特征值排序(如年龄:22,25,28,30,36,...)
  2. 取相邻值中点作为候选分割点(如(22+25)/2=23.5)
  3. 计算每个候选点的分裂质量指标
  4. 选择最佳分割点构建决策节点

缺失值处理策略

  1. 替代法

    • 分类:用众数填充
    • 回归:用均值填充
    • 优点:实现简单
    • 缺点:可能引入偏差
  2. 概率分配

    • 根据特征值分布概率将样本分配到各分支
    • 保持样本权重不变
    • 更合理但实现复杂
  3. 特殊分支

    • 为缺失值创建专用分支路径
    • 需要足够多的缺失样本支持

4. 决策树的优缺点

4.1 优势分析

  1. 模型可解释性

    • 决策路径可以直观展示,适合需要解释预测结果的场景
    • 例如在信贷审批中,可以明确告知客户"因收入不足被拒"
  2. 数据处理优势

    • 不需要特征缩放(如标准化)
    • 能同时处理数值型(年龄、收入)和类别型(性别、职业)特征
    • 自动进行特征选择(忽略无关特征)
  3. 计算效率

    • 预测时间复杂度为O(树深度),通常非常高效
    • 适合实时预测场景,如欺诈检测
  4. 多功能性

    • 可处理多输出问题(同时预测多个目标变量)
    • 适用于分类和回归任务
  5. 可视化能力

    • 可通过图形展示决策流程
    • 便于向非技术人员解释模型逻辑

4.2 局限性

  1. 过拟合风险

    • 可能生成过于复杂的树,捕捉数据中的噪声
    • 需要通过剪枝、设置最小叶节点样本数等约束控制
  2. 模型不稳定性

    • 训练数据的微小变化可能导致完全不同的树结构
    • 可通过集成方法(如随机森林)缓解
  3. 局部最优问题

    • 贪心算法无法保证全局最优
    • 可能错过更好的特征组合分裂方式
  4. 连续变量处理

    • 需要将连续特征离散化
    • 可能丢失连续特征的精细信息
  5. 忽略特征相关性

    • 独立考虑每个特征的分裂效果
    • 无法捕捉特征间的交互作用

5. 决策树的实际应用

5.1 分类任务实现(乳腺癌诊断)

from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier, export_graphviz
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt# 加载威斯康星乳腺癌数据集
data = load_breast_cancer()
X = data.data  # 包含30个特征(半径、纹理等)
y = data.target  # 0=恶性,1=良性
feature_names = data.feature_names# 划分训练测试集(70%训练,30%测试)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)# 设置参数网格进行调优
param_grid = {'criterion': ['gini', 'entropy'],'max_depth': [3, 5, 7, None],'min_samples_split': [2, 5, 10],'min_samples_leaf': [1, 2, 5]
}# 创建决策树分类器
clf = DecisionTreeClassifier(random_state=42)# 使用网格搜索寻找最优参数
grid_search = GridSearchCV(clf, param_grid, cv=5, scoring='accuracy')
grid_search.fit(X_train, y_train)# 最佳参数模型
best_clf = grid_search.best_estimator_
print(f"最佳参数:{grid_search.best_params_}")# 在测试集上评估
y_pred = best_clf.predict(X_test)
print(f"测试集准确率:{accuracy_score(y_test, y_pred):.4f}")
print(classification_report(y_test, y_pred))# 可视化混淆矩阵
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(6,6))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.colorbar()
plt.xticks([0,1], ["Malignant", "Benign"])
plt.yticks([0,1], ["Malignant", "Benign"])
plt.xlabel("Predicted")
plt.ylabel("Actual")
for i in range(2):for j in range(2):plt.text(j, i, str(cm[i,j]), ha="center", va="center", color="white" if cm[i,j] > cm.max()/2 else "black")
plt.show()# 导出决策树图形(需要graphviz)
export_graphviz(best_clf, out_file="breast_cancer_tree.dot",feature_names=feature_names,class_names=data.target_names,filled=True, rounded=True)

5.2 回归任务实现(房价预测)

from sklearn.datasets import fetch_california_housing
from sklearn.tree import DecisionTreeRegressor
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import numpy as np
import pandas as pd# 加载加州房价数据集
housing = fetch_california_housing()
X = housing.data  # 8个特征(经度、纬度、房龄等)
y = housing.target  # 房价中位数(单位:10万美元)
feature_names = housing.feature_names# 转换为DataFrame便于分析
df = pd.DataFrame(X, columns=feature_names)
df['MedHouseVal'] = y# 划分训练测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 设置参数分布进行随机搜索
param_dist = {'criterion': ['mse', 'friedman_mse', 'mae'],'max_depth': np.arange(3, 15),'min_samples_split': np.arange(2, 20),'min_samples_leaf': np.arange(1, 15),'max_features': ['auto', 'sqrt', 'log2', None]
}# 创建决策树回归器
reg = DecisionTreeRegressor(random_state=42)# 随机参数搜索
random_search = RandomizedSearchCV(reg, param_dist, n_iter=100, cv=5,scoring='neg_mean_squared_error',random_state=42)
random_search.fit(X_train, y_train)# 最佳参数模型
best_reg = random_search.best_estimator_
print(f"最佳参数:{random_search.best_params_}")# 在测试集上评估
y_pred = best_reg.predict(X_test)
print(f"测试集MSE:{mean_squared_error(y_test, y_pred):.4f}")
print(f"测试集MAE:{mean_absolute_error(y_test, y_pred):.4f}")
print(f"测试集R²:{r2_score(y_test, y_pred):.4f}")# 特征重要性分析
importance = pd.DataFrame({'feature': feature_names,'importance': best_reg.feature_importances_
}).sort_values('importance', ascending=False)# 绘制特征重要性
plt.figure(figsize=(10,6))
plt.barh(importance['feature'], importance['importance'])
plt.xlabel("Feature Importance")
plt.title("Decision Tree Feature Importance")
plt.gca().invert_yaxis()
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/95180.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/95180.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux开发必备:yum/vim/gcc/make全攻略

目录 1.学习yum、apt⼯具&#xff0c;进⾏软件安装 1-1 什么是软件包 1-2 yum/apt具体操作 2. 编辑器Vim 2-1 Linux编辑器-vim的引入 2-2 vim的基本概念 2-3 vim的基本操作 2-4 vim正常模式命令集 2-5 vim末⾏模式命令集 3. 编译器gcc/g 3-1 背景知识 3-2 gcc编译选…

【Linux系统】万字解析,进程间的信号

前言&#xff1a; 上文我们讲到了&#xff0c;进程间通信的命名管道与共享内存&#xff1a;【Linux系统】命名管道与共享内存-CSDN博客​​​​​​ 本文我们来讲一讲&#xff0c;进程的信号问题 点个关注&#xff01; 信号概念 信号是OS发送给进程的异步机制&#xff01;所谓异…

AI时代SEO关键词实战解析

内容概要 随着人工智能技术深度融入搜索引擎的运行机制&#xff0c;传统的SEO关键词研究方法正经历着根本性的变革。本文聚焦于AI时代背景下&#xff0c;如何利用智能化的策略精准定位目标用户&#xff0c;实现搜索可见度的实质性跃升。我们将深入探讨AI技术如何革新关键词研究…

Spring Boot + Spring MVC 项目结构

下面一个既能返回 JSP 页面&#xff0c;又能提供 JSON API 的 Spring Boot Spring MVC 项目结构&#xff0c;这样你就能同时用到 Controller 和 RestController 的优势。 &#x1f3d7; 项目结构 springboot-mvc-mixed/ ├── src/main/java/com/example/demo/ │ ├── …

通俗易懂的讲解下Ceph的存储原理

Ceph存储原理解析 要理解 Ceph 的存储原理&#xff0c;我们可以用一个 “分布式仓库” 的比喻来拆解 —— 把 Ceph 想象成一个由多个 “仓库管理员”&#xff08;硬件节点&#xff09;共同打理的大型仓库&#xff0c;能高效存储、管理海量货物&#xff08;数据&#xff09;&…

软件测试小结(1)

一、什么是测试&#xff1f;1.1 生活中常见的测试例如去商场买衣服&#xff1a;①、选择一件符合审美的衣服 -> 外观测试&#xff1b;②、穿上身上试试是否合身 -> 试穿测试&#xff1b;③、 看看衣服的材料是否纯棉 -> 材料测试&#xff1b;④、 询问衣服的价格 ->…

Python未来3-5年技术发展趋势分析:从AI到Web的全方位演进

Python作为全球最流行的编程语言之一&#xff0c;在开发者社区中占据核心地位。其简洁语法、丰富库生态和跨领域适用性&#xff0c;使其在AI、Web开发、数据科学等领域持续领先。本文基于当前技术演进趋势&#xff08;如2023-2024年的开源项目、社区讨论和行业报告&#xff09;…

【ComfyUI】SDXL Turbo一步完成高速高效的图像生成

今天演示的案例是一个基于 ComfyUI 与 Stable Diffusion XL Turbo 的图生图工作流。整体流程通过加载轻量化的 Turbo 版本模型&#xff0c;在文本编码与调度器的配合下&#xff0c;以极快的推理速度完成从提示词到高质量图像的生成。 配合演示图可以直观感受到&#xff0c;简洁…

基于 GPT-OSS 的在线编程课 AI 助教追问式对话 API 开发全记录

本文记录了如何在 3 天内使用 GPT-OSS 开源权重搭建一个 在线编程课 AI 助教追问式对话 API&#xff0c;从需求分析、数据准备到微调与部署全流程实战。 1️⃣ 需求与指标 回答准确率 ≥ 95%响应延迟 < 1 秒支持多学生并发提问 2️⃣ 数据准备 收集课程问答对清理无效数据…

YOLO v11 目标检测+关键点检测 实战记录

流水账记录一下yolo目标检测 1.搭建pytorch 不做解释 看以往博客或网上搜都行 2.下载yolo源码 &#xff1a; https://github.com/ultralytics/ultralytics 3.样本标注工具&#xff1a;labelme 自己下载 4.准备数据集 4.1 新建一个放置数据集的路径4.2 构建训练集和测试集 运行以…

uniApp 混合开发全指南:原生与跨端的协同方案

uniApp 作为跨端框架&#xff0c;虽能覆盖多数场景&#xff0c;但在需要调用原生能力&#xff08;如蓝牙、传感器&#xff09;、集成第三方原生 SDK&#xff08;如支付、地图&#xff09; 或在现有原生 App 中嵌入 uniApp 页面时&#xff0c;需采用「混合开发」模式。本文将系统…

【大模型】使用MLC-LLM转换和部署Qwen2.5 0.5B模型

目录 ■准备工作 下载模型 安装依赖 安装基础依赖 安装mlc-llm ■权重转换 ■生成配置文件 ■模型编译 GPU版本编译 CPU版本编译 ■启动服务 启动GPU服务 启动CPU服务 ■服务测试 ■扩展 优化量化版本(可选,节省内存) INT4量化版本 调整窗口大小以节省内存…

云计算学习100天-第43天-cobbler

目录 Cobbler 基本概念 命令 搭建cobbler 网络架构 Cobbler 基本概念 Cobbler是一款快速的网络系统部署工具&#xff0c;比PXE配置简单 集中管理所需服务&#xff08;DHCP、DNS、TFTP、WEB&#xff09; 内部集成了一个镜像版本仓库 内部集成了一个ks应答文件仓库 提供…

接口测试:如何定位BUG的产生原因

1小时postman接口测试从入门到精通教程我们从在日常功能测试过程中对UI的每一次操作说白了就是对一个或者多个接口的一次调用&#xff0c;接口的返回的内容(移动端一般为json)经过前端代码的处理最终展示在页面上。http接口是离我们最近的一层接口&#xff0c;web端和移动端所展…

GPIO的8种工作方式

GPIO的8种工作方式&#xff1a;一、4 种输入模式1.1 Floating Input 浮空输入1.2 Pull-up Input 上拉输入1.3 Pull-down Input 下拉输入1.4 Analog Input 模拟输入二、4种输出模式2.1 General Push-Pull Output 推挽输出2.2 General Open-Drain Output 开漏输出2.3…

LeetCode算法日记 - Day 29: 重排链表、合并 K 个升序链表

目录 1. 重排链表 1.1 题目解析 1.2 解法 1.3 代码实现 2. 合并 K 个升序链表 2.1 题目解析 2.2 解法 2.3 代码实现 1. 重排链表 143. 重排链表 - 力扣&#xff08;LeetCode&#xff09; 给定一个单链表 L 的头节点 head &#xff0c;单链表 L 表示为&#xff1a; L…

算法模板(Java版)_前缀和与差分

ZZHow(ZZHow1024) &#x1f4a1; 差分是前缀和的逆运算。 前缀和 &#x1f4a1; 前缀和作用&#xff1a;快速求出 [l, r] 区间的和。 一维前缀和 例题&#xff1a;AcWing 795. 前缀和 import java.util.Scanner;public class Main {public static void main(String[] args)…

openssl使用SM2进行数据加密和数据解密

一、准备工作 1. 安装依赖 sudo apt-get update sudo apt-get install libssl-dev2. 确认 OpenSSL 版本 openssl version如果是 1.1.1 或 3.0&#xff0c;就支持 SM2/SM3/SM4。二、C 语言示例代码 这个程序会&#xff1a; 生成 SM2 密钥对使用公钥加密一段明文使用私钥解密恢复…

用滑动窗口与线性回归将音频信号转换为“Token”序列:一种简单的音频特征编码方法

在深度学习和语音处理领域&#xff0c;如何将原始音频信号有效地表示为离散的“Token”序列&#xff0c;是语音识别、音频生成等任务中的关键问题。常见的方法如Mel频谱图向量量化&#xff08;VQ&#xff09;、wav2vec等已经非常成熟&#xff0c;但这些模型通常依赖复杂的神经网…

Vue开发准备

vs code VSCode的下载地址https://code.visualstudio.com/Download Node.js node.js的下载地址 https://nodejs.org/zh-cn/download 注意&#xff1a;nodejs安装路径不要和vscode安装到同一个文件夹&#xff0c;两个应用分别装到两个不同的文件夹 npm config set cache &q…