【笔记】扩散模型(一一):Stable Diffusion XL 理论与实现

论文链接:SDXL: Improving Latent Diffusion Models for High-Resolution Image Synthesis

官方实现:Stability-AI/generative-models

非官方实现:huggingface/diffusers

Stable Diffusion XL (SDXL) 是 Stablility AI 对 Stable Diffusion 进行改进的工作,主要通过一些工程化的手段提高了 SD 模型的生成能力。相比于 Stable Diffusion,SDXL 对模型架构、条件注入、训练策略等都进行了优化,并且还引入了一个额外的 refiner,用于对生成图像进行超分,得到高分辨率图像。

Stable Diffusion XL

模型架构改进

SDXL 对模型的 VAE、UNet 和 text encoder 都进行了改进,下面依次介绍一下。

VAE

相比于 Stable Diffusion,SDXL 对 VAE 模型进行了重新训练,训练时使用了更大的 batchsize(256,Stable Diffusion 使用的则是 9),并且使用了指数移动平均,得到了性能更强的 VAE。

需要注意的是,SD 2.x 的 VAE 相对 SD 1.x 只对 decoder 进行了微调,而 encoder 不变。因此两者的 latent space 是相同的,VAE 可以互相交换使用。但 SDXL 对 encoder 也进行了微调,因此 latent space 发生了变化,SDXL 不能用 SD 的 VAE,SD 也不能用 SDXL 的 VAE。另外,由于 SDXL 的 VAE 在 fp16 下会发生溢出,所以其必须在 fp32 类型下进行推理。

UNet

SDXL 使用了更大的 UNet 模块,具体来说做了以下几个变化:

  1. 为了提高效率,使用了更少的 3 个 stage,而不是 SD 使用的 4 个 stage;
  2. 将 transformer block 移动到更深的 stage,第一个 stage 没有 transformer block;
  3. 使用更多的 transformer block。

详情可以看下边的表格,可以看到第二个 stage 和第三个 stage 分别使用了 2 个和 10 个 transformer block,最后 UNet 的整体参数量变成了大约 3 倍。

模型SDXLSD1.4/1.5SD 2.0/2.1
UNet 参数量2.6 B860 M865 M
Transformer block 数量[0, 2, 10][1, 1, 1, 1][1, 1, 1, 1]
通道倍增系数[1, 2, 4][1, 2, 4, 4][1, 2, 4, 4]

Text Encoder

SDXL 还使用了更强的 text encoder,其同时使用了 OpenCLIP ViT-bigG 和 OpenAI CLIP ViT-L,使用时同时用两个 encoder 处理文本,并将倒数第二层特征拼接起来,得到一个 1280+768=2048 通道的文本特征作为最终使用的文本嵌入。

除此之外,SDXL 还使用 OpenCLIP ViT-bigG 的 pooled text embedding 映射到 time embedding 维度并与之相加,作为辅助的文本条件注入。

其与 SD 1.x 和 2.x 的比较如下表所示:

模型SDXLSD 1.4/1.5SD 2.0/2.1
文本编码器OpenCLIP ViT-bigG & CLIP ViT-LCLIP ViT-LOpenCLIP ViT-H
特征通道数20487681024
Pooled text emb.OpenCLIP ViT-bigGN/AN/A

Refine Model

除了上述结构变化之外,SDXL 还级联了一个 refine model 用来细化模型的生成结果。这个 refine model 相当于一个 img2img 模型,在模型中的位置如下所示:

SDXL 整体架构,refine model 级联在基础模型的后方

这个 refine model 的主要目的是进一步提高图像的生成质量。其是单独训练的,专注于对高质量高分辨率数据的学习,并且只在比较低的 noise level 上(即前 200 个时间步)进行训练。

在推理阶段,首先从 base model 完成正常的生成过程,然后再加一些噪音用 refine model 进一步去噪。这样可以使图像的细节得到一定的提升。

Refine model 的结构与 base model 有所不同,主要体现在以下几个方面:

  1. Refine model 使用了 4 个 stage,特征维度采用了 384(base model 为 320);
  2. Transformer block 在各个 stage 的数量为 [0, 4, 4, 0],最终参数量为 2.3 B,略小于 base model;
  3. 条件注入方面:text encoder 只使用了 OpenCLIP ViT-bigG;并且同样也使用了尺寸和裁剪的条件注入(这个下文会讲);除此之外还使用了 aesthetic score 作为条件。

