[PYTHON] Decision tree (classification)

What is a decision tree?

In the field of machine learning, decision trees are predictive models, and observations of a certain item lead to conclusions about the target value of that item. Internal nodes correspond to variables, and branches to child nodes indicate possible values for that variable. The leaf (end point) represents the predicted value of the objective variable with respect to the variable value represented by the route from the root. wikipedia

This is difficult to understand, so it looks like this when actually represented in a diagram. 207486.jpg

This is a classification of decision trees that determine whether to play tennis. In this way, everyday decisions can also be represented by decision trees.

About the depth of the decision tree

There is a depth called depth in the decision tree, and in the example of whether or not to play tennis, the weather → humidity corresponds to the depth of the decision tree. On the other hand, not only is the humidity low, but in response to the case of "the humidity is high because it rained the day before, or what is the state of the coat now?", "What was the weather the day before?" You can do it more rigorously by asking a lot of questions. However, if you add more and more classification rules using the above method, you will end up with a large number of nodes and a very deep decision tree. This is overfitting (overfitting) for training data. A model with high variance (the amount of variation in predicted values) has been completed.

Relationship between max_depth (decision tree depth) and Accuracy

When max_depth is 3-4, Accuracy becomes maximum, but when max_depth is 4 or later, Accuracy decreases. This is due to overfitting caused by making max_depth too large. In this way, the decision tree must be pruned (cut) to the proper tree depth. By pruning, the generalization performance of the model can be improved. In the program, max_depth, which maximizes Accuracy, is 3.

download.png

Description of the data

A dataset that summarizes breast cancer diagnostic data. Each case has 32 values, including inspection values, making it a highly variable dataset. Each case is accompanied by a diagnosis result of a benign tumor or a malignant tumor, and this is the objective variable for classification learning.

code


%%time
import matplotlib.pyplot as plt 
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

#Reading breast cancer data
cancer_data = load_breast_cancer()

#Separation of training data and test data
train_X, test_X, train_y, test_y = train_test_split(cancer_data.data, cancer_data.target, random_state=0)

#Condition setting
max_score = 0
accuracy = []
depth_list = [i for i in range(1, 15)]

#Execution of decision tree
for depth in tqdm(depth_list):
    clf = DecisionTreeClassifier(criterion="entropy", splitter="random", max_depth=depth, min_samples_split=4, min_samples_leaf=1, random_state=56)
    clf.fit(train_X, train_y)
    accuracy.append(clf.score(test_X, test_y))
    if max_score < clf.score(test_X, test_y):
        max_score = clf.score(test_X, test_y)
        depth_ = depth

#Convert decision tree to dot file
clf = DecisionTreeClassifier(criterion="entropy", splitter="random", max_depth=depth_, min_samples_split=4, min_samples_leaf=1, random_state=56)
clf.fit(train_X, train_y)
tree.export_graphviz(clf, out_file="tree.dot", feature_names=cancer_data.feature_names, class_names=cancer_data.target_names, filled=True)

#Graph plot
plt.plot(depth_list, accuracy)
plt.title("Accuracy change")
plt.xlabel("max_depth")
plt.ylabel("Accuracy")
plt.show()

print("max_depth:{}".format(depth_))
print("Best score:{}".format(max_score))

output

max_depth:3
Best score:0.9300699300699301
CPU times: user 228 ms, sys: 6.02 ms, total: 234 ms
Wall time: 237 ms

Summary

I think the decision tree is an algorithm that is fairly easy to understand in machine learning, so I thought it would be easy to use when explaining to people who do not understand machine learning.

Recommended Posts

Decision tree (classification)
Decision tree (load_iris)
Decision tree and random forest
What is a decision tree?
[Python] Decision Tree Personal Tutorial
Machine Learning: Supervised --Decision Tree
Decision tree (for beginners) -Code edition-
Creating a decision tree with scikit-learn
Machine learning ③ Summary of decision tree
Understand the Decision Tree and classify documents
I tried to understand the decision tree (CART) that makes the classification carefully
Create a decision tree from 0 with Python (1. Overview)