机器学习 入门——决策树分类

决策树是一种直观且强大的机器学习算法,适用于分类和回归任务。本文将全面介绍决策树分类的原理、实现、调优和实际应用。

一、什么是决策树分类

1.概念

决策树分类是一种树形结构的分类模型,它通过递归地将数据集分割成更小的子集来构建决策规则。就像我们日常生活中做决策一样(例如:如果天气晴朗,就去公园;否则在家看电影),决策树通过一系列的判断条件来对数据进行分类。下图为一个决策树

2.构建过程

  1. ​特征选择​​:

    • 使用指标(如信息增益、增益率或基尼指数)选择最佳分裂特征。

    • ​信息增益​​(ID3算法):选择使信息熵下降最多的特征。

    • ​增益率​​(C4.5算法):解决信息增益对多值特征的偏好问题。

    • ​基尼指数​​(CART算法):衡量数据不纯度,值越小纯度越高。

  2. ​节点分裂​​:

    • 根据特征的阈值(连续值)或类别(离散值)将数据划分为子集。

    • 递归处理子集,直到满足停止条件。

  3. ​剪枝策略​​(防止过拟合):

    • ​预剪枝​​:在分裂前评估,若增益不足则停止分裂。

    • ​后剪枝​​:先构建完整树,再自底向上剪去不重要的分支。

二、决策树的分类标准

1、信息增益(Information Gain)

1. 核心概念​

​(1)熵(Entropy)​

  • ​定义​​:衡量数据集的不确定性(混乱程度)。熵越大,数据越无序。

    • 公式:

      
      
      • S:当前数据集;

      • pi:类别 i在数据集中的比例;

      • c:类别总数。

  • ​例子​​:

    • 若数据集全是同一类别(如全为“是”),熵为0(完全确定)。

    • 若类别均匀分布(如“是”“否”各占50%),熵为1(最大不确定性)。

天气

温度

湿度

风力

是否出去玩

多云

正常

上述列表中类别分布为3个“是”,2个“否”。

所以信息熵为

​(2)信息增益(Information Gain)​

  • ​定义​​:划分前后熵的减少量,反映属性对分类的贡献。

    • 公式:

      
      
      • A:候选属性;

      • Values(A):属性 A的所有可能取值;

      • Sv​:属性 A取值为 v的子集。

  • ​目标​​:选择使 IG(S,A)最大的属性 A。

对于上述数据,可以计算出每个属性的信息增益以天气为例

  • 取值​​:晴(2条)、多云(1条)、雨(2条)。

  • 子集熵计算:

    • 晴:2条全为“否” → Entropy(S晴​)=0。

    • 多云:1条全为“是” → Entropy(S多云​)=0。

    • 雨:2条全为“是” → Entropy(S雨​)=0。

  • 信息增益:

    
    (同理可计算其他属性的信息增益,选择最大的作为划分节点。)

2. 构建决策树

通过计算信息增益我们就可​以构建决策树。​

根节点划分(天气)​

  • ​天气 = 晴​​:2条数据,全为“否” → ​​叶节点(否)​​。

  • ​天气 = 多云​​:1条数据,全为“是” → ​​叶节点(是)​​。

  • ​天气 = 雨​​:2条数据,全为“是” → ​​叶节点(是)​​。

此时决策树已完全分类,无需进一步划分(所有子集纯度100%)。

但若假设“雨”的子集不纯(例如有“否”),则需继续划分其他属性。

 2、信息增益比

信息增益比是​​对信息增益的改进​​,用于解决信息增益对多值属性的偏好问题。

通过引入属性的​​固有值(Intrinsic Value)​​,惩罚取值较多的属性,从而平衡划分标准。

公式​​:

其中:

  • IntrinsicValue(A)=

  • Values(A):属性 A的所有取值,Sv​是取值为 v的子集。

  • InformationGain(A)为a属性的信息增益

 计算步骤(示例)​

沿用之前的天气数据集,但假设“湿度”有更多取值以演示效果:

天气

湿度(新)

是否出去玩

80%

85%

多云

90%

75%

60%

​Step 1: 计算信息增益(IG)​

  • 对属性“湿度”(连续属性需离散化,假设分为高/正常):

    • 高(80%, 85%, 90%, 75%):3否,1是 → 熵 ≈ 0.811

    • 正常(60%):1是 → 熵 = 0

    • IG(湿度)=0.971−(54​×0.811+51​×0)≈0.322

