Decoder-Only整体结构
我们以模型Llama-3.1-8B-Instruct
为例,打印其结构如下(后面会慢慢解析每一部分,莫慌):
LlamaForCausalLM((model): LlamaModel((embed_tokens): VocabParallelEmbedding(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)(layers): ModuleList((0-31): 32 x LlamaDecoderLayer((self_attn): LlamaAttention((qkv_proj): QKVParallelLinear(in_features=4096, output_features=6144, bias=False, tp_size=1, gather_output=False)(o_proj): RowParallelLinear(input_features=4096, output_features=4096, bias=False, tp_size=1, reduce_results=True)(rotary_emb): Llama3RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=131072, base=500000.0, is_neox_style=True)(attn): RadixAttention())(mlp): LlamaMLP((gate_up_proj): MergedColumnParallelLinear(in_features=4096, output_features=28672, bias=False, tp_size=1, gather_output=False)(down_proj): RowParallelLinear(input_features=14336, output_features=4096, bias=False, tp_size=1, reduce_results=True)(act_fn): SiluAndMul())(input_layernorm): RMSNorm()(post_attention_layernorm): RMSNorm()))(norm): RMSNorm())(lm_head): ParallelLMHead(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)(logits_processor): LogitsProcessor()(pooler): Pooler()
)
Decoder-Only处理流程
我们以Llama-3.1-8B-Instruct模型为例,结合一个具体的聊天对话场景,详细说明Decoder-Only模型的处理流程,从用户输入到最终输出回答。整个过程会逐步拆解,并标注每个步骤的输入输出形状(假设batch_size=1,seq_len=10,hidden_dim=4096,词表大小=128000)。
1. 用户输入与聊天模板处理
场景:用户问:“如何做西红柿炒鸡蛋?”
模型需求:需要根据历史对话和当前问题生成回答。
聊天模板处理
- 输入文本text:原始用户输入(如“如何做西红柿炒鸡蛋?”)
- 模板化prompt:模型需要将输入包装成特定格式的prompt,例如:
[系统指令]:你是一个烹饪助手,请回答以下问题。 [用户]:如何做西红柿炒鸡蛋? [助手]:
- 作用:模板化prompt让模型明确任务目标(如回答问题),并模拟对话上下文。
输入输出形状:
- 输入文本长度:假设为10个字符(实际长度取决于具体输入)。
- 模板化后的prompt长度:假设为30个字符(包含系统指令、用户问题和占位符)。
2. Tokenizer处理:从prompt到input_ids
步骤:
- Tokenization:将模板化prompt拆分为模型能理解的Token(如“西红柿”→“西红柿”,“炒”→“炒”)。
- 映射到input_ids:每个Token被映射为对应的ID(例如,“西红柿”→1234,“炒”→5678)。
示例:
假设模板化Prompt被拆分为10个Token,其input_ids为:
[101, 1234, 5678, 8901, 2345, 6789, 102, 3456, 7890, 102]
(其中101和102是特殊标记,如<BOS>
和<EOS>
,表示开始和结束)
输入输出形状:
input_ids
的形状为(batch_size, seq_len)
→(1, 10)
attention_mask
(可选)的形状为(1, 10)
,标记哪些位置是有效Token(1)或填充(0)。
3. 嵌入层:input_ids → hidden_states
步骤:
- Token Embedding:将input_ids映射为高维向量(如4096维)。
- Positional Encoding:添加位置信息,让模型知道每个Token在序列中的位置。
示例:
- input_ids
[101, 1234, 5678, ...]
→ 隐藏状态hidden_states
的形状为(1, 10, 4096)
。 - 每个Token对应的向量包含其语义和位置信息(例如,“西红柿”对应的食物相关特征,以及它在句子中的位置)。
输入输出形状:
hidden_states
的形状为(batch_size, seq_len, hidden_dim)
→(1, 10, 4096)
4. Decoder Block处理:逐层计算
核心流程:
-
Masked Self-Attention(带掩码的自注意力):
- 每个Token只能看到自己及之前的Token(防止“偷看”未来内容)。
- 例如,在生成“西红柿炒鸡蛋”时,模型会先处理“西红柿”,再处理“炒”,确保生成逻辑连贯。
-
前馈网络(FFN):
- 对每个Token的隐藏状态进行非线性变换,增强表达能力。
示例:
- 假设模型有32层Decoder Block,每层都会更新
hidden_states
。 - 最终的
hidden_states
保留了完整的上下文信息(如“西红柿炒鸡蛋”的步骤描述)。
输入输出形状:
- 每层Decoder Block的输入输出形状不变,仍为
(1, 10, 4096)
5. LM Head:从hidden_states到下一个词
步骤:
- 线性层:将最后一个Token的隐藏状态(形状为
(1, 10, 4096)
)映射到词表维度(128000)。- 例如,对最后一个位置(
seq_len=9
)的隐藏状态取值:hidden_states[:, 9, :]
→ 形状(1, 4096)
。
- 例如,对最后一个位置(
- Softmax:将输出转换为概率分布(每个词的概率)。
示例:
- 假设模型预测下一个词是“步骤一”,其ID为9876,则概率分布中9876的值最高。
输入输出形状:
- 线性层输出形状:
(1, 128000)
- 概率分布形状:
(1, 128000)
6. 采样策略:从概率分布到下一个词
方法:
- Top-k采样:从概率最高的前k个词(如k=50)中随机选一个。
- Greedy Search:直接选概率最高的词(如“步骤一”)。
示例:
- 模型选择“步骤一”作为下一个词,并将其ID(9876)添加到
input_ids
中。 - 新的
input_ids
变为:[101, 1234, 5678, ..., 9876]
(长度+1)。
输入输出形状:
- 新的
input_ids
形状为(1, 11)
7. 迭代生成:重复步骤3-6直到完成
流程:
- 将新的
input_ids
和hidden_states
送回Decoder Block。 - 重复计算,逐步生成完整回答(如“步骤一:热锅凉油…”)。
- 直到生成终止标记(如
<EOS>
)或达到最大长度(如2048 Token)。
示例:
- 生成完整回答后,
input_ids
的长度可能变为200(假设生成190个新Token)。 - 最终的
input_ids
包含原始Prompt和生成的回答。
8. Tokenizer反向处理:从input_ids到用户文本
步骤:
- 将生成的
input_ids
(含prompt和回答)截取回答部分(去掉prompt)。 - 使用Tokenizer将
input_ids
转换回自然语言文本(如“步骤一:热锅凉油…”)。
输入输出形状:
- 截取后的
input_ids
形状为(1, 190)
- 最终输出文本长度取决于生成内容(如“步骤一:热锅凉油…”)
总结流程图
用户输入 → 模板化Prompt → Tokenizer → input_ids (1,10) → 嵌入层 → hidden_states (1,10,4096) → Decoder Block ×32 → hidden_states (1,10,4096) → LM Head → 概率分布 (1,128000) → 采样 → 新input_ids (1,11) → 重复生成 → input_ids (1,200) → Tokenizer反向 → 用户文本
LlamaForCausalLM结构分析
以模型Llama-3.1-8B-Instruct
为例,将一部分子结构信息折叠起来,将显示如下:
LlamaForCausalLM((model): LlamaModel((embed_tokens): VocabParallelEmbedding(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)(layers): ModuleList((0-31): 32 x LlamaDecoderLayer(...))(norm): RMSNorm())(lm_head): ParallelLMHead(num_embeddings=128256, embedding_dim=4096, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)(logits_processor): LogitsProcessor()(pooler): Pooler()
)
可以看到LlamaForCausalLM
主要由几个关键部分组成:model, lm_head, logits_processor和pooler。这几个组件作用各不相同,我们现在来介绍一下他们。
1. model
:核心解码器结构
(1) embed_tokens
:词嵌入层
- 作用:将输入的Token ID(如“西红柿”→ID=1234)映射为4096维的向量,表示Token的语义和位置信息。
- 技术细节:
- 使用VocabParallelEmbedding(并行词嵌入,仅需了解,无需深入),支持分布式训练。
- 词表大小为128256,覆盖多语言和特殊符号(如
<BOS>
、<EOS>
)。
- 输入输出形状:
- 输入:
(batch_size, seq_len)
→(1, 10)
(假设输入10个Token) - 输出:
(batch_size, seq_len, hidden_dim)
→(1, 10, 4096)
- 输入:
(2) layers
:32层Decoder Block
- 核心结构:
- 多头注意力(MHA):通过Grouped-Query Attention (GQA) 提高推理效率(Llama 3.1新增)。
- 查询(Q)、键(K)、值(V)的维度:
d_model=4096
,num_heads=32
,head_dim=128
。 - GQA机制:将K/V头数减少为
num_key_value_heads=8
,降低计算开销。
- 查询(Q)、键(K)、值(V)的维度:
- 前馈网络(MLP):使用SwiGLU激活函数(Sigmoid + Gated Linear Unit),替代传统ReLU。
- 输入:
4096
维 → 中间层:11008
维 → 输出:4096
维。
- 输入:
- 归一化:每层使用RMSNorm(均方根归一化),稳定训练并加速收敛。
- 多头注意力(MHA):通过Grouped-Query Attention (GQA) 提高推理效率(Llama 3.1新增)。
- 输入输出形状:
- 每层输入/输出:
(1, 10, 4096)
(与输入形状一致)
- 每层输入/输出:
(3) norm
:最终归一化层
- 作用:对32层Decoder Block的输出进行最后一次归一化,确保数值稳定性。
- 技术细节:
- 使用RMSNorm,无需计算均值,直接对向量的模长标准化。
- 公式:
hidden_states = hidden_states / sqrt(variance + ε)
,其中ε=1e-6
。
2. lm_head
:语言模型头部
- 作用:将最终的隐藏状态(
hidden_dim=4096
)映射为词表大小(vocab_size=128256
)的概率分布,预测下一个词。 - 技术细节:
- 使用ParallelLMHead(并行线性层),加速大规模词表的计算。
- 参数量:
4096 × 128256 ≈ 5.16B
(占模型总参数量的约76%)。
- 输入输出形状:
- 输入:
(1, 4096)
(取最后一个位置的隐藏状态) - 输出:
(1, 128256)
(每个词的概率值)
- 输入:
3. logits_processor
:概率分布处理器
- 作用:对
lm_head
输出的概率分布进行后处理,控制生成策略。 - 常用功能:
- 温度调节(Temperature):降低温度(
<1
)使输出更确定,升高温度(>1
)增加多样性。 - Top-k/Top-p采样:从概率最高的
k
个词或累积概率达p
的词中随机选择,平衡质量和多样性。 - 重复惩罚(Repetition Penalty):抑制重复生成相同词(如避免“西红柿西红柿”)。
- 温度调节(Temperature):降低温度(
- 输入输出形状:
- 输入:
(1, 128256)
(原始概率分布) - 输出:
(1, 128256)
(处理后的概率分布)
- 输入:
4. pooler
:池化层
- 作用:将整个序列的隐藏状态压缩为固定长度的向量,用于下游任务(如分类、相似度计算)。
- 技术细节:
- 默认取第一个Token(如
<BOS>
)的隐藏状态作为全局表示。 - 或使用平均池化/最大池化,但Llama 3.1通常直接取
<BOS>
。
- 默认取第一个Token(如
- 输入输出形状:
- 输入:
(1, 10, 4096)
(全序列隐藏状态) - 输出:
(1, 4096)
(固定长度的全局向量)
- 输入:
总结:组件协同工作流程
- 输入处理:用户输入文本 → 模板化Prompt →
embed_tokens
→(1, 10, 4096)
- 特征提取:32层Decoder Block →
hidden_states
→(1, 10, 4096)
- 归一化:
norm
→ 稳定输出 - 生成预测:
lm_head
→(1, 128256)
概率分布logits_processor
→ 调整概率分布- 采样生成下一个词 → 更新
input_ids
- 迭代生成:重复步骤1-4,直到生成终止标记(
<EOS>
)或达到最大长度。 - 任务适配:
pooler
提取全局向量 → 用于分类、相似度等任务。
model
:像一个厨师,逐步处理食材(Token)并调整火候(注意力机制)。lm_head
:厨师的“味觉”,决定下一步该加什么调料(预测下一个词)。logits_processor
:厨房的“规则制定者”,确保菜谱不重复且口味可控。pooler
:食客的“总结笔记”,用一句话概括整道菜的风味(全局语义)。