2. Multivariate analysis spelled out in Python 7-2. Decision tree [difference in division criteria]

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

Comparison of each index in 2 class classification

#Generate arithmetic progression corresponding to p
xx = np.linspace(0, 1, 50) #Start value 0, end value 1, number of elements 50

plt.figure(figsize=(10, 8))

#Calculate each index
gini = [2 * x * (1-x) for x in xx]
entropy = [-x * np.log2(x) - (1-x) * np.log2(1-x)  for x in xx]
misclass = [1 - max(x, 1-x) for x in xx]

#Show graph
plt.plot(xx, gini, label='gini', lw=3, color='b')
plt.plot(xx, entropy, label='entropy', lw=3, color='r')
plt.plot(xx, misclass, label='misclass', lw=3, color='g')

plt.xlabel('p', fontsize=15)
plt.ylabel('criterion', fontsize=15)
plt.legend(fontsize=15)
plt.xticks(fontsize=12) 
plt.yticks(fontsize=12) 
plt.grid()

2_7_2_01.PNG

#Generate arithmetic progression corresponding to p
xx = np.linspace(0, 1, 50) #Start value 0, end value 1, number of elements 50

plt.figure(figsize=(10, 8))

#Calculate each index
gini = [2 * x * (1-x) for x in xx]
entropy = [(x * np.log((1-x)/x) - np.log(1-x)) / (np.log(2))  for x in xx]
entropy_scaled = [(x * np.log((1-x)/x) - np.log(1-x)) / (2*np.log(2))  for x in xx]
misclass = [1 - max(x, 1-x) for x in xx]

#Show graph
plt.plot(xx, gini, label='gini', lw=3, color='b')
plt.plot(xx, entropy, label='entropy', lw=3, color='r', linestyle='dashed')
plt.plot(xx, entropy_scaled, label='entropy(scaled)', lw=3, color='r')
plt.plot(xx, misclass, label='misclass', lw=3, color='g')

plt.xlabel('p', fontsize=15)
plt.ylabel('criterion', fontsize=15)
plt.legend(fontsize=15)
plt.xticks(fontsize=12) 
plt.yticks(fontsize=12) 
plt.grid()

2_7_2_02.PNG

Differences in each index in information gain

1. What is information gain?

2. Calculate the information gain for each index

➀ Information gain due to Gini impure

#Gini purity of parent node
IGg_p = 2 * 1/2 * (1-(1/2))

#Gini purity of child node A
IGg_A_l = 2 * 3/4 * (1-(3/4)) #left
IGg_A_r = 2 * 1/4 * (1-(1/4)) #right
#Gini impureness of child node B
IGg_B_l = 2 * 2/6 * (1-(2/6)) #left
IGg_B_r = 2 * 2/2 * (1-(2/2)) #right

#Information gain of each branch
IG_gini_A = IGg_p - 4/8 * IGg_A_l - 4/8 * IGg_A_r
IG_gini_B = IGg_p - 6/8 * IGg_B_l - 2/8 * IGg_B_r

print("Information gain of branch A:", IG_gini_A)
print("Information gain of branch B:", IG_gini_B)

2_7_2_04.PNG

➁ Information gain from entropy

#Parent node entropy
IGe_p = (4/8 * np.log((1-4/8)/(4/8)) - np.log(1-4/8)) / (np.log(2))

#Entropy of child node A
IGe_A_l = (3/4 * np.log((1-3/4)/(3/4)) - np.log(1-3/4)) / (np.log(2)) #left
IGe_A_r = (1/4 * np.log((1-1/4)/(1/4)) - np.log(1-1/4)) / (np.log(2)) #right
#Entropy of child node B
IGe_B_l = (2/6 * np.log((1-2/6)/(2/6)) - np.log(1-2/6)) / (np.log(2)) #left
IGe_B_r = (2/2 * np.log((1-2/2+1e-7)/(2/2)) - np.log(1-2/2+1e-7)) / (np.log(2)) #right,+1e-7 adds a small value to avoid division by zero

#Information gain of each branch
IG_entropy_A = IGe_p - 4/8 * IGe_A_l - 4/8 * IGe_A_r
IG_entropy_B = IGe_p - 6/8 * IGe_B_l - 2/8 * IGe_B_r
print("Information gain of branch A:", IG_entropy_A)
print("Information gain of branch B:", IG_entropy_B)

2_7_2_05.PNG

➂ Information gain due to misclassification rate

#Misclassification rate of parent node
IGm_p = 1 - np.maximum(4/8, 1-4/8)

#Misclassification rate of child node A
IGm_A_l = 1 - np.maximum(3/4, 1-3/4) #left
IGm_A_r = 1 - np.maximum(1/4, 1-1/4) #right
#Misclassification rate of child node B
IGm_B_l = 1 - np.maximum(2/6, 1-2/6) #left
IGm_B_r = 1 - np.maximum(2/2, 1-2/2) #right

#Information gain of each branch
IG_misclass_A = IGm_p - 4/8 * IGm_A_l - 4/8 * IGm_A_r
IG_misclass_B = IGm_p - 6/8 * IGm_B_l - 2/8 * IGm_B_r
print("Information gain of branch A:", IG_misclass_A)
print("Information gain of branch B:", IG_misclass_B)

2_7_2_06.PNG

Summary

Classification condition A Classification condition B
Gini Impure 0.125 0.167
Entropy 0.189 0.311
Misclassification rate 0.250 0.250

Recommended Posts

2. Multivariate analysis spelled out in Python 7-2. Decision tree [difference in division criteria]
2. Multivariate analysis spelled out in Python 7-3. Decision tree [regression tree]
2. Multivariate analysis spelled out in Python 7-1. Decision tree (scikit-learn)
2. Multivariate analysis spelled out in Python 3-2. Principal component analysis (algorithm)
2. Multivariate analysis spelled out in Python 2-1. Multiple regression analysis (scikit-learn)
2. Multivariate analysis spelled out in Python 1-2. Simple regression analysis (algorithm)
2. Multivariate analysis spelled out in Python 3-1. Principal component analysis (scikit-learn)
2. Multivariate analysis spelled out in Python 8-1. K-nearest neighbor method (scikit-learn)
2. Multivariate analysis spelled out in Python 5-3. Logistic regression analysis (stats models)
2. Multivariate analysis spelled out in Python 8-3. K-nearest neighbor method [cross-validation]
2. Multivariate analysis spelled out in Python 6-2. Ridge regression / Lasso regression (scikit-learn) [Ridge regression vs. Lasso regression]
2. Multivariate analysis spelled out in Python 2-3. Multiple regression analysis [COVID-19 infection rate]
2. Multivariate analysis spelled out in Python 6-1. Ridge regression / Lasso regression (scikit-learn) [multiple regression vs. ridge regression]
2. Multivariate analysis spelled out in Python 8-2. K-nearest neighbor method [Weighting method] [Regression model]
2. Multivariate analysis spelled out in Python 6-3. Ridge regression / Lasso regression (scikit-learn) [How regularization works]
Association analysis in Python
Regression analysis in Python
Find the difference in Python
Axisymmetric stress analysis in Python
[Python] Decision Tree Personal Tutorial
Simple regression analysis in Python
[Python] PCA scratch in the example of "Introduction to multivariate analysis"
EEG analysis in Python: Python MNE tutorial
First simple regression analysis in Python
Difference between == and is in python
Division of timedelta in Python 2.7 series
Compiler in Python: PL / 0 syntax tree
Planar skeleton analysis in Python (2) Hotfix
Algorithm (segment tree) in Python (practice)