[Python] I thoroughly explained the theory and implementation of decision trees

Introduction

This time, I will summarize the theory of decision trees.

I would appreciate it if you could get along with me.

Theory

First, let's summarize the theory of decision trees.

Outline of decision tree

The visualization of the decision tree is as follows.

This time, I visualized the classification by iris dataset using export_graphviz of scikit-learn. image.png

A decision tree is an algorithm that creates a model of data classification or regression by dividing data according to certain conditions, as shown in the image above. The classification tree that classifies and the regression tree that performs regression are collectively called a decision tree.

Because it is a fairly simple algorithm, it tends to be less accurate than other complex algorithms.

However, the interpretability of the model is very high.

It is called a decision tree because the model looks like a tree as shown above.

Typical algorithms include CART and C5.0.

Please refer to the article here for the algorithm.

Here, we will deal with the CART algorithm that repeats Niclas classification.

About the standard of decision tree

So far, I think you have a rough overview of the decision tree.

Here, we will consider the criteria for branching of classified trees and regression trees.

Classification tree criteria

Now, let's think about the criteria for branching the classification tree.

To write from the conclusion, the classification determines the features and thresholds based on the criterion of dividing so that impurity is minimized.

Impurity is a numerical value of how many classes are mixed, and is expressed using misclassification rate, Gini index, and cross entropy error. When there is one class of observation sites on one node, the purity is 0.

Regression tree criteria

The regression tree defines a cost function represented by the mean square error, and selects features and thresholds so that the weighted sum of the cost functions is minimized.

About formulas

Classification tree

Let's express the impureness using a mathematical formula. Consider the following figure. image.png

The ratio of observations of class k in the region $ R_m $ with $ N_m $ observations is defined as follows.

\hat{p}_{mk} = \frac{1}{N_m} \sum_{x_i\in R_m}^{}I(y_i = k)

Pay attention to the third row from the top of the figure. m represents the area number, m = 1 represents the area of gini = 0.168, and m = 2 represents the area of gini = 0.043. k represents the class label, and this time it is defined as class 1, class 2, and class 3 from the left of the value part.

It seems difficult to formulate it, but the actual calculation is as follows.

\hat{p}_{11} = \frac{0}{54} \quad \hat{p}_{12} = \frac{49}{54}  \quad \hat{p}_{13} = \frac{5}{54} 

Did you somehow understand the meaning of the formula?

Using this $ \ hat {p} $, the impureness is expressed by the following three functions.

Misclassification rate

\frac{1}{N_m} \sum_{x_i\in R_m}^{}I(y_i \neq k(m)) = 1-\hat{p}_{mk}

Gini index

1 - \sum_{k=1}^{K}\hat{p}_{mk}

Cross entropy error

-\sum_{k=1}^{K}\hat{p}_{mk}log\hat{p}_{mk}

The impure function used as standard in sklearn is the Gini index, so let's actually calculate the Gini index.

image.png

Calculate the impureness of the third stage using the Gini index.

The Gini index of the node gini = 0.168 on the left is as follows.

1 - (\frac{0}{54})^2 - (\frac{49}{54})^2 - (\frac{5}{54})^2 = 0.168

Naturally the answer is 0.168. The above formula is the formula that sklearn is doing internally. Let's also calculate the node on the right with gini = 0.043.

1 - (\frac{0}{46})^2 - (\frac{1}{46})^2 - (\frac{45}{46})^2 = 0.043

This also matched. Now let's calculate the overall impureness by weighting each one with the data. It becomes the following formula.

\frac{54}{100} ×0.168 + \frac{46}{100} ×0.043 = 0.111

With this, the total purity can be derived. The decision tree builds a model by selecting features and thresholds that reduce this impureness.

Regression tree

In the regression tree, define the cost function as follows.

\hat{c}_m = \frac{1}{N_m}\sum_{x_i \in R_m}^{}y_i\\
Q_m(T) = \frac{1}{N_m} \sum_{x_i \in R_m}(y_i - \hat{c}_m)^2

The cost function is the mean square error because $ \ hat {c} _m $ represents the average of the observations contained in that node.

Since this cost function is calculated for each node, set the features and thresholds so that the weighted sum of the cost function is minimized.

Implementation

By now, you should have understood the theory of classification trees and regression trees.

Now let's implement the decision tree.

Implementation of classification tree

Now we are implementing a classification tree.

First, create the data to classify.

from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
from matplotlib.colors import ListedColormap
import graphviz

