Plot and understand the multivariate normal distribution in Python

Introduction

I tried plotting with python to get an image of the "multivariate normal distribution" that came out when I was studying statistics. This time, the number of $ n $ is set to 2 and the two-dimensional normal distribution is plotted so that it can be easily visualized.

reference

In understanding the multivariate normal distribution and plotting it, I referred to the following.

-Understanding the multivariate normal distribution -Graph the multivariate normal distribution with python to get an image -Plot the 2D Gaussian distribution

Overview of multivariate normal distribution

The multivariate normal distribution of the $ n $ variable is expressed as follows.


f(\vec{x}) = \frac{1}{\sqrt{(2\pi)^n |\sum|}}exp \left \{-\frac{1}{2}{}^t (\vec{x}-\vec{\mu}) {\sum}^{-1} (\vec{x}-\vec{\mu}) \right \}

Since there are $ n $ variables, the data is represented in $ n $ dimensional vector notation. Furthermore, since the average value $ \ mu $ exists as many as the number of variables, it is also expressed in vector notation.


{ \begin{equation}\vec{x}=\begin{pmatrix}x_1 \\ x_2 \\ \vdots \\ x_n \\  \end{pmatrix}, \vec{\mu}=\begin{pmatrix}\mu_1 \\ \mu_2 \\ \vdots \\ \mu_n \\  \end{pmatrix}   \end{equation}
}

One element $ x_ {i} $ represents the data of the random variable $ X_ {i} $, and the mean value $ \ mu_i $ represents the mean value of the random variable $ X_ {i} $. Next, regarding variance, in the case of multivariate, it is necessary to consider not only the distribution of each data but also the correlation between the data, so the ** variance-covariance matrix $ \ sum $ ** is used.


{ \begin{equation}\ \ \ \Sigma =  \begin{pmatrix} \sigma_{1}^2 & \cdots & \sigma_{1i} & \cdots & \sigma_{1n}\\ \vdots & \ddots & & & \vdots \\ \sigma_{i1} & & \sigma_{i}^2 & & \sigma_{in} \\ \vdots & & & \ddots & \vdots \\ \sigma_{n1} & \cdots & \sigma_{ni} & \cdots & \sigma_{n}^2 \end{pmatrix} \end{equation}
}

$ \ sigma ^ 2_i $ is the variance of the $ i $ th variable, and $ \ sigma_ {ij} = \ sigma_ {ji} (i ≠ j) $ is the covariance between the $ i $ th and $ j $ th variables. It is distributed. And the two-dimensional normal distribution when $ n $ is $ 2 $ is expressed as follows.

N_2 \left ( \begin{pmatrix}  \mu_x \\  \mu_y \\  \end{pmatrix} , \begin{pmatrix}  \sigma_{x}^2 & \sigma_{xy}\\  \sigma_{xy} & \sigma_{y}^2\\  \end{pmatrix} \right  )

Now I would like to plot a two-dimensional normal distribution.

Two-dimensional normal distribution plot

The script to plot the 2D normal distribution is as follows. First, let's output both variables according to the standard normal distribution when they are independent of each other.

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from matplotlib import cm

#Create data to put into the function
x = y = np.arange(-20, 20, 0.5)
X, Y = np.meshgrid(x, y)

z = np.c_[X.ravel(),Y.ravel()]

#A function that returns the probability density of a two-dimensional normal distribution
def gaussian(x):
    #Determinant of variance-covariance matrix
    det = np.linalg.det(sigma)
    print(det)
    #Inverse of the covariance matrix
    inv = np.linalg.inv(sigma)
    n = x.ndim
    print(inv)
    return np.exp(-np.diag((x - mu)@inv@(x - mu).T)/2.0) / (np.sqrt((2 * np.pi) ** n * det))

#Specify the mean value of 2 variables
mu = np.array([0,0])
#Specify a two-variable variance-covariance matrix
sigma = np.array([[1,0],[0,1]])

Z = gaussian(z)
shape = X.shape
Z = Z.reshape(shape)

