Pytorch为什么 nn.CrossEntropyLoss = LogSoftmax + nn.NLLLoss?

为什么 nn.CrossEntropyLoss = LogSoftmax + nn.NLLLoss

在使用 PyTorch 时,我们经常听说 nn.CrossEntropyLossLogSoftmaxnn.NLLLoss 的组合。这句话听起来简单,但背后到底是怎么回事?为什么这两个分开的功能加起来就等于一个完整的交叉熵损失?今天我们就从数学公式到代码实现,彻底搞清楚它们的联系。

1. 先认识三个主角

要理解这个等式,先得知道每个部分的定义和作用:

  • nn.CrossEntropyLoss:交叉熵损失,直接接受未归一化的 logits,计算模型预测与真实标签的差距,适用于多分类任务。
  • LogSoftmax:将 logits 转为对数概率(log probabilities),输出范围是负值。
  • nn.NLLLoss:负对数似然损失,接受对数概率,计算正确类别的负对数值。

表面上看,nn.CrossEntropyLoss 是一个独立的损失函数,而 LogSoftmaxnn.NLLLoss 是两步操作。为什么说它们本质上是一回事呢?答案藏在数学公式和计算逻辑里。

2. 数学上的拆解

让我们从交叉熵的定义开始,逐步推导。

(1) 交叉熵的数学形式

交叉熵(Cross-Entropy)衡量两个概率分布的差异。在多分类任务中:

  • ( p p p ):真实分布,通常是 one-hot 编码(比如 [0, 1, 0] 表示第 1 类)。
  • ( q q q ):预测分布,是模型输出的概率(比如 [0.2, 0.5, 0.3])。

交叉熵公式为:

H ( p , q ) = − ∑ c = 1 C p c log ⁡ ( q c ) H(p, q) = -\sum_{c=1}^{C} p_c \log(q_c) H(p,q)=c=1Cpclog(qc)

对于 one-hot 编码,( p c p_c pc ) 在正确类别上为 1,其他为 0,所以简化为:

H ( p , q ) = − log ⁡ ( q correct ) H(p, q) = -\log(q_{\text{correct}}) H(p,q)=log(qcorrect)

其中 ( q correct q_{\text{correct}} qcorrect ) 是正确类别对应的预测概率。对 ( N N N ) 个样本取平均,损失为:

Loss = − 1 N ∑ i = 1 N log ⁡ ( q i , y i ) \text{Loss} = -\frac{1}{N} \sum_{i=1}^{N} \log(q_{i, y_i}) Loss=N1i=1Nlog(qi,yi)

这正是交叉熵损失的核心。

(2) 从 logits 到概率

神经网络输出的是原始分数(logits),比如 ( z = [ z 1 , z 2 , z 3 ] z = [z_1, z_2, z_3] z=[z1,z2,z3] )。要得到概率 ( q q q ),需要用 Softmax:

q j = e z j ∑ k = 1 C e z k q_j = \frac{e^{z_j}}{\sum_{k=1}^{C} e^{z_k}} qj=k=1Cezkezj

交叉熵损失变成:

Loss = − 1 N ∑ i = 1 N log ⁡ ( e z i , y i ∑ k = 1 C e z i , k ) \text{Loss} = -\frac{1}{N} \sum_{i=1}^{N} \log\left(\frac{e^{z_{i, y_i}}}{\sum_{k=1}^{C} e^{z_{i,k}}}\right) Loss=N1i=1Nlog(k=1Cezi,kezi,yi)

这就是 nn.CrossEntropyLoss 的数学形式。

(3) 分解为两步

现在我们把这个公式拆开:

  • 第一步:LogSoftmax
    计算对数概率:
    log ⁡ ( q j ) = log ⁡ ( e z j ∑ k = 1 C e z k ) = z j − log ⁡ ( ∑ k = 1 C e z k ) \log(q_j) = \log\left(\frac{e^{z_j}}{\sum_{k=1}^{C} e^{z_k}}\right) = z_j - \log\left(\sum_{k=1}^{C} e^{z_k}\right) log(qj)=log(k=1Cezkezj)=zjlog(k=1Cezk)
    这正是 LogSoftmax 的定义。它把 logits ( z z z ) 转为对数概率 ( log ⁡ ( q ) \log(q) log(q) )。

  • 第二步:NLLLoss
    有了对数概率 ( log ⁡ ( q ) \log(q) log(q) ),取出正确类别的值,取负号并平均:
    NLL = − 1 N ∑ i = 1 N log ⁡ ( q i , y i ) \text{NLL} = -\frac{1}{N} \sum_{i=1}^{N} \log(q_{i, y_i}) NLL=N1i=1Nlog(qi,yi)
    这就是 nn.NLLLoss 的公式。

