Probably one of the most impactful papers in the Deep learning era is Auto-Encoding Variational Bayes (Kingma and Welling, 2014). The paper proposes a framework to perform scalable and effective inference for a broad family of scenarios, leveraging large amounts of data and leading to the popular Variational Autoencoder (VAE) model.
For that, it presents a probabilistic formulation, in which the observed variable $x$ is conditioned by some latent, unobserved variable $z$. The relationship between both variables is captured by Bayes’ theorem:
$$ p(z|x) = \frac{p(x|z) p(z)}{p(x)} $$
This relationship opens the door for many useful applications, like infering the latent code $z^{(i)}$ given an observation $x^{(i)}$, or extending the dataset $X$ with synthetic data by sampling from $p(z)$. However, it can’t be directly used in the general case due to the intractability of $p(x)$ (usually called the evidence or marginal distribution), and preexisting approaches to capture $p(x)$, like importance sampling are unfeasible in complex scenarios.
As a solution, the paper proposes to rewrite the marginal log-likelihood into the following expression, which has the nice properties mentioned before:
$$ log~p(x) = D_{KL}\Big(q(z|x)\Vert p(z|x) \Big) - D_{KL}\Big(q(z|x)\Vert p(z) \Big) + \mathbb{E}_{q(z|x)} \Big[ log~p(x | z) \Big] $$
But it doesn’t provide the explicit derivation. At first glance, the relation to Bayes’ theorem is not really apparent, due to the presence of new elements and operators.
Note that other sources online already provide derivations, but (as far as I could find) they either invoke Jensen’s inequality to ignore elements from the derivation, or start from the KL already, resulting in circular or incomplete derivations. This 2020 post shares the goal of deriving the expression above directly from Bayes, but it also makes use of inequalities to do so.
While all of the above are perfectly legitimate, a direct and complete derivation allows us to understand better what is going on. And as it turns out, it is quite simple, but a bit tricky to remember on the spot if you don’t work with it frequently. For this reason I thought it would be a nice addition, hopefully it helps others as well!
We’ll use only a few basic elements:
The full derivation can be done in 4 steps (I omitted the parameters for notational simplicity, they aren’t necessary to understand the derivation):
Steps 1, 3 and 4 involve the basic use of the definitions above and the associative property. In step 2, the equivalence is maintained because given any probability distribution $q(z|x)$, we are adding $0 = log~q(z|x) - log~q(z|x)$ and then multiplying by $1 = \int_z q(z|x) dz$. Nice! But this is much more than a nifty algebraic trick, it provides a lot of meaning and practicality and is the essence of variational inference approaches.
Let’s start remembering that our original goal was to overcome the intractability of the evidence $p(x)$. The approach here is to introduce a new probability distribution $q(z|x)$ (let’s not worry about it for the moment). With that, we are able to decompose the log-evidence into 3 elements. And since the $D_{KL}$ can’t be negative, we have:
$$ log~p(x) \geq - D_{KL}\Big(q(z|x)\Vert p(z) \Big) + \mathbb{E}_{q(z|x)} \Big[ log~p(x | z) \Big] \hat= ELBO \\ $$
Which stands for Evidence Lower Bound. From this we can easily see that the greater the ELBO, the closer we’ll get to the actual log-evidence. Therefore, the objective is to maximize the ELBO in order to approximate the evidence.
Furthermore, since $log~p(x)$ is constant, the larger the ELBO, the smaller the divergence $D_{KL}\Big(q(z|x)\Vert p(z|x) \Big)$. A smaller divergence between the true posterior and our variational posterior means that our model reflects better the actual distribution. So the maximal ELBO will also correspond to the best variational fit for the posterior.
The idea of reformulating the expression to arrive to an optimization objective that approximates the log-evidence is nice, but not enough:
So we can see that, assuming we had an appropriate $q(z|x)$, we can drop the first term $D_{KL}\Big(q(z|x)\Vert p(z|x) \Big)$, leaving only tractable computations ahead. We can also see that $q(z|x)$ plays the role of an encoder, converting observations into latent probability distributions. And those latent distributions should be as close to the prior $p(z)$ as possible, to minimize the negative impact of the second term $- D_{KL}\Big(q(z|x)\Vert p(z) \Big)$ on the ELBO that we want to maximize. This is interesting since it requires us to explicitly propose our prior expectation about how should z be distributed.
Finally, $p(x|z)$ acts as a decoder, effectively converting the latent representation $z$ back into an observation. Our objective tells us that we want to maximize the expected value of this decoder over the whole latent distribution. This also makes sense: if we sampled from the latent distribution, we’d like the resulting reconstruction to have a high likelihood of being useful (e.g. by being as close as possible to the original representation).
With the goal of approximating the evidence distribution, we’ve seen a pretty concise, comprehensive and direct reformulation of Bayes’ theorem that results in maximizing the ELBO. On the way, we were able to give proper interpretation to all steps and elements involved.
The focus was on the ELBO derivation: we haven’t covered further important details, like the reparametrization trick, since those are well covered in the paper and elsewhere.
An interesting question is, as a lower bound, how tight the ELBO actually is. This is in fact an active research field and recent discoveries were able to find tighter, efficient bounds, see e.g. Tightening Bounds for Variational Inference by Revisiting Perturbation Theory by Bamler et al. (JSTAT 2019). Another interesting revision of the idea of VAEs are the Wasserstein Auto-Encoders (WAEs) proposed by Tolstikhin et al (ICLR 2018).