从代码学习深度强化学习 - 目标导向的强化学习-HER算法 PyTorch版

文章目录

  • 1. 前言:当一个任务有多个目标
  • 2. 目标导向的强化学习 (GoRL) 简介
  • 3. HER算法:化失败为成功的智慧
  • 4. 代码实践:用PyTorch实现HER+DDPG
    • 4.1 自定义环境 (WorldEnv)
    • 4.2 智能体与算法 (DDPG)
    • 4.3 HER的核心:轨迹经验回放
    • 4.4 主流程与训练
  • 5. 训练结果与分析
  • 6. 总结


1. 前言:当一个任务有多个目标

经典的深度强化学习算法,如 PPO、SAC 等,在各自擅长的任务中都取得了非常好的效果。但它们通常都局限在解决单个任务上,换句话说,训练好的算法,在运行时也只能完成一个特定的任务。

想象一个场景:我们想让一个机械臂能把桌子上的任何一个物体重置到任意一个指定位置。对于传统强化学习而言,如果目标物体的初始位置和目标位置每次都变化,那么这就是一个全新的任务。即便任务的“格式”——抓取并移动——是一样的,但策略本身可能需要重新训练。这显然效率极低。

为了解决这类问题,目标导向的强化学习 (Goal-Oriented Reinforcement Learning, GoRL) 应运而生。它的核心思想是学习一个通用策略,这个策略能够根据给定的目标 (goal) 来执行相应的动作,从而用一个模型解决一系列结构相同但目标不同的复杂任务。

然而,在诸如机械臂抓取等真实场景中,奖励往往是稀疏的。只有当机械臂成功将物体放到指定位置时,才会获得正奖励,否则奖励一直为0或-1。在训练初期,智能体很难通过随机探索完成任务并获得奖励,导致学习效率极低。

为了解决稀疏奖励下的学习难题,OpenAI 在2017年提出了事后经验回放 (Hindsight Experience Replay, HER) 算法。HER 的思想极为巧妙:即使我们没有完成预设的目标,但我们总归是完成了“某个”目标。 通过这种“事后诸葛亮”的方式,将失败的经验转化为成功的学习样本,从而极大地提升了在稀疏奖励环境下的学习效率。

本文将从 HER 的基本概念出发,结合一个完整的 PyTorch 代码实例,带你深入理解 HER 是如何与 DDPG 等经典算法结合,并有效解决目标导向的强化学习问题的。

完整代码:下载链接

2. 目标导向的强化学习 (GoRL) 简介

在目标导向的强化学习中,传统的马尔可夫决策过程 (MDP) 被扩展了。除了状态 S、动作 A、转移概率 P 之外,还引入了目标空间 G。策略 π 不仅依赖于当前状态 s,还依赖于目标 g,即 π(a|s, g)

奖励函数 r 也与目标相关,记为 r_g。在本文的设定中,状态 s 包含了智能体自身的信息(例如坐标),而目标 g 则是状态空间中的一个特定子集(例如一个目标坐标)。我们使用一个映射函数 φ 将状态 s 映射到其对应的目标 g

在 GoRL 中,一个常见的挑战是稀疏奖励。例如,只有当智能体达到的状态 s' 对应的目标 φ(s') 与我们期望的目标 g 足够接近时,才给予奖励。这可以用以下公式表示:

其中,δ_g 是一个很小的阈值。这意味着,在绝大多数情况下,智能体得到的奖励都是-1,学习信号非常微弱。

3. HER算法:化失败为成功的智慧

HER 的核心思想在于重新利用失败的轨迹

假设智能体在一次任务(一个 episode)中,目标是 g,但最终没有达到,整个轨迹获得的奖励都是-1。这条“失败”的轨迹对于学习如何达到目标 g 几乎没有帮助。

但 HER 会这样想:虽然智能体没有达到目标 g,但它在轨迹的最后达到了某个状态 s_T。这个状态 s_T 自身就可以被看作是一个目标,我们称之为“事后目标” g' = φ(s_T)。如果我们把这次任务的目标“篡改”为 g',那么这条轨迹就变成了一条成功的轨迹!因为智能体确实达到了 g'

