目标导向的强化学习:问题定义与 HER 算法详解—强化学习(19)

目录

1、目标导向的强化学习:问题定义

1.1、 核心要素与符号定义

1.2、 核心问题:稀疏奖励困境

1.3、 学习目标

2、HER(Hindsight Experience Replay)算法

2.1、 HER 的核心逻辑

2.2、 算法步骤(结合 DDPG 举例)

2.2.1、步骤 1:收集原始经验

2.2.2、步骤 2:重构经验(核心!)

2.2.3、步骤 3:替代目标生成策略

2.2.4、步骤 4:策略更新

2.3、 为什么 HER 有效?

2.4、公式总结

3、通俗理解

4、完整代码

5、实验结果  


1、目标导向的强化学习:问题定义

目标导向的强化学习(Goal-Conditioned Reinforcement Learning)是一类让智能体通过学习策略,从初始状态达到特定目标的任务。与传统强化学习不同,这类任务的核心是 “目标”—— 智能体的行为需围绕 “达成目标” 展开,而目标本身可能随任务变化(如 “机械臂抓取 A 物体”“机械臂抓取 B 物体” 是两个不同目标)。

1.1、 核心要素与符号定义

  • 状态(State):环境的观测信息,记为s \in \mathcal{S}\mathcal{S}是状态空间)。例如:机械臂的关节角度、物体的坐标。
  • 目标(Goal):智能体需要达成的状态,记为g \in \mathcal{G}\mathcal{G}是目标空间,通常与状态空间重合或相关)。例如:机械臂需抓取的物体坐标。
  • 动作(Action):智能体的行为,记为 a \in \mathcal{A}\mathcal{A} 是动作空间)。例如:机械臂关节的旋转角度。
  • 转移函数:状态 - 动作对到下一状态的映射,记为 s' \sim P(s' | s, a)(P 是状态转移概率)。
  • 奖励函数:衡量 “当前状态与目标的差距”,记为r(s, a, g)。目标导向任务的奖励通常仅与 “状态是否接近目标” 相关,与动作间接相关。

1.2、 核心问题:稀疏奖励困境

目标导向任务的奖励函数通常是稀疏的:仅当状态 s 与目标 g 几乎一致时,才给予正奖励;否则奖励为 0 或负值。 奖励函数示例(机械臂抓取任务):

  • 智能体在绝大多数尝试中(如 99% 的交互)都得不到正奖励,无法判断 “哪些动作有助于接近目标”;
  • 策略更新缺乏有效信号(梯度难以计算),学习效率极低,甚至无法收敛。

1.3、 学习目标

目标导向强化学习的目标是学习一个目标条件策略 \pi(a | s, g),使得在策略引导下,智能体从任意初始状态 \(s_0\) 出发,通过执行动作序列a_0, a_1, ..., a_T,最终达到目标 g 的概率最大化。

2、HER(Hindsight Experience Replay)算法

HER 算法是解决目标导向任务中稀疏奖励问题的经典方法,核心思想是:从 “失败经验” 中 “事后重构” 有效奖励信号—— 即使智能体没达成原定目标,也能通过修改目标,将 “失败轨迹” 转化为 “成功轨迹”,从而提取学习信号

2.1、 HER 的核心逻辑

假设智能体在一次交互中,原定目标是 g,但实际轨迹为\tau = (s_0, a_0, s_1, a_1, ..., s_T),最终状态 s_T \neq g(失败)。 HER 的关键操作是:从轨迹 \tau中选一个状态 s_k作为 “替代目标” \hat{g} = s_k,此时轨迹 \tau对于新目标 \\hat{g} 是 “成功的”(因为 s_T可能接近 \hat{g}),从而可计算有效奖励。

2.2、 算法步骤(结合 DDPG 举例)

HER 通常与离线强化学习算法(如 DDPG)结合使用,流程如下:

2.2.1、步骤 1:收集原始经验

智能体与环境交互,收集轨迹并存储到经验回放池 \mathcal{D}。每条经验是一个五元组:e = (s_t, a_t, r_t, s_{t+1}, g)其中r_t = r(s_t, a_t, g)是基于原定目标 g 的奖励(可能为 0)。

2.2.2、步骤 2:重构经验(核心!)

对回放池中的每条原始经验 e,HER 通过替代目标生成策略选一个新目标 \hat{g},重构出一条 “虚拟成功经验” \hat{e}\hat{e} = (s_t, a_t, \hat{r}_t, s_{t+1}, \hat{g}) 其中\hat{r}_t = r(s_t, a_t, \hat{g}) 是基于新目标 \hat{g} 的奖励(此时可能为正,因为 \hat{g}来自轨迹,s_{t+1} 可能接近 \hat{g})。

