[PYTHON] Je voulais contester la classification du CIFAR-10 en utilisant l'entraîneur de Chainer

introduction

L'autre jour ici j'ai appris que Chainer peut être écrit de manière très concise, j'ai donc contesté la classification d'image de CIFAR-10 que je voulais essayer auparavant. j'ai essayé Je voulais écrire ..., mais je n'ai qu'un environnement CPU médiocre, donc je n'ai pas pu confirmer l'exécution. Je l'ai déplacé toute la journée et avancé d'environ 2 époques, donc c'est probablement correct ... ^^; Concernant la mise en œuvre, je me suis référé au blog de ici.

la mise en oeuvre

Chargez l'image de CIFAR-10

Téléchargez et chargez les données CIFAR-10 depuis ici. Puisqu'il est comme pickle, il est lu par la fonction suivante.

def unpickle(file):
    fp = open(file, 'rb')
    if sys.version_info.major == 2:
        data = pickle.load(fp)
    elif sys.version_info.major == 3:
        data = pickle.load(fp, encoding='latin-1')                                                                                    
    fp.close()

    return data

réseau neuronal

J'ai fait référence au blog que j'ai présenté plus tôt. Je ne sais toujours pas comment concevoir cette zone ...

class Cifar10Model(chainer.Chain):

    def __init__(self):
        super(Cifar10Model,self).__init__(
                conv1 = F.Convolution2D(3, 32, 3, pad=1),
                conv2 = F.Convolution2D(32, 32, 3, pad=1),
                conv3 = F.Convolution2D(32, 32, 3, pad=1),
                conv4 = F.Convolution2D(32, 32, 3, pad=1),
                conv5 = F.Convolution2D(32, 32, 3, pad=1),
                conv6 = F.Convolution2D(32, 32, 3, pad=1),
                l1 = L.Linear(512, 512),
                l2 = L.Linear(512,10))

    def __call__(self, x, train=True):
        h = F.relu(self.conv1(x))
        h = F.max_pooling_2d(F.relu(self.conv2(h)), 2)
        h = F.relu(self.conv3(h))
        h = F.max_pooling_2d(F.relu(self.conv4(h)), 2)
        h = F.relu(self.conv5(h))
        h = F.max_pooling_2d(F.relu(self.conv6(h)), 2)
        h = F.dropout(F.relu(self.l1(h)), train=train)
        return self.l2(h)

Lecture des données

Je suis un peu bouché ici. Lors de l'utilisation du nouvel entraîneur de fonctions de Chainer, je transmets les données que je veux apprendre à l'itérateur, mais dans des tutoriels, etc.

train_iter = chainer.iterators.SerialIterator(train, 100)
test_iter = chainer.iterators.SerialIterator(test, 100,repeat=False, shuffle=False)

Je ne savais pas comment passer le label etc. Après de nombreuses recherches, j'ai trouvé que je devais utiliser Tuple_dataset.

train = chainer.tuple_dataset.TupleDataset(train_data, train_label)

Cela semble bon de faire comme ça.

Vous trouverez ci-dessous le code complet de la partie lue.

x_train = None
y_train = []
for i in range(1,6):
    data_dic = unpickle("cifar-10-batches-py/data_batch_{}".format(i))
    if i == 1:
        x_train = data_dic['data']
    else:
        x_train = np.vstack((x_train, data_dic['data']))
    y_train += data_dic['labels']

test_data_dic = unpickle("cifar-10-batches-py/test_batch")
x_test = test_data_dic['data']
x_test = x_test.reshape(len(x_test),3,32,32)
y_test = np.array(test_data_dic['labels'])
x_train = x_train.reshape((len(x_train),3, 32, 32))
y_train = np.array(y_train)
x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)
x_train /= 255
x_test/=255                                                                                                                     
y_train = y_train.astype(np.int32)
y_test = y_test.astype(np.int32)

train = tuple_dataset.TupleDataset(x_train, y_train)
test = tuple_dataset.TupleDataset(x_test, y_test)

Partie d'apprentissage

J'apprends avec le réseau neuronal que j'ai défini plus tôt. Le code est juste une petite modification du tutoriel MNIST. Je suis surpris de pouvoir l'écrire d'une manière incroyablement concise.


model = L.Classifier(Cifar10Model())
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)

train_iter = chainer.iterators.SerialIterator(train, 100)
test_iter = chainer.iterators.SerialIterator(test, 100,repeat=False, shuffle=False)

updater = training.StandardUpdater(train_iter, optimizer, device=-1)
trainer = training.Trainer(updater, (40, 'epoch'), out="logs")
trainer.extend(extensions.Evaluator(test_iter, model, device=-1))
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport( ['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy']))
trainer.extend(extensions.ProgressBar())                                                                                          
trainer.run()

résultat

Lorsque vous l'exécutez, une barre de progression apparaîtra pour vous dire combien vous apprenez.

Screenshot from 2016-08-06 18:04:00.png

Estimated time to finish: 6 days

~~ J'ai abandonné ~~ (Corrigé le 15.08.2016) J'ai fait de mon mieux

J'ai lu le journal de sortie dans le dictionnaire et l'ai représenté graphiquement avec matplotlib

figure_1.png

figure_2.png

