Qwen2.5-VL 损失函数

Qwen2.5-VL 损失函数

flyfish

文章名称链接
深入理解交叉熵损失 CrossEntropyLoss - 概率基础链接
深入理解交叉熵损失 CrossEntropyLoss - 对数链接
深入理解交叉熵损失 CrossEntropyLoss - 概率分布链接
深入理解交叉熵损失 CrossEntropyLoss - 信息论(交叉熵)链接
深入理解交叉熵损失 CrossEntropyLoss - 损失函数链接
深入理解交叉熵损失 CrossEntropyLoss - one-hot 编码链接
深入理解交叉熵损失 CrossEntropyLoss - Softmax链接
深入理解交叉熵损失 CrossEntropyLoss - 归一化链接
深入理解交叉熵损失 CrossEntropyLoss - nn.LogSoftmax链接
深入理解交叉熵损失 CrossEntropyLoss - 似然链接
深入理解交叉熵损失 CrossEntropyLoss - 乘积符号在似然函数中的应用链接
深入理解交叉熵损失 CrossEntropyLoss - 最大似然估计链接
深入理解交叉熵损失 CrossEntropyLoss - nn.NLLLoss(Negative Log-Likelihood Loss)链接
深入理解交叉熵损失 CrossEntropyLoss - CrossEntropyLoss链接

qwen2_5_vl/modular_qwen2_5_vl.py
qwen2_5_vl/modeling_qwen2_5_vl.py
文件的forward方法中使用了CrossEntropyLoss

loss = None
if labels is not None:# Upcast to float if we need to compute the loss to avoid potential precision issueslogits = logits.float()# Shift so that tokens < n predict nshift_logits = logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()# Flatten the tokensloss_fct = CrossEntropyLoss()shift_logits = shift_logits.view(-1, self.config.vocab_size)shift_labels = shift_labels.view(-1)# Enable model parallelismshift_labels = shift_labels.to(shift_logits.device)loss = loss_fct(shift_logits, shift_labels)if not return_dict:output = (logits,) + outputs[1:]return (loss,) + output if loss is not None else output

loss_fct(shift_logits, shift_labels)

1. shift_logits

  • 类型torch.Tensor
  • 含义:它代表模型的原始输出得分,也就是未经Softmax函数处理过的对数概率。在多分类任务里,每个样本的 shift_logits 是一个长度为词汇表大小(self.config.vocab_size)的向量,此向量的每个元素对应着模型预测该样本属于各个类别的得分。
  • 形状:在代码里,shift_logits 经过 view(-1, self.config.vocab_size) 操作后,其形状为 (N, C),其中 N 是所有样本的总数(即批次大小与序列长度的乘积),C 是类别数量(也就是词汇表的大小)。

2. shift_labels

  • 类型torch.LongTensor
  • 含义:它代表样本的真实标签。每个元素是一个整数,这个整数对应着样本的真实类别索引,其取值范围是 [0, C - 1],这里的 C 是类别数量(即词汇表的大小)。
  • 形状:在代码中,shift_labels 经过 view(-1) 操作后,其形状为 (N,),这里的 N 是所有样本的总数(即批次大小与序列长度的乘积),要和 shift_logits 的第一个维度保持一致。

示例代码

使用 CrossEntropyLoss 计算

import torch
from torch.nn import CrossEntropyLoss# 假设词汇表大小为10
vocab_size = 10
# 假设总样本数为5
N = 5# 生成随机的logits和labels
shift_logits = torch.randn(N, vocab_size)
shift_labels = torch.randint(0, vocab_size, (N,))# 创建CrossEntropyLoss实例
loss_fct = CrossEntropyLoss()# 计算损失
loss = loss_fct(shift_logits, shift_labels)
print(f"Loss: {loss.item()}")

shift_logits 是模型的原始输出得分,shift_labels 是样本的真实标签,loss_fct(shift_logits, shift_labels) 会计算出这些样本的交叉熵损失。

代码中使用了CrossEntropyLoss类来计算损失,在 PyTorch 中,CrossEntropyLoss结合了LogSoftmax和NLLLoss(负对数似然损失)。LogSoftmax操作会对模型的输出logits应用 Softmax 函数,将其转换为概率分布,然后再取对数;NLLLoss则基于这个对数概率分布和真实标签计算损失。因此,CrossEntropyLoss本质上实现的是 Softmax 交叉熵。

