【模型训练篇】VeRL的使用 - RL(PPO)与源码

继续学习字节家的VeRL,今天来看看VeRL的RL,是VeRL系列的第三篇文章(话说近期好多大事儿,我司发布了Longcat、韩立结婴、阿里周五发布了QWen-Next都是好东西啊,学不过来了damn)

  • 底层分布式能力基础Ray(点击查看):VeRL分布式能力的基础,框架Ray
  • VeRL的原理(点击查看):HybridFlow
  • VeRL的使用(点击查看):普通RL(PPO)
  • VeRL的使用,Agentic RL(多轮RL)
  • VeRL的魔改

前两篇文章分别介绍了VeRL的分布式基础和其底层原理,下面就以RL的PPO为例,同时结合源码,看看具体的使用。

安装

  • 使用docker的话,verl提供了诸多版本可以使用,例如纯净的只包含Verl/CUDA/PyTorch等依赖的base镜像,也有整合了vLLM/SGLang/FSDP/Megatron的application镜像
  • 手动安装的话,要从CUDA/cuDNN等基础库开始,一定会遇到冲突(嗯,一定…)

使用

  1. 首先在Ray的Head节点上执行 ray start --head --dashboard-host=0.0.0.0,之后会得到两个address:
  • 一个是集群内head/worker之间通信用的 GCS address
  • 一个是提交与查看任务/资源监控/查看日志的dashboard地址(使用VSCode插件进行debug的地址也是它)
  1. 然后在每个Ray Worker节点上执行 ray start --address=gcs_address
  2. 最后提交job任务ray job submit --address=dashboard_address -- python3 -m verl.trainer.main_ppo trainer.n_gpus_per_node=8 ... 就可以在dashboard里看到各种信息了

启动后,整体架构如图,前两篇文章介绍过了,就不赘述了:

  • 其中driver进行代表single-controller
  • 其他的 actor/critic/rollout/ref/reward 那些 workers 代表 multi-controller,均对应着各自的 resource group

在这里插入图片描述

下面直接看源码。

源码

首先是入口函数,即main_ppo.py,主要做定义、初始化:在这里插入图片描述

  1. 初始化 Ray cluster 环境
  2. 通过 @ray.remote 定义了一个 远程执行的 class TaskRunner
  3. 定义 actor/rollout worker:通过配置指定使用 fsdpmegatron,并构建 mappingrole_worker_mapping[Role.ActorRollout] = ray.remote(actor_rollout_cls)
  4. 定义 criticworker
  5. 将上述两个worker映射到resourece资源上:mapping[Role.ActorRollout] = global_pool_idmapping[Role.Critic] = global_pool_id
  6. 定义 rewardworkerrole_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) 的同时映射资源 mapping[Role.RewardModel] = "global_pool"
  7. 定义 refworkerrole_worker_mapping[Role.RefPolicy] = ray.remote(ref_policy_cls) 的同时映射资源 mapping[Role.RefPolicy] = "global_pool"
  8. 执行PPO workflow:加载模型、准备dataset、构建RayPPOTrainer、执行 RayPPOTrainer.init_workers()、执行 RayPPOTrainer.fit()
# Initialize the PPO trainer.
trainer = RayPPOTrainer(config=config,tokenizer=tokenizer,processor=processor,role_worker_mapping=self.role_worker_mapping,resource_pool_manager=resource_pool_manager,ray_worker_group_cls=ray_worker_group_cls,reward_fn=reward_fn,val_reward_fn=val_reward_fn,train_dataset=train_dataset,val_dataset=val_dataset,collate_fn=collate_fn,train_sampler=train_sampler,
)
# Initialize the workers of the trainer.
trainer.init_workers()
# Start the training process.
trainer.fit()

然后执行的是核心的RayPPOTrainer,主要就是俩函数,一个是init_workers(),一个是fit() 在这里插入图片描述

先看init_workers()

  1. 根据config配置的资源创建resource pool
  2. 创建hybrid_engine,这是actorrollout的 colocate的复合体
resource_pool = self.resource_pool_manager.get_resource_pool(Role.ActorRollout)
actor_rollout_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.ActorRollout],config=self.config.actor_rollout_ref,role="actor_rollout",
)
self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
  1. 创建critic
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
critic_cfg = omega_conf_to_dataclass(self.config.critic)
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)
self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
  1. 创建ref
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
ref_policy_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RefPolicy],config=self.config.actor_rollout_ref,role="ref",
)
self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls
  1. 创建reward,下面设置用的是reward modelfunction
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls
  1. 创建各自的wroker groupWorkerGroup是一组Wroker的抽象集合,使得driver可以和底层的多个worker进行交互:
