[PYTHON] J'ai essayé de rendre mon propre code source compatible avec Chainer v2 alpha

Chainer v2 alpha

Depuis la sortie de Chainer v2 alpha, j'ai essayé de supporter mon propre code source. Je me suis référé au site suivant.

Environnement d'exploitation

Référentiel utilisé

J'ai créé une branche pour Chainer v2 dans le référentiel de Reconnaissance d'image de CIFAR-10 créé avant.

https://github.com/dsanno/chainer-cifar/tree/chainer_v2

Installation

Comme vous pouvez le voir sur la diapositive Chainer Meetup, j'ai pu l'installer avec la commande suivante. J'ai ajouté --no-cache-dir juste au cas où.

$ pip install chainer --pre --no-cache-dir
$ pip install cupy --no-cache-dir

Je vais l'essayer pour le moment

Comme la modification de Chainer v2 rompt la rétrocompatibilité, je peux m'attendre à ce que cela ne fonctionne pas, mais je vais l'essayer pour le moment.

$ python src/train.py -g 0 -m vgg -p model\temp9 -b 100 --iter 200 --lr 0.1 --optimizer sgd --weight_decay 0.0001 --lr_decay_iter 100,150

L'erreur suivante s'est produite.

Traceback (most recent call last):
  File "src\train.py", line 143, in <module>
    cifar_trainer.fit(train_x, train_y, valid_x, valid_y, test_x, test_y, on_epoch_done)
  File "c:\project_2\chainer\chainer-cifar\src\trainer.py", line 26, in fit
    return self.__fit(x, y, valid_x, valid_y, test_x, test_y, callback)
  File "c:\project_2\chainer\chainer-cifar\src\trainer.py", line 40, in __fit
    loss, acc = self.__forward(x_batch, y[batch_index])
  File "c:\project_2\chainer\chainer-cifar\src\trainer.py", line 75, in __forward
    y = self.net(x, train=train)
  File "c:\project_2\chainer\chainer-cifar\src\net.py", line 360, in __call__
    h = self.bconv1_1(x, train)
  File "c:\project_2\chainer\chainer-cifar\src\net.py", line 28, in __call__
    h = self.bn(self.conv(x), test=not train)
TypeError: __call__() got an unexpected keyword argument 'test'

C'est une erreur que l'argument de «call» de «chainer.links.BatchNormalization» est passé même s'il n'y a pas de «test».

Corrigé pour fonctionner avec Chainer v2

Supprimer le train de l'argument d'appel de chainer.functions.dropout

Depuis Chainer v2, l'argument dropout train` n'est plus nécessaire, supprimez-le.

Exemple de modification:

Avant correction:
h = F.dropout(F.max_pooling_2d(h, 2), 0.25, train=train)
modifié:
h = F.dropout(F.max_pooling_2d(h, 2), 0.25)

Supprimer le test de l'argument d'appel de chainer.links.BatchNormalization

L'argument test de BatchNormalization n'est plus nécessaire, supprimez-le comme dans le cas de dropout.

Avant correction:

class BatchConv2D(chainer.Chain):
    def __init__(self, ch_in, ch_out, ksize, stride=1, pad=0, activation=F.relu):
        super(BatchConv2D, self).__init__(
            conv=L.Convolution2D(ch_in, ch_out, ksize, stride, pad),
            bn=L.BatchNormalization(ch_out),
        )
        self.activation=activation

    def __call__(self, x, train):
        h = self.bn(self.conv(x), test=not train)
        if self.activation is None:
            return h
        return self.activation(h)

Modifié:

class BatchConv2D(chainer.Chain):
    def __init__(self, ch_in, ch_out, ksize, stride=1, pad=0, activation=F.relu):
        super(BatchConv2D, self).__init__(
            conv=L.Convolution2D(ch_in, ch_out, ksize, stride, pad),
            bn=L.BatchNormalization(ch_out),
        )
        self.activation=activation

    def __call__(self, x): #Retirer le train
        h = self.bn(self.conv(x)) #Supprimer le test
        if self.activation is None:
            return h
        return self.activation(h)

Entourez le traitement lorsque vous n'apprenez pas avec chainer.using_config ('train', False)

Suppression des arguments train et test des appels à dropout et BatchNormalization. À ce rythme, ces fonctions fonctionneront en mode d'apprentissage. À partir de Chainer v2, utilisez with chainer.using_config ('train',): pour contrôler si la formation est en cours.

    with chainer.using_config('train', False):
        #Que faire si vous n'apprenez pas(Calcul de la précision des données de test, etc.)

