Słownik

Wybierz jedno ze słów kluczowych po lewej stronie…

Bayesian Inference and Graphical ModelsExpectation-Maximization

Czas czytania: ~40 min

EM for Gaussian mixture models

In this section, we'll develop an approach to estimating model parameters when some of the random variables involved in the model's Bayes net are not observed. We'll begin with the Gaussian mixture model and develop an intuitive version of the method, and then we'll introduce the general version and apply it to a hidden Markov model.

Recall that the Gaussian mixture model (GMM) is a Bayes net with just two random variables: a discrete random variable Z and a random vector X. To draw an observation from this model, we draw Z from a given discrete distribution on \{1,2, \ldots, d\}, and then we draw X from a multivariate normal distribution with mean \mu_Z and covariance \Sigma_Z.

Let's generate observations from a made-up GMM.

using Distributions, Plots, Random, LinearAlgebra, Statistics
include("data-gymnasia/ellipse.jl")
Random.seed!(123)
n = 100
α = 0.4
𝒩₀ = MvNormal([1,1],[2.0 1.0; 1.0 2.0])
𝒩₁ = MvNormal([3.0,7.0],[1.5 0; 0 0.5])
X₁ = zeros(n)
X₂ = zeros(n)
Z = zeros(Bool,n)
for i=1:n
    Z[i] = rand(Bernoulli(α))
    X₁[i], X₂[i] = Z[i] ? rand(𝒩₁) : rand(𝒩₀)
end
scatter(X₁, X₂, color = :gray, legend = false)

Let's think about how we could set about recovering the parameters of this model if all we had were the observations shown in the scatter plot.

One simple idea would be to write down the log likelihood of the data and hand it to an optimization algorithm to find parameters which maximize it. The problem is the lack of a reasonable way to write down the log likelihood when we have missing values in the Bayes net. Instead, let's develop an iterative approach that starts with a bad guess and works to improve it. We begin with arbitrary values for the parameters:

α = 0.6
μ₀ = [3.0,3.0]
μ₁ = [1.0,6.0]
Σ₀ = 1.0*Matrix(I, 2, 2)
Σ₁ = 1.0*Matrix(I, 2, 2)
mixtureplot(X₁,X₂,μ₀,Σ₀,μ₁,Σ₁)

Conceptually, we'd like to fit the blue distribution to the points in roughly the lower half of the figure, leaving the remaining upper points to be fit by the orange distribution. To this end, we come up with a score for each point indicating how much it seems to belong to the blue distribution or orange distribution, based on our current parameter estimates. More precisely, let's compute for each point its conditional probability of having Z = 1 (orange) given the (x_1, x_2) value of the point.

We'll compute these values (conditional probability of being blue) for each point, and store the result in a vector called Π:

Π = [α*pdf(MvNormal(μ₁,Σ₁),[x₁,x₂]) /
       ((1-α)*pdf(MvNormal(μ₀,Σ₀),[x₁,x₂]) +
        α*pdf(MvNormal(μ₁,Σ₁),[x₁,x₂])) for (x₁,x₂) in zip(X₁,X₂)]

We can visualize the result of this computation by actually coloring each point x_i according to its blueness/orangeness value \pi_i:

mixtureplot(X₁,X₂,μ₀,Σ₀,μ₁,Σ₁,Π)

Next, we can fit a multivariate Gaussian to the points we colored blue. However, rather than performing the discontinuous operation of snapping each point to "orange" or "blue", we maintain the real-valued nature of the blueness/orangeness of each point, and instead compute a mean and covariance to find new parameters for the blue and orange distributions.

