OpenSTL PredRNNv2 模型复现与自定义数据集训练

OpenSTL PredRNNv2 模型复现与自定义数据集训练

概述

本文将详细介绍如何复现 OpenSTL 中的 PredRNNv2 模型,并使用自定义的 NPY 格式数据集进行训练和预测。我们将从环境配置开始,逐步讲解数据预处理、模型构建、训练过程和预测实现,最终实现输入多张连续时间序列的 500×500 图像并输出相应数量预测图像的目标。

目录

  1. 环境配置与依赖安装
  2. 数据集准备与预处理
  3. PredRNNv2 模型原理与架构
  4. 数据加载器实现
  5. 模型训练流程
  6. 预测与结果可视化
  7. 模型评估与优化
  8. 完整代码实现
  9. 常见问题与解决方案
  10. 总结与展望

1. 环境配置与依赖安装

首先,我们需要创建一个合适的 Python 环境并安装所有必要的依赖包。

# 创建conda环境
conda create -n openstl python=3.8
conda activate openstl# 安装PyTorch (根据CUDA版本选择)
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116# 安装其他依赖
pip install numpy==1.21.6
pip install opencv-python==4.7.0.72
pip install matplotlib==3.5.3
pip install tensorboard==2.11.2
pip install scikit-learn==1.0.2
pip install tqdm==4.64.1
pip install nni==2.8
pip install timm==0.6.12
pip install einops==0.6.0

接下来,我们需要克隆 OpenSTL 仓库并安装相关依赖:

git clone https://github.com/chengtan9907/OpenSTL.git
cd OpenSTL
git checkout OpenSTL-Lightning
pip install -e .

2. 数据集准备与预处理

我们的数据集是 NPY 格式的文件,每张图像尺寸为 500×500,且文件之间在时间上是连续的。首先,我们需要了解数据集的目录结构:

dataset/
├── train/
│   ├── sequence_001/
│   │   ├── frame_001.npy
│   │   ├── frame_002.npy
│   │   └── ...
│   ├── sequence_002/
│   └── ...
├── valid/
└── test/

2.1 数据预处理类实现

我们需要创建一个数据预处理类,将 NPY 文件转换为模型可用的格式:

import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
import cv2class NPYDataset(Dataset):def __init__(self, data_root, mode='train', input_frames=10, output_frames=10, future_frames=10, transform=None, preprocess=True):"""初始化NPY数据集参数:data_root: 数据根目录mode: 模式 ('train', 'valid', 'test')input_frames: 输入帧数output_frames: 输出帧数future_frames: 未来帧数 (预测帧数)transform: 数据转换函数preprocess: 是否进行预处理"""self.data_root = os.path.join(data_root, mode)self.mode = modeself.input_frames = input_framesself.output_frames = output_framesself.future_frames = future_framesself.transform = transformself.preprocess = preprocess# 获取所有序列self.sequences = []for seq_name in os.listdir(self.data_root):seq_path = os.path.join(self.data_root, seq_name)if os.path.isdir(seq_path):frames = sorted([f for f in os.listdir(seq_path) if f.endswith('.npy')])if len(frames) >= input_frames + future_frames:self.sequences.append((seq_path, frames))# 数据标准化器self.scaler = Noneif preprocess:self._init_scaler()def _init_scaler(self):"""初始化数据标准化器"""print(f"Initializing scaler for {self.mode} mode...")all_data = []for seq_path, frames in self.sequences:for frame_name in frames[:min(100, len(frames))]:  # 使用前100帧计算统计量frame_path = os.path.join(seq_path, frame_name)data = np.load(frame_path)all_data.append(data.flatten())all_data = np.concatenate(all_data).reshape(-1, 1)self.scaler = StandardScaler()self.scaler.fit(all_data)print("Scaler initialized.")def _preprocess_data(self, data):"""预处理数据"""if self.preprocess and self.scaler is not None:original_shape = data.shapedata = data.flatten().reshape(-1, 1)data = self.scaler.transform(data)data = data.reshape(original_shape)return datadef _postprocess_data(self, data):"""后处理数据"""if self.preprocess and self.scaler is not None:original_shape = data.shapedata = data.flatten().reshape(-1, 1)data = self.scaler.inverse_transform(data)data = data.reshape(original_shape)return datadef __len__(self):return len(self.sequences)def __getitem__(self, idx):seq_path, frames = self.sequences[idx]# 随机选择起始帧total_frames = len(frames)max_start = total_frames - self.input_frames - self.future_framesstart_idx = np.random.randint(0, max_start + 1) if self.mode == 'train' else 0# 加载输入帧input_frames = []for i in range(start_idx, start_idx + self.input_frames):frame_path = os.path.join(seq_path, frames[i])frame_data = np.load(frame_path)frame_data = self._preprocess_data(frame_data)input_frames.append(frame_data)# 加载目标帧target_frames = []for i in range(start_idx + self.input_frames, start_idx + self.input_frames + self.future_frames):frame_path = os.path.join(seq_path, frames[i])frame_data = np.load(frame_path)frame_data = self._preprocess_data(frame_data)target_frames.append(frame_data)# 转换为numpy数组input_seq = np.stack(input_frames, axis=0)target_seq = np.stack(target_frames, axis=0)# 添加通道维度input_seq = np.expand_dims(input_seq, axis=1)  # [T, 1, H, W]target_seq = np.expand_dims(target_seq, axis=1)  # [T, 1, H, W]# 转换为张量input_seq = torch.FloatTensor(input_seq)target_seq = torch.FloatTensor(target_seq)if self.transform:input_seq = self.transform(input_seq)target_seq = self.transform(target_seq)return input_seq, target_seq# 数据增强转换
class RandomRotate:def __init__(self, angles=[0, 90, 180, 270]):self.angles = anglesdef __call__(self, x):angle = np.random.choice(self.angles)if angle == 0:return x# 旋转每个帧rotated = []for i in range(x.shape[0]):frame = x[i].numpy()# 对于3D数据,我们需要分别旋转每个通道if len(frame.shape) == 3:frame_rotated = np.stack([cv2.rotate(frame[c], cv2.ROTATE_90_CLOCKWISE) for c in range(frame.shape[0])], axis=0)else:frame_rotated = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)rotated.append(frame_rotated)return torch.FloatTensor(np.stack(rotated, axis=0))class RandomFlip:def __init__(self, p=0.5):self.p = pdef __call__(self, x):if np.random.random() < self.p:# 水平翻转return x.flip(-1)return x

