[PYTHON] Ich habe versucht, Kanas handschriftliche Zeichenerkennung Teil 1/3 zuerst von MNIST zu machen

Überblick

Ich habe versucht, Zeichen zu erkennen, indem ich Kana in die GUI eingegeben und ein Modell verwendet habe, das durch vorheriges Training durch maschinelles Lernen erstellt wurde.

Überprüfen Sie zuerst das Gefühl und die Genauigkeit von CNN mit MNIST, geben Sie dann die tatsächlichen Kana-Daten für das Training an und verknüpfen Sie sie schließlich mit der GUI.

Nächstes Mal (2/3): https://qiita.com/tfull_tf/items/968bdb8f24f80d57617e Nächstes Mal (3/3): https://qiita.com/tfull_tf/items/d9fe3ab6c1e47d1b2e1e

Den gesamten Code finden Sie unter: https://github.com/tfull/character_recognition

Modellbau mit MNIST

Erstellen Sie Ihr eigenes Modell und fahren Sie mit dem Zug. Testen Sie anhand des beliebten handgeschriebenen Ziffern-Datensatzes MNIST, wie genau er ist.

Da es sich bei MNIST um 28x28-Graustufendaten handelt, geben Sie diese als (Kanal, Breite, Höhe) = (1, 28, 28) ein. Da die Zahlen 0 bis 9 sind, gibt es 10 Klassifizierungsziele und 10 Wahrscheinlichkeiten werden ausgegeben.

