VAE-II EM?
In the last post on VAE, the ELBO reminds me of the EM (expectation-maximization) algo, so it’s worth exploring how they are related. This has been explored by many people but I just want to do it myself.
EM recap
Let’s recall EM briefly. See Ng’s notes for more details.
Consider a variable $X$ with some latent $Z$ which secretely affects $X$. It is often infeasible to directly find MLE, but here we assume $p_\theta(x,z)$ is more feasible. Given any distribution $q$ on the $Z$-space,
\[\begin{align*} \log p_\theta(x) &= \mathbf{E}_{z\sim q} \log \frac{p_\theta(x,z)}{p_\theta(z\mid x)} = \mathbf{E}_{q} \log \frac{p_\theta(x,z)}{q(z)}+ \mathbf{E}_{q}\log\frac{q(z)}{p_\theta(z\mid x)}\\ &=: \mathcal{L}(\theta,q\mid x) + D_{\rm KL}(q(z)\,\|\, p_\theta(z|x))\\ &\ge \mathcal{L}(\theta,q \mid x), \end{align*}\]where the equality holds iff. $q(z)=p_\theta(z\mid x).$ The computation is the same as in VAE.
EM algorithm is an alternation optimization method. Start with an initla guess $\theta^0$. At step $k$, compute two things:
\[\begin{equation} \tag{E-step} Q(\theta,\theta^k) := \sum_{i} \mathcal{L}(\theta,p_{\theta^k}(z\mid x^{(i)})\mid x^{(i)}). \end{equation}\] \[\begin{equation} \tag{M-step} \theta^{k+1}\in {\rm arg}\max Q(\theta,\theta^k). \end{equation}\]The E-step is really choosing the maximizer: (We are not careful about the uniqueness)
\[\begin{equation} \tag{E'-step} p_{\theta^k}(z\mid x^{(i)}) = {\rm arg}\max_q \mathcal{L}(\theta^k,q \mid x^{(i)}) \end{equation}\]Note that
\[Q(\theta^k,\theta^k)=\ell(\theta^k),\quad Q(\theta,\theta^k)\le \ell(\theta).\]Recall $\ell(\theta)=\sum_i \log p_\theta(x^{(i)})$ is the log-likelihood. So $Q(\cdot,\theta^k)$ is a lower barrier of $\ell$ at $\theta^k$ (in analysts’ term).
EM always increases likelihood:
\[\ell(\theta^{k+1})\ge Q(\theta^{k+1},\theta^{k}) \ge Q(\theta^k,\theta^k) = \ell(\theta^k).\]EM has nice convergence properties by J. Wu and my joint work etc.
VAE
In E-step, EM maximizes $\mathcal{L}(\theta^k,\cdot\mid x^{(i)})$ without restriction, while VAE searches among $q_\phi(z\mid x^{(i)})$ with a prescribed form of Gaussian. Since $q_\phi$ is Gaussian, there should always be a gap from the actual maximizer $p_\theta(z\mid x^{(i)})$ (generally not Gaussian) no matter how sophisticated your NN is. A possible improvement might be prescribing $q_\phi$ in a more flexible form.
Given a candidate $q_i^\star(z):=q_{\phi^\star}(z\mid x^{(i)})$, EM in M-step maximizes \(\mathcal{L}(\cdot,q^\star_i\mid x^{(i)})\). Recall the term involving $\theta$ in the loss of VAE is the reconstruction loss
\[\mathcal{J}_{\rm recon} = \frac{1}{N}\sum_{i=1}^N \|x^{(i)}-f_\theta(\tilde z^{(i)})\|^2,\]where $\tilde z^{(i)}\sim q^\star_i.$ This can be viewed as fixing $\phi$ and minimizing $\mathcal{J}_{\rm recon}$, which is equivalent to maximizing ELBO while fixing $\phi$.
So, VAE in a way splits the optimization into E/M steps and approximates maximizers of both E and M steps. One step of gradient descent loosely corresponds to an EM update.
Now let’s try to make it concrete. Let \(\mathcal{Q}_{x}\) be the space of candidates of $q$ in the E-step.
Define
\[q^\star(z\mid x,\theta):={\rm arg}\max_{q\in \mathcal{Q}_x} \mathcal{L}(\theta,q\mid x),\quad \tilde Q(\theta,\theta^k) = \sum_i \mathcal{L}(\theta, q^\star(z\mid x^{(i)}, \theta^k)\mid x^{(i)}).\]If \(\mathcal{Q}_x\) can be everything (as in EM), then $\tilde Q=Q$. For VAE, \(\mathcal{Q}_x=\{\mathcal{N}(z\mid \mu_\phi(x),\sigma^2_\phi(x))\}.\)
By the ELBO identity above,
\[\tilde Q(\theta,\theta^k) \le \tilde Q(\theta,\theta^k) + \sum_i D_{\rm KL}(q^\star(\cdot\mid x^{(i)},\theta^k)\,\|\, p_\theta(\cdot\mid x^{(i)})) = \ell(\theta).\]However, $\tilde Q(\theta^k,\theta^k) < \ell(\theta^k)$ in general, and thus $\tilde Q$ might not be a barrier.
In M-step, VAE approximates a maximizer of $\tilde Q(\cdot,\theta^k)$ in one step of gradient descent. If the approximation actually works,
\[\ell(\theta^{k+1})\ge \tilde Q(\theta^{k+1},\theta^k) \ge \tilde Q(\theta^k,\theta^k) = \ell(\theta^k)-\gamma_k,\]where $\gamma_k$ is the gap at step k. Hopefully, $\gamma_k\to 0$ and $\ell(\theta^k)$ is almost monotone.