组合起来

  • LogSoftmax 把 logits 转为 ( log ⁡ ( q ) \log(q) log(q) )。
  • nn.NLLLoss 对 ( log ⁡ ( q ) \log(q) log(q) ) 取负号,计算损失。
  • 两步合起来正好是 ( − log ⁡ ( q correct ) -\log(q_{\text{correct}}) log(qcorrect) ),与交叉熵一致。
3. PyTorch 中的实现验证

从数学上看,nn.CrossEntropyLoss 的确可以分解为 LogSoftmaxnn.NLLLoss。我们用代码验证一下:

import torch
import torch.nn as nn# 输入数据
logits = torch.tensor([[1.0, 2.0, 0.5], [0.1, 0.5, 2.0]])  # [batch_size, num_classes]
target = torch.tensor([1, 2])  # 真实类别索引# 方法 1:直接用 nn.CrossEntropyLoss
ce_loss_fn = nn.CrossEntropyLoss()
ce_loss = ce_loss_fn(logits, target)
print("CrossEntropyLoss:", ce_loss.item())# 方法 2:LogSoftmax + nn.NLLLoss
log_softmax = nn.LogSoftmax(dim=1)
nll_loss_fn = nn.NLLLoss()
log_probs = log_softmax(logits)  # 计算对数概率
nll_loss = nll_loss_fn(log_probs, target)
print("LogSoftmax + NLLLoss:", nll_loss.item())

运行结果:两个输出的值完全相同(比如 0.75)。这证明 nn.CrossEntropyLoss 在内部就是先做 LogSoftmax,再做 nn.NLLLoss

4. 为什么 PyTorch 这么设计?

既然 nn.CrossEntropyLoss 等价于 LogSoftmax + nn.NLLLoss,为什么 PyTorch 提供了两种方式?

  • 便利性
    nn.CrossEntropyLoss 是一个“一体式”工具,直接输入 logits 就能用,适合大多数场景,省去手动搭配的麻烦。

  • 模块化
    LogSoftmaxnn.NLLLoss 分开设计,给开发者更多灵活性:

    • 你可以在模型里加 LogSoftmax,只用 nn.NLLLoss 计算损失。
    • 可以单独调试对数概率(比如打印 log_probs)。
    • 在某些自定义损失中,可能需要用到独立的 LogSoftmax
  • 数值稳定性
    nn.CrossEntropyLoss 内部优化了计算,避免了分开操作时可能出现的溢出问题(比如 logits 很大时,Softmax 的分母溢出)。

5. 为什么不直接用 Softmax?

你可能好奇:为什么不用 Softmax + 对数 + 取负,而是用 LogSoftmax
答案是数值稳定性:

  • 单独计算 Softmax(指数运算)可能导致溢出(比如 ( e 1000 e^{1000} e1000 ))。
  • LogSoftmax 把指数和对数合并为 ( z j − log ⁡ ( ∑ e z k ) z_j - \log(\sum e^{z_k}) zjlog(ezk) ),计算更稳定。
6. 使用场景对比
  • nn.CrossEntropyLoss

    • 输入:logits。
    • 场景:标准多分类任务(图像分类、文本分类)。
    • 优点:简单直接。
  • LogSoftmax + nn.NLLLoss

    • 输入:logits 需手动转为对数概率。
    • 场景:需要显式控制 Softmax,或者模型已输出对数概率。
    • 优点:灵活性高。
7. 小结:为什么等价?
  • 数学上:交叉熵 ( − log ⁡ ( q correct ) -\log(q_{\text{correct}}) log(qcorrect) ) 可以拆成两步:
    1. LogSoftmax:从 logits 到 ( log ⁡ ( q ) \log(q) log(q) )。
    2. nn.NLLLoss:从 ( log ⁡ ( q ) \log(q) log(q) ) 到 ( − log ⁡ ( q correct ) -\log(q_{\text{correct}}) log(qcorrect) )。
  • 实现上nn.CrossEntropyLoss 把这两步封装成一个函数,结果一致。
  • 设计上:PyTorch 提供两种方式,满足不同需求。

所以,nn.CrossEntropyLoss = LogSoftmax + nn.NLLLoss 不是巧合,而是交叉熵计算的自然分解。理解这一点,能帮助你更灵活地使用 PyTorch 的损失函数。

