1. 从自编码器(AE)到变分自编码器(VAE)
自编码器(AutoEncoder, AE)
基本结构:
自编码器是一种无监督学习模型,通常由两个部分组成:
- 编码器(Encoder):将输入图像映射为一个低维特征向量(code);
- 解码器(Decoder):将该特征向量还原回图像。
训练的目标是使解码器输出的图像尽可能接近原始输入图像,即最小化重构误差。
应用场景:
AE 广泛应用于图像压缩领域。例如,谷歌曾使用自编码器进行图片传输优化:仅传输编码器提取的特征表示,在客户端通过解码器还原图像,从而实现高效的数据传输。
局限性:
尽管 AE 在图像重建方面表现良好,但它不具备生成新图像的能力。因为其编码空间中的每个点都来源于真实图像的编码结果,无法直接通过随机采样得到有意义的图像。
思考与改进:如何让 AE 具备图像生成能力?
核心问题:
传统 AE 的特征向量(code)是一个确定性的点,只有这个点可以被解码为图像,其他位置可能生成噪声或无效图
解决思路:
为了实现图像生成能力,我们希望编码器输出的特征向量能够在潜变量空间中具有某种统计分布特性,使得我们可以从该分布中任意采样,并通过解码器生成合理的图像。
我们将“新月”图片映射成 μ = 1 μ=1 μ=1的正态分布,那么就相当于在1附近加了噪声,此时不仅1表示“新月”,1附近的数值也表示“新月”,只是1的时候最像“新月”。将"满月"映射成 μ = 10 μ=10 μ=10的正态分布,10的附近也都表示“满月”。那么code=5时,就同时拥有了“新月”和“满月”的特点,那么这时候decode出来的大概率就是“半月”了。这就是VAE的思想。
变分自编码器(Variational AutoEncoder, VAE)
核心思想
VAE 的核心在于将编码过程从“输出一个点”转变为“输出一个分布”,通常是高斯分布。具体来说:
- 编码器不再输出单一的特征向量,而是输出一个均值和方差;
- 然后通过采样方式,从该分布中获得一个具体的特征向量;
- 再将其输入解码器进行图像重建。
这样,整个编码空间就变得连续且有规律,任何符合该分布的点都可以被解码为一张有意义的图像。
Vae的结构:
变分自编码器(VAE)与传统自编码器(AE)在整体架构上相似,都包含一个编码器(Encoder)和一个解码器(Decoder)。它们的核心区别在于编码器输出的内容不同。
🔹 自编码器(AE)
- 编码器直接输出一个确定性的向量 z z z,也称为潜变量(latent code);
- 这个向量是唯一的,不能随机采样;
- 解码器仅能对来自编码器的点进行有效解码,其他位置的向量可能会生成无意义的图像。
🔹变分自编码器(VAE)
- 编码器不再输出单一的潜变量 z z z,而是输出:
- 一组正态分布的均值: μ 1 , μ 2 , . . . . . , μ n μ_1,μ_2,.....,μ_n μ1,μ2,.....,μn
- 一组对应的标准差: σ 1 , σ 2 , . . . . . . , σ n σ_1,σ_2,......,σ_n σ1,σ2,......,σn
- 即对于每个维度 i i i,假设其服从高斯分布:
z i ∼ N ( μ i , σ i 2 ) z_i \sim \mathcal{N}(\mu_i, \sigma_i^2) zi∼N(μi,σi2)
然后通过重参数化技巧(reparameterization trick),从这些分布中采样出具体的潜变量
z i = μ i + σ i ⋅ ϵ , 其中 ϵ ∼ N ( 0 , 1 ) z_i=\mu_i + \sigma_i \cdot \epsilon, \quad \text{其中 } \epsilon \sim \mathcal{N}(0, 1) zi=μi+σi⋅ϵ,其中 ϵ∼N(0,1)
最终得到潜向量:
z = ( z 1 , z 2 , . . . , z n ) z=(z_1,z_2,...,z_n) z=(z1,z2,...,zn)
将其输入解码器进行重建。
公式推导:
假设我们有 N N N 个样本:
X = ( x 1 , x 2 , ⋯ , x N ) X= \left( x^{1},x^{2}, \cdots,x^{N} \right) X=(x1,x2,⋯,xN)
对其进行似然估计:
∑ i = 1 N log P ( x i ) \begin{split}&\sum_{i=1}^{N}\log P(x^{i})\\ \end{split} i=1∑NlogP(xi)
先单独看看里面某一个样本的似然,某个样本记为 x x x:
log P ( x ) = log P ( x , z ) P ( z ∣ x ) \log P \left( x \right)= \log \frac{P \left( x,z \right)}{P \left( z|x \right)} logP(x)=logP(z∣x)P(x,z)
最大似然估计
在概率论与数理统计中我们就学习过极大似然估计法,它的用途是从数据中估计概率模型的参数。所谓训练一个生成模型,就是给定一堆样本 x ( 1 ) , … , x ( N ) x^{(1)},…,x^{(N)} x(1),…,x(N),从这堆样本中估计出最优的模型参数 θ ∗ θ^∗ θ∗,要做的事情其实是一样的。(都是利用部分数据去估计总体分布)
这个分布通俗来讲可以理解为数据的共性特征,拿猫的图片来训练VAE,它能生成和训练集不一样的猫,但是它不会生成狗出来,它不是“创造”数据,而是“模仿”训练数据所代表的分布
先验分布 p ( z ) p(z) p(z) 是我们设定的分布,通常是标准正态分布:
z ∼ p ( z ) = N ( 0 , I ) z \sim p(z)=\mathcal{N}(0, I) z∼p(z)=N(0,I)
根据贝叶斯公式:
两个随机变量 x x x 和 z z z 的联合概率分布可以写成两种形式: p ( x , z ) = p ( z ∣ x ) p ( x ) = p ( x ∣ z ) p ( z ) p(x, z)=p(z|x)p(x)=p(x|z)p(z) p(x,z)=p(z∣x)p(x)=p(x∣z)p(z)
p ( z ∣ x ) = p ( x ∣ z ) p ( z ) p ( x ) p(z|x)=\frac{p(x|z)p(z)}{p(x)} p(z∣x)=p(x)p(x∣z)p(z)
很明显分母 p ( x ) = ∫ p ( x ∣ z ) p ( z ) d z p(x)=\int p(x|z) p(z) \, dz p(x)=∫p(x∣z)p(z)dz是个高维积分,无法直接计算准确的后验分布 P ( z ∣ x ) P(z|x) P(z∣x),用的是后验分布的一个近似,记为 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(z∣x),其中 ϕ {\phi} ϕ 代表的是 q q q 分布的参数, q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(z∣x) 相当于编码器, P θ ( x ∣ z ) P_θ(x|z) Pθ(x∣z) 相当于解码器,训练结束后只需要保留解码器即可,解码器就是我们想要的生成模型。
与前面提到的似然 P θ ( x ) P_{ \theta} \left( x \right) Pθ(x)类似,为了表达的方便,同样隐去参数,得到 q ( x ∣ z ) q \left( x|z \right) q(x∣z)
log P ( x ) \log P \left( x \right) logP(x)等式左右分别在 z 服从 q ( z ∣ x ) q \left( z|x \right) q(z∣x)分布的条件下求期望
左边是 log P ( x ) \log P \left( x \right) logP(x),它不依赖于 z z z,相当于对常数求期望,所以对任何关于 z z z 的分布求期望还是它自己:
E z ∼ q ( z ∣ x ) [ log P ( x ) ] = log P ( x ) \mathbb{E}_{z\sim q(z|x)}[\log P \left( x \right)]= \log P \left( x \right) Ez∼q(z∣x)[logP(x)]=logP(x)
右边展开并拆解:
E z ∼ q ( z ∣ x ) [ log P ( x , z ) P ( z ∣ x ) ] = E z ∼ q ( z ∣ x ) [ log P ( x , z ) − log P ( z ∣ x ) ] \mathbb{E}_{z \sim q(z|x)} \left[\log \frac{P(x, z)}{P(z|x)} \right]=\mathbb{E}_{z \sim q(z|x)} \left[\log P(x, z) - \log P(z|x) \right] Ez∼q(z∣x)[logP(z∣x)P(x,z)]=Ez∼q(z∣x)[logP(x,z)−logP(z∣x)]
利用期望的线性性质:
= E z ∼ q ( z ∣ x ) [ log P ( x , z ) ] − E z ∼ q ( z ∣ x ) [ log P ( z ∣ x ) ] =\mathbb{E}_{z\sim q(z|x)}[\log P(x,z)]-\mathbb{E}_{z\sim q(z|x)}[ \log P(z|x)] =Ez∼q(z∣x)[logP(x,z)]−Ez∼q(z∣x)[logP(z∣x)]
进一步整理可以把联合分布 P ( x , z ) P(x,z) P(x,z)写成:
P ( x , z ) = P ( x ∣ z ) P ( z ) {P(x,z)=P(x|z)P(z)} P(x,z)=P(x∣z)P(z)
所以:
log P ( x , z ) = log P ( x ∣ z ) + log P ( z ) \log P \left( x,z \right)= \log P \left( x|z \right)+ \log P \left( z \right) logP(x,z)=logP(x∣z)+logP(z)
代入上面的式子:
E z ∼ q ( z ∣ x ) [ log P ( x , z ) ] = E z ∼ q ( z ∣ x ) [ log P ( x ∣ z ) ] + E z ∼ q ( z ∣ x ) [ log P ( z ) ] \mathbb{E}_{z\sim q(z|x)}[\log P(x,z)]=\mathbb{E}_{z\sim q(z|x)}[\log P(x|z)]+ \mathbb{E}_{z\sim q(z|x)}[\log P(z)] Ez∼q(z∣x)[logP(x,z)]=Ez∼q(z∣x)[logP(x∣z)]+Ez∼q(z∣x)[logP(z)]
整理一下:
log P ( x ) = E z ∼ q ( z ∣ x ) [ log P ( x ∣ z ) ] − D K L ( q ( z ∣ x ) ∥ P ( z ) ) \log P(x)=\mathbb{E}_{z\sim q(z|x)}[\log P(x|z)]-D_{\mathrm{KL}}(q(z|x)\|P(z)) logP(x)=Ez∼q(z∣x)[logP(x∣z)]−DKL(q(z∣x)∥P(z))
有时候还可以写成:
log P ( x ) = E z ∼ q [ log P ( x , z ) − log q ( z ) ] + D K L ( q ( z ) ∥ P ( z ∣ x ) ) \log P(x)=\mathbb{E}_{z\sim q}[\log P(x,z)-\log q(z)]+D_{\mathrm{KL}}(q(z)\|P( z|x)) logP(x)=Ez∼q[logP(x,z)−logq(z)]+DKL(q(z)∥P(z∣x))
D K L ( q ( z ) ∥ p ( z ∣ x ) ) D_{\mathrm{KL}}(q(z)\|p( z|x)) DKL(q(z)∥p(z∣x))是大于等于0的(当且仅当两个分布一致时等于0)所以:
log p ( x ) ≥ E z ∼ q [ log p ( x , z ) − log q ( z ) ] ≡ C q \log p(x) \geq \mathbb{E}_{z \sim q}[\log p(x, z) - \log q(z)] \equiv \mathcal{C}_q logp(x)≥Ez∼q[logp(x,z)−logq(z)]≡Cq
这个下界就是 ELBO(Evidence Lower BOund):
log p ( x ) ≥ E z ∼ q [ log p ( x , z ) − log q ( z ) ] ⏟ ELBO \boxed{\log p \left( x \right) \geq \underbrace{\mathbb{E}_{z \sim q}[ \log p \left( x,z \right)- \log q \left( z \right)]}_{\text{ELBO}}} logp(x)≥ELBO Ez∼q[logp(x,z)−logq(z)]
也叫变分下界:
log P ( x ) ≥ ∫ z log P ( x , z ) q ( z ∣ x ) q ( z ∣ x ) d z \log P \left( x \right) \geq \int_{z} \log \frac{P \left( x,z \right)}{q \left( z|x \right)}q \left( z|x \right) dz logP(x)≥∫zlogq(z∣x)P(x,z)q(z∣x)dz
所以优化目标从最大化似然函数,转换到最大化似然的下界:
max θ , ϕ ∫ z log P ( x , z ) q ( z ∣ x ) q ( z ∣ x ) d z \max_{\theta,\phi} \int_{z} \log \frac{P \left( x,z \right)}{q \left( z|x \right)}q \left( z|x \right) dz θ,ϕmax∫zlogq(z∣x)P(x,z)q(z∣x)dz
= max ∫ z log P ( x ∣ z ) P ( z ) q ( z ∣ x ) q ( z ∣ x ) d z =\max \int_{z} \log \frac{P \left( x|z \right) P \left( z \right)}{q \left( z|x \right)}q \left( z|x \right) dz =max∫zlogq(z∣x)P(x∣z)P(z)q(z∣x)dz
= max ∫ z ( log P ( z ) q ( z ∣ x ) + log P ( x ∣ z ) ) q ( z ∣ x ) d z = \max \int_{z} \left( \log \frac{P \left( z \right)}{q \left( z|x \right)}+ \log P \left( x|z \right) \right) q \left( z|x \right) dz =max∫z(logq(z∣x)P(z)+logP(x∣z))q(z∣x)dz
= max ∫ z log P ( x ∣ z ) q ( z ∣ x ) d z − ∫ log q ( z ∣ x ) P ( z ) q ( z ∣ x ) d z =\max\int_{z}\log P(x|z)q(z|x)dz-\int\log\frac{q(z|x)}{P(z)}q(z|x )dz =max∫zlogP(x∣z)q(z∣x)dz−∫logP(z)q(z∣x)q(z∣x)dz
= max ( E z ∼ q ( z ∣ x ) [ log P ( x ∣ z ) ] ⏟ (1) − K L ( q ( z ∣ x ) ∥ P ( z ) ) ⏟ (2) ) =\max \left( \underbrace{\mathbb{E}_{z \sim q(z|x)}[\log P(x|z)]}_{\text{(1)}} - \underbrace{KL(q(z|x) \| P(z))}_{\text{(2)}} \right) =max (1) Ez∼q(z∣x)[logP(x∣z)]−(2) KL(q(z∣x)∥P(z))
第一项就是从 q ( z ∣ x ) q(z|x) q(z∣x) 采样 z z z,再让 log P ( x ∣ z ) \log P \left( x|z \right) logP(x∣z)概率最大,重构代价最小,这是可以通过神经网路训练实现的。 同时还要最小化KL散度
重构损失希望 σ → 0 σ→0 σ→0,一个分布反差为0 说明这个随机变量取确定值的概率为1,这恰好就是AE的训练目标,但 KL散度项起到了“平衡”的作用:
KL 散度loss的公式如下:(后面有证明过程)
min 1 2 ∑ j = 1 J ( σ j 2 + μ j 2 − log σ j 2 − 1 ) \min \frac{1}{2} \sum_{j=1}^{J} \left( \sigma_{j}^{2}+ \mu_{j}^{2}- \log \sigma_{j}^{2}-1 \right) min21j=1∑J(σj2+μj2−logσj2−1) 注意最后一项: − l o g σ j 2 −logσ_j^2 −logσj2
- 当 σ → 0 σ→0 σ→0 时, l o g σ j 2 → − ∞ logσ_j^2→-∞ logσj2→−∞ , , , − l o g σ j 2 → + ∞ -logσ_j^2→+∞ −logσj2→+∞
- 导致 KL 损失变得非常大,从而增加总损失
- 模型为了减少总损失,就会避免让 σ σ σ 太小
细化目标函数:
1. 最小化KL散度
q ( z ∣ x ) q(z|x) q(z∣x) 需要逼近 p ( z ) p(z) p(z),而 P ( z ) ∼ N ( 0 , I ) P \left( z \right) \sim N \left( 0,I \right) P(z)∼N(0,I)的多维高斯分布,并且各个维度之间相互独立,要让KL最小, q ( z ∣ x ) q(z|x) q(z∣x)也要是一个多维高斯分布
那么最小化其KL散度,只需要对每一个维度求KL最小即可,单独看某一个维度,设某一个维度为 z j z_j zj ,设 q ( z j ∣ x ) ∼ N ( μ ϕ , σ ϕ 2 ) q \left( z_{j}|x \right) \sim N \left( \mu_{ \phi}, \sigma_{ \phi}^{2} \right) q(zj∣x)∼N(μϕ,σϕ2)(后续为了简便,也将参数 ϕ \phi ϕ 隐去):
min K L ( q ( z j ∣ x ) ∣ ∣ P ( z j ) ) \min KL \left( q \left( z_{j}|x \right)||P \left( z_{j} \right) \right) minKL(q(zj∣x)∣∣P(zj))
= min ∫ q ( z j ∣ x ) log q ( z j ∣ x ) P ( z j ) d z j = \min \int q \left( z_{j}|x \right) \log \frac{q \left( z_{j}|x \right)}{P \left( z_{j} \right)}dz_{j} =min∫q(zj∣x)logP(zj)q(zj∣x)dzj
= min ∫ q ( z j ∣ x ) log N ( μ , σ 2 ) N ( 0 , 1 ) d z j = \min \int q \left( z_{j}|x \right) \log \frac{N \left( \mu, \sigma^{2} \right)}{N \left( 0,1 \right)}dz_{j} =min∫q(zj∣x)logN(0,1)N(μ,σ2)dzj
= min ∫ q ( z j ∣ x ) log 1 2 π σ e x p { − ( z j − μ ) 2 2 σ 2 } 1 2 π e x p { − z j 2 2 } d z j =\min\int q(z_{j}|x)\log\frac{\frac{1}{\sqrt{2\pi}\sigma}\mathrm{ exp}\{-\frac{(z_{j}-\mu)^{2}}{2\sigma^{2}}\}}{\frac{1}{\sqrt{2\pi}}\mathrm{ exp}\{-\frac{z_{j}^{2}}{2}\}}dz_{j} =min∫q(zj∣x)log2π1exp{−2zj2}2πσ1exp{−2σ2(zj−μ)2}dzj
= min ∫ q ( z j ∣ x ) log 1 σ exp { − ( z j − μ ) 2 2 σ 2 } exp { − z j 2 2 } d z j = \min \int q \left( z_{j}|x \right) \log \frac{ \frac{1}{ \sigma} \exp \{- \frac{ \left( z_{j}- \mu \right)^{2}}{2 \sigma^{2}} \}}{ \exp \{- \frac{z_{j}^{2}}{2} \}}dz_{j} =min∫q(zj∣x)logexp{−2zj2}σ1exp{−2σ2(zj−μ)2}dzj
= min ∫ q ( z j ∣ x ) ( log 1 σ + log exp { − ( z j − μ ) 2 2 σ 2 } exp { − z j 2 2 } ) d z j =\min\int q(z_{j}|x)\left(\log\frac{1}{\sigma}+\log\frac{\exp\{- \frac{(z_{j}-\mu)^{2}}{2\sigma^{2}}\}}{\exp\{-\frac{z_{j}^{2}}{2}\}}\right)dz_ {j} =min∫q(zj∣x) logσ1+logexp{−2zj2}exp{−2σ2(zj−μ)2} dzj
= min ∫ q ( z j ∣ x ) ( log 1 σ + log exp { − ( z j − μ ) 2 2 σ 2 + z j 2 2 } ) d z j =\min\int q(z_{j}|x)\left(\log\frac{1}{\sigma}+\log\exp\{-\frac{( z_{j}-\mu)^{2}}{2\sigma^{2}}+\frac{z_{j}^{2}}{2}\}\right)dz_{j} =min∫q(zj∣x)(logσ1+logexp{−2σ2(zj−μ)2+2zj2})dzj
log以e为底:
= min ∫ q ( z j ∣ x ) ( log 1 σ − ( z j − μ ) 2 2 σ 2 + z j 2 2 ) d z j = \min \int q \left( z_{j}|x \right) \left( \log \frac{1}{ \sigma}- \frac{ \left( z_{j}- \mu \right)^{2}}{2 \sigma^{2}}+ \frac{z_{j}^{2}}{2} \right) dz_{j} =min∫q(zj∣x)(logσ1−2σ2(zj−μ)2+2zj2)dzj
= min ∫ q ( z j ∣ x ) 1 2 ( 2 log 1 σ − ( z j − μ ) 2 σ 2 + z j 2 ) d z j = \min \int q \left( z_{j}|x \right) \frac{1}{2} \left( 2 \log \frac{1}{ \sigma}- \frac{ \left( z_{j}- \mu \right)^{2}}{ \sigma^{2}}+z_{j}^{2} \right) dz_{j} =min∫q(zj∣x)21(2logσ1−σ2(zj−μ)2+zj2)dzj
即:
min K L ( q ( z j ∣ x ) ∣ ∣ P ( z j ) ) = min 1 2 ∫ q ( z j ∣ x ) ( z j 2 − log σ 2 − ( z j − μ ) 2 σ 2 ) d z j \min KL(q(z_{j}|x)||P(z_{j})) = \min\frac{1}{2}\int q(z_{j}|x)\left(z_{j}^{2 }-\log\sigma^{2}-\frac{(z_{j}-\mu)^{2}}{\sigma^{2}}\right)dz_{j} minKL(q(zj∣x)∣∣P(zj))=min21∫q(zj∣x)(zj2−logσ2−σ2(zj−μ)2)dzj
可以分为三部分:
min K L ( q ( z j ∣ x ) ∣ ∣ P ( z j ) ) = min 1 2 ( ∫ q ( z j ∣ x ) z j 2 d z j − ∫ q ( z j ∣ x ) log σ 2 d z j − ∫ z q ( z j ∣ x ) ( z j − μ ) 2 σ 2 d z j ) \min KL(q(z_{j}|x)||P(z_{j}))=\min\frac{1}{2}\Biggl(\int q(z_{j}|x)z_{j}^{2} dz_{j}-\int q(z_{j}|x)\log\sigma^{2}dz_{j}-\int_{z}q(z_{j}|x)\frac{(z_{j}-\mu)^{2}}{ \sigma^{2}}dz_{j}\Biggr) minKL(q(zj∣x)∣∣P(zj))=min21(∫q(zj∣x)zj2dzj−∫q(zj∣x)logσ2dzj−∫zq(zj∣x)σ2(zj−μ)2dzj)
- ∫ z q ( z j ∣ x ) z j 2 d z = E [ z j 2 ] = D ( z j ) + E [ z j ] 2 = σ 2 + μ 2 \int_{z}q(z_{j}|x)z_{j}^{2}dz=\mathbb{E}[z_{j}^{2}]=D(z_{j})+\mathbb{E}[z_{j}] ^{2}=\sigma^{2}+\mu^{2} ∫zq(zj∣x)zj2dz=E[zj2]=D(zj)+E[zj]2=σ2+μ2
D ( X ) = E ( X 2 ) − [ E ( X ) ] 2 D \left( X \right)=E \left( X^{2} \right)- \left[ E \left( X \right) \right]^{2} D(X)=E(X2)−[E(X)]2
- ∫ z q ( z j ∣ x ) log σ 2 d z j = log σ 2 ∫ z q ( z j ∣ x ) d z j = log σ 2 \int_{z}q(z_{j}|x)\log\sigma^{2}dz_{j}=\log\sigma^{2}\int_{z}q(z_{j}|x)dz_{j}= \log\sigma^{2} ∫zq(zj∣x)logσ2dzj=logσ2∫zq(zj∣x)dzj=logσ2
- ∫ z q ( z j ∣ x ) ( z j − μ ) 2 σ 2 d z j = 1 σ 2 ∫ z q ( z j ∣ x ) ( z j − μ ) 2 d z j = 1 σ 2 E [ ( z j − μ ) 2 ] = 1 \int_{z}q(z_{j}|x)\frac{(z_{j}-\mu)^{2}}{\sigma^{2}}dz_{j}=\frac{1}{\sigma^{2} }\int_{z}q(z_{j}|x)(z_{j}-\mu)^{2}dz_{j}=\frac{1}{\sigma^{2}}\mathbb{E}[(z_{j} -\mu)^{2}]=1 ∫zq(zj∣x)σ2(zj−μ)2dzj=σ21∫zq(zj∣x)(zj−μ)2dzj=σ21E[(zj−μ)2]=1
第三步公式推导如下:
1 σ 2 E [ ( z j − μ ) 2 ] \frac{1}{ \sigma^{2}}E \left[ \left( z_{j}- \mu \right)^{2} \right] σ21E[(zj−μ)2]
= 1 σ 2 E [ z j 2 − 2 μ z j + μ 2 ] = \frac{1}{ \sigma^{2}}E \left[ z_{j}^{2}-2 \mu z_{j}+ \mu^{2} \right] =σ21E[zj2−2μzj+μ2]
= 1 σ 2 ( E ( z j 2 ) − 2 μ E ( z j ) + μ 2 ) = \frac{1}{ \sigma^{2}} \left( E \left( z_{j}^{2} \right)-2 \mu E \left( z_{j} \right)+ \mu^{2} \right) =σ21(E(zj2)−2μE(zj)+μ2)
= 1 σ 2 ( σ 2 + μ 2 − 2 μ 2 + μ 2 ) = \frac{1}{ \sigma^{2}} \left( \sigma^{2}+ \mu^{2}-2 \mu^{2}+ \mu^{2} \right) =σ21(σ2+μ2−2μ2+μ2)
所以对于所有维度的KL散度,有(假设隐变量有J维):
min 1 2 ∑ j = 1 J ( σ j 2 + μ j 2 − log σ j 2 − 1 ) \min \frac{1}{2} \sum_{j=1}^{J} \left( \sigma_{j}^{2}+ \mu_{j}^{2}- \log \sigma_{j}^{2}-1 \right) min21j=1∑J(σj2+μj2−logσj2−1)
2. 最小重构损失:
E z ∼ q ( z ∣ x ) [ log P ( x ∣ z ) ] {\mathbb{E}_{z \sim q(z|x)}[\log P(x|z)]} Ez∼q(z∣x)[logP(x∣z)]
这个积分通常没有解析解,只能通过采样估计。
E z ∼ q ( z ∣ x ) [ log P ( x ∣ z ) ] = ∫ z log P ( x ∣ z ) q ( z ∣ x ) d z ≈ 1 n ∑ i = 1 n log P ( x i ∣ z i ) \mathbb{E}_{z \sim q( z \mid x )}\left[\log P(x|z)\right]= \int_{z} \log P(x|z)q(z|x)dz \approx\frac{1}{n}\sum_{i=1}^{n}\log P(x^{i}|z^{i}) Ez∼q(z∣x)[logP(x∣z)]=∫zlogP(x∣z)q(z∣x)dz≈n1i=1∑nlogP(xi∣zi)
使用蒙特卡罗近似,从 q ( z ∣ x ; ϕ ) q(z|x;ϕ) q(z∣x;ϕ)采样n个z样本,然后带入 log P ( x ∣ z ) \log P \left( x|z \right) logP(x∣z),求出均值,不过在实践中,对于每个样本 x,只需要从 q ( z ∣ x ; ϕ ) {q( z \mid x;ϕ )} q(z∣x;ϕ)采样一次 z 就能正常训练 VAE 了,效率很高,可见重要性采样的思想还是很有效的。
所以我们可以写成:
E z ∼ q ( z ∣ x ; ϕ ) [ log p ( x ∣ z ; θ ) ] ≈ log p ( x ∣ z ; θ ) , z ∼ q ( z ∣ x ; ϕ ) \mathbb{E}_{z \sim q(z|x; \phi)}[\log p(x|z; \theta)] \approx \log p(x|z; \theta), \quad z \sim q(z|x; \phi) Ez∼q(z∣x;ϕ)[logp(x∣z;θ)]≈logp(x∣z;θ),z∼q(z∣x;ϕ)
在论文中,其假设 P ( x ∣ z ) ∼ N ( μ , σ 2 I ) P \left( x|z\right) \sim N \left( \mu, \sigma^{2}I \right) P(x∣z)∼N(μ,σ2I),其中 μ \mu μ 和 σ \sigma σ 是需要神经网络去逼近,
由于各维度相互独立,概率密度函数形式如下:
p ( x ∣ z ; θ ) = ∏ i = 1 D N ( x i ; μ i , σ i 2 ) = ( ∏ i = 1 D 1 2 π σ i ) exp ( − 1 2 ∑ i = 1 D ( x i − μ i ) 2 σ i 2 ) p(x|z; \theta)=\prod_{i=1}^{D} \mathcal{N}(x_i; \mu_i, \sigma_i^2)=\left( \prod_{i=1}^{D} \frac{1}{\sqrt{2\pi\sigma_i}} \right) \exp\left( -\frac{1}{2} \sum_{i=1}^{D} \frac{(x_i - \mu_i)^2}{\sigma_i^2} \right) p(x∣z;θ)=i=1∏DN(xi;μi,σi2)=(i=1∏D2πσi1)exp(−21i=1∑Dσi2(xi−μi)2)
取对数后得到:
log p ( x ∣ z ; θ ) = − D 2 log ( 2 π ) − 1 2 ∑ i = 1 D log σ i 2 − 1 2 ∑ i = 1 D ( x i − μ i ) 2 σ i 2 \log p(x|z;\theta)=-\frac{D}{2} \log(2\pi) - \frac{1}{2} \sum_{i=1}^{D} \log \sigma_i^2 - \frac{1}{2} \sum_{i=1}^{D} \frac{(x_i - \mu_i)^2}{\sigma_i^2} logp(x∣z;θ)=−2Dlog(2π)−21i=1∑Dlogσi2−21i=1∑Dσi2(xi−μi)2
为了简化模型,通常会做以下两个简化假设:
- 所有维度的方差相同,即 σ i 2 = σ 2 σ_i^2=σ^2 σi2=σ2是一个常数;
- 解码器只需输出均值 μ μ μ,而不需要输出方差;
于是上式变为:
log p ( x ∣ z ; θ ) = C − 1 2 σ 2 ∑ i = 1 D ( x i − μ i ) 2 \log p(x|z; \theta)=C - \frac{1}{2\sigma^2} \sum_{i=1}^{D} (x_i - \mu_i)^2 logp(x∣z;θ)=C−2σ21i=1∑D(xi−μi)2
log p ( x ∣ z ; θ ) = C − 1 2 σ 2 ∑ i = 1 D ( x i − μ i ) 2 \log p(x|z; \theta)=C - \frac{1}{2\sigma^2} \sum_{i=1}^{D} (x_i - \mu_i)^2 logp(x∣z;θ)=C−2σ21i=1∑D(xi−μi)2
其中 C 是与 μ 无关的常数项。
因此,在最大化对数似然时,我们只需最小化平方误差项:
arg max θ log p ( x ∣ z ; θ ) = arg min θ ∑ i = 1 D ( x i − μ i ) 2 \arg\max_{\theta} \log p(x|z; \theta)=\arg\min_{\theta} \sum_{i=1}^{D} (x_i - \mu_i)^2 argθmaxlogp(x∣z;θ)=argθmini=1∑D(xi−μi)2
最大化这个等价于最小化 MSE(均方误差)。
重参数化技巧详解:实现 VAE 可微采样的关键
VAE 的编码器与 autoencoder 的编码器不同,autoencoder 的编码器直接输出隐变量 z 的值,而 VAE 编码器输出的是高斯分布的参数 μ , σ 2 \mu, \sigma^{2} μ,σ2,隐变量 z 是从分布 N ( z ; μ , σ 2 I ) N \left( z; \mu, \sigma^{2}I \right) N(z;μ,σ2I)中采样得到的。
z ∼ q ( z ∣ x ) = N ( μ , σ 2 ) z \sim q(z|x)=\mathcal{N}(\mu, \sigma^2) z∼q(z∣x)=N(μ,σ2)
但问题来了:
采样操作不是可导的,所以不能直接通过反向传播优化编码器参数 μ , σ \mu, \sigma μ,σ
解决方案:重参数化技巧(Reparameterization Trick)
为了解决这个问题,VAE 使用了所谓的 重参数化技巧,也称为 路径梯度估计器(Pathwise Gradient Estimator, PGE)。
核心思想:
将“采样”这一不可导的操作从参数依赖中分离出来,将参数依赖转移到一个确定性的、可导的变换函数中。
假定 q ( z ∣ x ) ∼ N ( μ , σ 2 ) q \left( z|x \right) \sim N \left( \mu, \sigma^{2} \right) q(z∣x)∼N(μ,σ2),那么可以构造一个概率分布 p ( ϵ ) ∼ N ( 0 , 1 ) p \left( \epsilon \right) \sim N \left( 0,1 \right) p(ϵ)∼N(0,1)。有:
z = μ + ϵ σ z= \mu+ \epsilon \sigma z=μ+ϵσ
我们从 p ( ϵ ) p \left( \epsilon \right) p(ϵ)采样,然后利用上述公式,就相当于得到了从 q ( z ∣ x ) q(z|x) q(z∣x)采样得采样值
证明:
E ( z ) = E ( μ ) + E ( ϵ ) σ = μ V a r ( z ) = E [ ( z − μ ) 2 ] = E [ ( ϵ σ ) 2 ] = σ 2 \begin{split}\mathbb{E} \left( z \right)=& \mathbb{E} \left( \mu \right)+ \mathbb{E} \left( \epsilon \right) \sigma= \mu \\ Var \left( z \right)=& \mathbb{E} \left[ \left( z- \mu \right)^{2} \right]= \mathbb{E} \left[ \left( \epsilon \sigma \right)^{2} \right]= \sigma^{2} \end{split} E(z)=Var(z)=E(μ)+E(ϵ)σ=μE[(z−μ)2]=E[(ϵσ)2]=σ2
在 PyTorch 中实现:
# 编码器输出 mu 和 log_var(数值稳定)
mu, log_var = encoder(x)
sigma = torch.exp(0.5 * log_var)# 从标准正态分布采样 epsilon
epsilon = torch.randn_like(sigma)# 重参数化
z = mu + sigma * epsilon
代码实战:实现生成手写数字图片
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
class VAE(nn.Module):def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):super(VAE, self).__init__()# 编码器网络self.encoder = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, latent_dim * 2))# 解码器网络self.decoder = nn.Sequential(nn.Linear(latent_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, input_dim),nn.Sigmoid())def encode(self, x):h = self.encoder(x)mu, log_var = torch.chunk(h, 2, dim=1)return mu, log_vardef reparameterize(self, mu, log_var):std = torch.exp(0.5*log_var)eps = torch.randn_like(std)z = mu + eps*stdreturn zdef decode(self, z):return self.decoder(z)def forward(self, x):mu, log_var = self.encode(x.view(-1, 784))z = self.reparameterize(mu, log_var)return self.decode(z), mu, log_vardef loss_function(recon_x, x, mu, log_var):BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())return BCE + KLD
# 数据加载
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
vae = VAE()
optimizer = optim.Adam(vae.parameters(), lr=1e-3)
for epoch in range(10):vae.train()train_loss = 0for i, (data, _) in enumerate(train_loader):optimizer.zero_grad()recon_batch, mu, logvar = vae(data)loss = loss_function(recon_batch, data, mu, logvar)loss.backward()train_loss += loss.item()optimizer.step()print(f'Epoch {epoch}, Loss: {train_loss / len(train_loader.dataset)}')
with torch.no_grad():sample = torch.randn(64, 20).to('cpu') # 假设latent_dim为20sample = vae.decode(sample).view(64, 1, 28, 28)
plt.figure(figsize=(10, 10))
for i in range(64):ax = plt.subplot(8, 8, i + 1)plt.imshow(sample[i].squeeze().numpy(), cmap='gray')plt.axis("off")
plt.show()
代码相当简洁,简单跑了10个epoch,结果见图,值得一提得是这还是无条件生成,有兴趣得还可以了解一下Conditional VAE,或者叫CVAE
总结:
来自苏剑林的 VAE博客:变分自编码器一:原来是这么一回事
VAE的本质是什么?VAE虽然也称是AE(AutoEncoder)的一种,但它的做法(或者说它对网络的诠释)是别具一格的。在VAE中,它的Encoder有两个,一个用来计算均值,一个用来计算方差,这已经让人意外了:Encoder不是用来Encode的,是用来算均值和方差的,这真是大新闻了,还有均值和方差不都是统计量吗,怎么是用神经网络来算的?
事实上,我觉得VAE从让普通人望而生畏的变分和贝叶斯理论出发,最后落地到一个具体的模型中,虽然走了比较长的一段路,但最终的模型其实是很接地气的:它本质上就是在我们常规的自编码器的基础上,对encoder的结果(在VAE中对应着计算均值的网络)加上了“高斯噪声”,使得结果decoder能够对噪声有鲁棒性;而那个额外的KLloss(目的是让均值为0,方差为1),事实上就是相当于对encoder的一个正则项,希望encoder出来的东西有零均值。
那另外一个encoder(对应着计算方差的网络)的作用呢?它是用来动态调节噪声的强度的。直觉上来想,当decoder还没有训练好时(重构误差远大于KLloss),就会适当降低噪声(KLloss增加),使得拟合起来容易一些(重构误差开始下降);反之,如果decoder训练得还不错时(重构误差小于KLloss),这时候噪声就会增加(KLloss减少),使得拟合更加困难了(重构误差又开始增加),这时候decoder就要想办法提高它的生成能力了。
公式推导学习:https://www.bilibili.com/video/BV1op421S7Ep/?spm_id_from=333.337.search-card.all.click
其他参考资料:http://www.gwylab.com/note-vae.html
李宏毅教程:https://www.bilibili.com/video/av15889450/?p=33
博客园:从极大似然估计到变分自编码器 - VAE 公式推导
博客园:AE(自动编码器)与VAE(变分自动编码器)简单理解