Letztes Mal verwendete die Template-Matching-Methode, um handschriftliche Zahlen zu identifizieren. Diesmal haben wir jedoch die Methode und die Methode des k-nächsten Nachbarn geändert. ) Wird verwendet.
Zunächst sieht der Umriss des Verfahrens wie folgt aus: Wenn die schwarzen Punkte als Identifikationszieldaten verwendet werden, wie in der folgenden Abbildung gezeigt, werden k Lehrerdaten gesucht, die nahe beieinander liegen, und das Etikett mit der größten Anzahl von Mehrheitsentscheidungen wird als geschätzter Wert ausgewählt. In diesem Beispiel gibt es 3 "5" und 2 "8", so dass "5" als geschätzter Wert übernommen wird. Die Merkmale sind, dass es sich um eine überwachte Analyse handelt und dass alle Daten für die Berechnung verwendet werden, sodass die Speichergröße und der Rechenaufwand verbraucht werden. Es gibt verschiedene Konzepte für "nahe", wie die gewöhnliche Entfernung (euklidische Entfernung) und die Maharanobis-Entfernung (unter Verwendung der Dispersion), aber dieses Mal werden wir die gewöhnliche Entfernung verwenden.
Ich werde es sofort in Python schreiben. Importieren Sie zunächst die erforderlichen Bibliotheken. Ich werde diesmal nicht die Bibliothek für maschinelles Lernen wie sklearn verwenden. (Weil ich alleine bin)
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from collections import defaultdict
Definieren Sie eine Klasse namens TrainDataSet. Da es sich um Lehrerdaten handelt, enthält es die Ergebnisbezeichnung (handschriftliche Nummer) und Pixeldaten, sodass Sie die erforderlichen Daten leicht abrufen können, z. B. nur bestimmte Elemente extrahieren.
class TrainDataSet():
def __init__(self, data):
data = np.array(data)
self.labels = data[:,0]
self.data_set = data[:,1:]
def __repr__(self):
ret = repr(self.labels) + "\n"
ret += repr(self.data_set)
return ret
def get_data_num(self):
return self.labels.size
def get_labels(self, *args):
if args is None:
return self.labels
else:
return self.labels[args[0]]
def get_data_set(self):
return self.data_set
def get_data_set_partial(self, *args):
if args is None:
return self.data_set
else:
return self.data_set[args[0]]
def get_label(self, i):
return self.labels[i]
def get_data(self, i):
return self.data_set[i,:]
def get_data(self,i, j):
return self.data_set[i][j]
Laden Sie dann die Daten. Wenn Sie es versuchen möchten, können Sie die Daten von den folgenden herunterladen. train.csv ... Lehrerdaten (42000) test_small.csv ... Identifikationszieldaten (200 Stück)
size = 28
master_data= np.loadtxt('train_master.csv',delimiter=',',skiprows=1)
test_data= np.loadtxt('test_small.csv',delimiter=',',skiprows=1)
train_data_set = TrainDataSet(master_data)
Es definiert auch eine Funktion, die die Ergebnisse von k Nachbarn aggregiert und eine Liste ausgibt, wie viele Daten für jede numerische Bezeichnung vorhanden sind.
def get_list_sorted_by_val(k_result, k_dist):
result_dict = defaultdict(int)
distance_dict = defaultdict(float)
#Aggregiert nach Nummernschild
for i in k_result:
result_dict[i] += 1
#Gesamtabstand für jedes Nummernschild
for i in range(len(k_dist)):
distance_dict[k_result[i]] += k_dist[i]
#Vom Wörterbuchtyp in Liste konvertieren (zum Sortieren)
result_list = []
order = 0
for key, val in result_dict.items():
order += 1
result_list.append([key, val, distance_dict[key]])
#In ndarray-Typ konvertieren
result_list = np.array(result_list)
return result_list
Nachdem alle Vorbereitungen abgeschlossen sind, beginnt hier der Identifizierungsprozess. Wählen Sie dieses Mal k = 5 Daten als Nachbarschaft.
k = 5
predicted_list = [] #Voraussichtlicher Wert des Nummernetiketts
k_result_list = [] #k Nachbarschaftsliste
k_distances_list = [] #Liste der Abstände zwischen k Zahlen und den zu identifizierenden Daten
# execute k-nearest neighbor method
for i in range(len(test_data)):
#Nehmen Sie den Unterschied zwischen den Identifikationszieldaten und den Lehrerdaten
diff_data = np.tile(test_data[i], (train_data_set.get_data_num(),1)) - train_data_set.get_data_set()
sq_data = diff_data ** 2 #Quadrieren Sie jedes Element und löschen Sie das Zeichen
sum_data = sq_data.sum(axis=1) #Fügen Sie jedes Vektorelement hinzu
distances = sum_data ** 0.5 #Nehmen Sie die Route und verwenden Sie sie als Entfernung
ind = distances.argsort() #Sortieren Sie in aufsteigender Reihenfolge der Entfernung und extrahieren Sie den Index
k_result = train_data_set.get_labels(ind[0:k]) #Nehmen Sie k Stücke aus dem nächsten heraus
k_dist = distances[ind[0:k]] #Extrahieren Sie k Entfernungsinformationen
k_distances_list.append(k_dist)
k_result_list.append(k_result)
#Aggregiert aus k Daten mit Nummernschildern,(Nummernschild,Menge,Entfernung)Generieren Sie eine Liste von
result_list = get_list_sorted_by_val(k_result, k_dist)
candidate = result_list[result_list[:,1].argsort()[::-1]]
counter = 0
min = 0
label_top = 0
#Wenn es mehrere Nummernschilder mit der größten Nummer gibt, wählen Sie das mit dem kleineren Gesamtabstand aus.
result_dict = {}
for d in candidate:
if d[1] in result_dict:
result_dict[d[1]] += [(d[0], d[2])]
else:
result_dict[d[1]] = [(d[0], d[2])]
for d in result_dict[np.max(result_dict.keys())]:
if counter == 0:
label_top = d[0]
min = d[1]
else:
if d[1] < min:
label_top = d[0]
min = d[1]
counter += 1
#Tragen Sie die Ergebnisse in eine Liste ein
predicted_list.append(label_top)
Zeigen Sie das Ergebnis an.
# disp calc result
print "[Predicted Data List]"
for i in range(len(predicted_list)):
print ("%d" % i) + "\t" + str(predicted_list[i])
print "[Detail Predicted Data List]"
print "index k units of neighbors, distances for every k units"
for i in range(len(k_result_list)):
print ("%d" % i) + "\t" + str(k_result_list[i]) + "\t" + str(k_distances_list[i])
Die Ausgabeergebnisdatei lautet hier, und das Ergebnis des Vergleichs des als richtige Antwort identifizierten vorhergesagten Werts lautet [hier](https: //gist.github). Es befindet sich unter com / matsuken92 / 7ca89520ff4e9d2242b0). Dieses Mal habe ich versucht, 200 Teile mit der Methode des nächsten Nachbarn zu identifizieren, aber die Identifikationsrate stieg dramatisch auf 97% (194/200)! Ich denke, es wird praktisch sein, wenn die Identifizierung so gut ist. Letztes Mal Im Fall des durchgeführten Vorlagenabgleichs waren es 80%, daher ist es im Vergleich dazu ziemlich gut.
Die folgenden 6 Daten sind fehlgeschlagen, aber es scheint sogar visuell problematisch zu sein. Die ersten Daten in der unteren Reihe sind 6 oder 4, sogar visuell. Man kann sagen, dass die meisten handgeschriebenen Zahlen mit Ausnahme der subtilen durch die Methode des nächsten Nachbarn k identifiziert werden können.
counter = 0
for d, num in zip(test_data, [3,76,128,132,147,165]):
counter += 1
X, Y = np.meshgrid(range(size),range(size))
Z = test_data[num].reshape(size,size)
Z = Z[::-1,:]
flat_Z = Z.flatten()
plot_digits(X, Y, Z, 2, 3, counter, "pred=%d" % predicted_list[num])
** Daten fehlschlagen **
index | label | pred | k-nearest digits | remarks |
---|---|---|---|---|
3 | 0 | 9 | [ 0. 9. 9. 9. 2.] | Die nächste Nachbarschaft ist 0, aber ... es ist köstlich. |
76 | 9 | 8 | [ 8. 8. 9. 8. 3.] | Es gibt auch eine 9 aber ... |
128 | 7 | 1 | [ 8. 1. 7. 8. 1.] | Setzen Sie keine zusätzlichen Zeilen in 7 ... |
132 | 4??? | 6 | [ 6. 6. 6. 6. 6.] | Dies ist 4 oder 6 oder sogar visuell subtil |
147 | 4 | 7 | [ 7. 7. 7. 7. 7.] | Ich frage mich, ob das 7 ist |
165 | 3 | 2 | [ 3. 2. 2. 2. 3.] | 3 war auch eine gute Linie, aber ... |