When building a machine learning model with certain table data, there was a variable that I wanted to eliminate the difference between classes, so I implemented that process.
Undersampling is performed for each attribute in each specified section. This time, we used seaborn titanic data to match the number of people at each teenage age between the'yes' and'no' groups with alive. ..
OS:Windows10 conda:4.8.3 python:3.8.3 pandas:0.25.3 matplotlib:3.3.1 seaborn:0.11.0
Get the data as follows.
load_dataset.py
import seaborn as sns
data = sns.load_dataset("titanic")
Regarding the acquired data, the age distribution for each teenager by alive (survival status) is as follows.
display_gragh.py
import matplotlib.pyplot as plt
sns.set_style(style='darkgrid')
fig, ax = plt.subplots(1,1, figsize=(4,4))
ax.set_xticks(range(0,100,10))
ax.set_ylim(0,150)
ax.set_ylabel('the number of people')
sns.distplot(data['age'][data['alive']=='yes'], kde=False, rug=False, bins=range(0,100,10),
             label='alive', ax=ax)
sns.distplot(data['age'][data['alive']=='no'], kde=False, rug=False, bins=range(0,100,10),
             label='dead', ax=ax)
ax.legend()
plt.show()

I created the following function for data matching.
import pandas as pd
def adjust_number(data, target_column, attribute, period):
    '''
    target_column:Column name to be adjusted
    attribute    :Attribute to be adjusted (Adjust the number of target columns between these attributes)
    period       :Section width to adjust
    '''
    ##Initial section setting
    #The lower limit starts from 0 if the minimum value of the target data group is 0 or more, and starts from the minimum value if it is less than 0.
    lower = 0 if data[target_column].min() >= 0 else data[target_column].min()
    #The upper limit is the lower limit + section width-Start from 1
    upper = lower+period-1
    data_adjusted = pd.DataFrame() #For data storage after adjustment
    maximum = data[target_column].max() #Get the maximum value of the data group to be adjusted
    #Repeat until the lower limit exceeds the maximum value
    while lower <= maximum:
        #Extract data for the target section
        data_in_range = data[(lower<=data.loc[:,target_column]) & (data.loc[:,target_column]<=upper)]
        
        #If there is no data in the target section or there is an attribute for which there is no data in the target section, go to the next section.
        #(Because of undersampling, if the number of data of any attribute is 0, all will be 0)
        if len(data_in_range) == 0 or set(data[attribute]) != set(data_in_range[attribute]):
            lower += period
            upper += period
            continue
    
        else:
            #Acquisition of the number of data for each attribute in the target section
            counts = data_in_range[attribute].value_counts()
        
            #Undersampling by attribute
            sample = pd.DataFrame()
            for att in counts.index:            
                sample = data_in_range[data_in_range[attribute]==att].sample(n=counts.min(), random_state=42)
                
                #Concatenate the adjusted data of the target section to the stored adjusted data
                data_adjusted = pd.concat([data_adjusted, sample],axis=0, ignore_index=True)
        
        #To the next section
        lower += period
        upper += period
    
    return data_adjusted
The age distribution for each teenager after processing with this function (after matching the ages) is as follows. It is posted again before the correction.
data_adjusted = adjust_number(data, target_column='age', attribute='alive', period=10)
fig, ax = plt.subplots(1,1, figsize=(4,4))
ax.set_xticks(range(0,100,10))
ax.set_ylim(0,150)
ax.set_ylabel('the number of people')
sns.distplot(data_adjusted['age'][data_adjusted['alive']=='yes'], kde=False, rug=False, bins=range(0,100,10),
             label='alive', ax=ax)
sns.distplot(data_adjusted['age'][data_adjusted['alive']=='no'], kde=False, rug=False, bins=range(0,100,10),
             label='dead', ax=ax)
ax.legend()
plt.show()
▼ After correction ▼ Before correction
     
I was able to correct it safely.
Even if I changed the section width from 10 to 5 (matching the number of people every 5 years old), it worked without problems.
data_adjusted = adjust_number(data, target_column='age', attribute='alive', period=5)
▼ After correction ▼ Before correction
     
There is no problem even if the adjustment target is changed to fare.
data_adjusted = adjust_number(data, target_column='fare', attribute='alive', period=30)
▼ After correction ▼ Before correction
     
There is no problem even if you change the target attribute to sex (gender).
data_adjusted = adjust_number(data, target_column='age', attribute='sex', period=10)
▼ After correction ▼ Before correction
     
That's it. Thank you for visiting.
Recommended Posts