[PYTHON] J'ai essayé de faire la reconnaissance de caractères manuscrits de Kana Partie 1/3 D'abord à partir de MNIST

Aperçu

J'ai essayé de détecter les caractères en saisissant kana dans l'interface graphique et en utilisant un modèle créé par entraînement préalable par apprentissage automatique.

Tout d'abord, vérifiez la sensation et la précision de CNN avec MNIST, puis donnez les données kana réelles pour l'entraînement, et enfin liez-les à l'interface graphique.

La prochaine fois (2/3): https://qiita.com/tfull_tf/items/968bdb8f24f80d57617e Prochaine fois (3/3): https://qiita.com/tfull_tf/items/d9fe3ab6c1e47d1b2e1e

Le code complet peut être trouvé à l'adresse: https://github.com/tfull/character_recognition

Construction de modèles avec MNIST

Construisez votre propre modèle et exécutez le train, testez le populaire jeu de données numériques manuscrites MNIST pour voir à quel point il est précis.

Étant donné que MNIST contient des données en niveaux de gris 28x28, entrez-le sous la forme (canal, largeur, hauteur) = (1, 28, 28). Puisque les nombres vont de 0 à 9, il y a 10 destinations de classification et 10 probabilités sont sorties.

import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout2d(0.3)
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(12 * 12 * 32, 128)
        self.relu3 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.3)
        self.linear2 = nn.Linear(128, 10)
        self.softmax = nn.Softmax(dim = 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool(x)
        x = self.dropout1(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.relu3(x)
        x = self.dropout2(x)
        x = self.linear2(x)
        x = self.softmax(x)
        return x

Il est converti en une dimension et passé à travers deux couches entièrement connectées via deux couches de pliage et une couche de regroupement ultérieure. La fonction d'activation est ReLU, et le contour du modèle consiste à insérer un calque de suppression pour éviter le surapprentissage au milieu.

L'acquisition des données

import torchvision

download_flag = not os.path.exists(data_directory + "/mnist")

mnist_train = torchvision.datasets.MNIST(
    data_directory + "/mnist",
    train = True,
    download = download_flag,
    transform = torchvision.transforms.ToTensor()
)

mnist_test = torchvision.datasets.MNIST(
    data_directory + "/mnist",
    train = False,
    download = download_flag,
    transform = torchvision.transforms.ToTensor()
)

Enregistrez les données MNIST localement et utilisez-les. Définissez data_directory pour qu'il soit téléchargé s'il n'existe pas. Ce faisant, je me suis assuré de ne télécharger que la première fois.

Préparation à l'apprentissage

import torch
import torch.optim as optim

train_loader = torch.utils.data.DataLoader(mnist_train,  batch_size = 100,  shuffle = True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size = 1000, shuffle = False)

model = Model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.001)

Utilisez DataLoader pour récupérer les données en séquence.

Définissez le modèle, la fonction d'erreur et l'algorithme d'optimisation. Nous avons adopté l'erreur d'entropie croisée, Adam.

Entraînement

n_epoch = 2

model.train()

for i_epoch in range(n_epoch):
    for i_batch, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        print("epoch: {}, train: {}, loss: {}".format(i_epoch + 1, i_batch + 1, loss.item()))

Une série d'opérations d'apprentissage sont effectuées dans une boucle, dans laquelle des données d'image (entrées) sont fournies au modèle, la sortie (sortie) et les données de réponse correctes (étiquettes) sont comparées, l'erreur est calculée et la rétro-propagation est effectuée. Je vais.

Je pense que donner chaque donnée une fois n'est pas suffisant pour l'entraînement, alors j'ai mis le nombre d'époques (n_epoch) à 2 et donne à chaque donnée n_epoch fois pour l'entraînement. Le nombre d'époques est mon expérience, mais je pense qu'environ 2 à 3 est juste. Je pense que cela dépend du nombre de données.

Évaluation

correct_count = 0
record_count = 0

model.eval()

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, prediction = torch.max(outputs.data, 1)
        judge = prediction == labels
        correct_count += int(judge.sum())
        record_count += len(judge)

print("Accuracy: {:.2f}%".format(correct_count / record_count * 100))

Les données numériques (entrées) de l'image sont entrées dans le modèle, et la plus élevée des 10 probabilités qui apparaissent est utilisée comme prédiction. Il compare si elle correspond aux données de réponse correctes (étiquettes), renvoie Vrai / Faux et calcule le nombre de Vrai (compte_correct) par rapport au nombre total (compte_enregistrement) pour obtenir le taux de réponse correct.

résultats et discussion

Le résultat a été en moyenne plusieurs fois, environ 97%.

Je pense que la valeur du taux de réponse correcte est élevée, mais j'ai échoué 3 fois en 100 fois. Je pense que ce sera une autre question de savoir si les humains peuvent tolérer cela. Cependant, il y a des caractères sales dans les données d'image MNIST qui sont difficiles à distinguer pour les humains, donc dans ce sens, une erreur de 3% peut être inévitable.

MNIST a 10 choix de 0 à 9, mais comme il y en a plus de 100 en hiragana et katakana pour kana, il sera difficile de classer et vous devez être prêt à une nouvelle baisse du taux de réponse correcte.

Recommended Posts

J'ai essayé de faire la reconnaissance de caractères manuscrits de Kana Partie 1/3 D'abord à partir de MNIST
J'ai essayé de faire la reconnaissance de caractères manuscrits de Kana Partie 2/3 Création et apprentissage de données
J'ai essayé de faire la reconnaissance de caractères manuscrits de Kana Partie 3/3 Coopération avec l'interface graphique en utilisant Tkinter
J'ai essayé d'implémenter Perceptron Part 1 [Deep Learning from scratch]
J'ai créé une API Web
J'ai essayé la reconnaissance manuscrite des caractères des runes avec scikit-learn
Je veux faire des crises de ma tête
J'ai essayé de faire de l'IA pour Smash Bra
Je veux créer du code C ++ à partir de code Python!
J'ai créé un jeu ○ ✕ avec TensorFlow
J'ai essayé de créer une API de reconnaissance d'image simple avec Fast API et Tensorflow
J'ai essayé de faire un "putain de gros convertisseur de littérature"
Suite ・ J'ai essayé de créer Slackbot après avoir étudié Python3
J'ai essayé de déboguer.
J'ai essayé d'effacer la partie négative de Meros
J'ai essayé de créer une application OCR avec PySimpleGUI
[Deep Learning from scratch] J'ai essayé d'expliquer le décrochage
J'ai essayé de créer un générateur qui génère une classe conteneur C # à partir de CSV avec Python
[Première API COTOHA] J'ai essayé de résumer l'ancienne histoire
J'ai essayé de créer une API list.csv avec Python à partir de swagger.yaml
J'ai essayé de créer diverses "données factices" avec Python faker
J'ai essayé de reconnaître le visage de la vidéo (OpenCV: version python)
J'ai essayé de classer MNIST par GNN (avec PyTorch géométrique)
J'ai fait un chronomètre en utilisant tkinter avec python
J'ai essayé la reconnaissance de caractères manuscrits des caractères runiques avec CNN en utilisant Keras
J'ai essayé de créer une interface graphique à trois yeux côte à côte avec Python et Tkinter
J'ai essayé de changer le script python de 2.7.11 à 3.6.0 sur Windows10
J'ai essayé d'obtenir diverses informations de l'API codeforces
J'ai créé un éditeur de texte simple en utilisant PyQt
J'ai essayé d'obtenir rapidement des données d'AS / 400 en utilisant pypyodbc
J'ai essayé d'apprendre PredNet
J'ai essayé d'organiser SVM.
J'ai essayé la reconnaissance faciale avec Face ++
J'ai essayé d'implémenter PCANet
J'ai essayé de réintroduire Linux
J'ai essayé de présenter Pylint
J'ai essayé de résumer SparseMatrix
jupyter je l'ai touché
J'ai essayé d'implémenter StarGAN (1)
[Premier grattage] J'ai essayé de créer un personnage VIP pour Smash Bra [Beautiful Soup] [En plus, analyse de données]
J'ai essayé de créer un système qui ne récupère que les tweets supprimés
J'ai essayé de rendre le deep learning évolutif avec Spark × Keras × Docker
[Python] J'ai essayé d'implémenter un tri stable, alors notez
[Introduction à la simulation] J'ai essayé de jouer en simulant une infection corona ♬ Partie 2
J'ai essayé de créer une expression régulière de "temps" en utilisant Python
[3ème] J'ai essayé de créer un certain outil de type Authenticator avec python
J'ai essayé de créer une expression régulière de "date" en utilisant Python
J'ai essayé de faire un processus d'exécution périodique avec Selenium et Python
J'ai essayé de créer une application de notification de publication à 2 canaux avec Python
J'ai essayé de faire 5 modèles de base d'analyse en 3 ans
Je souhaite créer une liste de paramètres à partir du code CloudFormation (yaml)
J'ai essayé de créer une application todo en utilisant une bouteille avec python
[4th] J'ai essayé de créer un certain outil de type Authenticator avec python
[Python] Japonais simple ⇒ J'ai essayé de créer un outil de traduction en anglais
J'ai essayé de couper une image fixe de la vidéo
[1er] J'ai essayé de créer un certain outil de type Authenticator avec python
J'ai essayé d'extraire des noms de joueurs et de compétences d'articles sportifs
J'ai essayé de faire une étrange citation pour Jojo avec LSTM
J'ai essayé d'obtenir rapidement des données d'AS / 400 en utilisant pypyodbc Préparation 1
J'ai essayé de créer une fonction de similitude d'image avec Python + OpenCV