[PYTHON] I tried to visualize all decision trees of random forest with SVG

Overview

I tried to summarize how to output multiple decision trees of RF in one SVG file using dtreeviz and svgutils.

Select and display any model from RF

I used the link below as it is and executed it for the time being. [Try dtreeviz for RandomForest](https://qiita.com/go50/item![Layout design.png](https://qiita-image-store.s3.ap-northeast-1.amazonaws.com/ 0/465379/d58f187f-3c25-82d6-e904-47dfc64147af.png) s/38c7757b444db3867b17)

from sklearn.datasets import load_iris
from sklearn import tree
from dtreeviz.trees import dtreeviz
from sklearn.ensemble import RandomForestClassifier

iris = load_iris()
clf = RandomForestClassifier(n_estimators=100 , max_depth = 2)
clf.fit(iris.data, iris.target)

estimators = clf.estimators_
viz = dtreeviz(
    estimators[0],
    iris.data, 
    iris.target,
    target_name='variety',
    feature_names=iris.feature_names,
    class_names=[str(i) for i in iris.target_names],
) 
viz.view()

Random_forest_1.png

The first model of the 100 decision trees can be visualized.

Save all models as SVG files

In the above program, one model will generate one svg file. By loop processing, all decision trees contained in RF were output as SVG. (Use viz.save () because it is troublesome to display all 100)

import tqdm

It is used to measure the processing time.

from tqdm import tqdm

Save all models

for estimator in tqdm(estimators):
    viz = dtreeviz(
    estimator,
    iris.data, 
    iris.target,
    target_name='variety',
    feature_names=iris.feature_names,
    class_names=[str(i) for i in iris.target_names],
    ) 
    viz.save()

occurring the incident

When I checked the Temp folder of the output destination, there was a problem that only one SVG file of the decision tree model was saved. image.png

Apparently, the output file naming convention includes the process ID of the execution environment. It seems that the same file name is generated every time, the SVG file is updated each time, and only the last model is saved. Contents of site-packages \ dtreeviz \ tree.py

 def save_svg(self):
        """Saves the current object as SVG file in the tmp directory and returns the filename"""
        tmp = tempfile.gettempdir()
        svgfilename = os.path.join(tmp, f"DTreeViz_{os.getpid()}.svg")
        self.save(svgfilename)
        return svgfilename

Bug fixes

Fixed the file naming convention to be generated using the runtime time. Import datatime into site-packages \ dtreeviz \ tree.py

from datetime import datetime

Fixed save_svg ()

    def save_svg(self):
        """Saves the current object as SVG file in the tmp directory and returns the filename"""
        tmp = tempfile.gettempdir()
        #svgfilename = os.path.join(tmp, f"DTreeViz_{os.getpid()}.svg")
        now = datetime.now()
        svgfilename = os.path.join(tmp, f"DTreeViz_{now:%Y%m%d_%H%M%S}.svg")
        self.save(svgfilename)
        return svgfilename

Run again

⇒Successful SVG output of all decision tree models image.png

Consolidate all svg files

It's very annoying to look at the above files one by one. I integrated it into one file using svgutils and output it. (I can't find the site I referred to when using svgutils .. I will link it as soon as I rediscover it. )

It is designed to be as square as possible according to the number of decision trees & so that the layout can be corrected immediately even if the depth of the decision tree is changed.

レイアウト設計.png

Save 100 SVGs created in advance in a specific file and execute the following program

import svgutils.transform as sg
import glob
import math
import os

def join_svg(cell_w, cell_h):
    SVG_file_dir = "./SVG_files"
    svg_filename_list = glob.glob(SVG_file_dir + "/*.svg")

    fig_tmp = sg.SVGFigure("128cm", "108cm")
    N = len(svg_filename_list)
    n_w_cells = int(math.sqrt(N))

    i = 0
    plot_list, txt_list = [], []

    for target_svg_file in svg_filename_list:
        print("i : {}".format(i))
        pla_x = i % n_w_cells
        pla_y = int(i / n_w_cells)
        print("Plot position:[x,y] : {},{}".format(pla_x, pla_y))
        print(target_svg_file)
        fig_target = sg.fromfile(target_svg_file)
        plot_target = fig_target.getroot()
        plot_target.moveto(cell_w * pla_x, cell_h * pla_y, scale=1)
        print("Model coordinates: {},{}".format(cell_w * pla_x, cell_h * pla_y))
        plot_list.append(plot_target)
        txt_target = sg.TextElement(25 + cell_w * pla_x, 20 + cell_h * pla_y,
                                    str(i), size=12, weight="bold")
        print("Text coordinates: {},{}".format(25 + cell_w * pla_x, 20 + cell_h * pla_y))
        txt_list.append(txt_target)
        print(i)
        i += 1

    fig_tmp.append(plot_list)
    fig_tmp.append(txt_list)

    ouput_dir = SVG_file_dir + "/output"

    try :
        fig_tmp.save(ouput_dir + "/RF.svg")

    except FileNotFoundError:
        os.mkdir(ouput_dir)
        fig_tmp.save(ouput_dir + "/RF.svg")
 
