[PYTHON] Machine Learning: Supervised --Decision Tree


Understand the decision tree and try it with scikit-learn.


Decision trees are classified by constructing a tree structure by setting thresholds for features that are important for classification.

The decision tree is a highly semantically interpretable classification model that allows you to know which features are important for classification and which one is classified at what threshold by visualizing the tree structure. It can also be used for regression.

Decision tree

There are several types of decision tree algorithms, but here we follow the CART algorithm used by scikit-learn.

Decision tree concept

In a tree structure, the start of the tree is called the root node and the end of the tree is called the leaf node, as shown in the figure below. In each node, the one above is called the parent node and the one below is called the child node.


In CART, starting from the root node, the threshold is set and divided by the feature amount that maximizes the information gain. Divide this until the leaf node is pure, that is, all the categories contained in the leaf node are the same.

However, pruning alleviates this, as splitting until the leaf nodes are pure would result in overfitting.

Decision tree learning

The decision tree learning is done by maximizing the information gain $ IG $ in the equation below.

IG(D_{parent}, f) = I(D_{parent}) - \frac{N_{left}}{N_{parent}} I(D_{left}) - \frac{N_{right}}{N_{parent}} I(D_{right})

Here, $ f $ is the feature quantity to be divided, $ D_ {parent} $ is the data contained in the parent node, $ D_ {left}, D_ {right} $ is the data of the left and right child nodes, $ N_ { parent} $ is the number of data in the parent node, $ N_ {left}, N_ {right} $ is the number of data in the left and right child nodes, and $ I $ is the impureness described below.

During training, the smaller the purity of the left and right child nodes in each feature, the larger the information gain, and the feature will be divided based on the set threshold.

The following three are typical indicators used for the evaluation of purity. Here, $ C_i (i = 1, 2, .., K) $ is $ K $ category, $ t $ is node, and $ P (C_i | t) $ is data of that category in a node. Represents the probability of being

The classification error $ I_E $ is not sensitive to node changes and is used for tree pruning as described below.

I_E = 1 - \max_i P(C_i | t)

Entropy $ I_H $ is 0 when all the data contained in the node belongs to the same category.

I_H = -\sum^K_{i=1} P(C_i | t) \ln P(C_i | t)

Gini $ I_G $ can be interpreted as an indicator that minimizes the probability of misclassification, which is 0 when all the data contained in the node belongs to the same category, similar to entropy.

I_G = 1 - \sum^K_{i=1} P^2 (C_i | t)

Each function is as shown below, and gini is the default in scikit-learn.


Decision tree pruning

During training, the tree is deepened until the leaf nodes are pure, but if it is left as it is, it will be overfitting, so pruning is performed to alleviate this.

As an evaluation criterion for tree pruning, we define the reassignment error rate when training data is re-entered. The reassignment error rate $ R (t) $ at a node $ t $ is expressed by the following equation using the classification error $ I_E $ and the marginal probability $ P (t) $ of the node $ t $.

R(t) = \left( 1 - \max_i P(C_i | t) \right) \times P(t) \\
P(t) = \frac{N(t)}{N}

Where $ N (t) $ represents the number of data contained in node $ t $ and $ N $ represents the total number of training data.

Tree pruning removes tree branches based on this reassignment error rate. With scikit-learn, you can prun and secure the required number of nodes with the argument max_leaf_nodes.


Execution environment


-CPU Intel (R) Core (TM) i7-6700K 4.00GHz


・ Windows 10 Pro 1909 ・ Python 3.6.6 ・ Matplotlib 3.3.1 ・ Numpy 1.19.2 ・ Scikit-learn 0.23.2

Program to run

The implemented program is published on GitHub.




Classification by decision tree

I applied a decision tree to the breast cancer dataset I've been using so far. In the decision tree, the threshold value is determined for each feature amount, so there is no need to standardize the feature amount as preprocessing.

Accuracy 97.37%
Precision, Positive predictive value(PPV) 97.06%
Recall, Sensitivity, True positive rate(TPR) 98.51%
Specificity, True negative rate(TNR) 95.74%
Negative predictive value(NPV) 97.83%
F-Score 97.78%

