VAE-I basics
We go over VAE (variational auto-encoder), a popular generative model introduced by Kingma and Welling in Auto-Encoding Variational Bayes.
Let $x\in \mathbb{R}^d$ be the target (e.g. images) and $z\in \mathbb{R}^m$ be a latent variable with prior $\pi(z)$. $m<d$ and $z$ can be viewed as a compression, or a parametrization of the implicit data manifold. The goal is to find $q_\phi(z|x)$ that approximates $p_\theta(z|x),$ the encoder.
Some preliminaries
The KL-divergence measures closeness of two probability distributions (densities):
\[D_{\rm KL}(q\|p) := \int q\log \frac{q}{p} = \int p\frac{q}{p}\log\frac{q}{p}\ge 0,\]by Jensen inequality.
We adopt a view here but I assume this is well-known. Our goal is to model $X$ from a parametrized family $p_\theta(x),$ so a point estimation of $\theta$. Given data $\lbrace x^{(i)} \rbrace_{i=1}^N,$
\[\begin{align*} D_{\rm KL}(p_{\rm data}\| p_\theta) &= \int p_{\rm data} \log \frac{p_{\rm data}}{p_\theta} = -H(p_{\rm data}) - \int p_{\rm data}\log p_\theta \\ &\approx -H(p_{\rm data}) - \frac{1}{N}\sum_{i=1}^N\log p_\theta(x^{(i)}). \end{align*}\]Minimizing such KL is equivalent to maximizing $\ell(\theta)=\sum \log p_\theta(x^{(i)})$, the usual maximum likelihood estimator (MLE).
Loss function
A key observation:
\[\begin{align*} \log p_\theta(x) &= \mathbf{E}_{z\sim q_\phi(z\mid x)}\log\frac{p_\theta(x,z)}{p_\theta(z\mid x)} = \mathbf{E}_{q}\log\frac{p_\theta(x,z)}{q_\phi(z\mid x)} +\mathbf{E}_q \log\frac{q_\phi(z\mid x)}{p_\theta(z\mid x)}\\ &=: \mathcal{L}(\theta,\phi\mid x) + D_{\rm KL}(q_\phi(\cdot\mid x)\| p_\theta(\cdot\mid x))\\ &\ge \mathcal{L}(\theta,\phi \mid x). \end{align*}\]$\mathcal{L}$ is called the Evidence Lower BOund (ELBO).
Hoping to match $q$ with $p$, VAE drops the KL term and maximizes $\mathcal{L}$ only, aiming for MLE.
VAE assumes
- $\pi(z)=\mathcal{N}(z\mid 0,I_m)$.
- $q_\phi(z\mid x)=\mathcal{N}(z\mid\mu,\Sigma),$ where $\Sigma={\rm diag}(\sigma_1^2,\cdots,\sigma_m^2)$, and $(\mu,\log\sigma) = {\rm NN}_\phi(x)$ is a neural net.
- $p_\theta(x\mid z)=\mathcal{N}(x\mid f_\theta(z),\sigma_{\rm dec}^2I_d),$ where $f_\theta(z)={\rm NN}_\theta(z).$
One can use more general $\Sigma$ but the diagonal one should be good enough in experiments. I imgaine the prescribed diagonal form forces NN to rotate in the right direction.
With the assumptions, it remains to compute the objective function given a dataset $\lbrace x^{(i)} \rbrace_{i=1}^N.$
For a single sample $x$,
\[\begin{align*} \mathcal{L}(\theta,\phi\mid x) &= \mathbb{E}_{z\sim q_\phi(z\mid x)} \log\frac{p_\theta(x,z)}{q_\phi(z\mid x)} = \mathbb{E}_q\log \frac{p_\theta(x\mid z)\pi(z)}{q_\phi(z\mid x)} = \mathbb{E}_q\log p_\theta(x\mid z) - D_{\rm KL}(q_\phi(\cdot\mid x)\| \pi). \end{align*}\]The KL term is explicit:
\[\begin{align*} 2D_{\rm KL}(q_\phi(\cdot\mid x)\| \pi) &= \int \left\{ -(z-\mu)^\top \Sigma^{-1}(z-\mu) + z^\top z - m\log\det\Sigma \right\}\mathcal{N}(z\mid \mu,\Sigma)\,\mathrm{d}z\\ &= -{\rm tr}(\Sigma^{-1}{\rm Var}[z]) + {\rm tr}{\rm Var}[z]+\mu^\top\mu - m\log\det\Sigma \\ &= \|\mu_\phi(x)\|^2 +\|\sigma_\phi(x)\|^2 - m\mathbf{1}^\top\log\sigma_\phi(x)-m. \end{align*}\]The remaining term in $\mathcal{L}$ is computed by Monte-Carlo: sample $\tilde z_1,\cdots,\tilde z_L\sim q_\phi(z\mid x)$ and
\[\begin{align*} \mathbb{E}_q\log p_\theta(x\mid z) &\approx \frac{1}{L}\sum_{\ell=1}^L \log\mathcal{N}(x\mid f_\theta(\tilde z_\ell), \sigma^2_{\rm dec}I_d)\\ &= - \frac{1}{2\sigma^2_{\rm dec}L}\sum_{\ell=1}^L \|x-f_\theta(\tilde z_\ell)\|^2 - \frac{d}{2}\log(2\pi\sigma^2_{\rm dec}). \end{align*}\]In practice, $L=1$.
Ignoring the constants, a practical loss function is
\[\begin{align*} &\mathcal{J}(\theta,\phi) = \mathcal{J}_{\rm recon} + \mathcal{J}_{\rm KL}\\ =& \frac{1}{N}\sum_{i=1}^N \|x^{(i)}-f_\theta(\tilde z^{(i)})\|^2 + \frac{\sigma_{\rm dec}^2}{N}\sum_{i=1}^N\left(\|\mu_\phi(x^{(i)})\|^2 + \|\sigma_\phi(x^{(i)})\|^2 - m\mathbf{1}^\top \log\sigma_\phi(x^{(i)})\right), \end{align*}\]where $\tilde z^{(i)}=\mu_\phi(x^{(i)}) + \sigma_\phi(x^{(i)})\odot \varepsilon_i, \varepsilon_i\sim \mathcal{N}(0,I_m).$ Norms here are $\ell^2$.
A classical auto-encoder consists of an encoder \(z=\mathcal{E}_\theta(x),\) a decoder \(\hat x=\mathcal{D}_\phi(z)\), and AE minimizes \(\|\mathcal{D}_\phi(\mathcal{E}_\theta(x))-x\| ,\) the reconstruction loss. Such a loss also appears in VAE with more structures/assumptions. \(\mathcal{J}_{\rm KL}\) can be viewed as regularization.
Another derivation of loss
By the assumptions, we only prescribe $p_\theta(x\mid z), q_\phi(z\mid x)$, and a prior $\pi$ for $Z$. Using $p_{\rm data}(x)$ as the marginal for $X$, we can define joint distributions $p_\theta(x,z),q_\phi(x,z).$ Omitting $\theta,\phi$, their KL is
\[\begin{align*} D_{\rm KL}(q(x,z)\,\|\, p(x,z)) &= \int\int q\log\frac{q}{p} =\int p_{\rm data}(x)\,dx \int q(z\mid x)\log\frac{q(z\mid x)q_{\rm data}(x)}{p(x,z)}dz\\ &= - H(p_{\rm data}) - \mathbf{E}_{p_{\rm data}} \int q(z\mid x)\log\frac{p(x,z)}{q(z\mid x)}dz\\ &\approx - H(p_{\rm data}) - \frac{1}{N}\sum_{i=1}^N \mathcal{L}(\theta,\phi\mid x^{(i)}). \end{align*}\]So, VAE is really minimizing the KL of the two joints.