从代码学习深度强化学习 - PPO PyTorch版

文章目录

  • 前言
  • PPO 算法简介
    • 从 TRPO 到 PPO
    • PPO 的两种形式:惩罚与截断
  • 代码实践:PPO 解决离散动作空间问题 (CartPole)
    • 环境与工具函数
    • 定义策略与价值网络
    • PPO 智能体核心实现
    • 训练与结果
  • 代码实践:PPO 解决连续动作空间问题 (Pendulum)
    • 环境准备
    • 适用于连续动作的网络
    • PPO 智能体 (连续版)
    • 训练与结果
  • 总结


前言

欢迎来到深度强化学习(DRL)的世界!在众多 DRL 算法中,Proximal Policy Optimization (PPO) 无疑是最受欢迎和广泛应用的算法之一。它由 OpenAI 在 2017 年提出,以其出色的性能、相对简单的实现和稳定的训练过程而著称,成为了许多研究和应用的基准算法。

本篇博客旨在通过一个完整的 PyTorch 实现,带您从代码层面深入理解 PPO 算法。我们将不仅仅是看公式,更是要“动手”,一步步构建、训练和分析 PPO 智能体。为了全面掌握其应用,我们将分别在经典的离散动作空间(CartPole-v1)和连续动作空间(Pendulum-v1)两个环境中进行实践。

无论您是 DRL 的初学者,还是希望巩固 PPO 知识的实践者,相信通过这篇代码驱动的教程,您都能对 PPO 有一个更具体、更深刻的认识。

完整代码:下载链接


PPO 算法简介

在深入代码之前,我们先快速回顾一下 PPO 的核心思想。

从 TRPO 到 PPO

PPO 的思想源于 TRPO(Trust Region Policy Optimization)。TRPO 旨在通过限制每次策略更新的步长,确保更新后的策略不会与旧策略偏离太远,从而保证学习的稳定性。它的优化目标如下:

TRPO 通过一个 KL 散度的约束来限制策略更新的区域,但这个约束的计算过程非常复杂,涉及泰勒展开、共轭梯度、线性搜索等,导致其实现难度大,运算量也非常可观。

PPO 的出现正是为了解决这个问题。它继承了 TRPO 的核心思想,即在更新策略时不要“步子迈得太大”,但采用了更简单、更易于实现的方法。

PPO 的两种形式:惩罚与截断

PPO 主要有两种形式:PPO-PenaltyPPO-Clip

  1. PPO-Penalty (惩罚)
    它将 TRPO 的 KL 散度约束作为一个惩罚项直接放入目标函数中,变成一个无约束的优化问题,并通过一个动态调整的系数 β 来控制惩罚的力度。

  2. PPO-Clip (截断)
    这是更常用的一种形式,也是我们代码将要实现的版本。它直接在目标函数中进行截断(clip),以保证新的参数和旧的参数的差距不会太大。

    其核心思想在于 clip 函数。我们定义一个比率 r(θ) 为新策略与旧策略输出同一动作的概率之比。

    • 优势函数 A > 0 时(即当前动作优于平均水平),我们希望增大这个动作的概率,但 r(θ) 的上限被截断在 1+ε,防止策略更新过于激进。
    • 优势函数 A < 0 时(即当前动作劣于平均水平),我们希望减小这个动作的概率,但 r(θ) 的下限被截断在 1-ε,同样是为了限制更新幅度。

    下图直观地展示了 PPO-Clip 的目标函数 L^Clip 与概率比 r(θ) 的关系:

大量的实验表明,PPO-Clip 的性能通常比 PPO-Penalty 更好且更稳定。因此,我们的代码实践将专注于 PPO-Clip 的实现。

理论铺垫结束,让我们开始编码吧!

代码实践:PPO 解决离散动作空间问题 (CartPole)

我们将从经典的 CartPole-v1 环境开始,它要求智能体通过向左或向右施加力来保持杆子竖直不倒,是一个典型的离散动作空间问题(动作:0-向左,1-向右)。

环境与工具函数

首先,我们定义一些通用的工具函数并初始化环境。这里的核心是 compute_advantage 函数,它实现了广义优势估计(GAE),这是一种在偏差和方差之间取得平衡的优势函数计算方法,对于稳定策略梯度算法的训练至关重要。

PPO离散动作.ipynb

