语言模型 RLHF 实践指南(一):策略网络、价值网络与 PPO 损失函数


在使用 Proximal Policy Optimization(PPO)对语言模型进行强化学习微调(如 RLHF)时,大家经常会问:

  • 策略网络的动作概率是怎么来的?
  • 价值网络的得分是如何计算的?
  • 奖励从哪里来?损失函数怎么构建?
  • 微调后的旧轨迹还能用吗?

这篇文章将以语言模型强化学习微调为例,结合实际实现和数学公式,深入解析 PPO 的关键计算流程。


1️⃣ 策略网络:如何计算动作概率?

策略网络 πθ(a∣s)\pi_\theta(a|s)πθ(as) 用于给出状态 sss 下采取动作 aaa 的概率。

对于语言模型(如 GPT)来说:

  • 状态 sss:Prompt(如“请介绍量子计算”)
  • 动作 aaa:生成的回答(如“量子计算是一种…”)

策略网络的输出是 token 级别的 logits,经 softmax 后得到概率:

outputs = model(input_ids)
logits = outputs.logits                         # [batch_size, seq_len, vocab_size]
probs = torch.softmax(logits, dim=-1)           # 得到 token 概率

对于一个完整回答,其概率为:

πθ(a1:T∣s)=∏t=1Tπθ(at∣s,a<t) \pi_\theta(a_{1:T} | s) = \prod_{t=1}^T \pi_\theta(a_t | s, a_{<t}) πθ(a1:Ts)=t=1Tπθ(ats,a<t)

该概率在 PPO 中用于计算策略概率比:

rt=πθ(at∣st)πθold(at∣st) r_t = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} rt=πθold(atst)πθ(atst)


2️⃣ 价值网络:如何计算状态得分?

价值网络 Vϕ(s)V_\phi(s)Vϕ(s) 预测的是状态 sss 的期望累计奖励,即该 prompt + 回复的“好坏”。

实现方式通常是共享模型底座 + 线性输出层:

hidden_states = outputs.hidden_states         # [batch_size, seq_len, hidden_dim]
value = value_head(hidden_states).squeeze(-1) # 每个 token 对应一个值

通常使用最后一个 token 的 value 作为整段文本的状态值:

Vϕ(s)=value(last_token) V_\phi(s) = \text{value}(\text{last\_token}) Vϕ(s)=value(last_token)
也可以做 mean pooling 等方式。


3️⃣ 奖励函数:怎么定义?

PPO 是一个基于奖励优化的强化学习算法。对于语言模型,一般使用人工偏好、打分器、奖励模型(RM)来计算奖励 RRR

示例方式:

  • 高质量回答奖励高,例如 R=+4R = +4R=+4
  • 差的回答奖励低,例如 R=+1R = +1R=+1
  • 或者使用两个回复的相对排序值差距(ranking loss)

PPO 使用奖励和预测值来计算优势函数(Advantage):

A^t=Rt−Vϕ(st) \hat{A}_t = R_t - V_\phi(s_t) A^t=RtVϕ(st)

也可以用 GAE(广义优势估计)进一步平滑优势值。


4️⃣ PPO 策略损失函数:如何构建?

核心损失函数如下(Clipped Surrogate Objective):

Lpolicy=−Et[min⁡(rtA^t,clip(rt,1−ϵ,1+ϵ)A^t)] L^{\text{policy}} = -\mathbb{E}_t \left[ \min \left( r_t \hat{A}_t, \text{clip}(r_t, 1 - \epsilon, 1 + \epsilon) \hat{A}_t \right) \right] Lpolicy=Et[min(rtA^t,clip(rt,1ϵ,1+ϵ)A^t)]

解释:

  • rtr_trt 是策略概率比
  • A^t\hat{A}_tA^t 是优势函数
  • ϵ\epsilonϵ 是截断系数(常用 0.2)