多分类场景的体现:shift_logits被调整为形状(-1, self.config.vocab_size),其中self.config.vocab_size表示词汇表的大小,这意味着模型的输出是一个多分类的概率分布,每个类别对应词汇表中的一个词。

简单的说

1. 基础概念:熵、交叉熵与KL散度

1.1 信息熵(Entropy)

信息熵是随机变量的不确定性度量,对于离散概率分布 P ( x ) P(x) P(x),其公式为 H ( P ) = − ∑ i P ( x i ) log ⁡ P ( x i ) H(P) = -\sum_{i} P(x_i) \log P(x_i) H(P)=iP(xi)logP(xi)。直观来看,熵越高意味着分布越“均匀”,不确定性越大——例如掷骰子结果的熵高于抛硬币,因为骰子的可能结果更多且分布更均匀,其不确定性更强。

1.2 交叉熵(Cross-Entropy)

交叉熵用于衡量用分布 Q Q Q 表示分布 P P P 的困难程度,公式为 H ( P , Q ) = − ∑ i P ( x i ) log ⁡ Q ( x i ) H(P, Q) = -\sum_{i} P(x_i) \log Q(x_i) H(P,Q)=iP(xi)logQ(xi)。交叉熵越小,说明 Q Q Q P P P 的差异越小,即 Q Q Q 越接近真实分布 P P P,因此常被用作衡量两个分布相似性的指标。

1.3 KL散度(Kullback-Leibler Divergence)

KL散度是对两个概率分布差异的量化度量,公式为 D K L ( P ∣ ∣ Q ) = ∑ i P ( x i ) log ⁡ P ( x i ) Q ( x i ) D_{KL}(P||Q) = \sum_{i} P(x_i) \log \frac{P(x_i)}{Q(x_i)} DKL(P∣∣Q)=iP(xi)logQ(xi)P(xi),其与交叉熵、信息熵的关系为 H ( P , Q ) = H ( P ) + D K L ( P ∣ ∣ Q ) H(P, Q) = H(P) + D_{KL}(P||Q) H(P,Q)=H(P)+DKL(P∣∣Q),即交叉熵等于信息熵与KL散度之和。当真实分布 P P P 固定时,最小化交叉熵等价于最小化KL散度,目标是让预测分布 Q Q Q 尽可能接近 P P P

2. 交叉熵损失在分类任务中的应用

2.1 二分类问题

在二分类场景中,模型需要预测样本属于0或1类的概率,通常通过sigmoid函数将输出映射到[0, 1]区间,得到概率 y ^ \hat{y} y^。二元交叉熵(BCE)损失函数为 L = − 1 N ∑ i = 1 N [ y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ] L = -\frac{1}{N} \sum_{i=1}^{N} \left[ y_i \log(\hat{y}_i) + (1-y_i) \log(1-\hat{y}_i) \right] L=N1i=1N[yilog(y^i)+(1yi)log(1y^i)]:当真实标签 y i = 1 y_i=1 yi=1 时,损失为 − log ⁡ ( y ^ i ) -\log(\hat{y}_i) log(y^i),鼓励 y ^ i \hat{y}_i y^i 接近1;当 y i = 0 y_i=0 yi=0 时,损失为 − log ⁡ ( 1 − y ^ i ) -\log(1-\hat{y}_i) log(1y^i),鼓励 y ^ i \hat{y}_i y^i 接近0。

2.2 多分类问题

多分类任务中,样本需被预测为 C C C 个类别之一,模型通过softmax函数将输出转换为概率分布 y ^ 1 , … , y ^ C \hat{y}_1, \dots, \hat{y}_C y^1,,y^C(满足 ∑ i = 1 C y ^ i = 1 \sum_{i=1}^{C} \hat{y}_i = 1 i=1Cy^i=1),损失函数为多类交叉熵 L = − 1 N ∑ i = 1 N ∑ c = 1 C y i , c log ⁡ ( y ^ i , c ) L = -\frac{1}{N} \sum_{i=1}^{N} \sum_{c=1}^{C} y_{i,c} \log(\hat{y}_{i,c}) L=N1i=1Nc=1Cyi,clog(y^i,c),其中 y i , c y_{i,c} yi,c 是one-hot标签(样本 i i i 属于类别 c c c 时为1,否则为0),该公式本质是对每个样本的真实类别对应的预测概率取负对数并求平均。

3. 交叉熵损失与最大似然估计

3.1 从最大似然到交叉熵

