[Python] Mixed Gauss model with Pyro

I tried estimating a mixed Gauss model with Pyro. Based on Official example, it is executed with supplementary contents as appropriate.

Environment Windows10 Python: 3.7.7 Jupyter Notebook: 1.0.0 PyTorch: 1.5.1 Pyro: 1.4.0 scipy: 1.5.2 numpy: 1.19.1 matplotlib: 3.3.0 seaborn: 0.10.1
import os

import numpy as np
from scipy import stats

import torch
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.autoguide import AutoDelta
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

pyro.set_rng_seed(0)
pyro.enable_validation(True)

Data preparation

Call the iris dataset from seaborn and set the value of petal_length as the target data.

df = sns.load_dataset('iris')
data = torch.tensor(df['petal_length'], dtype=torch.float32)
sns.swarmplot(data=df, x='petal_length')

image.png

Looking at the plot, it seems good to divide the cluster into two [^ 1]

model settings

In Pyro, describe the model of distribution in the model method. Apply a mixed Gaussian model with each cluster of data $ x_1, \ cdots, x_n \ in \ mathbb {R} $ as $ z_1, \ cdots, z_n \ in \ {1, \ cdots, K \} $ I will.

\begin{align}
p &\sim Dir(\tau_0/K, \cdots, \tau_0/K) \\
z_i &\sim Cat(p) \\

\mu_k &\sim N(\mu_0, \sigma_0^2) \\
\sigma_k &\sim InvGamma(\alpha_0, \beta_0) \\
x_i &\sim N(\mu_{z_i}, \sigma_{z_i}^2)
\end{align}

$ K $ is the number of clusters, $ \ tau_0, \ mu_0, \ sigma_0, \ alpha_0, \ beta_0 $ are prior distribution parameters. [^ 2]
Bayesian infers $ \ mu_1, \ cdots, \ mu_K $ and $ \ sigma_1, \ cdots, \ sigma_K $, and creates a model that probabilistically calculates clusters $ z_1, \ cdots, z_n $.

K = 2  # Fixed number of clusters
TAU_0 = 1.0
MU_0 = 0.0
SIGMA_0_SQUARE = 10.0
ALPHA_0 = 1.0
BETA_0 = 1.0

@config_enumerate
def model(data):
    alpha = torch.full((K,), fill_value=TAU_0)
    p = pyro.sample('p', dist.Dirichlet(alpha))
    with pyro.plate('cluster_param_plate', K):
        mu = pyro.sample('mu', dist.Normal(MU_0, SIGMA_0_SQUARE))
        sigma = pyro.sample('sigma', dist.InverseGamma(ALPHA_0, BETA_0))

    with pyro.plate('data_plate', len(data)):
        z = pyro.sample('z', dist.Categorical(p))
        pyro.sample('x', dist.Normal(locs[z], scales[z]), obs=data)

@config_enumerate is a decorator for sampling discrete variablespyro.sample ('z', dist.Categorical (p))in parallel.

Check the sampled value

By using poutine.trace, you can check the sampling value when data is given to model.

trace_model = poutine.trace(model).get_trace(data)
tuple(trace_model.nodes.keys())
> ('_INPUT',
   'p',
   'cluster_param_plate',
   'mu',
   'sigma',
   'data_plate',
   'z',
   'x',
   '_RETURN')

The type of trace_model.nodes is ʻOrderedDict, which holds the above key. _INPUT refers to the data given to model, _RETURN refers to the return value of model(None in this case), and the others refer to the parameters defined inmodel`.

As a test, let's check the value of p. This is a parameter sampled from $ Dir (\ tau_0 / K, ⋯, \ tau_0 / K) $. trace_model.nodes ['p'] is also dict and you can see the value with value.

trace_model.nodes['p']['value']
> tensor([0.8638, 0.1362])

Next, let's check the value of z, which represents a cluster of each data.

trace_model.nodes['z']['value']
> tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1,
          0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0,
          0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
          0, 0, 0, 0, 1, 0])

It can be said that 0 is easily sampled from the value of p, but the result is exactly that. Note that this is a sampling from the prior distribution, so we still cannot make a correct estimate.

guide settings

In Pyro, set the posterior distribution in the guide. pyro.infer.autoguide.AutoDelta is a class for MAP estimation.

guide = AutoDelta(poutine.block(model, expose=['p', 'mu', 'sigma']))

poutine.block is a method that selects the parameters to be estimated. ʻAutoDelta does not seem to be able to handle the discrete parameter z, so it is not specified by expose. Estimating z` is done after fitting the distribution.

This guide returns the parameters of the estimated value as a dict for the data.

guide(data)
> {'p': tensor([0.5000, 0.5000], grad_fn=<ExpandBackward>),
   'mu': tensor([4.0607, 2.8959], grad_fn=<ExpandBackward>),
   'sigma': tensor([1.3613, 1.6182], grad_fn=<ExpandBackward>)}

At the moment, we are only returning the initial value, but from now on, we will return the MAP estimated value by fitting with SVI.

Distribution fitting

In the guide, I constructed a model that estimates other parameters without estimating z. In other words, z needs to be marginalized and calculated. To do this, set the loss of the stochastic variational estimate to TraceEnum_ELBO ().

optim = pyro.optim.Adam({'lr': 1e-3})
svi = SVI(model, guide, optim, loss=TraceEnum_ELBO())

Make a fitting.

NUM_STEPS = 3000
pyro.clear_param_store()

