[PYTHON] Nonparametric Bayes with Pyro

Nonparametric Bayes is said to be an infinite dimensional Bayesian model. The Gaussian process done in previous article is also nonparametric Bayesian, but another famous one is the ** Dirichlet process **. For a theoretical explanation, refer to "Nonparametric Bayes (Machine Learning Professional Series)" and Qiita article here.

In the first place, Bayesian inference is performed in the following three steps. (1) Define the prior distribution of parameters and the probability model of the event (2) Calculate the likelihood from the obtained data and the probability model, and calculate the parameter posterior distribution by multiplying it by the parameter prior distribution (Bayes' theorem). (3) Generate a predicted probability distribution of data from the parameter posterior distribution and probability model. The one in which the prior distribution and the posterior distribution are the same in ② is called the conjugate prior distribution. Just as the beta distribution is the conjugate prior of the binomial distribution model, the Dirichlet distribution is the conjugate prior of the multinomial model. In other words, the Dirichlet distribution is a multivariate distribution of the beta distribution.

The Dirichlet process is the one that considers the Dirichlet distribution in infinite dimensions. It feels like the Gaussian process is an infinitely dimensional multivariate Gaussian distribution. However, while the Gaussian process is continuous, the Dirichlet process is discrete. Since it is discrete, it is an image that serves as a probabilistic model for clusters. This enables clustering (Dirichlet process mixed model (DPMM)) that automatically determines the number of clusters. However, the Dirichlet process is difficult to implement due to the infinite number of elements. So there are two ways. (1) A method of approximating in finite dimensions instead of infinity. Censor bar folding process (TSB) and finite symmetric Dirichlet distribution (FSD) are used. It is possible to solve with variational Bayes. (2) A method using the Chinese restaurant process (CRP) and the Pitman-Yaw process in which an infinite number of elements are integrated and eliminated. This method does not allow variational Bayes, only MCMC.

This time, we will implement by ① with Pyro.

Dirichlet Process Mixed Model (TSB) with Pyro

For example, in clustering with the k-means method or Gaussian mixed model, it is necessary to specify the number of clusters as a parameter. (Calculate and adjust AIC etc.) On the other hand, in clustering with this Dirichlet process mixed model, the number of clusters can be determined automatically, which enables flexible clustering.

In pyro, the Diricre process does not have a dedicated class like the Gaussian process, so it will be implemented from the model. This time, we will implement TSB (censored bar folding process) by referring to Example of official tutorial.

import torch
import pyro
import pyro.distributions as dist
import matplotlib.pyplot as plt
pyro.set_rng_seed(101)

First, create the data. Sampling 200 pieces of data from each of the four 2D Gaussian distributions.

num = 200
data = torch.cat((dist.MultivariateNormal(-8 * torch.ones(2), torch.eye(2)).sample([num]),
                  dist.MultivariateNormal(8 * torch.ones(2), torch.eye(2)).sample([num]),
                  dist.MultivariateNormal(torch.tensor([-5., 5.]), torch.eye(2)).sample([num]),
                  dist.MultivariateNormal(torch.tensor([6., -5.]), torch.eye(2)).sample([num])
                 ))
plt.scatter(data[:, 0], data[:, 1])
plt.show()

image.png

Cluster these 800 data without specifying the number of clusters.

Define a function to fold a bar.

import torch.nn.functional as F
def mix_weights(beta):
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)

Defines a probabilistic model for TSB. alpha is a concentration parameter that represents the degree of variability in the class. T is the upper limit of the number of clusters. This time I set it to 10.

N = data.shape[0]
T = 10
alpha = 0.1
def model(data):
    with pyro.plate("beta_plate", T-1):
        beta = pyro.sample("beta", dist.Beta(1, alpha))
    with pyro.plate("mu_plate", T):
        mu = pyro.sample("mu", dist.MultivariateNormal(torch.zeros(2), 5 * torch.eye(2)))
    with pyro.plate("data", N):
        z = pyro.sample("z", dist.Categorical(mix_weights(beta)))
        pyro.sample("obs", dist.MultivariateNormal(mu[z], torch.eye(2)), obs=data)

You can't use autoguide, so define your own guide function.

from torch.distributions import constraints
def guide(data):
    kappa = pyro.param('kappa', lambda: dist.Uniform(0, 2).sample([T-1]), constraint=constraints.positive)
    tau = pyro.param('tau', lambda: dist.MultivariateNormal(torch.zeros(2), 3 * torch.eye(2)).sample([T]))
    phi = pyro.param('phi', lambda: dist.Dirichlet(1/T * torch.ones(T)).sample([N]), constraint=constraints.simplex)
    with pyro.plate("beta_plate", T-1):
        q_beta = pyro.sample("beta", dist.Beta(torch.ones(T-1), kappa))
    with pyro.plate("mu_plate", T):
        q_mu = pyro.sample("mu", dist.MultivariateNormal(tau, torch.eye(2)))
    with pyro.plate("data", N):
        z = pyro.sample("z", dist.Categorical(phi))

Run variational Bayes.

from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO
from tqdm import tqdm
optim = Adam({"lr": 0.05})
svi = SVI(model, guide, optim, loss=Trace_ELBO())
losses = []
num_step = 1000
pyro.clear_param_store()
for j in tqdm(range(num_step)):
    loss = svi.step(data)
    losses.append(loss)
plt.plot(losses)
plt.xlabel("step")
plt.ylabel("loss")

image.png

Define a function that eliminates (censors) classes with low posterior probabilities.

def truncate(alpha, centers, weights):
    threshold = alpha**-1 / 100.
    print(threshold)
    true_centers = centers[weights > threshold]
    true_weights = weights[weights > threshold] / torch.sum(weights[weights > threshold])
    return true_centers, true_weights

Clears the classes with low posterior probabilities and displays the prediction results.

Bayes_Centers, Bayes_Weights = truncate(alpha, pyro.param("tau").detach(), torch.mean(pyro.param("phi").detach(), dim=0))
plt.scatter(data[:, 0], data[:, 1], color="blue")
plt.scatter(Bayes_Centers[:, 0], Bayes_Centers[:, 1], color="red")

image.png Four classes have been recognized. I think the center is also captured fairly accurately. This is a good example, but it didn't work depending on the values of alpha and T.

Other applications

Structural change estimation

When applying a linear regression model to time series data, the Dirichlet process considers an infinite linear model. By clustering each model, time series data can be segmented. Reference: https://www.slideshare.net/shotarosano5/in-62843951

HDP-LDA LDA (Latent Dirichlet Allocation) is unsupervised learning and is used in topic models. (LDA Pyro implementation example) HDP-LDA using HDP (Hierarchical Dirichlet Process) can automatically determine the appropriate number of topics. It seems easy to do with a Python library called Gensim. The hierarchical Dirichlet process is modeled on CRF (Chinese restaurant franchise), which is one step more complicated than CRP (Chinese restaurant process).

the end

I thought that nonparametric Bayes is not without parameters, but is thinking about infinite or many parameters (there is no restriction on the number of parameters). However, please note that the upper limit of the number of clusters must be set in this implementation of TSB. I would also like to try Implementing CRP, which does not require an upper limit.

Recommended Posts

Nonparametric Bayes with Pyro
Introduction to Nonparametric Bayes
Deep Kernel Learning with Pyro
[Python] Bayesian inference with Pyro
[Python] Mixed Gauss model with Pyro
Text filtering with naive bayes in sklearn