从代码学习深度强化学习 - 多臂老虎机 PyTorch版

文章目录

  • 前言
  • 创建多臂老虎机环境
  • 多臂老虎机算法基本框架(基类)
  • 1. ε-贪心算法 (Epsilon-Greedy)
  • 2. 随时间衰减的ε-贪婪算法 (Decaying ε-Greedy)
  • 3. 上置信界算法 (Upper Confidence Bound, UCB)
  • 4. 汤普森采样算法 (Thompson Sampling)
  • 总结


前言

欢迎来到“从代码学习深度强化学习”系列!在本篇文章中,我们将深入探讨一个强化学习中的经典问题——多臂老虎机(Multi-Armed Bandit, MAB)

多臂老虎机问题,顾名思义,源于一个赌徒在赌场面对一排老虎机(即“多臂老虎机”)的场景。每个老虎机(“臂”)都有其内在的、未知的获奖概率。赌徒的目标是在有限的回合内,通过选择拉动不同的老虎机,来最大化自己的总收益。

这看似简单的场景,却完美地诠释了强化学习中的一个核心困境:探索(Exploration)与利用(Exploitation)的权衡

  • 利用(Exploitation):选择当前已知收益最高的老虎机。这能保证我们在短期内获得不错的收益,但可能会错过一个潜在收益更高但尚未被充分尝试的选项。
  • 探索(Exploration):尝试那些我们不确定其收益的老虎机。这可能会在短期内牺牲一些收益,但却有机会发现全局最优的选择,从而获得更高的长期总回报。

为了量化算法的性能,我们引入一个重要概念——累积懊悔(Cumulative Regret)。懊悔指的是在某一步选择的动作所带来的期望收益与“上帝视角”下最优动作的期望收益之差。一个优秀的算法,其目标就是最小化在整个过程中的累积懊悔。

在本篇博客中,我们将通过 Python 代码,从零开始实现一个多臂老虎机环境,并逐步实现和对比以下四种经典的求解策略:

  1. ε-贪心算法 (Epsilon-Greedy)
  2. 随时间衰减的ε-贪心算法 (Decaying Epsilon-Greedy)
  3. 上置信界算法 (Upper Confidence Bound, UCB)
  4. 汤普森采样算法 (Thompson Sampling)

关于 PyTorch: 尽管标题提及 PyTorch,但对于 MAB 这种基础问题,使用 NumPy 能更清晰地展示算法的核心逻辑,而无需引入深度学习框架的复杂性。本文中的实现将基于 NumPy,但其核心思想(如价值估计、策略选择)是构建更复杂的深度强化学习算法(如DQN)的基石,在那些场景中 PyTorch 将发挥关键作用。

让我们开始吧!

完整代码:下载链接

创建多臂老虎机环境

多臂老虎机问题可以表示为一个元组 ⟨ A , R ⟩ \langle\mathcal{A},\mathcal{R}\rangle A,R ,其中:

  • A \mathcal{A} A为动作集合,其中一个动作表示拉动一个拉杆。若多臂老虎机一共有 K K K根拉杆,那动作空间就是集合 { a 1 , … , a K } \{a_1,\ldots,a_K\} {a1,,aK},我们用 a t ∈ A a_t\in\mathcal{A} atA表示任意一个动作;
  • R \mathcal{R} R为奖励概率分布,拉动每一根拉杆的动作 a a a都对应一个奖励概率分布 R ( r ∣ a ) \mathcal{R}(r|a) R(ra),不同拉杆的奖励分布通常是不同的。

假设每个时间步只能拉动一个拉杆,多臂老虎机的目标为最大化一段时间步 T T T内累积的奖励: max ⁡ ∑ t = 1 T r t , r t ∼ R ( ⋅ ∣ a t ) \max\sum_{t=1}^Tr_t, r_t\sim\mathcal{R}\left(\cdot|a_t\right) maxt=1Trt,rtR(at)。其中 a t a_t at表示在第 t t t个时间步拉动某一拉杆的动作, r t r_t rt表示动作 a t a_t at获得的奖励。

首先,我们需要一个模拟环境。我们创建一个 BernoulliBandit 类来模拟一个拥有 K 个臂的老虎机。每个臂都服从伯努利分布,即每次拉动它,会以一个固定的概率 p 获得奖励 1(获奖),以 1-p 的概率获得奖励 0(未获奖)。在我们的环境中,这 K 个臂的获奖概率 p 是在初始化时随机生成的,并且对我们的算法(智能体)是未知的。