history = []
for step in range(1, NUM_STEPS + 1):
    loss = svi.step(data)
    history.append(loss)
    if step % 100 == 0:
        print(f'STEP: {step} LOSS: {loss}')

The plot of loss at each step looks like this:

plt.figure()
plt.plot(history)
plt.title('Loss')
plt.grid()
plt.xlim(0, 3000)
plt.show()

image.png

It can be judged that the value of loss has converged and the estimation is completed.

Confirmation of estimated distribution

Get estimates of $ p, \ mu, \ sigma $ from guide.

map_params = guide(data)
p = map_params['p']
mu = map_params['mu']
sigma = map_params['sigma']
print(p)
print(mu)
print(sigma)
> tensor([0.6668, 0.3332], grad_fn=<ExpandBackward>)
  tensor([4.9049, 1.4618], grad_fn=<ExpandBackward>)
  tensor([0.8197, 0.1783], grad_fn=<ExpandBackward>)

Plot the distribution. In the figure below, the plot with the x mark means the value of the data.

x = np.arange(0, 10, 0.01)
y1 = p[0].item() * stats.norm.pdf((x - mu[0].item()) / sigma[0].item())
y2 = p[1].item() * stats.norm.pdf((x - mu[1].item()) / sigma[1].item())

plt.figure()
plt.plot(x, y1, color='red', label='z=0')
plt.plot(x, y2, color='blue', label='z=1')
plt.scatter(data.numpy(), np.zeros(len(data)), color='black', alpha=0.3, marker='x')
plt.legend()
plt.show()

image.png

The distribution can be estimated well.

Cluster estimation

First, set the parameters estimated by guide to model. In Pyro, parameters are set via trace.

trace_guide_map = poutine.trace(guide).get_trace(data)
model_map = poutine.replay(model, trace=trace_guide_map)

Check the parameters set in model. Here, only $ \ mu $ is confirmed.

trace_model_map = poutine.trace(model_map).get_trace(data)
trace_guide_map.nodes['mu']['value']
>> tensor([4.9048, 1.4618], grad_fn=<ExpandBackward>)

It matches the value of $ \ mu $ in guide. Then estimate the value of $ z $ for each data. At this time, use pyro.infer.infer_discrete.

model_map = infer_discrete(model_map, first_available_dim=-2)

first_available_dim = -2 is a setting to avoid conflict with the dimension of data_plate. This sets the $ z $ estimate to model, which can be obtained from trace.

trace_model_map = poutine.trace(model_map).get_trace(data)
z_inferred = trace_model_map.nodes['z']['value']
z_inferred
> tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0])

Let's plot the data for each value of $ z $.

df['z'] = trace_model_map.nodes['z']['value']
df['z'] = df['z'].apply(lambda z: f'z={z}')
sns.swarmplot(data=df, x='petal_length', y='z')

image.png

You can see that it can be estimated well.

in conclusion

I made a mixed Gauss with Pyro and tried to fit it. I'm used to object-oriented thinking, so I found it a bit tedious to use poutine.trace to retrieve the estimated value. When actually using it, it seems better to create a class like GaussianMixtureModel and describe the process of extracting the value internally. I will continue to touch Pyro to deepen my understanding.

[^ 1]: The iris data is originally a 3-class classification dataset, but we will not consider the original class here. [^ 2]: In Pyro's example, LogNormal is applied as the distribution of $ \ sigma_k $, but this time, Inverse Gamma, which is a conjugate prior distribution for scala of Gaussian distribution, is applied.

Recommended Posts

[Python] Mixed Gauss model with Pyro
[Python] Clustering with an infinitely mixed Gaussian model
[Python] Bayesian inference with Pyro
PRML Chapter 14 Conditional Mixed Model Python Implementation
FizzBuzz with Python3
Scraping with Python
Statistics with python
Scraping with Python
Twilio with Python
Integrate with Python
Play with 2016-Python
AES256 with python
python starts with ()
Bingo with python
Zundokokiyoshi with python
Excel with Python
Microcomputer with Python
Cast with python
Solving the Lorenz 96 model with Julia and Python
Portfolio optimization with Python (Markowitz's mean variance model)
Serial communication with Python
[Python] Implementation of clustering using a mixed Gaussian model
Django 1.11 started with Python3.6
Primality test with Python
Python with eclipse + PyDev.
Socket communication with Python
Data analysis with python 2
Scraping with Python (preparation)
Try scraping with Python.
Learning Python with ChemTHEATER 03
"Object-oriented" learning with python
Run Python with VBA
Handling yaml with python
Solve AtCoder 167 with python
Serial communication with python
Simulate a good Christmas date with a Python optimized model
[Python] Use JSON with Python
Learning Python with ChemTHEATER 05-1
Learn Python with ChemTHEATER
Run prepDE.py with python3
Nonparametric Bayes with Pyro
1.1 Getting Started with Python
Collecting tweets with Python
Binarization with OpenCV / Python
3. 3. AI programming with Python
Kernel Method with Python
Non-blocking with Python + uWSGI
Scraping with Python + PhantomJS
Posting tweets with python
Drive WebDriver with python
Use mecab with Python3
[Python] Redirect with CGIHTTPServer
Voice analysis with python
[# 2] Make Minecraft with Python. ~ Model drawing and player implementation ~
Think yaml with python
Getting Started with Python
Use DynamoDB with Python
Zundko getter with python
Handle Excel with python
Model fitting with lmfit
Ohm's Law with Python