wan2.1代码笔记

GPU内存不够,可以先运行umt5,然后再运行wanpipeline,参考FLUX.1代码笔记,或者使用ComfyUI。
下面使用随机数代替umt5 embedding。

import torch
from diffusers.utils import export_to_video
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
model_id = "Wan-AI/Wan2___1-T2V-1___3B-Diffusers"vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained(model_id, tokenizer=None,text_encoder=None,vae=vae, torch_dtype=torch.bfloat16)
flow_shift = 3.0  # 5.0 for 720P, 3.0 for 480P
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
pipe.to("cuda")
prompt_embeds = torch.randn(1,226,4096).to('cuda')  #随机数
negative_prompt_embeds = torch.randn(1,226,4096).to('cuda') #随机数output = pipe(prompt=None,negative_prompt=None,prompt_embeds = prompt_embeds,negative_prompt_embeds = negative_prompt_embeds,num_inference_steps = 1,height=480,width=832,num_frames=81,guidance_scale=6.0,).frames[0]export_to_video(output, "output.mp4", fps=16)

在这里插入图片描述

WanPipeline的步骤和文生图的步骤基本一致。
1.检查输入;
2.定义参数;
3.encode prompt;
4.准备timesteps;
5.准备latent;
6.循环去噪,最后decode.
在这里插入图片描述
图5,Wan-VAE 在时间维度上压缩了4倍,空间维度上长和宽分别压缩了8倍。
channel数为16,latent的维度就是(1,16,21,60,104)

        num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 #(81-1)/4+1shape = (batch_size, #1num_channels_latents, #16num_latent_frames,int(height) // self.vae_scale_factor_spatial, #480/8int(width) // self.vae_scale_factor_spatial, #832/8)

WanTransformer3DModel

在patchify中,WanTransformer3DModel 使用(1,2,2)的3D卷积核,将输入的序列转换为(B,L,D)维度,其中B为batch size,L为(1+T/4)×H/16×W/16,D为latent的维度。

self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)rotary_emb = self.rope(hidden_states) #(1,1,32760,64)
hidden_states = self.patch_embedding(hidden_states) #(1,1536,21,30,52)
hidden_states = hidden_states.flatten(2).transpose(1, 2)#(1,32760,1536)temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image) #(1,1536),(1,9216),(1,226,1536),None
timestep_proj = timestep_proj.unflatten(1, (6, -1))#(1,6,1536)if encoder_hidden_states_image is not None:encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)for block in self.blocks:hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)# 5. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
# first device rather than the last device, which hidden_states ends up
# on.
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.proj_out(hidden_states)hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)if USE_PEFT_BACKEND:# remove `lora_scale` from each PEFT layerunscale_lora_layers(self, lora_scale)if not return_dict:return (output,)return Transformer2DModelOutput(sample=output)

WanTransformerBlock

1.3B模型有30个WanTransformerBlock,DiT结构。30个WanTransformerBlock是共享temb参数的,在每个Block中学习一个偏差self.scale_shift_table,self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5),通过大量实验证明,这种设计可使参数数量减少约 25%,并表明在相同参数规模下,该方法能显著提升性能。

class WanTransformerBlock(nn.Module):def forward(self,hidden_states: torch.Tensor,encoder_hidden_states: torch.Tensor,temb: torch.Tensor,rotary_emb: torch.Tensor,) -> torch.Tensor:shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (self.scale_shift_table + temb.float()).chunk(6, dim=1)# 1. Self-attentionnorm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)# 2. Cross-attentionnorm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)hidden_states = hidden_states + attn_output# 3. Feed-forwardnorm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)ff_output = self.ffn(norm_hidden_states)hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)return hidden_states

vae.decode

遍历全部的frame。

        x = self.post_quant_conv(z) #(1,16,21,60,104)for i in range(num_frame): #21self._conv_idx = [0]if i == 0:out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)else:out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)out = torch.cat([out, out_], 2)out = torch.clamp(out, min=-1.0, max=1.0)self.clear_cache()

1.3B模型中,feat_cache是一个长度为33的list,在完整的decode过程中,需要用到33个Conv3d。

        def _count_conv3d(model):count = 0for m in model.modules():if isinstance(m, WanCausalConv3d):count += 1return countself._conv_num = _count_conv3d(self.decoder)self._conv_idx = [0]self._feat_map = [None] * self._conv_num

class WanDecoder3d是执行Vae decode的类。根据图5,这里要运行两次时间维度的放大和三次空间维度的放大。
CACHE_T = 2,缓存后两个frame的值。
缓存处理,除了前两个frame的feat_cache要特殊处理,每个feat_cache元素都含有两个frame,然后和当前的frame凑成3个frame进行下一步计算。同时取feat_cache元素的最后一个frame和当前的frame,更新feat_cache。如图6