"""
强化学习工具函数集
包含广义优势估计(GAE)和数据平滑处理功能
"""import torch
import numpy as npdef compute_advantage(gamma, lmbda, td_delta):"""计算广义优势估计(Generalized Advantage Estimation,GAE)GAE是一种在强化学习中用于减少策略梯度方差的技术,通过对时序差分误差进行指数加权平均来估计优势函数,平衡偏差和方差的权衡。参数:gamma (float): 折扣因子,维度: 标量取值范围[0,1],决定未来奖励的重要性lmbda (float): GAE参数,维度: 标量  取值范围[0,1],控制偏差-方差权衡lmbda=0时为TD(0)单步时间差分,lmbda=1时为蒙特卡洛方法用采样到的奖励-状态价值估计td_delta (torch.Tensor): 时序差分误差序列,维度: [时间步数]包含每个时间步的TD误差值返回:torch.Tensor: 广义优势估计值,维度: [时间步数]与输入td_delta维度相同的优势函数估计数学公式:A_t^GAE(γ,λ) = Σ_{l=0}^∞ (γλ)^l * δ_{t+l}其中 δ_t = r_t + γV(s_{t+1}) - V(s_t) 是TD误差"""# 将PyTorch张量转换为NumPy数组进行计算# td_delta维度: [时间步数] -> [时间步数]td_delta = td_delta.detach().numpy() # 因为A用来求g的,需要梯度,防止梯度向下传播# 初始化优势值列表,用于存储每个时间步的优势估计# advantage_list维度: 最终为[时间步数]advantage_list = []# 初始化当前优势值,从序列末尾开始反向计算# advantage维度: 标量advantage = 0.0# 从时间序列末尾开始反向遍历TD误差# 反向计算是因为GAE需要利用未来的信息# delta维度: 标量(td_delta中的单个元素)for delta in td_delta[::-1]:  # [::-1]实现序列反转# GAE递归公式:A_t = δ_t + γλA_{t+1}# gamma * lmbda * advantage: 来自未来时间步的衰减优势值# delta: 当前时间步的TD误差# advantage维度: 标量advantage = gamma * lmbda * advantage + delta# 将计算得到的优势值添加到列表中# advantage_list维度: 逐步增长到[时间步数]advantage_list.append(advantage)# 由于是反向计算,需要将结果列表反转回正确的时间顺序# advantage_list维度: [时间步数](时间顺序已恢复)advantage_list.reverse()# 将NumPy列表转换回PyTorch张量并返回# 返回值维度: [时间步数]return torch.tensor(advantage_list, dtype=torch.float)def moving_average(data, window_size):"""计算移动平均值,用于平滑奖励曲线该函数通过滑动窗口的方式对时间序列数据进行平滑处理,可以有效减少数据中的噪声,使曲线更加平滑美观。常用于强化学习中对训练过程的奖励曲线进行可视化优化。参数:data (list): 原始数据序列,维度: [num_episodes]包含需要平滑处理的数值数据(如每轮训练的奖励值)window_size (int): 移动窗口大小,维度: 标量决定了平滑程度,窗口越大平滑效果越明显但也会导致更多的数据点丢失返回:list: 移动平均后的数据,维度: [len(data) - window_size + 1]返回的数据长度会比原数据少 window_size - 1 个元素这是因为需要足够的数据点来计算第一个移动平均值示例:>>> data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]  # 维度: [10]>>> smoothed = moving_average(data, 3)       # window_size = 3>>> print(smoothed)  # 输出: [2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]  维度: [8]"""# 边界检查:如果数据长度小于窗口大小,直接返回原数据# 这种情况下无法计算移动平均值# data维度: [num_episodes], window_size维度: 标量if len(data) < window_size:return data# 初始化移动平均值列表# moving_avg维度: 最终为[len(data) - window_size + 1]moving_avg = []# 遍历数据,计算每个窗口的移动平均值# i的取值范围: 0 到 len(data) - window_size# 循环次数: len(data) - window_size + 1# 每次循环处理一个滑动窗口位置for i in range(len(data) - window_size + 1):# 提取当前窗口内的数据切片# window_data维度: [window_size]# 包含从索引i开始的连续window_size个元素# 例如:当i=0, window_size=3时,提取data[0:3]window_data = data[i:i + window_size]# 计算当前窗口内数据的算术平均值# np.mean(window_data)维度: 标量# 将平均值添加到结果列表中moving_avg.append(np.mean(window_data))# 返回移动平均后的数据列表# moving_avg维度: [len(data) - window_size + 1]return moving_avg``````python
"""
强化学习环境初始化模块
用于创建和配置OpenAI Gym环境
"""import gym# 环境配置
# 定义要使用的强化学习环境名称
# CartPole-v1是经典的平衡杆控制问题:
# - 状态空间:4维连续空间(车位置、车速度、杆角度、杆角速度)
# - 动作空间:2维离散空间(向左推车、向右推车)
# - 目标:保持杆子平衡尽可能长的时间
# env_name维度: 标量(字符串)
env_name = 'CartPole-v1'# 创建强化学习环境实例
# gym.make()函数根据环境名称创建对应的环境对象
# 该环境对象包含了状态空间、动作空间、奖励函数等定义
# env维度: gym.Env对象(包含状态空间[4]和动作空间[2]的环境实例)
# env.observation_space.shape: (4,) - 观测状态维度
# env.action_space.n: 2 - 离散动作数量
env = gym.make(env_name)

