【Next Token Prediction】VLM模型训练中数据集标签预处理详解

源代码来自:https://github.com/huggingface/nanoVLM/blob/main/data/collators.py

详解如下所示:

import torch#-------------------------------#
# 主要是在数据加载器的构建中被使用
#-------------------------------#class BaseCollator(object):def __init__(self, tokenizer):self.tokenizer               = tokenizerrandom_string_5_letters      = "xzyvd" # 作为“锚点”,查找它在模板化后的完整文本中的位置# 将输入消息转换成Chat模板格式的字符串 例如 "<|start|>assistant\nxzyvd<|end|>" 此为纯文本而不是被编码后得到的token idsrandom_string_chat_templated = self.tokenizer.apply_chat_template([{"role": "assistant", "content": random_string_5_letters}], tokenize=False, add_special_tokens=False)random_string_location       = random_string_chat_templated.find(random_string_5_letters) # 查找我们之前插入的“随机标记”出现的位置# 例如回复为<|start|>assistant\nxzyvd<|end|># 获取到nxzyvd开始后的位置, 然后从而获取到前缀的长度# 目的是在后续设置loss_mask时能够精准跳过模板前缀,只对assistant回复的实际内容进行监督self.prefix_len              = len(self.tokenizer.encode(random_string_chat_templated[:random_string_location])) # 找到前缀模板结束的位置#----------------------------------------------------------## 用于处理批量对话消息# 随后返回模型需要的token ids、attention mask以及loss mask# 1.将消息转换为模型所需的 token 格式# 2.根据消息中的role(例如 assistant)标记哪些token需要计算损失(loss_mask),即只对assistant的具体输出进行损失计算,而不对user的内容进行计算# 3.将所有输入统一padding到最大长度max_len,确保批次的输入大小一致#----------------------------------------------------------#def prepare_inputs_and_loss_mask(self, batched_messages, max_length=None):batch_token_ids: list[list[int]]  = [] # 保存每个批次消息的token idsbatch_masks:     list[list[int]]  = [] # 保存每个批次消息的loss_mask,即哪些token需要计算损失batch_attentions: list[list[int]] = [] # 保存每个批次消息的attention mask,模型用来指示哪些部分是有效输入,哪些是 paddingfor messages in batched_messages: # 每一条消息中都包含若干user和assistant的内容#---------------------------------------------------------------------------------------## 对于此处生成的attention mask# tokenizer会自动将padding部分的attention mask设为0,其余为1# 其作用为告诉模型哪些token是“真正需要注意的内容”,哪些只是为了凑长度而padding的垃圾位# 它是Transformer中注意力机制不可或缺的一部分,尤其在处理变长输入(如自然语言对话)时非常关键# NOTE:此处,tokenizer没有做统一长度 padding,而是保留了变长的attention_mask#---------------------------------------------------------------------------------------#conv_ids = self.tokenizer.apply_chat_template(messages,tokenize=True, # 控制attention mask相关内容add_special_tokens=False,return_dict=True,) # conv_ids是面向整个对话的一个字典,包含了对应的 input_ids(token ids)和 attention_maskmask   = [0] * len(conv_ids["input_ids"]) # 为每个对话消息初始化一个全零的 mask 列表# Locate each assistant turn and flip its mask to 1cursor = 0 # 用来记录当前已经处理过的token数量for msg in messages: # 对user与assistant的内容均进行处理segment_ids = self.tokenizer.apply_chat_template([msg], tokenize=True, add_special_tokens=False) # 将每条消息msg转换为token ids # 只包含这一条消息的内容seg_len = len(segment_ids) # 获取消息的长度, 即为每条消息的实际token数目#---------------------------------------## 当处理角色为assistant的时候展开下述操作# 只对其具体回复的内容进行操作#---------------------------------------#if msg["role"] == "assistant":start = cursor + self.prefix_len # 确定消息的起点end   = cursor + seg_len         # 根据消息的长度去确定终点mask[start:end] = [1] * (end - start)  # attend to these tokens # 将assistant的回复部分的mask设置为1cursor += seg_len # 因为一组对话中assistant回复的内容可能有多处, 因此需要进行累积batch_token_ids.append(conv_ids["input_ids"]) # token idsbatch_masks.append(mask) # 哪些token需要去计算batch_attentions.append(conv_ids["attention_mask"]) # 哪些部分是有效输入# NOTE:主要针对assistant回复过长的情况进行处理if max_length is not None:  # We need to keep the tokens to allow for the img embed replacing logic to work. Otherwise, we would need to track which images correspond to long samples.batch_token_ids  = [ids[:max_length] for ids in batch_token_ids] # 对超过max length的样本进行裁剪, 使其长度满足要求# 如果长度超过 max_length,则将其截断为全零的 mask(表示忽略该样本)batch_masks      = [m if len(m) <= max_length else [0]*max_length for m in batch_masks] # Ignore samples that are longer than max_lengthbatch_attentions = [a[:max_length] for a in batch_attentions] # 同样进行截取# Pad samples to max lengthif max_length is not None:max_len = max_lengthelse:max_len = max(map(len, batch_token_ids))# 对每个样本均展开padding操作batch_token_ids  = [[self.tokenizer.pad_token_id]*(max_len-len(ids)) + ids for ids in batch_token_ids] # 使用pad_token_id将长度填充到max lengthbatch_masks      = [[0]*(max_len-len(m)) + m         for m   in batch_masks]                           # 填充至最大长度max_len,使用0填充batch_attentions = [[0]*(max_len-len(a)) + a         for a   in batch_attentions]                      # 填充至最大长度max_len,使用0填充 # NOTE: 相当于是在tokenzier的基础上 根据max length去展开补充性paddingreturn torch.tensor(batch_token_ids), torch.tensor(batch_attentions), torch.tensor(batch_masks).to(torch.bool)#-------------------------------------#
# Visual Question Answering Collator
# 训练与验证数据集
#-------------------------------------#
class VQACollator(BaseCollator):def __init__(self, tokenizer, max_length):self.max_length  = max_lengthsuper().__init__(tokenizer)def __call__(self, batch):images           = [item["images"] for item in batch]messages_batched = [item["text_data"] for item in batch]# Stack imagesimgs   = [img for sublist in images for img in sublist]images = torch.stack(imgs)# Create inputs by concatenating special image tokens, question, and answerbatch_input_ids, batch_attention_mask, loss_masks = self.prepare_inputs_and_loss_mask(messages_batched, max_length=self.max_length)#--------------------------------------------------------------------------------------------------------------------------------------------------------------------------## Create labels where only answer tokens are predicted# 1. 首先将模型回复的内容全部复制一份出来, 然后将为mask为0的区域全部填充为-100, 表明直接忽视不参与计算# 2. 为适应因果语言建模, 展开标签平移操作, 作用为确保模型在展开语言生成任务时, 能够预测当前时间步的下一个token# 具体而言, labels[:, :-1]为选择每个样本的所有token中除去最后一个token的部分, labels[:, 1:]为获取每个样本中从第二个token到最后一个token的所有内容# 这样就可以将每个样本的所有token都可以向左移动一位, 从而将每个位置对应的token都用它的下一个token去进行预测。这样每个token的标签都变成了它的下一个token, 即为next token prediction# 3. 这样最后一个token由于没有标签目标, 直接设置为-100即可, 表明到了结尾# 例子:# batch_input_ids为[[101, 2001, 2023, 2045, 102]], 其中2001处的loss mask为0, 那么labels即为[[101, 2023, 2045, 102]]# 然后第一个样本的0 1 2 3四个位置上对应的label即变为[2023, 2045, 102, -100]# 这样就形成了真值标签[[2023, 2045, 102, -100]]#--------------------------------------------------------------------------------------------------------------------------------------------------------------------------#labels         = batch_input_ids.clone().masked_fill(~loss_masks, -100) # 将~loss_masks为1的地方填充为-100 NOTE:此处相当于就是无效的地方labels[:, :-1] = labels[:, 1:] # Shift labels for causal LMlabels[:, -1]  = -100 # Last token has no targetreturn {"image": images, # 图像"input_ids": batch_input_ids, # 输入内容"attention_mask": batch_attention_mask, # 告诉模型在等长序列中, 哪些是需要关注的实际token, 哪些是padding token"labels": labels, #标签}#--------------------------------------------------------#
# 测试数据集
# https://huggingface.co/datasets/Lin-Chen/MMStar
#--------------------------------------------------------#
class MMStarCollator(BaseCollator): def __init__(self, tokenizer):super().__init__(tokenizer)def __call__(self, batch):images           = [item["image"] for item in batch]messages_batched = [item["text_data"] for item in batch]# Stack imagesimages = torch.stack(images)# Create inputs by concatenating special image tokens, question, and answerbatch_input_ids, batch_attention_mask, loss_masks = self.prepare_inputs_and_loss_mask(messages_batched)#---------------------------------------------------------------------------------------------------------------------------------------------## 1. 把需要预测的位置(即 loss_masks=1)设成pad token, 这意味着这些位置不会被送去模型作为“输入”,因为它们是模型需要生成的内容# 2. 把要预测的部分在attention mask里屏蔽掉, 导致模型不会“看到”这些 token,符合推理阶段的auto-regressive decoding 逻辑# 3. 只保留需要预测的token作为标签,其余地方用pad填充#---------------------------------------------------------------------------------------------------------------------------------------------#"""example:query: "User: What color is the sky?\nAssistant: The sky is"prediction: "blue."那么 loss_mask 会标记 "blue." 这一段, collator就会:把 input_ids 中 "blue." 变成pad(输入时忽略)把 attention_mask 中对应位置设为0(不关注)把 labels 中 "blue." 保留, 其余是pad(只评估蓝天这个词)"""input_ids      = batch_input_ids.masked_fill(loss_masks, self.tokenizer.pad_token_id)attention_mask = batch_attention_mask.masked_fill(loss_masks, 0)labels         = batch_input_ids.clone().masked_fill(~loss_masks, self.tokenizer.pad_token_id)return {"images": images,"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels,}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/88907.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/88907.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Istio 简介