这个损失保证了策略更新不能偏离旧策略太远,防止训练不稳定。

🔍 第一次微调时,rt=1r_t = 1rt=1

由于初始时当前策略与旧策略相同,有:

rt=πθ(at∣st)πθold(at∣st)=1 r_t = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)} = 1 rt=πθold(atst)πθ(atst)=1

所以第一次策略更新实际变成:

Lpolicy=−A^t L^{\text{policy}} = -\hat{A}_t Lpolicy=A^t

相当于标准的策略梯度算法。


5️⃣ PPO 价值损失函数:如何构建?

价值网络使用均方误差损失来拟合奖励:

Lvalue=Et[(Vϕ(st)−Rt)2] L^{\text{value}} = \mathbb{E}_t \left[ \left( V_\phi(s_t) - R_t \right)^2 \right] Lvalue=Et[(Vϕ(st)Rt)2]

也可以加上 value clipping:

Lvalue-clipped=max⁡((Vϕ(st)−Rt)2,(clip(Vϕ(st),Vold−ϵ,Vold+ϵ)−Rt)2) L^{\text{value-clipped}} = \max\left( (V_\phi(s_t) - R_t)^2, (\text{clip}(V_\phi(s_t), V_{\text{old}} - \epsilon, V_{\text{old}} + \epsilon) - R_t)^2 \right) Lvalue-clipped=max((Vϕ(st)Rt)2,(clip(Vϕ(st),Voldϵ,Vold+ϵ)Rt)2)


6️⃣ 总损失函数:包含 entropy 奖励

完整的 PPO 损失函数通常为:

L=Lpolicy+c1⋅Lvalue−c2⋅H(πθ) L = L^{\text{policy}} + c_1 \cdot L^{\text{value}} - c_2 \cdot H(\pi_\theta) L=Lpolicy+c1Lvaluec2H(πθ)

  • H(πθ)H(\pi_\theta)H(πθ):策略的熵,用于鼓励探索(entropy bonus)
  • c1,c2c_1, c_2c1,c2:超参数,通常取 0.5 和 0.01

熵越高表示策略更随机,防止策略过早收敛到确定动作。


7️⃣ 微调后,旧轨迹还能继续用吗?

不能。PPO 是 on-policy 算法。

每轮策略更新后,旧轨迹(state, action, reward, old prob)就过时了,必须重新采样:

  • 旧策略生成的样本反映不了当前策略的行为
  • 若继续使用,会引入策略偏移(policy mismatch)

因此,PPO 的标准训练循环是:

  1. 用当前策略生成轨迹
  2. 固定轨迹,训练 N 个 epoch
  3. 更新策略后丢弃旧轨迹
  4. 重复采样新数据

✅ 总结回顾

项目内容说明
策略概率模型输出 logits → softmax 得到 token 概率
策略损失PPO clipped objective,基于概率比和优势函数
价值得分Value head 输出一个实数,预测状态期望奖励
奖励函数来自人工打分或奖励模型,指导优势函数计算
是否复用轨迹❌ 不能复用旧轨迹,策略更新后必须重新采样

🔚 写在最后

理解 PPO 中策略概率、价值得分、损失函数之间的关系,是成功实现 RLHF、SFT + RL 微调语言模型的基础。

这些原理不只是公式,更影响着你训练是否稳定、样本是否有效、微调是否收敛。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/90290.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/90290.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

日常--记一次gitlab Runner配置与CI/CD环境搭建流程

文章目录一、前言二、相关知识1.相关定义1.什么是 CI&#xff1f;2.什么是 CD&#xff1f;2.CI/CD 构建块与工具链3.为什么要使用 CI/CD&#xff1f;三、准备四、实现1.Runner安装与配置1.更新源2.安装Runner3.注册Runner4.启动Runner5.查看Runner信息2.CI/CD流程测试1.CI/CD构…