通过这种方式,HER 能够从任何轨迹中都提取出有价值的学习信号,将稀疏的奖励变得稠密。

在具体实现时,HER 会从一条完整的轨迹中,随机采样一个时间步 (s_t, a_t, r_t, s_{t+1}),然后根据一定策略选择一个新的目标 g' 来替换原始目标 g,并根据新目标重新计算奖励 r'

HER 提出了几种选择新目标 g' 的策略,其中最常用也最直观的是 future 策略:在当前时间步 t 之后,从该轨迹中随机选择一个未来状态 s_k (k > t),将其对应的 φ(s_k) 作为新的目标 g'

这种方法保证了新目标是在当前状态之后可以达到的,使得学习过程更加稳定和有效。

HER 作为一个通用的技巧,可以与任何 off-policy 的强化学习算法(如 DQN, DDPG, SAC)结合。在本文的实践中,我们将它与 DDPG 算法相结合。

4. 代码实践:用PyTorch实现HER+DDPG

接下来,我们通过一个完整的 PyTorch 代码项目来学习 HER 的实现。任务非常直观:在一个二维平面上,智能体需要从原点 (0, 0) 移动到一个随机生成的目标点。

4.1 自定义环境 (WorldEnv)

首先,我们定义一个简单的二维世界环境。

  • 状态空间: 4维向量 [agent_x, agent_y, goal_x, goal_y]
  • 动作空间: 2维向量 [move_x, move_y],每个分量的范围是 [-1, 1]
  • 目标: 在每个 episode 开始时,在 [3.5, 4.5] x [3.5, 4.5] 区域内随机生成一个目标点。
  • 奖励: 如果智能体与目标的距离小于阈值 0.15,奖励为 0;否则为 -1
  • 终止条件: 达到目标,或达到最大步数 50
# 自定义环境
import numpy as np
import random
from typing import Tupleclass WorldEnv:"""二维世界环境类,用于目标导向的强化学习任务智能体需要从起始位置移动到随机生成的目标位置"""def __init__(self) -> None:"""初始化环境参数"""# 距离阈值,当智能体与目标的距离小于等于此值时认为任务完成 (标量)self.distance_threshold: float = 0.15# 动作边界,限制每个动作分量的取值范围为[-1, 1] (标量)self.action_bound: float = 1.0# 地图边界,智能体活动范围为[0, 5] x [0, 5] (标量)self.map_bound: float = 5.0# 最大步数,防止无限循环 (标量)self.max_steps: int = 50# 当前状态,智能体在二维平面上的坐标 (2维向量)self.state: np.ndarray = None# 目标位置,智能体需要到达的目标坐标 (2维向量)self.goal: np.ndarray = None# 当前步数计数器 (标量)self.count: int = 0def reset(self) -> np.ndarray:"""重置环境到初始状态Returns:np.ndarray: 包含当前状态和目标位置的观测向量 (4维向量: [state_x, state_y, goal_x, goal_y])"""# 在目标区域[3.5, 4.5] x [3.5, 4.5]内随机生成目标位置 (2维向量)goal_x = 4.0 + random.uniform(-0.5, 0.5)goal_y = 4.0 + random.uniform(-0.5, 0.5)self.goal = np.array([goal_x, goal_y])# 设置智能体初始位置为原点 (2维向量)self.state = np.array([0.0, 0.0])# 重置步数计数器 (标量)self.count = 0# 返回包含状态和目标的观测向量 (4维向量)return np.hstack((self.state, self.goal))def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool]:"""执行一个动作并返回下一个状态、奖励和是否结束Args:action (np.ndarray): 智能体的动作,包含x和y方向的移动量 (2维向量)Returns:Tuple[np.ndarray, float, bool]: - 下一个观测状态 (4维向量: [state_x, state_y, goal_x, goal_y])- 奖励值 (标量)- 是否结束标志 (布尔值)"""# 将动作限制在有效范围内[-action_bound, action_bound] (2维向量)action = np.clip(action, -self.action_bound, self.action_bound)# 计算执行动作后的新位置,并确保在地图边界内[0, map_bound] (标量)new_x = max(0.0, min(self.map_bound, self.state[0] + action[0]))new_y = max(0.0, min(self.map_bound, self.state[1] + action[1]))# 更新智能体位置 (2维向量)self.state = np.array([new_x, new_y])# 增加步数计数 (标量)self.count += 1# 计算当前位置与目标位置之间的欧几里得距离 (标量)distance = np.sqrt(np.sum(np.square(self.state - self.goal)))# 计算奖励:如果距离大于阈值则给予负奖励-1.0,否则给予0奖励 (标量)reward = -1.0 if distance > self.distance_threshold else 0.0# 判断是否结束:距离足够近或达到最大步数 (布尔值)if distance <= self.distance_threshold or self.count >= self.max_steps:done = Trueelse:done = False# 返回新的观测状态、奖励和结束标志# 观测状态包含当前位置和目标位置 (4维向量)return np.hstack((self.state, self.goal)), reward, done

