Last time used the template matching method to identify handwritten numbers, but this time we changed the method and k-nearest neighbor method. ) Will be used.
First of all, the outline of the method is as follows. If the black dots are used as the identification target data as shown in the figure below, k teacher data that are close to each other are searched for, and the label with the largest majority vote is selected as the estimated value. In the case of this example, there are 3 "5" and 2 "8", so "5" will be adopted as the estimated value. The features are that it is supervised analysis and that it consumes memory size and amount of calculation because all the data is used for calculation. There are several concepts of "close" such as ordinary distance (Euclidean distance) and Mahalanobis distance (using variance), but this time we will use ordinary distance.
I will write it in Python at once. First, import the necessary libraries. I will not use the machine learning library such as sklearn this time. (Because I am on my own)
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from collections import defaultdict
Define a class called TrainDataSet. Since it is teacher data, it holds both the result label (handwritten number) and pixel data so that you can easily retrieve the necessary data, such as extracting only specific elements.
class TrainDataSet():
def __init__(self, data):
data = np.array(data)
self.labels = data[:,0]
self.data_set = data[:,1:]
def __repr__(self):
ret = repr(self.labels) + "\n"
ret += repr(self.data_set)
return ret
def get_data_num(self):
return self.labels.size
def get_labels(self, *args):
if args is None:
return self.labels
else:
return self.labels[args[0]]
def get_data_set(self):
return self.data_set
def get_data_set_partial(self, *args):
if args is None:
return self.data_set
else:
return self.data_set[args[0]]
def get_label(self, i):
return self.labels[i]
def get_data(self, i):
return self.data_set[i,:]
def get_data(self,i, j):
return self.data_set[i][j]
Then load the data. If you want to try it, you can download the data from the following. train.csv ... Teacher data (42000) test_small.csv ... Identification target data (200 pieces)
size = 28
master_data= np.loadtxt('train_master.csv',delimiter=',',skiprows=1)
test_data= np.loadtxt('test_small.csv',delimiter=',',skiprows=1)
train_data_set = TrainDataSet(master_data)
It also defines a function that aggregates the results of k neighborhoods and outputs a list of how many data there are for each numeric label.
def get_list_sorted_by_val(k_result, k_dist):
result_dict = defaultdict(int)
distance_dict = defaultdict(float)
#Aggregate by number label
for i in k_result:
result_dict[i] += 1
#Aggregate total distance for each number label
for i in range(len(k_dist)):
distance_dict[k_result[i]] += k_dist[i]
#Convert from dictionary type to list (for sorting)
result_list = []
order = 0
for key, val in result_dict.items():
order += 1
result_list.append([key, val, distance_dict[key]])
#Convert to ndarray type
result_list = np.array(result_list)
return result_list
Now that all the preparations are complete, the identification process begins here. This time, select k = 5 data as the neighborhood.
k = 5
predicted_list = [] #Predicted value for numeric labels
k_result_list = [] #k neighborhood list
k_distances_list = [] #Distance list between k numbers and data to be identified
# execute k-nearest neighbor method
for i in range(len(test_data)):
#Take the difference between the identification target data and the teacher data
diff_data = np.tile(test_data[i], (train_data_set.get_data_num(),1)) - train_data_set.get_data_set()
sq_data = diff_data ** 2 #Square each element and erase the sign
sum_data = sq_data.sum(axis=1) #Add up each vector element
distances = sum_data ** 0.5 #Take the route and use it as the distance
ind = distances.argsort() #Sort in ascending order of distance and extract the subscript
k_result = train_data_set.get_labels(ind[0:k]) #Take out k pieces from the closest one
k_dist = distances[ind[0:k]] #Extract k distance information
k_distances_list.append(k_dist)
k_result_list.append(k_result)
#Aggregated from k data with numeric labels,(Numeric label,Quantity,distance)Generate a list of
result_list = get_list_sorted_by_val(k_result, k_dist)
candidate = result_list[result_list[:,1].argsort()[::-1]]
counter = 0
min = 0
label_top = 0
#If there are multiple number labels with the largest number, select the one with the smaller total distance.
result_dict = {}
for d in candidate:
if d[1] in result_dict:
result_dict[d[1]] += [(d[0], d[2])]
else:
result_dict[d[1]] = [(d[0], d[2])]
for d in result_dict[np.max(result_dict.keys())]:
if counter == 0:
label_top = d[0]
min = d[1]
else:
if d[1] < min:
label_top = d[0]
min = d[1]
counter += 1
#Put the results in a list
predicted_list.append(label_top)
Display the result.
# disp calc result
print "[Predicted Data List]"
for i in range(len(predicted_list)):
print ("%d" % i) + "\t" + str(predicted_list[i])
print "[Detail Predicted Data List]"
print "index k units of neighbors, distances for every k units"
for i in range(len(k_result_list)):
print ("%d" % i) + "\t" + str(k_result_list[i]) + "\t" + str(k_distances_list[i])
The output result file is here, and the result of comparing the predicted value identified as the correct answer is [here](https://gist.github. It is located at com / matsuken92 / 7ca89520ff4e9d2242b0). This time, I tried to identify 200 pieces using the k-nearest neighbor method, but the identification rate increased dramatically to 97% (194/200)! I think it will be practical if the identification is so good. Last time In the case of template matching performed, it was 80%, so it is quite good compared to this.
The following 6 data failed, but it seems to be troublesome even visually. The first data in the lower row is 6 or 4 even visually, which is subtle ... It can be said that the k-nearest neighbor method can almost identify the data except for the subtle handwritten numbers.
counter = 0
for d, num in zip(test_data, [3,76,128,132,147,165]):
counter += 1
X, Y = np.meshgrid(range(size),range(size))
Z = test_data[num].reshape(size,size)
Z = Z[::-1,:]
flat_Z = Z.flatten()
plot_digits(X, Y, Z, 2, 3, counter, "pred=%d" % predicted_list[num])
** fail data details **
index | label | pred | k-nearest digits | remarks |
---|---|---|---|---|
3 | 0 | 9 | [ 0. 9. 9. 9. 2.] | The nearest neighbor is 0, but ... it's good. |
76 | 9 | 8 | [ 8. 8. 9. 8. 3.] | There is also one 9 but ... |
128 | 7 | 1 | [ 8. 1. 7. 8. 1.] | Don't put extra lines in 7 ... |
132 | 4??? | 6 | [ 6. 6. 6. 6. 6.] | This is 4 or 6 or even visually subtle |
147 | 4 | 7 | [ 7. 7. 7. 7. 7.] | I wonder if this is 7 |
165 | 3 | 2 | [ 3. 2. 2. 2. 3.] | 3 was also a good line, but ... |
Recommended Posts