PPO和GRPO算法

        verl 是现在非常火的 rl 框架,而且已经支持了多个 rl 算法(ppo、grpo 等等)。

        过去对 rl 的理解很粗浅(只知道有好多个角色,有的更新权重,有的不更新),也曾硬着头皮看了一些论文和知乎,依然有很多细节不理解,现在准备跟着 verl 的代码梳理一遍两个著名的 rl 算法,毕竟代码不会隐藏任何细节!

        虽然 GRPO 算法是基于 PPO 算法改进来的,但是毕竟更简单,所以我先从 GRPO 的流程开始学习,然后再看 PPO。

GRPO 论文中的展示的总体流程:

论文中这张图主要展示了 GRPO 和 PPO 的区别,隐藏了其他的细节。

图中只能注意到以下几个关键点:

  • 没有 Value Model 和输出 v(value)

  • 同一个 q 得出了一组的 o(从 1 到 G)

  • 计算 A(Advantage) 的算法从 GAE 变成了 Group Computation

  • KL 散度计算不作用于 Reward Model,而是直接作用于 Policy Model

        其他细节看不懂,结合论文也依然比较抽象,因为我完全没有 RL 的知识基础,下文中我们结合代码会再一次尝试理解。

        下面是我根据 verl 代码自己 DIY 的流程图(帮助理解):

01 第一步:Rollout

        第一步是 rollout,rollout 是一个强化学习专用词汇,指的是从一个特定的状态按照某个策略进行一些列动作和状态转移。

        在 LLM 语境下,“某个策略”就是 actor model 的初始状态,“进行一些列动作”指的就是推理,即输入 prompt 输出 response 的过程。

verl/trainer/ppo/ray_trainer.py:

gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)

        其背后的实现一般就是是 vllm 或 sglang 这些常见推理框架的离线推理功能,这部分功能相对独立我们先不展开。

权重同步

        一个值得注意的细节是代码里面的 rollout_sharding_manager 实现,它负责每一个大 step 结束后把刚刚训练好的 actor model 参数更新到 vllm 或 sglang。

        这样下一个大 step 的 rollout 采用的就是最新的模型权重(最新的策略)了。

        这是每一个大 step 里面真正要做的第一件事,在真正执行 rollout 之前。

        verl/workers/fsdp_workers.py:

class ActorRolloutRefWorker(Worker):   # ...    @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)     def generate_sequences(self, prompts: DataProto):       # ...        with self.rollout_sharding_manager:            # ...            prompts = self.rollout_sharding_manager.preprocess_data(prompts)           output = self.rollout.generate_sequences(prompts=prompts)            output = self.rollout_sharding_manager.postprocess_data(output)

rollout_sharding_manager 的基类是 BaseShardingManager。

verl/workers/sharding_manager/base.py:

class BaseShardingManager:   def __enter__(self):        pass    def __exit__(self, exc_type, exc_value, traceback):        pass    def preprocess_data(self, data: DataProto) -> DataProto:        return data    def postprocess_data(self, data: DataProto) -> DataProto:        return data

  BaseShardingManager 的派生类在各自的 __enter__ 方法中实现了把 Actor Model 的权重 Sync 到 Rollout 实例的逻辑,以保证被 with self.rollout_sharding_manager 包裹的预处理和推理逻辑都是用的最新 Actor Model 权重。

推理 N 次

        此外,GRPO 算法要求对每一个 prompt 都生成多个 response,后续才能根据组间对比得出相对于平均的优势(Advantage)。

verl/trainer/config/ppo_trainer.yaml:

actor_rollout_ref:  rollout:    # number of responses (i.e. num sample times)   n: 1 # > 1 for grpo

        在 _build_rollout 的时候 actor_rollout_ref.rollout.n 被传给了 vLLMRollout 或其他的 Rollout 实现中,从而推理出 n 组 response。

verl/workers/fsdp_workers.py:

class ActorRolloutRefWorker(Worker):    def _build_rollout(self, trust_remote_code=False):        # ...        elif rollout_name == "vllm":            # ...            if vllm_mode == "customized":                rollout = vLLMRollout(                   actor_module=self.actor_module_fsdp,                                  config=self.config.rollout,                   tokenizer=self.tokenizer,                    
model_hf_config=self.actor_model_config,               )

02 第二步:计算 log prob

        log 是 logit,prob 是 probability,合起来就是对数概率,举一个简单的例子来说明什么是 log prob:

词表仅有 5 个词:    
<pad> (ID 0)    
你好 (ID 1)    
世界 (ID 2)   
! (ID 3)    
吗 (ID 4)
prompt:你好
prompt tokens: [1]
response:世界!
response tokens: [2,3]
模型前向传播得到完整的 logits 张量:
[    [-1.0, 0.5, 2.0, -0.5, -1.5],    // 表示 “你好” 后接 “世界” 概率最高,数值为 2.0    [-2.0, -1.0, 0.1, 3.0, 0.2]      // 表示 “你好世界” 后接 “!” 概率最高,数值为 3.0]
对每个 logit 计算 softmax 得到:
[    [-3.65, -2.15, -0.64, -3.15, -4.08],    [-4.34, -3.32, -2.20, -0.20, -2.10]]
提取实际 response 对应的数值:得到 log_probs:
[-0.64, -0.20]

总结下来:

  • 首先计算 prompt + response(来自 rollout)的完整 logits,即每一个 token 的概率分布

  • 截取 response 部分的 logits

  • 对每一个 logits 计算 log_sofmax(先 softmax,然后取对数),取出最终预测的 token 对应的 log_sofmax

  • 最终输出 old_log_probs, size = [batchsize, seq_len]

        此处你可能会有一个疑惑:在上一步 Rollout 的时候我们不是已经进行过完整 batch 的推理了么?

        为什么现在还要重复进行一次 forward 来计算 log_prob,而不是在 generate 的过程中就把 log_prob 保存下来?

答:因为 generate_sequences 阶段为了高效推理,不会保存每一个 token 的 log_prob,相反只关注整个序列的 log_prob。因此需要重新算一遍。

答:另外,vllm 官方 Q&A 中提到了 vllm 框架并不保证 log_probs 的稳定性。因为 pytorch 的 numerical instability 与 vllm 的并发批处理策略导致每一个 token 的 logits/log_probs 结果会略有不同,假如某一个 token 位采样了不同 token id,那么这个误差在后续还会被继续累加。我们在训练过程需要保证 log_probs 的稳定性,因此需要根据已经确定的 token id(即 response)再次 forward 一遍。

old log prob

verl/workers/fsdp_workers.py:

old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)

        指 Actor Model 对整个 batch 的数据(prompt + response)进行 forward 得到的 log_prob

        此处的 “old” 是相对于后续的 actor update 阶段,因为现在 actor model 还没有更新,所以依然采用的是旧策略 (ps:当前 step 的“旧策略”也是上一个大 step 的“新策略”)

ref log prob

verl/trainer/ppo/ray_trainer.py:

ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)

        指 Ref Model 对整个 batch 的数据(prompt + response)进行 forward 得到的 log_prob。

        通常 Ref Model 就是整个强化学习开始之前 Actor Model 最初的模样,换句话说第一个大 step 开始的时候 Actor Model == Ref Model,且 old_log_prob == ref_log_prob。

        Ref Model 的作用是在后续计算 policy loss 之前,计算 KL 散度并作用于 policy loss,目的是让 actor model 不要和最初的 ref model 相差太远。

03第三步:advantage

        advantage 是对一个策略的好坏最直接的评价,其背后就是 Reward Model,甚至也许不是一个 Model,而是一个粗暴的 function,甚至一个 sandbox 把 prompt+response 执行后得出的结果。

        在 verl 中允许使用上述多种 Reward 方案中的一种或多种,并把得出的 score 做合。

verl/trainer/ppo/ray_trainer.py:

# compute reward model score
if self.use_rm:    reward_tensor = self.rm_wg.compute_rm_score(batch)    batch = batch.union(reward_tensor)
if self.config.reward_model.launch_reward_fn_async:    future_reward = compute_reward_async.remote(batch, self.config, self.tokenizer)
else:   reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)

