手搓多模态-08 主模型的搭建(上)

前情回顾

在之前的章节我们已经构建好了视觉编码器,预处理模块,以及gemma模型的顶层。gemma模型的顶层,主要是构建图中圈出的输入,它把视觉编码器里每个图像patch的编码维度对齐到自然语言token的嵌入维度,并组装成了一个大的输入向量。同时在模型的顶层,我们准备好了位置id 以及attention mask,用来在后面的模型层计算旋转位置编码和注意力得分矩阵。接下来,我们要开始构建gemma模型的架构了。

顶层模型 GemmaForCausalLM

还记得吗,在之前的paligemma模型的顶层,我们有一个GemmaForCausalLM,然后我们通过下面的代码把输入传入了语言模型:

self.language_model = GemmaForCausalLM(config.text_config)
		outputs = self.language_model(
			inputs_embeds = input_embeds,
			position_ids = position_ids,
			attention_mask = attention_mask,
			kv_cache = kv_cache,
			**kwargs
		)

现在我们首先要实现这个GemmaForCausalLM。

一般模型的上层是对整个模型逻辑的简单封装,故这里GemmaForCausalLM的作用很简单,它仅仅把上下文编码后的注意力嵌入通过一个MLP转换为不同token的输出概率,也就是logits,然后返回给上层,从而让上层根据概率分布来采样下一个要输出的token是什么。

先给出代码:

class GemmaForCausalLM(nn.Module): ## 匹配
	def __init__(self,config:GemmaConfig): ##CasualLM实际上是Transformer模型加一个投影层,即将嵌入转换为对数概率
		super().__init__()
		self.config = config
		self.model = GemmaModel(config)
		self.vocab_size = config.vocab_size
		self.lm_head = nn.Linear(config.hidden_size,config.vocab_size,bias=False)	def get_input_embeddings(self): ##这里返回的是模型对象本身
		return self.model.embed_tokens	def tie_weights(self):
		self.lm_head.weight = self.model.embed_tokens.weight	def forward(
		self,
		attention_mask: Optional[torch.Tensor] = None,
		inputs_embeds: Optional[torch.FloatTensor] = None,
		kv_cache: Optional[KVCache] = None,
		position_ids: Optional[torch.Tensor] = None
	):
		'''
		input: [Batch_size, Seq_len, Hidden_size]
		output: [Batch_size, Seq_len, Vocab_size]
		'''
		## [Batch_size, Seq_len, Hidden_size]
		outputs = self.model(
			attention_mask = attention_mask,
			inputs_embeds = inputs_embeds,
			kv_cache = kv_cache,
			position_ids = position_ids
		)		hidden_states = outputs
		logits = self.lm_head(hidden_states) #lm_head负责将hidden_states映射到vocab_size维度的向量,即logits
		logits = logits.float()		return_data = {
			"logits": logits 
		}
		if kv_cache is not None:
			return_data["kv_cache"] = kv_cache ##这里kv cache是要传递下去的,因为自回归的逻辑下,后面生成的token的注意力计算要能够通过kv cache来看到之前的token的kv		return return_data

以上便是顶层模型的前向传递过程:

  • 就是通过 GemmaModel 生成的注意力嵌入来计算logits
  • 注意:由于我们在推理过程中,后续的token计算要用到之前的kv,所以kv cache必须在推理的过程中依次传递下去,同时也要返回给上层,从而在下一次推理运算的时候有kv cache可以传入。
  • 我们之前用到了参数捆绑的策略,即token嵌入的模型参数等于嵌入反解码成logits的模型参数,所以我们提供这两个函数供上层调用:
def get_input_embeddings(self): ##这里返回的是模型对象本身return self.model.embed_tokensdef tie_weights(self):self.lm_head.weight = self.model.embed_tokens.weight

GemmaModel

GemmaModel里面实际上就是一个注意力块的序列,就像一个注意力块数组一样,而该层需要做的仅仅是将输入在不同的注意力块里依次传递,并把最后一个注意力块的输出返回给上层即可。

