[PYTHON] Autoencoder dans Chainer (Remarques sur l'utilisation de + trainer)

introduction

Dernière fois J'ai essayé de contester la classification d'image de CIFAR-10 en utilisant le formateur, une nouvelle fonction de Chainer, mais en raison de la puissance de la machine, cela fonctionne. Je n'ai pas pu le confirmer et cela s'est terminé. Donc, cette fois, je vais confirmer comment utiliser le formateur en créant Autoencoder en utilisant MNIST.

Concernant Autoencoder, j'ai fait référence à cet article.

la mise en oeuvre

Créez un réseau qui prend 1000 caractères manuscrits MNIST en entrée et passe par une couche masquée pour obtenir une sortie égale à l'entrée. Le code complet est répertorié ici [https://github.com/trtd56/Autoencoder).

Partie réseau

Le nombre d'unités de calque masquées est limité à 64. De plus, lorsqu'elle est appelée avec hidden = True, la couche masquée peut être sortie.

class Autoencoder(chainer.Chain):
    def __init__(self):
        super(Autoencoder, self).__init__(
                encoder = L.Linear(784, 64),
                decoder = L.Linear(64, 784))

    def __call__(self, x, hidden=False):
        h = F.relu(self.encoder(x))
        if hidden:
            return h
        else:
            return F.relu(self.decoder(h))

Partie création de données

Lisez les données MNIST et créez les données des enseignants et les données de test. Je n'ai pas besoin d'étiquette pour les données de l'enseignant et la sortie est la même que l'entrée, donc je bricole un peu la forme des données.

# Lecture des données MNIST
train, test = chainer.datasets.get_mnist()

# Données des enseignants
train = train[0:1000]
train = [i[0] for i in train]
train = tuple_dataset.TupleDataset(train, train)
train_iter = chainer.iterators.SerialIterator(train, 100)

# Données de test
test = test[0:25]

La modélisation

model = L.Classifier(Autoencoder(), lossfun=F.mean_squared_error)
model.compute_accuracy = False
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)

Deux points à noter ici

  1. Définition de la fonction de perte Lors de la définition d'un modèle avec L.Classifier, la fonction de perte semble être softmax_cross_entropy par défaut, mais cette fois je veux utiliser mean_squared_error, je dois donc la définir avec lossfun.

  2. Ne calculez pas la précision Cette fois, nous n'utilisons pas d'étiquettes pour les données des enseignants, nous n'avons donc pas besoin de calculer la précision. Vous devez donc définir compute_accuracy sur False.

Partie d'apprentissage

Je ne pense pas qu'il y ait un besoin particulier d'explication. Depuis que le formateur est devenu disponible, j'ai pu écrire cette partie facilement, ce qui m'a aidé ^^

updater = training.StandardUpdater(train_iter, optimizer, device=-1)
trainer = training.Trainer(updater, (N_EPOCH, 'epoch'), out="result")
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport( ['epoch', 'main/loss']))
trainer.extend(extensions.ProgressBar())

trainer.run()

Vérifiez le résultat

Créez une fonction et tracez le résultat avec matplotlib. L'étiquette originale est imprimée en rouge en haut de l'image. Je n'ai pas ajusté correctement les coordonnées, donc certaines parties sont couvertes ...

À propos, si vous entrez les données de test telles quelles dans cette fonction, l'image des données originales sera sortie.

def plot_mnist_data(samples):
    for index, (data, label) in enumerate(samples):
        plt.subplot(5, 5, index + 1)
        plt.axis('off')
        plt.imshow(data.reshape(28, 28), cmap=cm.gray_r, interpolation='nearest')
        n = int(label)
        plt.title(n, color='red')
    plt.show()

pred_list = []
for (data, label) in test:
    pred_data = model.predictor(np.array([data]).astype(np.float32)).data
    pred_list.append((pred_data, label))
plot_mnist_data(pred_list)

résultat

Voyons comment cela change à mesure que nous augmentons l'époque.

Image originale

epoch_origin.png

16 images dont toutes de 0 à 9. Regardons ces 16 types de changements.

