[PYTHON] Exécuter l'inférence avec l'exemple de Chainer 2.0 MNIST

introduction

J'étudie l'apprentissage profond en exécutant l'exemple de code de Chainer.

Au moment de la rédaction de cet article (juin 2017), la dernière version de Chainer est 2.0, mais elle peut ne pas être compatible avec 1.x et les versions antérieures du code peuvent ne pas fonctionner. Référence: Différences entre les versions de chainer (au 19 janvier 2016)

Cet article est un exemple de Chainer 2.0 MNIST, une note d'implémentation pour la conduite de l'inférence.

Pour la mise en œuvre, je me suis référé à cet article. Chainer: Tutoriel pour les débutants Vol.1

environnement

Chainer 2.0 python 2.7.10 Exécuter sur CPU

code

https://github.com/abechi/chainer_mnist_predict

Échantillon Chainer 2.0 MNIST (original) https://github.com/chainer/chainer/tree/v2.0.0/examples/mnist

1. Ajout du processus (1 ligne) pour enregistrer le modèle entraîné dans train_mnist.py

train_mnist.py


    # Run the training
    trainer.run()

    chainer.serializers.save_npz('my_mnist.model', model) # Added

2. Exécutez train_mnist.py pour commencer à apprendre

$ python train_mnist.py --epoch 3
GPU: -1
# unit: 1000
# Minibatch-size: 100
# epoch: 3

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
1           0.191836    0.0885223             0.942233       0.9718                    26.099        
2           0.0726428   0.0825069             0.9768         0.974                     53.4849       
3           0.0466335   0.0751425             0.984983       0.9747                    81.2683       
$ ls
my_mnist.model  result/         train_mnist.py*

3. Charger et déduire le modèle entraîné enregistré

predict_mnist.py


#!/usr/bin/env python

from __future__ import print_function 

try:
    import matplotlib
    matplotlib.use('Agg')
except ImportError:
    pass

import argparse

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions


# Network definition
class MLP(chainer.Chain):

    def __init__(self, n_units, n_out):
        super(MLP, self).__init__()
        with self.init_scope():
            # the size of the inputs to each layer will be inferred
            self.l1 = L.Linear(None, n_units)  # n_in -> n_units
            self.l2 = L.Linear(None, n_units)  # n_units -> n_units
            self.l3 = L.Linear(None, n_out)  # n_units -> n_out

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)


def main():
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                        help='Number of units')
    args = parser.parse_args()

    print('# unit: {}'.format(args.unit))
    print('')

    # Set up a neural network
    model = L.Classifier(MLP(args.unit, 10))

    # Load the MNIST dataset
    train, test = chainer.datasets.get_mnist()

    chainer.serializers.load_npz('my_mnist.model', model)

    x, t = test[0]
    print('label:', t)

    x = x[None, ...]
    y = model.predictor(x)
    y = y.data

    print('predicted_label:', y.argmax(axis=1)[0])

if __name__ == '__main__':
    main()

Predict_mnist.py lit my_mnist.model et déduit l'étiquette des données de test.

$ python predict_mnist.py 
# unit: 1000

label: 7
predicted_label: 7

J'ai la même étiquette que l'étiquette de réponse correcte.

Précautions lors de la création d'objets de modèle

train_mnist.py


    # iteration, which will be used by the PrintReport extension below.
    model = L.Classifier(MLP(args.unit, 10))

Dans train_mnist.py, j'ai créé un modèle en utilisant L.Classifier. Vous devez également utiliser L.Classifier lors de la création d'un objet de modèle lors de l'inférence.

Si vous créez un objet pour le modèle sans passer par L.Classifier, vous obtiendrez une erreur lors du chargement du modèle.

predict_mnist.py


    # Set up a neural network
    model = MLP(args.unit, 10)

Erreur KeyError: 'l2/b is not a file in the archive'

Référence Enregistrer et charger le modèle Chainer

Recommended Posts

Exécuter l'inférence avec l'exemple de Chainer 2.0 MNIST
Que faire si le mnist d'exemple Chainer (Windows) se termine par WinError 183.
Seq2Seq (1) avec chainer
MNIST (DCNN) avec skflow
Utiliser tensorboard avec Chainer
Explique mnist après le chainer 1.11.0
Exemple de données créées avec python
Lire l'exemple de keras mnist
Essayez d'implémenter RBM avec chainer.
Apprenez les orbites elliptiques avec Chainer
Seq2Seq (3) ~ Edition CopyNet ~ avec chainer
Utilisation du chainer avec Jetson TK1
Réseau de neurones commençant par Chainer
Implémentation du GAN conditionnel avec chainer
[Python] Estimation bayésienne avec Pyro
Implémentation de SmoothGrad avec Chainer v2
Clustering embarqué profond avec Chainer 2.0
Un peu coincé dans le chainer
Pensez aux abandons avec MNIST
Essayez TensorFlow MNIST avec RNN
Fonctionne avec Python et R