[PYTHON] [PyStan] Try Graphical Lasso with Stan.

Hello, this is a long time to write a blog @ kenmatsu4. I wrote the article on the 23rd day of Stan Advent Calendar.

In this blog, I will try using Stan, a method called Graphical Lasso, which estimates the precision matrix (inverse matrix of the variance-covariance matrix) with L1 regularization. The full code has been uploaded to GitHub.

1. Generate test data

First, generate a random number that follows a multivariate normal distribution. This time, we will generate 300 6-dimensional data with the following mean and variance. Then force $ x_6 $ and $ x_4 $, and then $ x_6 $ and $ x_5 $ to correlate, so that $ x_4 $ and $ x_5 $ have an indirect correlation. Originally there was no $ x_4 $ and $ x_5 $, but under the influence of $ x_6 $, the values of $ x_4 $ and $ x_5 $ move in conjunction with the fluctuation of $ x_6 $, so variables that are not originally correlated with each other Will appear to have a correlation. ← (*)

python


    m = np.zeros(M)
    #Covariance matrix for random number generation of multivariate normal distribution
    cov = [[   1, .29, .49, .10, .30,  0],
           [ .29,   1,-.49, .49, .40,  0],
           [ .49,-.49,   1,   0,   0,  0],
           [ .10, .49,   0,   1,   0,  0],
           [ .30, .40,   0,   0,   1,  0],
           [   0,   0,   0,   0,   0,  1]]

    #Create 6 variables
    X = st.multivariate_normal.rvs(mean=m, cov=cov, size=size, random_state=random_state)

    #Correlate x4 and x6 (*)
    X[:,3] += 0.6*X[:,5]
    
    #Correlate x5 and x6 (*)
    X[:,4] += 0.6*X[:,5]

** Figure: Image of indirect correlation ** スクリーンショット 2016-12-23 20.43.46.png

The scatter plot of the obtained data with pairplot is as follows.

スクリーンショット 2016-12-23 20.48.58.png

The correlation matrix and the partial correlation matrix are calculated as follows. Partial correlation can be used as a value that indicates the degree of direct correlation, excluding the effect of indirect correlation as shown in the previous figure. I would like to see what the partial correlation is like immediately.

The partial correlation matrix can be expressed using the element $ \ lambda_ {ij} $ of the precision matrix (inverse matrix of the variance-covariance matrix) $ \ Lambda $.

\hat{\rho}_{ij} = {-\lambda_{ij} \over \sqrt{\lambda_{ii}\lambda_{jj}}}

It will be.

partial_corr_est.png

The right is the partial correlation matrix, but if you calculate using the variance-covariance matrix that is calculated straight from the data, it will be affected by noise, and the value is not clear. It doesn't seem that the structure of the data is very visible. It looks like all the variables are related ... That shouldn't be the case.

Therefore, I would like to exclude the influence of noise using L1 regularization, that is, Lasso, and use the estimated variance-covariance matrix and precision matrix.

  1. Graphical Lasso A Markov graph that assumes a multivariate normal distribution is called a Gaussian graphical model, and the relationship between variables can be viewed as a graphical model using the precision matrix $ \ Lambda $, which is a parameter of this distribution. That is, if the i, j element $ \ lambda_ {ij} $ of the precision matrix is non-zero, there is a direct correlation between $ x_i $ and $ x_j $. This state is represented by the graph below and is called a graphical model.
スクリーンショット 2016-12-23 21.26.05.png

If the data $ \ boldsymbol {x} $ follows a multivariate normal distribution, the distribution is

\mathcal{N}(\boldsymbol{x} | {\bf 0}, \Lambda^{-1}) \sim {|\Lambda|^{1/2} \over (2\pi)^{M/2}} \exp\left( -{1 \over 2} \boldsymbol{x}^\intercal \Lambda \boldsymbol{x} \right)

Can be expressed as. However, since we decided to consider that there is a direct correlation if a value other than 0 is included as it is, the graph structure above has a direct correlation between most variables, considering that noise is added to most estimation results. It will end up being. It is necessary to devise so that a sparse precision matrix can be obtained as much as possible. A sparse solution can be obtained by assuming a prior distribution of the Laplace distribution $ p (\ Lambda) $ in the precision matrix $ \ Lambda $. The posterior distribution is

p(\Lambda|\boldsymbol{x}) \propto p(\Lambda)\prod_{n=1}^{N}\mathcal{N}(\boldsymbol{x}^{(n)}|{\bf 0}, \Lambda)

Therefore, by taking a log and differentiating it with $ \ Lambda $

\ln |\Lambda| - \mathrm{tr}(S\Lambda)-\alpha\|\Lambda\|_1

