大模型Decoder-Only深入解析

Decoder-Only整体结构

我们以模型Llama-3.1-8B-Instruct为例,打印其结构如下(后面会慢慢解析每一部分,莫慌):

LlamaForCausalLM((model): LlamaModel((embed_tokens): VocabParallelEmbedding(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)(layers): ModuleList((0-31): 32 x LlamaDecoderLayer((self_attn): LlamaAttention((qkv_proj): QKVParallelLinear(in_features=4096, output_features=6144, bias=False, tp_size=1, gather_output=False)(o_proj): RowParallelLinear(input_features=4096, output_features=4096, bias=False, tp_size=1, reduce_results=True)(rotary_emb): Llama3RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=131072, base=500000.0, is_neox_style=True)(attn): RadixAttention())(mlp): LlamaMLP((gate_up_proj): MergedColumnParallelLinear(in_features=4096, output_features=28672, bias=False, tp_size=1, gather_output=False)(down_proj): RowParallelLinear(input_features=14336, output_features=4096, bias=False, tp_size=1, reduce_results=True)(act_fn): SiluAndMul())(input_layernorm): RMSNorm()(post_attention_layernorm): RMSNorm()))(norm): RMSNorm())(lm_head): ParallelLMHead(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)(logits_processor): LogitsProcessor()(pooler): Pooler()
)

Decoder-Only处理流程

我们以Llama-3.1-8B-Instruct模型为例,结合一个具体的聊天对话场景,详细说明Decoder-Only模型的处理流程,从用户输入到最终输出回答。整个过程会逐步拆解,并标注每个步骤的输入输出形状(假设batch_size=1,seq_len=10,hidden_dim=4096,词表大小=128000)。

1. 用户输入与聊天模板处理

场景:用户问:“如何做西红柿炒鸡蛋?”
模型需求:需要根据历史对话和当前问题生成回答。

聊天模板处理
  • 输入文本text:原始用户输入(如“如何做西红柿炒鸡蛋?”)
  • 模板化prompt:模型需要将输入包装成特定格式的prompt,例如:
    [系统指令]:你是一个烹饪助手,请回答以下问题。
    [用户]:如何做西红柿炒鸡蛋?
    [助手]:
    
  • 作用:模板化prompt让模型明确任务目标(如回答问题),并模拟对话上下文。

输入输出形状

  • 输入文本长度:假设为10个字符(实际长度取决于具体输入)。
  • 模板化后的prompt长度:假设为30个字符(包含系统指令、用户问题和占位符)。

2. Tokenizer处理:从prompt到input_ids

步骤

  1. Tokenization:将模板化prompt拆分为模型能理解的Token(如“西红柿”→“西红柿”,“炒”→“炒”)。
  2. 映射到input_ids:每个Token被映射为对应的ID(例如,“西红柿”→1234,“炒”→5678)。

示例
假设模板化Prompt被拆分为10个Token,其input_ids为:

[101, 1234, 5678, 8901, 2345, 6789, 102, 3456, 7890, 102]

(其中101和102是特殊标记,如<BOS><EOS>,表示开始和结束)

输入输出形状

  • input_ids的形状为 (batch_size, seq_len)(1, 10)
  • attention_mask(可选)的形状为 (1, 10),标记哪些位置是有效Token(1)或填充(0)。

3. 嵌入层:input_ids → hidden_states

步骤

  1. Token Embedding:将input_ids映射为高维向量(如4096维)。
  2. Positional Encoding:添加位置信息,让模型知道每个Token在序列中的位置。

示例

  • input_ids [101, 1234, 5678, ...] → 隐藏状态 hidden_states 的形状为 (1, 10, 4096)
  • 每个Token对应的向量包含其语义和位置信息(例如,“西红柿”对应的食物相关特征,以及它在句子中的位置)。

输入输出形状

  • hidden_states 的形状为 (batch_size, seq_len, hidden_dim)(1, 10, 4096)

4. Decoder Block处理:逐层计算