定义策略与价值网络

PPO 是一种 Actor-Critic 架构的算法。我们需要定义两个网络:

  • 策略网络 (PolicyNet):作为 Actor,输入状态,输出一个动作的概率分布。
  • 价值网络 (ValueNet):作为 Critic,输入状态,输出该状态的价值估计 V(s)。
"""
PPO(Proximal Policy Optimization)算法实现
包含策略网络、价值网络和PPO智能体的完整定义
"""import torch
import torch.nn.functional as F
import numpy as npclass PolicyNet(torch.nn.Module):"""策略网络(Actor Network)用于输出动作概率分布,指导智能体如何选择动作"""def __init__(self, state_dim, hidden_dim, action_dim):"""初始化策略网络参数:state_dim (int): 状态空间维度,维度: 标量对于CartPole-v1环境,state_dim=4hidden_dim (int): 隐藏层神经元数量,维度: 标量控制网络的表达能力action_dim (int): 动作空间维度,维度: 标量对于CartPole-v1环境,action_dim=2"""super(PolicyNet, self).__init__()# 第一层全连接层:状态输入 -> 隐藏层# 输入维度: [batch_size, state_dim] -> 输出维度: [batch_size, hidden_dim]self.fc1 = torch.nn.Linear(state_dim, hidden_dim)# 第二层全连接层:隐藏层 -> 动作概率# 输入维度: [batch_size, hidden_dim] -> 输出维度: [batch_size, action_dim]self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):"""前向传播过程参数:x (torch.Tensor): 输入状态,维度: [batch_size, state_dim]返回:torch.Tensor: 动作概率分布,维度: [batch_size, action_dim]每行为一个状态对应的动作概率分布,概率和为1"""# 第一层 + ReLU激活函数# x维度: [batch_size, state_dim] -> [batch_size, hidden_dim]x = F.relu(self.fc1(x))# 第二层 + Softmax激活函数,输出概率分布# x维度: [batch_size, hidden_dim] -> [batch_size, action_dim]# dim=1表示在第1维(动作维度)上进行softmax,确保每行概率和为1return F.softmax(self.fc2(x), dim=1)class ValueNet(torch.nn.Module):"""价值网络(Critic Network)用于估计状态价值函数V(s),评估当前状态的好坏"""def __init__(self, state_dim, hidden_dim):"""初始化价值网络参数:state_dim (int): 状态空间维度,维度: 标量对于CartPole-v1环境,state_dim=4hidden_dim (int): 隐藏层神经元数量,维度: 标量控制网络的表达能力"""super(ValueNet, self).__init__()# 第一层全连接层:状态输入 -> 隐藏层# 输入维度: [batch_size, state_dim] -> 输出维度: [batch_size, hidden_dim]self.fc1 = torch.nn.Linear(state_dim, hidden_dim)# 第二层全连接层:隐藏层 -> 状态价值(标量)# 输入维度: [batch_size, hidden_dim] -> 输出维度: [batch_size, 1]self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):"""前向传播过程参数:x (torch.Tensor): 输入状态,维度: [batch_size, state_dim]返回:torch.Tensor: 状态价值估计,维度: [batch_size, 1]每行为一个状态对应的价值估计"""# 第一层 + ReLU激活函数# x维度: [batch_size, state_dim] -> [batch_size, hidden_dim]x = F.relu(self.fc1(x))# 第二层,输出状态价值(无激活函数,可以输出负值)# x维度: [batch_size, hidden_dim] -> [batch_size, 1]return self.fc2(x)

PPO 智能体核心实现

这是我们 PPO 算法的核心。PPO 类封装了 Actor 和 Critic,并实现了 take_action(动作选择)和 update(网络更新)两个关键方法。请特别关注 update 函数,它完整地实现了 PPO-Clip 的目标函数计算和参数更新逻辑。