2.2.3、步骤 3:替代目标生成策略

HER 定义了 4 种常用的替代目标生成策略(以轨迹 \tau = (s_0, s_1, ..., s_T) 为例):

  • Final\hat{g} = s_T(选最终状态);
  • Future\hat{g} = s_k,其中 k \sim \text{Uniform}(t+1, T)(选未来状态);
  • Random\hat{g} = s_k,其中 k \sim \text{Uniform}(0, T)(随机选一个状态);
  • Episode\hat{g} 从同回合的其他轨迹中随机选一个状态(适用于多目标任务)。

2.2.4、步骤 4:策略更新

将原始经验 e 和重构经验\hat{e} 一起放入回放池,用离线算法(如 DDPG)更新策略。 以 DDPG 的 Critic 网络更新为例:

2.3、 为什么 HER 有效?

  • 解决稀疏性:通过重构经验,将 “0 奖励” 转化为 “正奖励”,使奖励信号密集化;
  • 利用失败经验:原本无用的失败轨迹被转化为有效学习样本,提高数据利用率;
  • 通用兼容:HER 是 “经验回放增强技术”,可与 DDPG、SAC 等多种算法结合,无需修改算法核心。

2.4、公式总结

  • 原始经验:e = (s_t, a_t, r(s_t, a_t, g), s_{t+1}, g)
  • 重构经验:\hat{e} = (s_t, a_t, r(s_t, a_t, \hat{g}), s_{t+1}, \hat{g}),其中 \hat{g} \in \{s_0, s_1, ..., s_T\}
  • 策略目标:\max_{\pi} \mathbb{E}_{\pi} \left[ \sum_{t=0}^T \gamma^t r(s_t, a_t, g) \right]

3、通俗理解

用 “快递员送货” 举例:

  • 目标导向任务:快递员(智能体)需要把包裹送到目标地址 g(原定目标),但只有送到 g 才有钱(奖励),中途迷路(失败)则没钱。
  • 稀疏奖励问题:快递员第一次送陌生地址,99% 的概率找不到,长期没钱,不知道往哪走。
  • HER 的做法:快递员虽然没到 g,但路过了 \hat{g}(比如某个超市),就把 “送到超市” 当作新目标,这次 “成功” 能拿到钱,从而学会 “如何到超市”;多次积累后,就能掌握城市路线,最终学会到任意目标 g 的方法。

4、完整代码