​Step 2: 计算固有值(IV)​

  • 湿度取值分布:高(4条)、正常(1条)。

    
    

​Step 3: 计算增益比​

GainRatio(湿度)=0.7220.322​≈0.446

3、GINI系数

基尼系数是决策树(如CART算法)中用于衡量数据​​不纯度​​的指标,表示从数据集中随机抽取两个样本,其类别标签不一致的概率。

GINI系数

  • ​公式​​:

    • S:当前数据集;

    • pi​:类别 i在数据集中的比例;

    • c:类别总数。

  • ​范围​​:0(完全纯净)到 0.5(均匀分布的两类)

特征A条件下的加权基尼系数​

​公式​​:


  • Values(A):特征A的所有可能取值(如“天气”取值为晴、雨、多云)。

  • Sv​:特征A取值为 v的子数据集。

  • S:特征A的数据集数

  • Gini(Sv​):子集 Sv​的基尼系数。​

假设银行根据以下特征决定是否批准贷款申请:

年龄

收入

学历

是否有房产

是否批准贷款

青年

高中

青年

高中

青年

本科

中年

本科

中年

硕士

老年

硕士

​目标​​:预测“是否批准贷款”

​特征​​:年龄、收入、学历、是否有房产

​Step 1: 计算初始基尼系数​

  • 类别分布:3“否”,3“是”。

  • 初始基尼系数:

    Gini(S)=1−((63​)2+(63​)2)=1−(0.25+0.25)=0.5

Step 2: 计算各特征的基尼增益​

​(1)特征:年龄​

  • ​取值​​:青年(3条)、中年(2条)、老年(1条)。

  • ​子集基尼系数​​:

    • 青年:3条(全“否”)→ Gini=1−(1^2+0^2)=0

    • 中年:2条(1“否”,1“是”)→ Gini=1−(0.5^2+0.5^2)=0.5

    • 老年:1条(全“是”)→ Gini=0

  • ​加权基尼系数​​:

    Ginisplit​(年龄)=1/2​×0+62​×0.5+1/6​×0≈0.167

(2)特征:是否有房产​

  • ​取值​​:有(3条)、无(3条)。

  • ​子集基尼系数​​:

    • 有:3条(1“否”,2“是”)→ Gini=1−(3/1​)^2−(3/2​)^2≈0.444

    • 无:3条(2“否”,1“是”)→ Gini≈0.444

  • ​加权基尼系数​​:

    Ginisplit​(房产)=1/2​×0.444+1/2​×0.444≈0.444

4、决策树剪枝(Pruning)

1. 剪枝的目的​

决策树容易过拟合(Overfitting),当你的数据量过大时,会导致树深度过大或节点过多。剪枝通过​​移除部分分支或子树​​,简化模型结构,提升泛化能力(指模型在​​未见过的数据​​上表现良好的能力,即从训练数据中学到的规律能否推广到新样本。)。

  • ​核心目标​​:在训练集准确性和测试集泛化性之间取得平衡。

​2. 剪枝方法分类​

​(1)预剪枝(Pre-Pruning)​

在决策树构建过程中提前停止生长,通过设定阈值限制树的复杂度。

预剪枝就像​​给树苗修剪枝叶​​,在决策树生长过程中提前阻止不必要的分支。通过设定规则限制树的复杂度,防止它长得"太茂盛"(过拟合)。

预剪枝的常见方法​

​(1) 限制树的高度(max_depth)​
  • ​作用​​:控制树的最大层数,避免决策规则过于复杂。

  • ​例子​​:贷款审批时,最多只问3个问题(如年龄→收入→房产),再多就拒绝(防止过度追问隐私)。

​(2) 设置节点最小样本数(min_samples_split)​
  • ​作用​​:节点至少需要多少样本才允许继续分裂。

  • ​例子​​:医生诊断时,至少要有10个相似病例才新增检查项,否则按经验开药。

​(3) 信息增益阈值(min_impurity_decrease)​
  • ​作用​​:只有分裂能显著提升分类效果时才允许分裂。

  • ​例子​​:挑西瓜时,如果"听声音"和"看颜色"判断效果差不多,就只用其中一个特征。

