We will implement the mixed Bernoulli distribution of PRML 9.3.3. As an example of the EM algorithm, the maximum likelihood estimation of the mixed Gaussian distribution by adding several Gaussian distributions is common, but it can also be applied to the maximum likelihood estimation of the mixed Bernoulli distribution by adding the Bernoulli distributions. The Gaussian distribution has two parameters, mean and variance, while the Bernoulli distribution has only one parameter, which is rather easier. This time, we will use the mixed Bernoulli distribution and apply it to MNIST like PRML to cluster each number.
The model used this time is based on the multidimensional Bernoulli distribution. It represents the distribution of D-dimensional binary vectors.
{\rm Bern}({\bf x}|{\bf\mu}) = \prod_{i=1}^D \mu_i^{x_i}(1-\mu_i)^{(1-x_i)}
The mixed Bernoulli distribution is obtained by weighting this with the K-dimensional mixing coefficient $ {\ bf \ pi} $ and adding K pieces together. Assuming that the training data is $ {\ bf X} = \ {{\ bf x} \ _1, \ dots, {\ bf x} \ _N \} $
p({\bf X}|{\bf\mu},{\bf\pi}) = \prod_{n=1}^N\left\{\sum_{k=1}^K\pi_k{\rm Bern}({\bf x}_n|{\bf\mu}_k)\right\}
Will be. Now introduce the latent variable $ {\ bf Z} = \ {{\ bf z} \ _1, \ dots, {\ bf z} \ _N \} $ for each piece of data. The K-dimensional binary latent variable vector $ {\ bf z} $ has only one of the K components being 1, and all other components being 0. Given the complete data $ {\ bf X, Z} $, the likelihood function is:
p({\bf X, Z}|{\bf\mu,\pi}) = \prod_{n=1}^N\left\{\prod_{k=1}^K\pi_k^{z_{nk}}{\rm Bern}({\bf x}_n|{\bf\mu}_k)^{z_{nk}}\right\}
import
If the multidimensional Bernoulli distribution is used as it is, the likelihood is too small and it is inconvenient for the computer, so use logsumexp to use the logarithm.
import numpy as np
from scipy.misc import logsumexp
If you are a python2 type person, please replace @ with a function that calculates the inner product of numpy.
#Mixed Bernoulli distribution
class BernoulliMixtureDistribution(object):
    def __init__(self, n_components):
        #Number of clusters
        self.n_components = n_components
    def fit(self, X, iter_max=100):
        self.ndim = np.size(X, 1)
        #Parameter initialization
        self.weights = np.ones(self.n_components) / self.n_components
        self.means = np.random.uniform(0.25, 0.75, size=(self.n_components, self.ndim))
        self.means /= np.sum(self.means, axis=-1, keepdims=True)
        #Repeat EM step
        for i in range(iter_max):
            params = np.hstack((self.weights.ravel(), self.means.ravel()))
            #E step
            stats = self._expectation(X)
            #M step
            self._maximization(X, stats)
            if np.allclose(params, np.hstack((self.weights.ravel(), self.means.ravel()))):
                break
        self.n_iter = i + 1
    #PRML formula(9.52)Logarithm of
    def _log_bernoulli(self, X):
        np.clip(self.means, 1e-10, 1 - 1e-10, out=self.means)
        return np.sum(X[:, None, :] * np.log(self.means) + (1 - X[:, None, :]) * np.log(1 - self.means), axis=-1)
    def _expectation(self, X):
        #PRML formula(9.56)
        log_resps = np.log(self.weights) + self._log_bernoulli(X)
        log_resps -= logsumexp(log_resps, axis=-1)[:, None]
        resps = np.exp(log_resps)
        return resps
    def _maximization(self, X, resps):
        #PRML formula(9.57)
        Nk = np.sum(resps, axis=0)
        #PRML formula(9.60)
        self.weights = Nk / len(X)
        #PRML formula(9.58)
        self.means = (X.T @ resps / Nk).T
Like this jupyter notebook 9.3.3 Applying the mixed Bernoulli distribution to the MNIST dataset (200 randomly picked images from 0 to 4 each), the average of the individual Bernoulli distributions is as shown in the figure below.

Since the learning of the EM algorithm fits into the local solution (although it may not be the local solution in reality), it is not only that each number is clearly reflected as shown above. I felt that it was difficult to learn if there were pairs with similar shapes such as 1 and 7 and 3 and 8.
Recommended Posts