[PYTHON] Correlation visualization of features and objective variables

Introduction

Before modeling in business or analysis competitions, we made it possible to quickly see the correlation between the created features and the objective variables. Here, in particular, modeling is targeted when the objective variable such as sales forecast is a continuous value.

The code below uses Signate's bento demand forecast data. (Data: https://signate.jp/competitions/24/data)

This data is the task of creating a model that predicts the number of bento boxes sold in column y.

column Header name Data type Description
0 datetid datetime Date to use as index (yyyy-m-d)
1 y int Number of sales (objective variable)
2 week char Day of the week (Monday-Friday)
3 soldout boolean Sold out flag (0:Not sold out, 1:sold out)
4 name varchar Main menu
5 kcal int There is a calorie (kcal) deficiency in the side dish
6 remarks varchar Remarks
7 event varchar Start at 13:00 In-house event where you can bring your own lunch
8 payday boolean Payday flag (1:Payday)
9 weather varchar weather
10 precipitation float Precipitation. If not"--"
11 temperature float temperature

Library to use

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
%matplotlib inline

Data confirmation

df = pd.read_csv('./signate/train.csv')
# df.shape >> (207, 12)
df.head(2)
datetime y week soldout name kcal remarks event payday weather precipitation temperature
0 2013-11-18 90 month 0 Thick sliced squid NaN NaN NaN NaN Sunny -- 19.8
1 2013-11-19 101 fire 1 Handmade fin cutlet NaN NaN NaN NaN Sunny -- 17.0

Data preprocessing

df['precipitation'] = df.precipitation.replace({'--' : '0'}).astype(float)
df = pd.get_dummies(df[['y', 'week', 'soldout', 'kcal', 'payday', 'weather', 'precipitation', 'temperature']])
df['payday'] = df.payday.fillna(0).astype(str)
df.head()
y soldout kcal payday precipitation temperature week_month week_tree week_Wed week_Tue week_Fri weather_fine weather weather_sunny weather_cloudy weather_light cloudy weather_rain weather_snow weather_Raiden
0 90 0 NaN 0.0 0.0 19.8 1 0 0 0 0 1 0 0 0 0 0 0
1 101 1 NaN 0.0 0.0 17.0 0 0 0 1 0 1 0 0 0 0 0 0
2 118 0 NaN 0.0 0.0 15.5 0 0 1 0 0 1 0 0 0 0 0 0
3 120 1 NaN 0.0 0.0 15.2 0 1 0 0 0 1 0 0 0 0 0 0
4 130 1 NaN 0.0 0.0 16.1 0 0 0 0 1 1 0 0 0 0 0 0

Statistic confirmation

df.describe()
y soldout kcal precipitation temperature week_month week_tree week_Wed week_Tue week_Fri weather_fine weather weather_sunny weather_cloudy weather_light cloudy weather_rain weather_snow weather_Raiden
count 207.000000 207.000000 166.000000 207.000000 207.000000 207.000000 207.000000 207.000000 207.000000 207.000000 207.000000 207.000000 207.000000 207.000000 207.000000 207.000000 207.000000
mean 86.623188 0.449275 404.409639 0.113527 19.252174 0.188406 0.207729 0.207729 0.198068 0.198068 0.256039 0.241546 0.256039 0.120773 0.115942 0.004831 0.004831
std 32.882448 0.498626 29.884641 0.659443 8.611365 0.391984 0.406666 0.406666 0.399510 0.399510 0.437501 0.429058 0.437501 0.326653 0.320932 0.069505 0.069505
min 29.000000 0.000000 315.000000 0.000000 1.200000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
25% 57.000000 0.000000 386.000000 0.000000 11.550000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
50% 78.000000 0.000000 408.500000 0.000000 19.800000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
75% 113.000000 1.000000 426.000000 0.000000 26.100000 0.000000 0.000000 0.000000 0.000000 0.000000 1.000000 0.000000 1.000000 0.000000 0.000000 0.000000 0.000000
max 171.000000 1.000000 462.000000 6.500000 34.600000 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000

Function that visualizes features and objective variables

argument

--df: Data frame to be visualized --target: Column name of objective variable