然后用这个 score 计算最终的 advantage。

verl/trainer/ppo/ray_trainer.py:

# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get(    "norm_adv_by_std_in_grpo", True)  
# GRPO adv normalization factorbatch = compute_advantage(    batch,    
adv_estimator=self.config.algorithm.adv_estimator,   gamma=self.config.algorithm.gamma,    
lam=self.config.algorithm.lam,    
num_repeat=self.config.actor_rollout_ref.rollout.n,    norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,)

04第四步:actor update(小循环)

        在 PPOTrainer 中简单地一行调用,背后可是整个 GRPO 算法中最关键的步骤:

actor_output = self.actor_rollout_wg.update_actor(batch)

        在这里,会把上面提到的整个 batch 的数据再根据 actor_rollout_ref.actor.ppo_mini_batch_size 配置的值拆分成很多个 mini batch。

        然后对每一个 mini batch 数据进行一轮 forward + backward + optimize step,也就是小 step。

new log prob

        每一个小 step 中首先会对 mini batch 的数据计算(new)log_prob,第一个小 step 得到的值还是和 old_log_prob 一模一样的。

pg_loss

        然后通过输入所有 Group 的 Advantage 以新旧策略的概率比例(old_log_prob 和 log_prob),得出 pg_loss(Policy Gradient),这是最终用于 backward 的 policy loss 的基础部分。

        再次描述一下 pg_loss 的意义,即衡量当前策略(log_prob)相比于旧策略(old_log_prob),在当前优势函数(advantage)指导下的改进程度。

verl/workers/actor/dp_actor.py:

pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss(    old_log_prob=old_log_prob,    
log_prob=log_prob,    
advantages=advantages,    
response_mask=response_mask,    
cliprange=clip_ratio,    
cliprange_low=clip_ratio_low,    
cliprange_high=clip_ratio_high,    
clip_ratio_c=clip_ratio_c,    
loss_agg_mode=loss_agg_mode,)

entropy loss

        entropy 指策略分布的熵 (Entropy):策略对选择下一个动作(在这里是下一个 token)的不确定性程度。

        熵越高,表示策略输出的概率分布越均匀,选择各个动作的概率越接近,策略的探索性越强;熵越低,表示策略越倾向于选择少数几个高概率的动作,确定性越强。

  entropy_loss 指 entropy 的 平均值,是一个标量,表示探索性高低。

verl/workers/actor/dp_actor.py:

if entropy_coeff != 0:   entropy_loss = agg_loss(loss_mat=entropy, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)   # compute policy loss    policy_loss = pg_loss - entropy_loss * entropy_coeff
else:   policy_loss = pg_loss

计算 KL 散度

        这里用到了前面 Ref Model 推出的 ref_log_prob,用这个来计算 KL 并作用于最后的 policy_loss,保证模型距离 Ref Model(初始的模型)偏差不会太大。

verl/workers/actor/dp_actor.py:

if self.config.use_kl_loss:    ref_log_prob = data["ref_log_prob"]   # compute kl loss    kld = kl_penalty(logprob=log_prob, ref_logprob=ref_log_prob, kl_penalty=self.config.kl_loss_type    )    kl_loss = agg_loss(loss_mat=kld, loss_mask=response_mask, loss_agg_mode=self.config.loss_agg_mode    )    policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef    metrics["actor/kl_loss"] = kl_loss.detach().item()    metrics["actor/kl_coef"] = self.config.kl_loss_coef

反向计算

verl/workers/actor/dp_actor.py:

loss.backward()

        持续循环小 step,直到遍历完所有的 mini batch,Actor Model 就完成了本轮的训练,会在下一个大 step 前把权重 sync 到 Rollout实例当中,准备处理下一个大 batch 数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/pingmian/84003.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PyTorch——优化器(9)

优化器根据梯度调整参数&#xff0c;以达到降低误差 import torch.optim import torchvision from torch import nn from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear from torch.utils.data import DataLoader# 加载CIFAR10测试数据集&#xff0c;设置tr…