核心流程

  1. Masked Self-Attention(带掩码的自注意力)

    • 每个Token只能看到自己及之前的Token(防止“偷看”未来内容)。
    • 例如,在生成“西红柿炒鸡蛋”时,模型会先处理“西红柿”,再处理“炒”,确保生成逻辑连贯。
  2. 前馈网络(FFN)

    • 对每个Token的隐藏状态进行非线性变换,增强表达能力。

示例

  • 假设模型有32层Decoder Block,每层都会更新 hidden_states
  • 最终的 hidden_states 保留了完整的上下文信息(如“西红柿炒鸡蛋”的步骤描述)。

输入输出形状

  • 每层Decoder Block的输入输出形状不变,仍为 (1, 10, 4096)

5. LM Head:从hidden_states到下一个词

步骤

  1. 线性层:将最后一个Token的隐藏状态(形状为 (1, 10, 4096))映射到词表维度(128000)。
    • 例如,对最后一个位置(seq_len=9)的隐藏状态取值:hidden_states[:, 9, :] → 形状 (1, 4096)
  2. Softmax:将输出转换为概率分布(每个词的概率)。

示例

  • 假设模型预测下一个词是“步骤一”,其ID为9876,则概率分布中9876的值最高。

输入输出形状

  • 线性层输出形状:(1, 128000)
  • 概率分布形状:(1, 128000)

6. 采样策略:从概率分布到下一个词

方法

  • Top-k采样:从概率最高的前k个词(如k=50)中随机选一个。
  • Greedy Search:直接选概率最高的词(如“步骤一”)。

示例

  • 模型选择“步骤一”作为下一个词,并将其ID(9876)添加到 input_ids 中。
  • 新的 input_ids 变为:[101, 1234, 5678, ..., 9876](长度+1)。

输入输出形状

  • 新的 input_ids 形状为 (1, 11)

7. 迭代生成:重复步骤3-6直到完成

流程

  1. 将新的 input_idshidden_states 送回Decoder Block。
  2. 重复计算,逐步生成完整回答(如“步骤一:热锅凉油…”)。
  3. 直到生成终止标记(如<EOS>)或达到最大长度(如2048 Token)。

示例

  • 生成完整回答后,input_ids 的长度可能变为200(假设生成190个新Token)。
  • 最终的 input_ids 包含原始Prompt和生成的回答。

8. Tokenizer反向处理:从input_ids到用户文本

步骤

  1. 将生成的 input_ids(含prompt和回答)截取回答部分(去掉prompt)。
  2. 使用Tokenizer将 input_ids 转换回自然语言文本(如“步骤一:热锅凉油…”)。

输入输出形状

  • 截取后的 input_ids 形状为 (1, 190)
  • 最终输出文本长度取决于生成内容(如“步骤一:热锅凉油…”)

总结流程图

用户输入 → 模板化Prompt → Tokenizer → input_ids (1,10)  → 嵌入层 → hidden_states (1,10,4096)  → Decoder Block ×32 → hidden_states (1,10,4096)  → LM Head → 概率分布 (1,128000)  → 采样 → 新input_ids (1,11)  → 重复生成 → input_ids (1,200)  → Tokenizer反向 → 用户文本

LlamaForCausalLM结构分析

以模型Llama-3.1-8B-Instruct为例,将一部分子结构信息折叠起来,将显示如下:

LlamaForCausalLM((model): LlamaModel((embed_tokens): VocabParallelEmbedding(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)(layers): ModuleList((0-31): 32 x LlamaDecoderLayer(...))(norm): RMSNorm())(lm_head): ParallelLMHead(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)(logits_processor): LogitsProcessor()(pooler): Pooler()
)

可以看到LlamaForCausalLM主要由几个关键部分组成:model, lm_head, logits_processorpooler。这几个组件作用各不相同,我们现在来介绍一下他们。

1. model:核心解码器结构

(1) embed_tokens:词嵌入层
  • 作用:将输入的Token ID(如“西红柿”→ID=1234)映射为4096维的向量,表示Token的语义和位置信息。
  • 技术细节
    • 使用VocabParallelEmbedding(并行词嵌入,仅需了解,无需深入),支持分布式训练。
    • 词表大小为128256,覆盖多语言和特殊符号(如<BOS><EOS>)。
  • 输入输出形状
    • 输入:(batch_size, seq_len)(1, 10)(假设输入10个Token)
    • 输出:(batch_size, seq_len, hidden_dim)(1, 10, 4096)
