[PYTHON] Machine Learning: Supervised --Linear Discriminant Analysis

Target

Understand Fisher's linear discriminant analysis with mathematical formulas and try it with scikit-learn.

It is assumed that you have already learned calculus and linear algebra.

theory

Fisher's linear discriminant analysis is a supervised method that finds $ w $ so that the distributions between categories do not overlap after projecting the data, and although it is named discriminant, it is practically used for dimensionality reduction.

Fisher's linear discriminant analysis

Derivation of Fisher's Linear Discriminant Analysis

When the data $ x $ is projected by $ w $, the projected data $ y $ is

y = w^T x

It will be. At this time, we will find $ w $ so that the distribution of categories in $ y $ is as far apart as possible. In the case shown below, the optimal $ w $ projects the blue point cloud and the orange point cloud onto the hollow points on the black straight line.

105_lda_reduction.png

Now consider two categories of data, as shown above. The average vector of categories 1 and 2 can be expressed as follows.

\mu_1 = \frac{1}{N_1} \sum^{N_1}_{i \in C_1} x_i \\
\mu_2 = \frac{1}{N_2} \sum^{N_2}_{i \in C_2} x_i

When the mean vector projected by $ w $ is represented by $ m_1 = w ^ T \ mu_1, m_2 = w ^ T \ mu_2 $, the difference between the mean values after projection

m_1 - m_2 = w^T (\mu_1 - \mu_2)

The larger is, the greater the degree of separation between categories. Therefore,wWill be the maximum. But what I really want iswBecause it is the projection direction of|w|^2 = 1I will add a constraint. However, this alone will not work, so consider the distribution of each category. Intraclass variance after projection of each categorys^2_1, s^2_2Is

s^2_1 = \sum^{N_1}_{i \in C_1} (w^T x_i - w^T \mu_1)^2 \\
s^2_2 = \sum^{N_2}_{i \in C_2} (w^T x_i - w^T \mu_2)^2

The smaller the variance after projection, the better, so we should minimize the intraclass variance $ s ^ 2 = s ^ 2_1 + s ^ 2_2 $, which is the sum of $ s ^ 2_1 and s ^ 2_2 $.

Here, we define the following Fisher criteria $ J (w) $ as an evaluation function that considers both the maximization of the mean value after projection and the minimization of the variance after projection.

J(w) = \frac{(m_1 - m_2)^2}{s^2_1 + s^2_2}

Also, if the interclass covariance matrix is $ S_B = (\ mu_1-\ mu_2) (\ mu_1-\ mu_2) ^ T $, the interclass variation $ (m_1 --m_2) ^ 2 $ is

\begin{align}
(m_1 - m_2)^2 &= \left( w^T(\mu_1 - \mu_2) \right)^2 \\
&= \left( w^T(\mu_1 - \mu_2) \right) \left( w^T(\mu_1 - \mu_2) \right)^T \\
&= w^T (\mu_1 - \mu_2)(\mu_1 - \mu_2)^T w \\
&= w^T S_B w
\end{align}

Can be expressed as. In addition, the intraclass variance $ s ^ 2_k $ is

\begin{align}
s^2_k &= \sum_{i \in C_k} (y_i - m_k)^2 \\
&= \sum_{i \in C_k} \left( w^T (x_i - \mu_k) \right)^2 \\
&= \sum_{i \in C_k} \left( w^T(x_i - \mu_k) \right) \left( w^T(x_i - \mu_k) \right)^T \\
&= w^T \sum_{i \in C_k} (x_i - \mu_k)(x_i - \mu_k)^T w \\
&= w^T S_k w
\end{align}

Therefore, the variance within all classes $ s ^ 2_1 + s ^ 2_2 $ sets the covariance matrix within all classes to $ S_W = S_1 + S_2 $.

s^2_1 + s^2_2 = w^T (S_1 + S_2) w = w^T S_W w

Can be expressed as. Therefore, Fisher's reference $ J (w) $ is

J(w) = \frac{w^T S_B w}{w^T S_W w}

And will maximize this.

Learning Fisher's Linear Discriminant Analysis

Since we need to find the maximum value, we differentiate Fisher's reference $ J (w) $ with respect to $ w $ and solve it as 0.

\begin{align}
\frac{\partial J(w)}{\partial w} &= \frac{2S_B w \cdot w^TS_Ww - w^TS_Bw \cdot 2S_Ww}{(w^TS_Ww)^2} \\
&= \frac{2}{w^TS_Ww} \left( S_Bw - \frac{w^TS_Bw}{w^TS_Ww} S_Ww \right) = 0
\end{align}

