【机器学习笔记 Ⅱ】7 多类分类

1. 多类分类(Multi-class Classification)

定义

多类分类是指目标变量(标签)有超过两个类别的分类任务。例如:

  • 手写数字识别:10个类别(0~9)。
  • 图像分类:区分猫、狗、鸟等。
  • 新闻主题分类:政治、经济、体育等。

特点

  • 互斥性:每个样本仅属于一个类别(区别于多标签分类)。
  • 输出要求:模型需输出每个类别的概率分布,且概率之和为1。

实现方式

  • One-vs-Rest (OvR):训练K个二分类器(K为类别数),每个分类器判断“是否属于该类”。
  • Softmax回归:直接输出多类概率分布(更高效,主流方法)。

2. Softmax函数

定义

Softmax将神经网络的原始输出(logits)转换为概率分布,公式为:
[
\sigma(\mathbf{z})_i = \frac{e{z_i}}{\sum_{j=1}K e^{z_j}}}, \quad i = 1, \dots, K
]

  • ( \mathbf{z} ):神经网络的原始输出向量(logits)。
  • ( K ):类别总数。
  • 输出:每个类别的概率 ( \sigma(\mathbf{z})i \in (0,1) ),且 ( \sum{i=1}^K \sigma(\mathbf{z})_i = 1 )。

示例

假设神经网络对3个类别的原始输出为 ( \mathbf{z} = [2.0, 1.0, 0.1] ):
[
\begin{aligned}
\sigma(\mathbf{z})_1 &= \frac{e{2.0}}{e{2.0} + e^{1.0} + e^{0.1}}} \approx 0.659 \
\sigma(\mathbf{z})_2 &= \frac{e{1.0}}{e{2.0} + e^{1.0} + e^{0.1}}} \approx 0.242 \
\sigma(\mathbf{z})_3 &= \frac{e{0.1}}{e{2.0} + e^{1.0} + e^{0.1}}} \approx 0.099 \
\end{aligned}
]
最终概率分布:[0.659, 0.242, 0.099] → 预测为第1类。

特性

  1. 放大差异:较大的 ( z_i ) 会获得显著更高的概率。
  2. 数值稳定:实际计算时,通常减去最大值(( z_i - \max(\mathbf{z}) ))避免数值溢出。
    exp_z = np.exp(z - np.max(z, axis=1, keepdims=True))
    softmax = exp_z / np.sum(exp_z, axis=1, keepdims=True)
    

3. Softmax与交叉熵损失

组合使用

  • Softmax:将logits转为概率。
  • 交叉熵损失(Cross-Entropy Loss):衡量预测概率与真实标签的差异。

反向传播

Softmax与交叉熵的梯度计算被合并优化,梯度形式简洁:
[
\frac{\partial J}{\partial z_i} = \hat{y}_i - y_i
]

  • 梯度直接反映预测与真实的差异。

4. 代码实现

(1) Python(NumPy)

import numpy as npdef softmax(z):exp_z = np.exp(z - np.max(z, axis=1, keepdims=True))  # 防溢出return exp_z / np.sum(exp_z, axis=1, keepdims=True)# 示例:3个样本,4个类别
logits = np.array([[1.0, 2.0, 3.0, 4.0],[0.5, 1.0, 2.0, 3.0],[-1.0, 0.0, 1.0, 2.0]])
probabilities = softmax(logits)
print("概率分布:\n", probabilities)

(2) PyTorch

import torch
import torch.nn as nnlogits = torch.tensor([[1.0, 2.0, 3.0], [0.1, 0.2, 0.3]])
softmax = nn.Softmax(dim=1)
probabilities = softmax(logits)
print(probabilities)

(3) TensorFlow/Keras

from tensorflow.keras.layers import Softmaxlogits = tf.constant([[1.0, 2.0, 3.0], [0.1, 0.2, 0.3]])
softmax = Softmax(axis=-1)
probabilities = softmax(logits)
print(probabilities.numpy())

5. 多类分类 vs 二分类

任务类型输出层激活函数损失函数标签格式
二分类Sigmoid二元交叉熵0或1
多类分类Softmax分类交叉熵One-hot编码

6. 常见问题

Q1:为什么Softmax输出概率和为1?

  • 设计目的即为生成概率分布,便于直观解释和优化。