条件注入的改进

SDXL 引入了额外的条件注入来改善训练过程中的数据处理问题,主要包括图像尺寸图像裁剪问题。

图像尺寸条件

Stable Diffusion 的训练通常分为多个阶段,先在 256 分辨率的数据上进行训练,再在 512 分辨率的数据上进行训练,每次训练时需要过滤掉尺寸小于训练尺寸的图像。根据统计,如果直接丢弃所有分辨率不足 256 的图像,会浪费大约 40% 的数据。如果不希望丢弃图像,可以使用超分辨率模型先将图像超分到所需的分辨率,但这样会导致数据质量的降低,影响训练效果。

为了解决这个问题,作者加入了一种额外的图像尺寸条件注入。作者将原图的宽高进行傅立叶编码,然后将特征拼接起来加到 time embedding 上作为额外条件。

在训练时直接使用原图的宽高作为条件,推理的时候可以自定义宽高,生成不同质量的图像,下面的图是一个例子,可以看到当以较小的尺寸为条件时,生成的图比较模糊,反之则清晰且细节更丰富:

SDXL 中的 size conditioning

图像裁剪条件

在 SD 训练时使用的是固定尺寸(例如 512x512),使用时需要对原图进行处理。一般的处理流程是先 resize 到短边为目标尺寸,然后沿着长边进行裁剪。这种裁剪会导致图像部分缺失的问题,例如生成的图像部分会出现部分缺失,就是因为裁剪后的数据是不完整的。

为了解决这个问题,SDXL 在训练时把裁剪位置的坐标也当作条件注入进来。具体做法是把左上角的像素坐标值也进行傅立叶编码+拼接,再加到 time embedding 上,这样模型就能得知使用的数据是在什么位置进行的裁剪。

在推理阶段,只需要将这个条件设置为 (0, 0) 即可得到正常图像。如果设置成其他的值则能得到裁剪图像,例如下边图里的效果。(感觉还是很神奇的,竟然这种条件能 work,而且没有和图像尺寸的条件混淆)

SDXL 中的 crop conditioning

训练策略的改进

多比例微调

在训练阶段使用的都是正方形图像,但是现实图像很多都是有一定的长宽比的图像。因此在训练后的微调阶段,还使用了一种多比例微调的策略。

具体来说,这种方法预先将训练集按照图像长宽比不同分成多个 bucket,在微调时每次随机选取一个 bucket,并从中采样出一个 batch 的数据进行训练。在原论文中给出了一个表格,从表中可以看到选取的长宽比从 0.25(对应 512×2048512\times2048512×2048 分辨率) 到 4(对应 2048×5122048\times5122048×512 分辨率)不等,并且总像素数基本都维持在 102421024^210242 左右。

在这个微调阶段,除了使用尺寸和裁剪条件注入,还使用了 bucket size(也就是生成的目标大小) 作为一个条件,用相同的方式进行了注入。经过这样的条件注入和微调,模型就能生成多种长宽比的图像。

Noise Offset

在多比例微调的阶段,SDXL 还使用了一种 noise offset 的方法,来解决 SD 只能生成中等亮度图像、而无法生成纯黑或者纯白图像的问题。这个问题出现的原因是在训练和采样阶段之间存在一定的 bias,训练时在最后一个时间步的时候实际上是没有加噪的,所以会出现一些问题,解决方案也比较简单,就是在训练的时候给噪声添加一个比较小的 offset 即可。

SDXL 代码解析

这里依然以 diffusers 提供的训练代码为主进行分析,模型架构的改变主要体现在加载的预训练模型中(之后应该会出一期怎么看 huggingface 里的那些文件以及 config.json 的教程),这里主要分析一下各种条件注入和训练策略是怎么实现的。

各种条件注入

首先是尺寸和裁剪的条件注入,在图像进行预处理的阶段就记录下了每张图的原始尺寸以及裁剪位置:

def preprocess_train(examples):images = [image.convert("RGB") for image in examples[image_column]]original_sizes = []all_images = []crop_top_lefts = []for image in images:# 在这里记录原始尺寸original_sizes.append((image.height, image.width))# 调整图片大小image = train_resize(image)# 以 0.5 的概率进行随机翻转if args.random_flip and random.random() < 0.5:image = train_flip(image)# 进行裁剪if args.center_crop:y1 = max(0, int(round((image.height - args.resolution) / 2.0)))x1 = max(0, int(round((image.width - args.resolution) / 2.0)))image = train_crop(image)else:y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))image = crop(image, y1, x1, h, w)# 在这里记录裁剪位置crop_top_left = (y1, x1)crop_top_lefts.append(crop_top_left)image = train_transforms(image)all_images.append(image)examples["original_sizes"] = original_sizesexamples["crop_top_lefts"] = crop_top_leftsexamples["pixel_values"] = all_imagesreturn examples

随后原始尺寸和裁剪位置被进行编码,可以看到下边这部分包含了三部分的条件注入:

def compute_time_ids(original_size, crops_coords_top_left):target_size = (args.resolution, args.resolution)# 包括三部分的条件注入,分别为:# 1. 原始尺寸;2. 裁剪位置;3. 目标尺寸add_time_ids = list(original_size + crops_coords_top_left + target_size)add_time_ids = torch.tensor([add_time_ids])add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)return add_time_idsadd_time_ids = torch.cat([compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
)

最后把 pooled prompt embedding 也加入进来:

unet_added_conditions = {"time_ids": add_time_ids}
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})

这样四种条件注入就都准备好了,在 forward 时直接传到 UNet 的 added_cond_kwargs 参数即可参与计算。这些参数在 get_aug_embed 中被组合起来添加到 time embedding 上:

# pooled text embedding
text_embeds = added_cond_kwargs.get("text_embeds")
# 1. 原始尺寸;2. 裁剪位置;3. 目标尺寸
time_ids = added_cond_kwargs.get("time_ids")
# 处理得到最终加到 time embedding 上的条件嵌入
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)

在上边的代码里用到了两个对象 self.add_time_projself.add_embedding,定义为:

self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)

这两个对象中,Timesteps 应该是负责傅立叶编码,TimestepEmbedding 则负责对编码后的结果进行嵌入。

Noise Offset

这个实现很简单,就在加噪前对 noise 随机偏移一下即可:

if args.noise_offset:# https://www.crosslabs.org//blog/diffusion-with-offset-noisenoise += args.noise_offset * torch.randn((model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device)

多尺度微调

根据我的观察,diffusers 里没有直接提供多尺度微调相关的代码,应该是默认在训练之前已经自行处理好了各个 bucket 的图像。印象中前段时间某个组织开源了一份分 bucket 的代码,不过因为当时没保存所以现在找不到了,能找到的主要是 kohya-ss/sd-scripts 的一个实现。

大体的原理是先创建一系列桶,然后对于每张图片,选择长宽比最接近的一个桶,然后进行裁剪,裁剪到和这个桶对应的分辨率相同。由于相邻两个桶之间的分辨率之差为 64,所以最多裁剪 32 像素,对训练的影响并不大。在将图片分桶之后,则可以按照每个桶的数据比例作为概率进行采样。如果某些桶中的数据量不足一个 batch,则把这个桶中的数据都放入一个公共桶中,并以标准的 1024×10241024\times10241024×1024 分辨率进行训练。

如果读者有兴趣自己阅读代码,可以先看 library.model_util 模块中的 make_bucket_resolutions,这个方法创建了一系列分辨率的 bucket,并在 library.train_util.BucketManager 中调用,用来创建 bucket。这个 BucketManager 提供了一个方法 select_bucket,用来为某个特定分辨率的图像选择 bucket。最后在 library.train_util.BaseDataset 中,会对每张图片调用 select_bucket 选择 bucket,再将对应的图片加入到选择的 bucket 中。

总结

感觉 SDXL 是一个比较工程的工作,尤其是对模型架构的修改,比较大力出奇迹。除此之外感觉对数据的理解还是很重要的,除了修改模型架构之外的其他工作都是围绕着数据展开的,这也是比较值得学习的思路。

参考资料:

  1. 深入浅出完整解析Stable Diffusion XL(SDXL)核心基础知识
  2. 扩散模型(七)| SDXL

本文原文以 CC BY-NC-SA 4.0 许可协议发布于 笔记|扩散模型(一一):Stable Diffusion XL 理论与实现,转载请注明出处。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/web/93637.shtml
繁体地址,请注明出处:http://hk.pswp.cn/web/93637.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