Therefore, setting the prior distribution of the Laplace distribution means that L1 regularization is applied.

3. Find the solution of Graphical Lasso using Scikit-Learn

Scikit-Learn implements Graph Lasso that implements this Graphical Lasso. .. An optimization method called the coordinate descent method is used for this. Let's try this first.

It's very easy to implement. Just fit as usual: relaxed: Then you can get the covariance and precision matrices.

python


alpha = 0.2 #L1 regularization parameter
model = GraphLasso(alpha=alpha,
                     max_iter=100,                     
                     verbose=True,
                     assume_centered = True)

model.fit(X)
cov_ = model.covariance_ #Covariance matrix
prec_ = model.precision_ #Precision matrix

The obtained variance-covariance matrix and precision matrix are as follows. glasso_cov_prec.png

In addition, the correlation matrix calculated from it and the partial correlation matrix are as follows. Of the $ x_4, x_5, x_6 $ correlations created by force, you can see that the correlation between $ x_4 $ and $ x_5 $ is 0 on the partial correlation matrix. corr_pcorr_sklearn.png

Now, I'm happy that I can find the partial correlation matrix from the sparse precision matrix, but this blog is an article of Stan Advent Calendar 2016. And I wrote earlier that this L1 regularization can be interpreted as assuming a Laplace distribution for the prior distribution of $ \ Lambda $. Then you can try this with Stan.

4. Find the solution of Graphical Lasso using Stan

So, let's write stan code and do the same thing as this Graphical Lasso. In stan, the Laplace distribution is called the double exponential distribution, so we will use it. The stan code is as follows.

glasso.stan


data {
  int N;           // Sample size
  int P;           // Feature size
  matrix[N, P] X;  // Data
  real alpha;      // scale parameter of double exponential (L1 parameter)
}
parameters {
  corr_matrix[P] Lambda; // Covariance matrix
}
model {
  vector[P] zeros;
  for (i in 1:P) {
     zeros[i] = 0;
  }
  
  // Precision matrix follows laplace distribution
  to_vector(Lambda) ~ double_exponential(0, 1/alpha);
  
  for (j in 1:N){
    // X follows multi normal distribution
    X[j] ~ multi_normal(zeros, inverse(Lambda));
  }
}
generated quantities {
  matrix[P, P] Sigma;
  Sigma = inverse(Lambda);
}

Here is the Python code that calls this.

python


%time sm = pystan.StanModel(file='glasso.stan')
print('Compile finished.')

n_sample = 1000  #Number of samples per chain
n_warm   = 1000  #Number used for warm up
n_chain  = 4     #Number of chains
stan_data = {'N': X.shape[0], 'P': P, 'X': X, 'alpha': alpha}
%time fit = sm.sampling(data=stan_data, chains=n_chain, iter=n_sample+n_warm, warmup=n_warm)
print('Sampling finished.')

fit

Here is the result. Note that the Rhat of the diagonal element of the covariance matrix is nan, but the value does not seem to be strange, and the Rhat of all other elements is 1.0.

out


Inference for Stan model: anon_model_31ac7e216f1b5eccff16f1394bd9827e.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

              mean se_mean     sd   2.5%    25%    50%    75%  97.5%  n_eff   Rhat