三、python实战

1.sklearn.tree.DecisionTreeClassifier()参数

sklearn.tree.DecisionTreeClassifier(criterion="gini",    splitter="best",    max_depth=None,   min_samples_split=2,  min_samples_leaf=1,   min_weight_fraction_leaf=0.0,  max_features=None,  random_state=None,   max_leaf_nodes=None, min_impurity_decrease=0.0,   class_weight=None, ccp_alpha=0.0,)

DecisionTreeClassifier是 scikit-learn 提供的分类决策树模型,适用于离散类别预测(如垃圾邮件分类、疾病诊断)。

1.核心分裂参数​

参数

说明

推荐值

引用

criterion

分裂质量评估标准:
'gini'(默认):基尼系数,计算更快
'entropy':信息增益,对不纯度更敏感但可能过拟合

高维数据用 'gini',低维清晰数据两者差异小

splitter

分裂策略:
'best'(默认):全局最优分裂
'random':局部随机分裂,适合大数据加速训练

小数据用 'best',大数据用 'random'


​2. 剪枝与复杂度控制​

参数

说明

推荐值

引用

max_depth

树的最大深度。None表示不限,但易过拟合

通常设为 3-10,通过交叉验证选择

min_samples_split

节点继续分裂的最小样本数:
- 整数:绝对数量
- 浮点数:占总样本比例

样本量大时建议 ≥10 或 0.01-0.1

min_samples_leaf

叶节点最小样本数,防止噪声干扰

分类任务建议 ≥5

max_leaf_nodes

限制叶节点总数,优先于 max_depth

特征多时设为 10-100

min_impurity_decrease

分裂需达到的最小不纯度减少量

0-0.1,值越大树越简单


​3. 特征与随机性控制​

参数

说明

推荐值

引用

max_features

分裂时考虑的最大特征数:
'sqrt'(默认):特征数的平方根
'log2':log2(特征数)
- 整数/浮点数:固定数量/比例

高维数据用 'sqrt'或 'log2'

random_state

随机种子,保证结果可复现

固定值如 42


​4. 类别不平衡处理​

参数

说明

推荐值

引用

class_weight

类别权重:
None(默认):等权重
'balanced':自动按类别频率反比加权

类别不平衡时用 'balanced'


​5. 其他实用参数​

参数

说明

推荐值

引用

ccp_alpha

代价复杂度剪枝参数,后剪枝强度

通过交叉验证选择

presort

预排序数据以加速训练(已弃用)

不推荐使用

 2.实例——客户历史流失数据

对于600多条数据,一共20个属性,两种状态,为二分类问题

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split as trtesp
#sklearn自带的拆分数据集的函数train_test_split
#导入需要用到的库data = pd.read_excel('电信客户流失数据.xlsx')x = data.iloc[:, :-1]
y = data.iloc[:, -1]x_train,x_test,y_train,y_test = \trtesp(x, y, test_size=0.2, random_state=100)#拆分数据集lr = DecisionTreeClassifier()
lr.fit(x_train, y_train)from sklearn.model_selection import cross_val_score#导入打印召回率的函数
#定义循环列表定义最大深度,分裂的最小样本,叶节点最小样本数,最大叶子节点总数范围
max_depth = [i for i in range(4,13)]
min_samples_split = [i for i in range(2,10)]
min_samples_leaf = [i for i in range(12,19)]
max_leaf_nodes = [i for i in range(9,16)]scores = 0
best = []#定义最佳参数存放列表#利用循环嵌套寻找最佳参数
for depth in max_depth:for min_samples in min_samples_split:for min_leaf in min_samples_leaf:for max_leaf in max_leaf_nodes:lr = DecisionTreeClassifier(criterion='gini',max_depth=depth,min_samples_split=min_samples,min_samples_leaf=min_leaf,max_leaf_nodes=max_leaf,random_state=42)#进行交叉验证,评估模型的泛化能力score = /cross_val_score(lr, x_train, y_train, cv=8,scoring='recall')scores_m = sum(score)/len(score)#计算分数if scores_m > scores:#统计最大分数scores= scores_mbest = [depth,min_samples,min_leaf,max_leaf]print('最佳惩罚因子为:',best[:])
#训练最佳参数模型
lr = DecisionTreeClassifier(criterion='gini',max_depth=best[0],min_samples_split=best[1],min_samples_leaf=best[2],max_leaf_nodes=best[3],random_state=42)
lr.fit(x_train, y_train)
y_spred = lr.predict(x_train)
y_pred = lr.predict(x_test)
from sklearn import metrics
print('自测:',metrics.classification_report(y_train, y_spred))#获取自测报告
print('测试:',metrics.classification_report(y_test, y_pred))#获得测试集测试报告#输出决策树图像
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
fig,ax = plt.subplots(figsize=(10,10))
plot_tree(lr,filled=True,ax=ax)
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/91839.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/91839.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