class GemmaModel(nn.Module): ## 匹配def __init__(self,config:GemmaConfig):super().__init__()
		self.config = config
		self.hidden_size = config.hidden_size
		self.embed_tokens = nn.Embedding(config.vocab_size,config.hidden_size,padding_idx=config.pad_token_id)
		self.layers = nn.ModuleList([GemmaLayer(config, _) for _ in range(config.num_hidden_layers)])
		self.norm = GemmaRMSNorm(config.hidden_size,eps=config.rms_norm_eps) ##Root Mean Square Normalization均方根标准化,该论文表明并不一定要标准化到标准正态分布,而是只要方差为1就可以def forward(
		self,
		attention_mask: Optional[torch.Tensor] = None,
		inputs_embeds: Optional[torch.FloatTensor] = None,
		kv_cache: Optional[KVCache] = None,
		position_ids: Optional[torch.Tensor] = None):#[Batch_size, Seq_len, Hidden_size]
		hidden_states = inputs_embeds
		normalizer = torch.tensor(self.hidden_size ** 0.5,dtype= inputs_embeds.dtype)
		hidden_states = hidden_states * normalizerfor layer in self.layers:
			hidden_states = layer(
				hidden_states = hidden_states,
				attention_mask = attention_mask,
				kv_cache = kv_cache,
				position_ids = position_ids)## 均方根归一化,不改变shape
		hidden_states = self.norm(hidden_states)return hidden_states

这里我们用一个nn.ModuleList来存储所有的GemmaLayer,一个GemmaLayer实际上就是一个attention 块。值得注意的是,在每个attention块内部我们将会做两次归一化,但是每个attention layer的输出不会做归一化,为了使得上层的计算能拿到归一化后的结果,我们在整个list前向传递完了之后再补一个normalization的过程:

hidden_states = self.norm(hidden_states)
  • 注意:我们此处用的是RMSNorm,即均方根归一化,关于这个归一化与之前的其他归一化的不同我们会在文末补充一些资料。

有人可能想问,为什么嵌入模型会放到这里:self.embed_tokens

这是因为paligemma的作者是这么实现的,而我们将从huggingface来导入整个模型的参数,所以我们的架构也必须和作者一样才能正确导入参数,所以我们不得不放在这里。

GemmaLayer

在一个attention块里面我们有一个多头注意力层和一个前向传播网络,以及两个归一化,但我们实际的实现中会把归一化提前,即add&norm -> attention -> add&norm -> ff。

这也就是为什么上面提到在layer的输出没有做归一化。

代码如下:

class GemmaLayer(nn.Module): ##匹配def __init__(self,config:GemmaConfig,layer_idx:int): ##layer_idx是当前layer的索引,辅助attention存储kv_cachesuper().__init__()self.config = configself.layer_idx = layer_idxself.hidden_size = config.hidden_sizeself.intermediate_size = config.intermediate_sizeself.input_layernorm = GemmaRMSNorm(config.hidden_size,eps=config.rms_norm_eps)self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,eps=config.rms_norm_eps)self.mlp = GemmaMLP(config)self.self_attn = GemmaAttention(config,layer_idx)def forward(self,hidden_states: torch.Tensor,attention_mask: Optional[torch.Tensor] = None,kv_cache: Optional[KVCache] = None,position_ids: Optional[torch.Tensor] = None)-> Tuple[torch.Tensor,Optional[Tuple[torch.FloatTensor,torch.FloatTensor]]]:
		'''input: [Batch_size, Seq_len, Hidden_size]output: [Batch_size, Seq_len, Hidden_size]		'''residual = hidden_stateshidden_states = self.input_layernorm(hidden_states)hidden_states,_ = self.self_attn(hidden_states = hidden_states,attention_mask = attention_mask,kv_cache = kv_cache,position_ids = position_ids)hidden_states = residual + hidden_statesresidual = hidden_stateshidden_states = self.post_attention_layernorm(hidden_states)hidden_states = self.mlp(hidden_states)hidden_states = residual + hidden_statesreturn hidden_states

  • 在这里的两个归一化我们也用RMSNorm来进行归一化,注意除了归一化,我们还要处理好残差。
  • 残差的作用是防止梯度为0导致训练缓慢。

RMSNorm

在前面的第四章节:手搓多模态-04 归一化介绍 里面我们介绍了BatchNormalization和LayerNormalization,我们了解到以下信息:

  • 归一化是为了防止不同模型层的输入输出不稳定,分布不均匀导致的训练速度过慢
  • BN 依赖于batch 的规模,而batch的规模过大会导致训练速度变相过慢
  • LN 通过对单个样本的所有特征进行标准化规避了BN的问题,主要做法是对单个样本的所有特征计算均值和方差,从而将其分布转换为0-1分布。