Istio 简介 什么是 Istio Istio 是一个开源的 服务网格&#xff08;Service Mesh&#xff09; 框架&#xff0c;由 Google、IBM 和 Lyft 联合开发&#xff0c;目前属于 CNCF&#xff08;云原生计算基金会&#xff09;项目。它主要用于管理和连接微服务架构中的服务&#xff0…

融云在华为开发者大会分享智能办公平台的鸿蒙化探索实践

6 月 20 日-22 日&#xff0c;“华为开发者大会&#xff08;HDC 2025&#xff09;”在东莞隆重召开&#xff0c;融云受邀出席并在“政企内部应用论坛”发表主旨演讲。 鸿蒙为千行百业的生态伙伴创新带来了独特的历史机遇&#xff0c;其蓬勃发展也为我国数字经济高质量发展提供…

滚珠导轨如何助力自动化生产实现高质量输出?

在自动化生产线的蓬勃发展中&#xff0c;高效、精准与稳定是核心追求。滚珠导轨作为关键的传动部件&#xff0c;以其独特的优势&#xff0c;在众多自动化生产场景里大放异彩&#xff0c;为生产流程的优化和产品质量的提升显著提高设备系统的稳定性和可靠性。 汽车自动化装配线 …

消息队列的推拉模式详解:实现原理与代码实战

