[PYTHON] Explanation and implementation of PRML Chapter 4

PRML learning notes

Now that I am in charge of the presentation of the 4th chapter of "Pattern Recognition and Machine Learning", I would like to write about what I learned and a little commentary. I am one of the people who struggled with this book, so I would be very happy if it would be helpful when there are people with similar circumstances in the future. If you find a mathematical error or point out that it is better to do this, please do not hesitate to contact us.

Fisher's linear discrimination

2 classes

The term of the discriminant function starts from the least squares, but the least squares is omitted because it is the conclusion that "it is natural that it cannot be used well". So from the 2nd class Fisher. Here, we look at linear discrimination from the perspective of dimensionality reduction.

Obtain a D-dimensional vector as input and project it to one dimension with the following formula.

python


y = \boldsymbol{w}^T\boldsymbol{x}

Set a threshold value for $ y $ and classify it as class $ C_1 $ when $ y \ ge -w_0 $, otherwise classify it as $ C_2 $. Since information loss will occur due to the reduced dimension, I would like to adjust $ \ boldsymbol {w} $ to maximize class separation.

Here, assuming that there are $ N_1 $ points in class $ C_1 $ and $ N_2 $ points in $ C_2 $, the average vector of each class is

python


\boldsymbol{m}_1 = \frac{1}{N_1}\sum_{n \in C_1}\boldsymbol{x}_n, \quad
\boldsymbol{m}_2 = \frac{1}{N_2}\sum_{n \in C_2}\boldsymbol{x}_n

At this time, based on the idea of "projecting to the place where the averages of the classes are farthest from each other", select $ \ boldsymbol {w} $ that maximizes the following formula.

python


m_2 - m_1 = \boldsymbol{w}^T(\boldsymbol{m}_2 - \boldsymbol{m}_1)

Here, $ m_k $ represents the average of the data projected from $ C_k $. Since it is meaningless if $ \ boldsymbol {w} $ can be increased as much as possible, a constraint of norm = 1 is added. The so-called Lagrange's undetermined multiplier method comes into play. If you know the basics of vector differentiation, there is no problem.

python


L = \boldsymbol{w}^T(\boldsymbol{m}_2 - \boldsymbol{m}_1) + \lambda(\boldsymbol{w}^T\boldsymbol{w}-1)\\
\\
\nabla L=(\boldsymbol{m}_2 - \boldsymbol{m}_1)+2\lambda\boldsymbol{w}\\
\\
\boldsymbol{w}=-\frac{1}{2\lambda}(\boldsymbol{m}_2 - \boldsymbol{m}_1)\propto(\boldsymbol{m}_2 - \boldsymbol{m}_1)

However, in reality, this may still cause the classes to overlap. Therefore, I would like to take a method such as "the same class is grouped together after projection, and the classes are separated from each other". Therefore, we introduced Fisher's discrimination criteria. Intraclass variance of each class

python


s_k^2 = \sum_{n \in C_k}(y_k - m_k)^2

Therefore, the discrimination criteria are as follows

python


J(\boldsymbol{w}) = \frac{(m_2-m_1)^2}{s_1^2 + s_2^2}

The denominator is the variance within the total class, defined by the sum of the variances of each class. Molecules are dispersed between classes. In this section, this is rewritten as follows.

python


J(\boldsymbol{w}) = \frac{\boldsymbol{w}^T\boldsymbol{S}_\boldsymbol{B}\boldsymbol{w}}{\boldsymbol{w}^T\boldsymbol{S}_\boldsymbol{W}\boldsymbol{w}}

here

python


\boldsymbol{S}_\boldsymbol{B} = (\boldsymbol{m}_2 - \boldsymbol{m}_1)(\boldsymbol{m}_2 - \boldsymbol{m}_1)^T\\
\\
\boldsymbol{S}_\boldsymbol{W} =\sum_{k}\sum_{n\in C_k}(\boldsymbol{x}_n-m_k)(\boldsymbol{x}_n-m_k)
^T

The former is called the interclass covariance matrix and the latter is called the total intraclass covariance matrix. I was confused because it looked a little difficult for me, but if I expand it by using the fact that the denominator and numerator are $ y = \ boldsymbol {w} ^ T \ boldsymbol {x} $, the original It turns out that it is the same as the formula.

Therefore, by differentiating J (w) with respect to w and setting it to zero, w that maximizes J can be obtained.

python



\frac{\partial J}{\partial w}=\frac{(2(\boldsymbol{S}_\boldsymbol{B}\boldsymbol{w})(\boldsymbol{w}^T\boldsymbol{S}_\boldsymbol{W}\boldsymbol{w})-2(\boldsymbol{S}_\boldsymbol{W}\boldsymbol{w})(\boldsymbol{w}^T\boldsymbol{S}_\boldsymbol{B}\boldsymbol{w}))}{(\boldsymbol{w}^T\boldsymbol{S}_\boldsymbol{W}\boldsymbol{w})^2}=0\\
\\\\
(\boldsymbol{S}_\boldsymbol{W}\boldsymbol{w})(\boldsymbol{w}^T\boldsymbol{S}_\boldsymbol{B}\boldsymbol{w}) = (\boldsymbol{S}_\boldsymbol{B}\boldsymbol{w})(\boldsymbol{w}^T\boldsymbol{S}_\boldsymbol{W}\boldsymbol{w})

$ \ Boldsymbol {w} ^ T \ boldsymbol {S} _ \ boldsymbol {W} \ boldsymbol {w} $ is a scalar and the covariance matrix is a symmetric matrix when differentiating the quadratic form The point is that I am using it. I will write about this in another article.