α = sum(Π)/n
μ₀ = [(1 .- Π) ⋅ X₁, (1 .- Π) ⋅ X₂] / sum(1 .- Π)
μ₁ = [Π ⋅ X₁, Π ⋅ X₂] / sum(Π)
Σ₀ = Matrix(Hermitian(sum((1-π)*([x₁,x₂] - μ₀) * ([x₁,x₂] - μ₀)' for (x₁,x₂,π) in zip(X₁,X₂,Π))/sum(1 .- Π)))
Σ₁ = Matrix(Hermitian(sum(π*([x₁,x₂] - μ₁) * ([x₁,x₂] - μ₁)' for (x₁,x₂,π) in zip(X₁,X₂,Π))/sum(Π)))
Π = [α*pdf(MvNormal(μ₁,Σ₁),[x₁,x₂]) /
       ((1-α)*pdf(MvNormal(μ₀,Σ₀),[x₁,x₂]) +
        α*pdf(MvNormal(μ₁,Σ₁),[x₁,x₂])) for (x₁,x₂) in zip(X₁,X₂)];
mixtureplot(X₁,X₂,μ₀,Σ₀,μ₁,Σ₁,Π)

If you run the cell above a few times, you'll see that it pretty quickly settles on a particular choice for the model parameters.

The General EM Algorithm

Now let's consider a general Bayesian network, with some variables hidden and others observed. How can we generalize the approach we developed for Gaussian mixture models?

The first step in our Gaussian mixture model algorithm was to compute for each data point the conditional distribution of the hidden variable Z given the observed data X. That step is already general: for any collection of Z's and X's, we can compute the conditional distribution of the Z's given the X's.

The second step is trickier to generalize: for the GMM, we used the conditional probabilities for each value of Z as weights and chose model parameters which fit the data in a way that accounted for those weights. In the general case, we'll use the conditional distribution of the Z's given the X's as a probability measure with respect to which we will compute the expected log likelihood and then choose model parameters which maximize that quantity.

More explicitly, we iterate the following steps to convergence:

  1. Using current values for the parameters, work out the likelihood as a function of the observed values of the X's (which we will write using lowercase x's) as well as values for the Z's which we'll pretend we also observed (we write these values as lowercase z's).

  2. We will uppercase the z's to treat them as random variables, and we'll calculate the expectation of the log likelihood function with respect to the conditional distribution of the Z's given the x's.

  3. We maximize the expected log likelihood computed in the second step and update the parameter values to these new optimizing values.

Let's begin by showing that this algorithm is equivalent to the one we introduced previously in the Gaussian mixture case. Suppose 1 \leq i \leq n. The likelihood of Z_i = z_i and X_i = x_i is \alpha f_1(x_i) if z_i = 1 and (1-\alpha)f_0(x_i) if z_i = 0 (where f_j is the Gaussian desity with mean \mu_j and covariance \Sigma_j, for j \in \{0, 1\}).

We can write this in a single expression by saying that the likelihood of Z_i = z_i and X_i = x_i is z_{i} \alpha f_{1}\left(\mathbf{x}_{i}\right)+\left(1-z_{i}\right)(1-\alpha) f_{0}\left(\mathbf{x}_{i}\right). So the overall likelihood given all of the "observed" data is the product of these expressions:

\begin{align*}\prod_{i=1}^{n}\left(z_{i} \alpha f_{1}\left(\mathbf{x}_{i}\right)+\left(1-z_{i}\right)(1-\alpha) f_{0}\left(\mathbf{x}_{i}\right)\right)\end{align*}

Since the log of a product is a sum of logs, we get

\begin{align*}\sum_{i=1}^{n}\log\left(z_{i} \alpha f_{1}\left(\mathbf{x}_{i}\right)+\left(1-z_{i}\right)(1-\alpha) f_{0}\left(\mathbf{x}_{i}\right)\right)\end{align*}

for the log likelihood. Finally, taking the expectation, we can use linearity of expectation to look at one term at a time:

\begin{align*}\sum_{i=1}^{n}\mathbb{E}\left[\log\left(Z_{i} \alpha f_{1}\left(\mathbf{x}_{i}\right)+\left(1-Z_{i}\right)(1-\alpha) f_{0}\left(\mathbf{x}_{i}\right)\right)\right]\end{align*}

Note that the distribution of the random variable Z_i is supported at just two points (0 and 1). Therefore, we can pretty manageably compute this expectation by multiplying each possible value of the random variable by the probability that that value occurs.

Remember that the probability measure we decided to use for the Z's is their conditional distribution given the specified x's. So let's define \pi_i to be the conditional probability of Z_i = 1 given X_i = x_i (for each i from 1 to n). Then we get an expected log likelihood of

\begin{align*}Q(\theta) &= \mathbb{E}\left[\log \prod_{i=1}^{n}\left(z_{i} \alpha f_{1}\left(\mathbf{x}_{i}\right)+\left(1-z_{i}\right)(1-\alpha) f_{0}\left(\mathbf{x}_{i}\right)\right)\right] \\ &= \sum_{i=1}^{n} \pi_{i}\left[\log \alpha+\log f_{1}\left(\mathbf{x}_{i}\right)\right] \\ & \quad+\left(1-\pi_{i}\right)\left[\log (1-\alpha)+\log f_{0}\left(\mathbf{x}_{i}\right)\right]\end{align*}

We can differentiate to minimize the terms involving the \pi's, and we find that \alpha works out to be the mean of the \pi values (\frac{1}{n} \sum_{i=1}^n \pi_i). The remaining terms (involving f_1 and f_0) take the form of a weighted maximum likelihood estimation problem for the normal distribution, and the optimizing parameters for that problem are the weighted sample mean and weighted sample covariance. (We derived this result in the statistics course in the case where the weights are uniform, and here we'll take the generalization of that calculation for granted).

EM for Hidden Markov Models

Let's see the general EM algorithm in action on a more complex model. Recall the Hidden Markov Model we discussed in the previous section: the Markov chain starts at state 0 or 1 with probability 1/2 each and has a transition matrix of the form

\begin{align*}P = \left[ \begin{matrix} q & 1-q \\ 1-q & q\end{matrix} \right],\end{align*}

where q \in [0,1]. We assume further that the conditional distribution of each random variable X_j is a Gaussian with mean zero and unknown variance \sigma^2. So we'll be trying to estimate q and \sigma^2 based on a series of observations of the values of X.

using Plots, Distributions, OffsetArrays
q = 0.8
σ² = 0.25
P = OffsetArray([q 1-q; 1-q q], 0:1, 0:1)

Random.seed!(1)

function markov_chain(P, n)
    Z = [0]
    for i in 1:n-1
        current_state = Z[end]
        push!(Z, rand(Bernoulli(P[current_state, 0])) ? 0 : 1)
    end
    Z
end

Z = markov_chain(P, 100)
X = Z + √(σ²)*randn(100)
plot(Z, size = (500, 150), ylims = (-4, 4), legend = false)
plot!(X)

Let's follow the EM algorithm. The first step is to find the likelihood for .

The probability that we see Z_1 = z_1 is 1/2 (regardless of the value of z_1). Then the probability that we see Z_2 = z_2 given that Z_1 = z_1 is if z_1 = z_2 and if z_1 \ne z_2. We can write this by saying that the conditional probability of Z_2 = z_2 given that Z_1 = z_1 is

\begin{align*}q\mathbf{1}_{z_2 = z_{1}} + (1-q)\mathbf{1}_{z_2 \neq z_{1}}\end{align*}

Likewise, we get a similar factor for Z_3, another for Z_4, and so on. Putting these factors together, we get a likelihood of

\begin{align*}\frac{1}{2}\prod_{j=2}^n\left[ q\mathbf{1}_{z_j = z_{j-1}} + (1-q)\mathbf{1}_{z_j \neq z_{j-1}}\right],\end{align*}

accounting for all of the Z's.

The conditional probability that we see a value of X_1 which is really close to x_1 (given the Z values) is proportional to \frac{1}{\sqrt{2\pi\sigma^2}}\operatorname{e}^{-(x_1-z_1)^2/(2\sigma^2)}, and similarly for x_2, x_3, and so on. All together, we get a likelihood of

\begin{align*}\frac{1}{2}\prod_{j=2}^n [q\mathbf{1}_{z_j = z_{j-1}} + (1-q)\mathbf{1}_{z_j \neq z_{j-1}}]\prod_{j=1}^n \log\left(\frac{1}{\sqrt{2\pi \sigma^2}} e^{-\frac{(x_j-z_j)^2}{2\sigma^2}}\right)\end{align*}

Taking the log, we get

\begin{align*}\log(1/2) &+ \sum_{j=2}^n \log([q\mathbf{1}_{z_j = z_{j-1}} + (1-q)\mathbf{1}_{z_j \neq z_{j-1}}]) + \sum_{j=1}^n \log\left(\frac{1}{\sqrt{2\pi \sigma^2}} e^{-\frac{(x_j-z_j)^2}{2\sigma^2}}\right) \\ &= \log(1/2) + \log(q)\sum_{j=2}^n \mathbf{1}_{z_j = z_{j-1}} + \log(1-q) \sum_{j=2}^n \mathbf{1}_{z_j \neq z_{j-1}} - \frac{1}{2\sigma^2}\sum_{j=1}^n (x_j-z_j)^2 - \frac{n}{2} \log(2\pi\sigma^2).\end{align*}

Finally, we replace the z's with Z's (reflecting that we are going to treat them as random variables for purposes of computing an expected value) and take the expectation with respect to .

\begin{align*}&\mathbb{E}\left[\log(1/2)\right] + \log(q)\mathbb{E}\left[\sum_{j=2}^n \mathbf{1}_{Z_j = Z_{j-1}}\right] + \log(1-q) \mathbb{E}\left[\sum_{j=2}^n \mathbf{1}_{Z_j \neq Z_{j-1}}\right] - \frac{1}{2\sigma^2}\mathbb{E}\left[\sum_{j=1}^n (x_j-Z_j)^2\right] - \mathbb{E}\left[\frac{n}{2} \log(2\pi\sigma^2)\right] \\ &= \log(1/2) + \log(q)\underbrace{\mathbb{E}\left[\sum_{j=2}^n \mathbf{1}_{Z_j = Z_{j-1}}\right]}_{a} + \log(1-q) \underbrace{\mathbb{E}\left[\sum_{j=2}^n \mathbf{1}_{Z_j \neq Z_{j-1}}\right]}_{b} - \frac{1}{2\sigma^2}\underbrace{\mathbb{E}\left[\sum_{j=1}^n (x_j-Z_j)^2\right]}_{c} - \frac{n}{2} \log(2\pi\sigma^2) \\ &= \log(1/2) + a\log(q) + b\log(1-q) - \frac{c}{2\sigma^2} - \frac{n}{2}\log(2\pi\sigma^2).\end{align*}

Our goal is to find new values of q and \sigma^2 which this expression. We can do that by differentiating with respect to each of these parameters and setting the resulting expressions equal to . Solving those systems, we get

\begin{align*}q &= \frac{a}{a+b} \\ \sigma^2 &= \frac{c}{n}\end{align*}

Therefore, the only remaining task is estimating a, b, and c.

Exercise
Give verbal descriptions of the quantities a, b, and c.

Solution. We can describe a as the expected number of times the Markov chain stays in the same state, b as the expected number of Markov chain switches, and c as the expected squared distance between the vector of X's and the vector of Z's.

The main difficulty in estimating a, b, and c is that the conditional distribution of the Z's given the X's is a reasonably complex probability measure. It has to account for the X values as well as the conditional distribution of each Z_j given the value of Z_{j-1}.

Fortunately, we've developed a technique for sampling from complex probability measures: .

As a reminder, Metropolis-Hastings proceeds by starting from some point in the space we're trying to sample from. In this case, that means starting with a length-n binary string. Then we propose changes to the string and accept or reject them with a probability which is determined by the density values at the current and proposed strings.

To work out the acceptance ratio, we need to compare the values of the desired density function for two paths which differ in one position.

Exercise
In terms of q and \sigma^2, work out the acceptance ratio for the proposal to move from [Z_1, Z_2, Z_3, Z_4] = [0, 0, 1, 1] to [Z_1, Z_2, Z_3, Z_4] = [0, 1, 1, 1] if [X_1, X_2, X_3, X_4] = [0.25, -0.3, 0.8, 1.1].

Solution. We get

\begin{align*}\frac{\frac{1}{2}(1-q)qqf_\sigma(0.25)f_\sigma(-1.3)f_\sigma(-0.2)f_\sigma(0.1)}{\frac{1}{2}q(1-q)qf_\sigma(0.25)f_\sigma(-0.3)f_\sigma(-0.2)f_\sigma(0.1)},\end{align*}

where f_\sigma denotes the Gaussian density with mean 0 and variance \sigma^2. We see that factors not involving the second position (where the change was proposed) cancel. So we're left with

\begin{align*}\frac{(1-q)qf_\sigma(-1.3)}{q(1-q)f_\sigma(-0.3)}.\end{align*}

Let's write down the result of this exercise in general terms. Including only the factors that don't necessarily cancel, we obtain

\begin{align*}\frac{(q \mathbf{1}_{z_{j-1} \ne z_{j}} + (1-q)\mathbf{1}_{z_{j-1} = z_j}) (q \mathbf{1}_{z_{j+1} \ne z_j} + (1-q)\mathbf{1}_{z_{j+1} = z_j}) \frac{1}{\sqrt{2\pi\sigma^2 }}e^{-\frac{(x_j - (1-z_j))^2}{2\sigma^2 }}} {(q \mathbf{1}_{z_{j-1} = z_j} + (1-q)\mathbf{1}_{z_{j-1} \ne z_j}) (q \mathbf{1}_{z_{j+1} = z_j} + (1-q)\mathbf{1}_{z_{j+1} \ne z_j}) \frac{1}{\sqrt{2\pi\sigma^2 }}e^{-\frac{(x_j - z_j)^2}{2\sigma^2 }}}\end{align*}

as the Metropolis-Hastings acceptance ratio.

Gibbs Sampling

To get more efficient mixing, we'll use a variation of Metropolis-Hastings where we cycle through the positions in order to propose changes, rather than choosing them randomly. This is called Gibbs sampling.

Lets take a look at several draws from the conditional distribution of Z given X:

include("data-gymnasia/expectation-maximization.jl")
observations = [plot(gibbs_sampler(X, (q, σ²)), yticks = 0:1) for _ in 1:10]
plot(observations..., layout = (10, 1), size = (700, 700), legend = false)    

We can use these draws from the conditional distribution of Z given X to estimate a, b, and c by counting the number of times each sample path Z switches states and accumulating the squared difference between X and and each path Z. This method of estimating an expected value using draws from the underlying distribution is called .

The script expectation-maximization.jl also contains a method for using these draws to estimate a, b, and c:

estimate_a_b_c(X, θ)

We see that these estimates make sense: the Markov chain switches about 20% of the time, so we get an a value of about 80% and a b value of about 20%. Likewise, the variance used to generate these data was \sigma^2 = 0.25, so the accumulated squared difference across all 100 values of j is approximately (0.25)(100) = 25.

Finally, we can actually perform the expectation-maximization algorithm using the update rules we derived:

q, σ² = em_algorithm(X)
Bruno
Bruno Bruno