(2) layers:32层Decoder Block
  • 核心结构
    • 多头注意力(MHA):通过Grouped-Query Attention (GQA) 提高推理效率(Llama 3.1新增)。
      • 查询(Q)、键(K)、值(V)的维度:d_model=4096num_heads=32head_dim=128
      • GQA机制:将K/V头数减少为num_key_value_heads=8,降低计算开销。
    • 前馈网络(MLP):使用SwiGLU激活函数(Sigmoid + Gated Linear Unit),替代传统ReLU。
      • 输入:4096维 → 中间层:11008维 → 输出:4096维。
    • 归一化:每层使用RMSNorm(均方根归一化),稳定训练并加速收敛。
  • 输入输出形状
    • 每层输入/输出:(1, 10, 4096)(与输入形状一致)
(3) norm:最终归一化层
  • 作用:对32层Decoder Block的输出进行最后一次归一化,确保数值稳定性。
  • 技术细节
    • 使用RMSNorm,无需计算均值,直接对向量的模长标准化。
    • 公式:hidden_states = hidden_states / sqrt(variance + ε),其中ε=1e-6

2. lm_head:语言模型头部

  • 作用:将最终的隐藏状态(hidden_dim=4096)映射为词表大小(vocab_size=128256)的概率分布,预测下一个词。
  • 技术细节
    • 使用ParallelLMHead(并行线性层),加速大规模词表的计算。
    • 参数量:4096 × 128256 ≈ 5.16B(占模型总参数量的约76%)。
  • 输入输出形状
    • 输入:(1, 4096)(取最后一个位置的隐藏状态)
    • 输出:(1, 128256)(每个词的概率值)

3. logits_processor:概率分布处理器

  • 作用:对lm_head输出的概率分布进行后处理,控制生成策略。
  • 常用功能
    • 温度调节(Temperature):降低温度(<1)使输出更确定,升高温度(>1)增加多样性。
    • Top-k/Top-p采样:从概率最高的k个词或累积概率达p的词中随机选择,平衡质量和多样性。
    • 重复惩罚(Repetition Penalty):抑制重复生成相同词(如避免“西红柿西红柿”)。
  • 输入输出形状
    • 输入:(1, 128256)(原始概率分布)
    • 输出:(1, 128256)(处理后的概率分布)

4. pooler:池化层

  • 作用:将整个序列的隐藏状态压缩为固定长度的向量,用于下游任务(如分类、相似度计算)。
  • 技术细节
    • 默认取第一个Token(如<BOS>)的隐藏状态作为全局表示。
    • 或使用平均池化/最大池化,但Llama 3.1通常直接取<BOS>
  • 输入输出形状
    • 输入:(1, 10, 4096)(全序列隐藏状态)
    • 输出:(1, 4096)(固定长度的全局向量)

总结:组件协同工作流程

  1. 输入处理:用户输入文本 → 模板化Prompt → embed_tokens(1, 10, 4096)
  2. 特征提取:32层Decoder Block → hidden_states(1, 10, 4096)
  3. 归一化norm → 稳定输出
  4. 生成预测
    • lm_head(1, 128256) 概率分布
    • logits_processor → 调整概率分布
    • 采样生成下一个词 → 更新 input_ids
  5. 迭代生成:重复步骤1-4,直到生成终止标记(<EOS>)或达到最大长度。
  6. 任务适配pooler 提取全局向量 → 用于分类、相似度等任务。
  • model:像一个厨师,逐步处理食材(Token)并调整火候(注意力机制)。
  • lm_head:厨师的“味觉”,决定下一步该加什么调料(预测下一个词)。
  • logits_processor:厨房的“规则制定者”,确保菜谱不重复且口味可控。
  • pooler:食客的“总结笔记”,用一句话概括整道菜的风味(全局语义)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/87417.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/87417.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

web网页,在线%电商,茶叶,商城,网上商城系统%分析系统demo,于vscode,vue,java,jdk,springboot,mysql数据库