假设训练数据 ( x i , y i ) (x_i, y_i) (xi,yi) 独立同分布,模型预测为条件概率 P ( y ∣ x ; θ ) P(y|x; \theta) P(yx;θ),对数似然函数为 log ⁡ L ( θ ) = ∑ i = 1 N log ⁡ P ( y i ∣ x i ; θ ) \log \mathcal{L}(\theta) = \sum_{i=1}^{N} \log P(y_i|x_i; \theta) logL(θ)=i=1NlogP(yixi;θ),最小化负对数似然(NLL)即 L = − 1 N ∑ i = 1 N log ⁡ P ( y i ∣ x i ; θ ) L = -\frac{1}{N} \sum_{i=1}^{N} \log P(y_i|x_i; \theta) L=N1i=1NlogP(yixi;θ) 等价于最大化似然。对于分类问题,若 P ( y i ∣ x i ; θ ) P(y_i|x_i; \theta) P(yixi;θ) 是softmax分布,则NLL损失恰好对应交叉熵损失。

3.2 为什么用交叉熵而非MSE?

从梯度特性看,交叉熵在预测错误时梯度较大,能加速收敛;而MSE在预测值远离真实值时梯度较小,可能导致训练缓慢。从概率解释角度,交叉熵直接优化似然函数,与概率模型的训练目标一致,而MSE更适用于回归任务,不直接关联概率分布的拟合。

4. 交叉熵损失在语言模型中的应用

4.1 自回归语言模型

自回归语言模型的任务是根据前文 x 1 , … , x t − 1 x_1, \dots, x_{t-1} x1,,xt1 预测下一个token x t x_t xt 的概率分布,损失函数为对序列每个位置 t t t 计算交叉熵 L = − 1 T ∑ t = 1 T log ⁡ P ( x t ∣ x 1 , … , x t − 1 ) L = -\frac{1}{T} \sum_{t=1}^{T} \log P(x_t | x_1, \dots, x_{t-1}) L=T1t=1TlogP(xtx1,,xt1)。以PyTorch为例,假设模型输出logits为 [batch_size, seq_len, vocab_size],真实标签为 [batch_size, seq_len],可通过 nn.CrossEntropyLoss() 计算损失:loss = loss_fct(logits.view(-1, vocab_size), labels.view(-1)),其中函数会自动对logits应用softmax并计算交叉熵。

4.2 与分类任务的联系

语言模型中每个token位置的预测可视为独立的分类问题,词汇表大小即类别数,模型需在每个位置预测当前token属于词汇表中某个词的概率。因此,语言模型的训练本质是对序列中每个位置的“分类器”进行联合优化,与多分类任务的核心逻辑一致,只是序列场景下需要考虑上下文依赖关系。

交叉熵损失和交叉熵区别

1. 区别

交叉熵(信息论概念)是衡量两个概率分布差异的指标,数学定义为 H ( P , Q ) = − ∑ i P ( x i ) log ⁡ Q ( x i ) H(P, Q) = -\sum_{i} P(x_i) \log Q(x_i) H(P,Q)=iP(xi)logQ(xi),其中 P P P是真实分布, Q Q Q是预测分布,核心作用是衡量用分布 Q Q Q表示分布 P P P的“不匹配程度”,不匹配程度越低,交叉熵越小。而交叉熵损失(机器学习损失函数)是交叉熵在模型训练中的具体应用,本质上是对样本的交叉熵求平均,用于量化模型预测与真实标签的差距,作为优化目标,例如在分类任务中,其表达式为 L = − 1 N ∑ i = 1 N ∑ c = 1 C y i , c log ⁡ y ^ i , c L = -\frac{1}{N} \sum_{i=1}^{N} \sum_{c=1}^{C} y_{i,c} \log \hat{y}_{i,c} L=N1i=1Nc=1Cyi,clogy^i,c,其中 N N N是样本数, C C C是类别数, y i , c y_{i,c} yi,c是样本 i i i的one-hot标签, y ^ i , c \hat{y}_{i,c} y^i,c是模型预测的概率。

2. 联系:交叉熵损失是交叉熵的“工程化应用”