Lambda[0,0]    1.0     0.0    0.0    1.0    1.0    1.0    1.0    1.0   4000    nan
Lambda[1,0]  -0.31  8.7e-4   0.05   -0.4  -0.34  -0.31  -0.28  -0.21   2891    1.0
Lambda[2,0]  -0.42  7.1e-4   0.04   -0.5  -0.45  -0.42  -0.39  -0.33   3735    1.0
Lambda[3,0]   0.04  1.0e-3   0.05  -0.07 2.1e-3   0.04   0.07   0.14   2808    1.0
Lambda[4,0]  -0.12  9.1e-4   0.05  -0.22  -0.15  -0.11  -0.08-9.5e-3   3437    1.0
Lambda[5,0]   0.02  1.0e-3   0.06  -0.09  -0.01   0.02   0.06   0.13   3014    1.0
Lambda[0,1]  -0.31  8.7e-4   0.05   -0.4  -0.34  -0.31  -0.28  -0.21   2891    1.0
Lambda[1,1]    1.0 1.5e-189.0e-17    1.0    1.0    1.0    1.0    1.0   3633    nan
Lambda[2,1]   0.47  6.3e-4   0.04   0.39   0.44   0.47    0.5   0.55   4000    1.0
Lambda[3,1]  -0.31  7.6e-4   0.05   -0.4  -0.34  -0.31  -0.28  -0.22   3810    1.0
Lambda[4,1]  -0.19  9.4e-4   0.05  -0.29  -0.22  -0.19  -0.15  -0.08   3021    1.0
Lambda[5,1]    0.2  8.8e-4   0.05    0.1   0.16    0.2   0.23    0.3   3395    1.0
Lambda[0,2]  -0.42  7.1e-4   0.04   -0.5  -0.45  -0.42  -0.39  -0.33   3735    1.0
Lambda[1,2]   0.47  6.3e-4   0.04   0.39   0.44   0.47    0.5   0.55   4000    1.0
Lambda[2,2]    1.0 3.6e-188.7e-17    1.0    1.0    1.0    1.0    1.0    594    nan
Lambda[3,2]  -0.11  8.9e-4   0.05  -0.22  -0.15  -0.11  -0.08  -0.01   3623    1.0
Lambda[4,2]  -0.04  9.1e-4   0.05  -0.15  -0.08  -0.04-5.8e-3   0.07   3642    1.0
Lambda[5,2]   0.03  9.0e-4   0.05  -0.08-9.2e-3   0.03   0.06   0.13   3495    1.0
Lambda[0,3]   0.04  1.0e-3   0.05  -0.07 2.1e-3   0.04   0.07   0.14   2808    1.0
Lambda[1,3]  -0.31  7.6e-4   0.05   -0.4  -0.34  -0.31  -0.28  -0.22   3810    1.0
Lambda[2,3]  -0.11  8.9e-4   0.05  -0.22  -0.15  -0.11  -0.08  -0.01   3623    1.0
Lambda[3,3]    1.0 2.0e-181.2e-16    1.0    1.0    1.0    1.0    1.0   3553    nan
Lambda[4,3]  -0.02  9.3e-4   0.06  -0.13  -0.06  -0.02   0.02   0.09   3503    1.0
Lambda[5,3]  -0.38  7.5e-4   0.04  -0.47  -0.41  -0.38  -0.35  -0.29   3591    1.0
Lambda[0,4]  -0.12  9.1e-4   0.05  -0.22  -0.15  -0.11  -0.08-9.5e-3   3437    1.0
Lambda[1,4]  -0.19  9.4e-4   0.05  -0.29  -0.22  -0.19  -0.15  -0.08   3021    1.0
Lambda[2,4]  -0.04  9.1e-4   0.05  -0.15  -0.08  -0.04-5.8e-3   0.07   3642    1.0
Lambda[3,4]  -0.02  9.3e-4   0.06  -0.13  -0.06  -0.02   0.02   0.09   3503    1.0
Lambda[4,4]    1.0 2.0e-181.2e-16    1.0    1.0    1.0    1.0    1.0   3633    nan
Lambda[5,4]  -0.36  7.2e-4   0.05  -0.45  -0.39  -0.36  -0.33  -0.27   4000    1.0
Lambda[0,5]   0.02  1.0e-3   0.06  -0.09  -0.01   0.02   0.06   0.13   3014    1.0
Lambda[1,5]    0.2  8.8e-4   0.05    0.1   0.16    0.2   0.23    0.3   3395    1.0
Lambda[2,5]   0.03  9.0e-4   0.05  -0.08-9.2e-3   0.03   0.06   0.13   3495    1.0
Lambda[3,5]  -0.38  7.5e-4   0.04  -0.47  -0.41  -0.38  -0.35  -0.29   3591    1.0
Lambda[4,5]  -0.36  7.2e-4   0.05  -0.45  -0.39  -0.36  -0.33  -0.27   4000    1.0
Lambda[5,5]    1.0 2.2e-181.3e-16    1.0    1.0    1.0    1.0    1.0   3381    nan
Sigma[0,0]    1.31  1.1e-3   0.07   1.19   1.26    1.3   1.35   1.45   3507    1.0
Sigma[1,0]    0.26  1.3e-3   0.08   0.11   0.21   0.27   0.32   0.43   4000    1.0
Sigma[2,0]    0.45  1.3e-3   0.08   0.29   0.39   0.44   0.51   0.62   4000    1.0
Sigma[3,0]     0.1  1.2e-3   0.08  -0.05   0.05    0.1   0.15   0.25   4000    1.0
Sigma[4,0]    0.23  1.2e-3   0.08   0.09   0.18   0.23   0.28   0.38   4000    1.0
Sigma[5,0]    0.03  1.2e-3   0.08  -0.13  -0.02   0.03   0.08   0.18   4000    1.0
Sigma[0,1]    0.26  1.3e-3   0.08   0.11   0.21   0.27   0.32   0.43   4000    1.0
Sigma[1,1]    1.55  1.5e-3   0.09   1.38   1.48   1.54   1.61   1.74   4000    1.0
Sigma[2,1]   -0.56  1.4e-3   0.09  -0.74  -0.62  -0.56  -0.49  -0.39   4000    1.0
Sigma[3,1]    0.41  1.3e-3   0.08   0.24   0.35    0.4   0.46   0.57   4000    1.0
Sigma[4,1]    0.29  1.3e-3   0.08   0.14   0.24   0.29   0.34   0.46   4000    1.0
Sigma[5,1]   -0.04  1.3e-3   0.08   -0.2  -0.09  -0.04   0.02   0.13   4000    1.0
Sigma[0,2]    0.45  1.3e-3   0.08   0.29   0.39   0.44   0.51   0.62   4000    1.0
Sigma[1,2]   -0.56  1.4e-3   0.09  -0.74  -0.62  -0.56  -0.49  -0.39   4000    1.0
Sigma[2,2]    1.47  1.3e-3   0.08   1.32   1.41   1.46   1.52   1.65   4000    1.0
Sigma[3,2]  2.9e-3  1.3e-3   0.08  -0.15  -0.05 1.3e-3   0.06   0.16   4000    1.0
Sigma[4,2]    0.04  1.2e-3   0.08  -0.12  -0.02   0.03   0.09   0.19   4000    1.0
Sigma[5,2]    0.07  1.3e-3   0.08  -0.08   0.02   0.08   0.13   0.23   4000    1.0
Sigma[0,3]     0.1  1.2e-3   0.08  -0.05   0.05    0.1   0.15   0.25   4000    1.0
Sigma[1,3]    0.41  1.3e-3   0.08   0.24   0.35    0.4   0.46   0.57   4000    1.0
Sigma[2,3]  2.9e-3  1.3e-3   0.08  -0.15  -0.05 1.3e-3   0.06   0.16   4000    1.0
Sigma[3,3]    1.36  1.1e-3   0.07   1.23    1.3   1.35    1.4   1.51   4000    1.0
Sigma[4,3]    0.31  1.2e-3   0.08   0.17   0.26   0.31   0.36   0.47   4000    1.0
Sigma[5,3]    0.55  1.4e-3   0.09   0.39   0.49   0.55    0.6   0.73   4000    1.0
Sigma[0,4]    0.23  1.2e-3   0.08   0.09   0.18   0.23   0.28   0.38   4000    1.0
Sigma[1,4]    0.29  1.3e-3   0.08   0.14   0.24   0.29   0.34   0.46   4000    1.0
Sigma[2,4]    0.04  1.2e-3   0.08  -0.12  -0.02   0.03   0.09   0.19   4000    1.0
Sigma[3,4]    0.31  1.2e-3   0.08   0.17   0.26   0.31   0.36   0.47   4000    1.0
Sigma[4,4]    1.29  9.9e-4   0.06   1.19   1.25   1.29   1.33   1.43   4000    1.0
Sigma[5,4]    0.53  1.3e-3   0.08   0.38   0.47   0.52   0.58    0.7   4000    1.0
Sigma[0,5]    0.03  1.2e-3   0.08  -0.13  -0.02   0.03   0.08   0.18   4000    1.0
Sigma[1,5]   -0.04  1.3e-3   0.08   -0.2  -0.09  -0.04   0.02   0.13   4000    1.0
Sigma[2,5]    0.07  1.3e-3   0.08  -0.08   0.02   0.08   0.13   0.23   4000    1.0
Sigma[3,5]    0.55  1.4e-3   0.09   0.39   0.49   0.55    0.6   0.73   4000    1.0
Sigma[4,5]    0.53  1.3e-3   0.08   0.38   0.47   0.52   0.58    0.7   4000    1.0
Sigma[5,5]    1.42  1.3e-3   0.08   1.28   1.36   1.41   1.47   1.59   4000    1.0
lp__        -713.2    0.06   2.67 -719.3 -714.9 -712.9 -711.2 -709.0   1983    1.0

