【ConvLSTM第二期】模拟视频帧的时序建模(Python代码实现)

目录

  • 1 准备工作:python库包安装
    • 1.1 安装必要库
  • 案例说明:模拟视频帧的时序建模
    • ConvLSTM概述
    • 损失函数说明
    • (python全代码)
  • 参考

ConvLSTM的原理说明可参见另一博客-【ConvLSTM第一期】ConvLSTM原理。

1 准备工作:python库包安装

1.1 安装必要库

pip install torch torchvision matplotlib numpy

案例说明:模拟视频帧的时序建模

🎯 目标:给定一个人工生成的动态图像序列(例如移动的方块),使用 ConvLSTM 对其进行建模,输出预测结果,并查看输出的维度和特征变化。

ConvLSTM概述

ConvLSTM 的基本结构,包括:

  • ConvLSTMCell:实现了一个时间步的 ConvLSTM 单元,类似于一个“时刻”的神经元。
  • ConvLSTM:实现了多层ConvLSTM结构,能够处理一整个时间序列的视频帧数据。

损失函数说明

MSE(均方误差) 衡量预测值和真实值之间的平均平方差。
在这里插入图片描述

关于训练终止条件:
可以根据 MSE是否达到某个阈值(如 < 0.001)提前终止训练,这是所谓的 “Early Stopping(提前停止)策略”。

(python全代码)

MSE损失函数曲线如下:可知MSE一直在下降,虽然存在振荡
在这里插入图片描述

前9帧图像及预测的第十帧图像得到的动图如下:
在这里插入图片描述

python完整代码如下:

import os
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image# 设置字体
plt.rcParams['font.family'] = 'Times New Roman'# 创建保存图像目录
os.makedirs("./Figures", exist_ok=True)# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# ====================================
# 一、ConvLSTM 模型结构
# ====================================class ConvLSTMCell(nn.Module):def __init__(self, input_channels, hidden_channels, kernel_size, bias=True):super(ConvLSTMCell, self).__init__()padding = kernel_size // 2self.input_channels = input_channelsself.hidden_channels = hidden_channelsself.conv = nn.Conv2d(input_channels + hidden_channels, 4 * hidden_channels, kernel_size, padding=padding, bias=bias)def forward(self, x, h_prev, c_prev):combined = torch.cat([x, h_prev], dim=1)conv_output = self.conv(combined)cc_i, cc_f, cc_o, cc_g = torch.chunk(conv_output, 4, dim=1)i = torch.sigmoid(cc_i)f = torch.sigmoid(cc_f)o = torch.sigmoid(cc_o)g = torch.tanh(cc_g)c = f * c_prev + i * gh = o * torch.tanh(c)return h, cclass ConvLSTM(nn.Module):def __init__(self, input_channels, hidden_channels, kernel_size, num_layers):super(ConvLSTM, self).__init__()self.num_layers = num_layerslayers = []for i in range(num_layers):in_channels = input_channels if i == 0 else hidden_channelslayers.append(ConvLSTMCell(in_channels, hidden_channels, kernel_size))self.layers = nn.ModuleList(layers)def forward(self, input_seq):b, t, c, h, w = input_seq.size()h_t = [torch.zeros(b, layer.hidden_channels, h, w).to(input_seq.device) for layer in self.layers]c_t = [torch.zeros(b, layer.hidden_channels, h, w).to(input_seq.device) for layer in self.layers]for time in range(t):x = input_seq[:, time]for i, layer in enumerate(self.layers):h_t[i], c_t[i] = layer(x, h_t[i], c_t[i])x = h_t[i]return h_t[-1]  # 返回最后一层最后一帧的隐藏状态# ====================================
# 二、生成移动方块序列数据
# ====================================def generate_moving_square_sequence(batch_size, time_steps, height, width):data = torch.zeros((batch_size, time_steps, 1, height, width))for b in range(batch_size):dx = np.random.randint(1, 3)dy = np.random.randint(1, 3)x = np.random.randint(0, width - 6)y = np.random.randint(0, height - 6)for t in range(time_steps):data[b, t, 0, y:y+5, x:x+5] = 1.0x = (x + dx) % (width - 5)y = (y + dy) % (height - 5)return data# ====================================
# 三、模型、损失、优化器
# ====================================class ConvLSTM_Predictor(nn.Module):def __init__(self):super().__init__()self.convlstm = ConvLSTM(input_channels=1, hidden_channels=16, kernel_size=3, num_layers=1)self.decoder = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1)def forward(self, input_seq):hidden = self.convlstm(input_seq)pred = self.decoder(hidden)return predmodel = ConvLSTM_Predictor().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)# ====================================
# 四、训练过程
# ====================================mse_list = []
max_epochs = 100
mse_threshold = 0.001
height, width = 64, 64for epoch in range(max_epochs):model.train()seq = generate_moving_square_sequence(8, 10, height, width).to(device)input_seq = seq[:, :9]target_frame = seq[:, 9, 0].unsqueeze(1)optimizer.zero_grad()output = model(input_seq)loss = criterion(output, target_frame)loss.backward()optimizer.step()mse = loss.item()mse_list.append(mse)print(f"Epoch {epoch+1}/{max_epochs}, MSE: {mse:.6f}")# 提前停止条件if mse < mse_threshold:print(f"✅ 提前停止:MSE 已达到阈值 {mse_threshold}")break# ====================================
# 五、测试与可视化结果
# ====================================model.eval()
with torch.no_grad():test_seq = generate_moving_square_sequence(1, 10, height, width).to(device)input_seq = test_seq[:, :9]true_frame = test_seq[:, 9, 0]pred_frame = model(input_seq)[0, 0].cpu().numpy()# 保存输入帧
for t in range(9):frame = input_seq[0, t, 0].cpu().numpy()plt.imshow(frame, cmap='gray')plt.title(f"Input Frame t={t}")plt.colorbar()plt.savefig(f"./Figures/input_frame_{t}.png")plt.close()# 保存 Ground Truth
plt.imshow(true_frame[0].cpu().numpy(), cmap='gray')
plt.title("Ground Truth Frame t=9")
plt.colorbar()
plt.savefig("./Figures/ground_truth_t9.png")
plt.close()# 保存预测帧
plt.imshow(pred_frame, cmap='gray')
plt.title("Predicted Frame t=9")
plt.colorbar()
plt.savefig("./Figures/predicted_t9.png")
plt.close()# 保存 MSE 曲线图
plt.plot(mse_list)
plt.title("Training MSE Loss")
plt.xlabel("Epoch")
plt.ylabel("MSE")
plt.grid(True)
plt.savefig("./Figures/mse_curve.png")
plt.close()# ---------------- 生成动图 ----------------frames = []# 添加前9帧输入
for t in range(9):img = Image.open(f"./Figures/input_frame_{t}.png")frames.append(img.copy())# 添加预测帧
img = Image.open("./Figures/predicted_t9.png")
frames.append(img.copy())# 保存动图
frames[0].save("./Figures/sequence.gif", save_all=True, append_images=frames[1:], duration=500, loop=0)
print("✅ 所有图像和动图已保存至 ./Figures 文件夹")

参考

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/diannao/85624.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MySQL DDL操作全解析:从入门到精通,包含索引视图分区表等全操作解析

目录 一、DDL 基础概述 1.1 DDL 定义与作用 1.2 DDL 语句分类 1.3 数据类型与存储引擎 1.3.1 数据类型 1.3.2 存储引擎差异 二、基础 DDL 语句详解 2.1 创建数据库与表 2.1.1 创建数据库 2.1.2 创建表 2.2 修改表结构 2.2.1 添加列 2.2.2 修改列属性 2.2.3 删除列…

