矿物分类系统开发笔记(二):模型训练[删除空缺行]

目录

一、阶段衔接与开发目标

二、数据准备

三、模型选择与训练

1. 逻辑回归(LR)

2. 随机森林(RF)

3. 高斯朴素贝叶斯(GNB)

4. 支持向量机(SVM)

5. AdaBoost

6. XGBoost

四、模型评估与结果分析

评估指标

评估结果

结果分析

五、开发总结

六、后续计划


一、阶段衔接与开发目标

在《矿物分类系统开发笔记(一)》中,我们完成了矿物数据集的收集、清洗与预处理工作,重点对数据中的空缺值进行了分析,并采用 “删除空缺行” 的方式生成了可供模型训练的标准化数据集。本阶段作为开发流程的延续,主要基于预处理后的数据完成以下目标:

  • 选取 6 种经典机器学习算法进行矿物分类模型训练
  • 通过网格搜索优化模型参数,提升分类性能
  • 构建统一的评估体系,对比各模型在测试集上的表现
  • 记录并分析实验结果,为后续系统选型提供依据

二、数据准备

数据来源:使用预处理阶段生成的训练集(训练数据集 [删除空缺行].xlsx)和测试集(测试数据集 [删除空缺行].xlsx)

数据划分:

  • 特征集(X):所有样本的属性数据(除最后一列标签外的所有列)
  • 标签集(y):
    • 训练集标签:包含 0、1、3 三类(训练集中标签为 2 的样本均存在数据空缺,已在预处理阶段随空缺行一同删除)
    • 测试集标签:包含 0、1、2、3 四类(保留了数据完整的标签 2 样本,用于验证模型对未见过类别的泛化能力)

特殊处理:
针对 XGBoost 模型特性,构建标签映射关系:{0:0, 1:1, 3:2},将原始标签转换为连续整数编码;预测后通过反向映射{0:0, 1:1, 2:3}还原原始标签,对测试集特有的标签 2 单独处理(预测结果中若出现未映射编码则判定为 2)

import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from xgboost import XGBClassifier
from sklearn.model_selection import GridSearchCV
from sklearn import metrics
import json# 数据读取
train_data = pd.read_excel('..//temp_data//训练数据集[删除空缺行].xlsx')
test_data = pd.read_excel('..//temp_data//测试数据集[删除空缺行].xlsx')# 特征与标签分割
train_X = train_data.iloc[:, :-1]
train_y = train_data.iloc[:, -1]  # 训练标签:0、1、3
test_X = test_data.iloc[:, :-1]
test_y = test_data.iloc[:, -1]    # 测试标签:0、1、2、3# XGBoost标签映射处理
label_mapping = {0: 0, 1: 1, 3: 2}
reverse_mapping = {v: k for k, v in label_mapping.items()}
train_y_xgb = train_y.map(label_mapping)  # 转换为连续编码
test_y_xgb = test_y.map(label_mapping)# 结果存储容器
result_data = {}

三、模型选择与训练

选取 6 种经典分类算法进行对比实验,均采用网格搜索(GridSearchCV)进行参数优化,5 折交叉验证确定最佳参数:

1. 逻辑回归(LR)

核心参数:C=0.001, max_iter=100, multi_class='ovr', penalty='l1', solver='liblinear'
特点:采用 L1 正则化(Lasso),适合高维数据特征选择,使用 ovr 策略处理多分类

# 网格搜索优化(实际运行时启用)
# logreg = LogisticRegression()
# param_grid = [
#     {'penalty': ['l1'], 'solver': ['liblinear'], 'C': [0.001, 0.01, 0.1], 'multi_class': ['ovr']},
#     {'penalty': ['l2'], 'solver': ['lbfgs'], 'C': [0.001, 0.01, 0.1], 'multi_class': ['multinomial']}
# ]
# grid_search = GridSearchCV(logreg, param_grid, cv=5)
# grid_search.fit(train_X, train_y)
# print("LR最佳参数:", grid_search.best_params_)# 最佳模型训练
LR_result = {}
lr = LogisticRegression(C=0.001, max_iter=100, multi_class='ovr', penalty='l1', solver='liblinear')
lr.fit(train_X, train_y)# 评估
train_pred = lr.predict(train_X)
test_pred = lr.predict(test_X)
print("LR训练集评估:\n", metrics.classification_report(train_y, train_pred))
print("LR测试集评估:\n", metrics.classification_report(test_y, test_pred))# 结果提取
report = metrics.classification_report(test_y, test_pred, digits=6).split()
LR_result['recall_0'] = float(report[6])
LR_result['recall_1'] = float(report[11])
LR_result['recall_2'] = float(report[16])
LR_result['recall_3'] = float(report[21])
LR_result['acc'] = float(report[25])
result_data['LR'] = LR_result