def make_plot(df, target):
    plt_col = sorted([c for c in df.columns if  c != target and len(df[c].unique()) > 1])
    
    col_num = len(plt_col)
    row_num = col_num // 2 + col_num % 2
    col_num = 2

    fig, ax = plt.subplots(row_num, col_num, figsize=(18, 3*row_num), sharex=False, sharey=False)
    fig.subplots_adjust(left=0.1, right=0.95, hspace=0.7, wspace=0.4)

    for i,col in enumerate(plt_col):
        tmp = df[[target, col]]
        tmp = tmp[~pd.isna(tmp[col])].reset_index(drop=True)
        if len(tmp[col].unique()) == 1:
            continue

        p = ((i+1) // 2) + ((i+1) % 2) -1
        q = abs(((i+1) % 2) - 1)
        if len(tmp[col].unique()) > 2:
            percentile095 = tmp[col].quantile(0.95)
            over_tmp = tmp[tmp[col] >= percentile095].reset_index(drop=True)
            over_tmp[col] = percentile095

            if tmp[col].min() < 0:
                percentile005 = tmp[col].quantile(0.05)
                under_tmp = tmp[tmp[col] <= percentile005].reset_index(drop=True)
                under_tmp[col] = percentile005

                outof_percentile = tmp[(percentile005 < tmp[col]) & ( tmp[col] < percentile095)].reset_index(drop=True)
                new_tmp = pd.concat([outof_percentile, under_tmp, over_tmp], axis=0)
                if percentile095 == percentile005:
                    new_tmp = tmp.copy()
            else:
                percentile095 = tmp[col].quantile(0.95)
                over_tmp = tmp[tmp[col] >= percentile095].reset_index(drop=True)
                over_tmp[col] = percentile095

                outof_percentile = tmp[tmp[col] < percentile095].reset_index(drop=True)
                new_tmp = pd.concat([outof_percentile, over_tmp], axis=0)
                if percentile095 == 0:
                    new_tmp = tmp.copy()
            ax1 = ax[p,q]
            ax2 = ax1.twinx()
            n, bins, pathces = ax1.hist(new_tmp[col], bins='auto', label='Feature value: {}'.format(col), ec='black')

            new_tmp['bins'] = pd.cut(new_tmp[col], bins.tolist(), right=False).values.astype(str)
            if len([f for f in new_tmp[col].unique() if bins[-2] <= f and f < bins[-1]]) > 0:
                new_tmp['bins_start'] = [float(b.split(',')[0].replace('[', '')) for b in new_tmp['bins']]
                bins_max = new_tmp['bins_start'].max()
                nan_value = new_tmp.query('bins_start == @bins_max').bins.unique()[0]
                new_tmp['bins'] = new_tmp['bins'].replace({'nan' : nan_value})
            else:
                new_tmp['bins'] = new_tmp['bins'].replace({'nan' : '[{}, {}]'.format(bins[-2], bins[-1])})

            num_bin = new_tmp.groupby('bins').size().reset_index(name='cnt')
            mean_target_bin = new_tmp.groupby('bins')[target].mean().reset_index().rename(columns={target : '{}_mean'.format(target)})
            center_feature_bin = new_tmp.groupby('bins').agg({col : {np.max, np.min}}).reset_index()
            center_feature_bin.columns = ['bins', 'feature_max', 'feature_min']
            center_feature_bin['feature_center'] = center_feature_bin.apply(lambda x : (x['feature_max'] + x['feature_min']) / 2, axis=1)
            plt_data = mean_target_bin.merge(center_feature_bin, on='bins', how='inner').merge(num_bin, on='bins', how='inner').sort_values('feature_center', ascending=True).reset_index(drop=True)
            ax2.plot(plt_data['feature_center'], plt_data['{}_mean'.format(target)], label='Mean value of objective variable (for each bin)', marker='.', color='orange')

        else:
            new_tmp = tmp.copy()
            ax1 = ax[p,q]
            ax2 = ax1.twinx()
            bins_list = sorted(new_tmp[col].unique().tolist())
            a = new_tmp.groupby([col]).agg({col : len, target : np.mean}).rename(columns={col : 'count', target : '{}_mean'.format(target)}).reset_index().astype({col : str})
            ax1.bar(a[col], a['count'], label='Feature value: {}'.format(col), ec='black')
            ax2.plot(a[col], a['{}_mean'.format(target)], label='Mean value of objective variable (for each bin)', marker='.', color='orange')

        ax2.hlines([new_tmp[target].mean()], new_tmp[col].min(), new_tmp[col].max(), color="darkred", linestyles='dashed', label='Mean of objective variables (whole data)')

        handler1, label1 = ax1.get_legend_handles_labels()
        handler2, label2 = ax2.get_legend_handles_labels()
        ax1.legend(handler1 + handler2, label1 + label2, borderaxespad=0., bbox_to_anchor=(0, 1.45), loc='upper left', fontsize=9)
        ax1.set_ylabel('count', fontsize=12)
        ax2.set_ylabel('Objective variable', fontsize=12)
        ax1.set_title('{}'.format(col), loc='right', fontsize=12)

    plt.show()
target = 'y'  #Specifying the objective variable
make_plot(df, target)  #plot

Visualization result

Items to visualize ** 1. Histogram or bar graph (blue) **

--Histogram when features are continuous values --Bar graph when features are binary

** 2. Mean value of objective variable for each bin (yellow) **

--Line graph of the mean value of the objective variable for each bin (for each value in the case of two values)

** 3. Mean of objective variables for the entire data (red) **

output_7_0.png

Recommended Posts

Correlation visualization of features and objective variables
Visualization method of data by explanatory variable and objective variable
Features of symbolic and hard links
Aggregation and visualization of accumulated numbers
Example of using class variables and class methods
Visualization of CNN feature maps and filters (Tensorflow 2.0)
How to create explanatory variables and objective functions
[Python] Chapter 02-01 Basics of Python programs (operations and variables)
Analysis of financial data by pandas and its visualization (2)
Calculation of standard deviation and correlation coefficient in Python
Analysis of financial data by pandas and its visualization (1)
Difference between Ruby and Python in terms of variables
[Control engineering] Visualization and analysis of PID control and step response
Visualization of the connection between malware and the callback server
[Python] Types of statistical values (features) and calculation methods
Global and local variables 2
Features of Go language
Main features of ChainMap
Global and local variables 1
Reference order of class variables and instance variables in "self. Class variables" in Python
Negative / positive judgment of sentences and visualization of grounds by Transformer
Using MLflow with Databricks ② --Visualization of experimental parameters and metrics -
Negative / positive judgment of sentences by BERT and visualization of grounds
I want to know the features of Python and pip
[Python] What do you do with visualization of 4 or more variables?