经验心得 这也是帮之前一客户加了几个功能&#xff0c;需要掌握crud&#xff0c;前后端开发&#xff0c;前后端怎么对接&#xff0c;前后端通讯是以那种格式&#xff0c;把这些掌握后咱们就可以进行网站开发了。后端记好一定要分层开发&#xff0c;不要像老早一起所有代码写到一…

MybatisPlus-05.核心功能-条件构造器

一.条件构造器 我们前面使用的MP功能主要是根据id进行操作的&#xff0c;并未涉及到复杂查询。而根据id所进行的增删改查操作在MP中都有直接的封装。但是遇到复杂的查询条件时&#xff0c;如何使用MP进行操作是我们要考虑的问题。因此MP为我们提供了条件构造器。 在BaseMapper…

ES6从入门到精通:常用知识点

变量声明ES6引入了let和const替代var。let用于声明可变的块级作用域变量&#xff0c;const用于声明不可变的常量。块级作用域有效避免了变量提升和污染全局的问题。let name Alice; const PI 3.1415;箭头函数箭头函数简化了函数写法&#xff0c;且自动绑定当前上下文的this值…

51单片机教程(十一)- 单片机定时器

11、单片机定时器 项目目标 通过定时器/计数器实现流水灯控制。知识要点 定时器的结构。TMOD和TCON;定时/计数器工作方式;定时/计数器编程步骤;1、项目分析 前面的流水灯的时间控制通过空循环语句来实现,定时不是很精确。本章通过用定时器来控制流水灯任务可以实现精确的时…

基于opencv的疲劳驾驶监测系统

博主介绍&#xff1a;java高级开发&#xff0c;从事互联网行业多年&#xff0c;熟悉各种主流语言&#xff0c;精通java、python、php、爬虫、web开发&#xff0c;已经做了多年的毕业设计程序开发&#xff0c;开发过上千套毕业设计程序&#xff0c;没有什么华丽的语言&#xff0…

Vue 2 和 Vue 3 区别

1. 响应式系统原理 Vue 2&#xff1a;利用Object.defineProperty()实现属性拦截。存在局限性&#xff0c;无法自动监测对象属性增减&#xff0c;需用Vue.set/delete&#xff1b;数组变异方法要重写&#xff1b;深层对象递归转换性能差。Vue 3&#xff1a;采用 ES6 Proxy代理对…

