[PYTHON] [Nonparametric Bayes] Estimating the number of clusters using the Dirichlet process

Overview

Hello, this is kwashi. Along with machine learning, the field of unsupervised learning is also rising. Speaking of unsupervised learning, the topic model is famous. The topic model is a technique for estimating potential meaning. One of the purposes of this is to get potential meanings, such as news article categories. In this topic model, the topic distribution (number of categories) of each document is acquired using the Dirichlet process.

In addition, I will introduce a paper on sound source separation, but Bayesian Nonparametrics for Microphone Array Processing is from input to multiple microphones (microphone array) by multiple sound sources. , Explains the method of estimating the sound source direction + separated sound. As mentioned in the subject of this paper, the number of sound sources is calculated in the Dirichlet process by using nonparametric Bayes. The number of sound sources is very important. This is because it gives a clue as to how many separate sounds should be generated.

In this article, we will explain using an example of simultaneously estimating the mean and variance of the number of clusters and the normal distribution using the Dirichlet process as this nonparametric Bayes.

Purpose

In this article, the mean and variance of each normal distribution are calculated from the training data generated from the following mixed normal distributions. First, we will explain how to estimate each normal distribution with the number of clusters specified, and then explain how to estimate the number of clusters at the same time.

The values for each normal distribution are mean (-8, 0, 4), variance (1.8, 1.5, 1.3), and normal distribution mix ratio (0.2, 0.5, 0.3). The figure below shows the probability density distribution (green line) generated and the frequency of the generated data. It also describes a mixed normal distribution and data generation program.

img001.png


import pymc3 as pm
import numpy as np
import theano.tensor as tt
import scipy.stats as stats
from scipy import optimize
import matplotlib.pyplot as plt

np.random.seed(53536)

xmin = -15.
xmax = 10.
xsize = 200
x = np.linspace(xmin, xmax, xsize)
pi_k = np.array([0.2, 0.5, 0.3])
loc_x = np.array([-8, 0, 4])

norm1 = stats.norm.pdf(x, loc=loc_x[0], scale=1.8) * pi_k[0]
norm2 = stats.norm.pdf(x, loc=loc_x[1], scale=1.5) * pi_k[1]
norm3 = stats.norm.pdf(x, loc=loc_x[2], scale=1.3) * pi_k[2]

npdf = norm1 + norm2 + norm3
npdf /= npdf.sum()

#Value according to the probability of the probability distribution(x)Get
y = np.random.choice(x, size=4000, p=npdf)

Fixed number of clusters Normal distribution estimation

In this chapter, we will explain how to estimate the mean and variance of the three normal distributions with the number of clusters set to 3 in advance. There are many methods for estimating the parameters of the normal distribution with the number of clusters determined, such as the EM algorithm, Variant Bayes, and Markov chain Monte Carlo methods (MCMC). In this chapter, MCMC is used.

The generative model of the normal distribution is as shown in the following program. Using the mixing ratio of each normal distribution generated from the Dirichlet distribution as a parameter, the categorical distribution generates an ID (z) indicating which group each data belongs to. After that, for each normal distribution, we set the distributions that deserve the mean and variance.

with pm.Model() as model:
  p = pm.Dirichlet('p', a=np.ones(cluster))
  z = pm.Categorical('z', p=p, shape=y.shape[0])

  mu = pm.Normal('mu', mu=y.mean(), sd=10, shape=cluster)
  sd = pm.HalfNormal('sd', sd=10, shape=cluster)

  y = pm.Normal('y', mu=mu[z], sd=sd[z], observed=y)

  trace = pm.sample(1000)

However, this model is slow to calculate due to the hidden variable z. Therefore, peripheralization(∫p(y|z,θ)dz -> p(y|θ))And modify the program as follows.

with pm.Model() as model:
  p = pm.Dirichlet('p', a=np.ones(cluster))
  mu = pm.Normal('mu', mu=y.mean(), sd=10, shape=cluster)
  sd = pm.HalfNormal('sd', sd=10, shape=cluster)

  y = pm.NormalMixture('y', w=p, mu=mu, sd=sd, observed=y)

  trace = pm.sample(3000, chains=1)

The result of inference is shown in the figure below. The mean (-8, 0, 4) (mu), variance (1.8, 1.5, 1.3) (sd) of each normal distribution, and the mixing ratio (0.2, 0.5, 0.3) (p) of the normal distribution are well estimated. Can be seen.

img002.png

Number of clusters unknown Normal distribution estimation

