双深度Q网络(Double DQN)基础解析与python实例:训练稳定倒立摆

目录

1. 前言

2. Double DQN的核心思想

3. Double DQN 实例:倒立摆

4. Double DQN的关键改进点

5. 双重网络更新策略

6. 总结


1. 前言

在强化学习领域,深度Q网络(DQN)开启了利用深度学习解决复杂决策问题的新篇章。然而,标准DQN存在一个显著问题:Q值的过估计。为解决这一问题,Double DQN应运而生,它通过引入两个网络来减少Q值的过估计,从而提高策略学习的稳定性和性能。本文将深入浅出地介绍Double DQN的核心思想,并通过一个完整python实现案例,帮助大家全面理解强化这一学习算法。

2. Double DQN的核心思想

标准DQN使用同一个网络同时选择动作和评估动作价值,这容易导致Q值的过估计。Double DQN通过将动作选择和价值评估分离到两个不同的网络来解决这个问题:

  1. 一个网络(在线网络)用于选择当前状态下的最佳动作

  2. 另一个网络(目标网络)用于评估这个动作的价值

这种分离减少了自举过程中动作选择和价值评估的关联性,从而有效降低了Q值的过估计。

结构如下:

3. Double DQN 实例:倒立摆

接下来,我们将实现一个完整的Double DQN,解决CartPole平衡问题。这个例子包含了所有关键组件:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
import random
from collections import deque# 1. 定义DQN网络结构
class DQN(nn.Module):def __init__(self, state_dim, action_dim):super(DQN, self).__init__()self.fc1 = nn.Linear(state_dim, 128)self.fc2 = nn.Linear(128, 128)self.fc3 = nn.Linear(128, action_dim)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))x = self.fc3(x)return x# 2. 经验回放缓冲区
class ReplayBuffer:def __init__(self, capacity):self.buffer = deque(maxlen=capacity)def add(self, state, action, reward, next_state, done):self.buffer.append((state, action, reward, next_state, done))def sample(self, batch_size):samples = random.sample(self.buffer, batch_size)states, actions, rewards, next_states, dones = zip(*samples)return states, actions, rewards, next_states, donesdef __len__(self):return len(self.buffer)# 3. Double DQN代理
class DoubleDQNAgent:def __init__(self, state_dim, action_dim):self.policy_net = DQN(state_dim, action_dim)self.target_net = DQN(state_dim, action_dim)self.target_net.load_state_dict(self.policy_net.state_dict())self.target_net.eval()self.optimizer = optim.Adam(self.policy_net.parameters(), lr=0.001)self.replay_buffer = ReplayBuffer(10000)self.batch_size = 64self.gamma = 0.99  # 折扣因子self.epsilon = 1.0  # 探索率self.epsilon_decay = 0.995self.min_epsilon = 0.01self.action_dim = action_dim# 根据ε-greedy策略选择动作def select_action(self, state):if random.random() < self.epsilon:return random.randint(0, self.action_dim - 1)else:with torch.no_grad():return self.policy_net(torch.FloatTensor(state)).argmax().item()# 存储经验def store_transition(self, state, action, reward, next_state, done):self.replay_buffer.add(state, action, reward, next_state, done)# 更新网络def update(self):if len(self.replay_buffer) < self.batch_size:return# 从经验回放中采样states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)# 转换为PyTorch张量states = torch.FloatTensor(states)actions = torch.LongTensor(actions)rewards = torch.FloatTensor(rewards)next_states = torch.FloatTensor(next_states)dones = torch.FloatTensor(dones)# 计算当前Q值current_q = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1)# 计算目标Q值(使用Double DQN方法)# 使用策略网络选择动作,目标网络评估价值with torch.no_grad():# 从策略网络中选择最佳动作policy_actions = self.policy_net(next_states).argmax(dim=1)# 从目标网络中评估这些动作的值next_q = self.target_net(next_states).gather(1, policy_actions.unsqueeze(1)).squeeze(1)target_q = rewards + self.gamma * next_q * (1 - dones)# 计算损失并优化loss = nn.MSELoss()(current_q, target_q)self.optimizer.zero_grad()loss.backward()self.optimizer.step()# 更新目标网络(软更新)for target_param, policy_param in zip(self.target_net.parameters(), self.policy_net.parameters()):target_param.data.copy_(0.001 * policy_param.data + 0.999 * target_param.data)# 减少探索率self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay)# 训练过程def train_double_dqn():# 创建环境env = gym.make('CartPole-v1')state_dim = env.observation_space.shape[0]action_dim = env.action_space.n# 创建代理agent = DoubleDQNAgent(state_dim, action_dim)# 训练参数episodes = 500max_steps = 500# 训练循环for episode in range(episodes):state, _ = env.reset()total_reward = 0for step in range(max_steps):action = agent.select_action(state)next_state, reward, done, _, _ = env.step(action)# 修改奖励以加速学习reward = reward if not done else -10agent.store_transition(state, action, reward, next_state, done)agent.update()total_reward += rewardstate = next_stateif done:break# 每10个episodes更新一次目标网络if episode % 10 == 0:agent.target_net.load_state_dict(agent.policy_net.state_dict())print(f"Episode: {episode + 1}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.2f}")env.close()# 执行训练
if __name__ == "__main__":train_double_dqn()