# 导入需要使用的库
import numpy as np  # numpy是支持数组和矩阵运算的科学计算库
import matplotlib.pyplot as plt  # matplotlib是绘图库class BernoulliBandit:"""伯努利多臂老虎机类该类实现了一个多臂老虎机问题的环境,每个拉杆都服从伯努利分布"""def __init__(self, K):"""初始化伯努利多臂老虎机参数:K (int): 拉杆个数,标量属性:probs (numpy.ndarray): 每个拉杆的获奖概率数组,维度为 (K,)best_idx (int): 获奖概率最大的拉杆索引,标量best_prob (float): 最大的获奖概率值,标量K (int): 拉杆总数,标量"""# 随机生成K个0~1之间的数,作为拉动每根拉杆的获奖概率# probs: (K,) - K个拉杆的获奖概率数组self.probs = np.random.uniform(size=K)# 找到获奖概率最大的拉杆索引# best_idx: 标量 - 最优拉杆的索引号self.best_idx = np.argmax(self.probs)# 获取最大的获奖概率# best_prob: 标量 - 最大获奖概率值self.best_prob = self.probs[self.best_idx]# 保存拉杆总数# K: 标量 - 拉杆个数self.K = Kdef step(self, k):"""执行一次拉杆动作当玩家选择了k号拉杆后,根据该拉杆的获奖概率返回奖励结果参数:k (int): 选择的拉杆编号,标量,取值范围为 [0, K-1]返回:int: 奖励结果,标量1 表示获奖0 表示未获奖"""# 根据k号拉杆的获奖概率进行伯努利采样# np.random.rand(): 标量 - 生成[0,1)之间的随机数# self.probs[k]: 标量 - k号拉杆的获奖概率if np.random.rand() < self.probs[k]:return 1  # 获奖else:return 0  # 未获奖# 设定随机种子,使实验具有可重复性
np.random.seed(1)# 设置拉杆数量
# K: 标量 - 多臂老虎机的拉杆个数
K = 10# 创建一个10臂伯努利老虎机实例
# bandit_10_arm: BernoulliBandit对象 - 包含10个拉杆的老虎机
bandit_10_arm = BernoulliBandit(K)# 输出老虎机的基本信息
print("随机生成了一个%d臂伯努利老虎机" % K)
print("获奖概率最大的拉杆为%d号,其获奖概率为%.4f" % (bandit_10_arm.best_idx, bandit_10_arm.best_prob))

运行以上代码,我们创建了一个10臂老虎机,并打印出了最优拉杆的信息。在我们的实验中,1号拉杆是收益最高的,其获奖概率为 0.7203。这个信息算法本身是不知道的,但我们可以用它来计算懊悔。

随机生成了一个10臂伯努利老虎机
获奖概率最大的拉杆为1号,其获奖概率为0.7203

多臂老虎机算法基本框架(基类)

为了方便实现和比较不同的算法,我们先定义一个 Solver 基类。这个基类包含了所有算法都需要共享的功能,例如记录每个臂被拉动的次数、记录历史动作以及计算和更新累积懊悔。具体的决策逻辑(run_one_step)将由各个子类来实现, 需要求解的是选取某根拉杆的策略。

累积懊悔

对于每一个动作,我们定义其期望奖励为 Q ( a ) = E r ∼ R ( ⋅ ∣ a ) [ r ] {Q}(a)=\mathbb{E}_{r\sim\mathcal{R}(\cdot|a)}\begin{bmatrix}r\end{bmatrix} Q(a)=ErR(a)[r]。于是,至少存在一根拉杆,它的期望奖励不小于拉动其他任意一根拉杆,我们将该最优期望奖励表示为 Q ∗ = max ⁡ a ∈ A Q ( a ) Q^*=\max_{a\in\mathcal{A}}Q(a) Q=maxaAQ(a)。为了更加直观、方便地观察拉动一根拉杆的期望奖励离最优拉杆期望奖励的差距,我们引入懊悔(regret)概念。懊悔定义为拉动当前拉杆的动作与最优拉杆的期望奖励差,即 R ( a ) = Q ∗ − Q ( a ) R(a)=Q^*-Q(a) R(a)=QQ(a)累积懊悔(cumulative regret)即操作 T T

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/bicheng/84632.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android学习之Window窗口

