RLHF奖励模型的训练

由于 RLHF 的训练过程中需要依赖大量的人类偏好数据进行学习,因此很难在训练过程中要求人类标注者实时提供偏好反馈。为此,我们需要训练一个模型来替代人类在 RLHF 训练过程中实时提供反馈,这个模型被称为奖励模型

🔸一、 目标函数公式解释

公式如下:

L = − E ( x , y + , y − ) ∼ D [ log ⁡ σ ( r θ ( x , y + ) − r θ ( x , y − ) ) ] − β E ( x , y + ) ∼ D [ ∑ t = 1 T log ⁡ p ( y t + ∣ x , y < t + ) ] L = -\mathbb{E}_{(x, y^+, y^-) \sim D} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \right] - \beta \mathbb{E}_{(x, y^+)\sim D} \left[ \sum_{t=1}^{T} \log p(y^+_t \mid x, y^+_{<t}) \right] L=E(x,y+,y)D[logσ(rθ(x,y+)rθ(x,y))]βE(x,y+)D[t=1Tlogp(yt+x,y<t+)]

含义拆解:

  • x: 输入(如问题或提示语)
  • y+: 正例响应(由人类标注或偏好选择的答案)
  • y-: 负例响应(不好的答案)
  • r_θ(x, y): 奖励模型对 (x, y) 的打分(通常是最后一个 token 的输出经过 reward head 得到)
  • σ: Sigmoid 函数
  • β: 权重超参,控制模仿学习(第二项)对总损失的影响程度

公式两部分含义:

  1. 对比损失(ranking loss)

    $$

    • \log \sigma(r(x, y^+) - r(x, y^-))
      $$
    • 目标是使 正例得分 > 负例得分
    • r(x, y+) ≫ r(x, y-) 时,sigmoid接近1,log接近0 → 损失小,说明模型学得好
  2. 模仿学习损失(语言模型 loss)

    $$

    • \sum_{t=1}^{T} \log p(y^+t \mid x, y^+{<t})
      $$
    • 即:语言模型在给定输入 x 和前缀 y^+_{<t} 的条件下,预测下一个 token 的交叉熵损失
    • 起正则作用,防止奖励模型过度拟合打分而丧失语言生成能力

🔸二、代码结构分析

基于 LLaMA 的奖励模型实现详解(逐行解读 + PyTorch 源码分析)

📦 模块导入

1  import torch
2  import torch.nn as nn
3  import torch.nn.functional as F
4
5  from transformers import LlamaForCausalLM
  • torch:PyTorch 核心包
  • nn:用于定义神经网络模块(如 Linear)
  • F:包含函数式接口(如 loss 函数)
  • LlamaForCausalLM:来自 Transformers 的 LLaMA 语言模型基类,支持自回归文本生成

🧠 模型定义:奖励模型类

7  class LlamaRewardModel(LlamaForCausalLM):
8      def __init__(self, config):
9          super().__init__(config)
10
11         # 初始化线性变换层,将隐状态映射为标量,用于输出最终奖励
12         self.reward_head = nn.Linear(config.hidden_size, 1, bias=False)
  • LlamaRewardModel 继承自 HuggingFace 的 LlamaForCausalLM
  • 增加了一个 reward_head 线性层,用于将模型输出(hidden state)映射为 奖励值(scalar)

🧾 正例/负例打分函数 _forward_rmloss

14 def _forward_rmloss(self, input_ids, attention_mask, **kargs):
18     output = self.model.forward(
19         input_ids=input_ids,
20         attention_mask=attention_mask,
21         return_dict=True,
22         use_cache=False
23     )
25     logits = self.reward_head(output.last_hidden_state).squeeze(-1)
26     return logits
  • 输入:拼接后的 [x, y] 序列
  • self.model.forward(...):获得 LLaMA 模型输出(hidden states)
  • self.reward_head(...):只对最后一层 hidden state 应用线性映射,输出奖励值
  • squeeze(-1):去除最后一维 [batch, 1] -> [batch]

