LLM场景下的强化学习【PPO】

适合本身对强化学习有基本了解

一、什么是强化学习

一句话:在当前状态(State)下,智能体(Agent)与环境(Environment)交互,并采取动作(Action)进入下一状态,过程中获得奖励(Reward,有正向有负向),从而实现从环境中学习。

在LLM场景下,提到RL一般是指RLHF(人类偏好对齐),此时上述关键概念介绍如下:

  • Agent: 语言模型本身,例如GPT、LLaMA。

  • Environment: 训练阶段,环境是奖励模型RM,它基于人类标注的偏好数据对生成的文本评分。在部署阶段,环境是真实用户,用户的实际反馈(如点击、停留时间、修正等)作为隐式奖励。

  • State: 当前文本上下文(Prompt+已生成的Token序列),State会受到模型的上下文窗口限制,需要注意力机制捕捉长程依赖。

  • Action: 生成下一个Token。动作空间是整个词表,每一步选择一个token生成文本。特点是actions具备强依赖性,单个action的奖励不好评估,需要对action序列进行评估。

  • Reward: 奖励模型对完整响应的打分,评价维度一般包含文本的流畅性、相关性、准确性、用户满意度等

二、强化学习基本概念

强化学习的目标是最大化累积奖励的期望

1、马尔科夫决策过程(Markov Decision Process,MDP

强化学习过程是一个马尔可夫决策过程,这是指下一个状态和奖励仅依赖于当前状态和动作,即满足马尔可夫性。

2、动作空间A

整个词表

3、状态空间S

当前环境中所有可能状态的集合。

状态 = 当前文本序列(Prompt + 已生成的Token序列)。

因此状态空间理论上接近无限,但实际会受到上下文窗口大小的限制。

例如,若上下文窗口为2048,词表大小为50k,则状态空间规模约为  {(50k)}^{2048}

4、策略

智能体在某一状态下选择动作的准则,一般有“确定性策略”和“随机性策略”,在LLM场景下,由于每个动作输出的是词表的概率分布,所以都采用“随机性策略”。

5、运动轨迹

智能体与环境交互过程中产生的一系列状态、动作和奖励的序列。

LLM场景下的特点是中间动作的奖励均为0,仅最后一个动作的奖励有意义。这种设定会导致稀疏奖励问题 (Sparse Rewards),使得策略难以学习中间token的长期价值。

6、累积奖励

累积奖励(Total Reward)是轨迹中所有奖励的未折扣总和,即 G = \sum _{t=0}^Tr_t由于LLM场景下,中间奖励全部为0,所以累积奖励等于最后一步的奖励。

(ps:从概念的定义来推断,能得出这个结论,但是还不确定是否存在中间奖励不为0的情况)

还需要提到的概念是累积折扣奖励G = \sum _{t=0}^T \gamma^t r_t会为奖励增加一个取值在[0,1]之间的折扣因子,作用是更关注近期奖励。

7、回报

从某一时刻开始的所有未来奖励的累积,一般使用累积折扣奖励作为回报。

G_t = \sum _{k=0}^T \gamma^k r_{t+k}

8、价值函数

  • 状态价值函数 V^\pi(s):评估当前文本状态的长期预期奖励(如:"如何泡茶? 先烧"的价值)。

  • 动作价值函数 Q^\pi(s,a):评估特定Token选择的价值(如:在状态"如何泡茶? 先"下,生成"烧" vs "准"的预期收益)。

这里我们可以发现动作价值函数和状态价值函数之间是有关联的:

Q^\pi(s,a) = E[r_t + \gamma V^\pi(s_{t+1})]

举例理解就是,在状态"如何泡茶? 先"下,采取动作"烧"带来的长期预期价值,等于"烧"这个动作的奖励 + "如何泡茶? 先烧"这个状态的状态价值 的期望。这里取期望是因为"烧"这个动作的奖励是不确定的,因为我们的策略在更新,动作的奖励也会变化。只是在实际中为了简化计算,会直接把期望给去掉,得到:

Q^\pi(s,a) = r_t + \gamma V^\pi(s_{t+1})

9、优势

优势定义为动作价值函数 减去 状态价值函数

A^\pi(s,a) = Q^\pi(s,a) - V^\pi(s) = r_t + \gamma V^\pi(s_{t+1}) - V^\pi(s_{t})

衡量的是在状态s下,执行动作a比执行其他动作的优势

让我们用一个马里奥游戏的例子来直观理解,为什么优势=动作价值函数-状态价值函数

  • 假设你在玩马里奥游戏,你来到了画面的某一帧。

  • 你在这一帧下有3个选择:顶金币,踩乌龟,跳过乌龟。你现在想知道执行“顶金币”的动作比别的动作好多少。

  • 你先执行了“顶金币”的动作(即现在你采取了某个确定的策略),在这之后你又玩了若干回合游戏。在每一次回合开始时,你都从(这一帧,顶金币)这个状态-动作对出发,玩到游戏结束。在每一回合中,你都记录下从(这一帧,顶金币)出发,一直到回合结束的累积奖励。你将这若干轮回合的奖励求平均,就计算出从(这一帧,顶金币)出发后的累积奖励期望,即动作价值函数 Q^\pi(s,a)

  • 现在你重新回到这一帧,对于“顶金币”,“踩乌龟”,“跳过乌龟”这三个动作,你按照当前的策略(对于每个回合,你有你自己当前的策略),从这三者中采样动作(注意,我们没有排除掉“顶金币”),并继续玩这个游戏直到回合结束,你记录下从这一帧出发一直到回合结束的累积回报。重复上面这个过程若干次,然后你将这若干轮回合的奖励求平均,就计算出从这一帧出发后的累积奖励期望,我们记其为 V^\pi(s)

  • 你会发现不管是Q还是V,下标都有一个 π ,这是因为它们和你当前采取的策略是相关的。

  • 从直觉上,我们取 A^\pi(s,a) = Q^\pi(s,a) - V^\pi(s) 这个差值,就可以衡量在某个状态下,执行某个动作,要比其它的动作好多少了。这个差值,我们可以理解为“优势”(advantage),这个优势更合理地帮助我们衡量了单步的奖励。

  • 当优势越大时,说明一个动作比其它动作更好,所以这时候我们要提升这个动作的概率

10、GAE优势

GAE(Generalized Advantage Estimator,广义优势估计器),目的是解决优势计算中的方差-偏差的问题。

在计算优势时,如果依赖局部近似,也就是TD优势(时序差分优势)的方案,可能会忽略长期回报,导致策略更新方向出错(偏差大问题);

如果依赖完整运动轨迹,也就是MC优势(蒙特卡洛优势)的方案,能够反映长期回报,但是对噪声敏感(方差大问题)

GAE通过引入一个新的超参数λ来平衡方差-偏差的问题

当λ接近0时,GAE优势会更依赖短期步骤(低方差),从而缓解方差大的问题,但会忽略长期信息导致高偏差;

当λ接近1时,GAE优势会更依赖未来多个步骤(低偏差),从而缓解偏差大的问题,但累积噪声会放大方差。

按照经验,λ的取值一般在0.9~0.99.

上面这个描述可能不好理解,下面详细解释下:

首先我们需要明确下,计算“优势”的目的,就是为了判断当前状态下,我们采取的某个动作到底好不好。判断的方式直观来看有两种,第一种是看采取这个动作后能不能马上带来收益;第二种是看采取这个动作后对我长期的收益有什么影响。

然后来对比三种优势的计算公式:

A_t^{GAE} = \sum _{l=0}^\infty (\gamma \lambda ) ^l\delta _{t+l}

A_t^{TD} = \delta _{t} = r_t + \gamma V^\pi(s_{t+1}) - V^\pi(s_{t})

A_t^{MC} = G_t - V(s_t)

从公式中,可以直观看到,TD优势只依赖于当前状态和下一状态的状态价值,以及当前状态下,所采取的动作带来的奖励 r_t

所以我们说TD优势依赖局部近似,也就是说它看的是“采取这个动作后能不能马上带来收益”,这种判断方式的问题也很明显,有时短期能带来收益,但是长期来看这个动作并不好。

比如我们的LLM要生成一句话「“如何泡茶?先烧开水,再准备茶叶”」,现在LLM已经生成了「“如何泡茶?先”」,接下来要从{“烧”、“准”}中选择一个动作,计算TD优势,发现“准”的短期收益更大,那么策略就会向着让“准”生成概率更大的方向来更新,最后生成的答案变成了「“如何泡茶?先准备茶叶,再烧开水”」,从长期来看,这个策略是不对的,整个策略的更新方向出错。将这个问题称为“偏差大”问题。

再看看MC优势的公式,使用当前状态的回报减去当前状态的价值,需要注意下,回报是我们根据采样得到的实际得到的奖励,而状态价值是我们使用Critic模型估计得到的预期奖励,所以两者之差其实是实际奖励与预期之间的差值。如果实际奖励比预期要好,说明我们采取的动作带来的未来收益比平均未来收益要高,应该在策略更新中增加采取该动作的概率。

这种方法的问题在于计算 G_t 时,我们是根据当前状态和当前策略,采样了一条运动轨迹,这种采样实际上随机性很大,因为对于未来的每一个动作,实际都是“概率分布”,我们对每个动作的采样,其实都是一个随机事件,两次采样的轨迹可能对每个动作的选择都不同。

所以MC优势的缺点是采样到的运动轨迹对噪声敏感,用MC优势指导策略更新会时好时坏,存在“方差大”问题。

最后看GAE公式,当λ接近0时,只有l=0的项有意义,退化成TD优势公式。当λ接近1时,退化成MC优势公式。所以λ越小,越能缓解方差问题,λ越大,越能缓解偏差问题,我们可以找到一个合适的λ值,从而在方差问题与偏差问题中找到一个平衡点。

11、策略更新方法

  • 基于策略的方法:以往方法是基于策略梯度,引入可学习参数θ,然后构造目标函数。在LLM场景下,一般使用PPO算法。

三、PPO算法

PPO算法的实现方式有两种:PPO-Penalty和PPO-Clip

  • PPO-Penalty:使用KL散度衡量新策略与旧策略之间的分布差异,作用是作为惩罚项,如果KL散度大,则惩罚也大,会减少策略更新幅度,否则惩罚小,对策略更新的约束也小。

  • PPO-Clip:引入一个参数ε,当策略更步长大于1+ε或者小于1-ε时,直接采取裁剪,将步长限制在1-ε到1+ε之间。由于Clip方法效果更好,所以LLM场景中默认使用Clip

PPO算法中有两个很重要的目标函数,至于为什么会有两个,在后面会讲

1. Actor目标函数/策略目标函数

 

公式中前面的项对应绿色线,后面的Clip项对应蓝色线,取min操作,得到的是红色线。

当A>0时,说明当前动作是优于平均的,我们应该让策略向着以后多采取这个动作的方向(x轴正向)更新,但是当新策略更新到与旧策略的比值等于1+ε时,我们就停止继续正向更新,从而实现策略更新的约束。

当A<0时,说明当前动作是低于平均的,我们就应该让更新后的策略降低采取当前动作的概率,也即向着x轴负向更新,同理,更新到新旧策略比值等于1-ε时,就停止负向更新。 

举个例子,模型在状态s时,可以采取的动作有输出{"you", "me", "he"}。

此时假设采取动作为输出"me",然后计算优势 A^\pi(s, ``me") ,假设值是0.8,说明"me"这个动作是优于平均的。

然后我们还需要计算旧策略下,在状态s的条件下采取动作"me"的概率,\pi_{old}(``me"|s_t) ,假设值是0.5

以及在当前策略下,在状态s的条件下采取动作"me"的概率,\pi_{t}(``me"|s_t) ,假设值是0.6

按照我们的目标函数,由于A>0,所以我们是希望在接下来更新的策略中,在状态s下,采取动作"me"的概率能够提高

假设超参数ε设置为0.3

当前策略与旧策略的概率比 r_t(\theta) = \frac{\pi_\theta(``me"|s_t)}{\pi_{old}(``me"|s_t)} = \frac{0.6}{0.5} = 1.2 < 1 + \epsilon = 1.3 ,所以在接下来的若干次策略更新中,我们的理想状态是,最终版本的策略,在状态s下,采取动作"me"的概率等于0.65,也就是达到新旧策略概率比上限1.3

2. Critic目标函数/价值目标函数

非常复杂的公式,要方便理解只需要看最后一个公式。

其中R_t是真实价值,等于在状态s下,旧策略的状态价值 + 当前策略的GAE优势(相当于当前策略带来的价值增长)。

目标函数的目标可以这么理解,在状态s下,让优化后的Critic Model预测当前策略的状态价值V_\pi(s_t),使其尽可能逼近真实价值R_t

Clip项的作用是约束更新后的模型预测出的状态价值不会过度偏离初始Critic模型预测出的状态价值。

四、RLHF-PPO

下图参考:https://zhuanlan.zhihu.com/p/13467768873

 整个过程大致三个阶段:

阶段1:基于预训练模型,用监督数据训练一个SFT model,用人类偏好数据训练一个Reward model。

阶段2:PPO过程,部署下述四个模型

来源model用途是否更新
SFT ModelActor Model要被优化的大模型,要学习的策略模型。更新
Reference Model为了控制Actor学习的分布与原始模型的分布相差不会太远的参考模型(通过loss中增加KL项来达实现)。不更新
Reward ModelCritic Model对每个状态做打分的价值模型,预估总收益 Vt。更新
Reward Model事先训练好的ORM(Outcome Reward Model),对整个生成的结果打分,计算即时收益 Rt。不更新

阶段3: 

大致流程如上图所示,逐个步骤来拆解:

(1)采样一批prompt作为输入;每一个prompt其实就是一个Environment。

(2)调用Actor Model(policy model),输出预测token概率分布(即输出动作);需注意,仅第一步的输入是q(prompt),后续的输入是 o_t(prompt+预测token);

(3)调用Reference Model,输入是 o_t(prompt+预测token),输出是旧策略(即最开始的策略)预测的token概率分布。Actor Model与Reference Model输出的新旧策略的概率分布会计算概率比r;需注意,采取clip策略,此处不需要完整计算KL。

r(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{ref}(a_t|s_t)}

(4)调用Reward Model,输入是完整的qa对(prompt+完整答案),输出是完整答案的得分,也即最终奖励 R_TR_T 作用是用于计算每个每个时间步的奖励 r_t

(5)调用Value Model,输入是 o_t,输出是当前策略的状态价值  V_\pi(s_t)。需注意,旧策略的状态价值 V_\pi^{old}(s) 是在最开始完成计算并保存。

(6)通过状态价值 V_\pi(s_t) 和 R_T,可以计算GAE优势 A_t^{GAE}

(7)通过 r = \frac{\pi_\theta(a_t|s_t)}{\pi_{ref}(a_t|s_t)}和 A_t^{GAE},可以代入Actor目标函数计算策略损失;通过 V_\pi(s_t)V_\pi^{old}(s)和 A_t^{GAE},可以代入Critic目标函数计算价值损失。从而完成两个模型的更新。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/913480.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/913480.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python爬虫实战:研究chardet库相关技术

1. 引言 1.1 研究背景与意义 在互联网信息爆炸的时代,网络数据采集技术已成为信息获取、数据分析和知识发现的重要手段。Python 作为一种高效的编程语言,凭借其丰富的第三方库和简洁的语法,成为爬虫开发的首选语言之一。然而,在网络数据采集中,文本编码的多样性和不确定…

回溯题解——全排列【LeetCode】

46. 全排列 一、算法逻辑&#xff08;逐步通顺讲解每一步思路&#xff09; 该算法使用了典型的 回溯&#xff08;backtracking&#xff09; 状态数组 思路&#xff0c;逐层递归生成排列。 题目目标&#xff1a;给定一个无重复整数数组 nums&#xff0c;返回其所有可能的全排…

RICE模型或KANO模型在具体UI评审时的运用经验

模型是抽象的产物,结合场景才好说明(数据为非精确实际数据,仅供参考,勿照搬)。 ​​案例一:RICE模型解决「支付流程优化」vs「首页动效升级」优先级争议​​ ​​背景​​:APP电商模块在迭代中面临两个需求冲突——支付团队主张优化支付失败提示(减少用户流失),设计…

缓存中间件

缓存与分布式锁 即时性、数据一致要求不高的 访问量大且更新频率不高的数据 &#xff08;读多&#xff0c;写少&#xff09; 常用缓存中间件 redis Spring 如果用spring的情况下&#xff0c;由于redis没有受spring的管理&#xff0c; 则我们需要自己先写一个redis的配置类&…

大语言模型全方位解析:从基础认知到RESTful API应用

文章目录 前言一、初见大模型1.1 大语言模型基本知识了解&#xff08;一&#xff09;日常可能用到的大语言模型&#xff08;二&#xff09;大模型的作用&#xff08;三&#xff09;核心价值 1.2 大模型与人工智能关系1.3 大语言模型的“前世今生”与发展1.3.1 大语言模型的发展…

网安系列【11】之目录穿越与文件包含漏洞详解

文章目录 前言一 目录穿越漏洞1.1 什么是目录穿越&#xff1f;1.2 目录穿越的原理1.3 目录穿越的常见形式1.3.1 基本形式1.3.2 编码绕过1.3.3 绝对路径攻击 1.4 实战案例解析1.4.1 案例1&#xff1a;简单的目录穿越1.4.2 案例2&#xff1a;编码绕过 1.5 目录穿越的危害 二、文件…

uri-url-HttpServletRequest

1. 使用HttpServletRequest UrlPathHelper 解析 出 url路径 org.springframework.web.util.UrlPathHelper 是 Spring 框架中用于处理 HTTP 请求路径的一个工具类&#xff0c;它帮助解析和处理与请求路径相关的细节。特别是 getLookupPathForRequest(HttpServletRequest request…

Ubuntu22.04安装p4显卡 nvidia-utils-570-server 570.133.20驱动CUDA Version: 12.8

Ubuntu22.04安装p4显卡 nvidia-utils-570-server 570.133.20驱动CUDA Version: 12.8专业显卡就是专业显卡&#xff0c;尽管p4已经掉到了白菜价&#xff0c;官方的支持却一直都保持&#xff0c;比如它可以装上cuda12.8,这真的出乎我意料。NVIDIA Tesla P4显卡的主要情况Pascal架…

工业日志AI大模型智能分析系统-前端实现

目录 主要架构 前端项目结构 1. 核心实现代码 1.1 API服务封装 (src/api/log.ts) 1.2 TS类型定义 (src/types/api.ts) 1.3 Pinia状态管理 (src/stores/logStore.ts) 1.4 日志分析页面 (src/views/LogAnalysis.vue) 1.5 日志详情组件 (src/components/LogDetail.vue) 2…

C++内存泄漏排查

引言 C内存泄漏问题的普遍性与危害内存泄漏排查大赛的背景与目标文章结构和主要内容概述 内存泄漏的基本概念 内存泄漏的定义与类型&#xff08;显式、隐式、循环引用等&#xff09;C中常见的内存泄漏场景&#xff08;指针管理不当、资源未释放等&#xff09;内存泄漏对程序性能…

20250706-4-Docker 快速入门(上)-常用容器管理命令_笔记

一、常用管理命令1. 选项&#xfeff;&#xfeff;1&#xff09;ls&#xfeff;功能&#xff1a;列出容器常用参数&#xff1a;-a&#xff1a;查看所有容器包含退出的-q&#xff1a;列出所有容器ID-l&#xff1a;列出最新创建的容器状态使用技巧&#xff1a;容器很多时使用dock…

基于 Camunda BPM 的工作流引擎示例项目

项目介绍 这是一个基于 Camunda BPM 的工作流引擎示例项目&#xff0c;包含完整的后台接口和前端页面&#xff0c;实现了流程的设计、部署、执行等核心功能。 技术栈 后端 Spring Boot 2.7.9Camunda BPM 7.18.0MySQL 8.0JDK 1.8 前端 Vue 3Element PlusBpmn.jsVite 功能…

Day06_刷题niuke20250707

试卷01&#xff1a; 单选题 C 1. 在C中,一个程序无论由多少个源程序文件组成,其中有且仅有一个主函数main().说法是否正确&#xff1f; A 正确 B 错误 正确答案&#xff1a;A 官方解析&#xff1a; 在C程序设计中,一个完整的程序确实有且仅有一个main函数作为程序的入口点,这…

洛谷 P5788 【模板】单调栈

题目背景模板题&#xff0c;无背景。2019.12.12 更新数据&#xff0c;放宽时限&#xff0c;现在不再卡常了。题目描述给出项数为 n 的整数数列 a1…n​。定义函数 f(i) 代表数列中第 i 个元素之后第一个大于 ai​ 的元素的下标&#xff0c;即 f(i)mini<j≤n,aj​>ai​​{…

linux系统运行时_安全的_备份_还原_方法rsync

1.问题与需求 问题: 新部署的机器设备(主控RK3588), 没有经过烧录定制镜像, 研发部署, 直接组装发送到客户现场需要通过frpc远程部署: 安装ros2 python包 docker镜像 环境配置 自启动配置 SN设备信息写自动部署脚本, 实现一键部署升级无奈物联网卡做了白名单限制, apt 和…

18套精美族谱Excel模板,助力家族文化传承!

【资源分享】18套精美族谱Excel模板&#xff0c;助力家族文化传承&#xff01; &#x1f3af; 本文分享一套完整的家族谱系资源&#xff0c;包含18个精心设计的Excel模板&#xff0c;从基础模板到专业图表&#xff0c;满足各类家族的族谱制作需求。 一、为什么要制作族谱&…

MySQL Galera Cluster企业级部署

一、MySQL Galera Cluster简介 主要特点 同步复制&#xff1a; 所有的写操作&#xff08;包括插入、更新、删除&#xff09;在集群中的所有节点上都是同步的。这意味着每个节点上的数据是完全一致的。 多主节点&#xff1a; 集群中的每个节点都是主节点。所有节点都可以处理读…

HTTP 重定向

什么是 HTTP 重定向&#xff1f; HTTP 重定向&#xff08;HTTP Redirect&#xff09; 是服务器向客户端&#xff08;通常是浏览器&#xff09;发出的指令&#xff0c;告诉客户端某个请求的资源已被移到新的位置。重定向通常通过发送一个特殊的 HTTP 状态码&#xff08;例如 3x…

本地加载非在线jar包设置

项目中存在私有jar包&#xff0c;提示在线获取不到&#xff0c;需要先获取到完整的jar包在打进maven中再在项目中进行maven依赖引入 mvn install:install-file -DfileD:\tools\maven\apache-maven-3.5.2\local_repository2\org\ahjk\SixCloudCommon\1.0\SixCloudCommon-1.0-SN…

Codeforces Round 979 (Div. 2)

A c[1]-b[1]0&#xff0c;之后每个c[1]-b[1]最大都是maxa-mina&#xff0c;最大和最小放前两个 B ans2^(a1)-2^s-1&#xff0c;1一个最小 C 我们可以把式子化为(....)||(....)||(....)括号里没有||&#xff0c;如果括号全是1那么A赢&#xff0c;A尽量选择把1选在一起 D …