3. PredRNNv2 模型原理与架构

PredRNNv2 是一种改进的循环神经网络,专门用于视频预测任务。它通过引入时空记忆(STM)单元来更好地捕捉时空动态。

3.1 核心组件

import torch
import torch.nn as nn
from einops import rearrangeclass SpatioTemporalLSTMCell(nn.Module):def __init__(self, in_channel, num_hidden, height, width, filter_size, stride, layer_norm):super(SpatioTemporalLSTMCell, self).__init__()self.num_hidden = num_hiddenself.padding = filter_size // 2self._forget_bias = 1.0# 卷积层self.conv_x = nn.Sequential(nn.Conv2d(in_channel, num_hidden * 7, kernel_size=filter_size,stride=stride, padding=self.padding, bias=False),nn.LayerNorm([num_hidden * 7, height, width]))self.conv_h = nn.Sequential(nn.Conv2d(num_hidden, num_hidden * 4, kernel_size=filter_size,stride=stride, padding=self.padding, bias=False),nn.LayerNorm([num_hidden * 4, height, width]))self.conv_m = nn.Sequential(nn.Conv2d(num_hidden, num_hidden * 3, kernel_size=filter_size,stride=stride, padding=self.padding, bias=False),nn.LayerNorm([num_hidden * 3, height, width]))self.conv_o = nn.Sequential(nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=filter_size,stride=stride, padding=self.padding, bias=False),nn.LayerNorm([num_hidden, height, width]))self.conv_last = nn.Conv2d(num_hidden * 2, num_hidden, kernel_size=1,stride=1, padding=0, bias=False)def forward(self, x_t, h_t, c_t, m_t):# 计算门控信号x_concat = self.conv_x(x_t)h_concat = self.conv_h(h_t)m_concat = self.conv_m(m_t)i_x, f_x, g_x, i_x_prime, f_x_prime, g_x_prime, o_x = torch.split(x_concat, self.num_hidden, dim=1)i_h, f_h, g_h, o_h = torch.split(h_concat, self.num_hidden, dim=1)i_m, f_m, g_m = torch.split(m_concat, self.num_hidden, dim=1)i_t = torch.sigmoid(i_x + i_h)f_t = torch.sigmoid(f_x + f_h + self._forget_bias)g_t = torch.tanh(g_x + g_h)c_new = f_t * c_t + i_t * g_ti_t_prime = torch.sigmoid(i_x_prime + i_m)f_t_prime = torch.sigmoid(f_x_prime + f_m + self._forget_bias)g_t_prime = torch.tanh(g_x_prime + g_m)m_new = f_t_prime * m_t + i_t_prime * g_t_primemem = torch.cat((c_new, m_new), 1)o_t = torch.sigmoid(o_x + o_h + self.conv_o(mem))h_new = o_t * torch.tanh(self.conv_last(mem))return h_new, c_new, m_newclass PredRNNv2(nn.Module):def __init__(self, configs):super(PredRNNv2, self).__init__()self.configs = configsself.frame_channel = configs.patch_size * configs.patch_size * configs.img_channelself.num_layers = len(configs.num_hidden)self.num_hidden = configs.num_hiddenself.device = configs.device# 构建网络cell_list = []height = configs.img_height // configs.patch_sizewidth = configs.img_width // configs.patch_sizefor i in range(self.num_layers):in_channel = self.frame_channel if i == 0 else self.num_hidden[i-1]cell_list.append(SpatioTemporalLSTMCell(in_channel, self.num_hidden[i], height, width,configs.filter_size, configs.stride, configs.layer_norm))self.cell_list = nn.ModuleList(cell_list)# 输出层self.conv_last = nn.Conv2d(self.num_hidden[self.num_layers-1], self.frame_channel,kernel_size=1, stride=1, padding=0, bias=False)def forward(self, frames_tensor, mask_true):#  frames_tensor: [batch, length, channel, height, width]batch = frames_tensor.shape[0]height = frames_tensor.shape[3]width = frames_tensor.shape[4]# 初始化隐藏状态和记忆状态next_frames = []h_t = []c_t = []m_t = []for i in range(self.num_layers):zeros = torch.zeros([batch, self.num_hidden[i], height, width]).to(self.device)h_t.append(zeros)c_t.append(zeros)m_t.append(zeros)# 记忆状态memory = torch.zeros([batch, self.num_hidden[0], height, width]).to(self.device)# 序列长度seq_length = self.configs.input_length + self.configs.total_lengthfor t in range(seq_length - 1):# 反向调度采样if self.configs.reverse_scheduled_sampling == 1:if t == 0:net = frames_tensor[:, t]else:# 从真实数据或预测数据中采样net = frames_tensor[:, t] if mask_true[:, t] == 1 else x_genelse:# 常规训练if t < self.configs.input_length:net = frames_tensor[:, t]else:# 从真实数据或预测数据中采样net = frames_tensor[:, t] if mask_true[:, t] == 1 else x_gen# 第一层h_t[0], c_t[0], m_t[0] = self.cell_list[0](net, h_t[0], c_t[0], m_t[0])# 后续层for i in range(1, self.num_layers):h_t[i], c_t[i], m_t[i] = self.cell_list[i](h_t[i-1], h_t[i], c_t[i], m_t[i])# 生成预测x_gen = self.conv_last(h_t[self.num_layers-1])next_frames.append(x_gen)# [length, batch, channel, height, width] -> [batch, length, channel, height, width]next_frames = torch.stack(next_frames, dim=0).permute(1, 0, 2, 3, 4)return next_frames

