之前已经完整的拆解了CLIP中所用到的ResNet、ViT和Transformer三个模型(CLIP拆解-CSDN博客),这篇将讲解model.py实现中的其他细节。
1.关于ResNet模型中vision_head的设置
ResNet:
vision_heads = vision_width * 32 // 64
ViT:
vision_heads = vision_width // 64
ResNet需要乘32是因为经过前面卷积处理后输入AttentionPool2d的是width*32,所以计算head的时候要把这个考虑进去。至于这里的64是分为多头后每一个头的embed的通道数,ResNet通常取64,ViT-B常取768
2.关于conver_weights
convert_weights()
是为了节省显存、提高推理速度,将模型中适合的权重转换为 fp16。
(1)half()的作用 就是把fp32转为fp16,如果输入本身是 fp16,那将不进行任何处理。
(2)一些结构不建议转化为fp16,因为转化后会不稳定,所以选择性的处理
def _convert_weights_to_fp16(l):if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):l.weight.data = l.weight.data.half()if l.bias is not None:l.bias.data = l.bias.data.half()if isinstance(l, nn.MultiheadAttention):for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:tensor = getattr(l, attr)if tensor is not None:tensor.data = tensor.data.half()for name in ["text_projection", "proj"]:if hasattr(l, name):attr = getattr(l, name)if attr is not None:attr.data = attr.data.half()
下面是常见的不建议使用fp16的模块:
模块/操作 | 原因说明 |
---|---|
LayerNorm / BatchNorm | 均值/方差运算容易数值下溢,精度敏感 |
Softmax / LogSoftmax | 输出接近 0 或 1,fp16 下舍入误差大 |
Sigmoid / Tanh | 对小输入不敏感,精度损失后容易失效 |
CrossEntropyLoss | 包含 log(softmax) ,fp16 精度不足导致数值不稳定 |
Attention (部分实现) | scaled dot-product 会导致爆炸,尤其是大输入或长序列时 |
Exp , Div , Log | 本身不稳定,数值小容易下溢出为 0 |
3.模型输入也要相应的进行转化,否则会遇到类型不匹配的问题
解决方法1:使用autocast
from torch.cuda.amp import autocastwith autocast():output = model(x) # 自动在每一层内部管理精度转换
但autocast只针对模块的外部类型来判断是否进行类型转化(如nn.Linear, nn.Conv2d),但是自定义的模块(类)autocast不会进行类型转换(autocast只是解决了类型不匹配的问题,但是低精度产生的梯度爆炸等问题无法解决,由反向传播时gradscaler解决)
问题场景 | AMP 是否能处理 | 说明 |
---|---|---|
输入是 fp16,模块需要 fp32 | ✅ autocast() 会自动转换 | |
自定义模块内部 + ,/ 导致类型错 | ❌ 你要自己管理,AMP 不管你自写的算子 | |
梯度为 0 或爆炸 | ✅ GradScaler() 自动放大/还原 | |
权重混用不同精度 | ✅ 支持 | |
推理时类型优化(加速,混用不同精度) | ✅ 只用 autocast() 即可 |
解决方法2:手动转化类型
# 例如 LayerNorm 中人为转 float32:
def forward(self, x):orig_type = x.dtyperet = super().forward(x.float()) # 保证 LayerNorm 在 float32 下执行return ret.to(orig_type)
4.关于forward的输出
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
logit_scale是缩放因子,定义是self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
logits_per_image是图像视角下的相似度分布,用于计算图像到文本的对比损失
logits_per_text是文本视角下的相似度分布,和图像视角下对称。
5.关于权重初始化
(1)ResNet的bn3初始化为0
for resnet_block in [self.visual.layer1,self.visual.layer2,self.visual.layer3,self.visual.layer4]:for name, param in resnet_block.named_parameters():if name.endswith("bn3.weight"):nn.init.zeros_(param)
手动初始化bn3.weight为0确保为恒等映射,从而防止残差支路输出不稳定、扰动太大的问题。
(2)CLIP中的手动初始化和自动初始化
CLIP只手动初始化了一些对训练稳定性或性能影响较大的模块,如embedding和位置编码(nanoGPT中也对这两个部分进行了手动初始化)、QKVC投影、transformer最后输出的初始化
def initialize_parameters(self):nn.init.normal_(self.token_embedding.weight, std=0.02)nn.init.normal_(self.positional_embedding, std=0.01)if isinstance(self.visual, ModifiedResNet):if self.visual.attnpool is not None:std = self.visual.attnpool.c_proj.in_features ** -0.5nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:for name, param in resnet_block.named_parameters():if name.endswith("bn3.weight"):nn.init.zeros_(param)proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)attn_std = self.transformer.width ** -0.5fc_std = (2 * self.transformer.width) ** -0.5for block in self.transformer.resblocks:nn.init.normal_(block.attn.in_proj_weight, std=attn_std)nn.init.normal_(block.attn.out_proj.weight, std=proj_std)nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)if self.text_projection is not None:nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
***与nanoGPT的_init_weights对比
# mainself.apply(self._init_weights)# apply special scaled init to the residual projections, per GPT-2 paperfor pn, p in self.named_parameters():if pn.endswith('c_proj.weight'):torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))#init_weightdef _init_weights(self, module):if isinstance(module, nn.Linear):torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)if module.bias is not None:torch.nn.init.zeros_(module.bias)elif isinstance(module, nn.Embedding):torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
*GPT
GPT 结构对初始化非常敏感,GPT 使用残差连接 + LayerNorm,梯度传播对初始权重分布非常依赖。所以在初始化的时候Linear和Embedding的weight的mean都初始化为0
*CLIP
CLIP更复杂,只初始化关键敏感部件,如embedding、positional encoding、attention等。
***目前总结到的经验
*建议手动初始化:
模块类型 | 初始化建议 | 原因 |
---|---|---|
Embedding | 手动正态初始化(如 std=0.01~0.02) | 防止稀疏索引导致偏置 |
Q/K/V Linear | 手动初始化(如 std=1/√d_k ) | 防止 attention dot-product 初始值爆炸 |
Positional Embedding | 正态初始化 | 因为是 learnable 参数,数值不宜过大 |
残差 block 最后一层(如 BN3) | 初始化为 0 | 初始退化为恒等映射,提高收敛性 |
任何“关键分支”的 projection 层 | 建议初始化 | 如 CLIP 的 text_projection , image_projection |
一般不主动初始化:
模块类型 | 理由 |
---|---|
Conv2d , Linear | 默认初始化已很好,除非有论文要求 |
LayerNorm , BatchNorm | 默认 weight=1 , bias=0 是最优策略 |
非残差中的普通线性层 | 默认即可 |
(3)初始化时std的设置
① attn_std = self.transformer.width ** -0.5
标准的transformer初始化方法
②fc_std = (2 * self.transformer.width) ** -0.5
用于初始化FFN中的前向Linear层,第一层输出通道很大(通常是 4×),为了避免输出激活过大,std 要适当减小。
x → Linear(width, 4*width) → GELU → Linear(4*width, width)
③proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
用于 Residual AttentionBlock 最后投影的 Linear 层
来源:来自论文 Understanding the Difficulty of Training Transformers,特别适用于 深层 Transformer(如 GPT-3, CLIP)。
核心思想是:
如果模型深度是 L 层,那每个 residual branch 叠加的方差也会增加,应该将其 std 缩小为 1/sqrt(2L)以稳定整体输出。
6.关于build_model的参数的使用
(1)
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
这里使用visual.conv1.weight的第一个维度的大小作为width,conv2d的weight的形状是(out_channle, in_channel, patch_size[0], patch_size[1])。
另外这里补充一下ViT patch和传统CNN卷积核的区别:
传统CNN是使用多个小卷积堆叠构建大感受野(kernel_size较小,stride小于kernel_size允许重叠),而ViT是使用一个大kernel,把整块patch当作token(kernel_size较大,stride=kernel_size,即不重复采样)
(2)
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
每个 Transformer block 里会有一个 nn.MultiheadAttention
模块,对应权重名如:visual.transformer.resblocks.0.attn.in_proj_weight
(3)
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
这里image_resolution是因为
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
(4)几个易混淆的概念
名字 | 意义 | 举例值 | 类似于 |
---|---|---|---|
vision_width | 通道维度 | 64、128、256 等 | CNN 中的输出 channels |
output_width | 特征图尺寸 | 7、14 等 | feature map 的宽度 |
patch_size | patch 的边长 | 32 | ViT 中的切片大小‘ |
(5)ResNet中image_resolution = output_width * 32
*32是因为在ResNet中总共下采样了5次
模块 | 操作类型 | 输出尺寸 |
---|---|---|
conv1 | stride=2 | 变成 H/2 × W/2 |
stem_pool | AvgPool2d(2) | 变成 H/4 × W/4 |
layer1 | 无下采样 | 尺寸不变 |
layer2 | stride=2 | 变成 H/8 × W/8 |
layer3 | stride=2 | 变成 H/16 × W/16 |
layer4 | stride=2 | 变成 H/32 × W/32 ✅ 最终输出 |
attnpool | 空间尺寸 = H/32 × W/32 |
(6)删除state_dict中的一些辅助信息字段
for key in ["input_resolution", "context_length", "vocab_size"]:if key in state_dict:del state_dict[key]
这些不是模型参数的一部分,加载模型权重前必须删掉,否则会引起state_dict键不匹配