Samples were drawn using NUTS at Sat Dec 24 00:05:39 2016.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

It's a little hard to see with just this, so let's draw a graph.

python


#Retrieving estimated parameters
Lambda = fit.extract()["Lambda"]
Sigma  = fit.extract()["Sigma"]

#Calculation of EAP estimator
EAP_Sigma  = np.mean(Sigma, axis=0)
EAP_Lambda = np.mean(Lambda, axis=0)

#Visualization of EAP estimators
plt.figure(figsize=(10,4))
ax = plt.subplot(121)
sns.heatmap(pd.DataFrame(EAP_Sigma), annot=EAP_Sigma, fmt='0.2f', ax=ax, xticklabels=label, yticklabels=label)
plt.title("Graphical Lasso with Stan: Covariance matrix")

ax = plt.subplot(122)
sns.heatmap(pd.DataFrame(EAP_Lambda), annot=EAP_Lambda, fmt='0.2f', ax=ax, xticklabels=label, yticklabels=label)
plt.title("Graphical Lasso with Stan: Precision matrix")
plt.savefig(img_path+"glasso_stan_cov_prec.png ", dpi=128)
plt.show()

glasso_stan_cov_prec.png

python


#Calculation of correlation matrix
EAP_cor = np.empty_like(EAP_Sigma)
for i in range(P):
    for j in range(P):
        EAP_cor[i, j] = EAP_Sigma[i, j]/np.sqrt(EAP_Sigma[i, i]*EAP_Sigma[j, j])
        
