变分自编码器(VAE)

1. 从自编码器(AE)到变分自编码器(VAE)

自编码器(AutoEncoder, AE)

基本结构:

自编码器是一种无监督学习模型,通常由两个部分组成:

  • 编码器(Encoder):将输入图像映射为一个低维特征向量(code);
  • 解码器(Decoder):将该特征向量还原回图像。

训练的目标是使解码器输出的图像尽可能接近原始输入图像,即最小化重构误差。

应用场景:

AE 广泛应用于图像压缩领域。例如,谷歌曾使用自编码器进行图片传输优化:仅传输编码器提取的特征表示,在客户端通过解码器还原图像,从而实现高效的数据传输。

局限性:

尽管 AE 在图像重建方面表现良好,但它不具备生成新图像的能力。因为其编码空间中的每个点都来源于真实图像的编码结果,无法直接通过随机采样得到有意义的图像。
在这里插入图片描述

思考与改进:如何让 AE 具备图像生成能力?

核心问题:
传统 AE 的特征向量(code)是一个确定性的点,只有这个点可以被解码为图像,其他位置可能生成噪声或无效图
在这里插入图片描述

解决思路:

为了实现图像生成能力,我们希望编码器输出的特征向量能够在潜变量空间中具有某种统计分布特性,使得我们可以从该分布中任意采样,并通过解码器生成合理的图像。

我们将“新月”图片映射成 μ = 1 μ=1 μ=1的正态分布,那么就相当于在1附近加了噪声,此时不仅1表示“新月”,1附近的数值也表示“新月”,只是1的时候最像“新月”。将"满月"映射成 μ = 10 μ=10 μ=10的正态分布,10的附近也都表示“满月”。那么code=5时,就同时拥有了“新月”和“满月”的特点,那么这时候decode出来的大概率就是“半月”了。这就是VAE的思想。

变分自编码器(Variational AutoEncoder, VAE)

核心思想
VAE 的核心在于将编码过程从“输出一个点”转变为“输出一个分布”,通常是高斯分布。具体来说:

  • 编码器不再输出单一的特征向量,而是输出一个均值和方差;
  • 然后通过采样方式,从该分布中获得一个具体的特征向量;
  • 再将其输入解码器进行图像重建。

这样,整个编码空间就变得连续且有规律,任何符合该分布的点都可以被解码为一张有意义的图像。
在这里插入图片描述图片描述
Vae的结构
在这里插入图片描述

变分自编码器(VAE)与传统自编码器(AE)在整体架构上相似,都包含一个编码器(Encoder)和一个解码器(Decoder)。它们的核心区别在于编码器输出的内容不同。

🔹 自编码器(AE)

  • 编码器直接输出一个确定性的向量 z z z,也称为潜变量(latent code);
  • 这个向量是唯一的,不能随机采样;
  • 解码器仅能对来自编码器的点进行有效解码,其他位置的向量可能会生成无意义的图像。

🔹变分自编码器(VAE)

  • 编码器不再输出单一的潜变量 z z z,而是输出:
    • 一组正态分布的均值: μ 1 , μ 2 , . . . . . , μ n μ_1,μ_2,.....,μ_n μ1,μ2,.....,μn
    • 一组对应的标准差: σ 1 , σ 2 , . . . . . . , σ n σ_1,σ_2,......,σ_n σ1,σ2,......,σn
  • 即对于每个维度 i i i,假设其服从高斯分布:
    z i ∼ N ( μ i , σ i 2 ) z_i \sim \mathcal{N}(\mu_i, \sigma_i^2) ziN(μi,σi2)

然后通过重参数化技巧(reparameterization trick),从这些分布中采样出具体的潜变量
z i = μ i + σ i ⋅ ϵ , 其中  ϵ ∼ N ( 0 , 1 ) z_i=\mu_i + \sigma_i \cdot \epsilon, \quad \text{其中 } \epsilon \sim \mathcal{N}(0, 1) zi=μi+σiϵ,其中 ϵN(0,1)
最终得到潜向量:
z = ( z 1 , z 2 , . . . , z n ) z=(z_1,z_2,...,z_n) z=(z1,z2,...,zn)
将其输入解码器进行重建。

公式推导:

在这里插入图片描述
假设我们有 N N N 个样本:

X = ( x 1 , x 2 , ⋯ , x N ) X= \left( x^{1},x^{2}, \cdots,x^{N} \right) X=(x1,x2,,xN)
对其进行似然估计:

∑ i = 1 N log ⁡ P ( x i ) \begin{split}&\sum_{i=1}^{N}\log P(x^{i})\\ \end{split} i=1NlogP(xi)

先单独看看里面某一个样本的似然,某个样本记为 x x x
log ⁡ P ( x ) = log ⁡ P ( x , z ) P ( z ∣ x ) \log P \left( x \right)= \log \frac{P \left( x,z \right)}{P \left( z|x \right)} logP(x)=logP(zx)P(x,z)

最大似然估计
在概率论与数理统计中我们就学习过极大似然估计法,它的用途是从数据中估计概率模型的参数。所谓训练一个生成模型,就是给定一堆样本 x ( 1 ) , … , x ( N ) x^{(1)},…,x^{(N)} x(1),,x(N),从这堆样本中估计出最优的模型参数 θ ∗ θ^∗ θ,要做的事情其实是一样的。(都是利用部分数据去估计总体分布)
这个分布通俗来讲可以理解为数据的共性特征,拿猫的图片来训练VAE,它能生成和训练集不一样的猫,但是它不会生成狗出来,它不是“创造”数据,而是“模仿”训练数据所代表的分布

先验分布 p ( z ) p(z) p(z) 是我们设定的分布,通常是标准正态分布:

z ∼ p ( z ) = N ( 0 , I ) z \sim p(z)=\mathcal{N}(0, I) zp(z)=N(0,I)