squeeze(-1) 的作用是去掉张量的最后一个维度,前提是该维度的值是 1。
假设 logits 是一个 [batch_size, 1] 的张量:
logits = tensor([[0.73], [0.24], [0.91]]) # shape: [3, 1]
执行:
logits = logits.squeeze(-1)
结果为:
tensor([0.73, 0.24, 0.91]) # shape: [3]

✍️ 模仿学习损失函数 _forward_lmloss

29 def _forward_lmloss(self, prompt_ids, lm_attn_mask, response_ids):
35     outputs = self.model.forward(
36         input_ids=prompt_ids,
37         attention_mask=lm_attn_mask,
38         return_dict=True,
39         use_cache=False,
40     )
42     hidden_states = outputs.last_hidden_state
43     logits = self.lm_head(hidden_states)
44     loss_fct = nn.CrossEntropyLoss()
45     logits = logits.view(-1, self.config.vocab_size)
46     response_ids = response_ids.view(-1)
47     loss = loss_fct(logits, response_ids)
48     return loss
  • prompt_ids[x, y⁺] 拼接后的 token ID
  • 输出 logits:维度 [batch_size, seq_len, vocab_size]
  • 计算交叉熵损失:对所有位置预测的 token 与 response_ids 进行对比

🚀 前向传播函数:组合损失计算

50 def forward(self, sent1_idx, attention_mask_1, sent2_idx,attention_mask_2, labels, prompt_ids, lm_attn_mask, response_ids):

参数说明:

  • sent1_idx: [x, y⁺] 拼接输入(正例)
  • sent2_idx: [x, y⁻] 拼接输入(负例)
  • labels: 全 0 标签,用于对比损失
  • prompt_ids: 与正例相关的 token(用于 LM Loss)
  • response_ids: 正例的 target token(用于 LM Loss)

计算对比损失(Reward Loss)

61 reward0 = self._forward_rmloss(sent1_idx, attention_mask_1)
66 reward1 = self._forward_rmloss(sent2_idx, attention_mask_2)
71 logits = reward0 - reward1
72 rm_loss = F.binary_cross_entropy_with_logits(logits,labels.to(logits.dtype), reduction="mean")
  • 分别计算 r(x, y⁺)r(x, y⁻)

  • 构造 logits = r⁺ - r⁻

  • 用 Binary Cross Entropy Loss 计算 reward loss

    公式对应:
    − log ⁡ ( σ ( r ( x , y + ) − r ( x , y − ) ) ) -\log(\sigma(r(x, y⁺) - r(x, y⁻))) log(σ(r(x,y+)r(x,y)))


计算语言模型损失(Language Modeling Loss)

75 lm_loss = self._forward_lmloss(prompt_ids, lm_attn_mask, response_ids)
  • 与传统语言模型训练一致,使用 CrossEntropyLoss

返回总损失

78 loss = rm_loss + lm_loss
79 return loss
  • 二者直接加和(可选加权项 β,可自己加参数)
  • 模型即同时优化打分能力 + 文本生成能力(联合学习)

🔸四、总结

项目描述
核心思想同时学习奖励模型 r_θ 和保持生成流畅性
优势1. 保留强化学习能力
2. 不失语义与流畅性
应用场景RLHF 的 reward 模型训练阶段,如 OpenAI 的 GPT 训练流程中 Step 2: Train Reward Model
可调参数β 控制生成质量与偏好打分之间的权衡

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/diannao/85178.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

reverse_ssh 建立反向 SSH 连接指南 混淆AV [好东西哟]

目录 &#x1f310; 工具简介 ⚙️ 前提条件 攻击主机 (Linux) 目标主机 (Windows) &#x1f4cb; 详细步骤 步骤 1&#xff1a;安装 Go 环境 步骤 2&#xff1a;安装必要依赖 步骤 3&#xff1a;下载并编译 reverse_ssh 步骤 4&#xff1a;配置密钥 步骤 5&#xff…

Ubuntu 下搭建ESP32 ESP-IDF开发环境,并在windows下用VSCode通过SSH登录Ubuntu开发ESP32应用

