[PYTHON] Arbre de décision (classification)

Qu'est-ce qu'un arbre de décision?

Dans le domaine de l'apprentissage automatique, l'arbre de décision est un modèle prédictif, et les résultats de l'observation pour un certain élément conduisent à la conclusion sur la valeur cible de l'élément. Les nœuds internes correspondent aux variables et les branches aux nœuds enfants indiquent les valeurs possibles de la variable. La feuille (point final) représente la valeur prédite de la variable objectif par rapport à la valeur de la variable représentée par l'itinéraire depuis la racine. wikipedia

Ceci est difficile à comprendre, donc cela ressemble à ceci lorsqu'il est représenté dans un diagramme. 207486.jpg

Il s'agit d'une classification des arbres de décision qui détermine s'il faut ou non jouer au tennis. De cette manière, les décisions quotidiennes peuvent également être représentées par un arbre de décision.

À propos de la profondeur de l'arbre de décision

Il y a une profondeur appelée la profondeur de l'arbre de décision, et dans l'exemple de jouer ou non au tennis, le temps → humidité correspond à la profondeur de l'arbre de décision. Par contre, non seulement le taux d'humidité est bas, mais en réponse au cas de «parce qu'il a plu la veille, quel est l'état du manteau maintenant?», «Quel temps faisait-il la veille? Vous pouvez le faire plus rigoureusement en posant beaucoup de questions. Cependant, si vous ajoutez de plus en plus de règles de classification à l'aide de la méthode ci-dessus, vous vous retrouverez avec un grand nombre de nœuds et un arbre de décision très profond. Approprié) est affiché et un modèle à forte variance (la quantité de variation des valeurs prédites) a été complété.

Relation entre max_depth (profondeur de l'arbre de décision) et précision

Lorsque max_depth est égal à 3-4, la précision devient maximale, mais lorsque max_depth est égale ou supérieure à 4, la précision diminue. Cela est dû au surapprentissage causé par le fait de rendre max_depth trop grand. De cette manière, l'arbre déterminé doit être élagué (coupé) à la profondeur appropriée. Vous pouvez améliorer les performances de généralisation du modèle par l'élagage. Dans le programme, max_depth, qui maximise la précision, correspond à 3.

download.png

Description des données

Il s'agit d'un ensemble de données qui résume les données de diagnostic du cancer du sein. Chaque cas a 32 valeurs, y compris des valeurs d'inspection, ce qui en fait un ensemble de données riche en variables. Chaque cas est accompagné d'un résultat de diagnostic d'une tumeur bénigne ou d'une tumeur maligne, et l'apprentissage de la classification est effectué en utilisant cela comme variable objective.

code


%%time
import matplotlib.pyplot as plt 
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

#Lire les données sur le cancer du sein
cancer_data = load_breast_cancer()

#Séparation des données d'entraînement et des données de test
train_X, test_X, train_y, test_y = train_test_split(cancer_data.data, cancer_data.target, random_state=0)

#Réglage de la condition
max_score = 0
accuracy = []
depth_list = [i for i in range(1, 15)]

#Exécution de l'arbre de décision
for depth in tqdm(depth_list):
    clf = DecisionTreeClassifier(criterion="entropy", splitter="random", max_depth=depth, min_samples_split=4, min_samples_leaf=1, random_state=56)
    clf.fit(train_X, train_y)
    accuracy.append(clf.score(test_X, test_y))
    if max_score < clf.score(test_X, test_y):
        max_score = clf.score(test_X, test_y)
        depth_ = depth

#Convertir l'arbre de décision en fichier dot
clf = DecisionTreeClassifier(criterion="entropy", splitter="random", max_depth=depth_, min_samples_split=4, min_samples_leaf=1, random_state=56)
clf.fit(train_X, train_y)
tree.export_graphviz(clf, out_file="tree.dot", feature_names=cancer_data.feature_names, class_names=cancer_data.target_names, filled=True)

#Graphique graphique
plt.plot(depth_list, accuracy)
plt.title("Accuracy change")
plt.xlabel("max_depth")
plt.ylabel("Accuracy")
plt.show()

print("max_depth:{}".format(depth_))
print("Meilleur score:{}".format(max_score))

production

max_depth:3
Meilleur score:0.9300699300699301
CPU times: user 228 ms, sys: 6.02 ms, total: 234 ms
Wall time: 237 ms

Résumé

Je pense que l'arbre de décision est un algorithme assez facile à comprendre dans l'apprentissage automatique, j'ai donc pensé qu'il serait facile à utiliser pour expliquer à des personnes qui ne comprennent pas l'apprentissage automatique.

Recommended Posts

Arbre de décision (classification)
Arbre de décision (load_iris)
Arbre de décision et forêt aléatoire
Qu'est-ce qu'un arbre de décision?
[Python] Tutoriel personnel sur l'arbre de décision
Machine Learning: Supervisé - Arbre de décision
Arbre de décision (pour les débutants) -Édition de code-
Créer un arbre déterminé avec scikit-learn
Apprentissage automatique ③ Résumé de l'arbre de décision
Comprendre l'arbre de décision et classer les documents
J'ai essayé de comprendre l'arbre de décision (CART) pour classer soigneusement
Créer un arbre de décision à partir de 0 avec Python (1. Présentation)