根据贝叶斯公式:

两个随机变量 x x x z z z 的联合概率分布可以写成两种形式: p ( x , z ) = p ( z ∣ x ) p ( x ) = p ( x ∣ z ) p ( z ) p(x, z)=p(z|x)p(x)=p(x|z)p(z) p(x,z)=p(zx)p(x)=p(xz)p(z)

p ( z ∣ x ) = p ( x ∣ z ) p ( z ) p ( x ) p(z|x)=\frac{p(x|z)p(z)}{p(x)} p(zx)=p(x)p(xz)p(z)

很明显分母 p ( x ) = ∫ p ( x ∣ z ) p ( z ) d z p(x)=\int p(x|z) p(z) \, dz p(x)=p(xz)p(z)dz是个高维积分,无法直接计算准确的后验分布 P ( z ∣ x ) P(z|x) P(zx),用的是后验分布的一个近似,记为 q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx),其中 ϕ {\phi} ϕ 代表的是 q q q 分布的参数, q ϕ ( z ∣ x ) q_{\phi}(z|x) qϕ(zx) 相当于编码器, P θ ( x ∣ z ) P_θ(x|z) Pθ(xz) 相当于解码器,训练结束后只需要保留解码器即可,解码器就是我们想要的生成模型。

与前面提到的似然 P θ ( x ) P_{ \theta} \left( x \right) Pθ(x)类似,为了表达的方便,同样隐去参数,得到 q ( x ∣ z ) q \left( x|z \right) q(xz)

log ⁡ P ( x ) \log P \left( x \right) logP(x)等式左右分别在 z 服从 q ( z ∣ x ) q \left( z|x \right) q(zx)分布的条件下求期望

左边是 log ⁡ P ( x ) \log P \left( x \right) logP(x),它不依赖于 z z z,相当于对常数求期望,所以对任何关于 z z z 的分布求期望还是它自己:
E z ∼ q ( z ∣ x ) [ log ⁡ P ( x ) ] = log ⁡ P ( x ) \mathbb{E}_{z\sim q(z|x)}[\log P \left( x \right)]= \log P \left( x \right) Ezq(zx)[logP(x)]=logP(x)
右边展开并拆解:
E z ∼ q ( z ∣ x ) [ log ⁡ P ( x , z ) P ( z ∣ x ) ] = E z ∼ q ( z ∣ x ) [ log ⁡ P ( x , z ) − log ⁡ P ( z ∣ x ) ] \mathbb{E}_{z \sim q(z|x)} \left[\log \frac{P(x, z)}{P(z|x)} \right]=\mathbb{E}_{z \sim q(z|x)} \left[\log P(x, z) - \log P(z|x) \right] Ezq(zx)[logP(zx)P(x,z)]=Ezq(zx)[logP(x,z)logP(zx)]

利用期望的线性性质:
= E z ∼ q ( z ∣ x ) [ log ⁡ P ( x , z ) ] − E z ∼ q ( z ∣ x ) [ log ⁡ P ( z ∣ x ) ] =\mathbb{E}_{z\sim q(z|x)}[\log P(x,z)]-\mathbb{E}_{z\sim q(z|x)}[ \log P(z|x)] =Ezq(zx)[logP(x,z)]Ezq(zx)[logP(zx)]
进一步整理可以把联合分布 P ( x , z ) P(x,z) P(x,z)写成:
P ( x , z ) = P ( x ∣ z ) P ( z ) {P(x,z)=P(x|z)P(z)} P(x,z)=P(xz)P(z)
所以:
log ⁡ P ( x , z ) = log ⁡ P ( x ∣ z ) + log ⁡ P ( z ) \log P \left( x,z \right)= \log P \left( x|z \right)+ \log P \left( z \right) logP(x,z)=logP(xz)+logP(z)
代入上面的式子:
E z ∼ q ( z ∣ x ) [ log ⁡ P ( x , z ) ] = E z ∼ q ( z ∣ x ) [ log ⁡ P ( x ∣ z ) ] + E z ∼ q ( z ∣ x ) [ log ⁡ P ( z ) ] \mathbb{E}_{z\sim q(z|x)}[\log P(x,z)]=\mathbb{E}_{z\sim q(z|x)}[\log P(x|z)]+ \mathbb{E}_{z\sim q(z|x)}[\log P(z)] Ezq(zx)[logP(x,z)]=Ezq(zx)[logP(xz)]+Ezq(zx)[logP(z)]

整理一下:
log ⁡ P ( x ) = E z ∼ q ( z ∣ x ) [ log ⁡ P ( x ∣ z ) ] − D K L ( q ( z ∣ x ) ∥ P ( z ) ) \log P(x)=\mathbb{E}_{z\sim q(z|x)}[\log P(x|z)]-D_{\mathrm{KL}}(q(z|x)\|P(z)) logP(x)=Ezq(zx)[logP(xz)]DKL(q(zx)P(z))
有时候还可以写成:

log ⁡ P ( x ) = E z ∼ q [ log ⁡ P ( x , z ) − log ⁡ q ( z ) ] + D K L ( q ( z ) ∥ P ( z ∣ x ) ) \log P(x)=\mathbb{E}_{z\sim q}[\log P(x,z)-\log q(z)]+D_{\mathrm{KL}}(q(z)\|P( z|x)) logP(x)=Ezq[logP(x,z)logq(z)]+DKL(q(z)P(zx))
D K L ( q ( z ) ∥ p ( z ∣ x ) ) D_{\mathrm{KL}}(q(z)\|p( z|x)) DKL(q(z)p(zx))是大于等于0的(当且仅当两个分布一致时等于0)所以:
log ⁡ p ( x ) ≥ E z ∼ q [ log ⁡ p ( x , z ) − log ⁡ q ( z ) ] ≡ C q \log p(x) \geq \mathbb{E}_{z \sim q}[\log p(x, z) - \log q(z)] \equiv \mathcal{C}_q logp(x)Ezq[logp(x,z)logq(z)]Cq
这个下界就是 ELBO(Evidence Lower BOund):
log ⁡ p ( x ) ≥ E z ∼ q [ log ⁡ p ( x , z ) − log ⁡ q ( z ) ] ⏟ ELBO \boxed{\log p \left( x \right) \geq \underbrace{\mathbb{E}_{z \sim q}[ \log p \left( x,z \right)- \log q \left( z \right)]}_{\text{ELBO}}} logp(x)ELBO Ezq[logp(x,z)logq(z)]

