Migration article # 3.
This time is a simple experiment. Note the code.
The EM algorithm is not easy to use (I personally think) because the equation to be solved changes depending on the model. Therefore, consider obtaining sampling directly from the model. Here, sampling is performed using pymc3.
The data to be sampled was generated as follows.
N=1000
X, y = datasets.make_blobs(n_samples=N, random_state=8)
transformation = [[0.6, -0.6], [-0.4, 0.8]]
X_aniso = np.dot(X, transformation)
df = pd.DataFrame()
df['x'] = X_aniso.T[0]
df['y'] = X_aniso.T[1]
df['c'] = y
The plot of the training data is as follows. Visually, you can see that there are three clusters.
plt.figure(figsize=(10, 10))
sns.scatterplot(x='x', y='y', data=df, hue='c')
plt.show()
Let the observation data be $ x $, the cluster be $ z $, and the parameter be $ \ theta_z $. It is assumed that these are generated as follows.
\displaystyle{
\begin{aligned}
x_i &\sim N(x|, \mu_{z_i}, I)\\
\mu_k &\sim N(\mu_k| 0, I)\\
z_i &\sim Cat(z_i|\pi)\\
\pi &\sim Dir(\pi|\alpha)
\end{aligned}
}
When this was written programmatically, it became as follows. The library used pymc3. Looking at the data, it was clear that the number of clusters was 3, so here we sampled with 3 clusters.
k=3
data_dim = len(df.T) -1
data_size = len(data)
with pm.Model() as model:
pi = pm.Dirichlet('p', a=np.ones(k), shape=k)
pi_min_potential = pm.Potential('pi_min_potential', tt.switch(tt.min(pi) < .1, -np.inf, 0))
z = pm.Categorical('z', p=pi, shape=data_size)
mus = pm.MvNormal('mus', mu=np.zeros(data_dim), cov=np.eye(data_dim), shape=(k, data_dim))
y = pm.MvNormal('obs', mu=mus[z], cov=np.eye(data_dim), observed=df.drop(columns='c').to_numpy())
tr = pm.sample(10*data_size, random_seed=0, chains=1)
Also, of the obtained sampling results, the cluster uses the mode value.
df['pred'] = scipy.stats.mode(tr['z'], axis=0).mode[0]
Plot the results using the code below. You can see that it is clustered nicely.
plt.figure(figsize=(10, 10))
sns.scatterplot(x='x', y='y', data=df, hue='pred')
plt.show()
--Bayesian Statistical Modeling with Python: Data Analysis Practice Guide with PyMC