在多分类任务中,模型输出一个概率分布,常用的损失函数是 Categorical Cross Entropy(多类交叉熵)。本文将带你理解其数学本质、应用场景、数值稳定性及完整 Python 实现。
📘 一、什么是 Categorical Cross Entropy?
多类交叉熵损失函数 衡量的是预测的概率分布 与真实类别分布
之间的距离。通常用于 Softmax 输出层 + 多分类问题。
🧮 二、数学公式
设:
:真实标签(独热编码 One-Hot)
:模型预测概率(Softmax 输出)
则多类交叉熵定义为:
含义:
若
且其他为 0(独热编码),则只考虑正确类别对应的概率;
预测越接近真实标签,对应损失越小。
🧑💻 三、Python 实现(含数值稳定)
函数实现如下:
import mathdef categorical_cross_entropy(y_true, y_pred):"""计算多类交叉熵损失(适用于独热编码标签)参数:y_true (List[float]):真实标签(One-Hot)y_pred (List[float]):预测概率(Softmax 输出)返回:float:交叉熵损失值"""epsilon = 1e-15 # 防止 log(0)y_pred = [min(max(p, epsilon), 1 - epsilon) for p in y_pred]return -sum(y * math.log(p) for y, p in zip(y_true, y_pred))# 示例:3类分类问题
y_true = [0, 1, 0]
y_pred = [0.2, 0.7, 0.1]loss = categorical_cross_entropy(y_true, y_pred)
print("Categorical Cross Entropy:", loss)
✅ 输出示例:
Categorical Cross Entropy: 0.35667494393873245
⚠️ 四、为什么需要 Epsilon 防止 log(0)?
在预测中,某些类概率可能非常接近 0(例如 1e-20),直接对其取对数会:
产生
math domain error
;导致梯度爆炸或模型不稳定。
因此我们设置:
epsilon = 1e-15
y_pred = max(min(p, 1 - epsilon), epsilon)
🔄 五、与 Binary Cross Entropy 的区别
项目 | Binary Cross Entropy | Categorical Cross Entropy |
---|---|---|
应用场景 | 二分类或多标签 | 多分类(单标签) |
标签格式 | 0 或 1 | 独热编码 |
输出层 | Sigmoid | Softmax |
🧠 六、实际应用场景
图像分类(如 CIFAR-10、ImageNet)
文本分类(如新闻分类、情感分析)
多类别实体识别(NER)
📌 七、总结
Categorical Cross Entropy 是多分类任务的首选损失函数;
与 Softmax 输出层配合使用;
一定要做 数值稳定性处理(加 epsilon);
真实标签应为 One-Hot 向量;
预测越准,损失越小。