[PYTHON] Schätzung der gemischten Gaußschen Verteilung nach der varianten Bayes'schen Methode

Illustriert

変分混合ガウス分布.gif

Selbst wenn Sie die anfängliche Anzahl von Klassen auf 6 setzen, können Sie sehen, dass sie schließlich zu 3 Klassen konvergiert

Algorithmus

  1. Initialisieren Sie $ r_ {nk} $
  2. Berechnen Sie drei Statistiken
    • N_k = \sum_{n=1}^N r_{nk}
    • {\bar x_k} = \frac{1}{N_k} \sum_{n=1}^N r_{nk}x_n
    • S_k = \frac{1}{N_k} \sum_{n=1}^N r_{nk}(x_n - {\bar x_k})(x_n - {\bar x_k})^T
  3. Mstep: q(\pi) = Dir(\pi|\alpha), q(\mu_k, \Lambda_k) = N(\mu_k|m_k, (\beta_k \Lambda_k)^{-1})W(\Lambda_k|W_k, \nu_k)Fragen.
    • \alpha_k = \alpha_0 + N_k
    • \beta_k = \beta_0 + N_k
    • m_k = \frac{1}{\beta_k}(\beta_0 m_0 + N_k {\bar x_k})
    • W_k^{-1} = W_0^{-1} + N_k S_k + \frac{\beta_0 N_k}{\beta_0 + N_k}({\bar x_k} - m_0)({\bar x_k} - m_0)^T
    • \nu_k = \nu_0 + N_k
  4. Schritt: $ q (Z) = \ Pi_ {n = 1} ^ N \ Pi_ {k = 1} ^ Kr_ {nk} ^ {z_ {nk}} $
    • r_{nk} = \frac{\rho_{nk}}{\sum_{j=1}^K \rho_{nj}}
    • \ln\rho_{nk} = E[\ln\pi_k] + \frac{1}{2} E[\ln |\Lambda_k|] - \frac{D}{2}\ln(2\pi) - \frac{1}{2}E_{\mu_k, \Lambda_k}[(x_n - \mu_k)^T \Lambda_k (x_n - \mu_k)]
    • E_{\mu_k, \Lambda_k}[(x_n - \mu_k)^T \Lambda_k (x_n - \mu_k)] = D\beta_k^{-1} + \nu_k(x_n - m_k)^TW_k(x_n - m_k)
    • E[\ln|\Lambda_k|] = \sum_{i=1}^D \psi(\frac{\nu_k + 1 - i}{2}) + D \ln 2 + \ln|W_k|
    • E[\ln \pi_k] = \psi(\alpha_k) - \psi({\hat \alpha})
    • {\hat \alpha} = \sum_k \alpha_k
  5. Drehen Sie 2 ~ 4, bis es konvergiert

Implementierung

%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from scipy.special import digamma
import matplotlib.cm as cm
plt.style.use("ggplot")

#Abmessungen
D = 2
#Die Anzahl der Daten
N = 2000


#Tatsächlicher Wert
mu1 = [0, 1]
sigma1 = 0.2 * np.eye(D)
N1 = int(N*0.3)
mu2 = [-1, -1]
sigma2 = 0.1 * np.eye(D)
N2 = int(N*0.5)
mu3 = [1, -1]
sigma3 = 0.1 * np.eye(D)
N3 = int(N*0.2)

plt.figure(figsize=(5, 5))
data = np.concatenate([np.random.multivariate_normal(mu1, sigma1, N1), 
                    np.random.multivariate_normal(mu2, sigma2, N2),
                    np.random.multivariate_normal(mu3, sigma3, N3)
                   ])
plt.xlim(-2.1, 2.1)
plt.ylim(-2.1, 2.1)
plt.scatter(data[:, 0], data[:, 1], s=10)
plt.show()

w8dO9h9H66IZwAAAABJRU5ErkJggg==.png

#Anfangsanzahl der Klassen
K = 6
#Ursprünglicher Wert
mu = np.array([[0., -0.5],[0., 0.5], [1., 0.5], [-1., -0.5], [-1, -1.5], [1, 1.5]])
S = np.array([0.1 * np.eye(2) for k in range(K)])
 
#Vorherige Verteilungsparameter
alpha_0 = 1e-3
beta_0 = 1e-3
m_0 = np.zeros((K, D))
nu_0 = 1
W_0 = np.eye(D)
 
#Anfangsparameter
W_k = np.zeros((K, D, D))
E_mu_lam = np.zeros((N, K))
 
def multi_gauss(x, y, mu, sigma):
    return stats.multivariate_normal(mu, sigma).pdf(np.array([x, y]))
 
#1: r_Initialisierung von nk
r = np.ones([N, K]) / K
pi = np.ones(K) / K
g = np.zeros((N, K))
for k in range(K):
    g[:, k] = np.vectorize(lambda x, y: pi[k] * multi_gauss(x, y, mu[k], S[k]))(data[:, 0], data[:, 1])