4. 数据加载器实现

接下来,我们需要实现数据加载器,将数据集转换为模型可用的格式:

def create_data_loaders(configs):"""创建训练、验证和测试数据加载器"""# 数据转换if configs.data_augmentation:train_transform = nn.Sequential(RandomRotate(),RandomFlip())else:train_transform = None# 创建数据集train_dataset = NPYDataset(data_root=configs.data_root,mode='train',input_frames=configs.input_length,output_frames=configs.total_length - configs.input_length,future_frames=configs.total_length - configs.input_length,transform=train_transform,preprocess=configs.preprocess_data)valid_dataset = NPYDataset(data_root=configs.data_root,mode='valid',input_frames=configs.input_length,output_frames=configs.total_length - configs.input_length,future_frames=configs.total_length - configs.input_length,transform=None,preprocess=configs.preprocess_data)test_dataset = NPYDataset(data_root=configs.data_root,mode='test',input_frames=configs.input_length,output_frames=configs.total_length - configs.input_length,future_frames=configs.total_length - configs.input_length,transform=None,preprocess=configs.preprocess_data)# 创建数据加载器train_loader = DataLoader(train_dataset,batch_size=configs.batch_size,shuffle=True,num_workers=configs.num_workers,pin_memory=True)valid_loader = DataLoader(valid_dataset,batch_size=configs.batch_size,shuffle=False,num_workers=configs.num_workers,pin_memory=True)test_loader = DataLoader(test_dataset,batch_size=configs.batch_size,shuffle=False,num_workers=configs.num_workers,pin_memory=True)return train_loader, valid_loader, test_loader

5. 模型训练流程

现在,我们实现完整的训练流程,包括损失函数、优化器和学习率调度器:

import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import time
from tqdm import tqdmclass Trainer:def __init__(self, configs, model, train_loader, valid_loader, test_loader):self.configs = configsself.model = modelself.train_loader = train_loaderself.valid_loader = valid_loaderself.test_loader = test_loaderself.device = configs.device# 损失函数self.criterion = nn.MSELoss()# 优化器self.optimizer = optim.Adam(model.parameters(),lr=configs.lr,weight_decay=configs.weight_decay)# 学习率调度器self.scheduler = ReduceLROnPlateau(self.optimizer,mode='min',factor=0.5,patience=5,verbose=True)# 记录训练历史self.train_losses = []self.valid_losses = []self.best_loss = float('inf')# 创建检查点目录os.makedirs(configs.save_dir, exist_ok=True)def train_epoch(self, epoch):"""训练一个epoch"""self.model.train()total_loss = 0progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')for batch_idx, (inputs, targets) in enumerate(progress_bar):inputs = inputs.to(self.device)targets = targets.to(self.device)# 前向传播self.optimizer.zero_grad()outputs = self.model(inputs, mask_true=None)# 计算损失loss = self.criterion(outputs, targets)# 反向传播loss.backward()self.optimizer.step()total_loss += loss.item()progress_bar.set_postfix({'loss': loss.item()})avg_loss = total_loss / len(self.train_loader)self.train_losses.append(avg_loss)return avg_lossdef validate(self):"""验证模型"""self.model.eval()total_loss = 0with torch.no_grad():for inputs, targets in self.valid_loader:inputs = inputs.to(self.device)targets = targets.to(self.device)outputs = self.model(inputs, mask_true=None)loss = self.criterion(outputs, targets)total_loss += loss.item()avg_loss = total_loss / len(self.valid_loader)self.valid_losses.append(avg_loss)return avg_lossdef test(self):"""测试模型"""self.model.eval()total_loss = 0all_outputs = []all_targets = []with torch.no_grad():for inputs, targets in self.test_loader:inputs = inputs.to(self.device)targets = targets.to(self.device)outputs = self.model(inputs, mask_true=None)loss = self.criterion(outputs, targets)total_loss += loss.item()# 保存结果用于后续分析all_outputs.append(outputs.cpu().numpy())all_targets.append(targets.cpu().numpy())avg_loss = total_loss / len(self.test_loader)return avg_loss, np.concatenate(all_outputs, axis=0), np.concatenate(all_targets, axis=0)def save_checkpoint(self, epoch, is_best=False):"""保存检查点"""checkpoint = {'epoch': epoch,'model_state_dict': self.model.state_dict(),'optimizer_state_dict': self.optimizer.state_dict(),'scheduler_state_dict': self.scheduler.state_dict(),'train_losses': self.train_losses,'valid_losses': self.valid_losses,'best_loss': self.best_loss}# 保存最新检查点torch.save(checkpoint, os.path.join(self.configs.save_dir, 'latest_checkpoint.pth'))# 如果是最佳模型,保存为最佳检查点if is_best:torch.save(checkpoint, os.path.join(self.configs.save_dir, 'best_checkpoint.pth'))def load_checkpoint(self, checkpoint_path):"""加载检查点"""checkpoint = torch.load(checkpoint_path)self.model.load_state_dict(checkpoint['model_state_dict'])self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])self.train_losses = checkpoint['train_losses']self.valid_losses = checkpoint['valid_losses']self.best_loss = checkpoint['best_loss']return checkpoint['epoch']def train(self, num_epochs):"""完整训练过程"""start_epoch = 0# 如果存在检查点,加载检查点if self.configs.resume and os.path.exists(os.path.join(self.configs.save_dir, 'latest_checkpoint.pth')):print("Loading checkpoint...")start_epoch = self.load_checkpoint(os.path.join(self.configs.save_dir, 'latest_checkpoint.pth'))print(f"Resumed from epoch {start_epoch}")for epoch in range(start_epoch, num_epochs):print(f"\nEpoch {epoch+1}/{num_epochs}")# 训练train_loss = self.train_epoch(epoch)print(f"Train Loss: {train_loss:.6f}")# 验证valid_loss = self.validate()print(f"Valid Loss: {valid_loss:.6f}")# 更新学习率self.scheduler.step(valid_loss)# 保存检查点is_best = valid_loss < self.best_lossif is_best:self.best_loss = valid_lossself.save_checkpoint(epoch, is_best)# 每5个epoch测试一次if (epoch + 1) % 5 == 0:test_loss, _, _ = self.test()print(f"Test Loss: {test_loss:.6f}")# 最终测试print("\nFinal Testing...")test_loss, outputs, targets = self.test()print(f"Final Test Loss: {test_loss:.6f}")return test_loss, outputs, targets

6. 预测与结果可视化

实现预测功能和结果可视化:

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGridclass Predictor:def __init__(self, configs, model):self.configs = configsself.model = modelself.device = configs.deviceself.model.eval()def predict(self, input_seq):"""预测未来帧"""with torch.no_grad():input_seq = input_seq.to(self.device)output_seq = self.model(input_seq, mask_true=None)return output_seq.cpu()def visualize_results(self, inputs, targets, predictions, save_path=None):"""可视化输入、目标和预测结果"""# 选择第一个批次进行可视化inputs = inputs[0].squeeze()  # [T, H, W]targets = targets[0].squeeze()  # [T, H, W]predictions = predictions[0].squeeze()  # [T, H, W]# 创建子图total_frames = inputs.shape[0] + targets.shape[0]fig = plt.figure(figsize=(20, 10))grid = ImageGrid(fig, 111, nrows_ncols=(3, total_frames), axes_pad=0.1)# 绘制输入帧for i in range(inputs.shape[0]):ax = grid[i]ax.imshow(inputs[i], cmap='viridis')ax.set_title(f'Input {i+1}')ax.axis('off')# 绘制目标帧for i in range(targets.shape[0]):ax = grid[inputs.shape[0] + i]ax.imshow(targets[i], cmap='viridis')ax.set_title(f'Target {i+1}')ax.axis('off')# 绘制预测帧for i in range(predictions.shape[0]):ax = grid[inputs.shape[0] + targets.shape[0] + i]ax.imshow(predictions[i], cmap='viridis')ax.set_title(f'Pred {i+1}')ax.axis('off')plt.tight_layout()if save_path:plt.savefig(save_path, dpi=300, bbox_inches='tight')plt.show()def save_predictions(self, predictions, save_dir):"""保存预测结果为NPY文件"""os.makedirs(save_dir, exist_ok=True)for i, pred_seq in enumerate(predictions):for j, frame in enumerate(pred_seq):frame_path = os.path.join(save_dir, f'batch_{i}_frame_{j}.npy')np.save(frame_path, frame.squeeze())def evaluate_metrics(self, targets, predictions):"""评估预测性能"""from sklearn.metrics import mean_squared_error, mean_absolute_error# 展平数据targets_flat = targets.flatten()predictions_flat = predictions.flatten()# 计算指标mse = mean_squared_error(targets_flat, predictions_flat)mae = mean_absolute_error(targets_flat, predictions_flat)rmse = np.sqrt(mse)# 计算PSNRmax_val = np.max(targets_flat)psnr = 20 * np.log10(max_val / rmse) if rmse > 0 else float('inf')# 计算SSIM (需要安装skimage)try:from skimage.metrics import structural_similarity as ssim_funcssim = ssim_func(targets_flat.reshape(targets.shape), predictions_flat.reshape(targets.shape),data_range=max_val)except ImportError:ssim = 0print("SSIM calculation requires skimage. Install with: pip install scikit-image")return {'MSE': mse,'MAE': mae,'RMSE': rmse,'PSNR': psnr,'SSIM': ssim}

