用逻辑回归(Logistic Regression)处理鸢尾花(iris)数据集

# 导入必要的库
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (accuracy_score, confusion_matrix,classification_report, ConfusionMatrixDisplay)
from sklearn.preprocessing import StandardScaler# 1. 加载鸢尾花数据集
iris = load_iris()
# 转换为DataFrame方便查看(特征+标签)
iris_df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
iris_df['species'] = [iris.target_names[i] for i in iris.target]  # 添加花名标签# 2. 数据基本信息查看
print("数据集形状:", iris.data.shape)  # 150个样本,4个特征
print("\n特征名称:", iris.feature_names)  # 花萼长度、宽度,花瓣长度、宽度
print("\n类别名称:", iris.target_names)  # 山鸢尾、变色鸢尾、维吉尼亚鸢尾# 3. 数据划分(特征X和标签y)
X = iris.data  # 特征:4个植物学测量值
y = iris.target  # 标签:0,1,2分别对应三种鸢尾花# 划分训练集(80%)和测试集(20%),随机种子确保结果可复现
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y  # stratify=y保持类别比例
)# 4. 特征标准化(逻辑回归对特征尺度敏感,标准化可提升性能)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)  # 训练集拟合并标准化
X_test_scaled = scaler.transform(X_test)  # 测试集使用相同的标准化参数# 5. 训练逻辑回归模型(多分类任务)
model = LogisticRegression(max_iter=200, random_state=42)  # 增加迭代次数确保收敛
model.fit(X_train_scaled, y_train)# 6. 模型预测
y_pred = model.predict(X_test_scaled)  # 测试集预测标签
y_pred_proba = model.predict_proba(X_test_scaled)  # 预测每个类别的概率# 7. 模型评估
print("\n===== 模型评估结果 =====")
print(f"训练集准确率:{model.score(X_train_scaled, y_train):.4f}")
print(f"测试集准确率:{accuracy_score(y_test, y_pred):.4f}")print("\n混淆矩阵:")
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=iris.target_names)
disp.plot(cmap=plt.cm.Blues)
plt.title("混淆矩阵(测试集)")
plt.show()print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=iris.target_names))# 8. 特征重要性分析(逻辑回归系数)
feature_importance = pd.DataFrame({'特征': iris.feature_names,'系数绝对值': np.abs(model.coef_).mean(axis=0)  # 多分类取各系数的绝对值均值
}).sort_values(by='系数绝对值', ascending=False)print("\n特征重要性(系数绝对值):")
print(feature_importance)# 可视化特征重要性
plt.figure(figsize=(8, 4))
sns.barplot(x='系数绝对值', y='特征', data=feature_importance, palette='coolwarm')
plt.title("特征对分类的重要性")
plt.show()# 9. 新样本预测示例
# 假设一个新的鸢尾花测量数据(花萼长、花萼宽、花瓣长、花瓣宽)
new_sample = np.array([[5.8, 3.0, 4.9, 1.6]])  # 接近变色鸢尾的特征
new_sample_scaled = scaler.transform(new_sample)  # 标准化# 预测结果
predicted_class = model.predict(new_sample_scaled)
predicted_prob = model.predict_proba(new_sample_scaled)print("\n===== 新样本预测 =====")
print(f"预测类别:{iris.target_names[predicted_class[0]]}")
print("各类别概率:")
for i, prob in enumerate(predicted_prob[0]):print(f"{iris.target_names[i]}: {prob:.4f}")

这段代码使用逻辑回归算法对经典的鸢尾花数据集进行分类,是一个完整的机器学习项目流程。

1. 导入必要的库

import numpy as np

import pandas as pd

import matplotlib.pyplot as plt

import seaborn as sns

from sklearn.datasets import load_iris

from sklearn.model_selection import train_test_split

from sklearn.linear_model import LogisticRegression

from sklearn.metrics import (accuracy_score, confusion_matrix,

                             classification_report, ConfusionMatrixDisplay)

from sklearn.preprocessing import StandardScaler

  1. numpy/pandas:用于数据处理(如矩阵运算、表格操作)。
  2. matplotlib/seaborn:用于绘制图表(如混淆矩阵、特征重要性)。
  3. sklearn:机器学习库,提供数据集、模型、评估工具。

2. 加载和查看数据

iris = load_iris()  # 加载内置鸢尾花数据集

