数学原理:
(1) 前向传播的方差一致性
假设输入 x 的均值为 0,方差为 σx2σ_x^2σx2,权重 W的均值为 0,方差为 σW2σ_W^2σW2,则输出 z=Wxz=Wxz=Wx的方差为:
Var(z)=nin⋅Var(W)⋅Var(x)
Var(z)=n_{in}⋅Var(W)⋅Var(x)
Var(z)=nin⋅Var(W)⋅Var(x)
为了使 Var(z)=Var(x),需要:
nin⋅Var(W)=1 ⟹ Var(W)=1nin
n_{in}⋅Var(W)=1 ⟹ Var(W)=\frac{1}{n_{in}}
nin⋅Var(W)=1 ⟹ Var(W)=nin1
其中 ninn_{in}nin是输入维度(fan_in)。这里乘以 nin 的原因是,输出 z 是由 nin 个输入 x 的线性组合得到的,每个输入 x 都与一个权重 W 相乘。因此,输出 z 的方差是 nin 个独立的 Wx 项的方差之和。
(2) 反向传播的梯度方差一致性
在反向传播过程中,梯度 ∂L∂x\frac{∂L}{∂x}∂x∂L 是通过链式法则计算得到的,其中 L 是损失函数,x 是输入,z 是输出。梯度∂L∂x\frac{∂L}{∂x}∂x∂L可以表示为:
∂L∂x=∂L∂z.∂z∂x
\frac{∂L}{∂x}=\frac{∂L}{∂z}.\frac{∂z}{∂x}
∂x∂L=∂z∂L.∂x∂z
假设 z=Wx,其中 W 是权重矩阵,那么 ∂z∂x=W\frac{∂z}{∂x}=W∂x∂z=W。因此,梯度 ∂L∂x\frac{∂L}{∂x}∂x∂L可以写为: ∂L∂x=∂L∂zW\frac{∂L}{∂x}=\frac{∂L}{∂z}W∂x∂L=∂z∂LW
反向传播时梯度 ∂L∂x\frac{∂L}{∂x}∂x∂L 的方差应与 ∂L∂z\frac{∂L}{∂z}∂z∂L 相同,因此:
nout⋅Var(W)=1 ⟹ Var(W)=1nout
n_{out}⋅Var(W)=1 ⟹ Var(W)=\frac{1}{n_{out}}
nout⋅Var(W)=1 ⟹ Var(W)=nout1
其中 noutn_{out}nout是输出维度(fan_out)。为了保持梯度的方差一致性,我们需要确保每个输入维度 nin 的梯度方差与输出维度 nout 的梯度方差相同。因此,我们需要将 W 的方差乘以 nout,以确保梯度的方差在反向传播过程中保持一致。
(3) 综合考虑
为了同时平衡前向传播和反向传播,Xavier 采用:
Var(W)=2nin+nout
Var(W)=\frac{2}{n_{in}+n_{out}}
Var(W)=nin+nout2
权重从以下分布中采样:
均匀分布:
W∼U(−6nin+nout,6nin+nout)
W\sim\mathrm{U}\left(-\frac{\sqrt{6}}{\sqrt{n_\mathrm{in}+n_\mathrm{out}}},\frac{\sqrt{6}}{\sqrt{n_\mathrm{in}+n_\mathrm{out}}}\right)
W∼U(−nin+nout6,nin+nout6)
在Xavier初始化中,我们选择 a=−6nin+nouta=−\sqrt{\frac{6}{n_{in}+n_{out}}}a=−nin+nout6 和 b=6nin+noutb=\sqrt{\frac{6}{n_{in}+n_{out}}}b=nin+nout6,这样方差为:
Var(W)=(b−a)212=(26nin+nout)212=4⋅6nin+nout12=2nin+nout
Var(W)=\frac{(b−a)^2}{12}=\frac{(2\sqrt{\frac{6}{n_{in}+n_{out}}})^2}{12}=\frac{4⋅\frac{6}{nin+nout}}{12}=\frac{2}{n_{in}+n_{out}}
Var(W)=12(b−a)2=12(2nin+nout6)2=124⋅nin+nout6=nin+nout2
正态分布:
W∼N(0,2nin+nout)
W\sim\mathrm{N}\left(0,\frac{2}{n_\mathrm{in}+n_\mathrm{out}}\right)
W∼N(0,nin+nout2)
N(0,std2) \mathcal{N}(0, \text{std}^2) N(0,std2)
其中 ninn_{\text{in}}nin 是当前层的输入神经元数量,noutn_{\text{out}}nout是输出神经元数量。
在前向传播中,输出的方差受 ninn_{in}nin 影响。在反向传播中,梯度的方差受 noutn_{out}nout 影响。