7. 模型评估与优化

实现模型评估和超参数优化功能:

def hyperparameter_optimization(configs):"""超参数优化"""import nni# 获取NNI超参数optimized_params = nni.get_next_parameter()configs.lr = optimized_params.get('lr', configs.lr)configs.batch_size = optimized_params.get('batch_size', configs.batch_size)configs.num_hidden = optimized_params.get('num_hidden', configs.num_hidden)# 创建模型和数据加载器model = PredRNNv2(configs).to(configs.device)train_loader, valid_loader, test_loader = create_data_loaders(configs)# 训练模型trainer = Trainer(configs, model, train_loader, valid_loader, test_loader)test_loss, _, _ = trainer.train(configs.epoch)# 报告最终结果nni.report_final_result(test_loss)return test_lossdef analyze_results(configs, outputs, targets):"""分析预测结果"""predictor = Predictor(configs, None)metrics = predictor.evaluate_metrics(targets, outputs)print("Evaluation Metrics:")for metric, value in metrics.items():print(f"{metric}: {value:.4f}")# 绘制损失曲线plt.figure(figsize=(10, 6))plt.plot(range(len(outputs)), outputs.flatten(), label='Predictions', alpha=0.7)plt.plot(range(len(targets)), targets.flatten(), label='Targets', alpha=0.7)plt.xlabel('Sample Index')plt.ylabel('Value')plt.title('Predictions vs Targets')plt.legend()plt.grid(True)plt.savefig(os.path.join(configs.save_dir, 'predictions_vs_targets.png'), dpi=300)plt.show()return metrics

8. 完整代码实现

现在,我们将所有组件整合到一个完整的脚本中:

import argparse
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from models import PredRNNv2
from data_loader import NPYDataset, create_data_loaders
from trainer import Trainer
from predictor import Predictor
from utils import analyze_resultsdef parse_args():parser = argparse.ArgumentParser(description='PredRNNv2 for NPY dataset')# 数据参数parser.add_argument('--data_root', type=str, default='./dataset', help='数据集根目录')parser.add_argument('--input_length', type=int, default=10, help='输入帧数')parser.add_argument('--total_length', type=int, default=20, help='总帧数(输入+预测)')parser.add_argument('--img_width', type=int, default=500, help='图像宽度')parser.add_argument('--img_height', type=int, default=500, help='图像高度')parser.add_argument('--img_channel', type=int, default=1, help='图像通道数')parser.add_argument('--preprocess_data', type=bool, default=True, help='是否预处理数据')parser.add_argument('--data_augmentation', type=bool, default=True, help='是否使用数据增强')# 模型参数parser.add_argument('--num_hidden', type=list, default=[64, 64, 64, 64], help='每层隐藏单元数')parser.add_argument('--filter_size', type=int, default=5, help='滤波器大小')parser.add_argument('--stride', type=int, default=1, help='步长')parser.add_argument('--patch_size', type=int, default=1, help='补丁大小')parser.add_argument('--layer_norm', type=bool, default=True, help='是否使用层归一化')parser.add_argument('--reverse_scheduled_sampling', type=int, default=0, help='反向调度采样')# 训练参数parser.add_argument('--batch_size', type=int, default=4, help='批次大小')parser.add_argument('--lr', type=float, default=1e-3, help='学习率')parser.add_argument('--weight_decay', type=float, default=0, help='权重衰减')parser.add_argument('--epoch', type=int, default=100, help='训练轮数')parser.add_argument('--num_workers', type=int, default=4, help='数据加载工作线程数')parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help='设备')parser.add_argument('--save_dir', type=str, default='./checkpoints', help='保存目录')parser.add_argument('--resume', type=bool, default=False, help='是否恢复训练')# 其他参数parser.add_argument('--mode', type=str, default='train', choices=['train', 'test', 'predict'], help='运行模式')parser.add_argument('--checkpoint_path', type=str, default='', help='检查点路径')return parser.parse_args()def main():# 解析参数configs = parse_args()# 创建保存目录os.makedirs(configs.save_dir, exist_ok=True)# 创建模型model = PredRNNv2(configs).to(configs.device)print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")if configs.mode == 'train':# 创建数据加载器train_loader, valid_loader, test_loader = create_data_loaders(configs)# 创建训练器并开始训练trainer = Trainer(configs, model, train_loader, valid_loader, test_loader)test_loss, outputs, targets = trainer.train(configs.epoch)# 分析结果analyze_results(configs, outputs, targets)elif configs.mode == 'test':# 加载检查点if configs.checkpoint_path:checkpoint = torch.load(configs.checkpoint_path)model.load_state_dict(checkpoint['model_state_dict'])print(f"Loaded checkpoint from {configs.checkpoint_path}")# 创建数据加载器_, _, test_loader = create_data_loaders(configs)# 测试模型trainer = Trainer(configs, model, None, None, test_loader)test_loss, outputs, targets = trainer.test()print(f"Test Loss: {test_loss:.6f}")# 分析结果metrics = analyze_results(configs, outputs, targets)# 保存结果np.save(os.path.join(configs.save_dir, 'test_outputs.npy'), outputs)np.save(os.path.join(configs.save_dir, 'test_targets.npy'), targets)elif configs.mode == 'predict':# 加载检查点if configs.checkpoint_path:checkpoint = torch.load(configs.checkpoint_path)model.load_state_dict(checkpoint['model_state_dict'])print(f"Loaded checkpoint from {configs.checkpoint_path}")# 创建预测器predictor = Predictor(configs, model)# 加载要预测的数据# 这里假设有一个单独的预测数据集predict_dataset = NPYDataset(data_root=configs.data_root,mode='predict',input_frames=configs.input_length,output_frames=configs.total_length - configs.input_length,future_frames=configs.total_length - configs.input_length,transform=None,preprocess=configs.preprocess_data)predict_loader = DataLoader(predict_dataset,batch_size=configs.batch_size,shuffle=False,num_workers=configs.num_workers,pin_memory=True)all_predictions = []all_inputs = []with torch.no_grad():for inputs, _ in predict_loader:inputs = inputs.to(configs.device)predictions = predictor.predict(inputs)all_predictions.append(predictions.numpy())all_inputs.append(inputs.cpu().numpy())all_predictions = np.concatenate(all_predictions, axis=0)all_inputs = np.concatenate(all_inputs, axis=0)# 保存预测结果output_dir = os.path.join(configs.save_dir, 'predictions')os.makedirs(output_dir, exist_ok=True)for i, (input_seq, pred_seq) in enumerate(zip(all_inputs, all_predictions)):# 保存输入序列for j, frame in enumerate(input_seq):frame_path = os.path.join(output_dir, f'sequence_{i:03d}_input_{j:03d}.npy')np.save(frame_path, frame.squeeze())# 保存预测序列for j, frame in enumerate(pred_seq):frame_path = os.path.join(output_dir, f'sequence_{i:03d}_pred_{j:03d}.npy')np.save(frame_path, frame.squeeze())print(f"Predictions saved to {output_dir}")# 可视化一些结果if len(all_inputs) > 0:sample_idx = 0predictor.visualize_results(all_inputs[sample_idx:sample_idx+1],all_predictions[sample_idx:sample_idx+1],all_predictions[sample_idx:sample_idx+1],save_path=os.path.join(output_dir, 'sample_prediction.png'))if __name__ == '__main__':main()