Utilisez chainer.config.train pour distinguer s'il apprend ou non

chainer.config a été ajouté à partir de Chainer v2, et il est maintenant possible de juger s'il apprend, si une rétro-propagation est nécessaire, etc. avec config. J'avais l'habitude de juger si j'apprenais ou non avec l'argument train de ma propre fonction comme indiqué ci-dessous, mais à partir de la v2, l'argument train n'est pas nécessaire et il peut être jugé avec configuration.config.train. C'est bon.

Avant correction:

def my_func(x, train=True):
    if train:
        #Traitement pendant l'apprentissage
    else:
        #Que faire si vous n'apprenez pas

Modifié:

def my_func(x):
    if chainer.config.train:
        #Traitement pendant l'apprentissage
    else:
        #Que faire si vous n'apprenez pas

Si la rétro-propagation n'est pas nécessaire, placez-la dans chainer.using_config ('train', False)

Entourez les processus qui ne nécessitent pas de rétro-propagation avec chainer.using_config ('train', False). Ceci s'applique aux cas où l'indicateur volatile était activé lorsque la chainer.Variable a été générée.

Non requis pour Chainer v2 alpha mais sera requis à l'avenir (après la version bêta)

Suppression de l'argument volatile dans le chainer.

Il reste dans la phase alpha v2, mais le "volatile" dans "chainer.Variable" sera supprimé à l'avenir. Au lieu de volatile, il sera contrôlé par chainer.using_config ('enable_backprop',). Puisqu'il est possible de passer le tableau Numpy et le tableau Cupy au lieu de «Variable» à l'appel de «chainer.functions» et «chainer.links», je pense qu'il existe une option pour supprimer le processus de génération de «Variable» également.

Avant correction:

    x = Variable(xp.asarray(batch_x), volatile=Train)

Modifié:

    with chainer.using_config('enable_backprop', False):
        x = Variable(xp.asarray(batch_x))

Exécution après modification

c:\project_2\chainer-cifar>python src\train.py -g 0 -m vgg -p model\temp -b 100 --iter 200 --lr 0.1 --optimizer sgd --weight_decay 0.0001 --lr_decay_iter 100,150
DEBUG: nvcc STDOUT mod.cu
Bibliothèque C:/Users/user_name/AppData/Local/Theano/compiledir_Windows-10-10.0.14393-Intel64_Family_6_Model_58_Stepping_9_GenuineIntel-2.7.11-64/tmpiwxtcf/265abc51f7c376c224983485238ff1a5.lib et objet C:/Users/user_name/AppData/Local/Theano/compiledir_Windows-10-10.0.14393-Intel64_Family_6_Model_58_Stepping_9_GenuineIntel-2.7.11-64/tmpiwxtcf/265abc51f7c376c224983485238ff1a5.Créer exp

Using gpu device 0: GeForce GTX 1080 (CNMeM is disabled, cuDNN 5105)
C:\Users\user_name\Anaconda\lib\site-packages\theano-0.8.2-py2.7.egg\theano\sandbox\cuda\__init__.py:600: UserWarning: Your cuDNN version is more recent than the one Theano officially supports. If you see any problems, try updating Theano or downgrading cuDNN to version 5.
  warnings.warn(warn)
loading dataset...
start training
epoch 0 done
train loss: 2.29680542204 error: 85.5222222221
valid loss: 1.95620539665 error: 81.3800000548
test  loss: 1.95627536774 error: 80.6099999845
test time: 1.04036228008s
elapsed time: 23.5432411172
epoch 1 done
train loss: 1.91133875476 error: 76.8000000185
valid loss: 1.83026596069 error: 73.6399999559
test  loss: 1.8381768012 error: 73.2900000066
test time: 0.993011643337s

Un avertissement est apparu autour de Theano avant de passer à Chainer v2, mais cela semble fonctionner.

finalement

Ce n'est pas difficile à corriger pour Chainer v2, mais comme il y avait de nombreux endroits où «dropout» et «BatchNormalization» étaient utilisés, la quantité de correction a été augmentée en conséquence. À la suite du correctif, le code est un peu plus propre car l'argument train que certaines fonctions avaient n'est plus nécessaire. Je pense que beaucoup de code implémenté pour la v1 ne fonctionnera pas dans la v2, donc je pense qu'il y a de nombreux cas où même si j'essaie de déplacer le code pour la v1 que j'ai récupéré immédiatement après la publication officielle de la v2, cela ne fonctionne pas.

Recommended Posts