As before, the important thing this time is the orientation of $ \ boldsymbol {w} $, not the size, so $ \ boldsymbol {S} _ \ boldsymbol {B} \ boldsymbol {w} $

python


\boldsymbol{S}_\boldsymbol{B}\boldsymbol{w} = (\boldsymbol{m}_2 - \boldsymbol{m}_1)(\boldsymbol{m}_2 - \boldsymbol{m}_1)^T\boldsymbol{w}

By taking advantage of the fact that it is a vector in the same direction as $ (\ boldsymbol {m} _2-\ boldsymbol {m} _1) $

python


\boldsymbol{w} \propto \boldsymbol{S}_\boldsymbol{W}^-1(\boldsymbol{m}_2 - \boldsymbol{m}_1)

Now that the direction of w has been decided, that's it!

Extra: I tried to make a code

fisher_2d.py


# Class 1
mu1 = [5, 5]
sigma = np.eye(2, 2)
c_1 = np.random.multivariate_normal(mu1, sigma, 100).T

# Class 2
mu2 = [0, 0]
c_2 = np.random.multivariate_normal(mu2, sigma, 100).T

# Average vectors
m_1 = np.sum(c_1, axis=1, keepdims=True) / 100.
m_2 = np.sum(c_2, axis=1, keepdims=True) / 100.

# within-class covariance matrix 
S_W = np.dot((c_1 - m_1), (c_1 - m_1).T) + np.dot((c_2 - m_2), (c_2 - m_2).T)

w = np.dot(np.linalg.inv(S_W), (m_2 - m_1))
w = w/np.linalg.norm(w)

plt.quiver(4, 2, w[1, 0], -w[0, 0], angles="xy", units="xy", color="black", scale=0.5)
plt.scatter(c_1[0, :], c_1[1, :])
plt.scatter(c_2[0, :], c_2[1, :])

Here is the result of plotting the vector obtained using quiver

Screen Shot 2020-04-30 at 19.58.45.png

The direction is good. So next time, I will write about the multi-class version.

Recommended Posts

Explanation and implementation of PRML Chapter 4
Explanation and implementation of SocialFoceModel
Explanation and implementation of ESIM algorithm
Explanation and implementation of simple perceptron
Explanation and implementation of Decomposable Attention algorithm
Explanation of edit distance and implementation in Python
Introduction and Implementation of JoCoR-Loss (CVPR2020)
Introduction and implementation of activation function
PRML Chapter 5 Neural Network Python Implementation
[Reinforcement learning] Explanation and implementation of Ape-X in Keras (failure)
PRML Chapter 3 Evidence Approximation Python Implementation
Explanation of CSV and implementation example in each programming language
Mathematical explanation of binary search and ternary search and implementation method without bugs
PRML Chapter 13 Maximum Likelihood Estimating Python Implementation of Hidden Markov Models
Implementation and experiment of convex clustering method
PRML Chapter 4 Bayesian Logistic Regression Python Implementation
PRML Chapter 5 Mixed Density Network Python Implementation
PRML Chapter 9 Mixed Gaussian Distribution Python Implementation
PRML Chapter 14 Conditional Mixed Model Python Implementation
PRML implementation Chapter 3 Linear basis function model
PRML Chapter 10 Variational Gaussian Distribution Python Implementation
PRML Chapter 6 Gaussian Process Regression Python Implementation
PRML Chapter 2 Student's t Distribution Python Implementation
PRML Chapter 1 Bayesian Curve Fitting Python Implementation
Implementation and explanation using XGBoost for beginners
Explanation and implementation of the XMPP protocol used in Slack, HipChat, and IRC
Comparison of k-means implementation examples of scikit-learn and pyclustering
PRML Chapter 11 Markov Chain Monte Carlo Python Implementation
[Python] Chapter 02-01 Basics of Python programs (operations and variables)
PRML Chapter 12 Bayesian Principal Component Analysis Python Implementation
Implementation of TRIE tree with Python and LOUDS
Completely understood Chapter 1 of "Make and Move ALife"
Completely understood Chapter 3 of "Making and Moving ALife"
Deep Learning from scratch The theory and implementation of deep learning learned with Python Chapter 3
[Python of Hikari-] Chapter 06-02 Function (argument and return value 1)
Python --Explanation and usage summary of the top 24 packages
Sequential update of covariance to derivation and implementation of expressions
Implementation of Fibonacci sequence
I touched Wagtail (3). Investigation and implementation of pop-up messages.
Perceptron basics and implementation
Implementation of DB administrator screen by Flask-Admin and Flask-Login
Overview of generalized linear models and implementation in Python
Explanation of package tools and commands for Linux OS
[Python] Chapter 01-02 About Python (Execution and installation of development environment)
[Python of Hikari-] Chapter 08-03 Module (Import and use of standard library)
Python implementation of CSS3 blend mode and talk of color space
[Deep Learning from scratch] Implementation of Momentum method and AdaGrad method
Derivation and implementation of update equations for non-negative tensor factorization
[Python of Hikari-] Chapter 05-10 Control syntax (interruption and continuation of iteration)
[With simple explanation] Scratch implementation of deep Boltzmann machine with Python ②
[With simple explanation] Scratch implementation of deep Boltzmann machine with Python ①
Theory and implementation of multiple regression models-why regularization is needed-
Verification and implementation of video reconstruction method using GRU and Autoencoder
PRML Chapter 7 Related Vector Machine Python Implementation for Regression Problems
Quantum computer implementation of quantum walk 2
Mechanism of pyenv and virtualenv
Implementation of TF-IDF using gensim
Implementation of MathJax on Sphinx
Pre-processing and post-processing of pytest
Combination of recursion and generator
Combination of anyenv and direnv