4.2 智能体与算法 (DDPG)

我们选择 DDPG (深度确定性策略梯度) 作为基础的 off-policy 算法。DDPG 包含一个 Actor (策略网络) 和一个 Critic (Q值网络),非常适合处理连续动作空间问题。

  • PolicyNet: Actor 网络,输入状态 s(包含目标 g),输出一个确定性的动作 a
  • QValueNet: Critic 网络,输入状态 s 和动作 a,输出该状态-动作对的Q值。
  • DDPG: 算法主类,集成了 Actor 和 Critic,并包含目标网络、优化器、软更新和 update 逻辑。这里的实现是标准的 DDPG。
# 要训练的智能体和采用的算法
import torch
import torch.nn.functional as F
import numpy as np
from typing import Dict, Anyclass PolicyNet(torch.nn.Module):"""策略网络(Actor网络)用于输出连续动作空间中的动作值"""def __init__(self, state_dim: int, hidden_dim: int, action_dim: int, action_bound: float) -> None:"""初始化策略网络Args:state_dim (int): 状态空间维度 (标量)hidden_dim (int): 隐藏层神经元数量 (标量)action_dim (int): 动作空间维度 (标量)action_bound (float): 动作边界值,动作取值范围为[-action_bound, action_bound] (标量)"""super(PolicyNet, self).__init__()# 第一个全连接层:状态维度 -> 隐藏层维度self.fc1 = torch.nn.Linear(state_dim, hidden_dim)# 第二个全连接层:隐藏层维度 -> 隐藏层维度self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)# 输出层:隐藏层维度 -> 动作维度(本环境中动作维度为2)self.fc3 = torch.nn.Linear(hidden_dim, action_dim)# 动作边界,用于将输出限制在有效范围内 (标量)self.action_bound = action_bounddef forward(self, x: torch.Tensor) -> torch.Tensor:"""前向传播计算动作输出Args:x (torch.Tensor): 输入状态 (batch_size, state_dim)Returns:torch.Tensor: 输出动作,范围在[-action_bound, action_bound] (batch_size, action_dim)"""# 通过两个隐藏层,使用ReLU激活函数 (batch_size, hidden_dim)x = F.relu(self.fc2(F.relu(self.fc1(x))))# 输出层使用tanh激活函数,将输出限制在[-1, 1],然后乘以action_bound# 得到范围在[-action_bound, action_bound]的动作 (batch_size, action_dim)return torch.tanh(self.fc3(x)) * self.action_boundclass QValueNet(torch.nn.Module):"""Q值网络(Critic网络)用于评估给定状态和动作的Q值"""def __init__(

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/pingmian/94057.shtml
繁体地址,请注明出处:http://hk.pswp.cn/pingmian/94057.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端 H5分片上传 vue实现大文件

用uniapp开发APP上传视频文件&#xff0c;大文件可以上传成功&#xff0c;但是一旦打包为H5的代码&#xff0c;就会一提示链接超时&#xff0c;我的代码中是实现的上传到阿里云 如果需要看全文的私信我 官方开发文档地址 前端&#xff1a;使用分片上传的方式上传大文件_对象…