mv重命名报错:-bash:syntax error near unexpected token ‘(‘

文章目录 一、报错背景二、解决方法2.1、方法一&#xff1a;文件名加引号2.2、方法二&#xff1a;特殊字符前加\进行转义 一、报错背景 在linux上对一文件执行重命名时报错。原因是该文件名包含空格与括号。 文件名如下&#xff1a; aa &#xff08;1).txt执行命令及报错如下…

AWS 开源 Strands Agents SDK,简化 AI 代理开发流程

最近&#xff0c;亚马逊网络服务&#xff08;AWS&#xff09;宣布推出 Strands Agents(https://github.com/strands-agents/sdk-python)&#xff0c;这一开源软件开发工具包&#xff08;SDK&#xff09;采用模型驱动的方法&#xff0c;助力开发者仅用数行代码即可构建并运行人工…

利用 AI 打造的开发者工具集合

如图. 我利用 AI 开发了这个网站花了半个小时. 目前就上了 四个 我想到的工具。 大家可以自行体验下&#xff1a;https://xiaojinzi123.github.io 本文并不是宣传什么产品. 只是感概 Ai 真的改变我的工作方式啊. 虽然现在 AI 对于一些已有的项目进行更改代码. 由于不了解业务,…

[自然语言处理]计算语言的熵

一、要求利用给定的中英文语料&#xff0c;分别计算英语字母、英语单词、汉字、汉语词的熵&#xff0c;并和已公开结果比较&#xff0c;思考汉语的熵对汉语编码和处理的影响。二、实验内容2.1 统计英文语料的熵1.代码(1)计算英文字母的熵import math #计算每个英文字母的熵 def…

如何处理“协议异常”错误

在Java中&#xff0c;“协议异常”通常是指在网络通信或者处理特定协议相关操作时出现的异常。以下是一些处理“协议异常”错误的方法&#xff1a;一、理解协议异常的类型和原因HTTP协议异常原因&#xff1a;在进行HTTP通信时&#xff0c;可能会因为请求格式错误、响应状态码异…

Spark 4.0的VariantType 类型以及内部存储

背景 本文基于Spark 4.0 总结 Spark中的 VariantType 类型,用尽量少的字节来存储Json的格式化数据 分析 这里主要介绍 Variant 的存储,我们从VariantBuilder.buildJson方法(把对应的json数据存储为VariantType类型)开始: public static Variant parseJson(JsonParser …

跨越十年的C++演进:C++20新特性全解析

跨越十年的C演进系列&#xff0c;分为5篇&#xff0c;本文为第四篇&#xff0c;后续会持续更新C23~ 前3篇如下&#xff1a; 跨越十年的C演进&#xff1a;C11新特性全解析 跨越十年的C演进&#xff1a;C14新特性全解析 跨越十年的C演进&#xff1a;C17新特性全解析 C20标准…

LeetCode--40.组合总和II

前言&#xff1a;如果你做出来了39题&#xff0c;但是遇到40题就不会做了&#xff0c;那我建议你去再好好缕清39题的思路&#xff0c;再来看这道题&#xff0c;会有种豁然开朗的感觉解题思路&#xff1a;这道题其实与39题基本一致&#xff0c;所以本次题解是借着39题为基础来讲…

Docker Desktop 安装到D盘(包括镜像下载等)+ 汉化

目录 一、 开启电脑虚拟化 1. 搜索并打开控制面板 2. 点击程序 3. 点击启用或关闭 Windows 功能 4. 打开相关功能 5. 没有Hyper-V的情况&#xff1a; 二、配置环境 1. 更新 WSL 到最新版 2. 设置 WSL 2为默认版本 3. 安装 Ubuntu 三. WSL 迁移到D盘 1. 停止运行wsl…

基于 OpenCV 的图像 ROI 切割实现

一、引言 在计算机视觉领域&#xff0c;我们经常需要处理各种各样的图像数据。有时候&#xff0c;我们只对图像中的某一部分区域感兴趣&#xff0c;例如在一张人物照片中&#xff0c;我们可能只关注人物的脸部。在这种情况下&#xff0c;将我们感兴趣的区域从整个图像中切割出…

Linux操作系统01

一、操作系统简史 二、Linux诞生与分支 三、Linux内核与发行版 内核版本号&#xff1a;cat /proc/version 、 u name -a 操作系统内核漏洞 【超详细】CentOS编译安装升级新内核_centos源码编译安装新版本内核 ntfs-CSDN博客 四、虚拟机 五、Docker容器技术 典型靶场集成环境…

Chrome 下载文件时总是提示“已阻止不安全的下载”的解决方案

解决 Chrome 谷歌浏览器下载文件时提示“已阻止不安全的下载”的问题。 ‍ 前言 最近更新 Chrome 后&#xff0c;下载文件时总是提示“已拦截未经验证的下载内容”、“已阻止不安全的下载”&#xff1a; ‍ 身为一个互联网冲浪高手&#xff0c;这些提醒非常没有必要&#x…

RocketMQ延迟消息是如何实现的?

RocketMQ的延迟消息实现机制非常巧妙&#xff0c;其核心是通过多级时间轮 定时任务 消息重投递来实现的。以下是详细实现原理&#xff1a; ⏰ 一、延迟消息的核心设计 预设延迟级别&#xff08;非任意时间&#xff09; RocketMQ不支持任意时间延迟&#xff0c;而是预设了18个…

D3 面试题100道之(21-40)

这里是D3的面试题,我们从第 21~40题 开始逐条解答。一共100道,陆续发布中。 🟩 面试题(第 21~40 题) 21. D3 中的数据绑定机制是怎样的? D3 的数据绑定机制通过 selection.data() 方法实现。它将数据数组与 DOM 元素进行一一对应,形成三种状态: Update Selection:已…