学习安卓APP开发,10年磨一剑,b4a/Android Studio

学习安卓APP开发 记得上次学APP都是在2016年前了&#xff0c;一晃就过去了10年。 当时用ANDROID studio打开一个空项目和编绎分别用了300秒&#xff0c;一下就用了10分钟。 后来买了一台一万多的电脑&#xff0c;CPU换成了I5 8600K 4.2GHZ*6核&#xff0c;再加上M2固态硬盘。 编…

调试技巧(vs2022 C语言)

调试之前首先要保证我们的脑袋是清晰的&#xff0c;我们调试的过程主要是看代码有没有按照我们的想法去运行调试最常使用的几个快捷键F5启动调试&#xff0c;经常用来直接跳到下一个断点处&#xff08;F5通常和F9配合使用&#xff0c;打了断点按F5程序可以直接运行到断点处&…

MySQL深度理解-Innodb底层原理

1.MySQL的内部组件结构大体来说&#xff0c;MySQL可以分为Server层和存储引擎层两部分。2.Server层Server层主要包括连接器、查询缓存、分析器、优化器和执行器等&#xff0c;涵盖MySQL的大多数核心服务功能&#xff0c;以及所有的内置函数&#xff08;如日期、时间、数据和加密…

QFtp在切换目录、上传文件、下载文件、删除文件等一系列操作时,如何按照预期操作指令顺序执行

FTP服务初始化时&#xff0c;考虑到重连、以及commandFinished信号信号执行完成置m_bCmdFinished 为true; void ICore::connectFtpServer() {if(g_pFile nullptr){g_pFile new QFile;}if(g_pFtp){g_pFtp->state();g_pFtp->abort();g_pFtp->deleteLater();g_pFtp n…

JavaSE高级-02

文章目录1. 多线程1.1 创建线程的三种方式多线程的创建方式一&#xff1a;继承Thread类多线程的创建方式二&#xff1a;实现Runnable接口多线程的创建方式三&#xff1a;实现Callable接口三种线程的创建方式对比Thread的常用方法1.2 线程安全线程同步方式一&#xff1a;同步代码…

从舒适度提升到能耗降低再到安全保障,楼宇自控作用关键

在现代建筑的发展历程中&#xff0c;楼宇自动化控制系统&#xff08;BAS&#xff09;已从单纯的设备管理工具演变为集舒适度优化、能耗控制与安全保障于一体的核心技术。随着物联网和人工智能的深度应用&#xff0c;楼宇自控系统正以数据为纽带&#xff0c;重构人与建筑的关系。…

图像分类精度评价的方法——误差矩阵、总体精度、用户精度、生产者精度、Kappa 系数

本文详细介绍 “图像分类精度评价的方法”。 图像分类后&#xff0c;需要评估分类结果的准确性&#xff0c;以判断分类器的性能和结果的可靠性。 常涉及到下面几个概念&#xff08;指标&#xff09; 误差矩阵、总体精度、用户精度、生产者精度和 Kappa 系数。1. 误差矩阵&#…

【科普向-第一篇】数字钥匙生态全景:手机厂商、车厂与协议之争

目录 一、协议标准之争&#xff1a;谁制定规则&#xff0c;谁掌控入口 1.1 ICCE&#xff1a;中国车企主导的自主防线 1.2 ICCOA&#xff1a;手机厂商的生态突围 1.3 CCC&#xff1a;国际巨头的高端壁垒 1.4 协议对比 二、底层技术路线&#xff1a;成本与安全的博弈 2.1B…

dockerfile及docker常用操作

1: docker 编写 Dockerfile 是用于构建 Docker 镜像的文本文件&#xff0c;包含一系列指令和参数&#xff0c;用于定义镜像的构建过程 以下是关键要点&#xff1a; 一、基本结构 ‌FROM‌&#xff1a;必须作为第一条指令&#xff0c;指定基础镜像&#xff08;如 FROM python:3.…

[vibe coding-lovable]lovable是不是ai界的复制忍者卡卡西?