Linux服务器Systemctl命令详细使用指南

目录 1. 基本语法 2. 基础命令速查表 3. 常用示例 3.1 部署新服务后&#xff0c;设置开机自启并启动 3.2 检查系统中所有失败的服务并尝试修复 3.3 查看系统中所有开机自启的服务 4. 总结 以下是 systemctl 使用指南&#xff0c;涵盖服务管理、单元操作、运行级别控制、…

【JVM内存结构系列】二、线程私有区域详解:程序计数器、虚拟机栈、本地方法栈——搞懂栈溢出与线程隔离

上一篇文章我们搭建了JVM内存结构的整体框架,知道程序计数器、虚拟机栈、本地方法栈属于“线程私有区域”——每个线程启动时会单独分配内存,线程结束后内存直接释放,无需GC参与。这三个区域看似“小众”,却是理解线程执行逻辑、排查栈溢出异常的关键,也是面试中高频被问的…

红帽认证升级华为openEuler证书活动!

如果您有红帽证书&#xff0c;可以升级以下相应的证书&#xff1a;&#x1f447; 有RHCSA证书&#xff0c;可以99元升级openEuler HCIA 有RHCE证书&#xff0c;可以99元升级openEuler HCIP 有RHCA证书&#xff0c;可以2100元升级openEuler HCIE 现金激励&#xff1a;&#x1f4…

迭代器模式与几个经典的C++实现

迭代器模式详解1. 定义与意图迭代器模式&#xff08;Iterator Pattern&#xff09; 是一种行为设计模式&#xff0c;它提供一种方法顺序访问一个聚合对象中的各个元素&#xff0c;而又不暴露该对象的内部表示。主要意图&#xff1a;为不同的聚合结构提供统一的遍历接口。将遍历…

epoll 陷阱:隧道中的高级负担

上周提到了 tun/tap 转发框架的数据通道结构和优化 tun/tap 转发性能优化&#xff0c;涉及 RingBuffer&#xff0c;packetization 等核心话题。我也给出了一定的数据结构以及处理逻辑&#xff0c;但竟然没有高尚的 epoll&#xff0c;本文说说它&#xff0c;因为它不适合。 epo…

微前端架构常见框架

1. iframe 这里指的是每个微应用独立开发部署,通过 iframe 的方式将这些应用嵌入到父应用系统中,几乎所有微前端的框架最开始都考虑过 iframe,但最后都放弃,或者使用部分功能,原因主要有: url 不同步。浏览器刷新 iframe url 状态丢失、后退前进按钮无法使用。 UI 不同…

SQL Server更改日志模式:操作指南与最佳实践!

全文目录&#xff1a;开篇语**前言****摘要****概述&#xff1a;SQL Server 的日志模式****日志模式的作用****三种日志模式**1. **简单恢复模式&#xff08;Simple&#xff09;**2. **完整恢复模式&#xff08;Full&#xff09;**3. **大容量日志恢复模式&#xff08;Bulk-Log…

git的工作使用中实际经验

老输入烦人的密码 每次我git pull的时候都要叫我输入三次烦人的密码&#xff0c;问了deepseek也没有尝试成功 出现 enter passphrase for key ‘~/.ssh/id_rsa’ 的原因: 在生成key的时候,没有注意,不小心设置了密码, 导致每次提交的时候都会提示要输入密码, 也就是上面的提示…

科技赋能,宁夏农业绘就塞上新“丰”景

在贺兰山的巍峨身影下&#xff0c;在黄河水的温柔滋养中&#xff0c;宁夏这片古老而神奇的土地&#xff0c;正借助农业科技的磅礴力量&#xff0c;实现从传统农耕到智慧农业的华丽转身&#xff0c;奏响一曲科技与自然和谐共生的壮丽乐章。一、数字农业&#xff1a;开启智慧种植…

imx6ull-驱动开发篇36——Linux 自带的 LED 灯驱动实验