for resource_pool, class_dict in self.resource_pool_to_cls.items():worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)wg_dict = self.ray_worker_group_cls(resource_pool=resource_pool,ray_cls_with_init=worker_dict_cls,**wg_kwargs,)spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())all_wg.update(spawn_wg)if self.use_critic:self.critic_wg = all_wg["critic"]self.critic_wg.init_model()if self.use_reference_policy and not self.ref_in_actor: # 需要关注self.ref_policy_wg = all_wg["ref"]self.ref_policy_wg.init_model()if self.use_rm:self.rm_wg = all_wg["rm"]self.rm_wg.init_model()

这里需要注意的是:

  • actorrollout进行colocate的目的:是在rollout和train两个阶段间高效更新参数权重
  • 但是否也同样也colocate ref,取决于是否用了LoRA,因为refactor它们的base基座模型一样,只不过actor lora多了一层lora的适配层,也就是BA矩阵,所以如果用LoRA,可以把rollout/actor/ref同时colocate到一起,更省资源

之后再看fit(),其实就是标准的PPO实现了,下面提取出关键信息:

for prompt in dataloader:output = actor_rollout_ref_wg.generate_sequences(prompt) # old_log_prob = actor_rollout_ref_wg.compute_log_prob(output)ref_log_prob = actor_rollout_ref_wg.compute_ref_log_prob(output)values = critic_wg.compute_values(output)rewards = reward_wg.compute_scores(output)advantages = compute_advantages(values, rewards)output = output.union(old_log_prob).union(ref_log_prob).union(values).union(rewards).union(advantages)actor_rollout_ref_wg.update_actor(output)critic.update_critic(output)

另外,关于driverwroker的数据交互,大致可以分成3步:

  1. driver把数据按DP数量进行切分
  2. 把数据分发给每个worker
  3. 每个worker再将执行的结果进行整合,所以VeRL这里搞了一个语法糖@register
class ActorRolloutRefWorker(Worker):@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)def generate_sequences(self, prompts: DataProto):prompts = prompts.to(torch.cuda.current_device())

上面的注解@register装饰了方法generate_sequence,包含了 dispatch_mode对应的:

  • dispatch_func:把输入dispatch到worker group中的各个worker
  • collect_func:把worker group的各个worker的response collect到一起

VeRL的各种参数,有详细解释,也有展示的图片


下篇文章介绍下如何使用VeRL进行Agentic RL,也就是多轮RL。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/news/922337.shtml
繁体地址,请注明出处:http://hk.pswp.cn/news/922337.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QML Charts组件之折线图的鼠标交互

目录前言相关系列代码示例详解(LineSeriesDemo3.qml)功能概览运行效果代码说明工程下载参考前言 接上文(QML Charts组件之折线图的基础属性),本文将重点介绍LineSeries的鼠标交互,包括:鼠标拖拽…

二值信号量——学习笔记12

本文是笔者在学习 正点原子官方 的《【正点原子】手把手教你学FreeRTOS实时系统》系列视频时整理的笔记。 视频讲解清晰透彻,非常感谢UP主的无私奉献!原课程链接如下: 👉 B站视频链接:​​​​​​【正点原子】手把手教…

裸机开发 时钟配置,EPIT

