分布变化的模仿学习算法

与传统监督学习不同,直接模仿学习在不同时刻所面临的数据分布可能不同.试设计一个考虑不同时刻数据分布变化的模仿学习算法

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.neighbors import KernelDensity
import matplotlib.pyplot as pltclass TimeAwareImitationLearning:def __init__(self, state_dim, action_dim, hidden_dim=64, device='cpu'):"""初始化时间感知的模仿学习算法state_dim: 状态维度action_dim: 动作维度hidden_dim: 隐藏层维度"""self.state_dim = state_dimself.action_dim = action_dimself.device = device# 策略网络 - 模仿专家行为self.policy = nn.Sequential(nn.Linear(state_dim + 1, hidden_dim),  # +1 是为了包含时间信息nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, action_dim)).to(device)# 判别器网络 - 区分专家和策略生成的轨迹self.discriminator = nn.Sequential(nn.Linear(state_dim + action_dim + 1, hidden_dim),  # +1 是为了包含时间信息nn.ReLU(),nn.Linear(hidden_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, 1),nn.Sigmoid()).to(device)# 优化器self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=1e-3)self.discriminator_optimizer = optim.Adam(self.discriminator.parameters(), lr=1e-3)# 记录训练过程self.train_losses = []def _compute_time_weights(self, expert_times, current_time, sigma=1.0):"""计算时间权重,距离当前时间越近的样本权重越大"""time_diffs = np.abs(expert_times - current_time)weights = np.exp(-time_diffs / (2 * sigma**2))return weights / np.sum(weights)def _compute_mmd_loss(self, expert_states, policy_states, times, current_time):"""计算最大均值差异(MMD)损失,衡量分布差异"""# 计算时间权重weights = self._compute_time_weights(times, current_time)# 对专家状态应用时间权重weighted_expert_states = expert_states * weights.reshape(-1, 1)# 计算MMDexpert_kernel = rbf_kernel(weighted_expert_states, weighted_expert_states)policy_kernel = rbf_kernel(policy_states, policy_states)cross_kernel = rbf_kernel(weighted_expert_states, policy_states)mmd = np.mean(expert_kernel) + np.mean(policy_kernel) - 2 * np.mean(cross_kernel)return mmddef train(self, expert_states, expert_actions, expert_times, epochs=100, batch_size=64):"""训练时间感知的模仿学习模型expert_states: 专家状态序列 [num_samples, state_dim]expert_actions: 专家动作序列 [num_samples, action_dim]expert_times: 专家时间戳 [num_samples]"""num_samples = expert_states.shape[0]expert_states_tensor = torch.FloatTensor(expert_states).to(self.device)expert_actions_tensor = torch.FloatTensor(expert_actions).to(self.device)expert_times_tensor = torch.FloatTensor(expert_times).reshape(-1, 1).to(self.device)for epoch in range(epochs):# 当前"时间" - 使用训练轮次的比例作为时间表示current_time = epoch / epochs# 生成策略动作policy_actions = []for i in range(0, num_samples, batch_size):batch_states = expert_states_tensor[i:i+batch_size]batch_times = torch.full((batch_states.shape[0], 1), current_time).to(self.device)policy_action = self.policy(torch.cat([batch_states, batch_times], dim=1))policy_actions.append(policy_action.detach().cpu().numpy())policy_actions = np.vstack(policy_actions)# 计算MMD损失mmd_loss = self._compute_mmd_loss(expert_states, policy_actions, expert_times, current_time)# 训练判别器for _ in range(5):  # 判别器训练多次# 随机采样批次indices = np.random.randint(0, num_samples, batch_size)batch_expert_states = expert_states_tensor[indices]batch_expert_actions = expert_actions_tensor[indices]batch_expert_times = expert_times_tensor[indices]# 生成策略动作batch_times = torch.full((batch_size, 1), current_time).to(self.device)batch_policy_actions = self.policy(torch.cat([batch_expert_states, batch_times], dim=1))# 计算判别器损失expert_input = torch.cat([batch_expert_states, batch_expert_actions, batch_expert_times], dim=1)policy_input = torch.cat([batch_expert_states, batch_policy_actions, batch_times], dim=1)expert_output = self.discriminator(expert_input)policy_output = self.discriminator(policy_input)# 判别器损失 (最大化区分能力)d_loss = -torch.mean(torch.log(expert_output + 1e-8) + torch.log(1 - policy_output + 1e-8))self.discriminator_optimizer.zero_grad()d_loss.backward()self.discriminator_optimizer.step()# 训练策略网络for _ in range(1):  # 策略网络训练较少次数indices = np.random.randint(0, num_samples, batch_size)batch_states = expert_states_tensor[indices]batch_times = torch.full((batch_size, 1), current_time).to(self.device)# 生成策略动作actions = self.policy(torch.cat([batch_states, batch_times], dim=1))# 计算策略损失 (最小化判别器的区分能力)policy_input = torch.cat([batch_states, actions, batch_times], dim=1)policy_output = self.discriminator(policy_input)# 策略损失 + MMD正则化p_loss = -torch.mean(torch.log(policy_output + 1e-8)) + 0.1 * mmd_lossself.policy_optimizer.zero_grad()p_loss.backward()self.policy_optimizer.step()# 记录损失self.train_losses.append(p_loss.item())if epoch % 100 == 0:print(f"Epoch {epoch}, Loss: {p_loss.item():.4f}, MMD: {mmd_loss:.4f}")def predict(self, state, time):"""根据当前状态和时间预测动作"""state_tensor = torch.FloatTensor(state).reshape(1, -1).to(self.device)time_tensor = torch.FloatTensor([time]).reshape(1, 1).to(self.device)with torch.no_grad():action = self.policy(torch.cat([state_tensor, time_tensor], dim=1))return action.cpu().numpy()[0]def visualize_training(self):"""可视化训练过程"""plt.figure(figsize=(10, 6))plt.plot(self.train_losses)plt.title('Training Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.grid(True)plt.show()# 示例:生成具有时间分布变化的专家数据
def generate_time_varying_expert_data(num_samples=1000, state_dim=2, time_period=1.0):"""生成随时间变化的数据分布"""times = np.linspace(0, time_period, num_samples)states = []actions = []for t in times:# 状态分布随时间变化mean = np.array([np.sin(2 * np.pi * t), np.cos(2 * np.pi * t)])cov = np.diag([0.1 + 0.1 * np.abs(np.sin(np.pi * t)), 0.1 + 0.1 * np.abs(np.cos(np.pi * t))])state = np.random.multivariate_normal(mean, cov)# 动作是状态的函数,也随时间变化action = 2.0 * state * (1.0 + 0.5 * np.sin(2 * np.pi * t))states.append(state)actions.append(action)return np.array(states), np.array(actions), times# 测试算法
def test_time_aware_il():# 生成专家数据state_dim = 2action_dim = 2expert_states, expert_actions, expert_times = generate_time_varying_expert_data(num_samples=2000, state_dim=state_dim, time_period=1.0)# 创建并训练模型model = TimeAwareImitationLearning(state_dim, action_dim)model.train(expert_states, expert_actions, expert_times, epochs=500)# 可视化训练过程model.visualize_training()# 测试不同时间点的策略test_times = np.linspace(0, 1, 5)test_states = np.random.randn(len(test_times), state_dim)plt.figure(figsize=(12, 8))for i, t in enumerate(test_times):plt.subplot(2, 3, i+1)# 真实专家行为expert_mask = (expert_times >= t - 0.1) & (expert_times <= t + 0.1)plt.scatter(expert_states[expert_mask, 0], expert_states[expert_mask, 1], c='blue', alpha=0.5, label='Expert')# 模型预测行为pred_actions = np.array([model.predict(s, t) for s in expert_states[expert_mask]])plt.scatter(pred_actions[:, 0], pred_actions[:, 1], c='red', alpha=0.5, label='Policy')plt.title(f'Time = {t:.2f}')plt.xlabel('State 1')plt.ylabel('State 2')plt.legend()plt.tight_layout()plt.show()if __name__ == "__main__":test_time_aware_il()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/85997.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/85997.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

