[PYTHON] Understand the Decision Tree and classify documents

Introduction

This time, I summarized the decision tree of machine learning algorithms. In order to understand Random Forest and XGBoost, which are the development systems of decision trees, I would like to firmly understand the basics.

reference

In understanding the decision tree, I referred to the following.

-Practical machine learning with scikit-learn and TensorFlow -Make and understand! Introduction to ensemble learning algorithm

Decision Tree

Outline of decision tree

** Decision tree is an algorithm that searches for the "leaf" that best matches the condition by tracing the branch according to the condition from the "root". ** Create a conditional expression consisting of explanatory variables as a node based on the training data, and automatically create a model that can derive the prediction result in the "leaf" part. It can handle both classification problems and regression problems, and they are called ** regression trees ** and ** classification trees **, respectively.

The advantages and disadvantages are as follows, but I think that the advantage of ** easy to interpret ** is the reason for using the decision tree most.

merit

--Easy to interpret (clearly visualize what is judged as a feature) --Data of various scales can be handled as it is (no preprocessing such as scaling is required)

Demerit

――It is not a method with high classification performance --Prone to overfitting --Not suitable for linear data

Image of decision tree

Let's easily train the decision tree algorithm for sklearn's iris dataset and draw the decision tree.

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier

iris = load_iris()

X = iris.data[:,2:]
y = iris.target

tree_clf = DecisionTreeClassifier(max_depth=2, random_state=42)
tree_clf.fit(X, y)

from graphviz import Source
from sklearn.tree import export_graphviz

export_graphviz(
        tree_clf,
        out_file=os.path.join("iris_tree.dot"),
        feature_names=iris.feature_names[2:],
        class_names=iris.target_names,
        rounded=True,
        filled=True
    )

The following is executed on the command line (to convert dot file to png)

$ dot -Tpng iris_tree.dot -o iris_tree.png

Then, the following decision tree can be output. It branches depending on whether the condition is met (True or False), and the colored node derives the final prediction result as a "leaf".

iris_tree.png

Decision tree learning and impurity

There is ** impurity ** as an index to see whether each node of the decision tree is able to create a conditional branch successfully. Furthermore, there are two methods for measuring the ** impurity **: ** GINI inpurity ** and ** entropy **.

GINI impurity

Assuming that $ n $ is the number of classes contained in the data and $ p_ {i} $ is the probability that the data is the class $ i $, the Gini impureness can be expressed by the following formula.


Gini(p) = \sum_{i=1}^{n} p_{i}(1-p_{i}) = 1-\sum_{i=1}^{n} p_{i}^2

It's hard to understand if it's just a mathematical formula, so let's see a concrete example. Again, check the green node part of the decision tree illustrated earlier. The value of ** values ** here shows how many data can be classified into each class in this conditional branch, and the value of ** class ** is judged as this classification at the time of this conditional branch. It represents something like.

iris_tree.png

If you check the * values * of the green node, it is [0, 49, 5], which means that 0 data is set to Setosa, 49 data is set to Versicle, and 5 data are set to Virginia in this conditional branch. It means that it is classified into. However, since the value of * class * is versicle this time, it is best ** that all of them are classified as versicle. In other words, anything else classified as ** impurities ** can be said to be ** impurities **, so the decision tree uses an index called ** impure ** to measure it quantitatively.

Let's calculate the Gini impureness of this specific example.


1 - (\dfrac{49}{50})^2 - (\dfrac{5}{50})^2 \approx 0.168

We were able to calculate the Gini purity. You can see that it matches the value of * gini * in the green node.

Entropy

It is also possible to use entpy as an indicator of purity. Entropy is a concept that expresses ** messiness ** based on the idea of information theory. The image is that if the entropy is large, ** messy = high purity **. The concept of entropy is explained in Previous article, so please refer to that as well.

Entropy is expressed by the following formula.


H(p) = -\sum_{i=1}^{n} p_{i}log_{2}(p_{i})

Let's calculate the enthusiasm of the green node as in the case of Gini impure.


-(\dfrac{49}{50})log_{2}(\dfrac{49}{50}) - (\dfrac{5}{50})log_{2}(\dfrac{5}{50}) \approx 0.445

By giving criterion ='entropy' as an argument when training a model on sklearn, it is possible to proceed with training with impureness as entropy. In the decision tree below, you can see that * entropy * is included instead of * gini *.

iris_tree.png

Basically, it seems that there is no big difference in learning results between using ** Gini Impure ** and ** Entropy **, but using Gini Impure is slightly better. It seems that the calculation speed is fast.

CART training algorithm

There are several algorithms in the decision tree, but this time we will introduce the most basic CART (Classification and Regression Tree). The sklearn algorithm uses CART. CART is an algorithm that creates only two-choice conditional branches (Yes or No) at each node, and creates branches so that the cost function as shown below is minimized.

Consider that one feature is $ k $ and the data is divided into two at the threshold $ t_ {k} $ for that $ k $. At this time, the optimum $ k $ and $ t_ {t} $ can be obtained by minimizing the following loss function.


L(k, t_{t}) = \dfrac{n_{right}}{n}Gini_{right} + \dfrac{n_{left}}{n}Gini_{left}

-$ Gini_ {right} / Gini_ {right} $ are each impure after left and right division -$ n_ {right} / n_ {left} $ is the number of data after left and right division

The meaning of the above formula is the weighted average of the number of impure data of the node after left and right division, and the optimum threshold can be found by minimizing this.

