[PYTHON] Entscheidungsbaum (Klassifikation)

Was ist ein Entscheidungsbaum?

Im Bereich des maschinellen Lernens ist der Entscheidungsbaum ein Vorhersagemodell, und die Beobachtungsergebnisse für einen bestimmten Gegenstand führen zu einer Schlussfolgerung über den Zielwert des Gegenstands. Die internen Knoten entsprechen den Variablen, und die Verzweigungen zu den untergeordneten Knoten geben die möglichen Werte der Variablen an. Das Blatt (Endpunkt) repräsentiert den vorhergesagten Wert der Zielvariablen in Bezug auf den Variablenwert, der durch die Route von der Wurzel dargestellt wird. wikipedia

Dies ist schwer zu verstehen, daher sieht es so aus, wenn es tatsächlich in einem Diagramm dargestellt wird. 207486.jpg

Dies ist eine Klassifizierung von Entscheidungsbäumen, die bestimmt, ob Tennis gespielt werden soll oder nicht. Auf diese Weise können alltägliche Entscheidungen auch durch einen Entscheidungsbaum dargestellt werden.

Über die Tiefe des Entscheidungsbaums

Es gibt eine Tiefe, die als Tiefe des Entscheidungsbaums bezeichnet wird, und im Beispiel, ob Tennis gespielt werden soll oder nicht, entspricht das Wetter → Luftfeuchtigkeit der Tiefe des Entscheidungsbaums. Auf der anderen Seite ist nicht nur die Luftfeuchtigkeit niedrig, sondern als Reaktion auf den Fall "Weil es am Tag zuvor geregnet hat, wie ist der Zustand des Mantels jetzt?" Sie können es strenger machen, indem Sie viele Fragen stellen. Wenn Sie jedoch mit der oben beschriebenen Methode immer mehr Klassifizierungsregeln hinzufügen, erhalten Sie eine große Anzahl von Knoten und einen sehr tiefen Entscheidungsbaum. Dies ist ein Übertraining (auch eine Überanpassung) für Trainingsdaten. Geeignet) wird gezeigt, und ein Modell mit hoher Varianz (das Ausmaß der Variation der vorhergesagten Werte) wurde fertiggestellt.

Beziehung zwischen max_depth (Tiefe des Entscheidungsbaums) und Genauigkeit

Wenn max_depth 3-4 ist, wird die Genauigkeit maximal, aber wenn max_depth 4 oder höher ist, nimmt die Genauigkeit ab. Dies ist auf Überlernen zurückzuführen, das dadurch verursacht wird, dass max_depth zu groß wird. Auf diese Weise muss der ermittelte Baum auf die entsprechende Baumtiefe beschnitten (geschnitten) werden. Sie können die Generalisierungsleistung des Modells durch Bereinigen verbessern. Im Programm ist max_depth, das die Genauigkeit maximiert, wenn 3.

download.png

Beschreibung der Daten

Dies ist ein Datensatz, der diagnostische Daten für Brustkrebs zusammenfasst. Jeder Fall hat 32 Werte, einschließlich Inspektionswerte, was ihn zu einem variablen Datensatz macht. Jeder Fall wird von einem Diagnoseergebnis eines gutartigen Tumors oder eines bösartigen Tumors begleitet, und das Klassifizierungslernen wird unter Verwendung dieses als Zielvariable durchgeführt.

Code


%%time
import matplotlib.pyplot as plt 
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

#Brustkrebsdaten lesen
cancer_data = load_breast_cancer()

#Trennung von Trainingsdaten und Testdaten
train_X, test_X, train_y, test_y = train_test_split(cancer_data.data, cancer_data.target, random_state=0)

#Zustandseinstellung
max_score = 0
accuracy = []
depth_list = [i for i in range(1, 15)]

#Ausführung des Entscheidungsbaums
for depth in tqdm(depth_list):
    clf = DecisionTreeClassifier(criterion="entropy", splitter="random", max_depth=depth, min_samples_split=4, min_samples_leaf=1, random_state=56)
    clf.fit(train_X, train_y)
    accuracy.append(clf.score(test_X, test_y))
    if max_score < clf.score(test_X, test_y):
        max_score = clf.score(test_X, test_y)
        depth_ = depth

#Konvertieren Sie den Entscheidungsbaum in eine Punktdatei
clf = DecisionTreeClassifier(criterion="entropy", splitter="random", max_depth=depth_, min_samples_split=4, min_samples_leaf=1, random_state=56)
clf.fit(train_X, train_y)
tree.export_graphviz(clf, out_file="tree.dot", feature_names=cancer_data.feature_names, class_names=cancer_data.target_names, filled=True)

#Diagrammplot
plt.plot(depth_list, accuracy)
plt.title("Accuracy change")
plt.xlabel("max_depth")
plt.ylabel("Accuracy")
plt.show()

print("max_depth:{}".format(depth_))
print("Bestes Ergebnis:{}".format(max_score))

Ausgabe

max_depth:3
Bestes Ergebnis:0.9300699300699301
CPU times: user 228 ms, sys: 6.02 ms, total: 234 ms
Wall time: 237 ms

Zusammenfassung

Ich denke, der Entscheidungsbaum ist ein ziemlich leicht verständlicher Algorithmus für maschinelles Lernen, daher dachte ich, er wäre einfach zu verwenden, wenn man Leuten erklärt, die maschinelles Lernen nicht verstehen.

Recommended Posts

Entscheidungsbaum (Klassifikation)
Entscheidungsbaum (load_iris)
Entscheidungsbaum und zufälliger Wald
Was ist ein Entscheidungsbaum?
[Python] Persönliches Tutorial zum Entscheidungsbaum
Maschinelles Lernen: Überwacht - Entscheidungsbaum
Entscheidungsbaum (für Anfänger) -Code Edition-
Erstellen eines bestimmten Baums mit Scikit-Learn
Maschinelles Lernen ③ Zusammenfassung des Entscheidungsbaums
Verstehen Sie den Entscheidungsbaum und klassifizieren Sie Dokumente
Ich habe versucht, den entscheidenden Baum (CART) zu verstehen, um ihn sorgfältig zu klassifizieren
Erstellen Sie mit Python einen Entscheidungsbaum von 0 (1. Übersicht)