import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout2d(0.3)
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(12 * 12 * 32, 128)
        self.relu3 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.3)
        self.linear2 = nn.Linear(128, 10)
        self.softmax = nn.Softmax(dim = 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool(x)
        x = self.dropout1(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.relu3(x)
        x = self.dropout2(x)
        x = self.linear2(x)
        x = self.softmax(x)
        return x

Es wird in eine Dimension umgewandelt und über zwei Faltschichten und eine nachfolgende Pooling-Schicht durch zwei vollständig verbundene Schichten geleitet. Die Aktivierungsfunktion ist ReLU, und der Umriss des Modells besteht darin, eine Dropout-Ebene einzufügen, um ein Überlernen in der Mitte zu verhindern.

Datenerfassung

import torchvision

download_flag = not os.path.exists(data_directory + "/mnist")

mnist_train = torchvision.datasets.MNIST(
    data_directory + "/mnist",
    train = True,
    download = download_flag,
    transform = torchvision.transforms.ToTensor()
)

mnist_test = torchvision.datasets.MNIST(
    data_directory + "/mnist",
    train = False,
    download = download_flag,
    transform = torchvision.transforms.ToTensor()
)

Speichern Sie die MNIST-Daten lokal und verwenden Sie sie. Definieren Sie data_directory so, dass es heruntergeladen wird, wenn es nicht vorhanden ist. Auf diese Weise habe ich sichergestellt, dass nur das erste Mal heruntergeladen wird.

Vorbereitung zum Lernen

import torch
import torch.optim as optim

train_loader = torch.utils.data.DataLoader(mnist_train,  batch_size = 100,  shuffle = True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size = 1000, shuffle = False)

model = Model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.001)

Verwenden Sie DataLoader, um die Daten nacheinander abzurufen.

Stellen Sie das Modell, die Fehlerfunktion und den Optimierungsalgorithmus ein. Wir haben den Kreuzentropiefehler Adam übernommen.

Ausbildung

n_epoch = 2

model.train()

for i_epoch in range(n_epoch):
    for i_batch, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        print("epoch: {}, train: {}, loss: {}".format(i_epoch + 1, i_batch + 1, loss.item()))

Eine Reihe von Trainingsoperationen wird in einer Schleife ausgeführt, z. B. das Geben von Bilddaten (Eingaben) an ein Modell, das Vergleichen der Ausgabe (Ausgabe) mit den richtigen Antwortdaten (Beschriftungen), das Auffinden des Fehlers und das Zurückgeben. Ich werde.

Ich denke, dass es für das Training nicht ausreicht, alle Daten einmal anzugeben. Deshalb setze ich die Anzahl der Epochen (n_epoch) auf 2 und gebe jedem Daten n_epoch-Zeiten für das Training. Die Anzahl der Epochen ist meine Erfahrung, aber ich denke, dass ungefähr 2 bis 3 genau richtig sind. Ich denke, es hängt von der Anzahl der Daten ab.

Auswertung

correct_count = 0
record_count = 0

model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, prediction = torch.max(outputs.data, 1)
        judge = prediction == labels
        correct_count += int(judge.sum())
        record_count += len(judge)

print("Accuracy: {:.2f}%".format(correct_count / record_count * 100))

Wir geben die numerischen Daten (Eingaben) des Bildes in das Modell ein, und die höchste der 10 herauskommenden Wahrscheinlichkeiten ist das Auswahlergebnis (Vorhersage). Es vergleicht, ob es mit den richtigen Antwortdaten (Labels) übereinstimmt, gibt True / False zurück und berechnet die Anzahl von True (korrekte_Zahl) in Bezug auf die Gesamtzahl (record_count), um die richtige Antwortrate zu erhalten.

Resultate und Diskussionen

Das Ergebnis lag im Durchschnitt mehrfach bei etwa 97%.

Ich denke, der Wert der richtigen Antwortrate ist hoch, aber ich habe 3 Mal in 100 Mal versagt. Ich denke, es wird eine andere Sache sein, ob Menschen dies tolerieren können. Es gibt jedoch einige schmutzige Zeichen in den MNIST-Bilddaten, die für den Menschen schwer zu unterscheiden sind. In diesem Sinne kann ein Fehler von 3% unvermeidbar sein.

MNIST hat 10 Auswahlmöglichkeiten von 0 bis 9, aber da es mehr als 100 in Hiragana und Katakana für Kana gibt, ist es schwierig zu klassifizieren und Sie müssen auf einen weiteren Rückgang der richtigen Antwortrate vorbereitet sein.

Recommended Posts

Ich habe versucht, Kanas handschriftliche Zeichenerkennung Teil 1/3 zuerst von MNIST zu machen
Ich habe versucht, Kanas handschriftliche Zeichenerkennung durchzuführen. Teil 2/3 Datenerstellung und Lernen
Ich habe versucht, Kanas handschriftliche Zeichenerkennung Teil 3/3 Zusammenarbeit mit der GUI mithilfe von Tkinter durchzuführen
Ich habe versucht, Perceptron Teil 1 [Deep Learning von Grund auf neu] zu implementieren.
Ich habe eine Web-API erstellt
Ich habe versucht, Runenfiguren mit Scikit-Learn handschriftlich zu erkennen
Ich möchte Passungen aus meinem Kopf machen
Ich habe versucht, KI für Smash Bra zu machen
Ich möchte C ++ - Code aus Python-Code erstellen!
Ich habe ein ○ ✕ Spiel mit TensorFlow gemacht
Ich habe versucht, eine einfache Bilderkennungs-API mit Fast API und Tensorflow zu erstellen
Ich habe versucht, einen "verdammt großen Literaturkonverter" zu machen.
Fortsetzung ・ Ich habe versucht, Slackbot zu erstellen, nachdem ich Python3 studiert habe
Ich habe versucht zu debuggen.
Ich habe versucht, den negativen Teil von Meros zu löschen
Ich habe versucht, eine OCR-App mit PySimpleGUI zu erstellen
Ich habe versucht, Dropout zu erklären
Ich habe versucht, einen Generator zu erstellen, der mit Python eine C # -Containerklasse aus CSV generiert
[Erste COTOHA-API] Ich habe versucht, die alte Geschichte zusammenzufassen
Ich habe versucht, API list.csv mit Python aus swagger.yaml zu erstellen
Ich habe versucht, mit Python faker verschiedene "Dummy-Daten" zu erstellen
Ich habe versucht, das Gesicht aus dem Video zu erkennen (OpenCV: Python-Version)
Ich habe versucht, MNIST nach GNN zu klassifizieren (mit PyTorch-Geometrie).
Ich habe eine Stoppuhr mit tkinter mit Python gemacht
Ich habe versucht, die handschriftliche Zeichenerkennung von Runenzeichen mit CNN mithilfe von Keras zu erkennen
Ich habe versucht, die Benutzeroberfläche neben Python und Tkinter dreiäugig zu gestalten
Ich habe versucht, das Python-Skript unter Windows 10 von 2.7.11 auf 3.6.0 zu ändern
Ich habe versucht, verschiedene Informationen von der Codeforces-API abzurufen
Ich habe mit PyQt einen einfachen Texteditor erstellt
Ich habe versucht, mit pypyodbc schnell Daten von AS / 400 abzurufen
Ich habe versucht, PredNet zu lernen
Ich habe versucht, SVM zu organisieren.
Ich habe versucht, das Gesicht mit Face ++ zu erkennen
Ich habe versucht, PCANet zu implementieren
Ich habe versucht, Linux wieder einzuführen
Ich habe versucht, Pylint vorzustellen
Ich habe versucht, SparseMatrix zusammenzufassen
jupyter ich habe es berührt
Ich habe versucht, StarGAN (1) zu implementieren.
[Erstes Scraping] Ich habe versucht, einen VIP-Charakter für Smash Bra [Beautiful Soup] zu erstellen. [Zusätzlich Datenanalyse]
Ich habe versucht, ein System zu erstellen, das nur gelöschte Tweets abruft
Ich habe versucht, Deep Learning mit Spark × Keras × Docker skalierbar zu machen
[Python] Ich habe versucht, eine stabile Sortierung zu implementieren
[Einführung in die Simulation] Ich habe versucht, durch Simulation einer Koronainfektion zu spielen ♬ Teil 2
Ich habe versucht, mit Python einen regulären Ausdruck von "Zeit" zu erstellen
[3.] Ich habe versucht, mit Python ein bestimmtes Authenticator-ähnliches Tool zu erstellen
Ich habe versucht, mit Python einen regulären Ausdruck von "Datum" zu erstellen
Ich habe versucht, mit Selenium und Python einen regelmäßigen Ausführungsprozess durchzuführen
Ich habe versucht, mit Python eine 2-Kanal-Post-Benachrichtigungsanwendung zu erstellen
Ich habe versucht, in 3 Jahren 5 Muster der Analysebasis zu erstellen
Ich möchte eine Parameterliste aus CloudFormation-Code (yaml) erstellen.
Ich habe versucht, eine ToDo-App mit einer Flasche mit Python zu erstellen
[4.] Ich habe versucht, mit Python ein bestimmtes Authenticator-ähnliches Tool zu erstellen
[Python] Einfaches Japanisch ⇒ Ich habe versucht, ein englisches Übersetzungswerkzeug zu erstellen
Ich habe versucht, ein Standbild aus dem Video auszuschneiden
[1.] Ich habe versucht, mit Python ein bestimmtes Authenticator-ähnliches Tool zu erstellen
Ich habe versucht, Spieler- und Fertigkeitsnamen aus Sportartikeln zu extrahieren
Ich habe versucht, Jojo mit LSTM ein seltsames Zitat zu machen
Ich habe versucht, mit pypyodbc Preparation 1 schnell Daten von AS / 400 abzurufen
Ich habe versucht, mit Python + OpenCV eine Bildähnlichkeitsfunktion zu erstellen