Decision tree parameters

The decision tree continues to branch deeply to the line ** where the purity does not decrease even if it is further divided. However, the decision tree has the property of ** being overfitted to the given training data **, so if you allow it to create branches to any depth, it will easily overfit.

As a countermeasure, give parameters that limit the shape of the decision tree. Typical examples are parameters that limit the depth of the tree ("max_depth" in sklearn) and the lower limit of the number of samples required to divide a node ("min_sample_leaf" in sklearn).

Document classification using decision trees

Creating a decision tree model using sklearn

In the following, we will actually create a model of the decision tree using the library.

Library used

scikit-learn 0.21.3

data set

It is possible to easily create a model of a decision tree using sklearn. This time, we will use "livedoor news corpus" for the dataset. For details of the dataset and the method of morphological analysis, please refer to Posted in the previously posted article. I will.

In the case of Japanese, preprocessing that decomposes sentences into morphemes is required in advance, so after decomposing all sentences into morphemes, they are dropped into the following data frame.

スクリーンショット 2020-01-13 21.07.38.png

The rightmost column is the one in which all sentences are morphologically analyzed and separated by half-width spaces. Use this to create a model of the decision tree.

Model learning

Create a model of the decision tree using sklearn. Below are the main parameters for creating a model.

Parameter name Meaning of parameters
criterion {“gini”, “entropy”}Whether to use gini purity or entropy as an indicator of purity
max_depth Maximum value of how much depth of decision tree is allowed(If not set, it may be possible to create a very complicated decision tree, which affects generalization performance.)
min_samples_split In splitting the nodeMinimum number of samples required(If the minimum value is small, a model that divides finely may be created, which is related to generalization performance.)
min_samples_leaf Minimum number of samples that must be left in the leaf at least(Similar to the above, if the minimum value is small, a model that divides finely may be created, which is related to generalization performance.)

This time, after converting the sentence into Bag-of-words format (counting how many words each word contains in each sentence and vectorizing it), the vector is applied to the decision tree. I will try to classify it-related articles called'it-life-hack'and sports-related articles called'sports-watch'.


import pandas as pd
import pickle

#It is assumed that the data frame after morpheme decomposition is already pickled and has
with open('df_wakati.pickle', 'rb') as f:
    df = pickle.load(f)

#Verify if you can classify two types of articles this time
ddf = df[(df[1]=='sports-watch') | (df[1]=='it-life-hack')].reset_index(drop = True)

#Bag sentences using sklearn's library-of-Convert to words format
from sklearn.feature_extraction.text import CountVectorizer
vectorizer = CountVectorizer(token_pattern="(?u)\\b\\w+\\b")
X = vectorizer.fit_transform(ddf[3])

#Convert article type to number
def convert(x):
    if x == 'it-life-hack':
        return 0
    elif x == 'sports-watch':
        return 1

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split

iris = load_iris()

X = X
y = ddf[1].apply(lambda x : convert(x))

#Separate training data and evaluation data
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, random_state=42)

#Learn model
tree_clf = DecisionTreeClassifier(criterion = 'gini', max_depth=4, min_samples_leaf = 10,random_state=42)
tree_clf.fit(X_train, y_train)

#Output the result for the evaluation data
print(tree_clf.score(X_test, y_test))

The value output as accuracy is here. The accuracy is better than you can imagine. There seems to be a clear difference between the words used in it-related articles and sports-related articles.

0.9529190207156308

Now let's visualize the model created this time.

from graphviz import Source
from sklearn.tree import export_graphviz
import os

export_graphviz(
        tree_clf,
        out_file=os.path.join("text_classification.dot"),
        feature_names=vectorizer.get_feature_names(),
        class_names=['it-life-hack', 'sports-watch'],
        rounded=True,
        filled=True
    )

The following is executed as a command

$ dot -Tpng text_classification.dot -o text_classification.png

iris_tree.png

You can see that somehow convincing features such as "product" and "player" are adopted as the conditions for dividing the nodes. The big strength is that the contents are not a black box.

Next I would like to study Random Forest, Adaboost, Xgboost, lightGBM, etc., which are development systems of decision trees, step by step.

Recommended Posts

Understand the Decision Tree and classify documents
Decision tree and random forest
Wagtail Recommendations (3) Understand and use the tree structure of pages
Try to implement and understand the segment tree step by step (python)
Review the tree structure and challenge BFS
2. Make a decision tree from 0 with Python and understand it (2. Python program basics)
I tried to understand the decision tree (CART) that makes the classification carefully
Make a decision tree from 0 with Python and understand it (4. Data structure)
Create a decision tree from 0 with Python and understand it (5. Information Entropy)
Decision tree (classification)
Understand the TensorFlow namespace and master shared variables
How to visualize the decision tree model of scikit-learn
Carefully understand the exponential distribution and draw in Python
Visualize data and understand correlation at the same time
Plot and understand the multivariate normal distribution in Python
Carefully understand the Poisson distribution and draw in Python
Create a decision tree from 0 with Python and understand it (3. Data analysis library Pandas edition)
Understand the Strategy pattern by comparing JavaScript and Java code
Understand the difference between cumulative assignment to variables and cumulative assignment to objects
Understand the Decorator pattern by comparing JavaScript and Java code
Understand the State pattern by comparing JavaScript and Java code
"Deep copy" and "Shallow copy" to understand with the smallest example
Understand the Composite pattern by comparing JavaScript and Java code