[PYTHON] How to deal with imbalanced data

Introduction

When I had to carry out a classification problem in my work, I sometimes dealt with imbalanced data, so when I was investigating how to deal with it, I found methods called "under sampling" and "over sampling", so I will organize them. ..

What is imbalanced data?

Imbalanced data refers to data that has a large bias in the distribution of objective variables. Simply put, when there is data such as label 0 being 99% and label 1 being 1%, it is known that if a classification model is generated without any special measures, the classification accuracy of the minority will be low. (It is natural that 99% of all labels are 0, so it is natural). When dealing with imbalanced data, it is often the purpose to identify minority data, so it is necessary to devise it.

What is under sampling?

This is a method of randomly extracting from the majority data to match the number of minority data. This method is easy, but it should be noted that it causes the following two problems because it discards useful data for training at the time of sampling and the total number of data is insufficient.

  1. The problem that the variance of the learned classifier becomes large
  2. The problem that the posterior distribution obtained after learning is distorted

Countermeasures and research have been conducted on the above issues, but this time we will omit them.

What is over sampling?

It is to supplement the shortage data based on the minority data. Here's a simple way to generate new data:

However, with this method, the correlation between features cannot be considered. For example, despite the correlation between height and weight, data of height 170 cm and weight 60 kg may be generated, which may affect the accuracy of the model.

SMOTE (Synthetic Minority Oversampling TEchnique) is a method to solve such problems. SMOTE is a method of connecting each data point with the smaller label with a line and randomly generating any point on the line segment as artificial data. There are some extended methods, but I will omit them this time.

Correspondence to unbalanced data (implementation)

This time, we will use the Credit Card Fraud Detection dataset provided by kaggle. Each record has a value indicating whether or not it is abused (1 is abused), but most of them are 0, which is a very unbalanced data set.

The python code is below.

#Import required libraries
import pandas as pd
from sklearn.datasets import make_classification

from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix

#Read Credit Card Fraud Detection data
df = pd.read_csv('creditcard.csv')

print(df.shape)
df.head()

スクリーンショット 2021-01-13 9.56.57.png

Next, whether it is fraudulent or not is in the'Class' column, so let's look at the data bias.

#Check the number of data in the classification class
df['Class'].value_counts()

スクリーンショット 2021-01-13 9.58.41.png

You can see that it is biased. At first, I will try to create a classification model with a gradient boosting tree without any support.

#Confirmation of missing values
df.isnull().sum()

#Delete missing value data
df = df.dropna()

#Separate data for training and validation
x = df.iloc[:, 1:30]
y = df['Class']
x_train, x_test, y_train, y_test = train_test_split(x, y, random_state = 40)

#Classification model creation
model = GradientBoostingClassifier()
model.fit(x_train, y_train)

#Predicted test data with the created model
y_pred = model.predict(x_test)

#Accuracy and confusion matrix
print('Confusion matrix(test):\n{}'.format(confusion_matrix(y_test, y_pred)))
print('Accuracy(test) : %.4f' %accuracy_score(y_test, y_pred))
# Confusion matrix(test):
# [[3473    0]
#  [   2   14]]
# Accuracy(test) : 0.9994

#Precision and Recall
tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
print('precision : %.4f'%(tp / (tp + fp)))
print('recall : %.4f'%(tp / (tp + fn)))
# precision : 1.0000
# recall : 0.8750

#F value
f1_score(y_pred, y_test)
# 0.9333333333333333

Next, I will try Under Sampling.

#Library
from imblearn.under_sampling import RandomUnderSampler

#Save the number of positive examples
positive_count_train = int(y_train.sum())
print('positive count:{}'.format(positive_count_train))

#Downsampling negative examples until 10% positive
rus = RandomUnderSampler(sampling_strategy={0:positive_count_train*9, 1:positive_count_train}, random_state=40)

#Reflected in learning data
x_train_resampled, y_train_resampled = rus.fit_sample(x_train, y_train)
print('y_train_undersample:\n{}'.format(pd.Series(y_train_resampled).value_counts()))
# positive count:40
# y_train_undersample:
# 0.0    360
# 1.0     40

#Classification model creation
mod = GradientBoostingClassifier()
mod.fit(x_train_resampled, y_train_resampled)

#Predicted value calculation
y_pred = mod.predict(x_test)

#Accuracy and confusion matrix
print('Confusion matrix(test):\n{}'.format(confusion_matrix(y_test, y_pred)))
print('Accuracy(test) : %.4f' %accuracy_score(y_test, y_pred))
# Confusion matrix(test):
# [[3458   15]
#  [   2   14]]
# Accuracy(test) : 0.9951

#Precision and Recall
tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
print('precision : %.4f'%(tp / (tp + fp)))
print('recall : %.4f'%(tp / (tp + fn)))
# precision : 0.4828
# recall : 0.8750

#F value
f1_score(y_pred, y_test)
# 0.6222222222222222

Looking at the result of the F value, it seems that the result has been deteriorated. I won't do a detailed survey this time, but Under Sampling doesn't seem to be good.

Next, perform Over Sampling.

#Library
from imblearn.over_sampling import RandomOverSampler