1.概念时钟(clock):在电子系统中是一个产生稳定、周期性振荡信号的电路或组件。这个信号像节拍器或心跳一样,为数字电路中的各种操作提供同步时序基准。PLL(phase locked loop)锁相环电路: 倍频PFD(phase fractional P…

Linux-文本三剑客(grep、sed、awk)

Linux-文本三剑客前言一、grep二、sed三、awk模式 -- 正则表达式关系表达式、运算符表达模式匹配表达式动作 输出流程控制参数传递,awk接受外部变量统计数组的使用分组统计练习常用内置函数前言 grep、sed、awk 被称为 “文本三剑客”,它们是处理文本文…

主流反爬虫、反作弊防护与风控对抗手段

文章目录1. 写在前面2. 指纹检测3. 行为验证3. 加固防护4. 链路检测5. 风控埋点6. 游客注册7. 数据防护8. 账号权重9. 反调阻断【🏠作者主页】:吴秋霖 【💼作者介绍】:擅长爬虫与JS加密逆向分析!Python领域优质创作者、…

金蝶云星空插件开发记录(一)

实现目的:新增供应商保存后,触发钉钉审批流程,并根据钉钉审批结果回写是否合格供应商。实现思路:通过BOS平台供在应商管理界面新增两个复选框字段:是否钉钉审批、是否合格供应商,若在新建供应商档案时勾选是…

企业跨区域组网新解:SD-WAN技术打造安全稳定网络体系

前言在数字化浪潮席卷全球的今天,企业跨区域网络互联已成为支撑业务发展的关键基础设施。传统MPLS专线虽性能稳定,但高昂成本和漫长部署周期令众多企业望而却步。SD-WAN技术的出现,正以其智能、灵活和成本效益的优势,重塑企业组网…

Docker 容器化

引言在解释docker是什么之前,我们首先应该先了解的是容器化的概念。什么是容器?就是一个沙箱,在这个沙箱中涵盖了特定应用运行的一切依赖的内容。但他不是一个操作系统,且和底层的操作系统是隔离的。什么是容器化?容器…

LeetCode刷题——hot 100(3)

题目1:矩阵置零题目:问题分析:使用两个布尔数组来分别记录哪行哪列出现了0,当出现0的行和列,对应的布尔数组值置为true。再次遍历数组,当出现行数组和列数组中的值为true,则对应的原数组的值置为…

Ajax-day2(图书管理)-渲染列表

本篇笔记素材来自“黑马程序员” 渲染列表图书管理一、获取数据二、渲染数据完整代码图书管理 Bootstrap 框架渲染列表(查)新增图书(增)删除图书(删)编辑图书(改) 自己的图书数据&a…

MOS管的电路

MOS管的三极都会存在以下三个电容,分别是:Cgs,Cgd,Cds 输入电容CissCgsCgd 输出电容CossCgdCds 反向传输电容CrssCgd,也叫米勒电容 然而,这三个等效电容是构成串并联组合关系,他们并不是独立的,而是相互…

STM32_05_时钟树

时钟 d用来输入数据,CLK就是我们的时钟,CPU1s中72000000HZ个时钟周期STM32的时钟树锁相环HSE时钟源HSI时钟源LSE时钟源LSI时钟源SystemInit函数SetSysClock函数SetSysClockTo72函数SystemInit()后时钟频率大小总结RCC标准库函数定义变量a&…

C语言---判断语句

文章目录1. if 语句2. if...else 语句3. if...else if...else 语句4. switch 语句5. 三元运算符 ( ? : )总结与对比如何选择C语言中的判断语句用于根据给定的条件来决定执行哪一段代码。其核心是条件为真(必须)则执行一段代码,条件为假&…

[硬件电路-212]:电流的本质确实是电子的移动

1. 微观机制:电子的定向漂移与热运动定向漂移(Drift Motion):在导体(如金属)中,自由电子(价电子)受电场驱动,从负端向正端定向移动,形成宏观电流。…

双RFSOC47DR-16通道5GSPS ADC采集模块

16通道5GSPS ADC采集板卡组成如图1所示。该板卡的输入接口为SMA单端输入,ADC采集和处理采用Xilinx公司的XCZU47DR-2FFVE1156I芯片。板卡需配备4路QSFP28光口输出,并需要集成网口、DDR4、SD卡、USB调试口。两块RF-Soc需确保连接通信功能。板卡的16通道需实…

pytest -- 中文文档

前言 零基础1小时快速入门pytest自动化测试教程,全套项目框架实战pytest配置文件可以改变pytest的运行方式,它是一个固定的文件pytest.ini文件,读取配置信息,按指定的方式去运行 非test文件 pytest里面有些文件是非test文件 pyt…

硬件开发2-ARM裸机开发3-IMX6ULL - 引入中断

一、铺垫引入中断 → 按键1、概要:实现按键控制发光二极管和蜂鸣器输入类型的外设:按键(key)2、参考手册内容完成配置过程(1)key 按键原理图(2)core 内核中命名 -- UART1 CTS&#x…

Ansible的 Playbook 模式详解

目录一、Playbook模式1.1 Playbook 的优势1.2 Playbook 的组成1.3 安装 httpd 服务案例1.4 Playbook 命令及常用参数1.5 Playbook 的语法 —— 权限相关1. remote_user2. become3. become_method1.6 Playbook 的通知与触发机制1. notify2. handlers3. 使用示例4. 使用场景1.6 P…

猿辅导Java后台开发面试题及参考答案

int 与 Integer 的区别是什么?若创建数量庞大的数字时使用 Integer,会对重复数字创建新对象吗?int 是 Java 中的基本数据类型,直接存储数值,占用 4 个字节,默认值为 0,不需要通过 new 关键字创建…

代码随想录学习摘抄day9(回溯1-11)

一个朴实无华的目录定义:回溯法也可以叫做回溯搜索法,它是一种搜索的方式。应用场景:回溯法解决的问题都可以抽象为树形结构代码模板题型第77题. 组合思路:每次从集合中选取元素,可选择的范围随着选择的进行而收缩&…