class WanDecoder3ddef forward(self, x, feat_cache=None, feat_idx=[0]):## conv1if feat_cache is not None:idx = feat_idx[0]cache_x = x[:, :, -CACHE_T:, :, :].clone()if cache_x.shape[2] < 2 and feat_cache[idx] is not None:# cache last frame of last two chunkcache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)x = self.conv_in(x, feat_cache[idx]) #(1,384,1,60,104)feat_cache[idx] = cache_xfeat_idx[0] += 1else:x = self.conv_in(x)## middlex = self.mid_block(x, feat_cache, feat_idx)#(1,384,1,60,104)## upsamplesfor up_block in self.up_blocks:x = up_block(x, feat_cache, feat_idx) #(1,192,2,120,208),(1,192,4,240,416),(1,96,4,480,832),(1,96,4,480,832)## headx = self.norm_out(x)x = self.nonlinearity(x)if feat_cache is not None:idx = feat_idx[0]cache_x = x[:, :, -CACHE_T:, :, :].clone()if cache_x.shape[2] < 2 and feat_cache[idx] is not None:# cache last frame of last two chunkcache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)x = self.conv_out(x, feat_cache[idx])feat_cache[idx] = cache_xfeat_idx[0] += 1else:x = self.conv_out(x)return x

在这里插入图片描述

WanCausalConv3d

class WanCausalConv3d(nn.Conv3d):r"""A custom 3D causal convolution layer with feature caching support.This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling featurecaching for efficient inference.Args:in_channels (int): Number of channels in the input imageout_channels (int): Number of channels produced by the convolutionkernel_size (int or tuple): Size of the convolving kernelstride (int or tuple, optional): Stride of the convolution. Default: 1padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0"""def __init__(self,in_channels: int,out_channels: int,kernel_size: Union[int, Tuple[int, int, int]],stride: Union[int, Tuple[int, int, int]] = 1,padding: Union[int, Tuple[int, int, int]] = 0,) -> None:super().__init__(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,)# Set up causal paddingself._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)self.padding = (0, 0, 0)def forward(self, x, cache_x=None):padding = list(self._padding)if cache_x is not None and self._padding[4] > 0:cache_x = cache_x.to(x.device)x = torch.cat([cache_x, x], dim=2)padding[4] -= cache_x.shape[2]x = F.pad(x, padding)return super().forward(x)

空间维度,self.resample,nn.Upsample上采样扩大2倍,然后维度缩小1/2.
时间维度,nn.Conv3d,输出维度扩大2倍。

        if mode == "upsample2d":self.resample = nn.Sequential(WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1))elif mode == "upsample3d":self.resample = nn.Sequential(WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1))self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))class WanUpsample(nn.Upsample):def forward(self, x):return super().forward(x.float()).type_as(x)

upsample blocks

ModuleList((0): WanUpBlock((resnets): ModuleList((0-2): 3 x WanResidualBlock((nonlinearity): SiLU()(norm1): WanRMS_norm()(conv1): WanCausalConv3d(384, 384, kernel_size=(3, 3, 3), stride=(1, 1, 1))(norm2): WanRMS_norm()(dropout): Dropout(p=0.0, inplace=False)(conv2): WanCausalConv3d(384, 384, kernel_size=(3, 3, 3), stride=(1, 1, 1))(conv_shortcut): Identity()))(upsamplers): ModuleList((0): WanResample((resample): Sequential((0): WanUpsample(scale_factor=(2.0, 2.0), mode='nearest-exact')(1): Conv2d(384, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))(time_conv): WanCausalConv3d(384, 768, kernel_size=(3, 1, 1), stride=(1, 1, 1)))))(1): WanUpBlock((resnets): ModuleList((0): WanResidualBlock((nonlinearity): SiLU()(norm1): WanRMS_norm()(conv1): WanCausalConv3d(192, 384, kernel_size=(3, 3, 3), stride=(1, 1, 1))(norm2): WanRMS_norm()(dropout): Dropout(p=0.0, inplace=False)(conv2): WanCausalConv3d(384, 384, kernel_size=(3, 3, 3), stride=(1, 1, 1))(conv_shortcut): WanCausalConv3d(192, 384, kernel_size=(1, 1, 1), stride=(1, 1, 1)))(1-2): 2 x WanResidualBlock((nonlinearity): SiLU()(norm1): WanRMS_norm()(conv1): WanCausalConv3d(384, 384, kernel_size=(3, 3, 3), stride=(1, 1, 1))(norm2): WanRMS_norm()(dropout): Dropout(p=0.0, inplace=False)(conv2): WanCausalConv3d(384, 384, kernel_size=(3, 3, 3), stride=(1, 1, 1))(conv_shortcut): Identity()))(upsamplers): ModuleList((0): WanResample((resample): Sequential((0): WanUpsample(scale_factor=(2.0, 2.0), mode='nearest-exact')(1): Conv2d(384, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))(time_conv): WanCausalConv3d(384, 768, kernel_size=(3, 1, 1), stride=(1, 1, 1)))))(2): WanUpBlock((resnets): ModuleList((0-2): 3 x WanResidualBlock((nonlinearity): SiLU()(norm1): WanRMS_norm()(conv1): WanCausalConv3d(192, 192, kernel_size=(3, 3, 3), stride=(1, 1, 1))(norm2): WanRMS_norm()(dropout): Dropout(p=0.0, inplace=False)(conv2): WanCausalConv3d(192, 192, kernel_size=(3, 3, 3), stride=(1, 1, 1))(conv_shortcut): Identity()))(upsamplers): ModuleList((0): WanResample((resample): Sequential((0): WanUpsample(scale_factor=(2.0, 2.0), mode='nearest-exact')(1): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))))))(3): WanUpBlock((resnets): ModuleList((0-2): 3 x WanResidualBlock((nonlinearity): SiLU()(norm1): WanRMS_norm()(conv1): WanCausalConv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1))(norm2): WanRMS_norm()(dropout): Dropout(p=0.0, inplace=False)(conv2): WanCausalConv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1))(conv_shortcut): Identity())))
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/diannao/84314.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

