[PYTHON] K-nearest neighbor method (multiclass classification)

What is the K-nearest neighbor method?

KNN (K Nearest Neighbor). A method for class discrimination. The training data is plotted on the vector space, and when unknown data is obtained, any K pieces are acquired in order of distance from the unknown data, and the class to which the data belongs is estimated by majority vote.

For example, in the case of the figure below, the flow of class discrimination is as follows. 1 Plot known data (learning data) as yellow and purple circles. 2 Decide the number of K. Like K = 3. 3 If a red star is obtained as unknown data, acquire three from the closest points. 4 Estimate the class to which the three classes belong by majority vote. This time, we presume that the unknown red star belongs to Class B.

スクリーンショット 2016-05-04 3.33.02.png

Usage data preparation

Prepare the iris dataset with sklearn.

{get_iris_dataset.py}


from sklearn.datasets import load_iris
iris= load_iris() #iris data acquisition
X = iris.data     #Explanatory variable(Variables for class estimation)
Y = iris.target   #Objective variable(Class value)

#Convert iris data to DataFrame
iris_data = DataFrame(X, columns=['Sepal Length', 'Sepal Width', 'Petal Length', 'Petal Width'])
iris_target = DataFrame(Y, columns=['Species'])

# iris_Since target is a value between 0 and 2, it is difficult to understand, so convert it to Ayame's name
def flower(num):
"""Name conversion function"""
    if num == 0:
        return 'Setosa'
    elif num == 1:
        return 'Veriscolour'
    else:
        return 'Virginica'

iris_target['Species'] = iris_target['Species'].apply(flower)
iris = pd.concat([iris_data, iris_target], axis=1)

Data overview

{describe_iris.py}


iris.head()
スクリーンショット 2016-05-04 3.45.38.png

スクリーンショット 2016-05-04 3.50.37.png  の長さと幅のデータ

Pair plot with seaboan and see the overview by class

{desplay_each_data.py}


import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')
%matplotlib inline

sns.pairplot(iris, hue = 'Species', size =2) # hue:Split by specified data
スクリーンショット 2016-05-04 3.54.59.png

Setosa [blue dot] seems to be easy to classify. Veriscolour [green dot] and Virginia [red dot] may be classified by Petal Length? Impression about.

Try

Run KNN with sklearn.

{do_knn.py}


from sklearn.neighbors import KNeighborsClassifier
from sklearn.cross_validation import train_test_split #For train and test split

#Data preparation for train and test. test_Specify the ratio of test data with size. random_For state, set the seed value appropriately.
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size = 0.4, random_state=3) 

knn = KNeighborsClassifier(n_neighbors=6) #Instance generation. n_neighbors:Number of K
knn.fit(X_train, Y_train)                 #Model creation execution
Y_pred = knn.predict(X_test)              #Predictive execution

#Import and execute library for accuracy check
from sklearn import metrics
metrics.accuracy_score(Y_test, Y_pred)    #Prediction accuracy measurement
> 0.94999999999999996

Accuracy of about 95%.

The accuracy changes depending on the number of K. → I don't know which K should be used, so for the time being, I will change K in various ways and draw a graph of changes in accuracy.

{create_graph_knn_accracy_change_k.py}


accuracy = []
for k in range(1, 90):
    knn = KNeighborsClassifier(n_neighbors=k) #Instance generation.
    knn.fit(X_train, Y_train)                 #Model creation execution
    Y_pred = knn.predict(X_test)              #Predictive execution
    accuracy.append(metrics.accuracy_score(Y_test, Y_pred)) #Precision storage

plt.plot(k_range, accuracy)

The result of turning 90 times

スクリーンショット 2016-05-04 4.10.41.png

K = 3? Is enough. If it exceeds 30, the accuracy will deteriorate. Since there are only 90 learning data this time, there are only about 30 learning data per class. If the number of K exceeds 30, if all the data of the correct answer class is included, only different classes can be picked up in the nearest neighbor, so it is expected that the accuracy will get worse and worse.

Recommended Posts

K-nearest neighbor method (multiclass classification)
[Python] [scikit-learn] k-nearest neighbor method introductory memo
SVM (multi-class classification)
Naive Bayes (multiclass classification)
Keras multiclass classification Iris
Machine learning #k-nearest neighbor method and its implementation and various
Implemented k-nearest neighbor method in python from scikit learn
2. Multivariate analysis spelled out in Python 8-1. K-nearest neighbor method (scikit-learn)
ROC curve for multiclass classification
Machine learning ④ K-nearest neighbor Summary
2. Multivariate analysis spelled out in Python 8-3. K-nearest neighbor method [cross-validation]
Classification by k-nearest neighbor method (kNN) by python ([High school information department information II] teaching materials for teacher training)
2. Multivariate analysis spelled out in Python 8-2. K-nearest neighbor method [Weighting method] [Regression model]