Python-Implementierung gemischte Bernoulli-Verteilung

Wir werden die gemischte Bernoulli-Distribution von PRML 9.3.3 implementieren. Als Beispiel für den EM-Algorithmus ist die wahrscheinlichste Schätzung der gemischten Gaußschen Verteilung durch Addition mehrerer Gaußscher Verteilungen üblich, sie kann jedoch auch auf die wahrscheinlichste Schätzung der gemischten Bernoulli-Verteilung durch Addition der Bernoulli-Verteilungen angewendet werden. Die Gaußsche Verteilung hat zwei Parameter, Mittelwert und Varianz, während die Bernoulli-Verteilung nur einen Parameter hat, was eher einfacher ist. Dieses Mal werden wir die gemischte Bernoulli-Verteilung verwenden und sie wie PRML auf MNIST anwenden, um jede Zahl zu gruppieren.

Gemischte Bernoulli-Verteilung

Das diesmal verwendete Modell basiert auf der mehrdimensionalen Bernoulli-Verteilung. Dies repräsentiert die Verteilung von D-dimensionalen Binärvektoren.

{\rm Bern}({\bf x}|{\bf\mu}) = \prod_{i=1}^D \mu_i^{x_i}(1-\mu_i)^{(1-x_i)}

Die gemischte Bernoulli-Verteilung wird erhalten, indem diese mit dem K-dimensionalen Mischungskoeffizienten $ {\ bf \ pi} $ gewichtet und K-Stücke addiert werden. Wenn die Trainingsdaten $ {\ bf X} = \ {{\ bf x} \ _1, \ dots, {\ bf x} \ _N \} $ sind

p({\bf X}|{\bf\mu},{\bf\pi}) = \prod_{n=1}^N\left\{\sum_{k=1}^K\pi_k{\rm Bern}({\bf x}_n|{\bf\mu}_k)\right\}

Wird sein. Führen Sie nun die latente Variable $ {\ bf Z} = \ {{\ bf z} \ _1, \ dots, {\ bf z} \ _N \} $ für jedes Datenelement ein. Der K-dimensionale binäre latente Variablenvektor $ {\ bf z} $ hat nur eine der K-Komponenten, die 1 ist, und alle anderen Komponenten, die 0 sind. Bei den vollständigen Daten $ {\ bf X, Z} $ lautet die Wahrscheinlichkeitsfunktion:

p({\bf X, Z}|{\bf\mu,\pi}) = \prod_{n=1}^N\left\{\prod_{k=1}^K\pi_k^{z_{nk}}{\rm Bern}({\bf x}_n|{\bf\mu}_k)^{z_{nk}}\right\}

Code

import Wenn die mehrdimensionale Bernoulli-Verteilung unverändert verwendet wird, ist die Wahrscheinlichkeit zu gering und für den Computer unpraktisch. Verwenden Sie daher "logsumexp", um den Logarithmus zu verwenden.

import numpy as np
from scipy.misc import logsumexp

Gemischte Bernoulli-Verteilung

Für Personen vom Typ python2 ersetzen Sie bitte @ durch eine Funktion, die das innere Produkt von numpy berechnet.

#Gemischte Bernoulli-Verteilung
class BernoulliMixtureDistribution(object):

    def __init__(self, n_components):
        #Anzahl der Cluster
        self.n_components = n_components

    def fit(self, X, iter_max=100):
        self.ndim = np.size(X, 1)

        #Parameterinitialisierung
        self.weights = np.ones(self.n_components) / self.n_components
        self.means = np.random.uniform(0.25, 0.75, size=(self.n_components, self.ndim))
        self.means /= np.sum(self.means, axis=-1, keepdims=True)

        #EM-Schritt wiederholen
        for i in range(iter_max):
            params = np.hstack((self.weights.ravel(), self.means.ravel()))

            #E Schritt
            stats = self._expectation(X)

            #M Schritt
            self._maximization(X, stats)
            if np.allclose(params, np.hstack((self.weights.ravel(), self.means.ravel()))):
                break
        self.n_iter = i + 1

    #PRML-Formel(9.52)Logistik von
    def _log_bernoulli(self, X):
        np.clip(self.means, 1e-10, 1 - 1e-10, out=self.means)
        return np.sum(X[:, None, :] * np.log(self.means) + (1 - X[:, None, :]) * np.log(1 - self.means), axis=-1)

    def _expectation(self, X):
        #PRML-Formel(9.56)
        log_resps = np.log(self.weights) + self._log_bernoulli(X)
        log_resps -= logsumexp(log_resps, axis=-1)[:, None]
        resps = np.exp(log_resps)
        return resps

    def _maximization(self, X, resps):
        #PRML-Formel(9.57)
        Nk = np.sum(resps, axis=0)

        #PRML-Formel(9.60)
        self.weights = Nk / len(X)

        #PRML-Formel(9.58)
        self.means = (X.T @ resps / Nk).T

Ergebnis

So jupyter notebook 9.3.3 Bei Anwendung der gemischten Bernoulli-Verteilung auf den MNIST-Datensatz (200 zufällig ausgewählte Bilder von jeweils 0 bis 4) ergibt sich der Durchschnitt der einzelnen Bernoulli-Verteilungen wie in der folgenden Abbildung dargestellt. index.png

Am Ende

Da das Lernen des EM-Algorithmus in die lokale Lösung passt (obwohl es in der Realität möglicherweise nicht die lokale Lösung ist), wird nicht nur jede Zahl wie oben gezeigt klar wiedergegeben. Ich fand es schwierig zu lernen, ob es Paare mit ähnlichen Formen wie 1 und 7 und 3 und 8 gab.

Recommended Posts

Python-Implementierung gemischte Bernoulli-Verteilung
PRML Kapitel 5 Python-Implementierung eines Netzwerks mit gemischter Dichte
PRML Kapitel 14 Bedingte gemischte Modell-Python-Implementierung
PRML Kapitel 10 Variante Mixed Gaussian Distribution Python-Implementierung
PRML Kapitel 2 Python-Implementierung von Student t-Distribution
Logistische Verteilung in Python
RNN-Implementierung in Python
ValueObject-Implementierung in Python
SVM-Implementierung in Python
[Python] Implementierung von Clustering mit einem gemischten Gaußschen Modell
Gemischte Gaußsche Verteilung und logsumexp
Schreiben Sie die Beta-Distribution in Python
[Line / Python] Beacon-Implementierungsnotiz
Generieren Sie eine U-Verteilung in Python
EM der gemischten Gaußschen Verteilung
Python-Implementierung des Partikelfilters
Implementierung eines neuronalen Netzwerks in Python
Maxout Beschreibung und Implementierung (Python)
Implementierung der schnellen Sortierung in Python
Python-Implementierung eines selbstorganisierenden Partikelfilters
PRML Kapitel 5 Python-Implementierung für neuronale Netze
Implementierung eines Lebensspiels in Python
Beherrsche die lineare Suche! ~ Python-Implementierungsversion ~
PRML Kapitel 3 Evidence Ungefähre Python-Implementierung
Implementierung von Desktop-Benachrichtigungen mit Python
Python-Implementierung eines nicht rekursiven Segmentbaums
[Python] Gemischtes Gaußsches Modell mit Pyro
Implementierung von Light CNN (Python Keras)
Implementierung der ursprünglichen Sortierung in Python
Implementierung der Dyxtra-Methode durch Python
Ableitung der multivariaten t-Verteilung und Implementierung der Zufallszahlengenerierung durch Python