en conclusion

~~ Je n'ai pas pu confirmer que le résultat était correct, mais j'ai appris à utiliser ~~ trainer. Après tout, GPU est indispensable pour étudier le Deep Learning

Recommended Posts

Je voulais contester la classification du CIFAR-10 en utilisant l'entraîneur de Chainer
J'ai essayé de transformer l'image du visage en utilisant sparse_image_warp de TensorFlow Addons
J'ai essayé d'obtenir les résultats de Hachinai en utilisant le traitement d'image
J'ai essayé d'estimer la similitude de l'intention de la question en utilisant Doc2Vec de gensim
Je voulais faire attention au comportement des arguments par défaut de Python
J'ai essayé d'extraire et d'illustrer l'étape de l'histoire à l'aide de COTOHA
J'ai essayé l'histoire courante de l'utilisation du Deep Learning pour prédire la moyenne Nikkei
En utilisant COTOHA, j'ai essayé de suivre le cours émotionnel de la course aux meros.
Je voulais jouer avec la courbe de Bézier
J'ai essayé de corriger la forme trapézoïdale de l'image
Je souhaite personnaliser l'apparence de zabbix
J'ai essayé d'utiliser le filtre d'image d'OpenCV
J'ai essayé de vectoriser les paroles de Hinatazaka 46!
J'ai essayé de prédire la détérioration de la batterie lithium-ion en utilisant le SDK Qore
J'ai essayé de notifier la mise à jour de "Hameln" en utilisant "Beautiful Soup" et "IFTTT"
[Python] J'ai essayé de juger l'image du membre du groupe d'idols en utilisant Keras
J'ai fait un script pour enregistrer la fenêtre active en utilisant win32gui de Python
Continuer à relever les défis de Cyma en utilisant le service OCR de Google Cloud Platform
J'ai essayé de prédire la victoire ou la défaite de la Premier League en utilisant le SDK Qore
J'ai essayé de notifier la mise à jour de "Devenir romancier" en utilisant "IFTTT" et "Devenir un romancier API"
J'ai essayé de résumer la forme de base de GPLVM
Python pratique 100 coups J'ai essayé de visualiser l'arbre de décision du chapitre 5 en utilisant graphviz
Je voulais utiliser la bibliothèque Python de MATLAB
Je veux bien comprendre les bases de Bokeh
J'ai essayé d'extraire le texte du fichier image en utilisant Tesseract du moteur OCR
Je souhaite prendre une capture d'écran du site sur Docker en utilisant n'importe quelle police
J'ai essayé d'approcher la fonction sin en utilisant le chainer
J'ai examiné l'argument class_weight de la fonction softmax_cross_entropy de Chainer.
J'ai essayé d'utiliser l'API de Sakenowa Data Project
J'ai essayé de visualiser les informations spacha de VTuber
J'ai essayé d'effacer la partie négative de Meros
Je voulais juste extraire les données de la date et de l'heure souhaitées avec Django
[Échec] Je voulais générer des phrases en utilisant TextRegressor de Flair
Je veux automatiser ssh en utilisant la commande expect!
L'histoire de l'utilisation de Circleci pour construire des roues Manylinux
J'ai essayé la méthode la plus simple de classification de documents multi-étiquettes
J'ai essayé d'identifier la langue en utilisant CNN + Melspectogram
J'ai essayé de compléter le graphe de connaissances en utilisant OpenKE
J'ai essayé de classer les voix des acteurs de la voix
Je souhaite augmenter la sécurité de la connexion SSH
J'ai essayé de compresser l'image en utilisant l'apprentissage automatique
J'ai essayé de résumer les opérations de chaîne de Python
[Bouclier d'épée Pokémon] J'ai essayé de visualiser la base de jugement de l'apprentissage en profondeur en utilisant la classification des trois familles comme exemple
Je voulais faire fonctionner le moteur avec une tarte à la râpe, alors j'ai essayé d'utiliser la carte de commande du moteur de Waveshare
J'ai essayé de comparer la précision des modèles d'apprentissage automatique en utilisant kaggle comme thème.
J'ai essayé de vérifier la classification yin et yang des membres hololive par apprentissage automatique
J'ai essayé de prédire l'infection d'une nouvelle pneumonie en utilisant le modèle SIR: ☓ Wuhan edition ○ Hubei province edition
J'ai essayé d'automatiser la construction d'un environnement pratique à l'aide de l'API SoftLayer d'IBM Cloud
Osez remplir le formulaire sans utiliser de sélénium
J'ai essayé de trouver l'entropie de l'image avec python
[Courses de chevaux] J'ai essayé de quantifier la force du cheval de course
J'ai essayé d'obtenir les informations de localisation du bus Odakyu
Je voulais résoudre le concours de programmation Panasonic 2020 avec Python
J'ai essayé de trouver la moyenne de plusieurs colonnes avec TensorFlow
Je veux automatiser ssh en utilisant la commande expect! partie 2
J'ai essayé de refactoriser le modèle CNN de TensorFlow en utilisant TF-Slim
J'ai essayé de simuler l'optimisation des publicités à l'aide de l'algorithme Bandit
J'ai essayé la reconnaissance faciale du problème du rire en utilisant Keras.