4. Double DQN的关键改进点

  1. 双网络结构:通过将动作选择(策略网络)和价值评估(目标网络)分离,减少了Q值的过估计。

  2. 经验回放:通过存储和随机采样历史经验,打破了数据的相关性,提高了学习稳定性。

  3. ε-greedy策略:平衡探索与利用,随着训练进行逐渐减少探索概率。

目标网络在Double DQN中扮演着非常重要的角色:

  • 它为策略网络提供稳定的目标Q值

  • 通过延迟更新,减少了目标Q值的波动

  • 与策略网络共同工作,实现了动作选择和价值评估的分离

5. 双重网络更新策略

在Double DQN中,我们使用了软更新(soft update)策略来更新目标网络:

for target_param, policy_param in zip(self.target_net.parameters(), self.policy_net.parameters()):target_param.data.copy_(0.001 * policy_param.data + 0.999 * target_param.data)

这种软更新方式比传统的目标网络定期硬更新(hard update)更平滑,有助于训练过程的稳定。

6. 总结

本文通过详细讲解Double DQN的原理,并提供了完整的python实现代码,展示了如何应用这一先进强化学习算法解决实际问题。与标准DQN相比,Double DQN通过引入双网络结构,有效解决了Q值过估计问题,提高了策略学习的稳定性和最终性能。Double DQN是强化学习领域的一个重要进步,为后续更高级的算法(如Dueling DQN、C51、Rainbow DQN等)奠定了基础。通过理解Double DQN的原理和实现,读者可以为进一步探索复杂强化学习算法打下坚实基础。在实际应用中,可以根据具体任务调整网络结构、超参数(如学习率、折扣因子、经验回放缓冲区大小等)以及探索策略,以获得最佳性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/diannao/84515.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用KubeKey快速部署k8s v1.31.8集群

实战环境涉及软件版本信息&#xff1a; 使用kubekey部署k8s 1. 操作系统基础配置 设置主机名、DNS解析、时钟同步、防火墙关闭、ssh免密登录等等系统基本设置 dnf install -y curl socat conntrack ebtables ipset ipvsadm 2. 安装部署 K8s 2.1 下载 KubeKey ###地址 https…

SQL:窗口函数(Window Functions)

目录 什么是窗口函数&#xff1f; 基本语法结构 为什么要用窗口函数&#xff1f; 常见的窗口函数分类 1️⃣ 排名类函数 2️⃣ 聚合类函数&#xff08;不影响原始行&#xff09; 3️⃣ 值访问函数 窗口范围说明&#xff08;ROWS / RANGE&#xff09; 什么是窗口函数&a…

相机内参 opencv

