[PYTHON] [Bouclier d'épée Pokémon] J'ai essayé de visualiser la base de jugement de l'apprentissage en profondeur en utilisant la classification des trois familles comme exemple

introduction

Jouez-vous à Pokémon? Je l'ai acheté pour la première fois en 10 ans ~~ Je l'ai obtenu du Père Noël. Dans le but d'être une force forte, nous prévoyons de rester à la maison et de sélectionner avec soin pendant les vacances de fin d'année et du Nouvel An. Je me demandais si Advent Calendar pouvait faire quelque chose avec Pokemon Neta, alors j'ai essayé ** une méthode qui montre la base pour juger un modèle d'apprentissage en profondeur ** qui m'intéresse récemment, en utilisant ** Pokemon Three Family Classification comme exemple **. Vu.

Méthode pour montrer la base de jugement du modèle d'apprentissage en profondeur: Qu'est-ce que TCAV?

L'apprentissage en profondeur commence à être mis en œuvre dans la société dans divers domaines, mais il a tendance à être une boîte noire sur la base de laquelle le modèle prend des décisions. Ces dernières années, des recherches sur «l'explication» et «l'interprétable» des modèles sont en cours.

Par conséquent, cette fois, j'aimerais essayer la méthode Quantitative Testing with Concept Activation Vectors (TCAV) adoptée dans ICML 2018.

Aperçu de l'article

Concept de vecteurs d'activation de concept (CAV)

Dérivation de CAV en apprenant un classificateur linéaire entre une image conceptuelle et un contre-exemple aléatoire et en obtenant un vecteur orthogonal à la frontière de décision. (Il est plus rapide de voir la figure ci-dessous).

image.png

Qu'est-ce que tu sais

Commencez par créer un classificateur approprié

Cette fois, mon objectif est de déplacer TCAV, donc j'en ai fait une tâche simple. Créez un classificateur Pokemon Three Family.

Préparation du jeu de données

① Ramper

Les images suivantes ont été collectées à l'aide de icrawler. Je vais mettre le code.

import os
from icrawler.builtin import GoogleImageCrawler

save_dir = '../datasets/hibany'
os.makedirs(save_dir, exist_ok=True)

query = 'Scorbunny'
max_num = 200

google_crawler = GoogleImageCrawler(storage={'root_dir': save_dir})
google_crawler.crawl(keyword=query, max_num=max_num)

② Prétraitement

Seul un traitement minimal.

  1. ① Recadrez manuellement l'image acquise en rampant dans un carré
  2. Redimensionner à 256 x 256
  3. Divisez en train / val / test

Échantillon d'image de trois familles

Les images ont été collectées comme ça. (Au fait, j'ai été une décision rapide pour Hibani. J'adore le type de flamme)

Scorbunny Messon Sarnori
000003.jpg 000003.jpg 000006.png
156 feuilles 147 feuilles 182 feuilles

Les Pokémon suivants autres que les trois familles, les images de personnages et les illustrations surdéformées étaient également confondus, ils sont donc exclus par inspection visuelle. ~~ Kibana San Cool ~~ 000075.png 000237.jpg 000075.png

Création du classificateur

C'est un simple CNN.

image.png

Comme le nombre d'images des données de test est petit (environ 15 feuilles), la précision des données de test flotte, mais nous avons créé un modèle de classification avec une précision qui sera suffisante pour la vérification TCAV. image.png

Vous aurez besoin d'un fichier **. Pb ** pour calculer le CAV, donc enregistrez le modèle au format .pb. Ensuite, préparez-vous à voir ce que le modèle apprend.

Préparation à l'exécution de TCAV

Suivez les étapes ci-dessous pour vous préparer. (Le code que j'ai utilisé cette fois-ci est sur ici. J'écrirai correctement le README plus tard ...)

Étape 1: Préparation d'images conceptuelles (exemples positifs et négatifs)

L'image suivante est préparée pour l'image d'exemple standard. Nous avons préparé plusieurs couleurs en partant de l'hypothèse que nous classerions les trois familles en regardant les couleurs. (Bien que cela fonctionne avec 10 à 20 feuilles, il est préférable d'avoir environ 50 à 200 feuilles)

** Exemple d'image d'exemple **

blanc rouge Bleu Jaune vert noir
000001.jpg 000005.jpg 000009.jpg 000004.jpg 000023.png 000023.png
22 feuilles 20 feuilles 15 feuilles 18 feuilles 21 feuilles 17 feuilles

J'exclus ceux qui ont trop de couleurs 000025.png

** Exemple d'image négatif ** Tout ce qui ne rentre dans aucun des exemples ci-dessus est souhaitable. (Dans ce cas, il est difficile de dire qu'il ne correspond à aucune couleur.) Cette fois, j'ai pris au hasard des images de Caltech256.

La structure du répertoire des images collectées jusqu'à présent est la suivante. Tous les ensembles d'images conceptuelles doivent être des sous-répertoires.

├── datasets
│   ├── for_tcav #Jeu de données pour TCAV
│   │   ├── black
│   │   ├── blue
│   │   ├── green
│   │   ├── hibany
│   │   ├── messon
│   │   ├── random500_0
│   │   ├── random500_1
│   │   ├── random500_2
│   │   ├── random500_3
│   │   ├── random500_4
│   │   ├── random500_5
│   │   ├── red
│   │   ├── sarunori
│   │   ├── white
│   │   └── yellow
│   └── splited #Ensemble de données pour la création de modèles de classification d'images
│       ├── test
│       │   ├── hibany
│       │   ├── messon
│       │   └── sarunori
│       ├── train
│       │   ├── hibany
│       │   ├── messon
│       │   └── sarunori
│       └── validation
│           ├── hibany
│           ├── messon
│           └── sarunori

Étape 2: implémenter le wrapper de modèle

Je vais d'abord le cloner.

git clone [email protected]:tensorflow/tcav.git

Ici, nous allons créer un wrapper pour transmettre les informations du modèle à TCAV. Ajoutez cette classe à tcav / model.py.

class SimepleCNNWrapper_public(PublicImageModelWrapper):
    def __init__(self, sess, model_saved_path, labels_path):
        self.image_value_range = (0, 1)
        image_shape_v3 = [256, 256, 3]
        endpoints_v3 = dict(
            input='conv2d_1_input:0',
            logit='activation_6/Softmax:0',
            prediction='activation_6/Softmax:0',
            pre_avgpool='max_pooling2d_3/MaxPool:0',
            logit_weight='activation_6/Softmax:0',
            logit_bias='dense_1/bias:0',
        )

        self.sess = sess
        super(SimepleCNNWrapper_public, self).__init__(sess,
                                                       model_saved_path,
                                                       labels_path,
                                                       image_shape_v3,
                                                       endpoints_v3,
                                                       scope='import')
        self.model_name = 'SimepleCNNWrapper_public'

Maintenant, vous êtes prêt à partir. Voyons le résultat immédiatement.

résultat

Jetons un coup d'œil aux concepts (les couleurs cette fois) qui sont importants dans chaque classe. Les concepts non marqués d'un * sont importants.

Classe Hibani Classe Messon Classe Sarnori
image.png Rouge / jaune / blanc image.png rouge(!?) image.png vert
000003.jpg 000003.jpg 000006.png

Je pense que Hibani et Sarnori sont comme ça. Le messon est un mystère, il est donc important de le considérer. Si vous changez le nombre d'essais ou le nombre d'images conceptuelles / images cibles pendant l'expérience, les résultats changeront considérablement, donc je pense qu'il est nécessaire d'envisager un peu plus. Cela semble valoir la peine d'essayer diverses choses car cela semble changer en fonction de la façon dont vous choisissez l'image conceptuelle.

Résumé

J'ai essayé une méthode pour montrer la base de jugement du modèle de réseau neuronal. C'était facile pour les humains à interpréter, et le résultat était ** "intuitivement comme ça" **. Cette fois, j'ai choisi la couleur comme image conceptuelle car elle est classée comme une famille à trois familles, mais il est difficile de préparer l'image conceptuelle. .. Vous devez effectuer diverses préparations, mais vous n'avez pas besoin de réapprendre le modèle, et si vous essayez une fois la série d'étapes et que vous vous y habituez, vous pouvez l'utiliser facilement. Veuillez essayer par tous les moyens essayez!

Recommended Posts

[Bouclier d'épée Pokémon] J'ai essayé de visualiser la base de jugement de l'apprentissage en profondeur en utilisant la classification des trois familles comme exemple
J'ai essayé l'histoire courante de l'utilisation du Deep Learning pour prédire la moyenne Nikkei
[TF] J'ai essayé de visualiser le résultat de l'apprentissage en utilisant Tensorboard
J'ai essayé de comparer la précision des modèles d'apprentissage automatique en utilisant kaggle comme thème.
J'ai essayé l'histoire courante de prédire la moyenne Nikkei à l'aide du Deep Learning (backtest)
J'ai essayé d'exécuter le didacticiel de détection d'objets en utilisant le dernier algorithme d'apprentissage en profondeur
J'ai essayé de comprendre attentivement la machine vectorielle de support (Partie 1: J'ai essayé le noyau polynomial / RBF en utilisant MakeMoons comme exemple).
J'ai essayé de visualiser les informations spacha de VTuber
J'ai essayé de compresser l'image en utilisant l'apprentissage automatique
Python pratique 100 coups J'ai essayé de visualiser l'arbre de décision du chapitre 5 en utilisant graphviz
[Deep Learning from scratch] J'ai essayé d'expliquer la confirmation du gradient d'une manière facile à comprendre.
J'ai essayé de vérifier la classification yin et yang des membres hololive par apprentissage automatique
[Python] J'ai essayé de visualiser la relation de suivi de Twitter
[Apprentissage automatique] J'ai essayé de résumer la théorie d'Adaboost
Comprendre la fonction de convolution en utilisant le traitement d'image comme exemple
J'ai essayé d'obtenir l'index de la liste en utilisant la fonction énumérer
Je voulais contester la classification du CIFAR-10 en utilisant l'entraîneur de Chainer
J'ai essayé de visualiser la condition commune des téléspectateurs de la chaîne VTuber
J'ai essayé l'apprentissage en profondeur avec Theano
[Fabric] J'étais accro à l'utilisation de booléen comme argument, alors notez les contre-mesures.
J'ai essayé de transformer l'image du visage en utilisant sparse_image_warp de TensorFlow Addons
J'ai essayé d'obtenir les résultats de Hachinai en utilisant le traitement d'image
J'ai essayé de visualiser la tranche d'âge et la distribution des taux d'Atcoder
J'ai essayé de transcrire les actualités de l'exemple d'intégration commerciale sur Amazon Transcribe
J'ai essayé d'estimer la similitude de l'intention de la question en utilisant Doc2Vec de gensim
J'ai essayé d'extraire et d'illustrer l'étape de l'histoire à l'aide de COTOHA
J'ai essayé de visualiser le texte du roman "Weather Child" avec Word Cloud
J'ai essayé de visualiser le modèle avec la bibliothèque d'apprentissage automatique low-code "PyCaret"
En utilisant COTOHA, j'ai essayé de suivre le cours émotionnel de la course aux meros.
J'ai essayé 200 fois l'échange magique [Pokemon Sword Shield]
Visualisez les effets de l'apprentissage profond / de la régularisation
J'ai essayé de comprendre attentivement la fonction d'apprentissage dans le réseau de neurones sans utiliser la bibliothèque d'apprentissage automatique (première moitié)
J'ai essayé d'obtenir les informations du site .aspx qui est paginé à l'aide de Selenium IDE aussi sans programmation que possible.
J'ai essayé de notifier la mise à jour de "Hameln" en utilisant "Beautiful Soup" et "IFTTT"
[Python] J'ai essayé de juger l'image du membre du groupe d'idols en utilisant Keras
[Python] Deep Learning: J'ai essayé d'implémenter Deep Learning (DBN, SDA) sans utiliser de bibliothèque.
J'ai essayé de visualiser facilement les tweets de JAWS DAYS 2017 avec Python + ELK
J'ai essayé d'extraire le dessin au trait de l'image avec Deep Learning
J'ai essayé de prédire la présence ou l'absence de neige par apprentissage automatique.
Un amateur a essayé le Deep Learning avec Caffe (Introduction)
Apprentissage automatique du sport-Analyse de la J-League à titre d'exemple-②
Un amateur a essayé le Deep Learning en utilisant Caffe (Practice)
J'ai essayé de corriger la forme trapézoïdale de l'image
Un amateur a essayé le Deep Learning avec Caffe (Vue d'ensemble)
J'ai essayé d'utiliser le filtre d'image d'OpenCV
J'ai essayé de vectoriser les paroles de Hinatazaka 46!
J'ai essayé de prédire la victoire ou la défaite de la Premier League en utilisant le SDK Qore
J'ai essayé de notifier la mise à jour de "Devenir romancier" en utilisant "IFTTT" et "Devenir un romancier API"
J'ai essayé d'extraire le texte du fichier image en utilisant Tesseract du moteur OCR
J'ai essayé de visualiser les caractéristiques des nouvelles informations sur les personnes infectées par le virus corona avec wordcloud
[First data science ⑥] J'ai essayé de visualiser le prix du marché des restaurants à Tokyo
J'ai essayé de visualiser les données de course du jeu de course (Assetto Corsa) avec Plotly
J'ai essayé de fonctionner à partir de Postman en utilisant Cisco Guest Shell comme serveur API
J'ai essayé de résumer la forme de base de GPLVM
Techniques pour comprendre la base des décisions d'apprentissage en profondeur
J'ai essayé d'obtenir une AMI en utilisant AWS Lambda
J'ai essayé d'approcher la fonction sin en utilisant le chainer