Ubuntu 下搭建ESP32 ESP-IDF开发环境&#xff0c;网上操作指南很多&#xff0c;本来一直也没有想过要写这么一篇文章。因为我其实不太习惯在linux下开发应用&#xff0c;平时更习惯windows的软件操作&#xff0c;只是因为windows下开发ESP32的应用编译时太慢&#xff0c;让人受…

Rust使用Cargo构建项目

文章目录 你好&#xff0c;Cargo&#xff01;验证Cargo安装使用Cargo创建项目新建项目配置文件解析默认代码结构 Cargo工作流常用命令速查表详细使用说明1. 编译项目2. 运行程序3.快速检查4. 发布版本构建 Cargo的设计哲学约定优于配置工程化优势 开发建议1. 新项目初始化​2. …

免费且好用的PDF水印添加工具

软件介绍 琥珀扫描.zip下载链接&#xff1a;https://pan.quark.cn/s/3a8f432b29aa 今天要给大家推荐一款超实用的PDF添加水印工具&#xff0c;它能够满足用户给PDF文件添加水印的需求&#xff0c;而且完全免费。 这款PDF添加水印的软件有着简洁的界面&#xff0c;操作简便&a…

NW969NW978美光闪存颗粒NW980NW984

NW969NW978美光闪存颗粒NW980NW984 技术解析&#xff1a;NW969、NW978、NW980与NW984的架构创新 美光&#xff08;Micron&#xff09;的闪存颗粒系列&#xff0c;尤其是NW969、NW978、NW980和NW984&#xff0c;代表了存储技术的前沿突破。这些产品均采用第九代3D TLC&#xf…

Mysql常用知识3:Kafka和数据库优化

文章目录 一、分布式消息系统&#xff08;Kafka相关问题5-10&#xff09;5. Kafka如何保证消息不丢失&#xff1f;6. 项目中Kafka具体怎么使用的&#xff1f;7. 消息异常未发送成功怎么解决&#xff1f;8. 重试具体怎么做的&#xff0c;循环吗&#xff1f;9. 重试多次失败怎么办…

常见的RAG文档解析辅助工具汇总及企业选型思考

以下当前比较知名的RAG的文档解析辅助工具的开源项目汇总&#xff0c;包含核心功能、License信息及GitHub地址&#xff1a; 1. RAGFlow 核心功能&#xff1a;支持PDF/扫描件/CAD等23种格式解析&#xff0c;OCR准确率98%&#xff0c;知识图谱融合&#xff0c;混合检索&#xf…