epoch = 1

epoch_1.png

C'est comme une tempête de sable à la télé et je ne sais pas ce que c'est à ce stade.

epoch = 5

epoch_5.png

J'ai enfin vu quelque chose comme un nombre, mais je ne le sais toujours pas.

epoch = 10

epoch_10.png

Les formes 0, 1, 3, etc. deviennent progressivement visibles. Le 6 de la deuxième rangée est toujours écrasé et je ne suis pas sûr.

epoch = 20

epoch_20.png

Je peux presque voir les chiffres.

epoch = 100

epoch_100.png

J'ai essayé d'avancer à 100 à la fois. La forme du 6 dans la deuxième rangée, qui était presque écrasée, est maintenant visible. Ce sera clair si vous augmentez le nombre d'époques, mais cette fois, c'est à ici.

en conclusion

C'était amusant de voir comment le réseau reconnaissait les nombres comme des nombres. ~~ trainer est pratique, mais soyez prudent car diverses parties comme la fonction de perte sont automatiquement déterminées. ~~ (Fixé le 2016.08.10) C'était la spécification de Classifer, pas le formateur, que la fonction de perte était définie par défaut sur soft_max_cross_entropy. La fonction de perte est spécifiée lors de la définition du programme de mise à jour utilisé dans le formateur, mais généralement celui défini dans l'optimiseur semble être lié.

Recommended Posts

Autoencoder dans Chainer (Remarques sur l'utilisation de + trainer)
Remarques sur l'utilisation de pywinauto
Remarques sur l'utilisation des featuretools
Python: comment utiliser async avec
Pour utiliser virtualenv avec PowerShell
Comment utiliser l'homebrew dans Debian
Remarques sur la rédaction de requirements.txt
[Hyperledger Iroha] Remarques sur l'utilisation du SDK Python
Remarques sur l'utilisation de la guimauve dans la bibliothèque de schémas
Comment utiliser mecab, neologd-ipadic sur colab
Comment utiliser OpenVPN avec Ubuntu 18.04.3 LTS
Comment utiliser Cmder avec PyCharm (Windows)
Comment utiliser l'Assistant Google sur Windows 10
Comment utiliser Ass / Alembic avec HtoA
Pour utiliser python, mettez pyenv sur macOS avec PyCall
Comment utiliser le japonais avec le tracé NLTK
Comment utiliser le notebook Jupyter avec ABCI
Comment utiliser la commande CUT (avec exemple)
Comment utiliser SQLAlchemy / Connect avec aiomysql
Comment utiliser le pilote JDBC avec Redash
Comment utiliser la trace GCP avec la télémétrie ouverte
Stratégie sur la façon de monétiser avec Python Java
Comment installer OpenGM sur OSX avec macports
Comment utiliser tkinter avec python dans pyenv
Comment utiliser xml.etree.ElementTree
Comment utiliser Python-shell
Remarques sur l'utilisation de tf.data
Comment utiliser virtualenv
Comment utiliser Seaboan
Comment utiliser la correspondance d'image
Comment utiliser le shogun
Comment utiliser Pandas 2
Comment utiliser Virtualenv
Comment utiliser numpy.vectorize
Comment utiliser partiel
Comment utiliser Bio.Phylo
Comment utiliser SymPy
Comment utiliser WikiExtractor.py
Comment utiliser IPython
Comment utiliser virtualenv
Comment utiliser Matplotlib
Comment utiliser iptables
Comment utiliser numpy
Comment utiliser TokyoTechFes2015
Comment utiliser venv
Comment utiliser le dictionnaire {}
Comment utiliser Pyenv
Comment utiliser la liste []
Comment utiliser python-kabusapi
Comment utiliser OptParse
Comment utiliser le retour
Comment utiliser pyenv-virtualenv
Comment utiliser imutils
Comment obtenir la clé sur Amazon S3 avec Boto 3, exemple de mise en œuvre, notes
Comment utiliser xgboost: classification multi-classes avec des données d'iris
Comment utiliser le contrôleur audio C216 sur Arch Linux