[Machine learning] Write the k-nearest neighbor method (k-nearest neighbor method) in python by yourself and recognize handwritten numbers.

Last time used the template matching method to identify handwritten numbers, but this time we changed the method and k-nearest neighbor method. ) Will be used.

k-nearest neighbor method overview

First of all, the outline of the method is as follows. If the black dots are used as the identification target data as shown in the figure below, k teacher data that are close to each other are searched for, and the label with the largest majority vote is selected as the estimated value. In the case of this example, there are 3 "5" and 2 "8", so "5" will be adopted as the estimated value. The features are that it is supervised analysis and that it consumes memory size and amount of calculation because all the data is used for calculation. There are several concepts of "close" such as ordinary distance (Euclidean distance) and Mahalanobis distance (using variance), but this time we will use ordinary distance.

plot4.png

Implement##

I will write it in Python at once. First, import the necessary libraries. I will not use the machine learning library such as sklearn this time. (Because I am on my own)

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from collections import defaultdict

Define a class called TrainDataSet. Since it is teacher data, it holds both the result label (handwritten number) and pixel data so that you can easily retrieve the necessary data, such as extracting only specific elements.

class TrainDataSet():
    def __init__(self, data):
        data = np.array(data)
        
        self.labels = data[:,0]
        self.data_set = data[:,1:]
        
    def __repr__(self):
        ret  = repr(self.labels) + "\n"
        ret += repr(self.data_set)
        return ret
        
    def get_data_num(self):
        return self.labels.size
        
    def get_labels(self, *args):
        if args is None:
            return self.labels
        else:
            return self.labels[args[0]]
    def get_data_set(self):
        return self.data_set
        
    def get_data_set_partial(self, *args):
        if args is None:
            return self.data_set
        else:
            return self.data_set[args[0]]
    def get_label(self, i):
        return self.labels[i]
    def get_data(self, i):
        return self.data_set[i,:]
    def get_data(self,i, j):
        return self.data_set[i][j]

Then load the data. If you want to try it, you can download the data from the following. train.csv ... Teacher data (42000) test_small.csv ... Identification target data (200 pieces)

size = 28
master_data= np.loadtxt('train_master.csv',delimiter=',',skiprows=1)
test_data= np.loadtxt('test_small.csv',delimiter=',',skiprows=1)

train_data_set = TrainDataSet(master_data)

It also defines a function that aggregates the results of k neighborhoods and outputs a list of how many data there are for each numeric label.

def get_list_sorted_by_val(k_result, k_dist):
    result_dict = defaultdict(int)
    distance_dict = defaultdict(float)
    
    #Aggregate by number label
    for i in k_result:
        result_dict[i] += 1
    
    #Aggregate total distance for each number label
    for i in range(len(k_dist)):
        distance_dict[k_result[i]] += k_dist[i]
    
    #Convert from dictionary type to list (for sorting)
    result_list = []
    order = 0
    for key, val in result_dict.items():
        order += 1
        result_list.append([key, val, distance_dict[key]])
    
    #Convert to ndarray type
    result_list = np.array(result_list) 

    return result_list

Now that all the preparations are complete, the identification process begins here. This time, select k = 5 data as the neighborhood.

k = 5
predicted_list = []    #Predicted value for numeric labels
k_result_list  = []    #k neighborhood list
k_distances_list = []  #Distance list between k numbers and data to be identified

# execute k-nearest neighbor method
for i in range(len(test_data)):
  
    #Take the difference between the identification target data and the teacher data
    diff_data = np.tile(test_data[i], (train_data_set.get_data_num(),1)) - train_data_set.get_data_set()
    
    sq_data   = diff_data ** 2       #Square each element and erase the sign
    sum_data  = sq_data.sum(axis=1)  #Add up each vector element
    distances = sum_data ** 0.5      #Take the route and use it as the distance
    ind = distances.argsort()        #Sort in ascending order of distance and extract the subscript
    k_result = train_data_set.get_labels(ind[0:k]) #Take out k pieces from the closest one
    k_dist   = distances[ind[0:k]]   #Extract k distance information
    
    k_distances_list.append(k_dist)
    k_result_list.append(k_result)
    
    #Aggregated from k data with numeric labels,(Numeric label,Quantity,distance)Generate a list of
    result_list = get_list_sorted_by_val(k_result, k_dist)
    candidate = result_list[result_list[:,1].argsort()[::-1]]

    counter = 0
    min = 0
    label_top = 0
    
    #If there are multiple number labels with the largest number, select the one with the smaller total distance.
    result_dict = {}
    for d in candidate:
        if d[1] in result_dict:
            result_dict[d[1]] += [(d[0], d[2])]
        else:
            result_dict[d[1]] =  [(d[0], d[2])]

    for d in result_dict[np.max(result_dict.keys())]:
        if counter == 0:
            label_top = d[0]
            min = d[1]
        else:
            if d[1] < min:
                label_top = d[0]
                min = d[1]
        counter += 1
                
    #Put the results in a list
    predicted_list.append(label_top)

Display the result.