"""
文件名: 19.1
作者: 墨尘
日期: 2025/7/25
项目名: d2l_learning
备注: 
"""
import torch
import torch.nn.functional as F
import numpy as np
import random
from tqdm import tqdm
import collections
import matplotlib.pyplot as pltclass WorldEnv:def __init__(self):self.distance_threshold = 0.15self.action_bound = 1def reset(self):  # 重置环境# 生成一个目标状态, 坐标范围是[3.5~4.5, 3.5~4.5]self.goal = np.array([4 + random.uniform(-0.5, 0.5), 4 + random.uniform(-0.5, 0.5)])self.state = np.array([0, 0])  # 初始状态self.count = 0return np.hstack((self.state, self.goal))def step(self, action):action = np.clip(action, -self.action_bound, self.action_bound)x = max(0, min(5, self.state[0] + action[0]))y = max(0, min(5, self.state[1] + action[1]))self.state = np.array([x, y])self.count += 1dis = np.sqrt(np.sum(np.square(self.state - self.goal)))reward = -1.0 if dis > self.distance_threshold else 0if dis <= self.distance_threshold or self.count == 50:done = Trueelse:done = Falsereturn np.hstack((self.state, self.goal)), reward, done
class PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim, action_bound):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)self.fc3 = torch.nn.Linear(hidden_dim, action_dim)self.action_bound = action_bound  # action_bound是环境可以接受的动作最大值def forward(self, x):x = F.relu(self.fc2(F.relu(self.fc1(x))))return torch.tanh(self.fc3(x)) * self.action_boundclass QValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(QValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)self.fc3 = torch.nn.Linear(hidden_dim, 1)def forward(self, x, a):cat = torch.cat([x, a], dim=1)  # 拼接状态和动作x = F.relu(self.fc2(F.relu(self.fc1(cat))))return self.fc3(x)
class DDPG:''' DDPG算法 '''def __init__(self, state_dim, hidden_dim, action_dim, action_bound,actor_lr, critic_lr, sigma, tau, gamma, device):self.action_dim = action_dimself.actor = PolicyNet(state_dim, hidden_dim, action_dim,action_bound).to(device)self.critic = QValueNet(state_dim, hidden_dim, action_dim).to(device)self.target_actor = PolicyNet(state_dim, hidden_dim, action_dim,action_bound).to(device)self.target_critic = QValueNet(state_dim, hidden_dim,action_dim).to(device)# 初始化目标价值网络并使其参数和价值网络一样self.target_critic.load_state_dict(self.critic.state_dict())# 初始化目标策略网络并使其参数和策略网络一样self.target_actor.load_state_dict(self.actor.state_dict())self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=critic_lr)self.gamma = gammaself.sigma = sigma  # 高斯噪声的标准差,均值直接设为0self.tau = tau  # 目标网络软更新参数self.action_bound = action_boundself.device = devicedef take_action(self, state):state = torch.tensor([state], dtype=torch.float).to(self.device)action = self.actor(state).detach().cpu().numpy()[0]# 给动作添加噪声,增加探索action = action + self.sigma * np.random.randn(self.action_dim)return actiondef soft_update(self, net, target_net):for param_target, param in zip(target_net.parameters(),net.parameters()):param_target.data.copy_(param_target.data * (1.0 - self.tau) +param.data * self.tau)def update(self, transition_dict):states = torch.tensor(transition_dict['states'],dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions'],dtype=torch.float).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)next_q_values = self.target_critic(next_states,self.target_actor(next_states))q_targets = rewards + self.gamma * next_q_values * (1 - dones)# MSE损失函数critic_loss = torch.mean(F.mse_loss(self.critic(states, actions), q_targets))self.critic_optimizer.zero_grad()critic_loss.backward()self.critic_optimizer.step()# 策略网络就是为了使Q值最大化actor_loss = -torch.mean(self.critic(states, self.actor(states)))self.actor_optimizer.zero_grad()actor_loss.backward()self.actor_optimizer.step()self.soft_update(self.actor, self.target_actor)  # 软更新策略网络self.soft_update(self.critic, self.target_critic)  # 软更新价值网络
class Trajectory:''' 用来记录一条完整轨迹 '''def __init__(self, init_state):self.states = [init_state]self.actions = []self.rewards = []self.dones = []self.length = 0def store_step(self, action, state, reward, done):self.actions.append(action)self.states.append(state)self.rewards.append(reward)self.dones.append(done)self.length += 1class ReplayBuffer_Trajectory:''' 存储轨迹的经验回放池 '''def __init__(self, capacity):self.buffer = collections.deque(maxlen=capacity)def add_trajectory(self, trajectory):self.buffer.append(trajectory)def size(self):return len(self.buffer)def sample(self, batch_size, use_her, dis_threshold=0.15, her_ratio=0.8):batch = dict(states=[],actions=[],next_states=[],rewards=[],dones=[])for _ in range(batch_size):traj = random.sample(self.buffer, 1)[0]step_state = np.random.randint(traj.length)state = traj.states[step_state]next_state = traj.states[step_state + 1]action = traj.actions[step_state]reward = traj.rewards[step_state]done = traj.dones[step_state]if use_her and np.random.uniform() <= her_ratio:step_goal = np.random.randint(step_state + 1, traj.length + 1)goal = traj.states[step_goal][:2]  # 使用HER算法的future方案设置目标dis = np.sqrt(np.sum(np.square(next_state[:2] - goal)))reward = -1.0 if dis > dis_threshold else 0done = False if dis > dis_threshold else Truestate = np.hstack((state[:2], goal))next_state = np.hstack((next_state[:2], goal))batch['states'].append(state)batch['next_states'].append(next_state)batch['actions'].append(action)batch['rewards'].append(reward)batch['dones'].append(done)batch['states'] = np.array(batch['states'])batch['next_states'] = np.array(batch['next_states'])batch['actions'] = np.array(batch['actions'])return batchif __name__ == '__main__':actor_lr = 1e-3critic_lr = 1e-3hidden_dim = 128state_dim = 4action_dim = 2action_bound = 1sigma = 0.1tau = 0.005gamma = 0.98num_episodes = 2000n_train = 20batch_size = 256minimal_episodes = 200buffer_size = 10000device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")random.seed(0)np.random.seed(0)torch.manual_seed(0)env = WorldEnv()replay_buffer = ReplayBuffer_Trajectory(buffer_size)agent = DDPG(state_dim, hidden_dim, action_dim, action_bound, actor_lr,critic_lr, sigma, tau, gamma, device)return_list = []for i in range(10):with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:for i_episode in range(int(num_episodes / 10)):episode_return = 0state = env.reset()traj = Trajectory(state)done = Falsewhile not done:action = agent.take_action(state)state, reward, done = env.step(action)episode_return += rewardtraj.store_step(action, state, reward, done)replay_buffer.add_trajectory(traj)return_list.append(episode_return)if replay_buffer.size() >= minimal_episodes:for _ in range(n_train):transition_dict = replay_buffer.sample(batch_size, True)agent.update(transition_dict)if (i_episode + 1) % 10 == 0:pbar.set_postfix({'episode':'%d' % (num_episodes / 10 * i + i_episode + 1),'return':'%.3f' % np.mean(return_list[-10:])})pbar.update(1)episodes_list = list(range(len(return_list)))plt.plot(episodes_list, return_list)plt.xlabel('Episodes')plt.ylabel('Returns')plt.title('DDPG with HER on {}'.format('GridWorld'))plt.show()