#Plot a two-dimensional normal distribution
fig = plt.figure(figsize = (15, 15))
ax = fig.add_subplot(111, projection='3d')
    
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm)
plt.show()

The output result is as follows. Since both variables are normally distributed, the graph will be unbiased and sharp.

多変量正規分布1.png

Now we will also plot graphs of different shapes. Let's plot the two-dimensional normal distribution when the distribution of two variables is as follows.

#Specify the mean value of 2 variables
mu = np.array([3,1])
#Specify a two-variable variance-covariance matrix
sigma = np.array([[10,5],[5,10]])

The following is the same as the plot above.


Z = gaussian(z)
shape = X.shape
Z = Z.reshape(shape)

#Plot a two-dimensional normal distribution
fig = plt.figure(figsize = (15, 15))
ax = fig.add_subplot(111, projection='3d')
    
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, cmap=cm.coolwarm)
plt.show()

The output result is as follows. This time, we plot the distributions that correlate with each other, so we can see that the shape is slightly diagonally distorted.

ダウンロード (1).png

It's easy to get an image by visualizing things that were difficult to understand in mathematical formulas.

Next When studying statistics, it is often difficult to get an image with just mathematical formulas, so I would like to actively try to write it myself with python or plot it to visualize it.

Recommended Posts

Plot and understand the multivariate normal distribution in Python
Carefully understand the Poisson distribution and draw in Python
Create an elliptical scatter plot in Python without using a multivariate normal distribution
Mixed normal distribution implementation in python
Poisson distribution and Poisson cumulative distribution plot via sqlite in Python and Java
Graph the Poisson distribution and the Poisson cumulative distribution in Python and Java, respectively.
Explain the nature of the multivariate normal distribution graphically
Match the distribution of each group in Python
Create a standard normal distribution graph in Python
About the difference between "==" and "is" in python
[Python] Display the Altair legend in the plot area
How to plot autocorrelation and partial autocorrelation in python
Logistic distribution in Python
plot the coordinates of the processing (python) list and specify the number of times in draw ()
Note that I understand the least squares algorithm. And I wrote it in Python.
The simplest Python memo in Japan (classes and objects)
Receive the form in Python and do various things
Find the Hermitian matrix and its eigenvalues in Python
Plot Bitcoin candle charts and technical indicators in Python
Check the asymptotic nature of the probability distribution in Python
Download the file in Python
Find the difference in Python
Write beta distribution in Python
Understand Python packages and modules
Generate U distribution in Python
Stack and Queue in Python
Unittest and CI in Python
I understand Python in Japanese!
Plot geographic information in Python
Get the MIME type in Python and determine the file format
Sort and output the elements in the list as elements and multiples in Python.
Understanding the meaning of complex and bizarre normal distribution formulas
[python] plot the values ​​before and after the conversion of yeojohnson conversion
[Understand in the shortest time] Python basics for data analysis
Manipulate the clipboard in Python and paste the table into Excel
I tried programming the chi-square test in Python and Java.
[Python] Display the elapsed time in hours, minutes, and seconds (00:00:00)
Get the current date and time in Python, considering the time difference
[Statistics] Let's visualize the relationship between the normal distribution and the chi-square distribution.
[Tips] Problems and solutions in the development of python + kivy
Determine the date and time format in Python and convert to Unixtime
The story of Python and the story of NaN
MIDI packages in Python midi and pretty_midi
Count the number of Thai and Arabic characters well in Python
Getting the arXiv API in Python
Difference between list () and [] in Python
Difference between == and is in python
View photos in Python and html
Sorting algorithm and implementation in Python
Python in the browser: Brython's recommendation
Save the binary file in Python
Hit the Sesami API in Python
[Python] PCA scratch in the example of "Introduction to multivariate analysis"
[Python / matplotlib] Understand and use FuncAnimation
New Python grammar and features not mentioned in the introductory book
Try transcribing the probability mass function of the binomial distribution in Python
Get the desktop path in Python
About dtypes in Python and Cython
Get the script path in Python
In the python command python points to python3.8
Implement the Singleton pattern in Python