#Calculation of partial correlation matrix
EAP_rho = np.empty_like(EAP_Lambda)
for i in range(P):
    for j in range(P):
        EAP_rho[i, j] = -EAP_Lambda[i, j]/np.sqrt(EAP_Lambda[i, i]*EAP_Lambda[j, j])
        
plt.figure(figsize=(11,4))
ax = plt.subplot(122)
sns.heatmap(pd.DataFrame(EAP_rho), annot=EAP_rho, fmt='0.2f', ax=ax, xticklabels=label, yticklabels=label)
plt.title("Partial correlation Coefficiant with stan")
#plt.savefig(img_path+"partial_corr_sklearn.png ", dpi=128)

ax = plt.subplot(121)
sns.heatmap(pd.DataFrame(EAP_cor), annot=EAP_cor, fmt='0.2f', ax=ax, xticklabels=label, yticklabels=label)
plt.title("Correlation Coefficiant with stan")
plt.savefig(img_path+"corr_pcorr_stan.png ", dpi=128)
plt.show()

Obviously, because of random number simulation, there is no element that becomes completely 0. In that sense, the effect of L1 regularization cannot be obtained by Stan simulation. corr_pcorr_stan.png

This is from Scikit-Learn. For comparison.

corr_pcorr_sklearn.png

The values are slightly different, but they have a similar structure. The indirect correlation between $ x_4 $ and $ x_5 $ also disappears in the partial correlation matrix. The sampled histogram and the result of Scikit-Learn are superimposed and drawn below.

Visualization of sampling results

** Histogram of sampling results ** grid_dist_plot1.png

Enlarged version. The red line is the result of Scikit-Learn. The dotted lines are the 2.5% and 97.5% points of the posterior distribution. Some of them are off, but you can see that they are in the section at a reasonable rate. So, I think we can say that the result is almost the same (excluding diagonal elements). Not everything is in the conviction section, so you may need to tune a little more. grid_dist_plot2.png

Trace Plot trace_plot.png

5. Conclusion

I learned that Graphical Lasso's L1 regularization is a Laplace distribution of the prior distribution of parameters, and I wanted to try it with Stan. I got a partial correlation matrix with a similar structure, but there is a slight discrepancy with the result of Scikit-Learn, so I would like to investigate a little more.

reference

"Abnormality detection and change detection (machine learning professional series)" Takeshi Ide, Masashi Sugiyama Stan Modeling Language User’s Guide and Reference Manual  ⇒ http://www.uvm.edu/~bbeckage/Teaching/DataAnalysis/Manuals/stan-reference-2.8.0.pdf Partial Correration Coefficient  ⇒ http://www.ae.keio.ac.jp/lab/soc/takeuchi/lectures/5_Parcor.pdf

Recommended Posts

[PyStan] Try Graphical Lasso with Stan.
Try scraping with Python.
Try regression with TensorFlow
Try multivariable correlation analysis using Graphical lasso at explosive speed
Try to factorial with recursion
Try function optimization with Optuna
Try deep learning with TensorFlow
Try using PythonTex with Texpad.
Try edge detection with OpenCV
Try implementing RBM with chainer.
Try Google Mock with C
Try using matplotlib with PyCharm
Try GUI programming with Hy
Try an autoencoder with Pytorch
Try Python output with Haxe 3.2
Try matrix operation with NumPy
Try running CNN with ChainerRL
Try various things with PhantomJS
Try implementing perfume with Go
[MCMC] Calculate WAIC with pystan
Try Selenium Grid with Docker
Try face recognition with Python
Try OpenCV with Google Colaboratory
Try TensorFlow MNIST with RNN
Try building JupyterHub with Docker
Try using folium with anaconda