Here, set $ \ lambda = \ frac {w ^ TS_Bw} {w ^ TS_Ww} $

\frac{\partial J(w)}{\partial w} = \frac{2}{w^TS_Ww} (S_Bw - \lambda S_Ww) = 0 \\
(S_Bw - \lambda S_Ww) = 0

Therefore, we will solve the generalized eigenvalues problem of the following equation.

S_Bw = \lambda S_Ww

Here, if $ S_W $ is an invertible matrix,

\lambda w = S^{-1}_WS_Bw

And it becomes a normal eigenvalue problem. further,

S_Bw = (\mu_1 - \mu_2)(\mu_1 - \mu_2)^Tw \propto (\mu_1 - \mu_2)

Because it becomes

w \propto S^{-1}_WS_Bw \propto S^{-1}_W (\mu_1 - \mu_2)

You can find the optimal $ w $ as.

Implementation

Execution environment

hardware

-CPU Intel (R) Core (TM) i7-6700K 4.00GHz

software

・ Windows 10 Pro 1909 ・ Python 3.6.6 ・ Matplotlib 3.3.1 ・ Numpy 1.19.2 ・ Scikit-learn 0.23.2

Program to run

The implemented program is published on GitHub.

fisher_lda.py


result

This time, I decided to use the iris dataset provided by scikit-learn.

The execution result is as follows. The setosa is well separated and the versicolor and virginica are partially covered, but they appear to be reasonably separated.

105_fisher_lda.png

reference

1.2. Linear and Quadratic Discriminant Analysis

Yuzo Hirai. "First Pattern Recognition", Morikita Publishing, 2012.

Recommended Posts

Machine Learning: Supervised --Linear Discriminant Analysis
Machine Learning: Supervised --Linear Regression
Machine Learning: Supervised --AdaBoost
Machine learning linear regression
Machine Learning: Supervised --Random Forest
Machine Learning: Supervised --Support Vector Machine
Machine Learning: Supervised --Decision Tree
Python Scikit-learn Linear Regression Analysis Nonlinear Simple Regression Analysis Machine Learning
Image binarization using linear discriminant analysis
Machine Learning: Supervised --Linear Discriminant Analysis
Machine learning beginners try linear regression
Machine learning algorithm (multiple regression analysis)
Machine learning algorithm (simple regression analysis)
Machine learning
Image binarization using linear discriminant analysis
Machine learning algorithm (generalization of linear regression)
Machine learning with python (2) Simple regression analysis
[Machine learning] Supervised learning using kernel density estimation
<Course> Machine Learning Chapter 1: Linear Regression Model
[Python] First data analysis / machine learning (Kaggle)
<Course> Machine learning Chapter 4: Principal component analysis
Machine learning algorithm (linear regression summary & regularization)
Preprocessing in machine learning 1 Data analysis process
Supervised learning (classification)
[Memo] Machine learning
Machine learning classification
Machine Learning sample
[Machine learning] Supervised learning using kernel density estimation Part 2
EV3 x Python Machine Learning Part 2 Linear Regression
[Python] Data analysis, machine learning practice (Kaggle) -Data preprocessing-
[Machine learning] Supervised learning using kernel density estimation Part 3
Analysis of shared space usage by machine learning
A story about data analysis by machine learning
Machine learning tutorial summary
About machine learning overfitting
Machine learning logistic regression
Machine learning support vector machine
Machine learning course memo
Machine learning library dlib
Machine learning (TensorFlow) + Lotto 6
Coursera Machine Learning Challenges in Python: ex1 (Linear Regression)
Somehow learn machine learning
Supervised learning (regression) 1 Basics
Python: Supervised Learning (Regression)
Machine learning library Shogun
Machine learning rabbit challenge
Introduction to machine learning
Python: Supervised Learning (Classification)
Machine Learning: k-Nearest Neighbors
What is machine learning?
Python learning memo for machine learning by Chainer Chapter 7 Regression analysis
Coursera Machine Learning Challenges in Python: ex7-2 (Principal Component Analysis)
Machine learning model considering maintainability
Machine learning learned with Pokemon
Data set for machine learning
Machine learning in Delemas (practice)
An introduction to machine learning
Ensemble learning and basket analysis
Machine learning / classification related techniques
Basics of Machine Learning (Notes)
Python: Supervised Learning: Hyperparameters Part 1
Machine learning beginners tried RBM