9. 常见问题与解决方案

9.1 内存不足问题

当处理 500×500 的大尺寸图像时,可能会遇到内存不足的问题。解决方案:

  1. 使用数据分块:将大图像分割成小块进行处理
  2. 降低批次大小:减少每次处理的样本数量
  3. 使用混合精度训练:使用半精度浮点数减少内存占用
# 混合精度训练示例
from torch.cuda.amp import autocast, GradScalerdef train_epoch_with_amp(self, epoch):"""使用混合精度训练一个epoch"""self.model.train()total_loss = 0scaler = GradScaler()progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')for batch_idx, (inputs, targets) in enumerate(progress_bar):inputs = inputs.to(self.device)targets = targets.to(self.device)# 使用自动混合精度with autocast():outputs = self.model(inputs, mask_true=None)loss = self.criterion(outputs, targets)# 缩放损失并反向传播self.optimizer.zero_grad()scaler.scale(loss).backward()scaler.step(self.optimizer)scaler.update()total_loss += loss.item()progress_bar.set_postfix({'loss': loss.item()})avg_loss = total_loss / len(self.train_loader)self.train_losses.append(avg_loss)return avg_loss

9.2 训练不稳定问题

PredRNNv2 模型训练可能会不稳定,可以尝试以下方法:

  1. 梯度裁剪:防止梯度爆炸
  2. 学习率调度:动态调整学习率
  3. 权重初始化:使用合适的初始化方法
# 梯度裁剪示例
def train_epoch_with_gradient_clipping(self, epoch, clip_value=1.0):"""带梯度裁剪的训练"""self.model.train()total_loss = 0progress_bar = tqdm(self.train_loader, desc=f'Epoch {epoch}')for batch_idx, (inputs, targets) in enumerate(progress_bar):inputs = inputs.to(self.device)targets = targets.to(self.device)self.optimizer.zero_grad()outputs = self.model(inputs, mask_true=None)loss = self.criterion(outputs, targets)loss.backward()# 梯度裁剪torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_value)self.optimizer.step()total_loss += loss.item()progress_bar.set_postfix({'loss': loss.item()})avg_loss = total_loss / len(self.train_loader)self.train_losses.append(avg_loss)return avg_loss

9.3 过拟合问题

当模型在训练集上表现良好但在验证集上表现不佳时,可能存在过拟合问题:

  1. 数据增强:增加数据多样性
  2. 正则化:使用 Dropout 或权重衰减
  3. 早停:在验证损失不再改善时停止训练