iris_df = pd.DataFrame(iris.data, columns=iris.feature_names)

iris_df['species'] = [iris.target_names[i] for i in iris.target]

print("数据集形状:", iris.data.shape)  # (150, 4) → 150个样本,4个特征

print("特征名称:", iris.feature_names)  # 花瓣/花萼的长度、宽度

print("类别名称:", iris.target_names)  # ['setosa' 'versicolor' 'virginica']

  1. 鸢尾花数据集:包含 150 朵花的数据,分为 3 个品种(每个品种 50 朵)。
  2. 4 个特征:花瓣长度、花瓣宽度、花萼长度、花萼宽度(都是厘米)。
  3. 目标:根据这 4 个特征预测花的品种。

3. 数据划分(训练集和测试集)

X = iris.data  # 特征(花瓣/花萼的测量值)

y = iris.target  # 标签(0/1/2对应3个品种)

X_train, X_test, y_train, y_test = train_test_split(

    X, y, test_size=0.2, random_state=42, stratify=y

)

  1. train_test_split:将数据分为 80% 训练集和 20% 测试集。
    1. stratify=y:确保训练集和测试集中 3 个品种的比例相同(避免数据偏斜)。
    2. random_state=42:固定随机种子,确保结果可复现(每次运行划分结果相同)。

4. 特征标准化

scaler = StandardScaler()

X_train_scaled = scaler.fit_transform(X_train)  # 训练集标准化

X_test_scaled = scaler.transform(X_test)  # 测试集用相同参数标准化

  1. 为什么标准化?:逻辑回归对特征尺度敏感(例如,如果某个特征的数值范围很大,会影响模型收敛)。
  2. StandardScaler:将特征转换为均值为 0、标准差为 1 的标准正态分布。
    1. fit_transform:计算训练集的均值 / 标准差,并应用转换。
    2. transform:用训练集的统计参数(均值 / 标准差)转换测试集(不能重新计算)。

5. 训练逻辑回归模型

model = LogisticRegression(max_iter=200, random_state=42)

model.fit(X_train_scaled, y_train)

  1. LogisticRegression:逻辑回归是分类算法(尽管名字带 “回归”)。
    1. max_iter=200:增加最大迭代次数,确保模型收敛(默认 100 可能不够)。
  2. fit:用训练数据学习模型参数(找到最佳分类边界)。

6. 模型预测

y_pred = model.predict(X_test_scaled)  # 预测类别(0/1/2

y_pred_proba = model.predict_proba(X_test_scaled)  # 预测每个类别的概率

  1. predict:直接输出预测的类别(例如 1 代表 versicolor)。
  2. predict_proba:输出样本属于每个类别的概率(例如 [0.01, 0.95, 0.04] 表示 95% 概率是第二类)。

7. 模型评估

print(f"训练集准确率:{model.score(X_train_scaled, y_train):.4f}")

print(f"测试集准确率:{accuracy_score(y_test, y_pred):.4f}")

  1. 准确率(Accuracy:预测正确的样本比例。
    1. 训练集准确率:约 0.99(模型对训练数据的拟合程度)。
    2. 测试集准确率:约 0.97(模型对新数据的泛化能力)。
混淆矩阵(Confusion Matrix)

cm = confusion_matrix(y_test, y_pred)

disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=iris.target_names)

disp.plot()

  1. 混淆矩阵:可视化分类结果,对角线表示预测正确的样本数。
    1. 例如:预测 setosa(0)的样本全部分类正确;有 1 个 versicolor(1)被误分类为 virginica(2)。
分类报告(Classification Report)