#Give a good example up to 10%
ros = RandomOverSampler(sampling_strategy={0:x_train.shape[0], 1:x_train.shape[0]//9}, random_state = 40)

#Reflected in learning data
x_train_resampled, y_train_resampled = ros.fit_sample(x_train, y_train)

#Classification model creation
mod = GradientBoostingClassifier()
mod.fit(x_train_resampled, y_train_resampled)

#Predicted value calculation
y_pred = mod.predict(x_test)

#Accuracy and confusion matrix
print('Confusion matrix(test):\n{}'.format(confusion_matrix(y_test, y_pred)))
print('Accuracy(test) : %.4f' %accuracy_score(y_test, y_pred))
# Confusion matrix(test):
# [[3471    2]
#  [   2   14]]
# Accuracy(test) : 0.9989

#Precision and Recall
tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
print('precision : %.4f'%(tp / (tp + fp)))
print('recall : %.4f'%(tp / (tp + fn)))
# precision : 0.8750
# recall : 0.8750

#F value
f1_score(y_pred, y_test)
# 0.875

Over Sampling gives better results than Under Sampling, but the F-number is better when nothing is done. Finally, I will try using SOMTE, which is a combination of Under Sampling and Over Sampling.

#Library
from imblearn.over_sampling import SMOTE

# SMOTE
smote = SMOTE(sampling_strategy={0:x_train.shape[0], 1:x_train.shape[0]//9}, random_state = 40)
x_train_resampled_smoth, y_train_resampled_smoth = smote.fit_sample(x_train, y_train)

#Classification model creation
mod = GradientBoostingClassifier()
mod.fit(x_train_resampled_smoth, y_train_resampled_smoth)

#Predicted value calculation
y_pred = mod.predict(x_test)

#Accuracy and confusion matrix
print('Confusion matrix(test):\n{}'.format(confusion_matrix(y_test, y_pred)))
print('Accuracy(test) : %.4f' %accuracy_score(y_test, y_pred))
# Confusion matrix(test):
# [[3469    4]
#  [   2   14]]
# Accuracy(test) : 0.9983

#Precision and Recall
tn, fp, fn, tp = confusion_matrix(y_test, y_pred).ravel()
print('precision : %.4f'%(tp / (tp + fp)))
print('recall : %.4f'%(tp / (tp + fn)))
# precision : 0.7778
# recall : 0.8750

#F value
f1_score(y_pred, y_test)
# 0.823529411764706

This was also subtle as a result.

at the end

Thank you for reading to the end. This time, we have organized how to deal with imbalanced data. As a result, we did not get good results, but I think that imbalance data is a problem that occurs frequently at the manufacturing site, so I hope that you can refer to it.

If you have a request for correction, we would appreciate it if you could contact us.

Recommended Posts

How to deal with imbalanced data
How to deal with imbalanced data
How to deal with DistributionNotFound errors
How to Data Augmentation with PyTorch
How to deal with enum compatibility errors
[Python] How to deal with module errors
How to read problem data with paiza
How to deal with memory leaks in matplotlib.pyplot
How to create sample CSV data with hypothesis
How to deal with errors when hitting pip ②
can't pickle annoy. How to deal with Annoy objects
How to deal with run-time errors in subprocess.call
How to scrape horse racing data with BeautifulSoup
How to deal with module'tensorflow' has no attribute'〇〇'
How to deal with SessionNotCreatedException when using Selenium
How to update with SQLAlchemy?
How to cast with Theano
How to Alter with SQLAlchemy?
How to separate strings with','
How to RDP with Fedora31
How to handle data frames
2 ways to deal with SessionNotCreatedException
How to Delete with SQLAlchemy?
How to use xgboost: Multi-class classification with iris data
How to scrape image data from flickr with python
How to convert horizontally held data to vertically held data with pandas
How to deal with Django's Template Does Not Exist
How to deal with pyenv initialization failure in fish 3.1.0
How to get more than 1000 data with SQLAlchemy + MySQLdb
How to extract non-missing value nan data with pandas
[Python] How to deal with pandas read_html read error
How to deal with Executing transaction: failed in Anaconda
How to extract non-missing value nan data with pandas
How to cancel RT with tweepy
How to read e-Stat subregion data
How to use virtualenv with PowerShell
How to install python-pip with ubuntu20.04LTS
How to get started with Scrapy
How to get started with Python
[Linux] How to deal with garbled characters when viewing files
How to get started with Django
[AWS] How to deal with "Invalid codepoint" error in CloudSearch
A story about how to deal with the CORS problem
How to use FTP with Python
For beginners, how to deal with common errors in keras
How to calculate date with python
How to install mysql-connector with pip3
How to INNER JOIN with SQLAlchemy
How to deal with UnicodeDecodeError when executing google image download
How to install Anaconda with pyenv
[Introduction to Python] How to get data with the listdir function
How to collect machine learning data
How to authenticate with Django Part 2
How to authenticate with Django Part 3
How to deal with python installation error in pyenv (BUILD FAILED)
How to extract features of time series data with PySpark Basics
How to deal with "You have multiple authentication backends configured ..." (Django)
How to deal with errors when installing whitenoise and deploying to Heroku
How to install pandas on EC2 (How to deal with MemoryError and PermissionError)
How to deal with errors when installing Python and pip with choco
How to do arithmetic with Django template