8. 彩蛋:手动推导

想自己验证?试试手动计算:

  • logits [1.0, 2.0, 0.5],目标是 1。
  • Softmax:[0.23, 0.63, 0.14]
  • LogSoftmax:[-1.47, -0.47, -1.97]
  • NLL:-(-0.47) = 0.47
  • 直接用 nn.CrossEntropyLoss,结果一样!

希望这篇博客解开了你的疑惑!

后记

2025年2月28日18点51分于上海,在grok3 大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/70968.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/70968.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

rabbitmq 延时队列

要使用 RabbitMQ Delayed Message Plugin 实现延时队列,首先需要确保插件已安装并启用。以下是实现延时队列的步骤和代码示例。 1. 安装 RabbitMQ Delayed Message Plugin 首先,确保你的 RabbitMQ 安装了 rabbitmq-delayed-message-exchange 插件。你可…

在 Vue 单文件组件(SFC)中,标签的显式关闭与隐式关闭有着重要的区别

一、显式关闭标签 1、定义&#xff1a; 所有的 HTML 标签都必须有一个对应的结束标签。 自闭合标签也必须使用 / 来关闭。 <template> <div> <p>这是一个段落</p> <img src"image.png"…

第四届大数据、区块链与经济管理国际学术会议

重要信息 官网&#xff1a;www.icbbem.com 时间&#xff1a;2025年3月14-16日 地点&#xff1a;中国-武汉 &#xff08;线上召开&#xff09; 简介 第四届大数据、区块链与经济管理国际学术会议(ICBBEM 2025)&#xff0c;将于2025年3月14-16日在中国湖北省武汉市召开。…

每日十个计算机专有名词 (7)

Metasploit 词源&#xff1a;Meta&#xff08;超越&#xff0c;超出&#xff09; exploit&#xff08;漏洞利用&#xff09; Metasploit 是一个安全测试框架&#xff0c;用来帮助安全专家&#xff08;也叫渗透测试人员&#xff09;发现和利用计算机系统中的漏洞。你可以把它想…

使用Docker Compose部署 MySQL8

MySQL 8 是一个功能强大的关系型数据库管理系统,而 Docker 则是一个流行的容器化平台。结合使用它们可以极大地简化 MySQL 8 的部署过程,并且确保开发环境和生产环境的一致性。 安装 Docker 和 Docker Compose 首先,确保你的机器上已经安装了 Docker 和 Docker Compose。 …

mamba_ssm和causal-conv1d详细安装教程

1.前言 Mamba是近年来在深度学习领域出现的一种新型结构&#xff0c;特别是在处理长序列数据方面表现优异。在本文中&#xff0c;我将介绍如何在 Linux 系统上安装并配置 mamba_ssm 虚拟环境。由于官方指定mamba_ssm适用于 PyTorch 版本高于 1.12 且 CUDA 版本大于 11.6 的环境…

c++中初始化列表的使用

在 C 中&#xff0c;初始化列表是在构造函数的定义中&#xff0c;用于对类的成员变量进行初始化的一种方式。它紧跟在构造函数的参数列表之后&#xff0c;使用冒号 : 分隔&#xff0c;各成员变量的初始化用逗号 , 分隔。下面详细介绍初始化列表及其参数的含义。 基本语法 clas…

《Linux系统编程篇》System V信号量实现生产者与消费者问题(Linux 进程间通信(IPC))——基础篇(拓展思维)

文章目录 &#x1f4da; **生产者-消费者问题**&#x1f511; **问题分析**&#x1f6e0;️ **详细实现&#xff1a;生产者-消费者****步骤 1&#xff1a;定义信号量和缓冲区****步骤 2&#xff1a;创建信号量****步骤 3&#xff1a;生产者进程****步骤 4&#xff1a;消费者进程…

利用 Python 爬虫进行跨境电商数据采集

1 引言2 代理IP的优势3 获取代理IP账号4 爬取实战案例---&#xff08;某电商网站爬取&#xff09;4.1 网站分析4.2 编写代码4.3 优化代码 5 总结 1 引言 在数字化时代&#xff0c;数据作为核心资源蕴含重要价值&#xff0c;网络爬虫成为企业洞察市场趋势、学术研究探索未知领域…

HONOR荣耀MagicBook 15 2021款 独显(BOD-WXX9,BDR-WFH9HN)原厂Win10系统