在之前的文章里&#xff0c;我们掌握了无设备树和有设备树这两种 platform 驱动的开发方式。但实际上有现成的&#xff0c;Linux 内核的 LED 灯驱动采用 platform 框架&#xff0c;我们只需要按照要求在设备树文件中添加相应的 LED 节点即可。本讲内容&#xff0c;我们就来学习…

深度学习中主流激活函数的数学原理与PyTorch实现综述

1. 定义与作用什么是激活函数&#xff1f;激活函数有什么用&#xff1f;答&#xff1a;激活函数&#xff08;Activation Function&#xff09;是一种添加到人工神经网络中的函数&#xff0c;旨在帮助网络学习数据中的复杂模式。类似于人类大脑中基于神经元的模型&#xff0c;激…

Linux高效备份:rsync + inotify实时同步

一、rsync 简介 rsync&#xff08;Remote Sync&#xff09;是 Linux 系统下的数据镜像备份工具&#xff0c;支持本地复制、远程同步&#xff08;通过 SSH 或 rsync 协议&#xff09;&#xff0c;是一个快速、安全、高效的增量备份工具。二、rsync 特性 支持镜像保存整个目录树和…

一种通过模板输出Docx的方法

起因在2个群里都有网友讨论这个问题&#xff0c;俺就写了一个最简单的例子。其实&#xff0c;我们经常遇到一些Docx的输出的需求&#xff0c;“用模板文件进行处理”是最简单的一个方法&#xff0c;如果想预览也简单 DevExpress 、Teleric 都可以&#xff0c;而且也支持 Web 、…

探索 List 的奥秘:自己动手写一个 STL List✨

&#x1f4d6;引言大家好&#xff01;今天我们要一起来揭开 C 中 list 容器的神秘面纱——不是直接用 STL&#xff0c;而是亲手实现一个简化版的 list&#xff01;&#x1f389;你是不是曾经好奇过&#xff1a;list 是怎么做到高效插入和删除的&#xff1f;&#x1f50d;迭代器…

mysql占用高内存排查与解决

mysql占用高内存排查-- 查看当前全局内存使用情况&#xff08;需要启用 performance_schema&#xff09; SELECT * FROM sys.memory_global_total; -- 查看总内存使用 SELECT * FROM sys.memory_global_by_current_bytes LIMIT 10; -- 按模块分类查看内存使用排行memory/perfor…

构建真正自动化知识工作的AI代理

引言&#xff1a;新一代生产力范式的黎明 自动化知识工作的人工智能代理&#xff08;AI Agent&#xff09;&#xff0c;或称“智能体”&#xff0c;正迅速从理论构想演变为重塑各行各业生产力的核心引擎。这些AI代理被定义为能够感知环境、进行自主决策、动态规划、调用工具并持…

青少年机器人技术(四级)等级考试试卷-实操题(2021年12月)

更多内容和历年真题请查看网站&#xff1a;【试卷中心 -----> 电子学会 ----> 机器人技术 ----> 四级】 网站链接 青少年软件编程历年真题模拟题实时更新 青少年机器人技术&#xff08;四级&#xff09;等级考试试卷-实操题&#xff08;2021年12月&#xff09; …

最新短网址源码,防封。支持直连、跳转。 会员无广

最新短网址源码&#xff0c;防封。支持直连、跳转。 会员无广告1.可将长网址自动缩短为短网址&#xff0c;方便记忆和使用。2.短网址默认为临时有效&#xff0c;可付费升级为永久有效&#xff0c;接入支付后可自动完成&#xff0c;无需人工操作。3.系统支持设置图片/文字/跳转页…

缓存-变更事件捕捉、更新策略、本地缓存和热key问题

缓存-基础知识 熟悉计算机基础的同学们都知道&#xff0c;服务的存储大多是多层级的&#xff0c;呈现金字塔类型。通常来说本机存储比通过网络通信的外部存储更快&#xff08;现在也不一定了&#xff0c;因为网络传输速度很快&#xff0c;至少可以比一些过时的本地存储设备速度…