join_svg(400, 300)

Output result

All files have been successfully combined successfully.

RF_output_svg.png

The file size is larger than expected (about 10M). It takes time to display even in chrome. In some cases, you may get an error due to insufficient memory to display if other apps are still running.

Recommended Posts

I tried to visualize all decision trees of random forest with SVG
I tried to visualize AutoEncoder with TensorFlow
I tried to visualize the text of the novel "Weathering with You" with WordCloud
I tried to extract features with SIFT of OpenCV
I tried to visualize the spacha information of VTuber
I tried using Random Forest
Python practice 100 knocks I tried to visualize the decision tree of Chapter 5 using graphviz
I tried to visualize the running data of the racing game (Assetto Corsa) with Plotly
I tried to find the entropy of the image with python
I tried to find the average of the sequence with TensorFlow
Visualize the results of decision trees performed with Python scikit-learn
I tried adding post-increment to CPython. List of all changes
[Python] I tried to visualize tweets about Corona with WordCloud
[Python] I tried to visualize the follow relationship of Twitter
I tried to implement ListNet of rank learning with Chainer
I tried to sort a random FizzBuzz column with bubble sort.
Visualize decision trees with jupyter notebook
I tried to automate the watering of the planter with Raspberry Pi
I tried to create a list of prime numbers with python
I tried to fix "I tried stochastic simulation of bingo game with Python"
I tried to expand the size of the logical volume with LVM
I tried to visualize Boeing of violin performance by pose estimation
I tried to improve the efficiency of daily work with Python
I tried to automatically collect images of Kanna Hashimoto with Python! !!
I tried to visualize the common condition of VTuber channel viewers
I tried to make a mechanism of exclusive control with Go
[Python] I tried to visualize the prize money of "ONE PIECE" over 100 million characters with matplotlib.
I tried to implement Autoencoder with TensorFlow
I tried to get started with Hy
Visualize scikit-learn decision trees with Plotly's Treemap
[Python] I tried to visualize the night on the Galactic Railroad with WordCloud!
I tried to visualize the age group and rate distribution of Atcoder
I tried to get the authentication code of Qiita API with Python.
I tried to visualize Google's general object recognition NN, Inception-v3 with Tensorboard
I tried to automatically extract the movements of PES players with software
I tried to analyze the negativeness of Nono Morikubo. [Compare with Posipa]
I tried to streamline the standard role of new employees with Python
I tried to visualize the model with the low-code machine learning library "PyCaret"
I tried to get the movie information of TMDb API with Python
I tried to predict the behavior of the new coronavirus with the SEIR model.
Since it is the 20th anniversary of the formation, I tried to visualize the lyrics of Perfume with Word Cloud
I tried handwriting recognition of runes with scikit-learn
I tried to predict next year with AI
I tried to detect Mario with pytorch + yolov3
I tried to implement reading Dataset with PyTorch
I tried to use lightGBM, xgboost with Boruta
I tried to learn logical operations with TF Learn
I tried to move GAN (mnist) with keras
I tried to make a simple mail sending application with tkinter of Python
I tried image recognition of CIFAR-10 with Keras-Image recognition-
I tried to notify slack of Redmine update
I tried to save the data with discord
I tried to detect motion quickly with OpenCV
I tried to find 100 million digits of pi
I tried to integrate with Keras in TFv1.1
I tried to touch the API of ebay
I tried to correct the keystone of the image
I tried to output LLVM IR with Python
I tried to detect an object with M2Det!
I tried to automate sushi making with python
I tried to generate a random character string