基于Sqoop的MySQL-Hive全量/增量同步解决方案(支持多表批量处理

一、全量同步方案设计 1.1 基础命令模板 sqoop import \ --connect jdbc:mysql://mysql_host:3306/db_name \ --username user \ --password pass \ --table source_table \ --hive-import \ --hive-table target_table \ --hive-overwrite \ # 覆盖已有表 --num-mappers 8 …

前端学习(7)—— HTML + CSS实现博客系统页面

目录 一&#xff0c;效果展示 二&#xff0c;实现博客列表页 2.1 实现导航栏 2.2 实现个人信息 2.3 实现博客列表 三&#xff0c;实现博客正文页 3.2 复用 3.4 实现博客正文 四&#xff0c;实现博客登录页 4.1 版心 4.2 登录框 五&#xff0c;实现博客编辑页 5.1 …

【技能拾遗】——家庭宽带单线复用布线与配置(移动2025版)

&#x1f4d6; 前言&#xff1a;在家庭网络拓扑中&#xff0c;客厅到弱电箱只预埋了一根网线&#xff0c;由于已将广电的有线电视取消并改用IPTV。现在需要解决在客厅布置路由器和观看IPTV问题&#xff0c;这里就用到单线复用技术。 目录 &#x1f552; 1. 拓扑规划&#x1f55…

VTK|实现类似CloundCompare的测量功能

文章目录 CloundCompare在点、线、面三种模式下的显示内容✅ 图1&#xff1a;点模式✅ 图2&#xff1a;线模式✅ 图3&#xff1a;面模式 增加控制菜单栏实现测量功能类如何调用项目git链接 CloundCompare在点、线、面三种模式下的显示内容 点 线 面 三张图展示了 CloudComp…

4000万日订单背后,饿了么再掀即时零售的“效率革命”

当即时零售转向价值深耕&#xff0c;赢面就是综合实力的强弱。 文&#xff5c;郭梦仪 编&#xff5c;王一粟 在硝烟弥漫的外卖行业“三国杀”中&#xff0c;饿了么与淘宝闪购的日订单量竟然突破了4000万单。 而距淘宝闪购正式上线&#xff0c;还不到一个月。 在大额福利优惠…

vedio.ontimeupdate()和video.onloadeddata()

video.onloadeddata &#xff08;&#xff09; video.onloadeddata 是 JavaScript 中用于监听 HTML <video> 元素 「当前帧数据已加载」 的事件处理器。当视频的第一帧画面数据加载完成&#xff08;足以开始播放&#xff09;时&#xff0c;会触发此事件。 1. 基本用法 …

Baklib内容中台革新企业知识实践

Baklib智能知识中枢构建 作为现代企业知识管理的核心架构&#xff0c;Baklib内容中台通过整合多源异构数据形成智能化知识中枢&#xff0c;实现从信息采集到价值转化的全链路管理。其底层采用跨平台数据贯通技术&#xff0c;支持API接口与企业现有CRM、ERP系统无缝对接&#x…

用不太严谨的文字介绍遥测自跟踪天线的基本原理

前两天跟一个客户见面的时候&#xff0c;客户问我&#xff1a;遥测自跟踪天线能够跟踪目标&#xff0c;是什么原理&#xff1f;不需要目标的位置&#xff0c;怎么做到自跟踪的&#xff1f; 突然一瞬间&#xff0c;有点语塞。 难道要介绍天线、馈源、极化、左旋、右旋、和差网…

VS配置redis环境、redis简单封装

一、安装redis数据库 1.下载redis的压缩包 wget https://download.redis.io/releases/redis-6.0.5.tar.g 2.解压缩redis压缩包&#xff0c;一般就在当前路径 tar -zvxf redis-6.0.5.tar.gz -C /usr/local/redis 方便找我把它解压缩在/usr/local/redis&#xff0c;如果没有r…

C++23 已移除特性解析

文章目录 引言C23 已移除特性介绍1. 垃圾收集的支持和基于可达性的泄漏检测&#xff08;P2186R2&#xff09;背景与原理存在的问题移除的影响 2. 混合宽字符串字面量拼接非良构&#xff08;P2201R1&#xff09;宽字符串编码概述混合拼接的问题示例分析移除的意义 3. 不可编码宽…

Cloudflare

Cloudflare 是一个网络基础设施和网站安全服务提供商&#xff0c;它的主要作用是让网站 更快、更安全、更可靠。简单来说&#xff0c;它是一个“护盾 加速器”。 &#x1f9e9; Cloudflare 的主要功能&#xff1a; 1. &#x1f680; 加速网站访问&#xff08;CDN&#xff09…

Spring Boot启动慢?Redis缓存击穿?Kafka消费堆积?——Java后端常见问题排查实战

Spring Boot启动慢&#xff1f;Redis缓存击穿&#xff1f;Kafka消费堆积&#xff1f;——Java后端常见问题排查实战 引言 Java后端系统因其丰富的技术栈和复杂的业务逻辑&#xff0c;常常面临启动延迟、性能瓶颈、异常错误等多种挑战。从核心语言、Web框架到分布式微服务及缓…

数字人引领政务新风尚:智能设备助力政务服务

在信息技术飞速发展的今天&#xff0c;政府机构不断探索提升服务效率和改善服务质量的新途径。实时交互数字人在政务服务中的应用正成为一大亮点&#xff0c;通过将“数字公务员”植入各种横屏智能设备中&#xff0c;为民众办理业务提供全程辅助。这种创新不仅优化了政务大厅的…