# 早停实现
class EarlyStopping:def __init__(self, patience=10, min_delta=0):self.patience = patienceself.min_delta = min_deltaself.counter = 0self.best_loss = Noneself.early_stop = Falsedef __call__(self, val_loss):if self.best_loss is None:self.best_loss = val_losselif val_loss > self.best_loss - self.min_delta:self.counter += 1if self.counter >= self.patience:self.early_stop = Trueelse:self.best_loss = val_lossself.counter = 0return self.early_stop# 在训练循环中使用早停
early_stopping = EarlyStopping(patience=10)for epoch in range(num_epochs):# 训练和验证...if early_stopping(valid_loss):print("Early stopping triggered")break

10. 总结与展望

本文详细介绍了如何复现 OpenSTL 中的 PredRNNv2 模型,并使用自定义的 NPY 格式数据集进行训练和预测。我们涵盖了从环境配置、数据预处理、模型构建到训练和评估的完整流程。

10.1 主要成果

  1. 完整的数据处理流程:实现了针对 NPY 格式数据的加载、预处理和增强功能
  2. PredRNNv2 模型复现:成功实现了 PredRNNv2 模型的核心组件和完整架构
  3. 训练框架:构建了完整的训练、验证和测试流程,包括损失函数、优化器和学习率调度
  4. 预测与可视化:实现了预测功能和结果可视化,便于分析模型性能
  5. 问题解决方案:提供了针对常见问题(内存不足、训练不稳定、过拟合)的解决方案

10.2 未来工作方向

  1. 模型优化:尝试更先进的视频预测模型,如 SimVP、PhyDNet 等
  2. 多模态融合:结合其他传感器数据(如气象数据、地理信息)提高预测精度
  3. 实时预测:优化模型推理速度,实现实时预测功能
  4. 不确定性量化:增加对预测结果不确定性的估计
  5. 部署优化:将模型部署到生产环境,支持大规模数据处理

通过本文的指导和代码实现,读者应该能够成功复现 PredRNNv2 模型,并在自己的数据集上进行训练和预测。希望这项工作能够为视频预测任务的研究和应用提供有价值的参考。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/99239.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/99239.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux内核IPv4隧道模式封装机制剖析

概述 在Linux网络栈中,XFRM(Transform)子系统负责实现IPsec等安全协议的功能。其中,xfrm4_mode_tunnel.c是实现IPv4隧道模式封装的核心模块,为IPv4数据包提供隧道模式的封装和解封装能力。本文将深入分析这一模块的实现机制。 模块架构与功能 该模块通过注册到XFRM框架…

OPC Client第10讲:实现主界面;获取初始界面传来的所有配置信息config【C++读写Excel:xlnx;ODBC;缓冲区】

接前面代码内容&#xff1a; OPC Client第6讲&#xff08;wxwidgets&#xff09;&#xff1a;Logger.h日志记录文件&#xff08;单例模式&#xff09;&#xff1b;登录后的主界面_wx.logger-CSDN博客 OPC Client第8讲&#xff1a;OPC UA&#xff1b;KEPServerEX创建OPC服务器…

快速入门HarmonyOS应用开发(一)

目录 前言 一、准备工作 二、实战开发 2.1、Navigation简介 2.2、页面路由开发 2.2.1、创建常量 2.2.2、创建字符串资源 2.2.3、创建float资源 2.2.4、创建color资源 2.2.5、创建数据实体 2.2.6、创建页面路由表 2.2.7、创建Navigation根容器 2.2.8、创建NavDesti…

AI 进课堂 - 语文教学流程重塑

AI 进课堂 - 语文教学流程重塑执教语文十余年&#xff0c;备课案头的参考书堆得比学生作业本还高&#xff0c;批改作文时红笔芯换得比粉笔还勤。 直到去年把 JBoltAI 请进课堂&#xff0c;那些重复机械的工作突然有了新解法&#xff0c;连课堂上孩子们的眼神都亮了许多 —— 这…

用户是否可以同时使用快照和备份来保护云服务器数据安全?

在云计算环境中&#xff0c;云服务器已成为企业和个人数据存储、应用部署和业务运营的重要平台。随着业务数据量的不断增长&#xff0c;数据安全和业务连续性成为用户关注的核心问题。云服务器提供的快照和备份功能为用户提供了有效的数据保护手段&#xff0c;但很多人会疑问&a…

RDS-MYSQL,这个RDS是什么?和mysql有什么区别?

好的&#xff0c;这是一个非常常见且重要的问题。我用最通俗易懂的方式给你解释清楚。 一、大白话解释 你可以把 MySQL 和 RDS MySQL 的关系&#xff0c;想象成&#xff1a;MySQL&#xff1a;就像是你自己买零件组装的一台电脑。 你需要自己挑选CPU、内存、硬盘、主板&#xff…

arcgis中实现四色/五色法制图

四色定理是图论中的一个著名定理&#xff0c;它指出在任何地图上&#xff0c;只需四种颜色就足以使任何相邻的区域&#xff08;拥有共同边界线段&#xff0c;而非单个点&#xff09;颜色不同。五色定理则是另一个更早被证明的、较弱但更易证的定理。在地图制图中&#xff0c;这…

Spring如何巧妙解决循环依赖问题

