2. Multivariate analysis spelled out in Python 7-1. Decision tree (scikit-learn)

** Here, let's first go through an example of a classified tree. ** **

⑴ Import library

#A class that builds a decision tree model
from sklearn.tree import DecisionTreeClassifier
#Module based on decision tree model
from sklearn import tree

#Package of dataset for machine learning
from sklearn import datasets
#Utility for splitting data
from sklearn.model_selection import train_test_split

#Module to display images in Notebook
from IPython.display import Image  
#Module for visualizing decision tree model
import pydotplus

⑵ Data acquisition and reading

iris = datasets.load_iris()
print(iris)
Variable name meaning Note Data type
1 sepal length Sepal length Continuous amount(cm) float64
2 sepal width Sepal width Continuous amount(cm) float64
3 petal length Petal length Continuous amount(cm) float64
4 petal width Petal width Continuous amount(cm) float64
5 species Type Setosa=1, Versicolour=2, Virginica=3 int64
#Label of explanatory variable
print(iris.feature_names)

#Explanatory variable shape
print(iris.data.shape)

#Show the first 5 lines of the explanatory variable
iris.data[0:5, :]

2_7_1_01.PNG

#Objective variable label
print(iris.target_names)

#Shape of objective variable
print(iris.target.shape)

#Show objective variable
iris.target

2_7_1_02.PNG

(3) Data preprocessing

#Store explanatory variables and objective variables respectively
X = iris.data
y = iris.target

#Separated for training and testing
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

⑷ Model construction and evaluation of decision trees

#Initialize the class that builds the decision tree model
clf = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=0)

#Generate decision tree model
model = clf.fit(X_train, y_train)

#Calculate the correct answer rate for each training and test
print('Correct answer rate(train):{:.3f}'.format(model.score(X_train, y_train)))
print('Correct answer rate(test):{:.3f}'.format(model.score(X_test, y_test)))

2_7_1_03.PNG

⑸ Drawing a tree diagram

  1. ** Convert decision tree model to DOT data **
  2. ** Draw a diagram from DOT data **
  3. ** Convert to png and display in Notebook **
#Convert decision tree model to DOT data
dot_data = tree.export_graphviz(model,                              #Specify decision tree model
                                out_file = None,                    #Specifies to return a string instead of an output file
                                feature_names = iris.feature_names, #Specify the display name of the feature amount
                                class_names = iris.target_names,    #Specify the display name of the classification
                                filled = True)                      #Color nodes in the majority class

#Draw a diagram
graph = pydotplus.graph_from_dot_data(dot_data)  

#View diagram
Image(graph.create_png())

2_7_1_04.PNG

How to read a tree diagram

2_7_1_05.PNG

Supplement

#Export to png file
graph.write_png("iris.png ")

#Download from google colaboratory
from google.colab import files
files.download('iris.png')

2_7_1_06.PNG

Recommended Posts

2. Multivariate analysis spelled out in Python 7-1. Decision tree (scikit-learn)
2. Multivariate analysis spelled out in Python 7-3. Decision tree [regression tree]
2. Multivariate analysis spelled out in Python 7-2. Decision tree [difference in division criteria]
2. Multivariate analysis spelled out in Python 1-1. Simple regression analysis (scikit-learn)
2. Multivariate analysis spelled out in Python 2-1. Multiple regression analysis (scikit-learn)
2. Multivariate analysis spelled out in Python 3-1. Principal component analysis (scikit-learn)
2. Multivariate analysis spelled out in Python 8-1. K-nearest neighbor method (scikit-learn)
2. Multivariate analysis spelled out in Python 6-2. Ridge regression / Lasso regression (scikit-learn) [Ridge regression vs. Lasso regression]
2. Multivariate analysis spelled out in Python 3-2. Principal component analysis (algorithm)
2. Multivariate analysis spelled out in Python 6-1. Ridge regression / Lasso regression (scikit-learn) [multiple regression vs. ridge regression]
2. Multivariate analysis spelled out in Python 1-2. Simple regression analysis (algorithm)
2. Multivariate analysis spelled out in Python 6-3. Ridge regression / Lasso regression (scikit-learn) [How regularization works]
2. Multivariate analysis spelled out in Python 5-3. Logistic regression analysis (stats models)
2. Multivariate analysis spelled out in Python 8-3. K-nearest neighbor method [cross-validation]
2. Multivariate analysis spelled out in Python 2-3. Multiple regression analysis [COVID-19 infection rate]
2. Multivariate analysis spelled out in Python 8-2. K-nearest neighbor method [Weighting method] [Regression model]
Association analysis in Python
Regression analysis in Python
[Python] Decision Tree Personal Tutorial
Simple regression analysis in Python
Scikit-learn decision Generate Python code from tree / random forest rules
EEG analysis in Python: Python MNE tutorial
First simple regression analysis in Python
Creating a decision tree with scikit-learn
[Python] PCA scratch in the example of "Introduction to multivariate analysis"
Compiler in Python: PL / 0 syntax tree
I can't install scikit-learn in Python
Planar skeleton analysis in Python (2) Hotfix
Algorithm (segment tree) in Python (practice)
Linear regression in Python (statmodels, scikit-learn, PyMC3)
Output tree structure of files in Python
Draw a tree in Python 3 using graphviz
Delayed segment tree in Python (debug request)
Manipulate namespaced XML in Python (Element Tree)
Residual analysis in Python (Supplement: Cochrane rules)
Survival time analysis learned in Python 2 -Kaplan-Meier estimator
Perform entity analysis using spaCy / GiNZA in Python
Data analysis in Python: A note about line_profiler
[Environment construction] Dependency analysis using CaboCha in Python 2.7
Compiler in Python: PL / 0 Abstract Syntax Tree (AST)
Create a decision tree from 0 with Python (1. Overview)
A well-prepared record of data analysis in Python
Put out a shortened URL string in Python