[PYTHON] How to find the correlation for categorical variables


When analyzing data, you will look at the correlation between variables in the given data. You can check the correlation coefficient for the correlation between numerical values, but what if one or both are categories? I looked it up, so I will summarize it.

How to check the correlation

Numeric vs Numeric

In this case, it is famous and you can check the correlation coefficient. The definition of the correlation coefficient is as follows.


To find the correlation coefficient in python, use the corr () method of pandas.DataFrame.

import numpy as np
import pandas as pd

x=np.random.randint(1, 10, 100)
y=np.random.randint(1, 10, 100)

data=pd.DataFrame({'x':x, 'y': y})


If the value is 0, there is no correlation, if it is close to 1, there is a positive correlation, and if it is close to -1, there is a negative correlation.

Category vs Numeric

It is expressed as a statistic called correlation ratio. The definition is as follows.

r=\frac{\sum_{category}category件数\times(categoryの平均-Overall average)^2}{Sum of squares of total deviation}

Please refer to here for a specific example. The numerator represents "how far each category is". The farther the categories are, the larger the numerator and the stronger the correlation. corr_ratio.png

This correlation ratio also means no correlation when it is 0, and a strong positive correlation when it approaches 1.

In python, the calculation is as follows (see here).

def correlation_ratio(cat_key, num_key, data):
    all_var=((numerical-mean)**2).sum()  #Sum of squares of total deviation
    categorical_num=[numerical[categorical==cat] for cat in unique_cat]
    categorical_var=[len(x.dropna())*(x.dropna().mean()-mean)**2 for x in categorical_num]  
    #Number of categories × (Average of categories-Overall average)^2

    return r

Category vs Category

We will look at it using a statistic called the number of Klamer correlations. The definition is


However, $ \ chi ^ {2} $ is the chi-square distribution, n is the number of data items, and k is the one with the smaller number of categories. Please refer to here for the χ-square distribution. Roughly speaking, it is a quantity that expresses how different the distribution of each category is from the overall distribution. Again, if it is close to 0, there is no correlation, and if it is close to 1, there is a positive correlation.

To calculate with python, do the following ([here](https://qiita.com/shngt/items/45da2d30acf9e84924b7#%E3%82%AF%E3%83%A9%E3%83%A1%E3] % 83% BC% E3% 83% AB% E3% 81% AE% E9% 80% A3% E9% 96% A2% E4% BF% 82% E6% 95% B0).

import scipy.stats as st

def cramerV(x, y, data):
    table=pd.crosstab(data[x], data[y])
    x2, p, dof, e=st.chi2_contingency(table, False)

    return r

Obtain each index collectively

And, this alone would be the second brew of the previous article, so I made a method to calculate each index collectively for DataFrame. You don't have to look it up one by one!

def is_categorical(data, key):
    if col_type=='int':
        return nunique<6
    elif col_type=="float":
        return False
        return True
def get_corr(data, categorical_keys=None):
    if categorical_keys is None:
        categorical_keys=keys[[is_categorycal(data, key) for ke in keys]]

    for key1 in keys:
        for key2 in keys:

            if (key1 in categorical_keys) and (key2 in categorical_keys):

                r=cramerV(key1, key2, data)
                corr_cramer.loc[key1, key2]=r                

            elif (key1 in categorical_keys) and (key2 not in categorical_keys):

                r=correlation_ratio(cat_key=key1, num_key=key2, data=data)
                corr_ratio.loc[key1, key2]=r                

            elif (key1 not in categorical_keys) and (key2 in categorical_keys):

                r=correlation_ratio(cat_key=key2, num_key=key1, data=data)
                corr_ratio.loc[key1, key2]=r                


                r=data.corr().loc[key1, key2]
                corr.loc[key1, key2]=r                    

    return corr, corr_ratio, corr_cramer

Which key is a categorical variable is automatically determined from the variable type unless otherwise specified.

Let's apply it to titanic data.

data=data.drop(["PassengerId", "Name", "Ticket", "Cabin"], axis=1)
category=["Survived", "Pclass", "Sex", "Embarked"]

corr, corr_ratio, corr_cramer=get_corr(data, category)

In addition, it can be visualized with the seaborn heatmap.

import seaborn as sns
sns.heatmap(corr_cramer, vmin=-1, vmax=1)



The explanation of each statistic has become messy, so please see the page mentioned in the reference. Even if I put it together, I end up forgetting it and checking it, so I try to create a method that automates as much as possible. The source of this method is also on github, so feel free to use it.


[Calculate the relationship between variables of various scales (Python)](https://qiita.com/shngt/items/45da2d30acf9e84924b7#%E3%82%AF%E3%83%A9%E3%83%A1% E3% 83% BC% E3% 83% AB% E3% 81% AE% E9% 80% A3% E9% 96% A2% E4% BF% 82% E6% 95% B0) Correlation analysis Correlation ratio Kai-square test / Cramer correlation number

Recommended Posts

How to find the correlation for categorical variables
How to specify the launch browser for JupyterLab 3.0.0
How to find the area of the Voronoi diagram
How to use MkDocs for the first time
How to use the generator
How to define Go variables
How to find Mahalanobis distance
How to use the decorator
How to increase the axis
How to start the program
[python] How to use the library Matplotlib for drawing graphs
[For beginners] How to use for statements on Linux (variables, etc.)
How to define multiple variables in a python for statement
I didn't know how to use the [python] for statement
How to find the scaling factor of a biorthogonal wavelet
How to find the cumulative sum / sum for each group using DataFrame in Spark [Python version]
How to set variables that can be used throughout the Django app-useful for templates, etc.-
How to calculate the autocorrelation coefficient
How to use the zip function
[For non-programmers] How to walk Kaggle
How to read the SNLI dataset
How to get the Python version
[Python] How to import the library
How to overwrite the output to the console
How to use the ConfigParser module
How to execute the sed command many times using the for statement
How to set the output resolution for each keyframe in Blender
How to change the log level of Azure SDK for Python
How to get the printer driver for Oki Mac into Linux
[Introduction to Python] How to use the in operator in a for statement?
How to find the memory address of a Pandas dataframe value
How to use machine learning for work? 01_ Understand the purpose of machine learning
[AtCoder] How to find the binomial coefficient nCk mod.p [D --Knight]
How to find if you don't know the Java installation directory
How to study for the Deep Learning Association G test (for beginners) [2020 version]
How to display the progress bar (tqdm)
Easily visualize the correlation coefficient between variables
How to dynamically define variables in Python
How to check the version of Django
How to create * .spec files for pyinstaller.
How to solve the bin packing problem
Find the dates for a jarring tournament
How to set the server time to Japanese time
How to find out the number of CPUs without using the sar command
How to manually update the AMP cache
[Linux] How to use the echo command
How to use the Linux grep command
[Python] Organizing how to use for statements
How to install Windows Subsystem For Linux
How to use Pylint for PyQt5 apps
How to access the Datastore from the outside
How to find large files on Linux
How to use fingerprint authentication for KDE
How to pass values to JavaScript variables directly from the [Django] template tag
[Circuit x Python] How to find the transfer function of a circuit using Lcapy
[Bash] While read, pass the contents of the file to variables for each column
How to find the coefficient of the trendline that passes through the vertices in Python
[linux] How to quit without waiting for the other party to disconnect with telnet
Checklist on how to avoid turning the elements of numpy's array with for
[Introduction to Python] How to get the index of data with a for statement
How to assign multiple values to the Matplotlib colorbar