J'ai essayé de rendre mon propre code source compatible avec Chainer v2 alpha
J'ai essayé de faire mon propre BOT lycéenne avec le style Rinna avec LINE BOT (Python & Heroku)
J'ai essayé d'entraîner la fonction péché avec chainer
J'ai essayé d'apprendre mon propre ensemble de données en utilisant Chainer Trainer
J'ai essayé de créer une application OCR avec PySimpleGUI
J'ai essayé de faire une simulation de séparation de source sonore en temps réel avec l'apprentissage automatique Python
J'ai essayé de créer diverses "données factices" avec Python faker
J'ai essayé d'implémenter ListNet d'apprentissage de rang avec Chainer
J'ai essayé de créer une interface graphique à trois yeux côte à côte avec Python et Tkinter
J'ai essayé de faire d'Othello AI que j'ai appris 7,2 millions de mains par apprentissage profond avec Chainer
[5e] J'ai essayé de créer un certain outil de type Authenticator avec python
[2nd] J'ai essayé de créer un certain outil de type Authenticator avec python
J'ai essayé de rendre le deep learning évolutif avec Spark × Keras × Docker
[3ème] J'ai essayé de créer un certain outil de type Authenticator avec python
J'ai essayé de faire un processus d'exécution périodique avec Selenium et Python
J'ai essayé de créer une application de notification de publication à 2 canaux avec Python
J'ai essayé de créer une application todo en utilisant une bouteille avec python
[4th] J'ai essayé de créer un certain outil de type Authenticator avec python
[1er] J'ai essayé de créer un certain outil de type Authenticator avec python
J'ai essayé de faire une étrange citation pour Jojo avec LSTM
J'ai essayé de créer une fonction de similitude d'image avec Python + OpenCV
J'ai essayé de créer un mécanisme de contrôle exclusif avec Go
J'ai essayé d'implémenter Autoencoder avec TensorFlow
J'ai essayé de visualiser AutoEncoder avec TensorFlow
J'ai essayé de commencer avec Hy
J'ai créé un capteur d'ouverture / fermeture (lien Twitter) avec TWE-Lite-2525A
J'ai essayé d'améliorer la précision de mon propre réseau neuronal
765 J'ai essayé d'identifier les trois familles professionnelles par CNN (avec Chainer 2.0.0)
J'ai essayé de créer un LINE BOT "Sakurai-san" avec API Gateway + Lambda
[AWS] [GCP] J'ai essayé de rendre les services cloud faciles à utiliser avec Python
J'ai essayé d'obtenir le code d'authentification de l'API Qiita avec Python.
J'ai essayé de faire un signal avec Raspeye 4 (édition Python)
J'ai essayé d'apprendre l'angle du péché et du cos avec le chainer
J'ai essayé d'implémenter CVAE avec PyTorch
J'ai créé une API Web
J'ai essayé de résoudre TSP avec QAOA
[Zaif] J'ai essayé de faciliter le commerce de devises virtuelles avec Python
J'ai essayé de créer un service de raccourcissement d'url sans serveur avec AWS CDK
J'ai fait de mon mieux pour retourner au Lasso
J'ai essayé de faire un processus périodique avec CentOS7, Selenium, Python et Chrome
J'ai essayé de publier mon propre module pour pouvoir l'installer
J'ai fait une application d'envoi de courrier simple avec tkinter de Python
Quand j'ai essayé de créer un VPC avec AWS CDK mais que je n'ai pas pu le faire
[Analyse des brevets] J'ai essayé de créer une carte des brevets avec Python sans dépenser d'argent
J'ai créé une API de recherche de château avec Elasticsearch + Sudachi + Go + echo
J'ai essayé de faire la reconnaissance de caractères manuscrits de Kana Partie 3/3 Coopération avec l'interface graphique en utilisant Tkinter
J'ai fait de mon mieux pour créer une fonction d'optimisation, mais cela n'a pas fonctionné.
J'ai essayé de créer une API de reconnaissance d'image simple avec Fast API et Tensorflow
J'ai essayé de rendre le deep learning évolutif avec Spark × Keras × Docker 2 Multi-host edition
J'ai essayé de prédire l'année prochaine avec l'IA
J'ai essayé d'implémenter la lecture de Dataset avec PyTorch
J'ai essayé d'utiliser lightGBM, xg boost avec Boruta
J'ai essayé d'apprendre le fonctionnement logique avec TF Learn
J'ai essayé de déplacer GAN (mnist) avec keras
J'ai essayé de sauvegarder les données avec discorde
J'ai essayé de détecter rapidement un mouvement avec OpenCV
J'ai essayé d'intégrer Keras dans TFv1.1
J'ai essayé Flask avec des conteneurs distants de VS Code