什么是循环依赖&#xff1f;循环依赖是指两个或多个Bean之间相互依赖&#xff0c;形成闭环的情况。例如&#xff1a;AService依赖BService&#xff0c;而BService又依赖AService。这种场景下&#xff0c;传统的创建顺序无法满足依赖注入的要求。Spring的三级缓存机制Spring通过…

CUDA 中Thrust exclusive_scan使用详解

1. 基本概念Thrust 是 NVIDIA CUDA 提供的类似 C STL 的并行算法库。Scan (前缀和)&#xff1a;给定数组 [a0, a1, a2, ...]&#xff0c;产生前缀和序列。Exclusive Scan (排他前缀和)&#xff1a; 输出位置 i 存放的是输入数组中 0 到 i-1 的累积结果。换句话说&#xff0c;结…

Linux -- 信号【上】

目录 一、信号的引入 1、信号概念 2、signal函数 普通标准信号详解表 3、前台/后台进程 3.1 概念 3.2 查看后台进程 3.3 后台进程拉回前台 3.4 终止后台进程 3.5 暂停前台进程 3.6 回复运行后台进程 4、发信号的本质 二、信号的产生 1、终端按键 2、系统调用 2…

Altium Designer(AD)自定义PCB外观颜色

目录 1视图设置界面介绍 2PCB阻焊层颜色设置 2.1进入视图设置界面 2.2阻焊层颜色设置 2.3顶层和底层阻焊层颜色设置 2.4顶层阻焊层试图效果 2.5底层阻焊层试图效果 3设置PCB丝印颜色设置 3.1找到丝印设置选项 3.2设置顶层和底层丝印颜色 3.3顶层丝印 3.4底层丝印 4…

5天改造,节能50%!冷能改造如何实现“不停产节能”?

你有没有发现一个现象&#xff1f;很多工厂老板一提到节能改造&#xff0c;第一反应就是摇头。不是不想省电费&#xff0c;而是怕停产。停产一天损失几十万&#xff0c;改造周期动辄几个月&#xff0c;这账怎么算都不划算。但如果我告诉你&#xff0c;有一种改造方式&#xff0…

【Flink】窗口

目录窗口窗口的概念窗口的分类滚动窗口&#xff08;Tumbling Windows&#xff09;滑动窗口&#xff08;Sliding Windows&#xff09;会话窗口&#xff08;Session Windows&#xff09;全局窗口&#xff08;Global Windows&#xff09;窗口API概览窗口函数增量聚合函数ReduceFun…

攻击路径(4):API安全风险导致敏感数据泄漏

本文是《攻防演练 | JS泄露到主机失陷[1]》的学习笔记&#xff0c;欢迎大家阅读原文。攻击路径通过未授权访问攻击获取敏感数据通过SQL注入攻击获取服务器权限通过凭据访问攻击获取数据库权限和敏感数据和应用权限安全风险与加固措施通过未授权访问攻击获取敏感数据、通过SQL注…

机器学习面试题:请介绍一下你理解的集成学习算法

集成学习&#xff08;Ensemble Learning&#xff09;的核心思想是“集思广益”&#xff0c;它通过构建并结合多个基学习器&#xff08;Base Learner&#xff09;来完成学习任务&#xff0c;从而获得比单一学习器更显著优越的泛化性能。俗话说&#xff0c;“三个臭皮匠&#xff…

Invalid bound statement (not found): com.XXX.XXx.service.xxx无法执行service

org.apache.ibatis.binding.BindingException: Invalid bound statement (not found): com.xxx.xxx.service.CitytownService.selectCitytown 出现无法加载sevice层的时候&#xff0c;如下图所示1&#xff0c;处理方法是&#xff0c;先看下注解MapperScan内的包地址&#xff0c…

泛型(Generics)what why when【前端TS】

我总是提醒自己一定要严谨严谨严谨 目录TypeScript 泛型 (Generics)1. 什么是泛型&#xff1f;2. 为什么需要泛型&#xff1f;3. 泛型常见用法3.1 函数泛型3.2 接口泛型3.3 类泛型3.4 泛型约束3.5 泛型默认值3.6 多个泛型参数4. 泛型应用场景TypeScript 泛型 (Generics) 1. 什…

分布式协议与算法实战-协议和算法篇

05丨Paxos算法&#xff08;一&#xff09;&#xff1a;如何在多个节点间确定某变量的值? 提到分布式算法&#xff0c;就不得不提 Paxos 算法&#xff0c;在过去几十年里&#xff0c;它基本上是分布式共识的代名词&#xff0c;因为当前最常用的一批共识算法都是基于它改进的。比…

9.13 9.15 JavaWeb(事务管理、AOP P172-P182)

事务管理事务概念事务是一组操作的集合&#xff0c;是一个不可分割的工作单位&#xff0c;这些操作要么同时成功&#xff0c;要么同时失败操作开启事务&#xff08;一组操作开始前&#xff0c;开启事务&#xff09;&#xff1a;start transaction / begin提交事务&#xff08;这…

检索融合方法- Distribution-Based Score Fusion (DBSF)

在信息检索&#xff08;IR&#xff09;、推荐系统和多模态检索中&#xff0c;我们常常需要融合来自多个检索器或模型的结果。不同检索器可能对同一文档打出的分数差异很大&#xff0c;如果直接简单加权&#xff0c;很容易出现某个检索器“主导融合结果”的情况。 Distribution…