也叫变分下界:
log ⁡ P ( x ) ≥ ∫ z log ⁡ P ( x , z ) q ( z ∣ x ) q ( z ∣ x ) d z \log P \left( x \right) \geq \int_{z} \log \frac{P \left( x,z \right)}{q \left( z|x \right)}q \left( z|x \right) dz logP(x)zlogq(zx)P(x,z)q(zx)dz
所以优化目标从最大化似然函数,转换到最大化似然的下界:
max ⁡ θ , ϕ ∫ z log ⁡ P ( x , z ) q ( z ∣ x ) q ( z ∣ x ) d z \max_{\theta,\phi} \int_{z} \log \frac{P \left( x,z \right)}{q \left( z|x \right)}q \left( z|x \right) dz θ,ϕmaxzlogq(zx)P(x,z)q(zx)dz
= max ⁡ ∫ z log ⁡ P ( x ∣ z ) P ( z ) q ( z ∣ x ) q ( z ∣ x ) d z =\max \int_{z} \log \frac{P \left( x|z \right) P \left( z \right)}{q \left( z|x \right)}q \left( z|x \right) dz =maxzlogq(zx)P(xz)P(z)q(zx)dz
= max ⁡ ∫ z ( log ⁡ P ( z ) q ( z ∣ x ) + log ⁡ P ( x ∣ z ) ) q ( z ∣ x ) d z = \max \int_{z} \left( \log \frac{P \left( z \right)}{q \left( z|x \right)}+ \log P \left( x|z \right) \right) q \left( z|x \right) dz =maxz(logq(zx)P(z)+logP(xz))q(zx)dz
= max ⁡ ∫ z log ⁡ P ( x ∣ z ) q ( z ∣ x ) d z − ∫ log ⁡ q ( z ∣ x ) P ( z ) q ( z ∣ x ) d z =\max\int_{z}\log P(x|z)q(z|x)dz-\int\log\frac{q(z|x)}{P(z)}q(z|x )dz =maxzlogP(xz)q(zx)dzlogP(z)q(zx)q(zx)dz
= max ⁡ ( E z ∼ q ( z ∣ x ) [ log ⁡ P ( x ∣ z ) ] ⏟ (1) − K L ( q ( z ∣ x ) ∥ P ( z ) ) ⏟ (2) ) =\max \left( \underbrace{\mathbb{E}_{z \sim q(z|x)}[\log P(x|z)]}_{\text{(1)}} - \underbrace{KL(q(z|x) \| P(z))}_{\text{(2)}} \right) =max (1) Ezq(zx)[logP(xz)](2) KL(q(zx)P(z))

第一项就是从 q ( z ∣ x ) q(z|x) q(zx) 采样 z z z,再让 log ⁡ P ( x ∣ z ) \log P \left( x|z \right) logP(xz)概率最大,重构代价最小,这是可以通过神经网路训练实现的。 同时还要最小化KL散度

重构损失希望 σ → 0 σ→0 σ0,一个分布反差为0 说明这个随机变量取确定值的概率为1,这恰好就是AE的训练目标,但 KL散度项起到了“平衡”的作用:

KL 散度loss的公式如下:(后面有证明过程)

min ⁡ 1 2 ∑ j = 1 J ( σ j 2 + μ j 2 − log ⁡ σ j 2 − 1 ) \min \frac{1}{2} \sum_{j=1}^{J} \left( \sigma_{j}^{2}+ \mu_{j}^{2}- \log \sigma_{j}^{2}-1 \right) min21j=1J(σj2+μj2logσj21) 注意最后一项: − l o g σ j 2 −logσ_j^2 logσj2

  • σ → 0 σ→0 σ0 时, l o g σ j 2 → − ∞ logσ_j^2→-∞ logσj2 , , − l o g σ j 2 → + ∞ -logσ_j^2→+∞ logσj2+
  • 导致 KL 损失变得非常大,从而增加总损失
  • 模型为了减少总损失,就会避免让 σ σ σ 太小

细化目标函数:

1. 最小化KL散度

