[PYTHON] Get to know the feelings of gradient boosting trees

The first-hand gradient boosting tree in machine learning of table data is used as much as the first-hand 76 steps of shogi.

First, apply gradient boosting, and if you can see which features are likely to be effective, it will be even more exciting, so let's display the importance of the features of gradient boosting.

Execution environment and data

I tried it with Googlw Colaboratory. You can run it with your Jupyter or Python script without any problems.

The notebook is below. https://colab.research.google.com/drive/1N1gtzTHFRKsbm88NyuEKqBr9wNS3tU7K?usp=sharing

The data uses the following "speed date test". https://knowledge-ja.domo.com/Training/Self-Service_Training/Onboarding_Resources/Fun_Sample_Datasets

This is the data of an experiment in which after a 4-minute date with all the participants, they wanted to date again or evaluated the date. The data seems to be very interesting, but it was quite difficult to see the data because there were nearly 200 items that were not explained, so this time I will look at the execution method without being too particular about the data.

Source code and execution results

First, get the data.

! wget https://knowledge-ja.domo.com/@api/deki/files/5950/Speed_Dating_Data.csv

Install the package for encoding text columns.

! pip install category_encoders

Load CSV data into a data frame.

import pandas as pd
speed_date = pd.read_csv("Speed_Dating_Data.csv", encoding='cp932')
speed_date

image.png

Preprocess the data. The data is a little rough.

import category_encoders as encoders
label_cols = ['field', 'from', 'career', 'undergra']

#Processing excluding id column etc.
exist_match_df = speed_date

object_col = exist_match_df['match']
object_col = object_col.values.astype(int)

feature_cols = exist_match_df.iloc[:,13:97]
feature_cols = feature_cols.drop('dec_o', axis=1)
col_names = feature_cols.columns.values
feature_cols['zipcode'] = feature_cols['zipcode'].str.replace(',', '')
feature_cols['income'] = feature_cols['income'].str.replace(',', '')
feature_cols['tuition'] = feature_cols['tuition'].str.replace(',', '')
feature_cols['mn_sat'] = feature_cols['mn_sat'].str.replace(',', '')

ordinal_encoder = encoders.OrdinalEncoder(cols=label_cols, handle_unknown='impute')
feature_cols = ordinal_encoder.fit_transform(feature_cols)

feature_cols = feature_cols.values.astype(float)
feature_cols

Learning is performed based on the processed data.

import xgboost as xgb
from sklearn.model_selection import cross_validate, cross_val_predict, KFold

kfold = KFold(n_splits=5)
score_func = ["accuracy", "precision_macro", "recall_macro", "f1_macro"]
 
clf = xgb.XGBClassifier(objective="binary:logistic", max_depth=10, n_estimatoers=10000, early_stopping_rounds=20)
 
score = cross_validate(clf, feature_cols, object_col, cv=kfold, scoring=score_func, return_estimator=True)
 
print('acc:       ' + str(score["test_accuracy"].mean()))
print('precision: ' + str(score["test_precision_macro"].mean()))
print('recall:    ' + str(score["test_recall_macro"].mean()))
print('F1:        ' + str(score["test_f1_macro"].mean()))

acc: 0.8350436362341039 precision: 0.6755307380632243 recall: 0.5681596439505251 F1: 0.5779607716750095

Looking at recall etc., I can not say that I learned very well, but this time the main subject is after this.

Output of execution result

You can get the estimator by return_estimator = True of cross_validate () earlier. Since estimator has a highly relevant explanatory variable as feature_importances_, it outputs this.

import numpy as np

estimators = score["estimator"]
sum_score = np.zeros(len(col_names))
 
for i in range(5):
    sum_score += estimators[i].feature_importances_
 
df_score = pd.DataFrame(sum_score/5, index=col_names, columns=["score"])
df_score.sort_values("score", ascending=False)

image.png

In this way, you can get the variables that you consider important when you train.

Recommended Posts

Get to know the feelings of gradient boosting trees
To get the path of the currently running python.exe
How to know the port number of the xinetd service
How to get the number of digits in Python
Try to get the contents of Word with Golang
Script to get the expiration date of the SSL certificate
Get the number of digits
Try to get the function list of Python> os package
I tried to get the location information of Odakyu Bus
I want to get the operation information of yahoo route
I want to know the features of Python and pip
Keras I want to get the output of any layer !!
I want to know the legend of the IT technology world
To get the name of the primitive etc. generated immediately before
Get the source of the page to load infinitely with python.
I want to get the name of the function / method being executed
Get the number of views of Qiita
How to know the internal structure of an object in Python
How to get the ID of Type2Tag NXP NTAG213 with nfcpy
[Python] How to get the first and last days of the month
How to get the Python version
Get the first element of queryset
How to get the vertex coordinates of a feature in ArcPy
Get the number of Youtube subscribers
Create a function to get the contents of the database in Go
I want to know the population of each country in the world.
Supplement to the explanation of vscode
PhytoMine-I tried to get the genetic information of plants with Python
I tried to get the batting results of Hachinai using image processing
Python Note: When you want to know the attributes of an object
[Linux] Command to get a list of commands executed in the past
I measured 6 methods to get the index of the maximum value (minimum value) of the list
I tried to get the authentication code of Qiita API with Python.
Try to get the road surface condition using big data of road surface management
Get the number of visits to each page with ReportingAPI + Cloud Functions
I want to express my feelings with the lyrics of Mr. Children
I tried to get the RSS of the top song of the iTunes store automatically
Get the song name from the title of the video you tried to sing
I tried to get the movie information of TMDb API with Python
The story of trying to reconnect the client
Script to change the description of fasta
How to get rid of long comprehensions
10 methods to improve the accuracy of BERT
The story of adding MeCab to ubuntu 16.04
Get the column list & data list of CASTable
I didn't know the basics of Python
Get the path to the systemd unit file
How to get colored output to the console
Get the value of the middle layer of NN
Get the last day of the specified month
[Python] Get the character code of the file
Get the filename of a directory (glob)
The story of pep8 changing to pycodestyle
[PowerShell] Get the reading of the character string
I wanted to know the number of lines in multiple files, so I tried to get it with a command
I want to get the path of the directory where the running file is stored.
python I don't know how to get the printer name that I usually use.
Use the Java SDK of GoogleMapsAPI to get the result of reverse GeoCoding in Japanese.
Predict the number of titles won by Souta Fujii 7th Dan by gradient boosting
[NNabla] How to get the output (variable) of the middle layer of a pre-built network
Introduction to Quiz Statistics (1) -Mathematical analysis of question sentences to know the tendency of questions-