Q2:Softmax和Sigmoid的区别?

  • Sigmoid:用于二分类,独立计算每个类别的概率(可非互斥)。
  • Softmax:用于多分类,强制所有类别概率和为1(互斥)。

Q3:如何处理类别不平衡?

  • 加权交叉熵:为少数类赋予更高权重。
    model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy'],class_weight={0: 1, 1: 2, 2: 1})  # 类别1的权重加倍
    

7. 总结

  • 多类分类:目标变量有多个互斥类别,需输出概率分布。

  • Softmax:将logits转换为归一化概率,与交叉熵损失配合使用。

  • 关键公式

  • 实践建议

    • 输出层用Softmax,隐藏层用ReLU。
    • 标签需为one-hot编码(或稀疏类别标签)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/87943.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/87943.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2025年深圳杉川机器人性格测评和Verify测评SHL题库高分攻略

1、杉川机器人包含性格测评和Verify测评,预计用时60min,请确保作答时周围环境无干扰、网络畅通;2、请使用电脑完成作答,建议使用以下浏览器登录:IE9.0及以上版本,火狐,谷歌;3、杉川机…

【flutter 在最新版本中已经弃用了传统的 apply from 方式引入 Gradle 插件】

报错 Flutter assets will be downloaded from https://storage.flutter-io.cn. Make sure you trust this source! Launching lib\main.dart on 2112123AC in debug mode... Running Gradle task assembleDebug...FAILURE: Build failed with an exception.* Where: Script D…

Web后端实战:(部门管理)

1.准备工作 1.1开发规范 1.1.1前后端分离开发 我们目前基本都是采用的前后台分离开发方式,如下图所示: 那么基于前后台分离开发的模式下,我们后台开发者开发一个功能的具体流程如何呢?如下图所示: 需求分析&…

字节寻址(Byte Addressing) 与 Verilog中的寄存器索引

字节寻址(Byte Addressing) 与 Verilog中的寄存器索引 之间的关系。 您的疑问非常正确,直接看 3h1 很容易让人以为地址就是 0x01。 但答案是:是的,3h1 在这里对应的字节地址(Byte Address)确实是…

Ubuntu远程桌面

方法1: 检查并使用已安装的VNC或远程桌面组件 请在终端中执行以下命令检查系统中已安装的相关组件: bash# 检查系统中已安装的VNC和远程桌面相关软件 dpkg -l | grep -E "vnc|vino|remote|rdp"# 检查常见远程桌面服务 which vino-server tigervncserver x11vnc xr…

WEB攻防-文件包含LFIRFI伪协议编码算法无文件利用黑白盒

知识点: 1、文件包含-原理&分类&危害-LFI&RFI 2、文件包含-利用-黑白盒&无文件&伪协议 一、演示案例-文件包含-原理&分类&利用 1、原理 程序开发人员通常会把可重复使用的函数写到单个文件中,在使用某些函数时&#xff…

LabVIEW的GPIB仪器校准

基于LabVIEW开发平台与 GPIB 总线技术,采用是德科技、泰克等硬件设备,构建示波器与频谱分析仪自动校准系统。通过图形化编程实现校准流程自动化,涵盖设备连接、参数配置、数据采集、误差分析及报告生成,显著提升校准效率与精度&am…

Zotero中进行文献翻译【Windows11】

zotero官网:https://www.zotero.org/ 1 在Zotero软件中安装插件 进入Zotero百科全书,依次点击:插件→翻译插件→插件介绍→Zotero 中文社区插件商店 进去后搜索pdf2zh,然后下载后放入空白文件夹zotero-pdf2zh 打开Zotero软件后…

用U盘启动制作centos系统最常见报错,系统卡住无法继续问题(手把手)

一、按照操作系统centos7报错如下(U盘) 按照系统报错如下: ERROR: could not insert ‘floppy’; ERROR: could not insert ‘edd’ : No这种报错很常见,基本上就是u盘启动路径找不到导致,遇到次数比较多所以也比较好解…

C#中的BindingList有什么作用?