适用型号&#xff1a;【BOD-WXX9】 MagicBook 15 2021款 i7 独显 MX450 16GB512GB (BDR-WFE9HN) MagicBook 15 2021款 i5 独显 MX450 16GB512GB (BDR-WFH9HN) MagicBook 15 2021款 i5 集显 16GB512GB (BDR-WFH9HN) 链接&#xff1a;https://pan.baidu.com/s/1S6L57ADS18fnJZ1…

c语言实现三子棋小游戏(涉及二维数组、函数、循环、常量、动态取地址等知识点)

使用C语言实现一个三子棋小游戏 涉及知识点&#xff1a;二维数组、自定义函数、自带函数库、循环、常量、动态取地址等等 一些细节点&#xff1a; 1、引入自定义头文件&#xff0c;需要用""双引号包裹文件名&#xff0c;目的是为了和官方头文件的<>区分开。…

C语言数据类型及其使用 (带示例)

目录 1. 基本数据类型 整型 浮点型 字符型 2. 构造数据类型 数组 结构体 联合体&#xff08;共用体&#xff09; 枚举类型 3. 指针类型 4. 空类型 在 C 语言中&#xff0c;数据类型是非常重要的概念&#xff0c;它决定了数据在内存中的存储方式、占用空间大小以及可…

Web自动化之Selenium添加网站Cookies实现免登录

在使用Selenium进行Web自动化时&#xff0c;添加网站Cookies是实现免登录的一种高效方法。通过模拟浏览器行为&#xff0c;我们可以将已登录状态的Cookies存储起来&#xff0c;并在下次自动化测试或爬虫任务中直接加载这些Cookies&#xff0c;从而跳过登录步骤。 Cookies简介 …

NAT 技术:网络中的 “地址魔术师”

目录 一、性能瓶颈&#xff1a;NAT 的 “阿喀琉斯之踵” &#xff08;一&#xff09;数据包处理延迟 &#xff08;二&#xff09;高并发下的性能损耗 二、应用兼容性&#xff1a;NAT 带来的 “适配难题” &#xff08;一&#xff09;端到端通信的困境 &#xff08;二&…

php序列化与反序列化

文章目录 基础知识魔术方法&#xff1a;在序列化和反序列化过程中自动调用的方法什么是 __destruct() 方法&#xff1f;何时触发 __destruct() 方法&#xff1f;用途&#xff1a;语法示例&#xff1a; 反序列化漏洞利用前提条件一些绕过策略绕过__wakeup函数绕过正则匹配绕过相…

docker 占用系统空间太大了,整体迁移到挂载的其他磁盘|【当前普通用户使用docker时,无法指定镜像、容器安装位置【无法指定】】

文章目录 前言【核心步骤皆为 大模型生成的方案】总结步骤应该是&#xff1a;详细步骤如下1. **停止 Docker 服务**2. **备份原数据&#xff08;防止迁移失败&#xff09;**3. **迁移数据到新磁盘**4. **修改 Docker 配置文件**5. **重启 Docker 服务**6. **验证容器和镜像**7.…

设计后端返回给前端的返回体

目录 1、为什么要设计返回体&#xff1f; 2、返回体包含哪些内容&#xff08;如何设计&#xff09;&#xff1f; 举例 3、总结 1、为什么要设计返回体&#xff1f; 在设计后端返回给前端的返回体时&#xff0c;通常需要遵循一定的规范&#xff0c;以确保前后端交互的清晰性…

Springboot 自动化装配的原理

Springboot 自动化装配的原理 SpringBoot 主要作用为&#xff1a;起步依赖、自动装配。而为了实现这种功能&#xff0c;SpringBoot 底层主要使用了 SpringBootApplication 注解。 首先&#xff0c;SpringBootApplication 是一个复合注解&#xff0c;它结合了 Configuration、…

基于vue框架的游戏博客网站设计iw282(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。

系统程序文件列表 项目功能&#xff1a;用户,博客信息,资源共享,游戏视频,游戏照片 开题报告内容 基于FlaskVue框架的游戏博客网站设计开题报告 一、项目背景与意义 随着互联网技术的飞速发展和游戏产业的不断壮大&#xff0c;游戏玩家对游戏资讯、攻略、评测等内容的需求日…

算法-二叉树篇13-路径总和

路径总和 力扣题目链接 题目描述 给你二叉树的根节点 root 和一个表示目标和的整数 targetSum 。判断该树中是否存在 根节点到叶子节点 的路径&#xff0c;这条路径上所有节点值相加等于目标和 targetSum 。如果存在&#xff0c;返回 true &#xff1b;否则&#xff0c;返回…