东方仙盟AI数据中间件使用教程:开启数据交互与自动化应用新时代——仙盟创梦IDE

一、启动未来之窗AI 二、初始化数据接口三、便捷接口数据进入东方仙盟获取接口标准四、同步参数仙界界牌&#xff0c;冥界界牌&#xff0c;仙盟界牌 五、开始同步六、东方仙盟青云剑魂架构在当今数字化浪潮下&#xff0c;数据的采集、处理与传输成为众多应用场景的核心需求。而…

Rust 仿射类型(Affine Types)

在 Rust 中&#xff0c;仿射类型&#xff08;Affine Types&#xff09; 是所有权系统的理论基础&#xff0c;它规定了每个值有且仅有一次使用机会。这与线性类型&#xff08;必须恰好使用一次&#xff09;有所不同&#xff0c;允许值未被使用就被丢弃。Rust 中的仿射类型核心特…

python库 arrow 库的各种案例的使用详解(更人性化的日期时间处理)

文章目录 一、arrow概述1.1 arrow介绍1.2 安装 arrow1.3 注意事项二、基本使用2.1 创建 Arrow 对象2.2 格式化输出2.3 时间运算三、高级功能3.1 时区处理3.2 时间范围3.3 时间间隔四、实际应用案例4.1 日志时间处理4.2 会议时间提醒4.3 国际化时间显示5. Arrow 与 datetime 互操…

window 服务器上部署前端静态资源以及nginx 配置

最近搞了一台境外服务器 这种境外服务器是不可以配置域名的 但是可以使用ip访问 但是如果需要 配置 需要下载nginx nginx: download 我这个是windows 的 服务器 所以下载windows 的nginx 下载完成以后 这个里面的html 文件 就是前端项目 里面必须要有index.html文件 部署…

行业实践案例:医疗行业数据治理的挑战与突破

“医疗数据不仅是资源,更关乎生命。” ——医疗行业的数据治理,是合规、安全、质量与智能化的多重挑战。 📘 本文目录 为什么医疗行业亟需数据治理? 医疗行业数据治理的独特挑战 医疗数据治理体系设计原则 关键能力模块与实践案例 工具选型与落地建议 总结与下一步 1️⃣ …

单细胞转录组学和空间转录组学数据的整合方法

文章目录问题1&#xff1a;现有技术是否可以拿取固定数目的细胞进行组合形成spot问题2&#xff1a;是否有关于这方面的研究问题3&#xff1a;相关论文推荐一、细胞反卷积的核心目标与挑战二、单细胞与空间转录组数据的整合方法分类1. 概率型方法&#xff08;Probabilistic-base…

【Java EE】SpringBoot 配置文件、日志和单元测试

1. 什么是配置文件在我们的计算机上诸如 C:/Users&#xff0c;C:/Windows&#xff0c;.config&#xff0c;.xml 都是配置文件&#xff0c;配置文件主要为了解决硬编码带来的问题。硬编码是将数据直接写在程序的源代码中&#xff0c;代码写死后再想改变就很麻烦。因此&#xff0…

CMake实践:常见的调试技巧

目录 1.简介 2.用 message() 输出关键信息 2.1.message简介 2.2.常用模式及作用 2.3.核心用法示例 2.4.常见问题及解决 3.查看缓存变量&#xff1a;cmake -L 与缓存文件 3.1.列出所有缓存变量&#xff08;cmake -L&#xff09; 3.2.直接查看 / 删除 CMakeCache.txt 4…

爬虫-第一个爬虫程序

浏览器里面都是html数据&#xff0c;拿到的都是页面源代码&#xff0c;可以用自己的方式打开测试。打开浏览器decode找charset

从SEO到GEO:优化策略如何应对传统搜索与AI搜索的巨变