class PPO:"""PPO(Proximal Policy Optimization)算法实现采用截断方式防止策略更新过大,确保训练稳定性"""def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,lmbda, epochs, eps, gamma, device):"""初始化PPO智能体参数:state_dim (int): 状态空间维度,维度: 标量hidden_dim (int): 隐藏层神经元数量,维度: 标量action_dim (int): 动作空间维度,维度: 标量actor_lr (float): Actor网络学习率,维度: 标量critic_lr (float): Critic网络学习率,维度: 标量lmbda (float): GAE参数λ,维度: 标量,取值范围[0,1]epochs (int): 每次更新的训练轮数,维度: 标量eps (float): PPO截断参数ε,维度: 标量,通常取0.1-0.3gamma (float): 折扣因子γ,维度: 标量,取值范围[0,1]device (torch.device): 计算设备(CPU或GPU),维度: 标量"""# 初始化Actor网络(策略网络)# 网络参数维度:fc1权重[state_dim, hidden_dim], fc2权重[hidden_dim, action_dim]self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)#

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/88580.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/88580.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PortsWiggerLab: Blind OS command injection with output redirection

实验目的This lab contains a blind OS command injection vulnerability in the feedback function.The application executes a shell command containing the user-supplied details. The output from the command is not returned in the response. However, you can use o…

星云穿越与超光速飞行特效的前端实现原理与实践

文章目录 1,引言2,特效设计思路3,技术原理解析1. 星点的三维分布2. 视角推进与星点运动3. 三维到二维的投影4. 星点的视觉表现5. 色彩与模糊处理4,关键实现流程图5,应用场景与优化建议6,总结1,引言 在现代网页开发中,炫酷的视觉特效不仅能提升用户体验,还能为产品增添…

【Linux】C++项目分层架构:核心三层与关键辅助

C 项目分层架构全指南&#xff1a;核心三层 关键辅助一、核心三层架构 传统的三层架构&#xff08;或三层体系结构&#xff09;是构建健壮系统的基石&#xff0c;包括以下三层&#xff1a; 1. 表现层&#xff08;Presentation Layer&#xff09; 负责展示和输入处理&#xff0…

【机器学习】保序回归平滑校准算法

保序回归平滑校准算法&#xff08;SIR&#xff09;通过分桶合并线性插值解决广告预估偏差问题&#xff0c;核心是保持原始排序下纠偏。具体步骤&#xff1a;1&#xff09;按预估分升序分桶&#xff0c;统计每个分桶的后验CTR&#xff1b;2&#xff09;合并逆序桶重新计算均值&a…

项目开发日记

框架整理学习UIMgr&#xff1a;一、数据结构与算法 1.1 关键数据结构成员变量类型说明m_CtrlsList<PageInfo>当前正在显示的所有 UI 页面m_CachesList<PageInfo>已打开过、但现在不显示的页面&#xff08;缓存池&#xff09; 1.2 算法逻辑查找缓存页面&#xff1a;…

60 美元玩转 Li-Fi —— 开源 OpenVLC 平台入门(附 BeagleBone Black 驱动简单解析)

60 美元玩转 Li-Fi —— 开源 OpenVLC 平台入门&#xff08;附 BeagleBone Black 及驱动解析&#xff09;一、什么是 OpenVLC&#xff1f; OpenVLC 是由西班牙 IMDEA Networks 研究所推出的开源可见光通信&#xff08;VLC / Li-Fi&#xff09;研究平台。它把硬件、驱动、协议栈…

Python性能优化

Python 以其简洁和易用性著称,但在某些计算密集型或大数据处理场景下,性能可能成为瓶颈。幸运的是,通过一些巧妙的编程技巧,我们可以显著提升Python代码的执行效率。本文将介绍8个实用的性能优化技巧,帮助你编写更快、更高效的Python代码。   一、优化前的黄金法则:先测…

easyui碰到想要去除顶部栏按钮边框

只需要加上 plain"true"<a href"javascript:void(0)" class"easyui-linkbutton" iconCls"icon-add" plain"true"onclick"newCheck()">新增</a>

C++字符串详解:原理、操作及力扣算法实战

一、C字符串简介在C中&#xff0c;字符串的处理方式主要有两种&#xff1a;字符数组&#xff08;C风格字符串&#xff09;和std::string类。虽然字符数组是C语言遗留的底层实现方式&#xff0c;但现代C更推荐使用std::string类&#xff0c;其封装了复杂的操作逻辑&#xff0c;提…