moons = make_moons(n_samples=300, noise=0.2, random_state=0)
X = moons[0]
Y = moons[1]
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, random_state=0, stratify=Y)

plt.figure(figsize=(12, 8))
mglearn.discrete_scatter(X[:, 0], X[:, 1], Y)
plt.show()

image.png

We will create a model to classify this data.

Below is the code.

clf_model = DecisionTreeClassifier(max_depth=3)
clf_model.fit(X_train, Y_train)
print(clf_model.score(X_test, Y_test))

0.8933333333333333

The accuracy was reasonable.

Let's visualize the model of this classification tree. Below is the code.

dot_data = export_graphviz(clf_model)
graph = graphviz.Source(dot_data)
graph.render('moon-tree', format='png')

image.png

For visualization using graphviz, please refer to the article here.

As you can see, the decision tree is very interpretable in the model. In this model, max_depth = 3 is set, so the model has a depth of 3.

Let's visualize the model after classification. Below is the code.

plt.figure(figsize=(12, 8))
_x1 = np.linspace(X[:, 0].min() - 0.5, X[:, 0].max() + 0.5, 100)
_x2 = np.linspace(X[:, 1].min() - 0.5, X[:, 1].max() + 0.5, 100)
x1, x2 = np.meshgrid(_x1, _x2)
X_stack = np.hstack((x1.ravel().reshape(-1, 1), x2.ravel().reshape(-1, 1)))
y_pred = clf_model.predict(X_stack).reshape(x1.shape)
custom_cmap = ListedColormap(['mediumblue', 'orangered'])
plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=custom_cmap)
mglearn.discrete_scatter(X[:, 0], X[:, 1], Y)
plt.show()

image.png

Please refer to the article here for how to change the color using grid points.

_x1 and _x2 specify the region of the grid points in the x-axis and y-axis directions, and x1, x2 = np.meshgrid (_x1, _x2) creates the grid points.

One-dimensional 100 × 100 grid points at the part of X_stack = np.hstack ((x1.ravel (). Reshape (-1, 1), x2.ravel (). Reshape (-1, 1))) After changing to an array, it is converted to a 10000 x 1 two-dimensional array, combined horizontally, and converted to 10000 x 2 data.

In the part of y_pred = clf_model.predict (X_stack) .reshape (x1.shape), 10000 × 2 data is converted to 0 and 1 data, and it is converted to 100 × 100 data. One side of the line that separates the data is 0 and the other is 1.

Draw contour lines at plt.contourf (x1, x2, y_pred, alpha = 0.3, cmap = custom_cmap). The color is specified by cmap = custom_cmap.

This is the end of the implementation of the classification tree.

Regression tree implementation

Now let's implement a regression tree.

Let's prepare the data and draw it.

import mglearn
from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import export_graphviz
import graphviz

X, Y = mglearn.datasets.make_wave(n_samples=200)

plt.figure(figsize=(12, 8))
plt.plot(X, Y, 'bo', ms=15)
plt.show()

image.png

The data looks like this. Now let's create a model.

tree_reg_model = DecisionTreeRegressor(max_depth=3)
tree_reg_model.fit(X, Y)
print(tree_reg_model.score(X, Y))

0.7755211625482443

You can display $ R ^ 2 $ with score.

score(self, X, y[, sample_weight]) Returns the coefficient of determination R^2 of the prediction.

The accuracy is not very good.

Let's visualize the model with the following code.

dot_data = export_graphviz(tree_reg_model)
graph = graphviz.Source(dot_data)
graph.render('wave-tree', format='png')

image.png

As you can see, the regression tree (as well as the decision tree) has a very good model interpretability.

Now let's illustrate the regression line with the following code.

X1 = np.linspace(X.min() - 1, X.max() + 1, 1000).reshape(-1, 1)
y_pred = tree_reg_model.predict(X1)
plt.xlabel('x', fontsize=10)
plt.ylabel('y', fontsize=10, rotation=-0)
plt.plot(X, Y, 'bo', ms=5)
plt.plot(X1, y_pred, 'r-', linewidth=3)
plt.show()

image.png

As you can see from the illustration, it's not very correct.

Implemented without specifying depth

Now let's implement it without implementing depth.

It's the same except that you don't specify the depth, so let's put the same steps together in a function.

import mglearn
from sklearn.tree import DecisionTreeRegressor
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import export_graphviz
import graphviz

X, Y = mglearn.datasets.make_wave(n_samples=200)

tree_reg_model_2 = DecisionTreeRegressor()

tree_reg_model_2.fit(X, Y)

print(tree_reg_model_2.score(X, Y))

