最近因为工作需求,接触了Latent Diffusion中VAE训练的相关代码,其中损失函数是由名为LPIPSWithDiscriminator的类进行计算的,包括像素级别的重建损失(rec_loss)、感知损失(p_loss)和基于判别器(g_loss)的对抗损失等。在阅读源码过程中有很多疑问,因此用该博客记录一下学习过程和相关思考,如有不对的地方欢迎评论区批评指正。
LPIPSWithDiscriminator
LPIPSWithDiscriminator主要包含以下三个方法,本文主要关注前向过程forward()和自适应权重计算方法calculate_adaptive_weight():
class LPIPSWithDiscriminator(nn.Module):# 初始化相关属性和方法def __init__(self, ):pass# 根据负对数损失(nll_loss)和对抗损失的梯度自适应计算对抗损失的权重def calculate_adaptive_weight(self, ):pass# 前向传播过程,计算各种损失def forward():pass
首先附上前向传播方法的源码,建议先结合注释把源码大概顺一遍:
def forward(self, inputs, reconstructions, posteriors, optimizer_idx,global_step, last_layer=None, cond=None, split="train",weights=None):# 图像级别重建损失rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())if self.perceptual_weight > 0:# 感知损失p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())# 问题1:为什么要把感知损失加到每个像素位置的重建损失上???rec_loss = rec_loss + self.perceptual_weight * p_loss# negative log-likelihood loss# 问题2:为什么要进行这种处理?nll_loss = rec_loss / torch.exp(self.logvar) + self.logvarweighted_nll_loss = nll_lossif weights is not None:weighted_nll_loss = weights*nll_lossweighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]# 和标准正态分布计算KL散度kl_loss = posteriors.kl()kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]# now the GAN partif optimizer_idx == 0:# generator updateif cond is None:assert not self.disc_conditionallogits_fake = self.discriminator(reconstructions.contiguous())else:assert self.disc_conditionallogits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))g_loss = -torch.mean(logits_fake)if self.disc_factor > 0.0:try:# 问题3:为什么计算自适应权重d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)except RuntimeError:assert not self.trainingd_weight = torch.tensor(0.0)else:d_weight = torch.tensor(0.0)disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_losslog = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),"{}/rec_loss".format(split): rec_loss.detach().mean(),"{}/d_weight".format(split): d_weight.detach(),"{}/disc_factor".format(split): torch.tensor(disc_factor),"{}/g_loss".format(split): g_loss.detach().mean(),}return loss, logif optimizer_idx == 1:# second pass for discriminator updateif cond is None:logits_real = self.discriminator(inputs.contiguous().detach())logits_fake = self.discriminator(reconstructions.contiguous().detach())else:logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),"{}/logits_real".format(split): logits_real.detach().mean(),"{}/logits_fake".format(split): logits_fake.detach().mean()}return d_loss, log
我看完源码,有三个问题:
- 问题1(未解决):为什么要把感知损失p_loss和未求均值的重建损失rec_loss相加?因为通过debug可知,p_loss没有空间维度,是单个值,但rec_loss是保留空间维度的,例如输入数据是一张三通道512×512大小的图像, 则rec_loss的维度为(1,3,512,512),由于广播机制,两者相加会导致感知损失p_loss会被添加到每个像素位置的重建损失rec_loss上。这个问题我到现在还没有理解,如果有懂的大佬可以在评论区指出。
- 问题2:为什么负对数损失的计算公式是那样?它的意义是什么?
- 问题3:为什么要自适应的计算对抗损失的权重系数?
问题2
为什么要使用负对数损失?
负对数损失只是最后结果的形式,我们还需要了解具体的过程(我个人的理解,不一定正确):我们通常假设数据服从某种概率分布(通常为高斯分布),因此我们也可以把重建误差看作服从均值为0,方差为的高斯分布:
其概率密度函数为:
对两边取负对数:
不考虑常数部分的话进一步可化简为:
上述公式的具体含义是每个像素位置重建误差,但注意,代码实现里假设每个像素位置的方差是一样的,即默认图像所有区域重建的不确定性(难度)是一样的,这显然是一种简易的方法。我们可以将其看作是对重建误差的缩放,一方面,第二项正则项要求方差越小越好;另一方面,也进一步约束第一项中分子部分的重建误差越小越好。对数方差作为模型的参数是可学习的。
但这里还有个小问题,按照上述公式,重建误差应该使用L2损失,但代码中则使用了L1损失,具体为什么我没有搞懂。。。。
问题3
为什么要自适应的计算对抗损失的权重系数?
首先说答案:目的是为了平衡负对数损失和对抗损失,避免由单个损失主导模型优化的方向,稳定训练过程。结合下面的源码来了解动态权重是如何计算的:
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):if last_layer is not None:# 得到nll_loss对于last_layer参数的梯度nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]else:nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]# 根据两者梯度的比值计算对抗损失的权重d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)# 限制权重的大小范围d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()d_weight = d_weight * self.discriminator_weightreturn d_weight
在讲解源码具体作用前,我们首先需要回忆一下损失函数在模型训练过程中的作用:
在模型优化过程中,起决定作用的是损失函数的梯度,而不是具体的损失函数数值。以爬山举例:损失函数的数值表示你现在的海拔,梯度则表示你所处位置有多陡峭;如果你在山顶(loss很大),但脚下很平缓(梯度很小),那么对应的模型参数更新缓慢;如果你在半山腰(loss不大),但坡很陡峭(梯度很大),对应模型参数更新较快。如果存在多个损失函数,那么这多个损失函数的梯度共同决定了模型参数优化的方向(想象向量相加的结果)和大小(不考虑学习率的话)。结合具体任务,重建损失(图像像素差异)和对抗损失(真实性)相当于从不同角度评价了生成图像的质量。
回到代码,语句torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]可得到负对数损失nll_loss对于last_layer模型参数的梯度nll_grads,在实际训练过程中,last_layer传入的是VAE-decoder最后一层参数,因为这一层离图像空间最近;同理也可得到对抗损失g_loss对于last_layer的梯度g_grads。对抗损失权重计算公式为d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)。如果没有g_grads很大,那么模型优化方向主要由判别损失主导,反之则会由负对数损失主导,因此需要在训练过程中动态的计算判别损失的权重,避免让某个损失主导模型优化方向。当然这相当于一个训练的trick,并不是必须的。