RMSNormalization,又称均方差归一化,是由论文《Root Mean Square Layer Normalization》提出的,该文章发现,其实分布不稳定的问题和均值没有关系,主要是方差的问题,所以只需要特征的方差稳定即可,不需要计算均值,这样可以减少计算的时间,从而加速训练。

论文提出用均方根来对每个值进行缩放,从而使得方差更小,如图所示。

其中,a_i 表示缩放前的特征值,RMS(a)表示所有特征值计算出来的均方根,g是一个可学习的参数向量,b是偏置。

在paligemma的实现中,RMSNorm的代码如下:

class GemmaRMSNorm(nn.Module): ##匹配
	def __init__(self,dim,eps=1e-6): ##dim是hidden_size
		super().__init__()
		self.dim = dim
		self.eps = eps
		self.weight = nn.Parameter(torch.ones(dim))	def _norm(self,x):		return x * torch.rsqrt(x.pow(2).mean(dim = -1,keepdim=True) + self.eps) ##rsqrt表示平方的倒数,self.eps是防止分母为0	def forward(self,x):
		x = self._norm(x)
		output = x * (1.0 + self.weight.float()) ##论文中的可学习参数g
		return output.type_as(x)

其中特征的维度为嵌入的维度大小。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/85269.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/85269.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Matlab 角点探测

文章目录 一、简介二、实现代码三、实现效果一、简介 这里实现一种角点探测功能,其思路仍然是借助图像的局部梯度信息,实现亚像素精度的角点定位。该功能核心思想是利用角点周围的局部梯度信息,通过加权最小二乘优化的方式迭代调整角点位置,使定位精度达到亚像素级别。整个…

错误监控----比如实现sentry一些思路

错误监控 ⼀、引⾔ 1.为什么需要前端错误监控 你的脚本在哪些边界条件下会报错? 你的脚本和样式兼容性如何? 有哪些地区不能正常访问你的⽹站? 出现问题之后,你如何快速定位排查,把损失降到最低? 如果你想解…

linux内核调试

1. 前置安装 1.1 编译好的内核 参考: https://blog.csdn.net/qq_51950769/article/details/148596916 1.2 编译busybox BusyBox 是一个非常轻量级的多合一工具箱,常被称为“Linux 的瑞士军刀”。 简单来说: 它把很多常用的 Linux 命令&am…

什么是曲面细分

什么是曲面细分 在CAD格式中,通常使用曲线和数学函数来定义曲面和实体。这些曲面的精确度和光滑度非常适用于制造过程。但是,现代GPU芯片针对由三角形网格体组成的曲面的渲染计算进行了高度优化。通常,实时渲染器和虚幻之类的游戏引擎只能处…

CANFD加速是什么?和CANFD有什么区别?

文章目录 摘要什么是CANFD加速?CAN FD的基本原理:仲裁阶段(Arbitration Phase):数据阶段(Data Phase):关键特性:优势:总结摘要 下面的截图,大家肯定不陌生,在使用CAN设备上位机的时候,已经选择了CANFD,但还有一个选项是“CANFD加速”,那CANFD加速和不加速有什么…

minio 启动失败--Incorrect Usage: flag provided but not defined: -consoleaddress

根据错误信息 flag provided but not defined: -consoleaddress,这表明 Minio 服务启动时使用了未定义的命令行参数 --consoleaddress,导致启动失败。这个问题与 Minio 版本兼容性有关。 问题原因 参数名称变更: Minio 版本 > RELEASE.20…

基于Rust的Polars学习笔记

基于Rust的Polars学习笔记 Polars 学习笔记 Cargo.toml通用配置 [package] name = "rustP" version = "0.1.0" edition = "2024"[dependencies] polars = { version = "0.48.1", features = ["full"]}Quickstart use po…

SpringBoot扩展——定时任务!

定时任务 项目开发中会涉及很多需要定时执行的代码,如每日凌晨对前一日的数据进行汇总,或者系统缓存的清理、对每日的数据进行分析和总结等需求,这些都是定时任务。单体系统和分布式系统的分布式任务有很大的区别,单体系统就一个…

RTDETRv2 pytorch 官方版自己数据集训练遇到的问题解决

rtdetrv2 训练问题遇到的问题。 pip install torch2.0.1 torchvision0.15.2 torchaudio2.0.2 --index-url https://download.pytorch.org/whl/cu117 1 Please make sure torchvision version > 0.15.2 发现自己实际装的是 torchvison0.15.2cu117 修改_misc.py中修改为…

Linux系统移植⑤:uboot启动流程详解-board_init_f执行过程