在火影忍者的世界里&#xff0c;卡卡西也被称为复制忍者&#xff0c;因为大部分忍术都可以被其Copy! 截图提示:实现这个效果 -> 发给Lovalbe -> 生成的的效果如下&#xff0c;虽然不是1比1还原&#xff0c;但是这个效果也很惊艳。 这个交互设计&#xff0c;这个UI效果&am…

技术赋能安全:智慧工地构建城市建设新防线

城市建设的热潮中&#xff0c;工地安全始终是关乎生命与发展的核心议题。江西新余火灾等事故的沉痛教训&#xff0c;暴露了传统工地监管的诸多短板——流动焊机“行踪难觅”&#xff0c;无证动火作业屡禁不止&#xff0c;每一次监管缺位都可能引发灾难性后果。如今&#xff0c;…

Sublime Text 代码编辑器(Mac中文)

原文地址&#xff1a;Sublime Text Mac 代码编辑器 sublime text Mac一款轻量级的文本编辑器&#xff0c;拥有丰富的功能和插件。 它支持多种编程语言&#xff0c;包括C、Java、Python、Ruby等&#xff0c;可以帮助程序员快速编写代码。 Sublime Text的界面简洁、美观&#…

如何制定项目时间线,合理预计?

制定一份现实可行且行之有效的项目时间线&#xff0c;是一个系统性的分解、估算与排序过程&#xff0c;而非简单的日期罗列。核心步骤包括&#xff1a;明确项目范围与可交付成果、利用工作分解结构&#xff08;WBS&#xff09;进行任务拆解、科学估算各项任务的持续时间、识别并…

RSA详解

一、RSA 简介RSA 是一种公钥密码体制&#xff0c;由罗纳德・李维斯特&#xff08;Ron Rivest&#xff09;、阿迪・萨莫尔&#xff08;Adi Shamir&#xff09;和伦纳德・阿德曼&#xff08;Leonard Adleman&#xff09;于 1977 年提出&#xff0c;算法名称由他们三人姓氏的首字母…

Linux获取物理硬盘总容量

获取物理硬盘总容量: 1.查看单个硬盘: 使用 lsblk 或 fdisk -l (需要 sudo) 命令。它们会直接列出物理硬盘 (sda, nvme0n1 等) 和它们的分区,并显示硬盘的总物理容量。 abcd四块物理盘,只挂载使用3块,留一块未使用 最常见的原因通常是配置了热备盘(RAID 1/5/6/10 等冗余…

STM32学习笔记14-I2C硬件控制

I2C外设简介STM32内部集成了硬件I2C收发电路&#xff08;硬件收发器&#xff1a;自动生产波形&#xff0c;自动翻转电平等&#xff09;&#xff0c;可以由硬件自动执行时钟生成、起始终止条件生成、应答位收发、数据收发等功能&#xff0c;减轻CPU的负担——软件只需要写入控制…

电子电气架构 --- 软件开发数字化转型

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 简单,单纯,喜欢独处,独来独往,不易合同频过着接地气的生活,除了生存温饱问题之外,没有什么过多的欲望,表面看起来很高冷,内心热情,如果你身…

我国空间站首次应用专业领域 AI大模型

据中国载人航天工程办公室消息&#xff0c;北京时间2025年8月15日22时47分&#xff0c;经过约6.5小时的出舱活动&#xff0c;神舟二十号乘组航天员陈冬、陈中瑞、王杰密切协同&#xff0c;在空间站机械臂和地面科研人员的配合支持下&#xff0c;圆满完成既定任务&#xff0c;出…

WPF真入门教程35--手搓WPF出真汁【蜀味正道CS版】

1、项目介绍 本项目采用多层架构设计&#xff0c;使用wpf&#xff0c;Panuon.UI.Silver控件库&#xff0c;AduSkin皮肤&#xff0c;MVVM等技术开发具有复杂交互和视觉效果的CS应用程序。WPF适用于企业级桌面应用&#xff1a;如ERP、CRM系统&#xff0c;需复杂表单和报表。WPF适…

JMeter与大模型融合应用之构建AI智能体:评审性能测试脚本

JMeter与大模型融合应用之构建AI智能体&#xff1a;评审性能测试脚本 一、引言 随着DevOps和持续测试的普及&#xff0c;性能测试已成为软件开发生命周期中不可或缺的环节。Apache JMeter作为最流行的开源性能测试工具之一&#xff0c;被广泛应用于各种性能测试场景。然而&…