虚拟机中查看和修改文件权限

在虚拟机中管理文件权限是系统管理的重要部分,无论是在Linux还是Windows虚拟机中。下面我将详细介绍两种主要系统的权限管理方法。Linux虚拟机中的文件权限管理查看文件权限使用ls命令:ls -l 文件名输出示例:-rwxr-xr-- 1 user group 1024 Ju…

图像处理拉普拉斯算子

AI对话记录,还没有来得及仔细验证和推导,目前只是记录 当然可以!我们来一步步推导拉普拉斯算子在旋转变换下保持不变的数学过程。这里以二维情况为例,最直观也最常见。🧮 拉普拉斯算子旋转不变性的推导(二维…

React ahooks——副作用类hooks之useThrottleEffect

useThrottleEffect 是 ahooks 提供的节流版 useEffect,它在依赖项变化时执行副作用函数,但会限制执行频率。一、基本语法useThrottleEffect(effect: React.EffectCallback,deps?: React.DependencyList,options?: Options )二、参数详解2.1. effect (必…

【建模与仿真】融合画像约束和潜在特征的深度推荐算法

导读: 基于深度学习的推荐算法已成为推荐系统领域的研究趋势。然而,大多数现有工作仅考虑单一的用户与物品交互数据,限制了算法的预测性能。本文提出一种画像约束的编码方式,并融合隐因子模型中的潜在特征,丰富了推荐…

华为网路设备学习-26(BGP协议 二)路径属性

一、属性分类二、属性含义①公认必遵:所有BGP对等体 必须识别 且 在Update报文中携带1.Origin2.AS-Path3.Next hop②公认自决:所有BGP对等体 必须识别但可以不在Update报文中携带 1.Local-Preference2.ATOMIC_Aggregate③可选传递:所有BGP对…

从0搭建YOLO目标检测系统:实战项目+完整流程+界面开发(附源码)

文章目录一、前言二、专栏介绍三、已有系统介绍3.0 基于yolo通用目标检测系统(手把手教你修改成为自己的检测系统)3.1 基于yolov8柑橘检测系统3.2 基于yolov8舰船检测系统3.3 基于yolo11人脸检测系统3.4 基于yolov8无人机影像光伏板缺陷检测系统一、前言…

【测试】自动化测试工具基础知识及基本应用

下面详细介绍一些常用的自动化测试工具及其基本概念,并提供具体的示例代码,帮助你更好地理解和应用这些工具。1. 自动化测试的基本概念自动化测试是通过软件程序自动执行测试用例的过程。与手动测试相比,自动化测试能够提高测试效率、减少人为…

ArcGIS的字段计算器生成随机数

在ArcGIS的字段计算器中使用Python脚本生成0-100的随机数,可以按照以下步骤操作: 打开属性表,选择要计算的字段打开字段计算器选择"Python"解析器勾选"显示代码块"在"预逻辑脚本代码"中输入以下代码在下方表达…

【前端:Html】--1.1.基础语法

目录 1.HTML--简介 2.HTML--编译器 步骤一:启动记事本 步骤二:用记事本来编辑 HTML 步骤三:保存 HTML 步骤四:在浏览器中运行 HTML 3.HTML--基础 3.1.HTML声明--!DOCTYPE 3.2.HTML 标题--h1 3.3.HTML 段落--p 3.3.1. 水平线--hr 3.3.2.换行符--br 3.3.3.固定格式…

FreeSWITCH 简单图形化界面46 - 收集打包的一些ASR服务

FreeSWITCH 简单图形化界面46 - 收集打包的一些ASR服务 0、一个fs的web配置界面预览1、docker地址2、使用2.1 下载2.2 运行 3、例子3.1 下载3.2 启动3.3 编译mod_audio_fork或者mod_audio_stream模块使用3.4 编写呼叫路由和呼叫脚本呼叫路由呼叫脚本 3.5 esl捕获识别结果3.6 其…

20250805问答课题-实现TextRank + 问题分类

textRank的工具包实现其他可能的实现方法,对比结果查找分类的相关算法 目录 1. 关键词提取TF-IDF TextRank 1.1. TF-IDF算法 1.2. TextRank算法 1.3. 双算法提取关键词 2. 问题分类 2.1. 预处理 2.2. 获取BERT向量 2.3. 一级标签预测 2.4. 二级标签预测 3…

Memcached缓存与Redis缓存的区别、优缺点和适用场景

一、核心差异概述特性MemcachedRedis​数据结构​简单键值存储丰富数据结构(String/Hash/List/Set等)​持久化​不支持支持RDB和AOF两种方式​线程模型​多线程单线程(6.0支持多线程I/O)​内存管理​Slab分配LRU淘汰多种淘汰策略&…

Git简易教程

Git教程 VCS Version Control System版本控制系统 配置用户名邮箱 配置用户名和邮箱 git config --global user.name mihu git config --global user.email aaabbb.com初始化仓库 从项目仓库拉 git clone [项目地址]新建文件夹之后 git init提交操作 提交到仓库 git add . #把…

关于Web前端安全之XSS攻击防御增强方法

仅依赖前端验证是无法完全防止 XSS的,还需要增强后端验证,使用DOMPurify净化 HTML 时,还需要平衡安全性与业务需求。一、仅依赖前端验证无法完全防止 XSS 的原因及后端验证的重要性1. 前端验证的局限性前端验证(如 JavaScript 输入…

消息系统技术文档

消息系统技术文档 概述 本文档详细说明了如何在现有的LHD通信系统中添加自己的消息类型,包括消息的发送、接收、解析和处理的完整流程。 系统架构 消息流程架构图 #mermaid-svg-My7ThVxSl6aftvWK {font-family:"trebuchet ms",verdana,arial,sans-serif;f…

【NLP舆情分析】基于python微博舆情分析可视化系统(flask+pandas+echarts) 视频教程 - 微博舆情数据可视化分析-热词情感趋势树形图

大家好,我是java1234_小锋老师,最近写了一套【NLP舆情分析】基于python微博舆情分析可视化系统(flaskpandasecharts)视频教程,持续更新中,计划月底更新完,感谢支持。今天讲解微博舆情数据可视化分析-热词情感趋势树形图…

8月4日 强对流天气蓝色预警持续:多地需警惕雷暴大风与短时强降水

中央气象台8月4日10时继续发布强对流天气蓝色预警,提醒广大民众注意防范即将到来的恶劣天气。 预警详情: 时间范围: 8月4日14时至5日14时 影响区域: 雷暴大风或冰雹: 西北地区中东部、华北中北部、华南南部等地,风力可达8级以上。 短时强降水: 西北地区中东部、华北、…

C语言数据结构(4)单链表专题2.单链表的应用

1. 链表经典算法——OJ题目 1.1 单链表相关经典算法OJ题1:移除链表元素 1.2 单链表相关经典算法OJ题2:反转链表 1.3 单链表相关经典算法OJ题3:合并两个有序链表 1.4 单链表相关经典算法OJ题4:链表的中间结点 1.5 循环链表…

Shell 脚本发送信号给 C 应用程序,让 C 应用程序回收线程资源后自行退出。

下面分别给出一个 Shell 脚本和 C 程序的例子,实现通过 Shell 脚本发送信号给 C 应用程序,让 C 应用程序回收线程资源后自行退出。原理在 Linux 系统中,我们可以使用信号机制来实现进程间的通信。Shell 脚本可以使用 kill 命令向指定的进程发…

C++入门自学Day6-- STL简介(初识)

往期内容回顾 C模版 C/C内存管理(初识) C/C内存管理(续) STL简介: STL 是 C 标准库的重要组成部分,是一个通用程序设计的模板库,用于数据结构和算法的复用。它极大地提升了代码效率、可靠性…