交叉熵损失的设计直接基于交叉熵的数学定义,在机器学习中,真实标签(如one-hot向量)可视为真实分布 P P P,模型预测的概率分布是 Q Q Q,此时交叉熵即对应单个样本的损失,而交叉熵损失是所有样本交叉熵的平均。例如在二分类问题中,真实标签 y y y(0或1)对应伯努利分布 P P P,模型预测概率 y ^ \hat{y} y^对应分布 Q Q Q,单个样本的交叉熵为 L = − [ y log ⁡ y ^ + ( 1 − y ) log ⁡ ( 1 − y ^ ) ] L = -[y \log \hat{y} + (1-y) \log (1-\hat{y})] L=[ylogy^+(1y)log(1y^)],这正是二元交叉熵损失(BCE Loss)的表达式。
交叉熵是通用指标,不依赖于机器学习任务,只要存在两个概率分布,就可用交叉熵衡量它们的差异(如在信息压缩、分布匹配中);而交叉熵损失是任务特定的优化目标,仅在模型训练时使用,目的是通过最小化损失让预测分布 Q Q Q逼近真实分布 P P P,例如在语言模型中,交叉熵损失对应“预测下一个token的概率与真实token的匹配程度”,而交叉熵本身可用来评估语言模型的困惑度(Perplexity)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/pingmian/83431.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode 1524. 和为奇数的子数组数目

好的&#xff01;让我们详细解释 LeetCode 1524. 和为奇数的子数组数目 这道题的思路和解法。 题目&#xff1a; https://leetcode.cn/problems/number-of-sub-arrays-with-odd-sum/description/ 题目分析 问题描述&#xff1a; 给定一个整数数组 arr&#xff0c;返回其中和…

kafka学习笔记(三、消费者Consumer使用教程——消费性能多线程提升思考)

1.简介 KafkaConsumer是非线程安全的&#xff0c;它定义了一个acquire()方法来检测当前是否只有一个线程在操作&#xff0c;如不是则会抛出ConcurrentModifcationException异常。 acquire()可以看做是一个轻量级锁&#xff0c;它仅通过线程操作计数标记的方式来检测线程是否发…

百度蜘蛛池的作用是什么?技术@baidutopseo

百度蜘蛛池的作用 百度蜘蛛池是一种用于集中管理和调度百度搜索引擎蜘蛛&#xff08;爬虫&#xff09;的技术工具。其主要目的是优化网站内容的收录效率&#xff0c;帮助搜索引擎更高效地发现和抓取网页内容。技术baidutopseo 提升收录效率 通过集中大量百度蜘蛛的访问请求&…

黑马Java面试笔记之 消息中间件篇(Kafka)

一. Kafka保证消息不丢失 Kafka如何保证消息不丢失 使用Kafka在消息的收发过程中都会出现消息丢失&#xff0c;Kafka分别给出了解决方案 生产者发送消息到Brocker丢失消息在Brocker中存储丢失消费者从Brocker接收消息丢失 1.1 生产者发送消息到Brocker丢失 设置异步发送 消息…

dis css port brief 命令详细解释

华为交换机命令 display css port brief 详细解释 display css port brief 是华为交换机中用于 快速查看堆叠&#xff08;CSS&#xff0c;Cluster Switch System&#xff09;端口状态及关键参数 的命令&#xff0c;适用于日常运维、堆叠链路健康检查及故障定位。以下是该命令的…

Redis 缓存问题及其解决方案

1. 缓存雪崩 概念&#xff1a;缓存雪崩是指在缓存层出现大范围缓存失效或缓存服务器宕机的情况下&#xff0c;大量请求直接打到数据库&#xff0c;导致数据库压力骤增&#xff0c;甚至可能引发数据库宕机。 影响&#xff1a;缓存雪崩会导致系统性能急剧下降&#xff0c;甚至导…

使用Python进行函数作画

前言 因为之前通过deepseek绘制一下卡通的人物根本就不像&#xff0c;又想起来之前又大佬通过函数绘制了一些图像&#xff0c;想着能不能用Python来实现&#xff0c;结果发现可以&#xff0c;不过一些细节还是需要自己调整&#xff0c;deepseek整体的框架是没有问题&#xff0…

关于list集合排序的常见方法

目录 1、list.sort() 2、Collections.sort() 3、Stream.sorted() 4、进阶排序技巧 4.1 空值安全处理 4.2 多字段组合排序 4.3. 逆序 5、性能优化建议 5.1 并行流加速 5.2 原地排序 6、最佳实践 7、注意事项 前言 Java中对于集合的排序操作&#xff0c;分别为list.s…

Java高级 | (二十二)Java常用类库

参考&#xff1a;Java 常用类库 | 菜鸟教程 一、核心Java类库 二、常用第三方库 以下是 Java 生态系统中广泛使用的第三方库&#xff1a; 类别库名称主要功能官方网站JSON 处理JacksonJSON 序列化/反序列化https://github.com/FasterXML/jacksonGsonGoogle 的 JSON 库https:…

