[PYTHON] Führen Sie eine Inferenz mit dem Chainer 2.0 MNIST-Beispiel durch

Einführung

Ich lerne Deep Learning, indem ich den Beispielcode von Chainer ausführe.

Zum Zeitpunkt des Schreibens dieses Artikels (Juni 2017) ist die neueste Version von Chainer 2.0, jedoch möglicherweise nicht mit 1.x kompatibel, und ältere Versionen des Codes funktionieren möglicherweise nicht. Referenz: Unterschiede zwischen Kettenversionen (Stand 19. Januar 2016)

Dieser Artikel ist ein Chainer 2.0 MNIST-Beispiel, ein Implementierungshinweis für die Inferenz.

Für die Implementierung habe ich auf diesen Artikel verwiesen. Chainer: Tutorial für Anfänger Vol.1

Umgebung

Chainer 2.0 python 2.7.10 Auf CPU ausführen

Code

https://github.com/abechi/chainer_mnist_predict

Chainer 2.0 MNIST-Probe (Original) https://github.com/chainer/chainer/tree/v2.0.0/examples/mnist

1. Prozess (1 Zeile) hinzugefügt, um das trainierte Modell in train_mnist.py zu speichern

train_mnist.py


    # Run the training
    trainer.run()

    chainer.serializers.save_npz('my_mnist.model', model) # Added

2. Führen Sie train_mnist.py aus, um mit dem Lernen zu beginnen

$ python train_mnist.py --epoch 3
GPU: -1
# unit: 1000
# Minibatch-size: 100
# epoch: 3

epoch       main/loss   validation/main/loss  main/accuracy  validation/main/accuracy  elapsed_time
1           0.191836    0.0885223             0.942233       0.9718                    26.099        
2           0.0726428   0.0825069             0.9768         0.974                     53.4849       
3           0.0466335   0.0751425             0.984983       0.9747                    81.2683       
$ ls
my_mnist.model  result/         train_mnist.py*

3. Laden Sie das gespeicherte trainierte Modell und schließen Sie es ab

predict_mnist.py


#!/usr/bin/env python

from __future__ import print_function 

try:
    import matplotlib
    matplotlib.use('Agg')
except ImportError:
    pass

import argparse

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions


# Network definition
class MLP(chainer.Chain):

    def __init__(self, n_units, n_out):
        super(MLP, self).__init__()
        with self.init_scope():
            # the size of the inputs to each layer will be inferred
            self.l1 = L.Linear(None, n_units)  # n_in -> n_units
            self.l2 = L.Linear(None, n_units)  # n_units -> n_units
            self.l3 = L.Linear(None, n_out)  # n_units -> n_out

    def __call__(self, x):
        h1 = F.relu(self.l1(x))
        h2 = F.relu(self.l2(h1))
        return self.l3(h2)


def main():
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                        help='Number of units')
    args = parser.parse_args()

    print('# unit: {}'.format(args.unit))
    print('')

    # Set up a neural network
    model = L.Classifier(MLP(args.unit, 10))

    # Load the MNIST dataset
    train, test = chainer.datasets.get_mnist()

    chainer.serializers.load_npz('my_mnist.model', model)

    x, t = test[0]
    print('label:', t)

    x = x[None, ...]
    y = model.predictor(x)
    y = y.data

    print('predicted_label:', y.argmax(axis=1)[0])

if __name__ == '__main__':
    main()

Predict_Mnist.py liest my_mnist.model und leitet die Bezeichnung für die Testdaten ab.

$ python predict_mnist.py 
# unit: 1000

label: 7
predicted_label: 7

Ich habe das gleiche Etikett wie das richtige Antwortetikett.

Vorsichtsmaßnahmen beim Erstellen von Modellobjekten

train_mnist.py


    # iteration, which will be used by the PrintReport extension below.
    model = L.Classifier(MLP(args.unit, 10))

In train_mnist.py habe ich mit L.Classifier ein Modell erstellt. Sie müssen L.Classifier auch verwenden, wenn Sie während der Inferenz ein Modellobjekt erstellen.

Wenn Sie ein Objekt für das Modell erstellen, ohne L.Classifier zu durchlaufen, wird beim Laden des Modells eine Fehlermeldung angezeigt.

predict_mnist.py


    # Set up a neural network
    model = MLP(args.unit, 10)

Error KeyError: 'l2/b is not a file in the archive'

Referenz Chainer-Modell speichern und laden

Recommended Posts

Führen Sie eine Inferenz mit dem Chainer 2.0 MNIST-Beispiel durch
Was tun, wenn der Chainer (Windows) -Beispielverzeichnis mit WinError 183 beendet wird?
Seq2Seq (1) mit Chainer
MNIST (DCNN) mit Skflow
Verwenden Sie Tensorboard mit Chainer
Erklärt Mnist nach Chainer 1.11.0
Mit Python erstellte Beispieldaten
Lesen Sie das Keras Mnist-Beispiel
Versuchen Sie, RBM mit Chainer zu implementieren.
Lernen Sie mit Chainer elliptische Bahnen
Seq2Seq (3) ~ CopyNet Edition ~ mit Chainer
Verwendung von Chainer mit Jetson TK1
Neuronales Netz beginnend mit Chainer
Bedingte GAN mit Chainer implementiert
[Python] Bayesianische Schätzung mit Pyro
SmoothGrad mit Chainer v2 implementiert
Deep Embedded Clustering mit Chainer 2.0
Ein bisschen im Kettenschiff stecken
Denken Sie an Aussetzer mit MNIST
Versuchen Sie TensorFlow MNIST mit RNN
Funktioniert mit Python und R.