Android Window机制学习笔记 在使用Window Flag实现界面全屏功能时&#xff0c;发现自身对Android Window机制缺乏系统认知&#xff0c;因此进行了专项学习与整理。 本文主要参考以下优质资料&#xff1a; Android的Window详解Android官方Window文档 Window基本概念 1. Win…

华为云 Flexus+DeepSeek 征文|搭建部署Dify-LLM推理引擎,赋能AI Agent智能体实现动态联网搜索能力

华为云 Flexus 云服务器 X 实例专门为 AI 应用场景设计。它提供了强大的计算能力&#xff0c;能够满足 DeepSeek 模型以及后续搭建 AI Agent 智能体过程中对于数据处理和模型运行的高要求。在网络方面&#xff0c;具备高速稳定的网络带宽&#xff0c;这对于需要频繁联网搜索信息…

Python 100个常用函数全面解析

Python 100个常用函数全面解析 1. 类型转换函数 1.1 int() 将字符串或数字转换为整数。 # 基本用法 int(123) # 123 int(3.14) # 3# 指定进制转换 int(1010, 2) # 10 (二进制转十进制) int(FF, 16) # 255 (十六进制转十进制)# 临界值处理 int() # ValueError: …

分享在日常开发中常用的ES6知识点【面试常考】

前言 在日常的业务开发中&#xff0c;可以熟悉运用掌握的知识点快速解决问题很重要。这篇分享JS相关的知识点&#xff0c;主要就是对数据的处理。 注意&#xff1a;本篇分享的知识点&#xff0c;只是起到一个抛砖引玉的作用&#xff0c;详情的使用和更多的ES6知识点还请参考官…

CHI协议验证中的异常及边界验证

CHI协议验证中的异常及边界验证 针对 CHI 协议的错误注入工具、覆盖率衡量方法及实际项目中的投入平衡 CHI 协议作为多核系统中复杂的缓存一致性协议,验证其行为需要强大的工具和方法来执行错误注入和边界条件测试,并衡量测试覆盖率。以下详细讨论常用工具、覆盖率评估方法及…

技术专栏|LLaMA家族——模型架构

LLaMA的模型架构与GPT相同&#xff0c;采用了Transformer中的因果解码器结构&#xff0c;并在此基础上进行了多项关键改进&#xff0c;以提升训练稳定性和模型性能。LLaMA的核心架构如图 3.14 所示&#xff0c;融合了后续提出的多种优化方法&#xff0c;这些方法也在其他模型&a…

电脑插入多块移动硬盘后经常出现卡顿和蓝屏

当电脑在插入多块移动硬盘后频繁出现卡顿和蓝屏问题时&#xff0c;可能涉及硬件资源冲突、驱动兼容性、供电不足或系统设置等多方面原因。以下是逐步排查和解决方案&#xff1a; 1. 检查电源供电问题 问题原因&#xff1a;多块移动硬盘同时运行可能导致USB接口供电不足&#x…

Go 语言实现高性能 EventBus 事件总线系统(含网络通信、微服务、并发异步实战)

前言 在现代微服务与事件驱动架构&#xff08;EDA&#xff09;中&#xff0c;事件总线&#xff08;EventBus&#xff09; 是实现模块解耦与系统异步处理的关键机制。 本文将以 Go 语言为基础&#xff0c;从零构建一个高性能、可扩展的事件总线系统&#xff0c;深入讲解&#…

npm ERR! @biomejs/biome@1.9.4 postinstall: `node scripts/postinstall.js`

npm install 报错如下, npm ERR! code ELIFECYCLE npm ERR! errno 1 npm ERR! @biomejs/biome@1.9.4 postinstall: `node scripts/postinstall.js` npm ERR! Exit status 1 npm ERR! npm ERR! Failed at the @biomejs/biome@1.9.4 postinstall script. npm ERR! This is pro…

APMPlus × veFaaS 一键开启函数服务性能监控,让函数运行全程可观测

资料来源&#xff1a;火山引擎-开发者社区 近年来&#xff0c;无服务器架构&#xff08;Serverless&#xff09;的崛起让开发者得以从基础设施的复杂性中解放&#xff0c;专注于业务逻辑创新。但随着采用率提升&#xff0c;新的问题开始出现——函数实例的短暂生命周期、动态变…