c++学习-this指针

1.基本概念 非静态成员函数都会默认传递this指针&#xff08;静态成员函数属于类本身&#xff0c;不属于某个实例对象&#xff09;&#xff0c;方便访问对象对类成员变量和 成员函数。 2.基本使用 编译器实际处理类成员函数&#xff0c;this是第一个隐藏的参数&#xff0c;类…

【Oracle】数据仓库

个人主页&#xff1a;Guiat 归属专栏&#xff1a;Oracle 文章目录 1. 数据仓库概述1.1 为什么需要数据仓库1.2 Oracle数据仓库架构1.3 Oracle数据仓库关键技术 2. 数据仓库建模2.1 维度建模基础2.2 星形模式设计2.3 雪花模式设计2.4 缓慢变化维度&#xff08;SCD&#xff09;处…

css-塞贝尔曲线

文章目录 1、定义2、使用和解释 1、定义 cubic-bezier() 函数定义了一个贝塞尔曲线(Cubic Bezier)语法&#xff1a;cubic-bezier(x1,y1,x2,y2) 2、使用和解释 x1,y1,x2,y2&#xff0c;表示两个点的坐标P1(x1,y1),P2(x2,y2)将以一条直线放在范围只有 1 的坐标轴中&#xff0c;并…

函数式接口实现分页查询

你提供的 PageResult 类是一个非常完整、功能齐全的分页结果封装类&#xff0c;它包含了&#xff1a; 当前页数据&#xff08;list&#xff09;总记录数&#xff08;totalCount&#xff09;总页数&#xff08;totalPage&#xff09;当前页码&#xff08;pageNo&#xff09;每页…

Global Security Markets 第 10 章衍生品知识点总结​

一、衍生品的定义与本质 衍生品&#xff0c;作为一种金融工具&#xff0c;其价值并非独立存在&#xff0c;而是紧密依赖于其他资产&#xff0c;如常见的股票、债券、商品&#xff0c;或者市场变量&#xff0c;像利率、汇率、股票指数等。这意味着衍生品的价格波动&#xff0c;…

DJango知识-模型类

一.项目创建 在想要将项目创键的目录下,输入cmd (进入命令提示符)在cmd中输入:Django-admin startproject 项目名称 (创建项目)cd 项目名称 (进入项目)Django-admin startapp 程序名称 (创建程序)python manage.py runserver 8080 (运行程序)将弹出的网址复制到浏览器中…

八股学习-JS的闭包

一.闭包的定义 闭包是指函数和其周围的词法环境的引用的组合。 简单来说&#xff0c;就是函数可以记住并访问其在定义时的作用域内的变量&#xff0c;即使该函数在其它作用域调用。 也就是说&#xff0c;闭包让你可以在一个内层函数中访问到其外层函数的作用域。 function …

qt使用笔记二:main.cpp详解

Qt中main.cpp文件详解 main.cpp是Qt应用程序的入口文件&#xff0c;包含程序的启动逻辑。下面我将详细解析其结构和功能。 基本结构 一个典型的Qt main.cpp 文件结构如下&#xff1a; #include <QApplication> // 或者 QGuiApplication/QCoreApplication #include &…

如何构建船舵舵角和船的航向之间的动力学方程?它是一个一阶惯性环节吗?

提问 船舵和船的航向之间的动力学方程是什么&#xff1f;是一个一阶惯性环节吗&#xff1f; 回答 船舵和船的航向&#xff08;航向角&#xff09;之间的动力学关系并不是一个简单的一阶惯性环节&#xff0c;虽然在某些简化控制模型中可以近似为一阶系统。实际上&#xff0c;…

抖去推--短视频矩阵系统源码开发

一、开发短视频矩阵系统的源码需要以下步骤&#xff1a; 确定系统需求&#xff1a; 根据客户的具体业务目标&#xff0c;明确系统需实现的核心功能模块&#xff0c;例如用户注册登录、视频内容上传与管理、多维度视频浏览与推荐、用户互动&#xff08;评论、点赞、分享&#xf…