2. 随机森林(RF)

核心参数:bootstrap=True, criterion='gini', max_depth=None, min_samples_leaf=1, min_samples_split=2, n_estimators=200
特点:集成多棵决策树降低过拟合风险,Gini 系数作为不纯度度量,保留完整决策树深度

# 网格搜索优化(实际运行时启用)
# rf = RandomForestClassifier(random_state=42)
# param_grid = {
#     'n_estimators': [100, 200],
#     'max_depth': [None, 20],
#     'min_samples_split': [2, 5],
#     'bootstrap': [True]
# }
# grid_search = GridSearchCV(rf, param_grid, cv=5, n_jobs=-1)
# grid_search.fit(train_X, train_y)
# print("RF最佳参数:", grid_search.best_params_)# 最佳模型训练
RF_result = {}
rf = RandomForestClassifier(bootstrap=True, criterion='gini', max_depth=None,min_samples_leaf=1, min_samples_split=2, n_estimators=200,random_state=42
)
rf.fit(train_X, train_y)# 评估
train_pred = rf.predict(train_X)
test_pred = rf.predict(test_X)
print("RF训练集评估:\n", metrics.classification_report(train_y, train_pred))
print("RF测试集评估:\n", metrics.classification_report(test_y, test_pred))# 结果提取
report = metrics.classification_report(test_y, test_pred, digits=6).split()
RF_result['recall_0'] = float(report[6])
RF_result['recall_1'] = float(report[11])
RF_result['recall_2'] = float(report[16])
RF_result['recall_3'] = float(report[21])
RF_result['acc'] = float(report[25])
result_data['RF'] = RF_result

3. 高斯朴素贝叶斯(GNB)

核心参数:var_smoothing=1e-06
特点:基于贝叶斯定理的概率模型,通过 var_smoothing 参数提高数值稳定性

# 网格搜索优化(实际运行时启用)
# gnb = GaussianNB()
# param_grid = {'var_smoothing': [1e-9, 1e-6, 1e-3]}
# grid_search = GridSearchCV(gnb, param_grid, cv=5)
# grid_search.fit(train_X, train_y)
# print("GNB最佳参数:", grid_search.best_params_)# 最佳模型训练
GNB_result = {}
gnb = GaussianNB(var_smoothing=1e-06)
gnb.fit(train_X, train_y)# 评估
train_pred = gnb.predict(train_X)
test_pred = gnb.predict(test_X)
print("GNB训练集评估:\n", metrics.classification_report(train_y, train_pred))
print("GNB测试集评估:\n", metrics.classification_report(test_y, test_pred))# 结果提取
report = metrics.classification_report(test_y, test_pred, digits=6).split()
GNB_result['recall_0'] = float(report[6])
GNB_result['recall_1'] = float(report[11])
GNB_result['recall_2'] = float(report[16])
GNB_result['recall_3'] = float(report[21])
GNB_result['acc'] = float(report[25])
result_data['GNB'] = GNB_result

4. 支持向量机(SVM)

核心参数:C=10, gamma=1, kernel='rbf', max_iter=1000
特点:采用 RBF 核函数处理非线性关系,较大的 C 值表示对误分类惩罚更严格

# 网格搜索优化(实际运行时启用)
# svm = SVC(random_state=42)
# param_grid = {
#     'kernel': ['rbf'],
#     'C': [1, 10],
#     'gamma': [0.1, 1],
#     'max_iter': [1000]
# }
# grid_search = GridSearchCV(svm, param_grid, cv=5, n_jobs=-1)
# grid_search.fit(train_X, train_y)
# print("SVM最佳参数:", grid_search.best_params_)# 最佳模型训练
SVM_result = {}
svm = SVC(C=10, gamma=1, kernel='rbf', max_iter=1000, random_state=42)
svm.fit(train_X, train_y)# 评估
train_pred = svm.predict(train_X)
test_pred = svm.predict(test_X)
print("SVM训练集评估:\n", metrics.classification_report(train_y, train_pred))
print("SVM测试集评估:\n", metrics.classification_report(test_y, test_pred))# 结果提取
report = metrics.classification_report(test_y, test_pred, digits=6).split()
SVM_result['recall_0'] = float(report[6])
SVM_result['recall_1'] = float(report[11])
SVM_result['recall_2'] = float(report[16])
SVM_result['recall_3'] = float(report[21])
SVM_result['acc'] = float(report[25])
result_data['SVM'] = SVM_result