玛哈特零件矫平机:精密制造中的平整度守护者

在精密制造、模具、冲压、钣金加工、汽车零部件、航空航天以及电子设备等众多工业领域&#xff0c;零件的平整度&#xff08;Flatness&#xff09;是一项至关重要的质量指标。微小的翘曲、扭曲或弯曲都可能导致装配困难、功能失效、外观缺陷甚至影响整机性能。为了消除零件在加…

std::make_shared简化智能指针 `std::shared_ptr` 的创建过程,并提高性能(减少内存分配次数,提高缓存命中率)

std::make_shared 是 C 标准库中的一个函数模板&#xff0c;用于简化智能指针 std::shared_ptr 的创建过程。引入 std::make_shared 的主要原因是提高代码的安全性、性能和可读性。以下是详细分析&#xff1a; 1. 安全性提升 避免显式调用 new 导致的错误 在不使用 std::make…

JDK版本如何丝滑切换

一句话总结 》》》步骤分为&#xff1a; 下载对应JDK配置环境变量 下载JDK 如何下载JDK这里不必多提&#xff0c;提出一点&#xff0c;就是多个版本的JDK最好放在一个文件夹里&#xff08;忽略我的java文件夹&#xff0c;这里都是不同的jdk版本&#xff09;&#xff1a; 配置环…

Rust 通用代码生成器:莲花,红莲尝鲜版三十六,哑数据模式图片初始化功能介绍

Rust 通用代码生成器&#xff1a;莲花&#xff0c;红莲尝鲜版三十六&#xff0c;哑数据模式图片初始化功能介绍 Rust 通用代码生成器莲花&#xff0c;红莲尝鲜版三十六。支持全线支持图片预览&#xff0c;可以直接输出带图片的哑数据模式快速原型。哑数据模式和枚举支持图片。…

45. Jump Game II

目录 题目描述 贪心 题目描述 45. Jump Game II 贪心 正向查找可到达的最大位置 时间复杂度O(n) class Solution { public:int jump(vector<int>& nums) {int n nums.size();if(n 1)return 0;int cur_cover 0;int cover 0;int res 0;for(int i 0;i < …

model.classifier 通常指模型的分类头 是什么,详细举例说明在什么部位,发挥什么作用

model.classifier 通常指模型的分类头 是什么,详细举例说明在什么部位,发挥什么作用 在深度学习模型中,分类头(Classifier Head)是指模型末端用于完成分类任务的组件,通常是一个或多个全连接层(线性层)。它的作用是将模型提取的高层语义特征映射到具体的分类标签空间。…

机器学习+城市规划第十四期:利用半参数地理加权回归来实现区域带宽不同的规划任务

机器学习城市规划第十四期&#xff1a;利用半参数地理加权回归来实现区域带宽不同的规划任务 引言 在城市规划中&#xff0c;如何根据不同地区的地理特征来制定有效的规划方案是一个关键问题。不同区域的需求和规律是不同的&#xff0c;因此我们必须考虑到地理空间的差异性。…

Kivy的ButtonBehavior学习

Kivy的ButtonBehavior学习 ButtonBehavior 简介1、主要特点2、基本用法3、主要事件4、常用属性5、方法代码示例 文档&#xff1a;https://kivy.org/doc/stable/api-kivy.uix.behaviors.button.html#kivy.uix.behaviors.button.ButtonBehavior ButtonBehavior 简介 ButtonBeha…

WPS中将在线链接转为图片

WPS中将在线链接转为图片 文章目录 WPS中将在线链接转为图片一&#xff1a;解决方案1、下载图片&#xff0c;精确匹配&#xff08;会员功能&#xff09;2、将在线链接直接转为图片 一&#xff1a;解决方案 1、下载图片&#xff0c;精确匹配&#xff08;会员功能&#xff09; …

API:解锁数字化协作的钥匙及开放实现路径深度剖析

API:解锁数字化协作的钥匙及开放实现路径深度剖析 一、API 的概念与本质 (一)定义与基本原理 API(Application Programming Interface,应用程序编程接口)是一组定义、协议和工具,用于构建和集成软件应用程序。它如同一个精心设计的合约,详细规定了软件组件之间相互交…