Visualization and interpretability of decision trees

scikit-learn provides a plot_tree function that visualizes the decision tree, making it easy to see the tree structure of the trained model.


You can also display the criteria for judgment at each node of the trained model on the command line as follows:

The binary tree structure has 9 nodes and has the following tree structure:
node=0 test node: go to node 1 if X[:, 27] <= 0.1423499956727028 else to node 2.
	node=1 test node: go to node 3 if X[:, 23] <= 957.4500122070312 else to node 4.
	node=2 test node: go to node 5 if X[:, 23] <= 729.5499877929688 else to node 6.
		node=3 leaf node.
		node=4 leaf node.
		node=5 test node: go to node 7 if X[:, 4] <= 0.10830000042915344 else to node 8.
		node=6 leaf node.
			node=7 leaf node.
			node=8 leaf node.

Rules used to predict sample 0: 
decision id node 0 : (X_test[0, 27](= 0.2051) > 0.1423499956727028)
decision id node 2 : (X_test[0, 23](= 844.4) > 729.5499877929688)

The following samples [0, 1] share the node [0] in the tree
It is 11.11%% of all nodes.

The figure below shows the identification boundaries when performing a multiclass classification on an iris dataset.


Regression by decision tree

The data of the regression problem is a sine wave plus a random number. It can be seen that deepening the tree improves expressiveness.


The figure below is an example applied to a multi-output regression problem.



1.10. Decision Tree

  1. Yuzo Hirai. "First Pattern Recognition", Morikita Publishing, 2012.

Recommended Posts

Machine Learning: Supervised --Decision Tree
Machine learning ③ Summary of decision tree
Machine Learning: Supervised --AdaBoost
Machine Learning: Supervised --Linear Regression
Machine Learning: Supervised --Random Forest
Machine Learning: Supervised --Support Vector Machine
Supervised machine learning (classification / regression)
Machine learning
[Machine learning] Try studying decision trees
Machine Learning: Supervised --Linear Discriminant Analysis
Decision tree (load_iris)
[Machine learning] FX prediction using decision trees
Decision tree (classification)
[Machine learning] Supervised learning using kernel density estimation
Supervised learning (classification)
[Memo] Machine learning
Machine learning classification
Machine Learning sample
[Machine learning] Supervised learning using kernel density estimation Part 2
[Machine learning] Supervised learning using kernel density estimation Part 3
Machine learning tutorial summary
About machine learning overfitting
Machine learning ⑤ AdaBoost Summary
Machine learning logistic regression
Machine learning support vector machine
Studying Machine Learning ~ matplotlib ~
Machine learning linear regression
Machine learning course memo
Machine learning library dlib
Machine learning (TensorFlow) + Lotto 6
Somehow learn machine learning
Supervised learning (regression) 1 Basics
Python: Supervised Learning (Regression)
Machine learning library Shogun
Introduction to machine learning
Python: Supervised Learning (Classification)
Machine Learning: k-Nearest Neighbors
What is machine learning?
Machine learning learned with Pokemon
Data set for machine learning
Japanese preprocessing for machine learning
Machine learning in Delemas (practice)
An introduction to machine learning
Machine learning / classification related techniques
Machine learning beginners tried RBM
[Machine learning] Understanding decision trees from both scikit-learn and mathematics
[Machine learning] Understanding random forest
Decision tree and random forest
Machine Learning Study Resource Notepad
Machine learning ② Naive Bayes Summary
Understand machine learning ~ ridge regression ~.
Machine learning article summary (self-authored)
[Python] Decision Tree Personal Tutorial
Supervised learning 1 Basics of supervised learning (classification)
Machine learning Minesweeper with PyTorch
Machine learning environment construction macbook 2021
Build a machine learning environment
Python Machine Learning Programming> Keywords
Python: Supervised Learning: Hyperparameters Part 2
Used in machine learning EDA
Supervised learning (regression) 2 Advanced edition