STORM代码阅读笔记

默认的 分辨率是 [160,240] ,基于 Transformer 的方法不能做高分辨率。

Dataloader

输入是 带有 pose 信息的 RGB 图像

eval datasets

## 采样帧数目 = 20
num_max_future_frames = int(self.timespan * fps) 
## 每次间隔多少个时间 timesteps 取一个context image
num_context_timesteps  = 4

按照STORM 原来的 setting, future_frames = 20 context_image 每次间隔4帧,所以是 context_frame_idx = [0,5,10,15], 在 target_frame 包含了 从[0,20]的所有20帧。

以这样20 帧的 image 作为一个基本的 batch, 进行预测: 进入 model

所以,输入网络的 context_image 对应的 shape (1,4,3,160,240) 输入4个时刻帧的 frame, 每一个 frame 有 3个相机;对应的 context_camtoworlds shape (1,4,3,4,4)

train datasets

第一帧 ID 随机采样, 之后的 context_image 每次间隔 5 帧,比如: [47 52 57 62]target·_frame_id 也是进行随机选取:

  if self.equispaced:context_frame_idx = np.arange(context_frame_idx,context_frame_idx + num_max_future_frames,num_max_future_frames // self.num_context_timesteps,)

随机在 num_future_id 里面 选择 self.num_target_timesteps 选择 4帧作为 target_image 的监督帧

Network

输入网络的 有3个 input: context_image, ray 和 time 的信息

  • context_image: (1,4,3,3,160,240)
  • Ray embedding (1,4,3,6,160,240)
  • time_embedding (1,4,3)
  • 将 image 和 ray_embedding 进行 concat 操作, 得到 x:(12,9,160,240)
 x = rearrange(x, "b t v c h w -> (b t v) c h w")
plucker_embeds = rearrange(plucker_embeds, "b t v h w c-> (b t v) c h w")
x = torch.cat([x, plucker_embeds], dim=1) ## (12,9,160,240)

然后经过3个 embedding , 将这些 feature 映射成为 token:

x = self.patch_embed(x)  # (b t v) h w c2x = self._pos_embed(x)  # (b t v) (h w) c2x = self._time_embed(x, time, num_views=v)

得到 x.shape (12,600,768), 表示一共有12张图像,每个图象 是 600 个 token, 每个 token 的 channel 是768. 然后将这些 token concat 在一起 得到了 (7200,768) 的 feature;

给得到的 token 分别加上可学习的 motion_token, affine_token 和 sky_token. 连接方式都是 concat
这样得到的 feature 为 (7220,768)的 feature

if self.num_motion_tokens > 0:motion_tokens = repeat(self.motion_tokens, "1 k d -> b k d", b=x.shape[0])x = torch.cat([motion_tokens, x], dim=-2)
if self.use_affine_token:affine_token = repeat(self.affine_token, "1 k d -> b k d", b=b)x = torch.cat([affine_token, x], dim=-2)
if self.use_sky_token:sky_token = repeat(self.sky_token, "1 1 d -> b 1 d", b=x.shape[0])x = torch.cat([sky_token, x], dim=-2)
  • 使用 Transformer 进行学习, 得到的 feature 维度不变。:
 x = self.transformer(x)x = self.norm(x) ## shape(7220,768)

运行完之后,可以将学习到的 token提取出来:

if self.use_sky_token:sky_token = x[:, :1] ## (1,1,768)x = x[:, 1:]if self.use_affine_token:affine_tokens = x[:, : self.num_cams] ## (1,3,768)x = x[:, self.num_cams :]if self.num_motion_tokens > 0:motion_tokens = x[:, : self.num_motion_tokens]  ## (1,16,768)x = x[:, self.num_motion_tokens :]

在 Transformer 内部,没有上采样层,也可以实现 这种 per-pixel feature 的学习。
对于 x 进行 GS 的预测,得到 pixel_align 的高斯。 对于每个 patch, 得到的 feature 是 (12,600,768), 通过一个CNN,虽然通道数没有变 (12,600,768), 但是之前 768 可以理解为全局的 语义, 之后的 768 为 一个patch 内部不同像素的语义,他们共享着 全局的 语义信息,但是每个pixel 却又不一样。 通过下面的 unpatchify 函数将将一个patch 的语义拆成 per-pixel 的语义,将每个768维token展开为8×8像素。

 b, t, v, h, w, _ = origins.shape## x_shape: (12,600,768)x = rearrange(x, "b (t v hw) c -> (b t v) hw c", t=t, v=v)## gs_params_shape: (12,600,768),这一步虽然通道没变,但其实是将一个 token 的全局 语义,映射成## token 内部的像素级别的语义gs_params = self.gs_pred(x)## gs_params_shape: (12,12,160,240)### 关键步骤:unpatchify将每个768维token展开为8×8像素gs_params = self.unpatchify(gs_params, hw=(h, w), patch_size=self.unpatch_size)

根据 token展开的 per-pixel feature, 进行3DGS 的属性预测

gs_params = rearrange(gs_params, "(b t v) c h w -> b t v h w c", t=t, v=v)
depth, scales, quats, opacitys, colors = gs_params.split([1, 3, 4, 1, self.gs_dim], dim=-1)
scales = self.scale_act_fn(scales)
opacitys = self.opacity_act_fn(opacitys)
depths = self.depth_act_fn(depth)
colors = self.rgb_act_fn(colors)
means = origins + directions * depths

除了3DGS 的一半属性之外, storm 还额外预测了其他的运动属性,包括:

其中: x: (1,7200,768) 代表 image_token, motion_tokens 是(1,16,768)代表 motion_token. 处理的大致思路是 motion_token 作为 query, 然后 image_token 映射的feature 作为 key, 去结合计算每一个 高斯的 moition_weightsmoition_bases

gs_params = self.forward_motion_predictor(x, motion_tokens, gs_params)
其中:
forward_flow = torch.einsum("b t v h w k, b k c -> b t v h w c", motion_weights, motion_bases)

moition_bases: shape: [1,16,3]
moition_weights: shape: [1,4,3,160,240,16]
forward_flow: shape: [1,4,3,160,240,3]: 是 weights 和bases 结合的结果

GS_param Rendering

  • 取出高斯的各项属性,尤其是 means 和 速度 forward_v: STORM 假设 在这 20帧是出于 匀速直线运动, 其速度时不变的,可能并不合理。我们的方法直接预测 BBX,可能更为准确。
means = rearrange(gs_params["means"], "b t v h w c -> b (t v h w) c")
scales = rearrange(gs_params["scales"], "b t v h w c -> b (t v h w) c")
quats = rearrange(gs_params["quats"], "b t v h w c -> b (t v h w) c")
opacities = rearrange(gs_params["opacities"], "b t v h w -> b (t v h w)")
colors = rearrange(gs_params["colors"], "b t v h w c -> b (t v h w) c")
forward_v = rearrange(gs_params["forward_flow"], "b t v h w c -> b (t v h w) c")

这里得到的 高斯的 mean 是全部由 context_image 得到的, shape (46800,3), 但这其实是 4个 时刻context_frame_idx = [0,5,10,15], 得到的高斯,并不处于同一时间刻度。
通过比较 target_timecontext_time 之间的插值,去得到每一个 target_time 的 3D Gaussian 的坐标means_batched

  if tgt_time.ndim == 3:tdiff_forward = tgt_time.unsqueeze(2) - ctx_time.unsqueeze(1)tdiff_forward = tdiff_forward.view(b * tgt_t, t * v, 1)tdiff_forward_batched = tdiff_forward.repeat_interleave(h * w, dim=1)else:tdiff_forward = tgt_time.unsqueeze(-1) - ctx_time.unsqueeze(-2)tdiff_forward = tdiff_forward.view(b * tgt_t, t, 1)tdiff_forward_batched = tdiff_forward.repeat_interleave(v * h * w, dim=1)forward_translation = forward_v_batched * tdiff_forward_batchedmeans_batched = means_batched + forward_translation ## (20,460800,3) 

使用 gsplatbatch_rasterization 函数:

  rendered_color, rendered_alpha, _ = rasterization(means=means_batched.float(),  ## (20,460800,3)quats=quats_batched.float(),scales=scales_batched.float(),opacities=opacities_batched.float(),colors=colors_batched.float(),viewmats=viewmats_batched,  ## (20,3,4,4)Ks=Ks_batched,  ## (20,3,3,3)width=tgt_w,height=tgt_h,render_mode="RGB+ED",near_plane=self.near,far_plane=self.far,packed=False,radius_clip=radius_clip,)

bug 记录:

当使用单个相机的时候,下面这段代码会把 维度搞错:

  if self.use_affine_token:affine = self.affine_linear(affine_tokens)  # b v (gs_dim * (gs_dim + 1))affine = rearrange(affine, "b v (p q) -> b v p q", p=self.gs_dim)images = torch.einsum("b t v h w p, b v p q -> b t v h w p", images, affine)gs_params["affine"] = affine

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/93942.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/93942.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2025电赛G题-发挥部分-参数自适应FIR滤波器

(1)测评现场提供由RLC元件(各1个)组成的“未知模型电路”。 按照图3所示,探究装置连接该电路的输入和输出端口,对该电路进行 自主学习、建模(不可借助外部测试设备),2分钟…

Linux基础 -- 内核快速向用户态共享内核变量方案之ctl_table

系统化、可直接上手的 /proc/sys sysctl 接口使用文档。内容涵盖:机制原理、适用场景、ctl_table 字段详解、常用解析器(proc_handler)完整清单与选型、最小样例到进阶(范围校验、毫秒→jiffies、字符串、数组、每网络命名空间&a…

【RH124知识点问答题】第3章 从命令行管理文件

1. 怎么理解“Linux中一切皆文件”?Linux是如何组织文件的?(1)“Linux中一切皆文件”的理解和文件组织:在Linux中,“一切皆文件”指的是Linux将各种设备、目录、文件等都视为文件对象进行管理。这种统一的文…

练习javaweb+mysql+jsp

只是简单的使用mysql、简单的练习。 有很多待完善的地方,比如list的servlet页面,应该判断有没有用户的。 比如list.jsp 应该循环list而不是写死 index.jsp 样式可以再优化一下的。比如按钮就特丑。 本文展示了一个简单的MySQL数据库操作练习项目&#x…

使用Nginx部署前端项目

使用Nginx部署前端项目 一、总述二、具体步骤 2.1解压2.2将原来的html文件夹的文件删除,将自己的静态资源文件放进去,点击nginx.exe文件启动项目2.3查看进程中是否有ngix的两个进程在浏览器中输入“localhost:端口号”即可访问。 2.4端口被占用情况处理 …

【论文学习】KAG论文翻译

文章目录KAG: Boosting LLMs in Professional Domains via Knowledge Augmented Generation摘要1 引言2 方法论2.1 LLM友好型知识表示2.2 互索引机制2.2.1 语义分块2.2.2 带丰富语境的的信息抽取2.2.3 领域知识注入与约束2.2.4 文本块向量与知识结构的相互索引2.3 逻辑形式求解…

24黑马SpringCloud安装MybatisPlus插件相关问题解决

目录 一、前言 二、菜单栏没有Other 三、Config Database里的dburl需要加上时区等配置 一、前言 在学习24黑马SpringCloud的MybatisPlus-12.拓展功能-代码生成器课程时,发现由于IDEA版本不同以及MybatisPlus版本更新会出现与视频不一致的相关问题,本博…

人工智能赋能聚合物及复合材料模型应用与实践

近年来,生成式人工智能(包括大语言模型、分子生成模型等)在聚合物及复合材料领域掀起革命性浪潮,其依托数据驱动与机理协同,从海量数据中挖掘构效关系、通过分子结构表示(如 SMILES、BigSMILES)…

MyBatis-Plus3

一、条件构造器和常用接口 1.wapper介绍 MyBatis-Plus 提供了一套强大的条件构造器(Wrapper),用于构建复杂的数据库查询条件。Wrapper 类允许开发者以链式调用的方式构造查询条件,无需编写繁琐的 SQL 语句,从而提高开…

GXP6040K压力传感器可应用于医疗/汽车/家电

GXP6040K 系列压力传感器是一种超小型,为设备小型化做出贡献的高精度半导体压力传感器,适用于生物医学、汽车电子、白色家电等领域。采用标准的SOP6 和 DIP6 封装形式,方便用户进行多种安装方式。 内部核心芯片是利用 MEMS(微机械…

Android ConstraintLayout 使用详解

什么是 ConstraintLayoutConstraintLayout(约束布局)是 Android Studio 2.2 引入的一种新型布局,现已成为 Android 开发中最强大、最灵活的布局管理器之一。它结合了 RelativeLayout 的相对定位和 LinearLayout 的线性布局优势,能…

Unity3D数学第三篇:坐标系与变换矩阵(空间转换篇)

Unity3D数学第一篇:向量与点、线、面(基础篇) Unity3D数学第二篇:旋转与欧拉角、四元数(核心变换篇) Unity3D数学第三篇:坐标系与变换矩阵(空间转换篇) Unity3D数学第…

UV安装并设置国内源

文章目录一、UV下载1.官方一键安装2.github下载安装二、更换国内镜像源(加速下载)方法1:临时环境变量(单次生效)方法2:永久配置(推荐)方法3:命令行直接指定源三、验证镜像…

1 前言:什么是 CICD 为什么要学 CICD

什么是 CI/CD 我的资源库网站:https://www.byteooo.cn 在开发阶段,许多编译工具会将我们的源码编译可使用的文件。例如 vue-cli 的项目会被 webpack 打包编译为浏览器的文件,Java 项目会被编译为 .class/jar 文件以供服务器使用。 但是&am…

GitHub 趋势日报 (2025年07月30日)

📊 由 TrendForge 系统生成 | 🌐 https://trendforge.devlive.org/ 🌐 本日报中的项目描述已自动翻译为中文 📈 今日获星趋势图 今日获星趋势图3579copyparty752supervision664500-AI-Agents-Projects483awesome403prompt-optim…

“非参数化”大语言模型与RAG的关系?

这个问题触及了一个关键的技术细节,两者关系密切,但层面不同: “非参数化”大语言模型是一个更广泛的概念或类别,而RAG(Retrieval-Augmented Generation)是实现这一概念最主流、最具体的一种技术框架。 您可…

LeetCode Hot 100:15. 三数之和

题目给你一个整数数组 nums ,判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k ,同时还满足 nums[i] nums[j] nums[k] 0 。请你返回所有和为 0 且不重复的三元组。注意:答案中不可以包含重复的三元组。示例 1&…

银行回单识别应用场景剖析

银行回单OCR识别技术通过自动化处理纸质或电子回单中的关键信息,显著提升了金融、企业及个人场景下的数据管理效率。以下是其核心应用场景及价值的详细剖析:一、企业财务场景自动化账务处理对账与记账:OCR自动提取交易日期、金额、账号等信息…

React的介绍和特点

1. React是什么? 1.1. React: 用于构建用户界面的JavaScript库1.2. React的官网文档:https://zh-hans.reactjs.org/ 2. React的特点2.1. 声明式编程: 目前整个大前端开发的模式:Vue、React、Flutter、SwiftUI只需要维护…

内核smmu学习

思考 smmu对外提供功能,设备驱动调用smmu 提供的api来配置页表,那其他设备是如何和smmu交互的?iommu 作为将不同smmu硬件的一个抽象封装,其它设备应该只能看到iommu这个封装层,那么iommu这个子系统是如何进行抽象的&a…