AI 搜索与传统搜索结果优化之间有什么重叠之处&#xff1f; 为了帮助确定主要的差异&#xff0c;以及那些重叠程度最高的区域&#xff0c;我创建了一个比较&#xff08;我会保持更新&#xff09;&#xff0c;通过搜索行为、优化领域、结果展示和交付&#xff0c;以及要跟踪的 K…

mysql5.7系列-InnoDB的MVCC实现原理

谈到数据库事务都要提一下ACID 特性&#xff1a; 原子性&#xff08;Atomicity&#xff09;&#xff1a;事务中的操作要么全部执行&#xff0c;要么全部不执行。 一致性&#xff08;Consistency&#xff09;&#xff1a;事务执行前后&#xff0c;数据库的状态必须是一致的。 …

力扣-287.寻找重复数

题目链接 287.寻找重复数 class Solution {public int findDuplicate(int[] nums) {int low nums[0];int fast nums[nums[0]];//1.快慢指针找相遇点while (low ! fast) {low nums[low];fast nums[nums[fast]];}//2.双指针找入环点int pre 0;while (pre ! low) {pre num…

Java 大视界 -- Java 大数据在智能教育个性化学习计划制定与动态调整中的应用(338)

Java 大视界 -- Java 大数据在智能教育个性化学习计划制定与动态调整中的应用&#xff08;338&#xff09; 引言&#xff1a;正文&#xff1a;一、Java 构建的学习行为数据采集与分析体系1.1 全场景数据接入引擎1.2 家校协同数据交互模块1.3 学习特征提取与建模 二、Java 驱动的…

uniapp返回webview返回小程序并且跳转回webview

webview页面提示&#xff1a;wx一定要导入sdk// 返回小程序&#xff0c;并携带当前 WebView 的 URL 和状态wx.miniProgram.postMessage({type: requestPayment,data: {webviewUrl: window.location.href,orderNum: this.orderNum,type: requestPayment}})setTimeout(() > {w…

[java: Cleaner]-一文述之

Cleaner Cleaner 是 Java 9 引入的资源清理机制&#xff0c;用于在对象被垃圾回收后自动或手动执行清理操作&#xff0c;替代 finalize()&#xff0c;安全、异步且高效。 public final class Cleaner {final CleanerImpl impl;static {CleanerImpl.setCleanerImplAccess(new Fu…

知识库中如何确实嵌入文本块大小?语义完整性与检索颗粒度的平衡机制

一、文本块大小确定的理论基础与历史演进 1.1 概念起源与发展脉络 文本块&#xff08;Text Chunk&#xff09; 这一概念最初源于信息检索领域的实践需求。早期的全文检索系统面临着一个根本性矛盾&#xff1a;如何在保持文档语义完整性的同时&#xff0c;实现高效的信息定位。这…

C/C++ 实现在快速排序Quick Sort中的三种分区方式

1. 简介神说, 要有光. 于是就有了光. 神说要有快排, 于是就有了快排. 快速排序Quick Sort的发明者 托尼 霍尔 是1980年的图灵奖得主. 快速排序就是他发明的. 当时发明的背景是: 由于霍尔要高效地对俄语词汇进行排序以优化翻译程序, 而当时的排序算法(如冒泡, 插入排序)效率较低…

Flink TiDB CDC 环境配置与验证

一、TiDB 数据库核心配置 1. 启用 TiCDC 服务 确保 TiDB 集群已部署 TiCDC 组件&#xff08;版本需兼容 Flink CDC 3.0.1&#xff09;&#xff0c;并启动同步服务&#xff1a; # 示例&#xff1a;启动 TiCDC 捕获 changefeed cdc cli changefeed create \--pd"localhos…

2025年数据挖掘与计算机科学国际会议 (DMCS 2025)

2025 International Conference on Data Mining and Computer Science【一】、大会信息 会议简称&#xff1a;DMCS 2025 大会地点&#xff1a;中国广州 收录检索&#xff1a;提交Ei Compendex,CPCI,CNKI,Google Scholar等【二】会议简介2025年数…