设计模式——抽象工厂设计模式(创建型)

摘要 抽象工厂设计模式是一种创建型设计模式&#xff0c;旨在提供一个接口&#xff0c;用于创建一系列相关或依赖的对象&#xff0c;无需指定具体类。它通过抽象工厂、具体工厂、抽象产品和具体产品等组件构建&#xff0c;相比工厂方法模式&#xff0c;能创建一个产品族。该模…

Express教程【006】:使用Express写接口

文章目录 8、使用Express写接口8.1 创建API路由模块8.2 编写GET接口8.3 编写POST接口 8、使用Express写接口 8.1 创建API路由模块 1️⃣新建routes/apiRouter.js路由模块&#xff1a; /*** 路由模块*/ // 1-导入express const express require(express); // 2-创建路由对象…

【iOS(swift)笔记-14】App版本不升级时本地数据库sqlite更新逻辑二

App版本不升级时&#xff0c;又想即时更新本地数据库怎么办&#xff1f; 办法二&#xff1a;从服务器下载最新的sqlite数据替换掉本地的数据&#xff08;注意是数据不是文件&#xff09; 稍加调整&#xff0c; // &#xff01;&#xff01;&#xff01;注意&#xff01;&…

Mac电脑_钥匙串操作选项变灰的情况下如何删除?

Mac电脑_钥匙串操作选项变灰的情况下如何删除&#xff1f; 这时候 可以使用相关的终端命令进行操作。 下面附加文章《Mac电脑_钥匙串操作的终端命令》。 《Mac电脑_钥匙串操作的终端命令》 &#xff08;来源&#xff1a;百度~百度AI 发布时间&#xff1a;2025-06&#xff09;…

对接系统外部服务组件技术方案

概述 当前系统需与多个外部系统对接,然而外部系统稳定性存在不确定性。对接过程中若出现异常,需依靠双方的日志信息来定位问题,但若日志信息不够完整,会极大降低问题定位效率。此外,问题发生后,很大程度上依赖第三方的重试机制,若第三方缺乏完善的重试机制,就需要手动…

WAF绕过,网络层面后门分析,Windows/linux/数据库提权实验

一、WAF绕过文件上传漏洞 win7&#xff1a;10.0.0.168 思路&#xff1a;要想要绕过WAF&#xff0c;第一步是要根据上传的内容找出来被拦截的原因。对于文件上传有三个可以考虑的点&#xff1a;文件后缀名&#xff0c;文件内容&#xff0c;文件类型。 第二步是根据找出来的拦截原…

一文学会c++中的内存管理知识点

文章目录 c/c内存管理c语言动态内存管理c动态内存管理new/delete自定义类型妙用operator new和operator delete malloc/new&#xff0c;free/delete区别 c/c内存管理 int globalVar 1;static int staticGlobalVar 1;void Test(){static int staticVar 1;int localVar 1;in…

深入解析Linux死锁:原理、原因及解决方案

Linux死锁是系统资源管理的致命陷阱&#xff0c;平均每年导致全球数据中心约​​3.7亿小时​​的服务中断。本文深度剖析死锁形成的​​四个必要条件​​和六种典型死锁场景&#xff0c;结合Linux内核源码层级的资源管理机制&#xff0c;揭示文件系统锁、内存分配、多线程同步等…

SKUA-GOCAD入门教程-第八节 线的创建与编辑2

8.1.3根据线创建曲线 (1)从线生成线 这个命令可以将一组曲线合并为一条曲线。每个输入曲线都会成为新曲线内的一个部分。 1、选择 Curve commands > New > Curves 打开对话框。 图1 根据曲线创建曲线 在“name”框中

『uniapp』把接口的内容下载为txt本地保存 / 读取本地保存的txt文件内容(详细图文注释)