5. AdaBoost

核心参数:algorithm='SAMME', learning_rate=0.5, n_estimators=50
特点:通过 SAMME 算法集成弱分类器,学习率 0.5 控制迭代步长,50 个基分类器

# 网格搜索优化(实际运行时启用)
# ada = AdaBoostClassifier(random_state=42)
# param_grid = {
#     'n_estimators': [50, 100],
#     'learning_rate': [0.5, 1.0],
#     'algorithm': ['SAMME']
# }
# grid_search = GridSearchCV(ada, param_grid, cv=5, n_jobs=-1)
# grid_search.fit(train_X, train_y)
# print("AdaBoost最佳参数:", grid_search.best_params_)# 最佳模型训练
Ada_result = {}
ada = AdaBoostClassifier(algorithm='SAMME', learning_rate=0.5, n_estimators=50, random_state=42
)
ada.fit(train_X, train_y)# 评估
train_pred = ada.predict(train_X)
test_pred = ada.predict(test_X)
print("AdaBoost训练集评估:\n", metrics.classification_report(train_y, train_pred))
print("AdaBoost测试集评估:\n", metrics.classification_report(test_y, test_pred))# 结果提取
report = metrics.classification_report(test_y, test_pred, digits=6).split()
Ada_result['recall_0'] = float(report[6])
Ada_result['recall_1'] = float(report[11])
Ada_result['recall_2'] = float(report[16])
Ada_result['recall_3'] = float(report[21])
Ada_result['acc'] = float(report[25])
result_data['AdaBoost'] = Ada_result

6. XGBoost

核心参数:colsample_bytree=0.8, gamma=0, learning_rate=0.1, max_depth=3, n_estimators=200
特点:基于树的集成模型,通过列采样(80%)防止过拟合,深度 3 的树结构控制复杂度

# 网格搜索优化(实际运行时启用)
# xgb = XGBClassifier(random_state=42, use_label_encoder=False, eval_metric='mlogloss', num_class=3)
# param_grid = {
#     'n_estimators': [100, 200],
#     'max_depth': [3, 5],
#     'learning_rate': [0.1],
#     'colsample_bytree': [0.8]
# }
# grid_search = GridSearchCV(xgb, param_grid, cv=5, n_jobs=-1)
# grid_search.fit(train_X, train_y_xgb)
# print("XGBoost最佳参数:", grid_search.best_params_)# 最佳模型训练
XGB_result = {}
xgb_best = XGBClassifier(colsample_bytree=0.8, gamma=0, learning_rate=0.1, max_depth=3,n_estimators=200, reg_alpha=0, reg_lambda=0, subsample=0.8,random_state=42, use_label_encoder=False, eval_metric='mlogloss', num_class=3
)
xgb_best.fit(train_X, train_y_xgb)# 评估(含标签映射还原)
train_pred_encoded = xgb_best.predict(train_X)
train_pred = [reverse_mapping[code] for code in train_pred_encoded]
test_pred_encoded = xgb_best.predict(test_X)
test_pred = [reverse_mapping[code] if code in reverse_mapping else 2 for code in test_pred_encoded]print("XGBoost训练集评估:\n", metrics.classification_report(train_y, train_pred))
print("XGBoost测试集评估:\n", metrics.classification_report(test_y, test_pred))# 结果提取
report = metrics.classification_report(test_y, test_pred, digits=6).split()
XGB_result['recall_0'] = float(report[6])
XGB_result['recall_1'] = float(report[11])
XGB_result['recall_2'] = float(report[16])
XGB_result['recall_3'] = float(report[21])
XGB_result['acc'] = float(report[25])
result_data['XGBoost'] = XGB_result# 保存所有结果
with open('..//temp_data//结果数据[删除空缺行].json', 'w', encoding='utf-8') as f:json.dump(result_data, f, ensure_ascii=False, indent=4)

四、模型评估与结果分析

评估指标

  • 各类别召回率(recall):分别记录 0、1、2、3 四类的召回率
  • 整体准确率(acc):模型整体分类正确率

评估结果