q ( z ∣ x ) q(z|x) q(zx) 需要逼近 p ( z ) p(z) p(z),而 P ( z ) ∼ N ( 0 , I ) P \left( z \right) \sim N \left( 0,I \right) P(z)N(0,I)的多维高斯分布,并且各个维度之间相互独立,要让KL最小, q ( z ∣ x ) q(z|x) q(zx)也要是一个多维高斯分布
那么最小化其KL散度,只需要对每一个维度求KL最小即可,单独看某一个维度,设某一个维度为 z j z_j zj ,设 q ( z j ∣ x ) ∼ N ( μ ϕ , σ ϕ 2 ) q \left( z_{j}|x \right) \sim N \left( \mu_{ \phi}, \sigma_{ \phi}^{2} \right) q(zjx)N(μϕ,σϕ2)(后续为了简便,也将参数 ϕ \phi ϕ 隐去):
min ⁡ K L ( q ( z j ∣ x ) ∣ ∣ P ( z j ) ) \min KL \left( q \left( z_{j}|x \right)||P \left( z_{j} \right) \right) minKL(q(zjx)∣∣P(zj))
= min ⁡ ∫ q ( z j ∣ x ) log ⁡ q ( z j ∣ x ) P ( z j ) d z j = \min \int q \left( z_{j}|x \right) \log \frac{q \left( z_{j}|x \right)}{P \left( z_{j} \right)}dz_{j} =minq(zjx)logP(zj)q(zjx)dzj
= min ⁡ ∫ q ( z j ∣ x ) log ⁡ N ( μ , σ 2 ) N ( 0 , 1 ) d z j = \min \int q \left( z_{j}|x \right) \log \frac{N \left( \mu, \sigma^{2} \right)}{N \left( 0,1 \right)}dz_{j} =minq(zjx)logN(0,1)N(μ,σ2)dzj
= min ⁡ ∫ q ( z j ∣ x ) log ⁡ 1 2 π σ e x p { − ( z j − μ ) 2 2 σ 2 } 1 2 π e x p { − z j 2 2 } d z j =\min\int q(z_{j}|x)\log\frac{\frac{1}{\sqrt{2\pi}\sigma}\mathrm{ exp}\{-\frac{(z_{j}-\mu)^{2}}{2\sigma^{2}}\}}{\frac{1}{\sqrt{2\pi}}\mathrm{ exp}\{-\frac{z_{j}^{2}}{2}\}}dz_{j} =minq(zjx)log2π 1exp{2zj2}2π σ1exp{2σ2(zjμ)2}dzj
= min ⁡ ∫ q ( z j ∣ x ) log ⁡ 1 σ exp ⁡ { − ( z j − μ ) 2 2 σ 2 } exp ⁡ { − z j 2 2 } d z j = \min \int q \left( z_{j}|x \right) \log \frac{ \frac{1}{ \sigma} \exp \{- \frac{ \left( z_{j}- \mu \right)^{2}}{2 \sigma^{2}} \}}{ \exp \{- \frac{z_{j}^{2}}{2} \}}dz_{j} =minq(zjx)logexp{2zj2}σ1exp{2σ2(zjμ)2}dzj
= min ⁡ ∫ q ( z j ∣ x ) ( log ⁡ 1 σ + log ⁡ exp ⁡ { − ( z j − μ ) 2 2 σ 2 } exp ⁡ { − z j 2 2 } ) d z j =\min\int q(z_{j}|x)\left(\log\frac{1}{\sigma}+\log\frac{\exp\{- \frac{(z_{j}-\mu)^{2}}{2\sigma^{2}}\}}{\exp\{-\frac{z_{j}^{2}}{2}\}}\right)dz_ {j} =minq(zjx) logσ1+logexp{2zj2}exp{2σ2(zjμ)2} dzj
= min ⁡ ∫ q ( z j ∣ x ) ( log ⁡ 1 σ + log ⁡ exp ⁡ { − ( z j − μ ) 2 2 σ 2 + z j 2 2 } ) d z j =\min\int q(z_{j}|x)\left(\log\frac{1}{\sigma}+\log\exp\{-\frac{( z_{j}-\mu)^{2}}{2\sigma^{2}}+\frac{z_{j}^{2}}{2}\}\right)dz_{j} =minq(zjx)(logσ1+logexp{2σ2(zjμ)2+2zj2})dzj
log以e为底:
= min ⁡ ∫ q ( z j ∣ x ) ( log ⁡ 1 σ − ( z j − μ ) 2 2 σ 2 + z j 2 2 ) d z j = \min \int q \left( z_{j}|x \right) \left( \log \frac{1}{ \sigma}- \frac{ \left( z_{j}- \mu \right)^{2}}{2 \sigma^{2}}+ \frac{z_{j}^{2}}{2} \right) dz_{j} =minq(zjx)(logσ12σ2(zjμ)2+2zj2)dzj
= min ⁡ ∫ q ( z j ∣ x ) 1 2 ( 2 log ⁡ 1 σ − ( z j − μ ) 2 σ 2 + z j 2 ) d z j = \min \int q \left( z_{j}|x \right) \frac{1}{2} \left( 2 \log \frac{1}{ \sigma}- \frac{ \left( z_{j}- \mu \right)^{2}}{ \sigma^{2}}+z_{j}^{2} \right) dz_{j} =minq(zjx)21(2logσ1σ2(zjμ)2+zj2)dzj

即:
min ⁡ K L ( q ( z j ∣ x ) ∣ ∣ P ( z j ) ) = min ⁡ 1 2 ∫ q ( z j ∣ x ) ( z j 2 − log ⁡ σ 2 − ( z j − μ ) 2 σ 2 ) d z j \min KL(q(z_{j}|x)||P(z_{j})) = \min\frac{1}{2}\int q(z_{j}|x)\left(z_{j}^{2 }-\log\sigma^{2}-\frac{(z_{j}-\mu)^{2}}{\sigma^{2}}\right)dz_{j} minKL(q(zjx)∣∣P(zj))=min21q(zjx)(zj2logσ2σ2(zjμ)2)dzj

可以分为三部分:
min ⁡ K L ( q ( z j ∣ x ) ∣ ∣ P ( z j ) ) = min ⁡ 1 2 ( ∫ q ( z j ∣ x ) z j 2 d z j − ∫ q ( z j ∣ x ) log ⁡ σ 2 d z j − ∫ z q ( z j ∣ x ) ( z j − μ ) 2 σ 2 d z j ) \min KL(q(z_{j}|x)||P(z_{j}))=\min\frac{1}{2}\Biggl(\int q(z_{j}|x)z_{j}^{2} dz_{j}-\int q(z_{j}|x)\log\sigma^{2}dz_{j}-\int_{z}q(z_{j}|x)\frac{(z_{j}-\mu)^{2}}{ \sigma^{2}}dz_{j}\Biggr) minKL(q(zjx)∣∣P(zj))=min21(q(zjx)zj2dzjq(zjx)logσ2dzjzq(zjx)σ2(zjμ)2dzj)

  1. ∫ z q ( z j ∣ x ) z j 2 d z = E [ z j 2 ] = D ( z j ) + E [ z j ] 2 = σ 2 + μ 2 \int_{z}q(z_{j}|x)z_{j}^{2}dz=\mathbb{E}[z_{j}^{2}]=D(z_{j})+\mathbb{E}[z_{j}] ^{2}=\sigma^{2}+\mu^{2} zq(zjx)zj2dz=E[zj2]=D(zj)+E[zj]2=σ2+μ2