视场角定相机内参 import numpy as np import cv2 import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3Ddef calculate_camera_intrinsics(image_width640, image_height480, fov55, is_horizontalTrue):"""计算相机内参矩阵参数:image_w…

MATLAB 各个工具箱 功能说明

​ 想必大家在安装MATLAB时&#xff0c;或多或少会疑惑应该安装哪些工具箱。笔者遇到了两种情况——只安装了MATLAB主程序&#xff0c;老师让用MATLAB的时候却发现没有安装对应安装包&#xff1b;第二次安装学聪明了&#xff0c;全选安装&#xff0c;嗯……占用了20多个G。 ​…

学习日记-day14-5.23

完成目标&#xff1a; 学习java下半段课程 知识点&#xff1a; 1.多态转型 知识点 核心内容 重点 多态转型 向上转型&#xff08;父类引用指向子类对象&#xff09; 与向下转型&#xff08;强制类型转换&#xff09;的机制与区别 向上转型自动完成&#xff0c;向下转型需…

【编程语言】【Java】一篇文章学习java,复习完善知识体系

第一章 Java基础 1.1 变量与数据类型 1.1.1 基本数据类型 1.1.1.1 整数类型&#xff08;byte、short、int、long&#xff09; 在 Java 中&#xff0c;整数类型用于表示没有小数部分的数字&#xff0c;不同的整数类型有不同的取值范围和占用的存储空间&#xff1a; byte&am…

汇量科技前端面试题及参考答案

数组去重的方法有哪些&#xff1f; 在 JavaScript 中&#xff0c;数组去重是一个常见的操作&#xff0c;有多种方法可以实现这一目标。每种方法都有其适用场景和性能特点&#xff0c;下面将详细介绍几种主要的去重方法。 使用 Set 数据结构 Set 是 ES6 引入的一种新数据结构&a…

Git实战演练,模拟日常使用,快速掌握命令

01 引言 上一期借助Idea&#xff0c;完成了Git仓库的建立、配置、代码提交等操作&#xff0c;初步入门了Git的使用。然而日常开发中经常面临各种各样的问题&#xff0c;入门级的命令远远不够使用。 这一期&#xff0c;我们将展开介绍Git的日常处理命令&#xff0c;解决日常问…

wordpress主题开发中常用的12个模板文件

在WordPress主题开发中&#xff0c;有多种常用的模板文件&#xff0c;它们负责控制网站不同部分的显示内容和布局&#xff0c;以下是一些常见的模板文件&#xff1a; 1.index.php 这是WordPress主题的核心模板文件。当没有其他更具体的模板文件匹配当前页面时&#xff0c;Wor…

数据库blog5_数据库软件架构介绍(以Mysql为例)

&#x1f33f;软件的架构 &#x1f342;分类 软件架构总结为两种主要类型&#xff1a;一体式架构和分布式架构 ● 一体化架构 一体式架构是一种将所有功能集成到一个单一的、不可分割的应用程序中的架构模式。这种架构通常是一个大型的、复杂的单一应用程序&#xff0c;包含所…

离线服务器算法部署环境配置

本文将详细记录我如何为一台全新的离线服务器配置必要的运行环境&#xff0c;包括基础编译工具、NVIDIA显卡驱动以及NVIDIA-Docker&#xff0c;以便顺利部署深度学习算法。 前提条件&#xff1a; 目标离线服务器已安装操作系统&#xff08;本文以Ubuntu 18.04为例&#xff09…

chromedp -—— 基于 go 的自动化操作浏览器库

chromedp chromedp 是一个用于 Chrome 浏览器的自动化测试工具&#xff0c;基于 Go 语言开发&#xff0c;专门用于控制和操作 Chrome 浏览器实例。 chromedp 安装 go get -u github.com/chromedp/chromedp基于chromedp 实现的的简易学习通刷课系统 目前实现的功能&#xff…

高级特性实战:死信队列、延迟队列与优先级队列(三)

四、优先级队列&#xff1a;优先处理重要任务 4.1 优先级队列概念解析 优先级队列&#xff08;Priority Queue&#xff09;是一种特殊的队列数据结构&#xff0c;它与普通队列的主要区别在于&#xff0c;普通队列遵循先进先出&#xff08;FIFO&#xff09;的原则&#xff0c;…

python打卡day34

GPU训练及类的call方法 知识点回归&#xff1a; CPU性能的查看&#xff1a;看架构代际、核心数、线程数GPU性能的查看&#xff1a;看显存、看级别、看架构代际GPU训练的方法&#xff1a;数据和模型移动到GPU device上类的call方法&#xff1a;为什么定义前向传播时可以直接写作…

Newtonsoft Json序列化数据不序列化默认数据

问题描述 数据在序列号为json时,一些默认值也序列化了,像旋转rot都是0、缩放scal都是1,这样的默认值完全可以去掉,减少和服务器通信数据量 核心代码 数据结构字段增加[DefaultValue(1.0)]属性,缩放的默认值为1 public class Vec3DataOne{[DefaultValue(1.0)] public flo…

可增添功能的鼠标右键优化工具

软件介绍 本文介绍一款能优化Windows电脑的软件&#xff0c;它可以让鼠标右键菜单添加多种功能。 软件基本信息 这款名为Easy Context Menu的鼠标右键菜单工具非常小巧&#xff0c;软件大小仅1.14MB&#xff0c;打开即可直接使用&#xff0c;无需进行安装。 添加功能列举 它…

Gemini 2.5 Pro 一次测试

您好&#xff0c;您遇到的重定向循环问题&#xff0c;即在 /user/messaging、/user/login?return_to/user/messaging 和 /user/login 之间反复跳转&#xff0c;通常是由于客户端的身份验证状态检查和页面重定向逻辑存在冲突或竞争条件。 在分析了您提供的代码&#xff08;特别…

vue3前端后端地址可配置方案

在开发vue3项目过程中&#xff0c;需要切换不同的服务器部署&#xff0c;代码中配置的服务需要可灵活配置&#xff0c;不随着run npm build把网址打包到代码资源中&#xff0c;不然每次切换都需要重新run npm build。需要一个配置文件可以修改服务地址&#xff0c;而打包的代码…

大模型微调与高效训练

随着预训练大模型(如BERT、GPT、ViT、LLaMA、CLIP等)的崛起,人工智能进入了一个新的范式:预训练-微调(Pre-train, Fine-tune)。这些大模型在海量数据上学习到了通用的、强大的表示能力和世界知识。然而,要将这些通用模型应用于特定的下游任务或领域,通常还需要进行微调…

编程技能:字符串函数10,strchr

专栏导航 本节文章分别属于《Win32 学习笔记》和《MFC 学习笔记》两个专栏&#xff0c;故划分为两个专栏导航。读者可以自行选择前往哪个专栏。 &#xff08;一&#xff09;WIn32 专栏导航 上一篇&#xff1a;编程技能&#xff1a;字符串函数09&#xff0c;strncmp 回到目录…