Linux系统移植⑤:uboot启动流程详解-board_init_f执行过程 _main 中会调用 board_init_f 函数。 board_init_f 函数主要有两个工作: ①初始化一系列外设,比如串口、定时器,或者打印一些消息等。 ②初始化 gd 的各个成员变量&am…

Git命令与代码仓库管理

步骤一、完成Gitee码云上账号注册并新建代码仓库。 1.1 新建代码仓库 1.2 填写信息并创建 1.3 获取仓库地址 https://gitee.com/dog-kidney/2022082206.git 步骤二、建立本地代码仓库,并连接到远程代码仓库。 2.1初始化 git init 2.2添加仓库 git remote add o…

资源占用多,Linux 系统中如何降低 CPU 资源消耗并提升利用率?

在 Linux 系统中降低 CPU 资源消耗并提升利用率,需从系统服务优化、进程管理、资源调度及内核参数调整等多维度入手。以下是适用于各类 Linux 发行版的通用优化方案,涵盖基础操作与进阶策略: 一、服务与进程优化:减少无效资源占用 1. 关闭冗余系统服务 查看运行中的服务 …

技术与情感交织的一生 (八)

目录 融合 东西厂公 接风宴 头痛 “巴巴罗萨” 突击 推进 助攻 96小时 寒冬 食堂 反攻 消耗 Delphi 西厂 内困 外患 “敦刻尔克” 多线作战 大撤退 资源 融合 东西厂公 初次来到纸箱厂,是主厂区,感觉很大,相对西面正在…

webuploader分片上传示例,服务端上传文件到腾讯云CDN Teo 应用示例

本文环境:php7.3.4 CI3.0框架 一、大概步骤: (1)利用百度的webuploader插件,将大文件分片上传的自己的服务器 (2)利用腾讯云接口从本服务器上传到腾讯云 二、详细代码: 1、进入…

LeetCode 632.最小区间

你有 k 个 非递减排列 的整数列表。找到一个 最小 区间&#xff0c;使得 k 个列表中的每个列表至少有一个数包含在其中。 我们定义如果 b-a < d-c 或者在 b-a d-c 时 a < c&#xff0c;则区间 [a,b] 比 [c,d] 小。 示例 1&#xff1a; 输入&#xff1a;nums [[4,10,…

篇章五 系统性能优化——资源优化——CPU优化(2)

目录 1.高级并发模式 1.1 工作窃取&#xff08;Work Stealing&#xff09; 1.工作窃取模式 2.ForkJoinPool实现 3.具体例子 1.2 结构化并发&#xff08;Structured Concurrency&#xff09; 1.结构化并发模式 2.Java 19 的 StructuredTaskScope 3.具体例子 1.3 对比与…

《中国电信运营商骨干网:历史、现状与未来演进》系列 第四篇:后发先至——中国移动CMNET的快速扩张与IP专网布局

摘要&#xff1a; 本文深入探讨中国移动骨干网CMNET (AS9808) 的发展历程、网络架构及其与中国电信扁平化策略的差异。同时&#xff0c;解析其为承载高价值业务而构建的IP专用承载网的定位、结构与技术特点。最后&#xff0c;展望中国移动在5G、云计算和算力网络时代&#xff0…

R情感分析:解码文本中的情感

基于之前关于文本聚类和文本模型的博客&#xff0c;我们现在可以深入探讨一个经典主题 - 情感分析。情感分析通过计算方式识别和分类文本中的情感&#xff0c;帮助理解公众意见或消费者反馈。 什么是情感分析&#xff1f; 情感分析确定文本背后的情感基调&#xff0c;将其分类…

云徙渠道订货系统:赋能企业渠道管理的数字化引擎

在当今商业竞争日益激烈的环境下&#xff0c;企业如何高效管理和优化渠道成为关键问题。云徙渠道订货系统凭借其强大的数字化能力&#xff0c;为企业提供了全新的渠道管理解决方案&#xff0c;助力企业在复杂多变的市场环境中保持竞争力。 从渠道管理的痛点出发 传统渠道管理方…

Nacos基础使用(二):nacos作为配置中心

一、Nacos 配置中心核心属性 在学习nacos 作为配置中心的使用之前&#xff0c;先看下Nacos 作为配置中心时的三个属性&#xff0c;即&#xff1a; 命名空间、配置分组、配置集ID&#xff08;习惯称为配置文件ID&#xff09;&#xff1b;在使用Nacos 作为配置中心 的过程中可以通…