D ( X ) = E ( X 2 ) − [ E ( X ) ] 2 D \left( X \right)=E \left( X^{2} \right)- \left[ E \left( X \right) \right]^{2} D(X)=E(X2)[E(X)]2

  1. ∫ z q ( z j ∣ x ) log ⁡ σ 2 d z j = log ⁡ σ 2 ∫ z q ( z j ∣ x ) d z j = log ⁡ σ 2 \int_{z}q(z_{j}|x)\log\sigma^{2}dz_{j}=\log\sigma^{2}\int_{z}q(z_{j}|x)dz_{j}= \log\sigma^{2} zq(zjx)logσ2dzj=logσ2zq(zjx)dzj=logσ2
  2. ∫ z q ( z j ∣ x ) ( z j − μ ) 2 σ 2 d z j = 1 σ 2 ∫ z q ( z j ∣ x ) ( z j − μ ) 2 d z j = 1 σ 2 E [ ( z j − μ ) 2 ] = 1 \int_{z}q(z_{j}|x)\frac{(z_{j}-\mu)^{2}}{\sigma^{2}}dz_{j}=\frac{1}{\sigma^{2} }\int_{z}q(z_{j}|x)(z_{j}-\mu)^{2}dz_{j}=\frac{1}{\sigma^{2}}\mathbb{E}[(z_{j} -\mu)^{2}]=1 zq(zjx)σ2(zjμ)2dzj=σ21zq(zjx)(zjμ)2dzj=σ21E[(zjμ)2]=1

第三步公式推导如下:
1 σ 2 E [ ( z j − μ ) 2 ] \frac{1}{ \sigma^{2}}E \left[ \left( z_{j}- \mu \right)^{2} \right] σ21E[(zjμ)2]
= 1 σ 2 E [ z j 2 − 2 μ z j + μ 2 ] = \frac{1}{ \sigma^{2}}E \left[ z_{j}^{2}-2 \mu z_{j}+ \mu^{2} \right] =σ21E[zj22μzj+μ2]
= 1 σ 2 ( E ( z j 2 ) − 2 μ E ( z j ) + μ 2 ) = \frac{1}{ \sigma^{2}} \left( E \left( z_{j}^{2} \right)-2 \mu E \left( z_{j} \right)+ \mu^{2} \right) =σ21(E(zj2)2μE(zj)+μ2)
= 1 σ 2 ( σ 2 + μ 2 − 2 μ 2 + μ 2 ) = \frac{1}{ \sigma^{2}} \left( \sigma^{2}+ \mu^{2}-2 \mu^{2}+ \mu^{2} \right) =σ21(σ2+μ22μ2+μ2)

所以对于所有维度的KL散度,有(假设隐变量有J维):
min ⁡ 1 2 ∑ j = 1 J ( σ j 2 + μ j 2 − log ⁡ σ j 2 − 1 ) \min \frac{1}{2} \sum_{j=1}^{J} \left( \sigma_{j}^{2}+ \mu_{j}^{2}- \log \sigma_{j}^{2}-1 \right) min21j=1J(σj2+μj2logσj21)

2. 最小重构损失:

E z ∼ q ( z ∣ x ) [ log ⁡ P ( x ∣ z ) ] {\mathbb{E}_{z \sim q(z|x)}[\log P(x|z)]} Ezq(zx)[logP(xz)]

这个积分通常没有解析解,只能通过采样估计。

E z ∼ q ( z ∣ x ) [ log ⁡ P ( x ∣ z ) ] = ∫ z log ⁡ P ( x ∣ z ) q ( z ∣ x ) d z ≈ 1 n ∑ i = 1 n log ⁡ P ( x i ∣ z i ) \mathbb{E}_{z \sim q( z \mid x )}\left[\log P(x|z)\right]= \int_{z} \log P(x|z)q(z|x)dz \approx\frac{1}{n}\sum_{i=1}^{n}\log P(x^{i}|z^{i}) Ezq(zx)[logP(xz)]=zlogP(xz)q(zx)dzn1i=1nlogP(xizi)

使用蒙特卡罗近似,从 q ( z ∣ x ; ϕ ) q(z|x;ϕ) q(zxϕ)采样n个z样本,然后带入 log ⁡ P ( x ∣ z ) \log P \left( x|z \right) logP(xz),求出均值,不过在实践中,对于每个样本 x,只需要从 q ( z ∣ x ; ϕ ) {q( z \mid x;ϕ )} q(zxϕ)采样一次 z 就能正常训练 VAE 了,效率很高,可见重要性采样的思想还是很有效的。

所以我们可以写成:
E z ∼ q ( z ∣ x ; ϕ ) [ log ⁡ p ( x ∣ z ; θ ) ] ≈ log ⁡ p ( x ∣ z ; θ ) , z ∼ q ( z ∣ x ; ϕ ) \mathbb{E}_{z \sim q(z|x; \phi)}[\log p(x|z; \theta)] \approx \log p(x|z; \theta), \quad z \sim q(z|x; \phi) Ezq(zx;ϕ)[logp(xz;θ)]logp(xz;θ),zq(zx;ϕ)

在论文中,其假设 P ( x ∣ z ) ∼ N ( μ , σ 2 I ) P \left( x|z\right) \sim N \left( \mu, \sigma^{2}I \right) P(xz)N(μ,σ2I),其中 μ \mu μ σ \sigma σ 是需要神经网络去逼近,
由于各维度相互独立,概率密度函数形式如下:

p ( x ∣ z ; θ ) = ∏ i = 1 D N ( x i ; μ i , σ i 2 ) = ( ∏ i = 1 D 1 2 π σ i ) exp ⁡ ( − 1 2 ∑ i = 1 D ( x i − μ i ) 2 σ i 2 ) p(x|z; \theta)=\prod_{i=1}^{D} \mathcal{N}(x_i; \mu_i, \sigma_i^2)=\left( \prod_{i=1}^{D} \frac{1}{\sqrt{2\pi\sigma_i}} \right) \exp\left( -\frac{1}{2} \sum_{i=1}^{D} \frac{(x_i - \mu_i)^2}{\sigma_i^2} \right) p(xz;θ)=i=1DN(xi;μi,σi2)=(i=1D2πσi 1)exp(21i=1Dσi2(xiμi)2)

取对数后得到:
log ⁡ p ( x ∣ z ; θ ) = − D 2 log ⁡ ( 2 π ) − 1 2 ∑ i = 1 D log ⁡ σ i 2 − 1 2 ∑ i = 1 D ( x i − μ i ) 2 σ i 2 \log p(x|z;\theta)=-\frac{D}{2} \log(2\pi) - \frac{1}{2} \sum_{i=1}^{D} \log \sigma_i^2 - \frac{1}{2} \sum_{i=1}^{D} \frac{(x_i - \mu_i)^2}{\sigma_i^2} logp(xz;θ)=2Dlog(2π)21i=1Dlogσi221i=1Dσi2(xiμi)2

为了简化模型,通常会做以下两个简化假设:

  1. 所有维度的方差相同,即 σ i 2 = σ 2 σ_i^2=σ^2 σi2=σ2是一个常数;
  2. 解码器只需输出均值 μ μ μ,而不需要输出方差;

于是上式变为:
log ⁡ p ( x ∣ z ; θ ) = C − 1 2 σ 2 ∑ i = 1 D ( x i − μ i ) 2 \log p(x|z; \theta)=C - \frac{1}{2\sigma^2} \sum_{i=1}^{D} (x_i - \mu_i)^2 logp(xz;θ)=C2σ21i=1D(xiμi)2

log ⁡ p ( x ∣ z ; θ ) = C − 1 2 σ 2 ∑ i = 1 D ( x i − μ i ) 2 \log p(x|z; \theta)=C - \frac{1}{2\sigma^2} \sum_{i=1}^{D} (x_i - \mu_i)^2 logp(xz;θ)=C2σ21i=1D(xiμi)2
其中 C 是与 μ 无关的常数项。
因此,在最大化对数似然时,我们只需最小化平方误差项:
arg ⁡ max ⁡ θ log ⁡ p ( x ∣ z ; θ ) = arg ⁡ min ⁡ θ ∑ i = 1 D ( x i − μ i ) 2 \arg\max_{\theta} \log p(x|z; \theta)=\arg\min_{\theta} \sum_{i=1}^{D} (x_i - \mu_i)^2 argθmaxlogp(xz;θ)=argθmini=1D(xiμi)2
最大化这个等价于最小化 MSE(均方误差)。

重参数化技巧详解:实现 VAE 可微采样的关键

VAE 的编码器与 autoencoder 的编码器不同,autoencoder 的编码器直接输出隐变量 z 的值,而 VAE 编码器输出的是高斯分布的参数 μ , σ 2 \mu, \sigma^{2} μ,σ2,隐变量 z 是从分布 N ( z ; μ , σ 2 I ) N \left( z; \mu, \sigma^{2}I \right) N(z;μ,σ2I)中采样得到的。

z ∼ q ( z ∣ x ) = N ( μ , σ 2 ) z \sim q(z|x)=\mathcal{N}(\mu, \sigma^2) zq(zx)=N(μ,σ2)
但问题来了:
采样操作不是可导的,所以不能直接通过反向传播优化编码器参数 μ , σ \mu, \sigma μ,σ

解决方案:重参数化技巧(Reparameterization Trick)
为了解决这个问题,VAE 使用了所谓的 重参数化技巧,也称为 路径梯度估计器(Pathwise Gradient Estimator, PGE)。

核心思想
将“采样”这一不可导的操作从参数依赖中分离出来,将参数依赖转移到一个确定性的、可导的变换函数中。

假定 q ( z ∣ x ) ∼ N ( μ , σ 2 ) q \left( z|x \right) \sim N \left( \mu, \sigma^{2} \right) q(zx)N(μ,σ2),那么可以构造一个概率分布 p ( ϵ ) ∼ N ( 0 , 1 ) p \left( \epsilon \right) \sim N \left( 0,1 \right) p(ϵ)N(0,1)。有:
z = μ + ϵ σ z= \mu+ \epsilon \sigma z=μ+ϵσ
我们从 p ( ϵ ) p \left( \epsilon \right) p(ϵ)采样,然后利用上述公式,就相当于得到了从 q ( z ∣ x ) q(z|x) q(zx)采样得采样值

证明:
E ( z ) = E ( μ ) + E ( ϵ ) σ = μ V a r ( z ) = E [ ( z − μ ) 2 ] = E [ ( ϵ σ ) 2 ] = σ 2 \begin{split}\mathbb{E} \left( z \right)=& \mathbb{E} \left( \mu \right)+ \mathbb{E} \left( \epsilon \right) \sigma= \mu \\ Var \left( z \right)=& \mathbb{E} \left[ \left( z- \mu \right)^{2} \right]= \mathbb{E} \left[ \left( \epsilon \sigma \right)^{2} \right]= \sigma^{2} \end{split} E(z)=Var(z)=E(μ)+E(ϵ)σ=μE[(zμ)2]=E[(ϵσ)2]=σ2
在 PyTorch 中实现:

# 编码器输出 mu 和 log_var(数值稳定)
mu, log_var = encoder(x)
sigma = torch.exp(0.5 * log_var)# 从标准正态分布采样 epsilon
epsilon = torch.randn_like(sigma)# 重参数化
z = mu + sigma * epsilon

代码实战:实现生成手写数字图片

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
class VAE(nn.Module):def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):super(VAE, self).__init__()# 编码器网络self.encoder = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, latent_dim * 2))# 解码器网络self.decoder = nn.Sequential(nn.Linear(latent_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, input_dim),nn.Sigmoid())def encode(self, x):h = self.encoder(x)mu, log_var = torch.chunk(h, 2, dim=1)return mu, log_vardef reparameterize(self, mu, log_var):std = torch.exp(0.5*log_var)eps = torch.randn_like(std)z = mu + eps*stdreturn zdef decode(self, z):return self.decoder(z)def forward(self, x):mu, log_var = self.encode(x.view(-1, 784))z = self.reparameterize(mu, log_var)return self.decode(z), mu, log_vardef loss_function(recon_x, x, mu, log_var):BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())return BCE + KLD
# 数据加载
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
vae = VAE()
optimizer = optim.Adam(vae.parameters(), lr=1e-3)
for epoch in range(10):vae.train()train_loss = 0for i, (data, _) in enumerate(train_loader):optimizer.zero_grad()recon_batch, mu, logvar = vae(data)loss = loss_function(recon_batch, data, mu, logvar)loss.backward()train_loss += loss.item()optimizer.step()print(f'Epoch {epoch}, Loss: {train_loss / len(train_loader.dataset)}')
with torch.no_grad():sample = torch.randn(64, 20).to('cpu')  # 假设latent_dim为20sample = vae.decode(sample).view(64, 1, 28, 28)
plt.figure(figsize=(10, 10))
for i in range(64):ax = plt.subplot(8, 8, i + 1)plt.imshow(sample[i].squeeze().numpy(), cmap='gray')plt.axis("off")
plt.show()

代码相当简洁,简单跑了10个epoch,结果见图,值得一提得是这还是无条件生成,有兴趣得还可以了解一下Conditional VAE,或者叫CVAE

在这里插入图片描述

总结:

来自苏剑林的 VAE博客:变分自编码器一:原来是这么一回事

VAE的本质是什么?VAE虽然也称是AE(AutoEncoder)的一种,但它的做法(或者说它对网络的诠释)是别具一格的。在VAE中,它的Encoder有两个,一个用来计算均值,一个用来计算方差,这已经让人意外了:Encoder不是用来Encode的,是用来算均值和方差的,这真是大新闻了,还有均值和方差不都是统计量吗,怎么是用神经网络来算的?

事实上,我觉得VAE从让普通人望而生畏的变分和贝叶斯理论出发,最后落地到一个具体的模型中,虽然走了比较长的一段路,但最终的模型其实是很接地气的:它本质上就是在我们常规的自编码器的基础上,对encoder的结果(在VAE中对应着计算均值的网络)加上了“高斯噪声”,使得结果decoder能够对噪声有鲁棒性;而那个额外的KLloss(目的是让均值为0,方差为1),事实上就是相当于对encoder的一个正则项,希望encoder出来的东西有零均值。

那另外一个encoder(对应着计算方差的网络)的作用呢?它是用来动态调节噪声的强度的。直觉上来想,当decoder还没有训练好时(重构误差远大于KLloss),就会适当降低噪声(KLloss增加),使得拟合起来容易一些(重构误差开始下降);反之,如果decoder训练得还不错时(重构误差小于KLloss),这时候噪声就会增加(KLloss减少),使得拟合更加困难了(重构误差又开始增加),这时候decoder就要想办法提高它的生成能力了。

公式推导学习:https://www.bilibili.com/video/BV1op421S7Ep/?spm_id_from=333.337.search-card.all.click
其他参考资料:http://www.gwylab.com/note-vae.html
李宏毅教程:https://www.bilibili.com/video/av15889450/?p=33
博客园:从极大似然估计到变分自编码器 - VAE 公式推导
博客园:AE(自动编码器)与VAE(变分自动编码器)简单理解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/diannao/88878.shtml
繁体地址,请注明出处:http://hk.pswp.cn/diannao/88878.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ChatboxAI 搭载 GPT 与 DeepSeek,引领科研与知识库管理变革

文章摘要:本文深入探讨 ChatboxAI 在科研领域的应用优势。ChatboxAI 集成多模型,支持全平台,能高效管理科研知识,助力文献检索、实验设计与论文撰写,提升科研效率与质量,同时保障数据安全。其知识库功能可整…

【无刷电机FOC进阶基础准备】【04 clark变换、park变换、等幅值变换】

目录 clark变换park变换等幅值变换 其实我不太记得住什么是clark变换、park变换,我每次要用到这个名词的时候都会上网查一下,因为这就是两个名词而已,但是我能记住的是他们背后的含义。 经过【从零开始实现stm32无刷电机FOC】系列后应该对cla…

Sentinel的流控策略

在 Sentinel 中,流控策略(Flow Control Strategy)用于定义如何处理请求的流量,并决定在流量达到某个阈值时采取的行动。流控策略是实现系统稳定性和高可用性的核心机制,尤其在高并发环境中,确保服务不会因过…

Ubuntu Extension Manager 插件卸载

Ubuntu 上使用Extension Manager 安装插件,但目前无法在Extension Manager 中卸载。 卸载方式可以通过 gnome-extensions 命令进行卸载: Usage:gnome-extensions COMMAND [ARGS…]Commands:help Print helpversion Print versionenable Enabl…

深度学习中Embedding原理讲解

我们用最直白的方式来理解深度学习中 Embedding(嵌入) 的概念。 核心思想一句话: Embedding 就是把一些复杂、离散的东西(比如文字、类别、ID)转换成计算机更容易理解和计算的“数字密码”,这些“数字密码…