消息队列是现代分布式系统中不可或缺的中间件&#xff0c;它通过"生产者-消费者"模式实现了系统间的解耦和异步通信。本文将深入探讨消息队列中的两种核心消息传递模式&#xff1a;推送(Push)和拉取(Pull)&#xff0c;并通过代码示例展示它们的实现方式。 目录 消息…

OpenCV图像噪点消除五大滤波方法

在数字图像处理中&#xff0c;噪点消除是提高图像质量的关键步骤。本文将基于OpenCV库&#xff0c;详细讲解五种经典的图像去噪滤波方法&#xff1a;均值滤波、方框滤波、高斯滤波、中值滤波和双边滤波&#xff0c;并通过丰富的代码示例展示它们的实际应用效果。 一、图像噪点…

Rust宏和普通函数的区别

Rust 中的宏&#xff08;macro&#xff09;和普通函数有以下核心区别&#xff0c;分别从用途、扩展方式、性能影响和语法特征等多个方面来解释&#xff1a; &#x1f4cc; 1. 定义方式 项目宏函数定义方式macro_rules! 或 macro&#xff08;新版&#xff09;fn 关键字调用方式…

基于Qt C++的影像重采样批处理工具设计与实现

摘要 本文介绍了一种基于Qt C++框架开发的高效影像重采样批处理工具。该工具支持按分辨率(DPI) 和按缩放倍率两种重采样模式,提供多种插值算法选择,具备强大的批量处理能力和直观的用户界面。工具实现了影像处理的自动化流程,显著提高了图像处理效率,特别适用于遥感影像处…

TypeScript 中的 WebSocket 入门

如何开始使用 Typescript 和 React 中的 WebSockets 创建一个简单的聊天应用程序 示例源码&#xff1a;ws 下一篇&#xff1a;https://blog.csdn.net/hefeng_aspnet/article/details/148898147 介绍 WebSocket 是一项我目前还没有在工作中使用过的技术&#xff0c;但我知道…

TMS汽车热管理系统HILRCP解决方案

TMS汽车热管理系统介绍 随着汽车电动化和智能化的发展&#xff0c;整车能量管理内容增多&#xff0c;对汽车能量管理的要求也越来越高&#xff0c;从整车层面出发对各子系统进行能量统筹管理将成为电动汽车未来的发展趋势&#xff0c;其中汽车热管理是整车能量管理的重要组成部…

