[PYTHON] Clustering experiment by sampling

Migration article # 3.

This time is a simple experiment. Note the code.

Cluster estimation by sampling

The EM algorithm is not easy to use (I personally think) because the equation to be solved changes depending on the model. Therefore, consider obtaining sampling directly from the model. Here, sampling is performed using pymc3.

Training data

The data to be sampled was generated as follows.

N=1000
X, y = datasets.make_blobs(n_samples=N, random_state=8)
transformation = [[0.6, -0.6], [-0.4, 0.8]]
X_aniso = np.dot(X, transformation)
df = pd.DataFrame()
df['x'] = X_aniso.T[0]
df['y'] = X_aniso.T[1]
df['c'] = y

The plot of the training data is as follows. Visually, you can see that there are three clusters.

plt.figure(figsize=(10, 10))
sns.scatterplot(x='x', y='y', data=df, hue='c')
plt.show()

cluster.png

model

Let the observation data be $ x $, the cluster be $ z $, and the parameter be $ \ theta_z $. It is assumed that these are generated as follows.

\displaystyle{
\begin{aligned}
x_i &\sim N(x|, \mu_{z_i}, I)\\
\mu_k &\sim N(\mu_k| 0, I)\\
z_i &\sim Cat(z_i|\pi)\\
\pi &\sim Dir(\pi|\alpha)
\end{aligned}
}

When this was written programmatically, it became as follows. The library used pymc3. Looking at the data, it was clear that the number of clusters was 3, so here we sampled with 3 clusters.

k=3
data_dim = len(df.T) -1
data_size = len(data)
with pm.Model() as model:
    pi = pm.Dirichlet('p', a=np.ones(k), shape=k)
    pi_min_potential = pm.Potential('pi_min_potential', tt.switch(tt.min(pi) < .1, -np.inf, 0))
    z = pm.Categorical('z', p=pi, shape=data_size)

    mus = pm.MvNormal('mus', mu=np.zeros(data_dim), cov=np.eye(data_dim), shape=(k, data_dim))

    y = pm.MvNormal('obs', mu=mus[z], cov=np.eye(data_dim), observed=df.drop(columns='c').to_numpy())
    tr = pm.sample(10*data_size, random_seed=0, chains=1)

Also, of the obtained sampling results, the cluster uses the mode value.

df['pred'] = scipy.stats.mode(tr['z'], axis=0).mode[0]

result

Plot the results using the code below. You can see that it is clustered nicely.

plt.figure(figsize=(10, 10))
sns.scatterplot(x='x', y='y', data=df, hue='pred')
plt.show()

cluster_pred.png

Other comments

reference

--Bayesian Statistical Modeling with Python: Data Analysis Practice Guide with PyMC

Recommended Posts

Clustering experiment by sampling
[Roughly] Clustering by KMeans
Implementation and experiment of convex clustering method
Try to classify O'Reilly books by clustering