5、实验结果  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/92950.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/92950.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2025 XYD Summer Camp 7.21 智灵班分班考 · Day1

智灵班分班考 Day1 时间线 8:00 在滨兰实验的远古机房中的一个键盘手感爆炸的电脑上开考。开 T1&#xff0c;推了推发现可以 segment tree 优化 dp&#xff0c;由于按空格需要很大的力气导致马蜂被迫改变。后来忍不住了顶着疼痛按空格。8:30 过了样例&#xff0c;但是没有大样…

基于多种主题分析、关键词提取算法的设计与实现【TF-IDF算法、LDA、NMF分解、BERT主题模型】

文章目录有需要本项目的代码或文档以及全部资源&#xff0c;或者部署调试可以私信博主一、项目背景二、研究目标与意义三、数据获取与处理四、文本分析与主题建模方法1. 传统方法探索2. 主题模型比较与优化3. 深度语义建模与聚类五、研究成果与应用价值六、总结与展望总结每文一…

MDC(Mapped Diagnostic Context) 的核心介绍与使用教程

关于日志框架中 MDC&#xff08;Mapped Diagnostic Context&#xff09; 的核心介绍与使用教程&#xff0c;结合其在分布式系统中的实际应用场景&#xff0c;分模块说明&#xff1a; 一、MDC 简介 MDC&#xff08;映射诊断上下文&#xff09; 是 SLF4J/Logback 提供的一种线程…

Linux随记(二十一)

一、highgo切换leader&#xff0c;follow - 随记 【待写】二、highgo的etcd未授权访问 - 随记 【待写】三、highgo的etcd未授权访问 - 随记 【待写】3.2、etcd的metric未授权访问 - 随记 【待写】四、安装Elasticsearch 7.17.29 和 Elasticsearch 未授权访问【原理扫描】…

Java环境配置之各类组件下载安装教程整理(jdk、idea、git、maven、mysql、redis)

Java环境配置之各类组件下载安装教程整理&#xff08;jdk、idea、git、maven、mysql、redis&#xff09;1.[安装配置jdk8]2.[安装配置idea]3.[安装配置git]4.[安装配置maven]5.[安装配置postman]6.[安装配置redis和可视化工具]7.[安装配置mysql和可视化工具]8.[安装配置docker]…

配置https ssl证书生成

1.可用openssl生成私钥和自签名证书 安装opensslsudo yum install openssl -y 2.生成ssl证书 365天期限sudo openssl req -x509 -nodes -days 365 -newkey rsa:2048 \-keyout /etc/ssl/private/nginx-selfsigned.key \-out /etc/ssl/certs/nginx-selfsigned.crt3、按照提示编…

编程语言Java——核心技术篇(四)集合类详解

言不信者行不果&#xff0c;行不敏者言多滞. 目录 4. 集合类 4.1 集合类概述 4.1.1 集合框架遵循原则 4.1.2 集合框架体系 4.2 核心接口和实现类解析 4.2.1 Collection 接口体系 4.2.1.1 Collection 接口核心定义 4.2.1.2 List接口详解 4.2.1.3 Set 接口详解 4.2.1.4…

GaussDB 数据库架构师(八) 等待事件(1)-概述

1、等待事件概述 等待事件&#xff1a;指当数据库会话(session)因资源竞争或依赖无法继续执行时&#xff0c;进入"等待"状态&#xff0c;此时产生的性能事件即等待事件。 2、等待事件本质 性能瓶颈的信号灯&#xff0c;反映CPU,I/O、锁、网络等关键资源的阻塞情况。…

五分钟系列-文本搜索工具grep

