VAE

Vanilla VAE#

理论推导#

VAE的损失函数如下

$$ \mathcal{L} = \log p(x)\int_z q(\underbrace{z}_{\text{latent var.}}|x)dz = \int_z q(z|x)\log p(x)dx $$

由于$p(x)=\frac{p(x,z)}{p(z|x)}$,有

$$RHS=\int_z q(z|x)\log \left(\frac{p(x,z)}{p(z|x)}\right)dz$$

进行增项

$$RHS=\int_z q(z|x)\log \left( \frac{p(x,z)}{q(z|x)}\cdot \frac{q(z|x)}{p(z|x)} \right)dz$$

由于$\log(\cdot)$的性质,可以进行拆项:

$$RHS1=\underbrace{\int_z q(z|x)\log\left(\frac{p(x,z)}{q(z|x)}\right)dz}_{ELBO,\text{denoted as }L_b}$$

$$RHS2=\underbrace{\int_z q(z|x)\log \left(\frac{q(z|x)}{p(z|x)} \right)}_{\text{KL div}.,KL(q(z|x)\Vert p(z|x))\ge 0}$$

由于KL散度项$\ge 0$,有

$$\log p(x) \ge L_b = \int_z q(z|x)\log\left(\frac{p(x,z)}{q(z|x)}\right)dz$$

由于$p(x,z)=p(x|z)p(z)$,可以对上式中的$p(x|z)$进行展开

$$RHS=\int_z q(z|x)\log \left(\frac{p(x|z)p(z)}{q(z|x)}\right)$$

根据对数函数性质,进行拆项处理:

$$RHS3=\underbrace{\int_z q(z|x)\log \left( \frac{p(z)}{q(z|x)}\right)dz}_{(3):\text{-KL div.},-KL(q(z|x)\Vert p(z))}+\int_z q(z|x)\log p(x|z)$$

最小化损失函数,相当于最大化ELBO。

[损失项1:$RHS3$式,负KL散度项]

对于$RHS3$进行拆项处理:

$$\int_z q(z|x)\log \left( \frac{p(z)}{q(z|x)}\right)dz = \int_z q(z|x)\log p(z)dz-\int_z q(z|x)\log q(z|x)dz$$

根据期望的定义,上式可以写成期望形式:

$$RHS=E_q(\log p(z))-E_q(\log q(z|x))$$

考虑对$\log p(z)$进行化简:

$$p(z)=\frac{1}{(2\pi)^{J/2}\vert I\vert^{1/2}}\exp\left(-\frac{1}{2} z^Tz\right)$$

则有

$$\log p(z)=-\frac{p}{2}\log(2\pi)-\frac{1}{2}J-\frac{1}{2}\sum_{j=1}^{J}z_j^2$$

而$E_q(z)=\mu^2+\sigma^2$,有

$$E_q(\log p(z))=-\frac{J}{2}\log(2\pi)-\frac{1}{2}J-\frac{1}{2}\sum_{j=1}^{J}(\mu_j^2+\sigma_j^2)$$

同理,有

$$E_q(\log q(z|x))=-\frac{J}{2}\log(2\pi)-\frac{1}{2}\sum_{j=1}^{J}\log\sigma_j^2-\frac{1}{2}\sum_{j=1}^{J}(1+\sigma_j^2)$$

综上,

$$RHS=\frac{1}{2}\sum_{j=1}^{J}\left(1+\log\sigma_j^2-\mu_j-e^{\log\sigma_j^2}\right)$$

[损失项2: $RHS4$式,重建概率项]

$$\max \int_z q(z|x)p(x|z)dz=\max E_{z\sim q(z|x)} \left(\log p(x|z)\right) \tag*{}$$

即最大化$x$经过Encoder-Decoder后的重建概率,又即最小化重建损失。