print(classification_report(y_test, y_pred, target_names=iris.target_names))

  1. 精确率(Precision:预测为某类的样本中,实际属于该类的比例。
  2. 召回率(Recall:实际属于某类的样本中,被正确预测的比例。
  3. F1 分数(F1-score:精确率和召回率的调和平均。

8. 特征重要性分析

feature_importance = pd.DataFrame({

    '特征': iris.feature_names,

    '系数绝对值': np.abs(model.coef_).mean(axis=0)

}).sort_values('系数绝对值', ascending=False)

  1. 逻辑回归系数:系数绝对值越大,说明该特征对分类的影响越大。
    1. 通常petal width(花瓣宽度)和petal length(花瓣长度)对分类最重要。

9. 新样本预测示例

new_sample = np.array([[5.8, 3.0, 4.9, 1.6]])  # 手动构造一个样本

new_sample_scaled = scaler.transform(new_sample)  # 标准化

predicted_class = model.predict(new_sample_scaled)  # 预测类别

predicted_prob = model.predict_proba(new_sample_scaled)  # 预测概率

  1. 预测结果:输出新样本的预测类别和概率(例如 95% 概率是 versicolor)。

总结

这个代码展示了一个完整的机器学习流程:

  1. 数据准备:加载数据、划分训练集 / 测试集。
  2. 特征工程:标准化特征,避免量纲影响。
  3. 模型训练:用逻辑回归学习分类规则。
  4. 模型评估:用准确率、混淆矩阵等指标衡量性能。
  5. 预测应用:对新样本进行分类。

鸢尾花数据集是机器学习的 “Hello World”,适合入门。逻辑回归是简单但强大的分类算法,尤其适合特征与类别之间存在线性关系的场景。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/91839.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/91839.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华大北斗TAU1201-1216A00高精度双频GNSS定位模块 自动驾驶专用

在万物互联的时代,您还在为定位不准、信号丢失而烦恼吗?TAU1201-1216A00华大北斗高精度定位模块TAU1201是一款高性能的双频GNSS定位模块,搭载了华大北斗的CYNOSURE III GNSS SoC 芯片,该模块支持新一代北斗三号信号体制&#xff0…

坚持继续布局32位MCU,进一步完善产品阵容,96Mhz主频CW32L012新品发布!

在全球MCU市场竞争加剧、国产替代加速的背景下,嵌入式设备对核心控制芯片的性能、功耗、可靠性及性价比提出了前所未有的严苛需求。为适应市场竞争,2025年7月16日,武汉芯源半导体正式推出基于CW32L01x系列低功耗微控制器家族的全新成员&#…

用线性代数推导码分多址(CDMA)

什么是码分多址 码分多址:CDMA允许多个用户同时、在同一频率上传输数据。它通过给每个用户分配唯一的、相互正交的二进制序列来实现区分。用户的数据比特被这个码片序列扩展成一个高速率的信号,然后在接收端通过相同的码片序列进行相关运算来回复原数据 …

mac 配置svn

1.查看brew的版本:brew install subversion2.安装brew命令:bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"3.把路径添加到path环境变量:echo export PATH"/opt/homebrew/b…

使用 .NET Core 的原始 WebSocket

在 Web 开发中,后端存在一些值得注意的通信协议,用于将更改通知给已连接的客户端。所有这些协议都用于处理同一件事。但鲜为人知的协议很少,鲜为人知的协议也很少。今天,将讨论 WebSocket,它在开发中使用最少&#xff…

编程实现Word自动排版:从理论到实践的全面指南

在现代办公环境中,文档排版是一项常见但耗时的工作。特别是对于需要处理大量文档的专业人士来说,手动排版不仅费时费力,还容易出现不一致的问题。本文将深入探讨如何通过编程方式实现Word文档的自动排版,从理论基础到实际应用&…

力扣经典算法篇-25-删除链表的倒数第 N 个结点(计算链表的长度,利用栈先进后出特性,双指针法)

1、题干 给你一个链表,删除链表的倒数第 n 个结点,并且返回链表的头结点。 示例 1:输入:head [1,2,3,4,5], n 2 输出:[1,2,3,5] 示例 2: 输入:head [1], n 1 输出:[] 示例 3&…

VIT速览

当我们取到一张图片,我们会把它划分为一个个patch,如上图把一张图片划分为了9个patch,然后通过一个embedding把他们转换成一个个token,每个patch对应一个token,然后在输入到transformer encoder之前还要经过一个class …

【服务器与部署 14】消息队列部署:RabbitMQ、Kafka生产环境搭建指南

【服务器与部署 14】消息队列部署:RabbitMQ、Kafka生产环境搭建指南 关键词:消息队列、RabbitMQ集群、Kafka集群、消息中间件、异步通信、微服务架构、高可用部署、消息持久化、生产环境配置、分布式系统 摘要:本文从实际业务场景出发&#x…

LeetCode中等题--167.两数之和II-输入有序数组

1. 题目 给你一个下标从 1 开始的整数数组 numbers &#xff0c;该数组已按 非递减顺序排列 &#xff0c;请你从数组中找出满足相加之和等于目标数 target 的两个数。如果设这两个数分别是 numbers[index1] 和 numbers[index2] &#xff0c;则 1 < index1 < index2 <…

【C# in .NET】19. 探秘抽象类:具体实现与抽象契约的桥梁

探秘抽象类:具体实现与抽象契约的桥梁 在.NET类型系统中,抽象类是连接具体实现与抽象契约的关键桥梁,它既具备普通类的状态承载能力,又拥有类似接口的行为约束特性。本文将从 IL 代码结构、CLR 类型加载机制、方法调度逻辑三个维度,全面揭示抽象类的底层工作原理,通过与…

Apache RocketMQ + “太乙” = 开源贡献新体验

Apache RocketMQ 是 Apache 基金会托管的顶级项目&#xff0c;自 2012 年诞生于阿里巴巴&#xff0c;服务于淘宝等核心交易系统&#xff0c;历经多次双十一万亿级数据洪峰稳定性验证&#xff0c;至今已有十余年发展历程。RocketMQ 致力于构建低延迟、高并发、高可用、高可靠的分…

永磁同步电机控制算法--弱磁控制(变交轴CCR-VQV)

一、原理介绍CCR-FQV弱磁控制不能较好的利用逆变器的直流侧电压&#xff0c;造成电机的调速范围窄、效率低和带载能力差。为了解决CCR-FQV弱磁控制存在的缺陷&#xff0c;可以在电机运行过程中根据工况的不同实时的改变交轴电压给定uq的值&#xff0c;实施 CCR-VQV弱磁控制。…

达梦数据守护集群搭建(1主1实时备库1同步备库1异步备库)

目录 1 环境信息 1.1 目录信息 1.2 其他环境信息 2 环境准备 2.1 新建dmdba用户 2.2 关闭防火墙 2.3 关闭Selinux 2.4 关闭numa和透明大页 2.5 修改文件打开最大数 2.6 修改磁盘调度 2.7 修改cpufreq模式 2.8 信号量修改 2.9 修改sysctl.conf 2.10 修改 /etc/sy…

电感与电容充、放电极性判断和电感选型

目录 一、电感 二、电容 三、电感选型 一、电感 充电&#xff1a;左右-为例 放电&#xff1a;极性相反&#xff0c;左-右 二、电容 充电&#xff1a;左右-为例 放电&#xff1a;左右-&#xff08;与充电极性一致&#xff09; 三、电感选型 主要考虑额定电流和饱和电流。…

新建模范式Mamba——“Selectivity is All You Need?”

目录 一、快速走进和理解Mamba建模架构 &#xff08;一&#xff09;从Transformer的统治地位谈起 &#xff08;二&#xff09;另一条道路&#xff1a;结构化状态空间模型&#xff08;SSM&#xff09; &#xff08;三&#xff09;Mamba 的核心创新&#xff1a;Selective SSM…

Python实现Word文档中图片的自动提取与加载:从理论到实践

在现代办公和文档处理中&#xff0c;Word文档已经成为最常用的文件格式之一。这些文档不仅包含文本内容&#xff0c;还经常嵌入各种图片、图表和其他媒体元素。在许多场景下&#xff0c;我们需要从Word文档中提取这些图片&#xff0c;例如进行内容分析、创建图像数据库、或者在…

Kafka、RabbitMQ 与 RocketMQ 高可靠消息保障方案对比分析

Kafka、RabbitMQ 与 RocketMQ 高可靠消息保障方案对比分析 在分布式系统中&#xff0c;消息队列承担着异步解耦、流量削峰、削峰填谷等重要职责。为了保证应用的数据一致性和业务可靠性&#xff0c;各大消息中间件都提供了多种高可靠消息保障机制。本文以Kafka、RabbitMQ和Rock…

四足机器人远程视频与互动控制的全链路方案

随着机器人行业的快速发展&#xff0c;特别是四足仿生机器人在巡检、探测、安防、救援等复杂环境中的广泛部署&#xff0c;如何实现高质量、低延迟的远程视频监控与人机互动控制&#xff0c;已经成为制约其应用落地与规模化推广的关键技术难题。 四足机器人常常面临以下挑战&a…

把leetcode官方题解自己简单解释一下

自用自用&#xff01;&#xff01;&#xff01;leetcode hot 100