目录 1️⃣核心功能​​ ​​2️⃣基本语法​​ 3️⃣​​常用选项 & 功能详解​​ ​​4️⃣经典应用场景 & 示例​​ 5️⃣​​重要的提示 & 技巧​​ ​​6️⃣总结​​ grep 是 Linux/Unix 系统中功能强大的​​文本搜索工具​​&#xff0c;其名称源自 …

Java面试题及详细答案120道之(041-060)

《前后端面试题》专栏集合了前后端各个知识模块的面试题&#xff0c;包括html&#xff0c;javascript&#xff0c;css&#xff0c;vue&#xff0c;react&#xff0c;java&#xff0c;Openlayers&#xff0c;leaflet&#xff0c;cesium&#xff0c;mapboxGL&#xff0c;threejs&…

【尝试】本地部署openai-whisper,通过 http请求识别

安装whisper的教程&#xff0c;已在 https://blog.csdn.net/qq_23938507/article/details/149394418 和 https://blog.csdn.net/qq_23938507/article/details/149326290 中说明。 1、创建whisperDemo1.py from fastapi import FastAPI, UploadFile, File import whisper i…

Visual Studio 的常用快捷键

Visual Studio 作为主流的开发工具&#xff0c;提供了大量快捷键提升编码效率。以下按功能分类整理常用快捷键&#xff0c;涵盖基础操作、代码编辑、调试等场景&#xff08;以 Visual Studio 2022 为例&#xff0c;部分快捷键可在「工具 > 选项 > 环境 > 键盘」中自定…

Triton Server部署Embedding模型

在32核CPU、无GPU的服务器上&#xff0c;使用Python后端和ONNX后端部署嵌入模型&#xff0c;并实现并行调用和性能优化策略。方案一&#xff1a;使用Python后端部署Embedding模型 Python后端提供了极大的灵活性&#xff0c;可以直接在Triton中运行您熟悉的sentence-transformer…

Java动态调试技术原理

本文转载自 美团技术团队胡健的Java 动态调试技术原理及实践, 通过学习java agent方式进行动态调试了解目前很多大厂开源的一些基于此的调试工具。 简介 断点调试是我们最常使用的调试手段&#xff0c;它可以获取到方法执行过程中的变量信息&#xff0c;并可以观察到方法的执…

人工智能-python-OpenCV 图像基础认知与运用

文章目录OpenCV 图像基础认知与运用1. OpenCV 简介与安装OpenCV 的优势安装 OpenCV2. 图像的基本概念2.1. 图像的存储格式2.2. 图像的表示3. 图像的基本操作3.1. 创建图像窗口3.2. 读取与显示图像3.3. 保存图像3.4. 图像切片与区域提取3.5. 图像大小调整4. 图像绘制与注释4.1. …

Windows电脑添加、修改打印机的IP地址端口的方法

本文介绍在Windows电脑中&#xff0c;为打印机添加、修改IP地址&#xff0c;从而解决电脑能找到打印机、但是无法打印问题的方法。最近&#xff0c;办公室的打印机出现问题——虽然在电脑的打印机列表能找到这个打印机&#xff0c;但是选择打印时&#xff0c;就会显示文档被挂起…

告别复杂配置!Spring Boot优雅集成百度OCR的终极方案

1. 准备工作 1.1 注册百度AI开放平台 访问百度AI开放平台 注册账号并登录 进入控制台 → 文字识别 → 创建应用 记录下API Key和Secret Key 2. 项目配置 2.1 添加依赖 (pom.xml) <dependencies><!-- Spring Boot Web --><dependency><groupId>o…

「iOS」——内存五大分区

UI学习iOS-底层原理 24&#xff1a;内存五大区总览一、栈区&#xff08;Stack&#xff09;1.1 核心特性1.2 优缺点1.3函数栈与栈帧1.3 堆栈溢出风险二、堆区&#xff08;Heap&#xff09;;2.1 核心特性2.2 与栈区对比三、全局 / 静态区&#xff08;Global/Static&#xff09;3.…

每日一题【删除有序数组中的重复项 II】

删除有序数组中的重复项 II思路class Solution { public:int removeDuplicates(vector<int>& nums) {if(nums.size()<2){return nums.size();}int index 2;for (int i 2; i < nums.size();i ) {if(nums[i] ! nums[index-2]) {nums[index]nums[i];}}return ind…

兼容性问题记录

1、dialog设置高度MATCH_PARENT全屏后&#xff0c;三星机型和好像是一加&#xff0c;会带出顶部状态栏&#xff0c;设置隐藏状态栏属性无效。解决方法&#xff1a;高度不设置为MATCH_PARENT&#xff0c;通过windowmanager.getdefaultdisplay来获取并设置高度&#xff0c;再设置…