在C#中&#xff0c;BindingList<T> 是一个非常重要的集合类&#xff0c;位于 System.ComponentModel 命名空间&#xff0c;主要用于实现数据绑定&#xff08;Data Binding&#xff09;场景。1. 核心作用BindingList<T> 是 List<T> 的增强版&#xff0c;主要提…

Python爬取知乎评论:多线程与异步爬虫的性能优化

1. 知乎评论爬取的技术挑战 知乎的评论数据通常采用动态加载&#xff08;Ajax&#xff09;&#xff0c;这意味着直接使用**<font style"color:rgb(64, 64, 64);background-color:rgb(236, 236, 236);">requests</font>****<font style"color:rg…

软件系统测试的基本流程

软件系统测试流程是确保软件质量的规范化过程&#xff0c;涵盖从测试准备到最终上线评估的全周期&#xff0c;通常分为以下6个核心阶段&#xff0c;各阶段紧密衔接、形成闭环&#xff1a; 一、测试启动与规划阶段 核心目标&#xff1a;明确“测什么、谁来测、怎么测”&#xff…

使用Linux操作MySQL数据库分批导出数据为.SQL文件

当数据库某张数据量非常大的表进行其他操作&#xff0c;需要先进行导出时&#xff1b; 先用linux进入操作环境&#xff0c; 1.添加一个export_mysql_batches.sh脚本文件&#xff0c; #!/bin/bash# 数据库连接配置 DB_HOST"36.33.0.138:3306" DB_USER"devuser&qu…

LeetCode 算法题解:链表与二叉树相关问题 打打卡

LeetCode 算法题解&#xff1a;链表与二叉树相关问题 在算法学习和实践中&#xff0c;LeetCode 是一个非常好的平台&#xff0c;它包含了各种各样的算法题目&#xff0c;有助于我们提升编程能力和解决问题的能力。本文将详细讲解在 leetcoding.cpp 文件中实现的一些链表和二叉树…

故宫票价监控接口分析(一)

故宫票价监控接口分析(一) 对爬虫、逆向感兴趣的同学可以查看文章,一对一小班教学(系统理论和实战教程)、提供接单兼职渠道:https://blog.csdn.net/weixin_35770067/article/details/142514698 本文内容仅供学习和参考之用,不得用于商业目的。作者对文中内容的准确性、完整…

AWS OpenSearch Dev Tools使用

# 创建通用索引模版 PUT _template/aws-waf_logs_template {"index_patterns": ["aws-waf-logs-*"],"mappings": {"properties": {"timestamp": {"type": "date"}}} }# 设置单个索引格式 PUT /aws-waf-…

git-安装 Gerrit Hook 自动生成changeid

要在 Git 中安装 Gerrit Hook 以自动生成 Change-ID&#xff0c;可以按照以下步骤操作&#xff1a; 全局钩子配置&#xff08;推荐&#xff09; 创建全局钩子目录并下载 Gerrit 提供的 commit-msg 钩子脚本&#xff0c;确保所有仓库共享该配置&#xff1a; mkdir -p ~/.githook…

Excel 的多线程特性

Excel 本身并不是完全多线程的应用程序&#xff0c;但它在某些操作和功能上支持多线程处理。以下是对 Excel 是否多线程的详细解答&#xff0c;结合你之前提到的 VBA/COM 自动化代码和受保护视图问题&#xff0c;提供清晰且准确的分析。 Excel 的多线程特性计算引擎的多线程支持…

【嵌入式ARM汇编】-操作系统基础(一)

操作系统基础(一) 文章目录 操作系统基础(一)1、操作系统架构概述2、用户模式与内核模式3、进程4、系统调用5、对象和句柄我们想要逆向的程序几乎从来不会在真空中执行。相反,程序通常在正在运行的操作系统(例如 Linux、Windows 或 macOS)的上下文中运行。因此,了解这些…

[创业之路-474]:企业经营层 - 小米与华为多维对比分析(2025年视角),以后不要把这两家公司放在同一个维度上 进行比较了

一、行业定位与市场角色不同华为&#xff1a;用技术手段解决行业的难题&#xff0c;顺便赚钱技术驱动型硬科技企业&#xff1a;以通信设备起家&#xff0c;延伸至智能手机、芯片、操作系统&#xff08;鸿蒙&#xff09;、云计算、智能汽车等领域&#xff0c;构建“云-管-端”全…