模型召回率_0召回率_1召回率_2召回率_3准确率(acc)
LR1.00.00.00.00.6
RF0.9333330.3333330.01.00.68
GNB0.7333330.3333330.01.00.56
SVM0.8666670.00.01.00.56
AdaBoost0.80.6666670.01.00.68
XGBoost0.9333330.1666670.01.00.64

结果分析

  • 整体表现:随机森林(RF)和 AdaBoost 模型表现最优,准确率均达到 0.68;逻辑回归(LR)和支持向量机(SVM)对标签 1 的识别能力较弱(召回率为 0)
  • 类别特异性:
    • 标签 0:逻辑回归(LR)识别效果最佳(召回率 1.0),XGBoost 和 RF 次之(0.933333)
    • 标签 1:AdaBoost 表现最优(召回率 0.666667),显著高于其他模型
    • 标签 2:所有模型召回率均为 0,主要原因是训练集中无该类别样本(因数据空缺已删除),模型无法学习该类特征
    • 标签 3:RF、GNB、SVM、AdaBoost、XGBoost 均能 100% 识别,说明该类别特征与其他类别区分度较高
  • 泛化能力:由于训练集缺失标签 2 样本,所有模型对该类别均无识别能力,反映了训练数据完整性对模型泛化能力的关键影响

五、开发总结

  • 完成了 6 种分类模型的训练与优化,验证了不同算法在矿物分类任务上的适用性
  • 通过标签映射机制解决了 XGBoost 对非连续标签的处理问题,保证了模型间评估标准的一致性
  • 明确了训练数据空缺对模型性能的影响:标签 2 因训练样本缺失导致所有模型识别失败
  • 确定了 RF 和 AdaBoost 为当前阶段表现最优的模型,为后续系统开发提供了选型依据

后续将进行其他5种预处理生成的数据集(平均值填充、中位数填充、众数填充、线性回归填充、随机森林填充)进行模型训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/93853.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/93853.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

通信方式:命名管道

一、命名管道 1. 命名管道的原理 有了匿名管道,理解命名管道就非常简单了。 对于普通文件而言,两个进程打开同一个文件,OS是不会将文件加载两次的,这两个进程都会指向同一个文件,那么,也就享有同一份 in…

如何将数据库快速接入大模型实现智能问数,实现chatbi、dataagent,只需短短几步,不需要配置工作流!

智能问数系统初始化操作流程 一、系统初始化与管理员账号创建登录与初始化提示:首次访问系统登录页,若系统未初始化,会弹出 “系统未完成初始化,请初始化管理员账号” 提示,点击【去创建】。填写管理员信息&#xff1a…

告别手写文档!Spring Boot API 文档终极解决方案:SpringDoc OpenAPI

在前后端分离和微服务盛行的今天,API 文档是团队协作的“通用语言”。一份清晰、准确、实时同步的文档,能极大提升开发和联调效率。然而,手动编写和维护 API 文档(如 Word、Markdown 或 Postman)是一场永无止境的噩梦—…

N4200EX是一款全智能超声波检测仪产品简析

N4200EX是一款全智能超声波检测仪,适用于石油、石化、天然气、气体生产等行业的压力管路、阀门、设备的各种防爆场合气体泄漏、真空泄漏、阀门内漏检测。●本安防爆设计,防爆、防尘、防水、抗摔。●适应恶劣环境,可在-25℃超低温环境检测&…

NestJS @Inject 装饰器入门教程

一、核心概念解析 1.1 依赖注入(DI)的本质 依赖注入是一种设计模式,通过 IoC(控制反转)容器管理对象生命周期。在 NestJS 中,Injectable() 标记的类会被容器管理,而 Inject() 用于显式指定依赖项…

网络地址详解

子网划分详解:从 IP 地址结构到实际应用 在计算机网络中,子网划分是一项关键的技术,它能帮助我们更高效地管理 IP 地址资源,优化网络性能。要深入理解子网划分,首先需要从 IP 地址的基本结构说起。 一、IPv4 地址的基…

吾日三省吾身 | 周反思 8.19

上周一览总体来说,上个周是一个被项目驱使而险些丧失自主思考能力的危险阶段。相比任何有机械化工作经验的读者都有类似的体验,在手上打螺丝的无尽循环中,自己的脑子就会逐渐丧失对自身的感知以及自主思考的能力。而这个负循环一旦开始&#…

08.19总结

连通性 在无向图中,若任意两点间均存在路径相连,则该图称为连通图。 若删除图中任意一个顶点后,剩余图仍保持连通性,则该图为点双连通图。 若删除图中任意一条边后,图仍保持连通性,则该图为边双连通图。 在…