(3)Java+Playwright自动化测试-启动浏览器

1.简介 前边两章文章已经将环境搭建好了,今天就在Java项目搭建环境中简单地实践一下: 启动两大浏览器。 接下来我们在Windows系统中启动浏览器即可,其他平台系统的浏览器类似的启动方法,照猫画虎就可以了。 但是在实践过程中&am…

使用OpenWebUI与DeepSeek交互

Open WebUI 是针对 LLM 用户友好的 WebUI,支持的 LLM 运行程序包括阿里百炼、 Ollama、OpenAI 兼容的 API。这里主要讲在Docker环境下安装与本地Ollame和百炼API Key配置 一、安装Docker 1. CentOS # 设置为阿里云的源 sudo yum install -y yum-utils sudo yum-config-mana…

Github 2025-06-25 C开源项目日报 Top9

根据Github Trendings的统计,今日(2025-06-25统计)共有9个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量C项目9C++项目1raylib: 用于视频游戏编程的简单易用图形库 创建周期:3821 天开发语言:C协议类型:zlib LicenseStar数量:18556 个Fork数量:1…

【数据标注师】2D标注

目录 一、 **2D标注知识体系框架**二、 **五阶能力培养体系**▶ **阶段1:基础规则内化(1-2周)**▶ **阶段2:复杂场景处理技能**▶ **阶段3:专业工具 mastery**▶ **阶段4:领域深度专精▶ **阶段5&#xff1…

深入浅出Node.js后端开发

让我们来理解Node.js的核心——事件循环和异步编程模型。在Node.js中,所有的I/O操作都是非阻塞的,这意味着当一个请求开始等待I/O操作完成时(如读取文件或数据库操作),Node.js不会阻塞后续操作,而是继续执行…

C++11的内容

1.支持花括号初始化 void test1() {vector<string> v1 { "asd","asd","add" };vector<string> v2{ "asd","asd","add" };map<string, int> m1{ {"asd",1},{"asd",2},{&q…

AI代码助手实践指南

概述与发展趋势 核心理念 发展方向&#xff1a;从代码补全 → 代码生成 → 整个工程服务价值转换&#xff1a;从单纯写代码 → 需求驱动的代码生成功能扩展&#xff1a;超越编写层面&#xff0c;涵盖测试环境搭建等 核心价值点 低价值动作识别&#xff1a;debug、代码评审、…

.net反编译工具

.NET 反编译工具大揭秘 在.NET 开发的世界里&#xff0c;有时候我们需要对已编译的.NET 程序集进行反编译&#xff0c;将 DLL 或 EXE 文件还原为可读的源代码形式&#xff0c;这在学习、调试、代码分析等方面都有着重要的作用。今天&#xff0c;就让我们一起深入了解一些流行的…

mac docker desktop 安装 oracle

1.登录 oracle 官网&#xff0c;选择镜像 https://container-registry.oracle.com/ords/f?p113:1:6104693702564::::FSP_LANGUAGE_PREFERENCE:&cs3CAuGEkeY6APmlAELFJ0uYU5M8_O8aTEufSKZHFf12lu1sUk5fsdbCzJAni9jVaCYXf-SNM_8e3VYr1V4QMBq1A 2.登录认证 oracle 账号 doc…

【redis使用场景——缓存——数据过期策略 】

redis使用场景——缓存——数据过期策略 定期删除&#xff08;Active Expiration&#xff09;1. 快速模式&#xff08;Fast Expiration Cycle&#xff09;工作流程&#xff1a;特点&#xff1a;优点&#xff1a; 2. 慢速模式&#xff08;Slow Expiration Cycle&#xff09;工作…

智能体Manus和实在Agent的区别

在当今数字化时代&#xff0c;AI 已经深度融入我们的生活和工作。曾经&#xff0c;像 ChatGPT 这样的传统 AI&#xff0c;虽然能在很多方面给我们提供帮助&#xff0c;比如写邮件时它妙笔生花&#xff0c;分析数据时头头是道&#xff0c;可却在最后一步掉了链子 —— 它不会点击…

Prism框架实战:WPF企业级开发全解

以下是一个完整的WPF项目示例&#xff0c;使用Prism框架实现依赖注入、导航、复合命令、模块化和聚合事件功能。项目结构清晰&#xff0c;包含核心功能实现&#xff1a; 项目结构 PrismDemoApp/ ├── PrismDemoApp (主项目) │ ├── Views/ │ │ ├── ShellView…

单片机学习笔记---AD/DA工作原理(含运算放大器的工作原理)

目录 AD/DA介绍 硬件电路模型 硬件电路 运算放大器 DA原理 T型电阻网络DA转换器 PWM型DA转换器 AD原理 逐次逼近型AD转换器 AD/DA性能指标 XPT2046 XPT2046时序 AD/DA介绍 AD&#xff08;Analog to Digital&#xff09;&#xff1a;模拟-数字转换&#xff0c;将模拟…

matlab实现相控超声波成像

相控超声波成像仿真检测探伤 数据接收 换能器开发 Phased Array Codes/Matlab Examples.pptx , 513230 Phased Array Codes/MATLAB M_files/delay_laws2D.m , 1027 Phased Array Codes/MATLAB M_files/delay_laws2D_int.m , 3290 Phased Array Codes/MATLAB M_files/delay_law…

Stable Diffusion入门-ControlNet 深入理解 第二课:ControlNet模型揭秘与使用技巧

大家好&#xff0c;欢迎回到Stable Diffusion入门-ControlNet 深入理解系列的第二课&#xff01; 如果你还记得第一篇文章的内容 - 我们已经了解了 ControlNet 的基础概念&#xff1a;它通过预处理器和模型两个强力模块&#xff0c;赋予了AI绘画前所未有的精准控制。 还没看过…