环境搭建与工具配置

3.1 本地环境搭建 3.1.1 WAMP环境搭建漏洞靶场&#xff08;一、二&#xff09; WAMP&#xff08;Windows Apache MySQL PHP&#xff09;是搭建本地Web漏洞靶场的基础环境。 安装步骤&#xff1a; Apache&#xff1a;下载并安装最新版Apache HTTP Server&#xff0c;配置监…

STM32F446主时钟失效时DAC输出异常现象解析与解决方案

—### 现象概述 在STM32F446微控制器应用中&#xff0c;若主时钟&#xff08;HSE&#xff09;的晶体信号对地短路&#xff0c;但DAC&#xff08;数模转换器&#xff09;仍能输出变化信号&#xff0c;这一现象看似矛盾&#xff0c;实则与系统时钟切换机制密切相关。本文将从硬件…

React 如何封装一个可复用的 Ant Design 组件

文章目录 前言一、为什么需要封装组件&#xff1f;二、 仿antd组件的Button按钮三、封装一个可复用的表格组件 (实战)1. 明确需求2. 设计组件 API3. 实现组件代码4. 使用组件 三、封装组件的最佳实践四、进阶优化 总结 前言 作为一名前端开发工程师&#xff0c;在日常项目中&a…

STC89C52RC/LE52RC

STC89C52RC 芯片手册原理图扩展版原理图 功能示例LED灯LED灯的常亮效果LED灯的闪烁LED灯的跑马灯效果&#xff1a;从左到右&#xff0c;从右到左 数码管静态数码管数码管计数mian.cApp.cApp.hCom.cCom.hDir.cDir.hInt.cInt.hMid.cMid.h 模板mian.cApp.cApp.hCom.cCom.hDir.cDir…

踩坑记录:RecyclerView 局部刷新notifyItemChanged多次调用只触发一次 onBindViewHolder 的原因

1. 问题背景 在做项目的时候&#xff0c;RecyclerView需要使用局部刷新&#xff0c;使用 notifyItemChanged(position, payload) 实现局部刷新&#xff0c;但发现调用多次只执行了一次&#xff0c;第二个刷新不生效。 2. 错误示例&#xff08;只处理 payloads.get(0)&#xff…

OpenLayers 加载鹰眼控件

注&#xff1a;当前使用的是 ol 5.3.0 版本&#xff0c;天地图使用的key请到天地图官网申请&#xff0c;并替换为自己的key 地图控件是一些用来与地图进行简单交互的工具&#xff0c;地图库预先封装好&#xff0c;可以供开发者直接使用。OpenLayers具有大部分常用的控件&#x…

WPF···

设置启动页 默认最后一个窗口关闭,程序退出,可以设置 修改窗体的icon图标 修改项目exe图标 双击项目名会看到代码 其他 在A窗体点击按钮打开B窗体,在B窗体设置WindowStartupLocation=“CenterOwner” 在A窗体的代码设置 B.Owner = this; B.Show(); B窗体生成在A窗体中间…

github公开项目爬取