CMU15445-2024fall-project1踩坑经历

p1目录&#xff1a;lRU\_K替换策略LRULRU\_K大体思路SetEvictableRecordAccessSizeEvictRemoveDisk SchedulerBufferPoolNewPageDeletePageFlashPage/FlashAllPageCheckReadPage/CheckWritePagePageGuard并发设计主逻辑感谢CMU的教授们给我们分享了如此精彩的一门课程&#xff…

【C语言进阶】带你由浅入深了解指针【第四期】:数组指针的应用、介绍函数指针

前言上一期讲了数组指针的原理&#xff0c;这一期接着上一期讲述数组指针的应用以及数组参数、函数参数。首先看下面的代码进行上一期内容的复习&#xff0c;pc应该是什么类型&#xff1f;char* arr[5] {0}; xxx pc &arr;分析&#xff1a;①首先判断arr是一个数组&#x…

在HTML中CSS三种使用方式

一、行内样式在标签<>中输入style "属性&#xff1a;属性值;"。(中等使用频率)不利于CSS样式的复用&#xff1b;违背了CSS内容和样式分离的设计理念&#xff0c;后期难以维护。<p style"color: red">这是div中的p元素</p>二、内部样式在…

汽车功能安全-软件单元验证 (Software Unit Verification)【用例导出方法、输出物】8

文章目录1 软件单元验证用例导出方法2 测试用例完整性度量标准3 验证环境要求4 软件单元验证的工作产品1 软件单元验证用例导出方法 为确保软件单元测试的测试案例规范符合9.4.2要求&#xff0c;应通过表8所列方法开发测试用例。 表8 软件单元测试用例的得出方法&#xff1a; …

MySQL内置函数(8)

文章目录前言一、日期函数二、字符串函数三、数学函数四、其它函数总结前言 其实在之前的几篇中我们也用到了内置函数&#xff0c;现在我们再来系统学习一下它&#xff01; 一、日期函数 函数名称描述current_date()获取当前日期current_time()获取当前时间current_timestamp(…

苍穹外卖项目日记(day04)

苍穹外卖|项目日记(day04) 前言: 今天主要是接口开发, 涉及的新东西不多, 需要注意的只有多表联查和修改的逻辑,今日难点: 1.菜品的停起售状态设置 2.套餐的停起售状态设置 3.动态sql中的 useGeneratedKeys 与 keyProperty 两个参数 一. 菜品的停起售状态设置 ​ 在菜品的停售中…

React之旅-05 List Key

每个React的初学者&#xff0c;在调试程序时&#xff0c;都会遇到这样的警告&#xff1a;Warning: Each child in a list should have a unique "key" prop. 如下面的代码&#xff1a; const list [Learn React, Learn GraphQL];const ListWithoutKey () > (&l…

[特殊字符] 人工智能技术全景:从基础理论到前沿应用的深度解析

&#x1f680; 人工智能技术全景&#xff1a;从基础理论到前沿应用的深度解析 在这个AI驱动的时代&#xff0c;理解人工智能的核心技术和应用场景已成为技术人员的必备技能。本文将带你深入探索AI的发展脉络、核心技术差异以及在各行业的创新应用。 文章目录&#x1f680; 人工…

Go语言教程-环境搭建

前言 Go&#xff08;又称 Golang&#xff09;是由 Google 开发的一种 开源、静态类型、编译型 编程语言&#xff0c;于 2009 年正式发布。它旨在解决现代软件开发中的高并发、高性能和可维护性问题&#xff0c;尤其适合 云计算、微服务、分布式系统 等领域。 Go 语言国际官网…

windows指定某node及npm版本下载

下载并安装 nvm-windowshttps://github.com/coreybutler/nvm-windows/releases&#xff08;选择 nvm-setup.zip&#xff09;。打开命令提示符&#xff08;管理员权限&#xff09;&#xff0c;安装 Node.js v16.15.0&#xff1a; nvm install 16.15.0 nvm use 16.15.0 验证node版…

每天一个前端小知识 Day 28 - Web Workers / 多线程模型在前端中的应用实践

Web Workers / 多线程模型在前端中的应用实践&#x1f9e0; 一、为什么前端需要多线程&#xff1f; 单线程 JS 的瓶颈&#xff1a;浏览器主线程不仅负责执行 JS&#xff0c;还要负责&#xff1a; UI 渲染&#xff08;DOM/CSS&#xff09;用户事件处理&#xff08;点击、输入&am…