《sklearn机器学习》——交叉验证迭代器

sklearn 交叉验证迭代器

scikit-learn (sklearn) 中,交叉验证迭代器(Cross-Validation Iterators)是一组用于生成训练集和验证集索引的工具。它们是 model_selection 模块的核心组件,决定了数据如何被分割,从而支持模型评估、超参数调优等任务。

这些迭代器实现了不同的数据划分策略,以适应各种数据类型和问题场景。下面详细介绍 sklearn 中主要的交叉验证迭代器。


一、核心概念

所有交叉验证迭代器都遵循相同的接口:

  • 输入:数据集大小 n_samples
  • 输出:一个生成器(generator),每次迭代返回一对 (train_indices, test_indices) 的 NumPy 数组。
  • 用途:可用于 cross_val_score, GridSearchCV 等函数的 cv 参数。

二、主要交叉验证迭代器

1. KFold - 标准 K 折交叉验证

用途:最基础的 K 折 CV,适用于类别均衡的分类或回归问题。

工作方式

  • 将数据集划分为 k 个大小基本相等的折(folds)。
  • 每次使用其中 1 折作为验证集,其余 k-1 折作为训练集。
  • 重复 k 次,确保每折都恰好被用作一次验证集。

参数

  • n_splits:折数,默认为 5。
  • shuffle:是否在划分前打乱数据顺序。建议设为 True,除非数据有时间顺序。
  • random_state:随机种子,确保结果可复现。

代码示例

from sklearn.model_selection import KFold
import numpy as npX = np.array([[1], [2], [3], [4], [5]])
y = np.array([1, 2, 3, 4, 5])kf = KFold(n_splits=3, shuffle=True, random_state=42)
for train_index, test_index in kf.split(X):print("TRAIN:", train_index, "TEST:", test_index)

2. StratifiedKFold - 分层 K 折交叉验证

用途:分类任务的首选,尤其当类别分布不均衡时。

工作方式

与 KFold 类似,但确保每一折中各类别的比例与原始数据集大致相同。
避免某些折中某个类别样本过少或缺失,导致评估偏差。
为什么重要?

例如:一个二分类数据集中正类占 10%。使用普通 KFold 可能在某折中正类样本极少,导致模型无法学习或评估失真。
StratifiedKFold 保证每折中正类比例都接近 10%。
代码示例:

python
深色版本
from sklearn.model_selection import StratifiedKFoldy = np.array([0, 0, 0, 1, 1])  # 不均衡数据skf = StratifiedKFold(n_splits=2, shuffle=True, random_state=42)
for train_index, test_index in skf.split(X, y):print("TRAIN:", train_index, "TEST:", test_index)print("Y_TRAIN:", y[train_index], "Y_TEST:", y[test_index])

3. LeaveOneOut (LOO) - 留一法交叉验证

用途:样本量非常小(如 < 100)时使用。

工作方式

每次留出一个样本作为验证集,其余所有样本作为训练集。
重复 n_samples 次。
优缺点:

✅ 几乎无偏估计(训练集最大)。
❌ 计算成本极高(训练 n 次),且方差可能很大(单个样本影响大)。
代码示例:

python
深色版本
from sklearn.model_selection import LeaveOneOutloo = LeaveOneOut()
for train_index, test_index in loo.split(X):print("TRAIN:", train_index, "TEST:", test_index)

4. LeavePOut - 留 P 法交叉验证

用途:比 LOO 更一般化,但计算更昂贵。

工作方式

每次留出 p 个样本作为验证集,其余所有样本作为训练集。
所有可能的 p 个样本组合都会被尝试,因此总次数为 C(n, p)。
p=1 时退化为 LOO。
注意:当 n 或 p 稍大时,组合数爆炸,极少在实际中使用。

5. ShuffleSplit - 随机划分分割

用途:灵活的随机抽样 CV,适合大数据集或需要控制训练/验证比例时。

工作方式

不强制使用所有样本。
每次迭代从数据中随机抽取指定比例作为训练集,其余作为验证集(可重叠)。
可指定迭代次数 n_splits。
参数:

n_splits:迭代次数。
train_size, test_size:训练/验证集比例。
优点:

可独立控制训练集大小。
适用于大数据,无需完整 K 折。
代码示例:

python
深色版本
from sklearn.model_selection import ShuffleSplitss = ShuffleSplit(n_splits=3, test_size=0.25, random_state=0)
for train_index, test_index in ss.split(X):print("TRAIN:", train_index, "TEST:", test_index)

6. StratifiedShuffleSplit - 分层随机划分

用途:ShuffleSplit 的分层版本,用于类别不均衡的分类任务。

工作方式

在每次随机划分时,保持训练集和验证集中各类别的比例一致。
适用场景:

大数据集上的分层 CV。
需要固定验证集大小且保持类别平衡。

7. GroupKFold - 组 K 折交叉验证

用途:当数据中存在组结构(如:同一用户多次记录、同一病人多个样本),需确保同一组的数据不同时出现在训练和验证集中,防止数据泄露。

工作方式

根据 groups 数组划分,确保一个组的所有样本要么全在训练集,要么全在验证集。
参数:

groups:长度为 n_samples 的数组,表示每个样本所属的组。
代码示例:

python
深色版本
from sklearn.model_selection import GroupKFoldX = [0.1, 0.2, 2.2, 2.4, 2.3, 4.5, 5.7, 5.8]
y = [1, 1, 0, 0, 0, 1, 1, 1]
groups = [1, 1, 2, 2, 2, 3, 3, 3]  # 3 个组gkf = GroupKFold(n_splits=3)
for train_index, test_index in gkf.split(X, y, groups):print("TRAIN:", train_index, "TEST:", test_index)print("GROUPS:", groups[test_index])

8. TimeSeriesSplit - 时间序列交叉验证

用途:处理时间序列数据,确保不使用未来数据预测过去。

工作方式

按时间顺序划分。
每次迭代,训练集是过去的数据,验证集是接下来的一段数据。
训练集逐渐增长(“前滚”交叉验证)。
关键特性:

不打乱数据。
验证集始终在训练集之后。
代码示例:

python
深色版本
from sklearn.model_selection import TimeSeriesSplittscv = TimeSeriesSplit(n_splits=3)
for train_index, test_index in tscv.split(X):print("TRAIN:", train_index, "TEST:", test_index)

输出:

深色版本
TRAIN: [0 1 2] TEST: [3]
TRAIN: [0 1 2 3] TEST: [4]

三、如何选择合适的 CV 迭代器?

场景 推荐迭代器
一般分类(类别均衡) KFold
分类(类别不均衡) ✅ StratifiedKFold
回归任务 KFold 或 ShuffleSplit
小样本数据 LeaveOneOut(谨慎使用)
大数据,灵活划分 ShuffleSplit, StratifiedShuffleSplit
数据有组结构(避免泄露) GroupKFold, LeaveOneGroupOut
时间序列数据 ✅ TimeSeriesSplit
需要分层 + 随机划分 StratifiedShuffleSplit

四、使用建议

默认选择:
分类:StratifiedKFold
回归:KFold
设置 shuffle=True:除非数据有序(如时间序列),否则建议打乱。
固定 random_state:确保实验可复现。
避免数据泄露:在使用 CV 时,任何数据预处理(如标准化、填充)都应在 CV 循环内部进行(使用 Pipeline)。

python
深色版本
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScalerpipe = Pipeline([('scaler', StandardScaler()),('clf', SVC())
])

在 cross_val_score 中使用 pipe,确保 scaler 只在训练集上拟合

cross_val_score(pipe, X, y, cv=5)

总结

sklearn 的交叉验证迭代器提供了丰富且灵活的工具,能够适应从标准分类到时间序列、组数据等各种复杂场景。选择合适的迭代器是获得可靠、无偏模型评估的关键第一步。务必根据数据的结构和任务类型,选择最匹配的 CV 策略。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/95263.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/95263.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Trae+Chrome MCP Server 让AI接管你的浏览器

一、核心优势1、无缝集成现有浏览器环境直接复用用户已打开的 Chrome 浏览器&#xff0c;保留所有登录状态、书签、扩展及历史记录&#xff0c;无需重新登录或配置环境。对比传统工具&#xff08;如 Playwright&#xff09;需独立启动浏览器进程且无法保留用户环境&#xff0c;…

