Migrationsartikel Nr. 3.
Diesmal ist ein einfaches Experiment. Beachten Sie das Schreiben von Code.
Der EM-Algorithmus ist nicht einfach zu verwenden (ich persönlich denke), da sich die zu lösende Gleichung je nach Modell ändert. Erwägen Sie daher, es durch Stichproben direkt aus dem Modell zu erhalten. Hier erfolgt die Abtastung mit pymc3.
Die abzutastenden Daten wurden wie folgt erzeugt.
N=1000
X, y = datasets.make_blobs(n_samples=N, random_state=8)
transformation = [[0.6, -0.6], [-0.4, 0.8]]
X_aniso = np.dot(X, transformation)
df = pd.DataFrame()
df['x'] = X_aniso.T[0]
df['y'] = X_aniso.T[1]
df['c'] = y
Die Trainingsdaten werden wie folgt aufgezeichnet. Es ist visuell zu sehen, dass es drei Cluster gibt.
plt.figure(figsize=(10, 10))
sns.scatterplot(x='x', y='y', data=df, hue='c')
plt.show()
Die Beobachtungsdaten seien $ x $, der Cluster sei $ z $ und der Parameter sei $ \ theta_z $. Es wird angenommen, dass diese wie folgt erzeugt werden.
\displaystyle{
\begin{aligned}
x_i &\sim N(x|, \mu_{z_i}, I)\\
\mu_k &\sim N(\mu_k| 0, I)\\
z_i &\sim Cat(z_i|\pi)\\
\pi &\sim Dir(\pi|\alpha)
\end{aligned}
}
Wenn dies programmatisch geschrieben wurde, wurde es wie folgt. Die Bibliothek verwendete pymc3. Bei Betrachtung der Daten war klar, dass die Anzahl der Cluster 3 betrug. Daher haben wir hier eine Stichprobe mit 3 Clustern erstellt.
k=3
data_dim = len(df.T) -1
data_size = len(data)
with pm.Model() as model:
pi = pm.Dirichlet('p', a=np.ones(k), shape=k)
pi_min_potential = pm.Potential('pi_min_potential', tt.switch(tt.min(pi) < .1, -np.inf, 0))
z = pm.Categorical('z', p=pi, shape=data_size)
mus = pm.MvNormal('mus', mu=np.zeros(data_dim), cov=np.eye(data_dim), shape=(k, data_dim))
y = pm.MvNormal('obs', mu=mus[z], cov=np.eye(data_dim), observed=df.drop(columns='c').to_numpy())
tr = pm.sample(10*data_size, random_seed=0, chains=1)
Zusätzlich verwendet der Cluster unter den erhaltenen Stichprobenergebnissen den häufigsten Wert.
df['pred'] = scipy.stats.mode(tr['z'], axis=0).mode[0]
Zeichnen Sie die Ergebnisse mit dem folgenden Code. Sie können sehen, dass sie gut gruppiert sind.
plt.figure(figsize=(10, 10))
sns.scatterplot(x='x', y='y', data=df, hue='pred')
plt.show()