[PYTHON] Visualize data and understand correlation at the same time

Introduction

When analyzing data, I think you will use graphs to visualize the data. At that time, it would be convenient if the statistics showing the correlation between the two variables could be displayed at the same time. Therefore, we have made it possible to display the appropriate statistics on the appropriate graph according to the contents of the variable (category or numerical value).

Review so far

Here is a summary of the appropriate graphing methods for each variable content and the statistics that represent the correlations that I have covered so far. Please see the link below for details. sns_corr_summary.png Visualization method of data by explanatory variable and objective variable How to find the correlation for categorical variables

Put the right statistics on the right graph

Modify the previously created method "Draw the right graph according to the content of the variable (category or number)" and put the right statistic on the right graph (see: pandas DataFrame is right I made a method to automatically select and visualize various graphs).

import matplotlib.pyplot as plt
import seaborn as sns
import scipy.stats as st

def visualize_data(data, target_col, categorical_keys=None):
     
    keys=data.keys()
        
    if categorical_keys is None:
        
        categorical_keys=keys[[is_categorical(data, key) for key in keys]]
   
    for key in keys:
        
        if key==target_col:
            continue
            
        length=10
        subplot_size=(length, length/2)
        
        if (key in categorical_keys) and (target_col in categorical_keys):

            r=cramerV(key, target_col, data)
            
            fig, axes=plt.subplots(1, 2, figsize=subplot_size)
            sns.countplot(x=key, data=data, ax=axes[0])
            sns.countplot(x=key, data=data, hue=target_col, ax=axes[1])
            plt.title(r)
            plt.tight_layout()
            plt.show()

        elif (key in categorical_keys) and not (target_col in categorical_keys):

            r=correlation_ratio(cat_key=key, num_key=target_col, data=data)
            
            fig, axes=plt.subplots(1, 2, figsize=subplot_size)
            sns.countplot(x=key, data=data, ax=axes[0])
            sns.violinplot(x=key, y=target_col, data=data, ax=axes[1])
            plt.title(r)
            plt.tight_layout()
            plt.show()

        elif not (key in categorical_keys) and (target_col in categorical_keys):

            r=correlation_ratio(cat_key=target_col, num_key=key, data=data)
            
            fig, axes=plt.subplots(1, 2, figsize=subplot_size)            
            sns.distplot(data[key], ax=axes[0], kde=False)
            g=sns.FacetGrid(data, hue=target_col)
            g.map(sns.distplot, key, ax=axes[1], kde=False)
            axes[1].set_title(r)
            axes[1].legend()            
            plt.tight_layout()
            plt.close()
            plt.show()

        else:

            r=data.corr().loc[key, target_col]
            
            sg=sns.jointplot(x=key, y=target_col, data=data, height=length*2/3)
            plt.title(r)
            plt.show()            

In addition, the following method is used on the way.

def is_categorical(data, key):  #Determine if it is a categorical variable
    
    col_type=data[key].dtype
    
    if col_type=='int':
        
        nunique=data[key].nunique()
        return nunique<6
    
    elif col_type=="float":
        return False
    
    else:
        return True

def correlation_ratio(cat_key, num_key, data):  #Find the correlation ratio
    
    categorical=data[cat_key]
    numerical=data[num_key]
    
    mean=numerical.dropna().mean()
    all_var=((numerical-mean)**2).sum()
    
    unique_cat=pd.Series(categorical.unique())
    unique_cat=list(unique_cat.dropna())
    
    categorical_num=[numerical[categorical==cat] for cat in unique_cat]
    categorical_var=[len(x.dropna())*(x.dropna().mean()-mean)**2 for x in categorical_num]    

    r=sum(categorical_var)/all_var
    
    return r

def cramerV(x, y, data):  #Find the number of correlations
    
    table=pd.crosstab(data[x], data[y])
    x2, p, dof, e=st.chi2_contingency(table, False)
    
    n=table.sum().sum()
    r=np.sqrt(x2/(n*(np.min(table.shape)-1)))

    return r

Let's apply it to titanic data (only part of the result is shown).

train_data=pd.read_csv("train.csv")
train_data=train_data.drop(["PassengerId", "Name", "Ticket", "Cabin"], axis=1)

categories=["Survived", "Pclass", "Sex", "Embarked"]
visualize_data(train_data, "Survived", categories)
countplot_corr.png distplot_corr.png

Finally

I tried to summarize the methods I have made so far. Now you can visualize the data and understand the correlation at once. The source code is on github, so feel free to use it!

Recommended Posts

Visualize data and understand correlation at the same time
Plot multiple maps and data at the same time with Python's matplotlib
wxPython: Draw animation and graph drawing at the same time
I tried the same data analysis with kaggle notebook (python) and Power BI at the same time ②
I tried the same data analysis with kaggle notebook (python) and Power BI at the same time ①
Browse .loc and .iloc at the same time in pandas DataFrame
Loop variables at the same time in the template
I want to make a music player and file music at the same time
[Understand in the shortest time] Python basics for data analysis
Steps to change table and column names in your Django model at the same time
Python built-in function ~ divmod ~ Let's get the quotient and remainder of division at the same time
[Python 3.8 ~] Rewrite arrays etc. at the same time as definition [tips]
[Statistics] Visualize and understand the Hamiltonian Monte Carlo method with animation.
About time series data and overfitting
Overlay and visualize Geo data and statistical data
python memo: enumerate () -get index and element of list at the same time and turn for statement
Set up a server that processes multiple connections at the same time
[Python] How to open two or more files at the same time
Easily visualize the correlation coefficient between variables
Call the python debugger at any time
Understand the Decision Tree and classify documents
Visualize the export data of Piyo log
Reading OpenFOAM time series data and sets data
Python open and io.open are the same
Let's visualize the relationship between average salary and industry with XBRL data and seaborn! (7/10)
Make sure to align the pre-processing at the time of forecast model creation and forecast
Type conversion of multiple columns of pandas DataFrame with astype at the same time
Turn multiple lists with a for statement at the same time in Python
Visualize railway line data and solve the shortest path problem (Python + Pandas + NetworkX)
[Python] Conversion memo between time data and numerical data
Interactively visualize data with TreasureData, Pandas and Jupyter.
Smoothing of time series and waveform data 3 methods (smoothing)
At the time of python update on ubuntu
Understand the TensorFlow namespace and master shared variables
I just wanted to extract the data of the desired date and time with Django
It's time to seriously think about the definition and skill set of data scientists