Shell 编程 —— 正则表达式与文本处理器

目录 一. 正则表达式 1.1 定义 1.2 用途 1.3 Linux 正则表达式分类 1.4 正则表达式组成 &#xff08;1&#xff09;普通字符 &#xff08;2&#xff09;元字符&#xff1a;规则的核心载体 &#xff08;3&#xff09; 重复次数 &#xff08;4&#xff09;两类正则的核心…

Springboot 监控篇

在 Spring Boot 中实现 JVM 在线监控&#xff08;包括线程曲线、内存使用、GC 情况等&#xff09;&#xff0c;最常用的方案是结合 Spring Boot Actuator Micrometer 监控可视化工具&#xff08;如 Grafana、Prometheus&#xff09;。以下是完整实现方案&#xff1a; 一、核…

Java 大视界 --Java 大数据在智能教育学习资源整合与知识图谱构建中的深度应用(406)

Java 大视界 --Java 大数据在智能教育学习资源整合与知识图谱构建中的深度应用&#xff08;406&#xff09;引言&#xff1a;正文&#xff1a;一、智能教育的两大核心痛点与 Java 大数据的适配性1.1 资源整合&#xff1a;42% 重复率背后的 “三大堵点”1.2 知识图谱&#xff1a…

2025年新版C语言 模电数电及51单片机Proteus嵌入式开发入门实战系统学习,一整套全齐了再也不用东拼西凑

最近有同学说想系统学习嵌入式&#xff0c;问我有没有系统学习的路线推荐。刚入门的同学可能不知道如何下手&#xff0c;这里一站式安排上。先说下学习的顺序&#xff0c;先学习C语言&#xff0c;接着学习模电数电&#xff08;即模拟电路和数字电路&#xff09;最后学习51单片机…

Android的USB通信 (AOA Android开放配件协议)

USB 主机和配件概览Android 通过 USB 配件和 USB 主机两种模式支持各种 USB 外围设备和 Android USB 配件&#xff08;实现 Android 配件协议的硬件&#xff09;。在 USB 配件模式下&#xff0c;外部 USB 硬件充当 USB 主机。配件示例可能包括机器人控制器、扩展坞、诊断和音乐…

人工智能视频画质增强和修复软件Topaz Video AI v7.1.1最新汉化,自带星光模型

软件介绍 这是一款专业的视频修复工具-topaz video ai&#xff0c;该版本是解压即可使用&#xff0c;自带汉化&#xff0c;免登陆无输出水印。 软件特点 不登录不注册解压即可使用无水印输出视频画质提升 软件使用 选择我们需要提升画质的视频即可 软件下载 夸克 其他网盘…

LeetCode 777.在LR字符串中交换相邻字符

在一个由 ‘L’ , ‘R’ 和 ‘X’ 三个字符组成的字符串&#xff08;例如"RXXLRXRXL"&#xff09;中进行移动操作。一次移动操作指用一个 “LX” 替换一个 “XL”&#xff0c;或者用一个 “XR” 替换一个 “RX”。现给定起始字符串 start 和结束字符串 result&#x…

RK-Android15-WIFI白名单功能实现

实现WIFI白名单功能 。 三个模式: 1、默认模式:允许搜索所有的WIFI显示、搜索出来 ; 2、禁用模式:允许所有WIFI显示,能够搜索出来 ;3、白名单模式:允许指定WIFI名单显示,被搜索出来 文章目录 前言-需求 一、参考资料 二、核心修改文件和实现方式 1、修改文件 疑问思考 …

Maven + JUnit:Java单元测试的坚实组合

Maven JUnit&#xff1a;Java单元测试的坚实组合Maven JUnit&#xff1a;Java单元测试的坚实组合一、什么是软件测试&#xff1f;二、测试的维度&#xff1a;阶段与方法&#xff08;一&#xff09;测试的四大阶段&#xff08;二&#xff09;测试的三大方法三、main方法测试与…

FFMPEG 10BIT下 Intel b570 qsv 硬解AV1,H265视频编码测试

