[PYTHON] What is a decision tree?

What is a decision tree?

--A decision tree is an algorithm that creates a tree-like structure by repeatedly dividing data based on a simple ** criterion **.

-Applicable to both ** classification / regression problems **

-Decision trees are rarely used alone. (Apply, random forest, etc.) image.png

How do you determine the standard features and thresholds?

** (Impurity before division)-(Impurity after division) ** Determine the criteria for division so that

That is, the division is performed so that ** (impureness after division) becomes the minimum **.

** "Impurity" ** is an index showing how many different classes of observations are mixed.

For classification problems, it is ideal that one node has only one class of observations (impurity = 0).

Function representing impureness

--Misclassification rate (non-differentiable) -** Gini index (differentiable) ** --Cross entropy (differentiable)

Can be mentioned. (The sklearn defalut is set to ** Gini index **)

Concrete example

image.png

** Left: 1-(0/54) ^ 2-(49/54) ^ 2-(5/54) ^ 2 = 0.168 **

** Right: 1-(0/46) ^ 2-(1/46) ^ 2-(45/46) ^ 2 = 0.043 **

Therefore, ** overall purity ** is ** 54/100 x 0.168 + 46/100 x 0.043 = 0.111 ** (impureness after division)

Advantages and disadvantages of decision trees

merit

--Easy to understand --Applicable to both classification and regression --Widely applicable to all problems --No need to standardize data or create dummy variables

Demerit

--Large variance (* susceptible to outliers ) - Easy to overfit ** (nonparametric model) --The prediction surface is not smooth

How to avoid overfitting?

――In order to prevent overfitting, it is important to adjust ** parameters **.

In other words, set the upper limit of the depth of the tree ** (max-depth) ** and the minimum number of observations ** (min_samples_leaf) ** that one node must have.

Experiment ① (Classification problem)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split

moons=make_moons(n_samples=200,noise=0.1,random_state=0)

X=moons[0]
y=moons[1]

X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=0)
from sklearn.tree import DecisionTreeClassifier

tree_clf=DecisionTreeClassifier(min_samples_leaf=10).fit(X_train,y_train) #default no upper limit
tree_clf_3=DecisionTreeClassifier(max_depth=3).fit(X_train,y_train)

print(tree_clf.score(X_test,y_test))
print(tree_clf_3.score(X_test,y_test))

image.png

from matplotlib.colors import ListedColormap

def plot_decision_boundary(model,X,y):
    _x1 = np.linspace(X[:,0].min()-0.5,X[:,0].max()+0.5,100)
    _x2 = np.linspace(X[:,1].min()-0.5,X[:,1].max()+0.5,100)
    x1,x2 = np.meshgrid(_x1,_x2)
    X_new=np.c_[x1.ravel(),x2.ravel()]
    y_pred=model.predict(X_new).reshape(x1.shape)
    custom_cmap=ListedColormap(["mediumblue","orangered"])
    plt.contourf(x1,x2,y_pred,cmap=custom_cmap,alpha=0.3)
    
def plot_dataset(X,y):
    plt.plot(X[:,0][y==0],X[:,1][y==0],"bo",ms=15)
    plt.plot(X[:,0][y==1],X[:,1][y==1],"r^",ms=15)
    plt.xlabel("$x_1$",fontsize=30)
    plt.ylabel("$x_2$",fontsize=30,rotation=0)

plt.figure(figsize=(24,8))
plt.subplot(121)
plot_decision_boundary(tree_clf,X,y)
plot_dataset(X,y)

plt.subplot(122)
plot_decision_boundary(tree_clf_3,X,y)
plot_dataset(X,y)

plt.show()

image.png

Experiment ② (regression problem)

import mglearn
from sklearn.tree import DecisionTreeRegressor

reg_X,reg_y=mglearn.datasets.make_wave(n_samples=100)

tree_reg=DecisionTreeRegressor().fit(reg_X,reg_y)
tree_reg_3=DecisionTreeRegressor(max_depth=3).fit(reg_X,reg_y)
def plot_regression_predicitons(model,X,y):
    x1 = np.linspace(X.min()-1,X.max()+1,500).reshape(-1,1)
    y_pred=model.predict(x1)
    plt.xlabel("x",fontsize=30)
    plt.ylabel("y",fontsize=30,rotation=0)
    plt.plot(X,y,"bo",ms=15)
    plt.plot(x1,y_pred,"r-",linewidth=6)
    
plt.figure(figsize=(24,8))

plt.subplot(121)
plot_regression_predicitons(tree_reg,reg_X,reg_y)

plt.subplot(122)
plot_regression_predicitons(tree_reg_3,reg_X,reg_y)

plt.show()

image.png

Recommended Posts

What is a decision tree?
What is a terminal?
What is a hacker?
What is a pointer?
What is a Context Switch?
What is a super user?
What is a system call
[Definition] What is a framework?
What is a callback function?
What is a python map?
[Python] What is a zip function?
[Python] What is a with statement?
Creating a decision tree with scikit-learn
What is a lexical scope / dynamic scope?
What is a Convolutional Neural Network?
What is namespace
Decision tree (load_iris)
What is Django? .. ..
What is dotenv?
What is POSIX?
What is Linux
What is klass?
Decision tree (classification)
What is SALOME?
What is a dog? Django installation volume
What is a dog? Python installation volume
What is Linux?
What is python
What is hyperopt?
What is Linux
What is pyvenv
What is __call__
What is Linux
What is Python
What is a dog? Challenge Django templates! Volume
Create a decision tree from 0 with Python (1. Overview)
It's a Mac. What is the Linux command Linux?
What is a dog? Django--Create a custom user model 2
Tell me what a conformal map is, Python!
[Python] What is Pipeline ...
What is Calmar Ratio?
[PyTorch Tutorial ①] What is PyTorch?
What is hyperparameter tuning?
What is JSON? .. [Note]
What is Linux for?
What is ensemble learning?
What is TCP / IP?
What is Python's __init__.py?
What is an iterator?
What is UNIT-V Linux?
[Python] What is virtualenv
What is machine learning?
What is a dog? POST Sending Volume Using Django--forms.py
What is a dog? Django App Creation Start Volume--startapp
What is a dog? Django App Creation Start Volume--startproject
Basics of Python learning ~ What is a string literal? ~
Machine learning beginners try to make a decision tree
Binary search tree is a relative of Hash chain !?
What is a recommend engine? Summary of the types
What is God? Make a simple chatbot with python
To myself as a Django beginner (2) --What is MTV?