arm-none-eabi-ld: cannot find -lm

arm-none-eabi-ld -Tuser/hc32l13x.lds -o grbl_hc32l13x.elf user/interrupts_hc32l13x.o user/system_hc32l13x.o user/main.o user/startup_hc32l13x.o -lm -Mapgrbl_hc32l13x.map arm-none-eabi-ld: cannot find -lm makefile:33: recipe for target link failed 改为在gcc…

【Python办公】Excel文件批量样式修改器

目录 专栏导读1. 背景介绍2. 项目概述3. 库的安装4. 核心架构设计① 类结构设计② 数据模型1) 文件管理2) 样式配置5. 界面设计与实现① 布局结构② 动态组件生成6. 核心功能实现① 文件选择与管理② 颜色选择功能③ Excel文件处理核心逻辑完整代码结尾专栏导读 🌸 欢迎来到P…

QT的一些介绍

//虽然下面一行代码进行widget和ui的窗口关联&#xff0c;但是如果发生窗口大小变化的时候&#xff0c;里面的布局不会随之变化ui->setupUi(this);//通过下面这行代码进行显示说明&#xff0c;让窗口变化时&#xff0c;布局及其子控件随之变化this->setLayout(ui->ver…

RISC-V向量扩展与GPU协处理:开源加速器设计新范式——对比NVDLA与香山架构的指令集融合方案