Windows 下搭建 Zephyr 开发环境

1. 系统要求 操作系统&#xff1a;Windows 10/11&#xff08;64位&#xff09;磁盘空间&#xff1a;至少 8GB 可用空间&#xff08;Zephyr 及其工具链较大&#xff09;权限&#xff1a;管理员权限&#xff08;部分工具需要&#xff09; 2. 安装必要工具 winget安装依赖工具&am…

三分算法与DeepSeek辅助证明是单峰函数

前置 单峰函数有唯一的最大值&#xff0c;最大值左侧的数值严格单调递增&#xff0c;最大值右侧的数值严格单调递减。 单谷函数有唯一的最小值&#xff0c;最小值左侧的数值严格单调递减&#xff0c;最小值右侧的数值严格单调递增。 三分的本质 三分和二分一样都是通过不断缩…

安全月报 | 傲盾DDoS攻击防御2025年5月简报

引言 在2025年5月&#xff0c;全球数字化进程高歌猛进&#xff0c;各行各业深度融入数字浪潮&#xff0c;人工智能、物联网、大数据等前沿技术蓬勃发展&#xff0c;进一步夯实了数字经济的基石。然而&#xff0c;在这看似繁荣的数字生态背后&#xff0c;网络安全威胁正以惊人的…

【Spring】Spring哪些源码解决了哪些问题P1

欢迎来到啾啾的博客&#x1f431;。 记录学习点滴。分享工作思考和实用技巧&#xff0c;偶尔也分享一些杂谈&#x1f4ac;。 有很多很多不足的地方&#xff0c;欢迎评论交流&#xff0c;感谢您的阅读和评论&#x1f604;。 目录 Spring是怎么处理请求的&#xff1f;Spring请求方…

坚持每日Codeforces三题挑战:Day 4 - 题目详解(2025-06-07,难度:1000, 1100, 1400)

前言&#xff1a; 此文章主要是记录每天的codeforces刷题&#xff0c;还有就是给其他打算法竞赛的人一点点点点小小的帮助&#xff08;毕竟现在实力比较菜&#xff0c;题目比较简单&#xff0c;但我还是会认真写题解&#xff09;。 之前忙学校事情&#xff0c;懈怠了一段时间…

6.7本日总结

一、英语 复习默写list10list19&#xff0c;07年第3篇阅读 二、数学 学习线代第一讲&#xff0c;写15讲课后题 三、408 学习计组第二章&#xff0c;写计组习题 四、总结 本周结束线代第一讲和计组第二章&#xff0c;之后学习计网4.4&#xff0c;学完计网4.4之后开操作系…

PGSR : 基于平面的高斯溅射高保真表面重建【全流程分析与测试!】【2025最新版!!】

【PGSR】: 基于平面的高斯溅射高保真表面重建 前言 三维表面重建是计算机视觉和计算机图形学领域的核心问题之一。随着Neural Radiance Fields (NeRF)和3D Gaussian Splatting (3DGS)技术的发展&#xff0c;从多视角RGB图像重建高质量三维表面成为了研究热点。今天我们要深入…

从认识AI开始-----AutoEncoder:生成模型的起点

前言 从15年开始&#xff0c;在深度学习的重要模型中&#xff0c;AutoEncoder&#xff08;自编码器&#xff09;可以说是打开生成模型世界的起点。它不仅是压缩与重建的工具&#xff0c;更是VAE、GAN、DIffusion等复杂生成模型的思想起源。其实AutoEncoder并不复杂&#xff0c;…

解决MySQL8.4报错ERROR 1524 (HY000): Plugin ‘mysql_native_password‘ is not loaded

最近使用了MySQL8.4 , 服务启动成功,但是就是无法登陆,并且报错: ERROR 1524 (HY000): Plugin mysql_native_password is not loaded 使用如下的命令也报错 mysql -u root -p -P 3306 问题分析: 在MySQL 8.0版本中,默认的认证插件从mysql_native_password变更为cachi…