[PYTHON] Autoencoder im Chainer (Hinweise zur Verwendung von + Trainer)

Einführung

Letztes Mal Ich habe versucht, die Bildklassifizierung von CIFAR-10 mit dem Trainer, einer neuen Funktion von Chainer, in Frage zu stellen, aber aufgrund der Maschinenleistung funktioniert es. Ich konnte es nicht bestätigen und es endete. Dieses Mal werde ich die Verwendung des Trainers durch Erstellen eines Autoencoders mit MNIST bestätigen.

In Bezug auf Autoencoder habe ich auf diesen Artikel verwiesen.

Implementierung

Erstellen Sie ein Netzwerk, das 1000 handgeschriebene Zeichen von MNIST als Eingabe verwendet und eine Ausgabe erhält, die der Eingabe über eine verborgene Ebene entspricht. Der gesamte Code ist hier aufgelistet [https://github.com/trtd56/Autoencoder].

Netzwerkteil

Die Anzahl der versteckten Ebeneneinheiten ist auf 64 begrenzt. Wenn mit hidden = True aufgerufen wird, kann die verborgene Ebene ausgegeben werden.

class Autoencoder(chainer.Chain):
    def __init__(self):
        super(Autoencoder, self).__init__(
                encoder = L.Linear(784, 64),
                decoder = L.Linear(64, 784))

    def __call__(self, x, hidden=False):
        h = F.relu(self.encoder(x))
        if hidden:
            return h
        else:
            return F.relu(self.decoder(h))

Datenerstellungsteil

Lesen Sie MNIST-Daten und erstellen Sie Lehrerdaten und Testdaten. Ich bastele ein bisschen an der Form der Daten, weil ich keine Beschriftung für die Lehrerdaten benötige und die Ausgabe mit der Eingabe übereinstimmt.

# MNIST-Daten lesen
train, test = chainer.datasets.get_mnist()

# Lehrerdaten
train = train[0:1000]
train = [i[0] for i in train]
train = tuple_dataset.TupleDataset(train, train)
train_iter = chainer.iterators.SerialIterator(train, 100)

# Testdaten
test = test[0:25]

Modellieren

model = L.Classifier(Autoencoder(), lossfun=F.mean_squared_error)
model.compute_accuracy = False
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)

Zwei Punkte, die hier zu beachten sind

  1. Definition der Verlustfunktion Wenn ich ein Modell mit L.Classifier definiere, scheint die Verlustfunktion standardmäßig softmax_cross_entropy zu sein, aber dieses Mal möchte ich mean_squared_error verwenden, also muss ich es mit lossfun definieren.

  2. Berechnen Sie nicht die Genauigkeit Dieses Mal verwenden wir keine Beschriftungen für Lehrerdaten, daher müssen wir die Genauigkeit nicht berechnen. Sie müssen also compute_accuracy auf False setzen.

Lernteil

Ich glaube nicht, dass es eines besonderen Erklärungsbedarfs bedarf. Seit der Verfügbarkeit des Trainers war es hilfreich, diesen Teil einfach schreiben zu können ^^

updater = training.StandardUpdater(train_iter, optimizer, device=-1)
trainer = training.Trainer(updater, (N_EPOCH, 'epoch'), out="result")
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport( ['epoch', 'main/loss']))
trainer.extend(extensions.ProgressBar())

trainer.run()

Überprüfen Sie das Ergebnis

Erstellen Sie eine Funktion und zeichnen Sie das Ergebnis mit matplotlib. Das Originaletikett ist oben im Bild rot gedruckt. Ich habe die Koordinaten nicht richtig angepasst, daher werden einige Teile behandelt ...

Übrigens, wenn Sie die Testdaten so eingeben, wie sie in dieser Funktion sind, wird das Bild der Originaldaten ausgegeben.

def plot_mnist_data(samples):
    for index, (data, label) in enumerate(samples):
        plt.subplot(5, 5, index + 1)
        plt.axis('off')
        plt.imshow(data.reshape(28, 28), cmap=cm.gray_r, interpolation='nearest')
        n = int(label)
        plt.title(n, color='red')
    plt.show()

pred_list = []
for (data, label) in test:
    pred_data = model.predictor(np.array([data]).astype(np.float32)).data
    pred_list.append((pred_data, label))
plot_mnist_data(pred_list)

Ergebnis

Mal sehen, wie es sich ändert, wenn wir die Epoche verlängern.

Original Bild

epoch_origin.png