点击 “AladdinEdu&#xff0c;同学们用得起的【H卡】算力平台”&#xff0c;H卡级别算力&#xff0c;按量计费&#xff0c;灵活弹性&#xff0c;顶级配置&#xff0c;学生专属优惠 当开源指令集遇上异构计算&#xff0c;RISC-V向量扩展&#xff08;RVV&#xff09;正重塑加速…

自动恢复网络路由配置的安全脚本说明

背景 两个文章 看了就明白 Ubuntu 多网卡路由配置笔记&#xff08;内网 外网同时通 可能有问题&#xff0c;以防万一可以按照个来恢复 sudo ip route replace 192.168.1.0/24 dev eno8403 proto kernel scope link src <你的IP>或者恢复脚本! 如下 误操作路由时&…

创建 Vue 3.0 项目的两种方法对比:npm init vue@latest vs npm init vite@latest

创建 Vue 3.0 项目的两种方法对比&#xff1a;npm init vuelatest vs npm init vitelatest Vue 3.0 作为当前主流的前端框架&#xff0c;官方提供了多种项目创建方式。本文将详细介绍两种最常用的创建方法&#xff1a;Vue CLI 方式 (npm init vuelatest) 和 Vite 方式 (npm in…

Java求职者面试指南:Spring, Spring Boot, Spring MVC, MyBatis技术点深度解析

Java求职者面试指南&#xff1a;Spring, Spring Boot, Spring MVC, MyBatis技术点深度解析 面试官与程序员JY的三轮提问 第一轮&#xff1a;基础概念问题 1. 请解释一下Spring框架的核心容器是什么&#xff1f;它有哪些主要功能&#xff1f; JY回答&#xff1a;Spring框架的…

【修复MySQL 主从Last_Errno:1051报错的几种解决方案】

当MySQL主从集群遇到Last_Errno:1051报错后不要着急&#xff0c;主要有三种解决方案&#xff1a; 方案1: 使用GTID场景&#xff1a; mysql> STOP SLAVE;(2)设置事务号&#xff0c;事务号从Retrieved_Gtid_Set获取 在session里设置gtid_next&#xff0c;即跳过这个GTID …

定位接口偶发超时的实战分析:iOS抓包流程的完整复现

我们通常把“请求超时”归结为网络不稳定、服务器慢响应&#xff0c;但在一次产品灰度发布中&#xff0c;我们遇到的一个“偶发接口超时”问题完全打破了这些常规判断。 这类Bug最大的问题不在于表现&#xff0c;而在于极难重现、不可预测、无法复盘。它不像逻辑Bug那样能从代…

【网工】华为配置专题进阶篇②

目录 ■DHCP NAT BFD 策略路由 ▲掩码与反掩码总结 ▲综合实验 ■DHCP NAT BFD 策略路由 ▲掩码与反掩码总结 使用掩码的场景&#xff1a;IP地址强相关 场景一&#xff1a;IP地址配置 ip address 192.168.1.1 255.255.255.0 或ip address 192.168.1.1 24 场景二&#x…