CCleaner Pro v6.29.11342 绿色便携版

CCleaner Pro v6.29.11342 绿色便携版 CCleaner是Piriform&#xff08;梨子公司&#xff09;最著名广受好评的系统清理优化及隐私保护软件&#xff0c;也是该公司主打和首发产品&#xff0c;它体积小、扫描速度快&#xff0c;具有强大的自定义清理规则扩展能力。CCleaner是一款…

不做手机控APP:戒掉手机瘾,找回专注与自律

在当今数字化时代&#xff0c;手机已经成为我们生活中不可或缺的一部分。然而&#xff0c;过度依赖手机不仅会分散我们的注意力&#xff0c;影响学习和工作效率&#xff0c;还可能对身心健康造成负面影响。为了帮助用户摆脱手机依赖&#xff0c;重拾自律和专注&#xff0c;一款…

Go 语言中的接口

1、接口与鸭子类型 在 Go 语言中&#xff0c;接口&#xff08;interface&#xff09;是一个核心且至关重要的概念。它为构建灵活、可扩展的软件提供了坚实的基础。要深入理解 Go 的接口&#xff0c;我们必须首先了解一个在动态语言中非常普遍的设计哲学——鸭子类型&#xff0…

在项目中如何巧妙使用缓存

缓存 对于经常访问的数据&#xff0c;每次都从数据库&#xff08;硬盘&#xff09;中获取是比较慢&#xff0c;可以利用性能更高的存储来提高系统响应速度&#xff0c;俗称缓存 。合理使用缓存可以显著降低数据库的压力、提高系统性能。 那么&#xff0c;什么样的数据适合缓存…

SLAM中的非线性优化-2D图优化之零空间(十五)

这节在进行讲解SLAM中一个重要概念&#xff0c;零空间&#xff0c;讲它有啥用呢&#xff1f;因为SLAM中零空间的存在&#xff0c;才需要FEJ或固定约束存在&#xff0c;本节内容不属于2D图优化独有&#xff0c;先看看什么是零空间概念&#xff1b;零空间是一个核心概念&#xff…

如何解决本地DNS解析失败问题?以连接AWS ElastiCache Redis为例

在云服务开发中,DNS解析问题常常成为困扰开发者的隐形障碍。本文将通过AWS ElastiCache Redis连接失败的实际案例,详细介绍如何诊断和解决DNS解析问题,帮助你快速恢复服务连接。 引言 在使用 telnet 或 redis-cli 连接 AWS ElastiCache Redis 时,有时会遇到类似以下错误:…

探索钉钉生态中的宜搭:创建与分享应用的新视界

在当今快速发展的数字化时代&#xff0c;企业对于高效协作和信息管理的需求日益增长。作为阿里巴巴集团旗下的智能工作平台&#xff0c;钉钉不仅为企业提供了强大的沟通工具&#xff0c;其开放的生态系统也为用户带来了无限可能。其中&#xff0c;宜搭&#xff08;YiDa&#xf…

深入理解事务和MVCC

文章目录 事务定义并发事务代码实现 MVCC定义核心机制 事务 定义 什么是事务&#xff1f; 事务是指一组操作要么全部成功&#xff0c;要么全部失败的执行单位。 在数据库中&#xff0c;一个事务通常包含一组SQL语句&#xff0c;系统保证这些语句作为一个整体执行。 为什么引…

用 Python 绘制精美雷达图:多维度材料属性对比可视化全指南

&#x1f31f; 为什么选择雷达图&#xff1f;从材料科学到多维数据对比的可视化利器 在科研和数据分析领域&#xff0c;当我们需要同时展示多个维度的数据对比时&#xff0c;传统的柱状图或折线图往往显得力不从心。这时候&#xff0c;雷达图&#xff08;Radar Chart&#xff…

Excel学习03

超级表与图表 Excel中具有超级表的功能。所谓超级表&#xff08;官方名称为“表格”&#xff0c;快捷键CtrlT&#xff09;是Excel中一个强大的数据管理工具&#xff0c;它将普通的数据区域转换为具有只能功能的交互式表格。 这就是表格变为超级表的样子。超级表默认具备冻结窗…

Netflix 网飞的架构演进过程、Java在网飞中的应用|图解

写在前面 上一篇文章中&#xff0c;我们讲解了网飞当前的架构&#xff0c;但网飞的架构并不是一开始就是这样的&#xff0c;而是不断演进发展才是当前的样子。 这篇文章我们就来讲讲网飞架构的演进过程。 第一阶段&#xff1a;Zuul Gateway REST API 使用 Zuul 作为API网关…