# disp calc result
print "[Predicted Data List]"
for i in range(len(predicted_list)):
    print ("%d" % i) + "\t" + str(predicted_list[i])

print "[Detail Predicted Data List]"
print "index k units of neighbors, distances for every k units"
for i in range(len(k_result_list)):
    print ("%d" % i) + "\t" + str(k_result_list[i]) + "\t" + str(k_distances_list[i])

The output result file is here, and the result of comparing the predicted value identified as the correct answer is [here](https://gist.github. It is located at com / matsuken92 / 7ca89520ff4e9d2242b0). This time, I tried to identify 200 pieces using the k-nearest neighbor method, but the identification rate increased dramatically to 97% (194/200)! I think it will be practical if the identification is so good. Last time In the case of template matching performed, it was 80%, so it is quite good compared to this.

Analysis of Fail data

The following 6 data failed, but it seems to be troublesome even visually. The first data in the lower row is 6 or 4 even visually, which is subtle ... It can be said that the k-nearest neighbor method can almost identify the data except for the subtle handwritten numbers.


counter = 0
for d, num in zip(test_data, [3,76,128,132,147,165]):
    counter += 1
    X, Y = np.meshgrid(range(size),range(size))
    Z = test_data[num].reshape(size,size)
    Z = Z[::-1,:]
    flat_Z = Z.flatten()
    plot_digits(X, Y, Z, 2, 3, counter, "pred=%d" % predicted_list[num])

knn_fault.png

** fail data details **

index label pred k-nearest digits remarks
3 0 9 [ 0. 9. 9. 9. 2.] The nearest neighbor is 0, but ... it's good.
76 9 8 [ 8. 8. 9. 8. 3.] There is also one 9 but ...
128 7 1 [ 8. 1. 7. 8. 1.] Don't put extra lines in 7 ...
132 4??? 6 [ 6. 6. 6. 6. 6.] This is 4 or 6 or even visually subtle
147 4 7 [ 7. 7. 7. 7. 7.] I wonder if this is 7
165 3 2 [ 3. 2. 2. 2. 3.] 3 was also a good line, but ...

Recommended Posts

[Machine learning] Write the k-nearest neighbor method (k-nearest neighbor method) in python by yourself and recognize handwritten numbers.
Machine learning #k-nearest neighbor method and its implementation and various
Python learning memo for machine learning by Chainer Chapters 1 and 2
Implemented k-nearest neighbor method in python from scikit learn
[Machine learning] "Abnormality detection and change detection" Let's draw the figure of Chapter 1 in Python.
2. Multivariate analysis spelled out in Python 8-1. K-nearest neighbor method (scikit-learn)
A simple Python implementation of the k-nearest neighbor method (k-NN)
The result of Java engineers learning machine learning in Python www
A concrete method of predicting horse racing by machine learning and simulating the recovery rate
Machine learning ④ K-nearest neighbor Summary
2. Multivariate analysis spelled out in Python 8-3. K-nearest neighbor method [cross-validation]
Perform morphological analysis in the machine learning environment launched by GCE
Classification and regression in machine learning
Python: Preprocessing in Machine Learning: Overview
Note that I understand the algorithm of the machine learning naive Bayes classifier. And I wrote it in Python.
2. Multivariate analysis spelled out in Python 8-2. K-nearest neighbor method [Weighting method] [Regression model]
Basic machine learning procedure: ③ Compare and examine the selection method of features
Python learning memo for machine learning by Chainer until the end of Chapter 2
Python: Preprocessing in machine learning: Handling of missing, outlier, and imbalanced data
[Python] [scikit-learn] k-nearest neighbor method introductory memo
Write the test in a python docstring
[python] Frequently used techniques in machine learning
Write O_SYNC file in C and Python
Python: Preprocessing in machine learning: Data acquisition
Read the file line by line in Python
Read and write JSON files in Python
Python and machine learning environment construction (macOS)
Python: Preprocessing in machine learning: Data conversion
Python & Machine Learning Study Memo ④: Machine Learning by Backpropagation
Alignment algorithm by insertion method in Python
[Python] Precautions when retrieving data by scraping and putting it in the list
Coursera Machine Learning Challenges in Python: ex3 (Handwritten Number Recognition with Logistic Regression)
Divides the character string by the specified number of characters. In Ruby and Python.
List of main probability distributions used in machine learning and statistics and code in python
Fourier transform the wav file read by Python, reverse transform it, and write it again.
I tried to predict the change in snowfall for 2 years by machine learning
Machine Learning with docker (40) with anaconda (40) "Hands-On Data Science and Python Machine Learning" By Frank Kane
Get the last element of the array by splitting the string in Python and PHP
Learn the design pattern "Template Method" in Python
Upgrade the Azure Machine Learning SDK for Python
I tried the least squares method in Python
To dynamically replace the next method in python
Learn the design pattern "Factory Method" in Python
About the difference between "==" and "is" in python
The trick to write flatten concisely in python
Get a glimpse of machine learning in Python
Try implementing the Monte Carlo method in Python
Object tracking using OpenCV3 and Python3 (tracking feature points specified by the mouse using the Lucas-Kanade method)
[Note] How to write QR code and description in the same image with python