An entry on how to visualize a model of a decision tree in scikit-learn. I use it a lot these days, so I'll write it instead of a memorandum & My cheat sheet. In this entry, the sample code is built with the Windows version of Python 3.5.2.
The components required to visualize the decision tree are:
Graphviz has different installation methods for each OS. I think scikit-learn is often included by default. On the other hand, pydotplus will need to be installed with pip.
Graphviz stands for Graph Visualization Software. It is a library that makes images written in the DOT language. Please read here for details. http://www.graphviz.org/Documentation.php
The download page is as follows. --Window version http://www.graphviz.org/Download_windows.php --RHEL, CentOS version http://www.graphviz.org/Download_linux_rhel.php --ubuntu version http://www.graphviz.org/Download_linux_ubuntu.php --Source version http://www.graphviz.org/Download_source.php Since it will be installed in a Windows environment, download the MSI file from the following page and execute it. http://www.graphviz.org/Download_windows.php
When you run the downloaded MSI file, the following screen will be displayed first. Click Next to proceed with the screen.
The version at the time of writing the entry (2017/09/03) is 2.38. Here, proceed with "Everyone" selected so that all users can use it.
This message informs you that the installation is ready. Press "Next" to proceed.
As the component installation progresses, the indicator gauge will fill up. Click Next when the indicators are completely filled.
You have successfully installed Graphviz. Click Close to close the window.
Then move on to installing Pydotplus.
pydotplus A python module for working with the DOt language mentioned earlier. This time it's a Windows environment, so we'll work with Anaconda Prompt.
After launching Anaconda Prompt, run the command "pip install pydotplus".
(C:\Program Files\Anaconda3) C:\Users\usr********>pip install pydotplus
Collecting pydotplus
Downloading pydotplus-2.0.2.tar.gz (278kB)
100% |################################| 286kB 860kB/s
Requirement already satisfied: pyparsing>=2.0.1 in c:\program files\anaconda3\li
b\site-packages (from pydotplus)
Building wheels for collected packages: pydotplus
Running setup.py bdist_wheel for pydotplus ... done
Stored in directory: C:\Users\usr********\AppData\Local\pip\Cache\wheels\43\31\
48\e1d60511537b50a8ec28b130566d2fbbe4ac302b0def4baa48
Successfully built pydotplus
Installing collected packages: pydotplus
Successfully installed pydotplus-2.0.2
If successful, the above output will be output. If no error occurs and it is not interrupted, the pydotplus installation work is complete.
Then edit the environment variables to make pydotplus aware of the Graphviz installation path. First, find the location of the "bin" directory where graphviz is installed. You can check it by looking at the properties of "gvedit.exe" in the list of the start menu.
Add this path ("C: \ Program Files (x86) \ Graphviz2.38 / bin") to the environment variable path.
After changing the path, restart the Python IDE (such as PyCharm).
I made a RandomForest model using the familiar iris dataset, took out one decision tree model from that model, and visualized it (= output as a png image). The sample code is as follows. The print statement for debugging to understand the internal processing is left as it is.
u"""
Visualize the decision tree model.
Visualize a model of a decision tree using Graphviz.
It can be applied not only to decision trees but also to tree-structured models such as random forests.
"""
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.cross_validation import train_test_split
from sklearn.model_selection import cross_val_score
#Packages needed to visualize the tree structure of the model
from sklearn import tree
import pydotplus as pdp
import pandas as pd
import numpy as np
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
print(df.head(5))
print(iris.target)
print(iris.target_names)
df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)
print(df.head(5))
#Separate training data and test data
features = df.columns[:4]
label = df["species"]
print(features)
print(label)
print(df[features].head(5))
df_train, df_test, label_train, label_test = train_test_split(df[features], label)
clf = RandomForestClassifier(n_estimators=150)
clf.fit(df_train, label_train)
print("========================================================")
print("Prediction accuracy")
print(clf.score(df_test, label_test))
#Visualize one of the trees to try
estimators = clf.estimators_
file_name = "./tree_visualization.png "
dot_data = tree.export_graphviz(estimators[0], #Specify one decision tree object
out_file=None, #Since it passes dot language data to Graphviz without going through a file, None
filled=True, #When set to True, it will show in color which node was classified most at the time of branching.
rounded=True, #When set to True, the corners of the node are drawn round.
feature_names=features, #If this is not specified, the feature name will not be displayed on the chart.
class_names=iris.target_names, #If this is not specified, the classification name will not be displayed on the chart.
special_characters=True #Be able to handle special characters
)
graph = pdp.graph_from_dot_data(dot_data)
graph.write_png(file_name)
Let's take a look at each part. The following parts read the iris dataset and prepare the training data. The feature name is set in iris.feature_names. The objective variable (= iris type) is set in iris.target. However, iris.target is a number, which makes it unfriendly for humans to read. Therefore, using the kind name notation in iris.target_names, we set a human-readable (= human-redable) objective variable in df ['species'].
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
print(df.head(5))
print(iris.target)
print(iris.target_names)
df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)
print(df.head(5))
Next is the code for the part that creates the Random Forest model. The df that stores the training data also contains the objective variable. It is necessary to separate the feature part and the objective variable and input them to the model. So, set the features part to features and the objective variable to label. Then, train_test_split divides the data into model training and test data. clf is set to the RandomForest object. The number of decision trees to be used is 150. (Argument: n_estimator = 150) After that, specify the training data and train the model with the fit () method.
features = df.columns[:4]
label = df["species"]
print(features)
print(label)
print(df[features].head(5))
df_train, df_test, label_train, label_test = train_test_split(df[features], label)
clf = RandomForestClassifier(n_estimators=150)
clf.fit(df_train, label_train)
And finally the visualization. The RandomForest object has a property called estimators_. estimators_ is a list of decision tree objects. Here we visualize the first decision tree object (estimators [0]) as a sample. png Output as an image file "tree_visualization.png ". tree.export_graphviz () is doing the visualization process. The explanation of the argument is described in the comment of the code. ** Please note that if you do not specify the argument properly, neither the feature name nor the classification name will be displayed. ** **
#Visualize one of the trees to try
estimators = clf.estimators_
file_name = "./tree_visualization.png "
dot_data = tree.export_graphviz(estimators[0], #Specify one decision tree object
out_file=None, #Since it passes dot language data to Graphviz without going through a file, None
filled=True, #When set to True, it will show in color which node was classified most at the time of branching.
rounded=True, #When set to True, the corners of the node are drawn round.
feature_names=features, #If this is not specified, the feature name will not be displayed on the chart.
class_names=iris.target_names, #If this is not specified, the classification name will not be displayed on the chart.
special_characters=True #Be able to handle special characters
)
graph = pdp.graph_from_dot_data(dot_data)
graph.write_png(file_name)
Then, the visualization of the decision tree as shown below is obtained as a png image.
Recommended Posts