目录 预览效果思路分析downloadTxt 方法readTxt 方法 完整代码总结 欢迎关注 『uniapp』 专栏&#xff0c;持续更新中 欢迎关注 『uniapp』 专栏&#xff0c;持续更新中 预览效果 思路分析 downloadTxt 方法 该方法主要完成两个任务&#xff1a; 下载 txt 文件&#xff1a;通…

攻防世界-unseping

进入环境 在获得的场景中发现PHP代码并进行分析 编写PHP编码 得到 Tzo0OiJlYXNlIjoyOntzOjEyOiIAZWFzZQBtZXRob2QiO3M6NDoicGluZyI7czoxMDoiAGVhc2UAYXJncyI7YToxOntpOjA7czozOiJwd2QiO319 将其传入 想执行ls&#xff0c;但是发现被过滤掉了 使用环境变量进行绕过 $a new…

IP查询与网络风险的关系

网络风险场景与IP查询的关联 网络攻击、恶意行为、数据泄露等风险事件频发&#xff0c;而IP地址作为网络设备的唯一标识&#xff0c;承载着关键线索。例如&#xff0c;在DDoS恶意行为中&#xff0c;攻击者利用大量IP地址发起流量洪泛&#xff1b;恶意行为通过变换IP地址绕过封…

pikachu通关教程-XSS

XSS XSS漏洞原理 XSS被称为跨站脚本攻击&#xff08;Cross Site Scripting&#xff09;&#xff0c;由于和层叠样式表&#xff08;Cascading Style Sheets&#xff0c;CSS&#xff09;重名&#xff0c;改为XSS。主要基于JavaScript语言进行恶意攻击&#xff0c;因为js非常灵活…

【时时三省】(C语言基础)数组作为函数参数

山不在高&#xff0c;有仙则名。水不在深&#xff0c;有龙则灵。 ----CSDN 时时三省 调用有参函数时&#xff0c;需要提供实参。例如sin ( x )&#xff0c;sqrt ( 2&#xff0c;0 )&#xff0c;max ( a&#xff0c;b )等。实参可以是常量、变量或表达式。数组元素的作用与变量…

硬件工程师笔记——555定时器应用Multisim电路仿真实验汇总

目录 一 555定时器基础知识 二、引脚功能 三、工作模式 1. 单稳态模式&#xff1a; 2. 双稳态模式&#xff08;需要外部电路辅助&#xff09;&#xff1a; 3. 无稳态模式&#xff08;多谐振荡器&#xff09;&#xff1a; 4. 可控脉冲宽度调制&#xff08;PWM&#xff09…

C++11特性:enum class(强枚举类型)详解

C11引入的 enum class&#xff08;强枚举类型&#xff09;解决了传统枚举的多个问题&#xff1a; 防止枚举值泄漏到外部作用域&#xff1b;禁止不同枚举间的隐式转换&#xff1b;允许指定底层数据类型优化内存&#xff1b;避免命名空间污染。 其基本语法为 enum class Name{.…

【QT】QString 与QString区别

在C中&#xff0c;QString 和 QString& 有本质区别&#xff0c;尤其是在参数传递和内存管理方面&#xff1a; 1. QString&#xff08;按值传递&#xff09; 创建副本&#xff1a;传递时会创建完整的字符串副本内存开销&#xff1a;可能涉及深拷贝&#xff08;特别是大字符…

提升四级阅读速度方法

以下是针对四级英语阅读速度提升的系统性解决方案&#xff0c;结合最新考试规律和高效训练方法&#xff0c;分五个核心模块整理&#xff1a; &#x1f680; ​​一、基础提速训练&#xff08;消除生理障碍&#xff09;​​ ​​扩大视幅范围​​ 从逐词阅读升级为 ​​意群阅读…

6.4 note

构造矩阵 class Solution { private: vector<int> empty {}; // 返回每个数字(-1)所在的序号&#xff0c;可以是行或列, 如果为空则无效 vector<int> topoSort(int k, vector<vector<int>>& conditions) { // 构建一个图…