几种常用的Agent的Prompt格式

一、基础框架范式&#xff08;Google推荐标准&#xff09; 1. 角色与职能定义 <Role_Definition> 你是“项目专家”&#xff08;Project Pro&#xff09;&#xff0c;作为家居园艺零售商的首席AI助手&#xff0c;专注于家装改造领域。你的核心使命&#xff1a; 1. 协助…

蛋白质结构预测软件openfold介绍

openfold 是一个用 Python 和 PyTorch 实现的 AlphaFold2 的开源复现版&#xff0c;旨在提升蛋白质结构预测的可复现性、可扩展性以及研究友好性。它允许研究者在不开源 DeepMind 原始代码的情况下&#xff0c;自由地进行蛋白结构预测的训练和推理&#xff0c;并支持自定义模型…

AD转嘉立创EDA

可以通过嘉立创文件迁移助手进行格式的转换 按照它的提示我们整理好文件 导出后是这样的&#xff0c;第一个文件夹中有原理图和PCB&#xff0c;可以把它们压缩成一个压缩包 这个时候我们打开立创EDA&#xff0c;选择导入AD 这样就完成了

MySQL(50)如何使用UNSIGNED属性?

在 MySQL 中&#xff0c;UNSIGNED 属性用于数值数据类型&#xff08;如 TINYINT、SMALLINT、MEDIUMINT、INT 和 BIGINT&#xff09;&#xff0c;表示该列只能存储非负整数。使用 UNSIGNED 属性可以有效地扩展列的正整数范围&#xff0c;因为它不需要为负数保留空间。 1. 定义与…

什么是链游,链游系统开发价格以及方案

2025 Web3钱包开发指南&#xff1a;从多版本源码到安全架构实战 在数字资产爆发式增长的今天&#xff0c;Web3钱包已成为用户进入链上世界的核心入口。作为开发者&#xff0c;如何高效构建安全、跨链、可扩展的钱包系统&#xff1f;本文结合前沿技术方案与开源实践&#xff0c…

文件IO流

IO使用函数 标准IO文件IO(低级IO)打开fopen, freopen, fdopenopen关闭fcloseclose读getc, fgetc, getchar, fgets, gets, fread printf fprintfread写putc, fputc, putchar, fputs, puts, fwrite scanf fscanfwrite操作文件指针fseeklseek其它fflush rewind ftell 文件描述符 …

云原生DMZ架构实战:基于AWS CloudFormation的安全隔离区设计

在云时代,传统的DMZ(隔离区)概念已经演变为更加灵活和动态的架构。本文通过解析一个实际的AWS CloudFormation模板,展示如何在云原生环境中构建现代化的DMZ安全架构。 1. 云原生DMZ的核心理念 传统DMZ是网络中的"缓冲区",位于企业内网和外部网络之间。而在云环境…

一、虚拟货币概述

1. 定义 - 虚拟货币是一种基于网络技术、加密技术和共识机制的数字货币&#xff0c;它不依赖传统金融机构发行&#xff0c;而是通过计算机算法生成&#xff0c;例如比特币、以太坊等。 2. 特点 - 去中心化&#xff1a;没有一个单一的机构或个人控制整个虚拟货币系统&#xff0c…

Make All Equal

给定一个循环数组 a1,a2,…,ana1​,a2​,…,an​。 你可以对 aa 至多执行 n−1n−1 次以下操作&#xff1a; 设 mm 为 aa 的当前大小&#xff0c;你可以选择任何两个相邻的元素&#xff0c;其中前一个不大于后一个&#xff08;特别地&#xff0c;amam​ 和 a1a1​ 是相邻的&a…

任务中心示例及浏览器强制高效下载实践

1. 效果展示 这里的进度展示&#xff0c;可以通过我们之前讲到的Vue3实现类ChatGPT聊天式流式输出(vue-sse实现) SSE技术实现&#xff0c;比如用户点击全量下载时&#xff0c;后台需要将PDF文件打包为ZIP文件&#xff0c;由于量较大&#xff0c;需要展示进度&#xff0c;用户点…

SpringBoot整合Flowable【08】- 前后端如何交互

引子 在第02篇中&#xff0c;我通过 Flowable-UI 绘制了一个简单的绩效流程&#xff0c;并在后续章节中基于这个流程演示了 Flowable 的各种API调用。然而&#xff0c;在实际业务场景中&#xff0c;如果要求前端将用户绘制的流程文件发送给后端再进行解析处理&#xff0c;这种…