VAE
Vanilla VAE#
理论推导#
VAE的损失函数如下
$$ \mathcal{L} = \log p(x)\int_z q(\underbrace{z}_{\text{latent var.}}|x)dz = \int_z q(z|x)\log p(x)dx $$
由于$p(x)=\frac{p(x,z)}{p(z|x)}$,有
$$RHS=\int_z q(z|x)\log \left(\frac{p(x,z)}{p(z|x)}\right)dz$$
进行增项
$$RHS=\int_z q(z|x)\log \left( \frac{p(x,z)}{q(z|x)}\cdot \frac{q(z|x)}{p(z|x)} \right)dz$$
由于$\log(\cdot)$的性质,可以进行拆项:
$$RHS1=\underbrace{\int_z q(z|x)\log\left(\frac{p(x,z)}{q(z|x)}\right)dz}_{ELBO,\text{denoted as }L_b}$$
$$RHS2=\underbrace{\int_z q(z|x)\log \left(\frac{q(z|x)}{p(z|x)} \right)}_{\text{KL div}.,KL(q(z|x)\Vert p(z|x))\ge 0}$$
由于KL散度项$\ge 0$,有
$$\log p(x) \ge L_b = \int_z q(z|x)\log\left(\frac{p(x,z)}{q(z|x)}\right)dz$$
由于$p(x,z)=p(x|z)p(z)$,可以对上式中的$p(x|z)$进行展开
$$RHS=\int_z q(z|x)\log \left(\frac{p(x|z)p(z)}{q(z|x)}\right)$$
根据对数函数性质,进行拆项处理:
$$RHS3=\underbrace{\int_z q(z|x)\log \left( \frac{p(z)}{q(z|x)}\right)dz}_{(3):\text{-KL div.},-KL(q(z|x)\Vert p(z))}+\int_z q(z|x)\log p(x|z)$$
最小化损失函数,相当于最大化ELBO。
[损失项1:$RHS3$式,负KL散度项]
对于$RHS3$进行拆项处理:
$$\int_z q(z|x)\log \left( \frac{p(z)}{q(z|x)}\right)dz = \int_z q(z|x)\log p(z)dz-\int_z q(z|x)\log q(z|x)dz$$
根据期望的定义,上式可以写成期望形式:
$$RHS=E_q(\log p(z))-E_q(\log q(z|x))$$
考虑对$\log p(z)$进行化简:
$$p(z)=\frac{1}{(2\pi)^{J/2}\vert I\vert^{1/2}}\exp\left(-\frac{1}{2} z^Tz\right)$$
则有
$$\log p(z)=-\frac{p}{2}\log(2\pi)-\frac{1}{2}J-\frac{1}{2}\sum_{j=1}^{J}z_j^2$$
而$E_q(z)=\mu^2+\sigma^2$,有
$$E_q(\log p(z))=-\frac{J}{2}\log(2\pi)-\frac{1}{2}J-\frac{1}{2}\sum_{j=1}^{J}(\mu_j^2+\sigma_j^2)$$
同理,有
$$E_q(\log q(z|x))=-\frac{J}{2}\log(2\pi)-\frac{1}{2}\sum_{j=1}^{J}\log\sigma_j^2-\frac{1}{2}\sum_{j=1}^{J}(1+\sigma_j^2)$$
综上,
$$RHS=\frac{1}{2}\sum_{j=1}^{J}\left(1+\log\sigma_j^2-\mu_j-e^{\log\sigma_j^2}\right)$$
[损失项2: $RHS4$式,重建概率项]
$$\max \int_z q(z|x)p(x|z)dz=\max E_{z\sim q(z|x)} \left(\log p(x|z)\right) \tag*{}$$
即最大化$x$经过Encoder-Decoder后的重建概率,又即最小化重建损失。