车e估牵头正式启动乘用车金融价值评估师编制

8月13日,汽车金融行业职业能力评价规范编制启动工作会议在广州圆满落幕。本次会议由中国机械工业联合会机械工业人才评价中心主办,广州穗圣信息科技有限公司(车e估)承办。会议汇聚了众多行业精英,包括中国机械工业联合…

清空 github 仓库的历史提交记录(创建新分支)

想在 现有仓库中创建一个新分支 master,删除原来的 main,然后把 master 重命名为 main,并且清空历史。可以用下面一条完整的命令序列操作: # 1. 创建一个没有历史的新分支 master git checkout --orphan master# 2. 添加当前所有文…

使用B210在Linux下实时处理ETC专用短程通信数据(2)-CPU单核高速数据处理

在上一篇文章中,使用Octave初步验证了ETC车联数据的格式。然而,Octave无法实时处理20M的采样带宽。我们本节通过C语言,重写 Octave程序,实现实时处理,涉及下面三个关键特点。 文章目录1. 全静态内存2. 使用环状缓存3 无…

Spark 运行流程核心组件(二)任务调度

1、调度策略参数默认值说明spark.scheduler.modeFIFO调度策略(FIFO/FAIR)spark.locality.wait3s本地性降级等待时间spark.locality.wait.processspark.locality.waitPROCESS_LOCAL 等待时间spark.locality.wait.nodespark.locality.waitNODE_LOCAL 等待时…

Orbbec---setBoolProperty 快捷配置设备行为

在奥比中光(Orbbec)SDK(通常称为ob库)中,setBoolProperty函数是用于设置设备或传感器的布尔类型属性的核心接口。它主要用于开启/关闭设备的某些功能或模式,是配置设备行为的重要方法。 函数原型与参数解析…

[OWASP]智能体应用安全保障指南

1.关键组件定义 KC1 生成式语言模型(Generative Language Models) KC1.1 大语言模型(LLMs):作为代理的“大脑”,基于预训练基础模型(如 GPT 系列、Claude、Llama、Gemini)&#xff…

【Vivado TCL 教程】从零开始掌握 Xilinx Vivado TCL 脚本编程(三)

【Vivado TCL 教程】从零开始掌握 Xilinx Vivado TCL 脚本编程(三) 系列文章目录 1、VMware Workstation Pro安装指南:详细步骤与配置选项说明 2、VMware 下 Ubuntu 操作系统下载与安装指南 3、基于 Ubuntu 的 Linux 系统中 Vivado 2020.1 下…

AI与大数据驱动下的食堂采购系统源码:供应链管理平台的未来发展

在数字化浪潮不断加速的今天,很多企业和机构都在追求一个目标:如何把“效率”与“成本”做到最佳平衡。对于学校、企事业单位的食堂来说,采购环节就是重中之重。往小了说,它关系到食堂员工的工作体验;往大了说&#xf…

HarmonyOS 实战:学会在鸿蒙中使用第三方 JavaScript 库(附完整 Demo)

摘要 在鸿蒙(HarmonyOS NEXT / ArkTS)开发中,我们大部分业务逻辑和 UI 都是用 ArkTS 写的。不过在做一些数据处理、网络请求、工具函数或者复杂算法时,完全没必要“重复造轮子”。这时候就可以直接引入 JavaScript 的第三方库。鸿…

C++实现教务管理系统,文件操作账户密码登录(附源码)

教务管理系统项目介绍 项目概述 这是一个基于C开发的教务管理系统,提供了学生、教师和系统管理员三种角色的功能模块,实现了教务信息的录入、查询、修改和删除等基本操作。系统采用文件存储方式保存数据,具有简单易用、功能完备的特点。 项…

《C++进阶之STL》【二叉搜索树】

【二叉搜索树】目录前言:------------概念介绍------------1. 什么是二叉搜索树?2. 二叉搜索树的性能怎么样?------------基本操作------------一、查找操作思想步骤简述二、插入操作目标步骤简述三、删除操作目标步骤简述------------代码实现--------…

Orange的运维学习日记--47.Ansible进阶之异步处理

Orange的运维学习日记–47.Ansible进阶之异步处理 文章目录Orange的运维学习日记--47.Ansible进阶之异步处理Playbook 执行顺序原理可选执行策略调整并发连接数:forks 参数查看与修改 forks性能调优建议分批执行全局任务:serial 关键字serial 用法示例应…