for k in range(K):
    r[:, k] = g[:, k] / g.sum(1)

#Illustriert
X, Y = np.meshgrid(np.linspace(-2.1, 2.1), np.linspace(-2.1, 2.1))
cmap_colors = [cm.spring, cm.summer, cm.autumn, cm.winter, cm.Reds_r, cm.Dark2]
colors = ["pink", "green", "orange", "blue", "red", "black"]
plt.figure(figsize=(5, 5))
for k in range(K):
    Z = np.vectorize(lambda x, y: multi_gauss(x, y, mu[k], S[k]))(X, Y)
    plt.contour(X, Y, Z, cmap=cmap_colors[k], alpha=0.5)
 
 
plt.scatter(data[:, 0] , data[:, 1], c = map(lambda x: colors[x], r.argmax(1)), alpha=0.3, s=10)
plt.xlim(-2.1, 2.1)
plt.ylim(-2.1, 2.1)
init_title = "iter: 0"
plt.title(init_title)
#plt.savefig("data/" + init_title + ".png ")
plt.show()

n9OnT9Pc3KzKdIpEGtT0IVbdmfJhbduhUGhT+oCIjkz3A8h8XxDVoLYXa+NDT09P2nyQK1AkEokEDY4MJRKJZDOQwVAikUiQwVAikUgAGQwlEokEkMFQIpFIABkMJRKJBJDBUCKRSAD4HwgVxm0OO01kAAAAAElFTkSuQmCC.png

for i in range(20):
    #2:Berechnen Sie drei Statistiken
    N_k = r.sum(0)
    mu = r.T.dot(data) / np.c_[N_k]
    for k in range(K):
        S[k] = (np.c_[r[:, k]] * (data - mu[k])).T.dot(data - mu[k]) / N_k[k]
 
    #3: Mstep
    alpha = alpha_0 + N_k
    beta = beta_0 + N_k
    m_k = (beta_0 * m_0 + np.c_[N_k] * mu) / np.c_[beta]
    for k in range(K):
        tmp1 = beta_0 * N_k[k] * np.outer(mu[k] - m_0[k], mu[k] - m_0[k]) / (beta_0 + N_k[k])
        tmp2 = LA.inv(W_0) + N_k[k] * S[k] + tmp1
        W_k[k] = LA.inv(tmp2)
    nu_k = nu_0 + N_k
 
    #4: Estep
    E_ln_lam = digamma(nu_k / 2) + digamma((nu_k - 1) / 2) + D * np.log(2) + np.log([LA.norm(w) for w in W_k])
    E_ln_pi = digamma(alpha) - digamma(alpha.sum())
    for k in range(K):
        E_mu_lam[:, k] = D / beta[k] + nu_k [k] * np.diag((data - m_k[k]).dot(W_k[k]).dot((data - m_k[k]).T))
    ro = np.exp(E_ln_pi + E_ln_lam / 2. - D * np.log(2 * np.pi) / 2. - E_mu_lam / 2.)
    r = ro / np.c_[ro.sum(1)]
    r[r < 1e-10] = 1e-10
 
    #Erstellen Sie ein GIF-Diagramm
    plt.figure(figsize=(5, 5))
    X, Y = np.meshgrid(np.linspace(-2.1, 2.1), np.linspace(-2.1, 2.1))
    pi = np.exp(E_ln_pi)
    for k in range(K):
        Z = np.vectorize(lambda x, y: multi_gauss(x, y, mu[k], S[k]))(X, Y)
        if np.exp(E_ln_pi)[k] > 0.01:
            plt.contour(X, Y, Z, cmap=cmap_colors[k], alpha=0.5)
    plt.scatter(data[:, 0] , data[:, 1], c = map(lambda x: colors[x], r.argmax(1)), s=10, alpha=0.3)
    plt.xlim(-2.1, 2.1)
    plt.ylim(-2.1, 2.1)
    title = "iter: {}".format(i+1)
    plt.title(title)
    #plt.savefig("data/" + title + ".png ")
    plt.show()

Recommended Posts

Schätzung der gemischten Gaußschen Verteilung nach der varianten Bayes'schen Methode
EM der gemischten Gaußschen Verteilung
Konzept des Bayes'schen Denkens (2) ... Bayes'sche Schätzung und Wahrscheinlichkeitsverteilung
PRML Kapitel 9 Mixed Gaussian Distribution Python-Implementierung
Schätzung der gemischten Gaußschen Verteilung durch EM-Algorithmus
PRML Kapitel 10 Variante Mixed Gaussian Distribution Python-Implementierung
EM-Algorithmusberechnung für gemischtes Gaußsches Verteilungsproblem
[Python] Implementierung von Clustering mit einem gemischten Gaußschen Modell
Python: Diagramm der zweidimensionalen Datenverteilung (Schätzung der Kerneldichte)
Implementierung der Bayes'schen Varianzschätzung des Themenmodells in Python
Überprüfung der Normalverteilung