In the previous chapter, we estimated the parameters of the normal distribution on the assumption that the number of clusters is known. In this chapter, we will explain the normal distribution parameter estimation when the number of clusters is unknown by introducing the Dirichlet process. In this chapter, we give priority to explaining the image of the Dirichlet process. For details on the Dirichlet process, see [Continued / Easy-to-understand pattern recognition-Introduction to unsupervised learning](https://www.amazon.co.jp/%E7%B6%9A%E3%83%BB%E3%82%8F% E3% 81% 8B% E3% 82% 8A% E3% 82% 84% E3% 81% 99% E3% 81% 84% E3% 83% 91% E3% 82% BF% E3% 83% BC% E3% 83% B3% E8% AA% 8D% E8% AD% 98% E2% 80% 95% E6% 95% 99% E5% B8% AB% E3% 81% AA% E3% 81% 97% E5% AD% A6% E7% BF% 92% E5% 85% A5% E9% 96% 80% E2% 80% 95-% E7% 9F% B3% E4% BA% 95-% E5% 81% A5% E4% B8% 80% E9% 83% 8E / dp / 427421530X) is very helpful, so please read it.

To briefly explain the Dirichlet process (DP), H ~ DP (a, H'), a are represented by the degree of concentration (image like variance) and H (base distribution; image like mean), and by DP. The distribution H is generated. For this reason, it is sometimes called the distribution with respect to the distribution. In my personal image, Regression using Gaussian process states that the regression function itself is inferred by Gaussian process. The image is that the Dirikre process also infers the distribution itself.

By the way, there are Chinese cooking process (CRP) and stick breaking process (SBP) as a method to realize this Dirichlet process. In this chapter, we will explain the method using this SBP. SBP is expressed by the following formula. K is the number of distributions. When K is infinite, an infinite dimensional Dirichlet distribution can be shown in SBP. (In actual use, set K to a finite constant.)

{\pi_k = b_k \prod_{j=1}^{K-1} (1-b_j),\,b_k \sim {\rm Beta}(b;1,\alpha) } 

The result π of this equation is the mixture ratio of cluster k. The program and results of this SBP are shown below. However, this SPB alone cannot calculate the average of the distribution. Therefore, the value is generated from the basis distribution as shown in the following formula and is set on the horizontal axis.

{  \theta _ { k } \sim H _ { 0 } , \text { for } k = 1 , \ldots , K } 

The figure below shows which mixture ratio was generated at each position. If a is small, it is concentrated in the center, but if a is large, you can see that it is scattered outside. img007.png

def stick_breaking(a, h, k):
  '''
  a:Concentration
  h:Basis distribution(scipy dist)
  K:Number of components

  Return
  locs :position(array)
  w:probability(array)
  '''
  s = stats.beta.rvs(1, a, size=K)
 #ex : [0.02760315 0.1358357  0.02517414 0.11310199 0.21462781]
  w = np.empty(K)
  w = s * np.concatenate(([1.], np.cumprod(1 - s[:-1])))
  #ex: 0.02760315 0.13208621 0.0211541  0.09264824 0.15592888]
  # if i == 1, s , elif i > 1, s∑(1-sj) (j 1 -> i-1)

  locs = H.rvs(size=K)
  return locs, w

Next, the normal distribution expressed by the mixture ratio generated by SPB when K = 5 and the mean generated from the basis distribution (normal distribution) is shown. The variance is constant. As shown in this figure, we can see that any mixed normal distribution can be expressed by changing the SPB parameter a and the basis distribution. imga_1.png

Now, based on the fact that the mixture ratio is generated by SPB in this way, we will estimate the normal distribution parameters when the number of clusters is unknown. I wrote the following program according to the format of pymc. Only the mixture ratio is output from SPB. Moreover, the hyperparameter a of SPB is generated from the gamma distribution. And the mean and variance are generated from the normal distribution. K is set to a finite number (20).

def stick_breaking_DP(a, K):
  b = pm.Beta('B', 1., a, shape=K)
  w = b * pm.math.concatenate([[1.], tt.extra_ops.cumprod(1. - b)[:-1]])
  return w

K = 20

with pm.Model() as model:
  a = pm.Gamma('a', 1., 1.)
  w = pm.Deterministic('w', stick_breaking_DP(a, K))
  mu = pm.Normal('mu', mu=y.mean(), sd=10, shape=K)
  sd = pm.HalfNormal('sd', sd=10, shape=K)

  y = pm.NormalMixture('y', w=w, mu=mu, sd=sd, observed=y)

  trace = pm.sample(1000, chains=1)

The estimated mixing numbers are shown below. In this way, since the mixture ratio with a large value is 3, it can be estimated that there are 3 clusters. (Actual values are 0.2, 0.5, 0.3) Also, the other means and variances were almost the same as with clusters. The important thing here is that we can estimate both clusters.

img006.png

Summary

In this article, as a nonparametric Bayesian method, we explained a method of estimating the number of clusters using the Dirichlet process at the same time as estimating the normal distribution. Since the explanation was given with priority given to the image, some people may feel uncomfortable, but in that case, thank you for your guidance.

Recommended Posts

[Nonparametric Bayes] Estimating the number of clusters using the Dirichlet process
Estimating the effect of measures using propensity scores
Determine the number of classes using the Starges formula
Clustering G-means that automatically determines the number of clusters
How to find the optimal number of clusters in k-means
10. Counting the number of lines
Get the number of digits
Calculate the number of changes
I investigated the X-means method that automatically estimates the number of clusters
Get the number of views of Qiita
Calculation of the number of Klamer correlations
Get the number of Youtube subscribers
How to find out the number of CPUs without using the sar command
Graph the change in the number of keyword appearances per month using pandas
Count / verify the number of method calls.
Implement part of the process in C ++
Using gensim with R (Hierarchical Dirichlet Process)
Set the process name of the Python program
Count the number of characters with echo
[Python] Automatically totals the total number of articles posted by Qiita using the API
An introduction to data analysis using Python-To increase the number of video views-