16 Bilder, darunter alle von 0 bis 9. Schauen wir uns diese 16 Arten von Änderungen an.

epoch = 1

epoch_1.png

Es ist wie ein Sandsturm im Fernsehen und ich weiß nicht, was es zu diesem Zeitpunkt ist.

epoch = 5

epoch_5.png

Ich habe endlich so etwas wie eine Nummer gesehen, aber ich weiß es immer noch nicht.

epoch = 10

epoch_10.png

Die Formen von 0, 1, 3 usw. werden allmählich sichtbar. Die 6 in der zweiten Reihe ist immer noch zerquetscht und ich bin mir nicht sicher.

epoch = 20

epoch_20.png

Ich kann die Zahlen fast sehen.

epoch = 100

epoch_100.png

Ich habe versucht, sofort auf 100 vorzurücken. Die Form von 6 in der zweiten Reihe, die fast zerquetscht wurde, ist jetzt sichtbar. Es wird klar sein, ob Sie die Anzahl der Epochen erhöhen, aber diesmal liegt es an Ihnen.

abschließend

Es hat Spaß gemacht zu sehen, wie das Netzwerk Zahlen als Zahlen erkannte. ~~ Trainer ist praktisch, aber seien Sie vorsichtig, da verschiedene Teile wie die Verlustfunktion automatisch ermittelt werden. ~~ (Behoben am 10.08.2016) Es war die Classifer-Spezifikation, nicht der Trainer, dass die Verlustfunktion standardmäßig auf soft_max_cross_entropy gesetzt war. Die Verlustfunktion wird beim Definieren des im Trainer verwendeten Updaters angegeben, aber normalerweise scheint der im Optimierer festgelegte Updater verknüpft zu sein.

Recommended Posts

Autoencoder im Chainer (Hinweise zur Verwendung von + Trainer)
Hinweise zur Verwendung von Pywinauto
Hinweise zur Verwendung von featuretools
Python: So verwenden Sie Async mit
So verwenden Sie virtualenv mit PowerShell
Wie benutzt man Homebrew in Debian?
Hinweise zum Schreiben von require.txt
[Hyperledger Iroha] Hinweise zur Verwendung des Python SDK
Hinweise zur Verwendung von Marshmallow in der Schemabibliothek
Wie man Mecab, neologd-ipadic auf Colab verwendet
Verwendung von OpenVPN mit Ubuntu 18.04.3 LTS
Verwendung von Cmder mit PyCharm (Windows)
So verwenden Sie Google Assistant unter Windows 10
Wie man Ass / Alembic mit HtoA benutzt
So verwenden Sie Python in Pyenv unter MacOS mit PyCall
Verwendung von Japanisch mit NLTK-Plot
Verwendung des Jupyter-Notebooks mit ABCI
Verwendung des CUT-Befehls (mit Beispiel)
Verwendung von SQLAlchemy / Connect mit aiomysql
Verwendung des JDBC-Treibers mit Redash
Verwendung der GCP-Ablaufverfolgung mit offener Telemetrie
Strategie zur Monetarisierung mit Python Java
So installieren Sie OpenGM unter OSX mit Macports
Wie man tkinter mit Python in Pyenv benutzt
Verwendung von xml.etree.ElementTree
Wie benutzt man Python-Shell
Hinweise zur Verwendung von tf.data
Verwendung von virtualenv
Wie benutzt man Seaboan?
Verwendung von Image-Match
Wie man Shogun benutzt
Verwendung von Pandas 2
Verwendung von Virtualenv
Verwendung von numpy.vectorize
Wie man teilweise verwendet
Wie man Bio.Phylo benutzt
Verwendung von SymPy
Verwendung von WikiExtractor.py
Verwendung von IPython
Verwendung von virtualenv
Wie benutzt man Matplotlib?
Verwendung von iptables
Wie benutzt man numpy?
Verwendung von TokyoTechFes2015
Wie benutzt man venv
Verwendung des Wörterbuchs {}
Wie benutzt man Pyenv?
Verwendung der Liste []
Wie man Python-Kabusapi benutzt
Verwendung von OptParse
Verwendung von return
Wie man Imutils benutzt
Hinweise zum Implementieren des Schlüssels unter Amazon S3 mit Boto 3, Implementierungsbeispiel, Hinweise
Verwendung von xgboost: Mehrklassenklassifizierung mit Irisdaten
Verwendung von C216 Audio Controller unter Arch Linux