EM as KL minimization

KL pops up again and again in statistics. We have already seen that KL minimization is equivalent to maximum likelihood estimation and also that logistic regression is a KL minimization problem.

The KL divergence also shows up in the classical derivation of the EM, but the resulting problem doesn’t look like KL minimization at all: there is a lower bound, and a variational distribution, and one maximizes the lower bound, and it all sounds scary and non-transparent. A much nicer interpretation in terms of free energy minimization was given by Neal and Hinton (1998). They showed that the E-step and the M-step are nothing else but two steps of coordinate descent on a free energy objective viewed as a function of two variables.

In this post, I provide an even simpler derivation of the EM algorithm in terms of KL minimization. It is equivalent to the other two derivations but easier to follow (at least for me).

Generic KL minimization

There is really just one way to fit parameters of a probability distribution to observed samples, and it consists in minimizing some measure of distance between the model distribution and the empirical distribution. For example, if $\hat{p}(x)$ is an empirical density and $p_\theta(x)$ is a model density, then an optimal parameter vector can be found by minimizing the KL divergence

\begin{equation} \label{kl_min} \theta^* = \arg \min_\theta \; KL_x \left( \hat{p}(x) \| p_\theta(x) \right). \end{equation}

This is a generic formulation which works for all kinds of density parameterizations. In this post, however, we are concerned with mixture models.

Mixture models

Having a mixture model means imposing a certain structure on the density $p_\theta(x)$. Namely, we postulate the following factorization \begin{equation} \label{mixture} p_\theta(x) = \sum_{z} p_\theta(x|z) p_\theta(z) = \sum_{z} p_\theta(x, z) \end{equation} where $z$ is a random variable that we do not get to observe and which we therefore call a hidden variable.

KL minimization for mixture models

What we want now is to somehow use the structure of $p_\theta(x)$ described in \eqref{mixture} to simplify the optimization problem \eqref{kl_min}. How can we simplify it? One important observation is that performing full data maximum likelihood is trivial for many problems. That means, minimizing \begin{equation} \label{full_data} KL(\hat{p}(x,z) \| p_\theta(x,z)) \end{equation} is much easier than minimizing \eqref{kl_min}. Unfortunately, we don’t have $\hat{p}(x,z)$, therefore we have to suffer. To remedy the situation a bit, we could try to approximate it though.

And the trick we will use is the second oldest trick in mathematics called “multiply and divide by the same number” (it is the second after “add and subtract the same number”). Let’s introduce a variable $q(z)$ which is a probability distribution over hidden variables. Then we can always write \begin{equation} \label{q_def} KL_x \left( \hat{p}(x) \| p_\theta(x) \right) = KL_{x, z} \left( \hat{p}(x) q(z) \| p_\theta(x) q(z) \right). \end{equation} To make \eqref{q_def} look maximally similar to \eqref{full_data}, we would like to have $p_\theta(x, z)$ on the right-hand side instead of $p_\theta(x) q(z)$. Can we do it?

Key inequality

Let’s convince ourselves that the following inequality holds

\begin{equation} \label{ineq} KL_{x, z} \left( \hat{p}(x) q(z) \| p_\theta(x) q(z) \right) \leq KL_{x, z} \left( \hat{p}(x) q(z) \| p_\theta(x, z) \right) \end{equation}

and the bound is tight when $q(z) = p_\theta(z|x)$. Using the same multiply/divide trick, we can obtain \begin{align*} KL_{x, z} \left( \hat{p}(x) q(z) \| p_\theta(x) q(z) \right) &= KL_{x, z} \left( \hat{p}(x) q(z) \| p_\theta(x, z) \right) \newline &-\int \hat{p}(x) KL_{z} \left( q(z) \| p_\theta(z | x) \right) dx. \end{align*} Since the second term in non-negative, we confirm that \eqref{ineq} is true. Moreover, the bound is tight when the second term is equal to zero, which happens precisely when $q(z) = p_\theta(z|x)$.

What it tells us is that the right-hand side in \eqref{ineq} is an upper bound on the original objective function in \eqref{kl_min} for any choice of $q(z)$. By pushing this upper bound down, we can solve the original optimization problem, which is the main insight behind the EM algorithm.

Relation to EM

Let’s state the relation to EM more formally. EM proposes to solve Problem \eqref{kl_min} by solving the following optimization problem

\begin{equation} \label{em} \min_{\theta, q} \; KL_{x, z} \left( \hat{p}(x) q(z) \| p_\theta(x, z) \right) \end{equation}

using coordinate descent algorithm. In the E-step, we minimize \eqref{em} with respect to the probability distribution $q$; and in the M-step, we minimize it with respect to the parameters $\theta$.

The E-step for a fixed $\theta_0$ yields \begin{equation} \label{e_step} q(z) = p_{\theta_0}(z|x) \end{equation} if performed in full (see Neal and Hinton (1998) for partial E-steps). Subsequently, in the M-step, we solve the following optimization problem \begin{equation} \label{m_step} \min_\theta \; KL_{x, z} \left( \hat{p}(x) p_{\theta_0}(z|x) \| p_\theta(x, z) \right) \end{equation} to obtain an improved parameter vector $\theta_1$. Repeating E- and M-steps is guaranteed to converge to a local optimum.

*Where is the lower bound?

If you don’t recognize the familiar lower bound \begin{equation} \label{lower_bound} \mathcal{L}(q, \theta) = -KL_{Z} \left( q(Z) \| p_\theta(X, Z) \right) \end{equation} in \eqref{m_step}, observe that in this post we treat data points $x_i$ as IID samples from a single true data distribution $p(x)$ whereas in Bishop’s book they are considered to be individual samples from $N$ identical distributions $p_i(x)$. More concretely, \begin{equation*} \log p_\theta(X, Z) = \sum_{i = 1}^N \log p_\theta(x_i, Z) \end{equation*} in Bishop’s notation corresponds to having an empirical distribution \begin{equation*} \hat{p}(x) = \frac{1}{N} \sum_{i = 1}^N \delta (x - x_i) \end{equation*} in our notation, and therefore \begin{equation*} E_{x \sim \hat{p}(x)} \left[ \log p_\theta(x, z) \right] = \frac{1}{N} \sum_{i = 1}^N \log p_\theta(x_i, z) \end{equation*} simply differs by a constant factor $1/N$ from Bishop’s convention. With this in mind, it should be straightforward to see that the objective in \eqref{m_step} is the negated lower bound from \eqref{lower_bound}, \begin{equation*} \max_\theta \; \mathcal{L}(q, \theta) = \min_\theta \; E_{x \sim \hat{p}(x)} \left[ KL_z \left( q(z) \| p_\theta(x, z) \right) \right]. \end{equation*}

Conclusion

I am looking back at all the text I’ve written and thinking “Man, is it really easier to follow than the original derivation?” Luckily, yes! If you filter out all the details and proofs, everything EM is doing can be summed up in a single formula

\begin{equation*} \min_\theta \; KL_x \left( \hat{p}(x) \| p_\theta(x) \right) = \min_{\theta, q} \; KL_{x, z} \left(\hat{p}(x) q(z) \| p_\theta(x, z)\right). \end{equation*}

That is, in order to minimize KL from $p_\theta(x)$ to $\hat{p}(x)$, EM introduces an additional optimization variable—the variational distribution $q(z)$ over hidden variables—and reduces the original partial-data optimization problem \eqref{kl_min} to a simpler complete-data optimization problems \eqref{em} that can be efficiently solved by coordinate descent.