知识蒸馏 Knowledge Distillation 序列的联合概率 分解成 基于历史的条件概率的连乘序列
flyfish
代码实践
论文 Generalized Knowledge Distillation (GKD)
On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes
自回归分解 将 “序列的联合概率” 分解成 “基于历史的条件概率的连乘序列”
自回归
以句子“你现在读的这句话”为例
首先明确句子的token时序顺序(从左到右,依次生成):
第1个token:你 → 第2个token:现 → 第3个token:在 → 第4个token:读 → 第5个token:的 → 第6个token:这 → 第7个token:句 → 第8个token:话
自回归的核心逻辑是:每个“当前token”只依赖“它前面所有已生成的历史token”(即左侧的、早于它的token),与右侧未生成的token无关。具体依赖关系如下:
当前token | 生成顺序 | 依赖的“历史token”(仅左侧已生成的) | 不依赖的内容(右侧未生成的) |
---|---|---|---|
你 | 第1步 | 无(第一个生成的token,仅依赖输入prompt,如“请说一句话:”) | 现、在、读、的、这、句、话 |
现 | 第2步 | 仅依赖第1个token“你” | 在、读、的、这、句、话 |
在 | 第3步 | 依赖第1-2个token“你、现” | 读、的、这、句、话 |
读 | 第4步 | 依赖第1-3个token“你、现、在” | 的、这、句、话 |
的 | 第5步 | 依赖第1-4个token“你、现、在、读” | 这、句、话 |
这 | 第6步 | 依赖第1-5个token“你、现、在、读、的” | 句、话 |
句 | 第7步 | 依赖第1-6个token“你、现、在、读、的、这” | 话 |
话 | 第8步 | 依赖第1-7个token“你、现、在、读、的、这、句” | 无(最后一个token) |
生成过程
语言模型生成“你现在读的这句话”的过程,就像人写字“一笔接一笔、从左到右”:
- 第一步:看到prompt(比如“请写一个日常场景的短句”),先写出第一个字“你”——此时只需要考虑“prompt要求”,不需要想后面要写什么;
- 第二步:写完“你”后,思考“‘你’后面接什么字合理”,选择“现”——只看已经写好的“你”,不看还没写的“在、读、的…”;
- 第三步:写完“你、现”后,思考“‘你现’后面接什么字合理”,选择“在”——只看已经写好的“你、现”,不看后面的“读、的…”;
- 以此类推,直到写完最后一个字“话”——每一步的选择,都只基于“左边已经写好的内容”,右侧的内容是“下一步才会生成的”,当前步骤完全无法预知,自然也无法依赖。
这里为了简单介绍一个字是一个token,实际会有多种token算法
自回归的“自”(自身历史),特指**“自身左侧已生成的时序历史”,而非“整个句子的所有内容”。原则是:
文本生成是“单向时序过程”,后面的token还没被模型生成,当前token只能“回头看左边的历史”,不能“超前看右边的未来”自回归语言生成是严格按“从左到右”的时序展开**,当前token只能依赖“它左边已经生成的历史token”,绝不可能依赖“右边还没生成的token”。
条件链式法则的序列形式
自回归分解的公式
P(y1,…,yL∣x)=∏n=1LP(yn∣y<n,x)P(y_1,\dots,y_L\mid x)=\prod_{n=1}^{L} P\!\left(y_n \mid y_{<n},x\right) P(y1,…,yL∣x)=n=1∏LP(yn∣y<n,x)
本质是条件概率链式法则的直接推广。
回顾基础链式法则:对于随机变量序列 Y1,Y2,…,YLY_1,Y_2,\dots,Y_LY1,Y2,…,YL 和给定的条件变量 XXX,它们的条件联合概率可以分解为一系列条件概率的乘积。这里的关键是:
- 每一步的条件概率 P(yn∣y<n,x)P(y_n \mid y_{<n},x)P(yn∣y<n,x) 只依赖于“之前所有已出现的变量 y<n=(y1,…,yn−1)y_{<n} = (y_1,\dots,y_{n-1})y<n=(y1,…,yn−1)”和“外部条件 xxx”;
- 整个序列的联合概率被拆解为 L个“局部条件概率”的连乘,将高维联合概率的计算转化为低维条件概率的乘积,大幅降低了建模难度。
为什么自回归分解是“最优解”?
语言模型的核心任务是建模文本序列的概率分布(如“一句话是否合理”“下一个词是什么”),而自回归分解完美适配语言的“时序特性”,主要原因有三点:
1. 符合语言的“顺序生成特性”
人类语言天然是时序序列:说话/写作时,我们总是“先说出第一个词,再根据第一个词说第二个词,以此类推”。自回归分解恰好模拟了这一过程——每个词 yny_nyn 的生成只依赖于“已经说过的词 y<ny_{<n}y<n”和“上下文提示 xxx”,与人类语言生成的直觉一致。
2. 降低建模复杂度
直接建模整个序列的联合概率 P(y1,…,yL∣x)P(y_1,\dots,y_L \mid x)P(y1,…,yL∣x) 几乎不可能:对于词汇表大小为 VVV 的语言,序列长度为 LLL 的可能组合有 VLV^LVL 种,无法直接枚举或存储。
而自回归分解将问题转化为逐一生成每个位置的词,只需建模 P(yn∣y<n,x)P(y_n \mid y_{<n},x)P(yn∣y<n,x)——这个条件概率的“输入维度”是“已生成序列 y<ny_{<n}y<n + 提示 xxx”,输出是“下一个词 yny_nyn 的概率分布”,可以用神经网络(如Transformer)高效建模。
3. 支持“增量生成”
语言模型的核心功能之一是生成文本(如续写、翻译),自回归分解天然支持“增量生成”:
- 第一步:根据提示 xxx 生成第一个词 y1y_1y1,即采样自 P(y1∣x)P(y_1 \mid x)P(y1∣x);
- 第二步:根据 xxx 和 y1y_1y1 生成 y2y_2y2,采样自 P(y2∣y1,x)P(y_2 \mid y_1, x)P(y2∣y1,x);
- …
- 第L步:根据 xxx 和 y1,…,yL−1y_1,\dots,y_{L-1}y1,…,yL−1 生成 yLy_LyL,最终得到完整序列。
自回归语言模型的生成逻辑——每一步只需要“看前面的词”,就能继续生成下一个词。
语言模型如何“学习”这个分解?
语言模型的训练目标,本质上是通过海量文本数据,学习自回归分解中的每个条件概率 P(yn∣y<n,x)P(y_n \mid y_{<n},x)P(yn∣y<n,x)。
具体来说:
- 训练数据是大量文本序列(如句子、段落),每个序列可视为 (x,y1,…,yL)(x, y_1,\dots,y_L)(x,y1,…,yL)(其中 xxx 可能是序列的前缀,或空提示);
- 模型通过“预测下一个词”的任务学习条件概率:给定 xxx 和 y<ny_{<n}y<n,模型输出对 yny_nyn 的概率分布,通过交叉熵损失优化,让模型的预测尽可能接近真实文本中 yny_nyn 的出现概率;
- 训练完成后,模型就能对“任意前缀序列”输出下一个词的合理概率分布,从而支持生成连贯的文本。
一句话的自回归分解
以生成句子“我爱吃苹果”(假设 xxx 为空提示,即生成独立句子)为例,其概率可分解为:
P(我,爱,吃,苹果)=P(我)×P(爱∣我)×P(吃∣我,爱)×P(苹果∣我,爱,吃)P(\text{我},\text{爱},\text{吃},\text{苹果}) = P(\text{我}) \times P(\text{爱} \mid \text{我}) \times P(\text{吃} \mid \text{我},\text{爱}) \times P(\text{苹果} \mid \text{我},\text{爱},\text{吃}) P(我,爱,吃,苹果)=P(我)×P(爱∣我)×P(吃∣我,爱)×P(苹果∣我,爱,吃)
- P(我)P(\text{我})P(我):第一个词是“我”的概率(语言模型中,通常会给句首词一个先验分布);
- P(爱∣我)P(\text{爱} \mid \text{我})P(爱∣我):在“我”之后接“爱”的概率(比如“我”后面更可能接“爱”“是”“想”等词,而不是“苹果”“跑步”);
- 以此类推,每个步骤的概率都依赖于前面的词,最终乘积就是整个句子的联合概率。
自回归分解是将“序列的联合概率”分解成“基于历史的条件概率的连乘序列”,而这个分解过程是条件概率链式法则——但它不是“任意的链式法则”,而是加了“自回归约束”的链式法则应用。
第一步:先明确“自回归分解分解的是什么?”
自回归分解的核心目标,是把“一个token序列(比如句子)的整体概率”拆解开,变成“每一步生成token的局部概率”,方便语言模型计算和预测。
比如,对于一个长度为n
的token序列 X = [x₁, x₂, x₃, ..., xₙ]
(x₁
是第一个token,xₙ
是最后一个token),我们要分解的是这个序列的联合概率 P(X) = P(x₁, x₂, x₃, ..., xₙ)
。
为什么要分解?因为直接计算“整个序列同时出现”的联合概率非常困难(可能性太多),但分解成“每一步基于前面内容的条件概率”后,模型只需要预测“当前token在历史token下的概率”,难度会大幅降低。
第二步:自回归分解如何用“条件概率链式法则”?
条件概率链式法则是概率论的基础规则:多个随机变量的联合概率,可分解为第一个变量的边缘概率,乘以第二个变量在第一个变量条件下的概率,再乘以第三个变量在第一、二个变量条件下的概率……以此类推。
对序列 X = [x₁, x₂, ..., xₙ]
,通用的条件概率链式法则是这样的:
P(x1,x2,...,xn)=P(x1)×P(x2∣x1)×P(x3∣x1,x2)×P(x4∣x1,x2,x3)×...×P(xn∣x1,x2,...,xn−1)P(x₁, x₂, ..., xₙ) = P(x₁) × P(x₂ | x₁) × P(x₃ | x₁, x₂) × P(x₄ | x₁, x₂, x₃) × ... × P(xₙ | x₁, x₂, ..., xₙ₋₁) P(x1,x2,...,xn)=P(x1)×P(x2∣x1)×P(x3∣x1,x2)×P(x4∣x1,x2,x3)×...×P(xn∣x1,x2,...,xn−1)
而自回归分解,就是完全遵循这个链式法则,但额外加了一个“自回归约束”:
约束:对于第k
个token xₖ
,它的条件依赖只能是“它前面所有已生成的历史token(x₁到xₖ₋₁)”,绝对不能依赖“它后面未生成的token(xₖ₊₁到xₙ)”。
自回归分解没有“发明新的分解规则”,而是把“条件概率链式法则”直接用在了“时序序列”上——因为语言生成是“从左到右、先有历史再有当前”的单向过程,后面的token还没生成,自然无法成为当前token的依赖,这就和链式法则中“xₖ只依赖x₁到xₖ₋₁”的形式完美匹配。
第三步:用具体例子看“自回归分解的结果”
还是用之前的句子“你现在读的这句话”,对应的token序列 X = [x₁=你, x₂=现, x₃=在, x₄=读, x₅=的, x₆=这, x₇=句, x₈=话]
。
根据自回归分解(基于条件链式法则),这个序列的联合概率会被分解成以下条件概率的连乘序列:
P(你,现,在,读,的,这,句,话)=P(你)(第一个token,无历史,只看边缘概率)×P(现∣你)(第二个token,依赖前1个历史token“你”)×P(在∣你,现)(第三个token,依赖前2个历史token“你、现”)×P(读∣你,现,在)(第四个token,依赖前3个历史token)×P(的∣你,现,在,读)(第五个token,依赖前4个)×P(这∣你,现,在,读,的)(第六个token,依赖前5个)×P(句∣你,现,在,读,的,这)(第七个token,依赖前6个)×P(话∣你,现,在,读,的,这,句)(第八个token,依赖前7个)\begin{align*} P(你,现,在,读,的,这,句,话) &= P(你) \quad \text{(第一个token,无历史,只看边缘概率)} \\ &\times P(现 \mid 你) \quad \text{(第二个token,依赖前1个历史token“你”)} \\ &\times P(在 \mid 你,现) \quad \text{(第三个token,依赖前2个历史token“你、现”)} \\ &\times P(读 \mid 你,现,在) \quad \text{(第四个token,依赖前3个历史token)} \\ &\times P(的 \mid 你,现,在,读) \quad \text{(第五个token,依赖前4个)} \\ &\times P(这 \mid 你,现,在,读,的) \quad \text{(第六个token,依赖前5个)} \\ &\times P(句 \mid 你,现,在,读,的,这) \quad \text{(第七个token,依赖前6个)} \\ &\times P(话 \mid 你,现,在,读,的,这,句) \quad \text{(第八个token,依赖前7个)} \\ \end{align*} P(你,现,在,读,的,这,句,话)=P(你)(第一个token,无历史,只看边缘概率)×P(现∣你)(第二个token,依赖前1个历史token“你”)×P(在∣你,现)(第三个token,依赖前2个历史token“你、现”)×P(读∣你,现,在)(第四个token,依赖前3个历史token)×P(的∣你,现,在,读)(第五个token,依赖前4个)×P(这∣你,现,在,读,的)(第六个token,依赖前5个)×P(句∣你,现,在,读,的,这)(第七个token,依赖前6个)×P(话∣你,现,在,读,的,这,句)(第八个token,依赖前7个)
这个连乘序列,就是“自回归分解”的最终结果——它完全是条件概率链式法则在“语言生成时序”下的直接体现,每一项都是“当前token基于历史的条件概率”,没有任何超出历史的依赖。
自回归分解与条件链式法则的关系
结论 | 具体解释 |
---|---|
分解的对象 | 序列的联合概率(如P(x₁,x₂,...,xₙ) ) |
分解的工具 | 条件概率链式法则(概率论基础规则) |
分解的约束 | 自回归约束:xₖ 仅依赖x₁~xₖ₋₁ (左侧历史,无未来依赖) |
分解的结果 | 一个条件概率的连乘序列(每一项对应一步生成的概率) |