基于STM32电子密码锁

基于STM32电子密码锁 &#xff08;程序&#xff0b;原理图&#xff0b;PCB&#xff0b;设计报告&#xff09; 功能介绍 具体功能&#xff1a; 1.正确输入密码前提下&#xff0c;开锁并有正确提示&#xff1b; 2.错误输入密码情况下&#xff0c;蜂鸣器报警并短暂锁定键盘&…

前端基础知识CSS系列 - 14(CSS提高性能的方法)

一、前言 每一个网页都离不开css&#xff0c;但是很多人又认为&#xff0c;css主要是用来完成页面布局的&#xff0c;像一些细节或者优化&#xff0c;就不需要怎么考虑&#xff0c;实际上这种想法是不正确的 作为页面渲染和内容展现的重要环节&#xff0c;css影响着用户对整个…

判断 NI Package Manager (NIPM) 版本与 LabVIEW 2019 兼容性

​判断依据 1. 查阅 LabVIEW 2019 自述文件 LabVIEW 2019 自述文件中包含系统要求&#xff0c;可通过 NI 官网访问。文件提到使用 NIPM 安装&#xff0c;但未明确最低版本要求&#xff0c;需结合其他信息判断。 2. 参考 NI 官方兼容性文档 NI 官方文档指出 LabVIEW 运行引擎与…

Django 安装指南

Django 安装指南 引言 Django 是一个高级的 Python Web 框架,用于快速开发安全且实用的网站。本文将详细介绍如何在您的计算机上安装 Django,以便您能够开始使用这个强大的工具。 安装前的准备 在开始安装 Django 之前,请确保您的计算机满足以下条件: 操作系统:Django…

Spring MVC参数绑定终极手册:单多参对象集合JSON文件上传精讲

我们通过浏览器访问不同的路径&#xff0c;就是在发送不同的请求&#xff0c;在发送请求时&#xff0c;可能会带一些参数&#xff0c;本文将介绍了Spring MVC中处理不同请求参数的多种方式 一、传递单个参数 接收单个参数&#xff0c;在Spring MVC中直接用方法中的参数就可以…

synchronized 做了哪些优化?

Java 中的 synchronized 关键字是保证线程安全的基本机制&#xff0c;随着 JVM 的发展&#xff0c;它经历了多次优化以提高性能。 1. 锁升级机制&#xff08;锁膨胀&#xff09; JDK 1.6 引入了偏向锁→轻量级锁→重量级锁的升级机制&#xff0c;避免了一开始就使用重量级锁&…

三甲医院AI医疗样本数据集分类与收集全流程节点分析(下)

3.3 典型案例分析 —— 以某三甲医院为例 为了更深入地了解三甲医院 AI 医疗样本数据收集的实际情况,本研究选取了具有代表性的某三甲医院作为案例进行详细分析。该医院作为区域医疗中心,在医疗技术、设备和人才方面具有显著优势,同时在医疗信息化建设和 AI 应用方面也进行…

设置程序开机自动启动

在Windows系统中&#xff0c;有几种方法可以将程序设置为开机自动启动。下面我将介绍最常用的三种方法&#xff0c;并提供一个C#实现示例。 方法一&#xff1a;使用启动文件夹&#xff08;最简单&#xff09; 按下 Win R 键打开运行对话框 输入 shell:startup 并回车 将你的…

多源异构数据接入与实时分析:衡石科技的技术突破

在数字化转型的浪潮中&#xff0c;企业每天产生的数据量呈指数级增长。这些数据来自CRM系统、IoT设备、日志文件、社交媒体、交易平台等众多源头&#xff0c;格式各异、结构混乱、流速不一。传统的数据处理方式如同在无数孤立的岛屿间划着小船传递信息&#xff0c;效率低下且无…

JVM——Synchronized:同步锁的原理及应用

引入 在多线程编程的世界里&#xff0c;共享资源的访问控制就像一场精心设计的交通管制&#xff0c;而Synchronized作为Java并发编程的基础同步机制&#xff0c;扮演着"交通警察"的关键角色。 并发编程的核心矛盾 当多个线程同时访问共享资源时&#xff0c;"…