1.0

$ R ^ 2 $ is now 1. What kind of model is it? Let's illustrate the branch below.

def graph_export(model):
    dot_data = export_graphviz(model)
    graph = graphviz.Source(dot_data)
    graph.render('test', format='png')

graph_export(tree_reg_model_2)

image.png

It was a terrifying branch. I can't read how it branches anymore.

Let's illustrate the model with the following code.

def plot_regression_predictions(tree_reg, x, y):
    x1 = np.linspace(x.min() - 1, x.max() + 1, 500).reshape(-1, 1)
    y_pred = tree_reg.predict(x1)
    plt.xlabel('x', fontsize=10)
    plt.ylabel('y', fontsize=10, rotation=-0)
    plt.plot(x, y, 'bo', ms=5)
    plt.plot(x1, y_pred, 'r-', linewidth=1)
    plt.show()


plot_regression_predictions(tree_reg_model_2, X, Y)

image.png

As you can see from the figure above, this is clearly overfitting.

This makes it impossible to predict unknown data, so you can understand the importance of setting the depth of the tree appropriately.

At the end

That's all for this time.

Thank you for staying with us so far.

Recommended Posts

[Python] I thoroughly explained the theory and implementation of decision trees
[Python] I thoroughly explained the theory and implementation of logistic regression
[Python] I thoroughly explained the theory and implementation of support vector machine (SVM)
Deep Learning from scratch The theory and implementation of deep learning learned with Python Chapter 3
I checked out the versions of Blender and Python
Build a python environment to learn the theory and implementation of deep learning
Visualize the results of decision trees performed with Python scikit-learn
I want to know the features of Python and pip
The story of Python and the story of NaN
I compared the speed of Hash with Topaz, Ruby and Python
[Introduction to Python] I compared the naming conventions of C # and Python.
Verification of the theory that "Python and Swift are quite similar"
The Python project template I think of.
I read the implementation of golang channel
I tried to verify and analyze the acceleration of Python by Cython
I measured the speed of list comprehension, for and while with python2.7.
I read the implementation of range (Objects / rangeobject.c)
Summary of the differences between PHP and Python
The answer of "1/2" is different between python2 and 3
Why the Python implementation of ISUCON 5 used Bottle
Compare the speed of Python append and map
Implementation of TRIE tree with Python and LOUDS
I touched some of the new features of Python 3.8 ①
I / O related summary of python and fortran
I read and implemented the Variants of UKR
About the * (asterisk) argument of python (and itertools.starmap)
A discussion of the strengths and weaknesses of Python
Explanation of edit distance and implementation in Python
I compared the speed of go language web framework echo and python web framework flask
I compared the speed of regular expressions in Ruby, Python, and Perl (2013 version)
[Python] Comparison of Principal Component Analysis Theory and Implementation by Python (PCA, Kernel PCA, 2DPCA)
[Trainer's Recipe] I touched the flame of the Python framework.
The story of Python without increment and decrement operators.
The process of installing Atom and getting Python running
Build API server for checking the operation of front implementation with python3 and Flask
Python --Explanation and usage summary of the top 24 packages
Python practice 100 knocks I tried to visualize the decision tree of Chapter 5 using graphviz
I followed the implementation of the du command (first half)
I tried to automate the article update of Livedoor blog with Python and selenium.
Visualize the range of interpolation and extrapolation with python
A python implementation of the Bayesian linear regression class
Referencing and changing the upper bound of Python recursion
I compared the speed of the reference of the python in list and the reference of the dictionary comprehension made from the in list.
I checked the default OS and shell of docker-machine
I followed the implementation of the du command (second half)
I touched Wagtail (3). Investigation and implementation of pop-up messages.
Overview of generalized linear models and implementation in Python
A reminder about the implementation of recommendations in Python
I tried to compare the processing speed with dplyr of R and pandas of Python
I tried to summarize the string operations of Python
I tried to get the number of days of the month holidays (Saturdays, Sundays, and holidays) with python
I considered the machine learning method and its implementation language from the tag information of Qiita
Python implementation of CSS3 blend mode and talk of color space
I tried "gamma correction" of the image with Python + OpenCV
A simple Python implementation of the k-nearest neighbor method (k-NN)
I just changed the sample source of Python a little.
[Python] Heron's formula functionalization and calculation of the maximum area
I wrote the basic grammar of Python with Jupyter Lab
I evaluated the strategy of stock system trading with Python.
[python] plot the values ​​before and after the conversion of yeojohnson conversion
[Python] I installed the game from pip and played it