import requestsdef search_github_repositories(keyword, tokenNone, languageNone, max_results1000):"""通过 GitHub API 搜索仓库&#xff0c;支持分页获取所有结果&#xff08;最多 1000 条&#xff09;:param keyword: 搜索关键词:param token: GitHub To…

防震基座在半导体晶圆制造设备抛光机详细应用案例-江苏泊苏系统集成有限公司

在半导体制造领域&#xff0c;晶圆抛光作为关键工序&#xff0c;对设备稳定性要求近乎苛刻。哪怕极其细微的振动&#xff0c;都可能对晶圆表面质量产生严重影响&#xff0c;进而左右芯片制造的成败。以下为您呈现一个防震基座在半导体晶圆制造设备抛光机上的经典应用案例。 企…

S32K开发环境搭建详细教程(一、S32K IDE安装注册)

一、S32K IDE安装注册 1、进入恩智浦官网https://www.nxp.com.cn/&#xff08;需要在官网注册一个账号&#xff09; 2、直接搜索 “Standard Software”&#xff0c;找到S32K3 Standard Software&#xff0c;点击进入 3、下载 (1)Automotive SW - S32K3 - S32 Design Studio…

Spring Cloud Gateway 微服务网关实战指南

上篇文章简单介绍了SpringCloud系列OpenFeign的基本用法以及Demo搭建&#xff08;Spring Cloud实战&#xff1a;OpenFeign远程调用与服务治理-CSDN博客&#xff09;&#xff0c;今天继续讲解下SpringCloud Gateway实战指南&#xff01;在分享之前继续回顾下本次SpringCloud的专…

MSP430G2553 USCI模块串口通信

1.前言 最近需要利用msp430连接蓝牙模块传递数据&#xff0c;于是死磕了一段时间串口&#xff0c;在这里记录一下 2.msp430串口模块 msp430的串口模块可以有USCI模块提供 在异步模式中&#xff0c; USCI_Ax 模块通过两个外部引脚&#xff0c; UCAxRXD 和 UCAxTXD&#xff0…

【产品经理从0到1】用户端产品设计与用户画像

思考 xx新闻的第一个版本应该做哪些事情呢&#xff1f; 用户端核心功能 用户端通用页面设计 思考 回想一下&#xff0c;大家在第一次使用一个新下载的App的时候会看到一些什么样的页面?这样的页面一般都是展示了一些什么内容? 引导页 概念 第一次安装App或者更新App后第…

多场景游戏AI新突破!Divide-Fuse-Conquer如何激发大模型“顿悟时刻“?

多场景游戏AI新突破&#xff01;Divide-Fuse-Conquer如何激发大模型"顿悟时刻"&#xff1f; 大语言模型在强化学习中偶现的"顿悟时刻"引人关注&#xff0c;但多场景游戏中训练不稳定、泛化能力差等问题亟待解决。Divide-Fuse-Conquer方法&#xff0c;通过…

佰力博科技与您探讨压电材料的原理与压电效应的应用

压电材料的原理基于正压电效应和逆压电效应&#xff0c;即机械能与电能之间的双向转换特性。 压电材料的原理源于其独特的晶体结构和电-机械耦合效应&#xff0c;具体可分为以下核心要点&#xff1a; 1. ‌正压电效应与逆压电效应的定义‌ ‌正压电效应‌&#xff1a;当压电…

算法备案审核周期

&#xff08;一&#xff09;主体备案审核 主体备案审核周期通常为7-10个工作日&#xff0c;监管部门将对企业提交的资质信息进行严格审查&#xff0c;审核重点包括&#xff1a; 营业执照的真实性、有效性及与备案主体的一致性。法人及算法安全责任人身份信息的准确性与有效性…

管理系统的接口文档

一、接口概述 本接口文档用于描述图书管理系统中的一系列 Restful 接口&#xff0c;涵盖图书的查询、添加、更新与删除操作&#xff0c;以及用户的登录注册等功能&#xff0c;方便客户端与服务器之间进行数据交互。 二、接口基础信息 接口地址&#xff1a;https://book-manag…

杰发科技AC7801——PWM获取固定脉冲个数

测试通道6 在初始化时候打开通道中断 void PWM1_GenerateFrequency(void) {PWM_CombineChConfig combineChConfig[1]; //组合模式相关结构体PWM_IndependentChConfig independentChConfig[2];//独立模式相关结构体PWM_ModulationConfigType pwmConfig; //PWM模式相关结构体PWM…

RL电路的响应

学完RC电路的响应&#xff0c;又过了一段时间了&#xff0c;想必很多人都忘了RC电路响应的一些内容。我们这次学习RL电路的响应&#xff0c;以此同时&#xff0c;其实也是带大家一起回忆一些之前所学的RC电路的响应的一些知识点。所以&#xff0c;这次的学习&#xff0c;其实也…

鸿蒙Flutter实战:21-混合开发详解-1-概述

引言 在前面的系列文章中&#xff0c;我们从搭建开发环境开始&#xff0c;讲到如何使用、集成第三方插件&#xff0c;如何将现有项目进行鸿蒙化改造&#xff0c;以及上架审核等内容&#xff1b;还以高德地图的 HarmonyOS SDK 的使用为例&#xff0c; 讲解了如何将高德地图集成…