上10bitffmpeg 8.0 b570最新驱动 &#xff0c;CPU 12100F 显卡 Intel b570 ffmpeg -hwaccel_output_format qsv -i "XXX.mkv" -vf "formatp010le" -c:v hevc_qsv -global_quality 19 -quality best -rc_mode ICQ -preset veryslow -g 120 -refs 5 -b…

SQL分类详解:掌握DQL、DML、DDL等数据库语言类型

如果你是一名数据库运维工程师&#xff0c;或者正在学习数据库技术&#xff0c;那么理解SQL的不同类型是非常重要的。让我们一起看看SQL到底有哪些种类&#xff0c;以及它们各自的作用。 1. 什么是SQL&#xff1f; SQL&#xff08;Structured Query Language&#xff09;是一种…

[特殊字符] 预告!我正在开发一款让自动化操作变得「像呼吸一样自然」的AI神器

各位技术爱好者和创作者朋友们&#xff0c;我要解决一个行业痛点&#xff01;在上一个项目中&#xff08;&#x1f525; 重磅预告&#xff01;我要用AI开发一个自媒体神器&#xff0c;彻底解决创作者的7大痛点&#xff01;&#xff09;&#xff0c;我本来雄心勃勃地打算直接用R…

加密软件哪个好用?加密软件-为数据共享提供安全保障

企业与合作伙伴协作时需共享大量数据&#xff0c;若缺乏保护&#xff0c;数据可能被非法获取&#xff0c;影响合作信任&#xff0c;甚至引发商业纠纷。加密软件可确保共享数据仅授权方可见&#xff0c;为数据共享提供安全保障&#xff0c;推动合作顺利开展。​1.固信软件固信加…

FPGA复位

1:能不复位尽量不要复位&#xff0c;减少逻辑扇出数&#xff1a;比如打拍信号。2:xilinx的FPGA推荐高复位&#xff0c;ATERAL的FPGA推荐低复位。3:尽量使用异步复位&#xff1a;大多数厂商目标库内的触发器都只有异步复位端口&#xff0c;采用同步复位需消耗较多逻辑资源。一&a…

Cursor 教我学 Python

文章目录1. 写在最前面2. Python 语法2.1 yield2.1.1 yield 和 return 的区别2.1.2 golang 中实现 yield 语法3. aiohttp 库3.1 原始写法3.2 修改写法3.2 耗时对比分析4. 碎碎念5. 参考资料1. 写在最前面 最近加了很多 Python Coding 的任务&#xff0c;虽然在 AI 加持下能够顺…

Ollama:本地大语言模型部署和使用详解

1.什么是Ollama&#xff1f; Ollama是一个开源的大语言模型管理工具&#xff0c;具有以下特点&#xff1a; 简单易用&#xff1a;提供简单的命令行接口本地部署&#xff1a;模型运行在本地&#xff0c;保护数据隐私跨平台支持&#xff1a;支持Windows、macOS、Linux丰富的模型…

云计算学习100天-第41天 -普罗米修斯2

目录 五、添加被监控端 1、在web1[192.168.88.100]上部署node exporter 2、在Prometheus服务器上添加监控节点 3、浏览器查看添加结果 六、Grafana的部署 概述 部署步骤 七、监控MySQL数据库 1、配置MySQL 2、配置mysql exporter 3、配置prometheus监控mysql 五、添…

集成电路学习:什么是SVM支持向量机

SVM:支持向量机 SVM,即支持向量机(Support Vector Machine),是一种常用的机器学习算法,特别适用于分类和回归问题。以下是对SVM的详细解析: 一、SVM的基本原理 SVM的基本思想是在特征空间中寻找一个最优的超平面,使得不同类别的样本能够被最大化地分开。这个最优…

盲盒抽谷机小程序开发:如何用3D技术重构沉浸式体验?

在盲盒经济中&#xff0c;“沉浸感”是提升用户停留时长与转化率的核心武器。某品牌通过3D扭蛋机旋转、卡牌翻转特效&#xff0c;使用户停留时长从15秒延长至45秒&#xff0c;转化率提升25%&#xff1b;另一品牌上线AR试戴功能后&#xff0c;单次抽谷时长延长至2分钟&#xff0…