[PYTHON] [Pokemon Sword Shield] I tried to visualize the judgment basis of deep learning using the three family classification as an example.

Introduction

Do you guys play Pokemon? I bought it for the first time in 10 years ~~ I got it from Santa. Aiming to be a strong force, we plan to stay home and carefully select during the year-end and New Year holidays. I was wondering if Advent Calendar could do something with Pokemon material, so I tried ** a method to show the judgment basis of the deep learning model ** that I am interested in recently ** using Pokemon Three Family Classification as an example ** Saw.

Method to show the judgment basis of deep learning model: What is TCAV?

Deep learning is beginning to be implemented in society in various fields, but it tends to be a black box on what the model makes decisions. In recent years, research on the "explanatory" and "interpretability" of models has been underway.

Therefore, this time, I would like to try the method Quantitative Testing with Concept Activation Vectors (TCAV) adopted in ICML 2018.

Paper summary

-** Method to show the judgment basis of the neural network model ** --Indicates the importance of the ** concept ** (color, gender, race, etc.) of the prediction class, rather than the traditional method of calculating importance for each ** pixel. --Since it generates ** explanations for each class ** (≒ global) instead of explanations for each image (≒ local), it has easy-to-understand explanations for humans. -** You can understand the explanation without specialized knowledge of ML model ** --There is no need to retrain or change the existing model you want to interpret **

Concept Activation Vectors (CAV) concept

Derivation of CAV by training a linear classifier between a conceptual image and a random counterexample and obtaining a vector orthogonal to the decision boundary. (It is faster to see the figure below).

image.png

What do you know

--The ** "concept" ** that the model is learning can be quantified in a form that can be interpreted by humans. Example: Learning "striped pattern" from "dot pattern" in "zebra" classification You can also see learning in any layer, so you can see how coarse / detailed features are captured in the shallow / deep layers. image.png

--Understand the ** bias ** of the dataset Example: In the "Apron" class, the concept of "female" is related, in the "Rugby ball" class, the concept of "white" is related image.pngimage.png --Can be used as an image sorter (can be sorted based on similarity to conceptual images)

First make a suitable classifier

This time, my goal is to move TCAV, so I made it a simple task. Make a Pokemon Three Family Classifier.

Data set preparation

① Crawling

I collected the following images using icrawler. I will put the code.

import os
from icrawler.builtin import GoogleImageCrawler

save_dir = '../datasets/hibany'
os.makedirs(save_dir, exist_ok=True)

query = 'Scorbunny'
max_num = 200

google_crawler = GoogleImageCrawler(storage={'root_dir': save_dir})
google_crawler.crawl(keyword=query, max_num=max_num)

② Pretreatment

Only minimal processing.

  1. ① Manually crop the image acquired by crawling into a square
  2. Resize to 256 x 256
  3. Divide into train / val / test

Misanke image sample

The images were collected like this. (By the way, I was a quick decision for Hibani. I love flame type)

Scorbunny Messon Sarnori
000003.jpg 000003.jpg 000006.png
156 sheets 147 sheets 182 sheets

The following Pokemon other than the three families, character images, and over-deformed illustrations were also confused, so they are excluded by visual inspection. ~~ Kibana San Cool ~~ 000075.png 000237.jpg 000075.png

Create classifier

It's a simple CNN.

image.png

Since there are few images of test data (about 15 images), the accuracy of test data is fluttering, but we have created a classification model with accuracy that will be sufficient for TCAV verification. image.png

You will need a **. Pb file ** to calculate the CAV, so save the model in .pb. Next, get ready to see what the model is learning.

Preparing to execute TCAV

Follow the steps below to prepare. (The code used this time is in here. I will write the README properly later ...)

Step1: Preparation of conceptual images (positive and negative examples)

The following image is prepared for the regular image. We prepared several colors based on the hypothesis that we would classify the three families by looking at the colors. (Although it works with 10 to 20 sheets, it is better to have about 50 to 200 sheets)

** Regular image sample **

White Red Blue yellow Green black
000001.jpg 000005.jpg 000009.jpg 000004.jpg 000023.png 000023.png
22 sheets 20 sheets 15 sheets 18 sheets 21 sheets 17 sheets

I'm excluding those that have too many colors mixed together. 000025.png

** Negative example image sample ** Anything that does not fit into any of the above examples is desirable. (In this case, it is difficult to say that it does not correspond to any color.) This time, I randomly took images from Caltech256.

The directory structure of the images collected so far is as follows. All sets of conceptual images should be subdirectories.

├── datasets
│   ├── for_tcav #Data set for TCAV
│   │   ├── black
│   │   ├── blue
│   │   ├── green
│   │   ├── hibany
│   │   ├── messon
│   │   ├── random500_0
│   │   ├── random500_1
│   │   ├── random500_2
│   │   ├── random500_3
│   │   ├── random500_4
│   │   ├── random500_5
│   │   ├── red
│   │   ├── sarunori
│   │   ├── white
│   │   └── yellow
│   └── splited #Data set for image classification model creation
│       ├── test
│       │   ├── hibany
│       │   ├── messon
│       │   └── sarunori
│       ├── train
│       │   ├── hibany
│       │   ├── messon
│       │   └── sarunori
│       └── validation
│           ├── hibany
│           ├── messon
│           └── sarunori

Step2: Implement the model wrapper

I will clone it first.

git clone [email protected]:tensorflow/tcav.git

Here, we will create a wrapper to convey model information to TCAV. Add this class to tcav / model.py.

class SimepleCNNWrapper_public(PublicImageModelWrapper):
    def __init__(self, sess, model_saved_path, labels_path):
        self.image_value_range = (0, 1)
        image_shape_v3 = [256, 256, 3]
        endpoints_v3 = dict(
            input='conv2d_1_input:0',
            logit='activation_6/Softmax:0',
            prediction='activation_6/Softmax:0',
            pre_avgpool='max_pooling2d_3/MaxPool:0',
            logit_weight='activation_6/Softmax:0',
            logit_bias='dense_1/bias:0',
        )

        self.sess = sess
        super(SimepleCNNWrapper_public, self).__init__(sess,
                                                       model_saved_path,
                                                       labels_path,
                                                       image_shape_v3,
                                                       endpoints_v3,
                                                       scope='import')
        self.model_name = 'SimepleCNNWrapper_public'

Now you are ready to go. Let's see the result immediately.

result

Let's take a look at the concepts (colors this time) that are important in each class. Concepts not marked with * are important.

Hibani class Messon class Sarnori class
image.png Red / yellow / white image.png Red(!?) image.png Green
000003.jpg 000003.jpg 000006.png

I think Hibani and Sarnori are like that. The messon is a mystery, so it is important to consider it. If you change the number of trials or the number of conceptual images / target images during the experiment, the results will change considerably, so I think it is necessary to consider a little more. It seems to be worth trying various things because it seems to change depending on how you choose the conceptual image.

Summary

I tried a method to show the judgment basis of the neural network model. It was easy for humans to interpret, and ** "intuitively like that" ** results were obtained. This time, I chose the color as the conceptual image because it is classified as a three-family family, but it is difficult to prepare the conceptual image. .. Various preparations are required, but there is no need to relearn the model, and if you try the series of steps once and get used to it, you can use it easily. Please try by all means try!

Recommended Posts

[Pokemon Sword Shield] I tried to visualize the judgment basis of deep learning using the three family classification as an example.
I tried the common story of using Deep Learning to predict the Nikkei 225
[TF] I tried to visualize the learning result using Tensorboard
I tried to compare the accuracy of machine learning models using kaggle as a theme.
I tried using the trained model VGG16 of the deep learning library Keras
I tried the common story of predicting the Nikkei 225 using deep learning (backtest)
I tried running an object detection tutorial using the latest deep learning algorithm
I tried to understand the support vector machine carefully (Part 1: I tried the polynomial / RBF kernel using MakeMoons as an example).
I tried to visualize the spacha information of VTuber
I tried to compress the image using machine learning
Python practice 100 knocks I tried to visualize the decision tree of Chapter 5 using graphviz
[Deep Learning from scratch] I tried to explain the gradient confirmation in an easy-to-understand manner.
I tried to verify the yin and yang classification of Hololive members by machine learning
[Python] I tried to visualize the follow relationship of Twitter
[Machine learning] I tried to summarize the theory of Adaboost
Understand the function of convolution using image processing as an example
I tried to get the index of the list using the enumerate function
I wanted to challenge the classification of CIFAR-10 using Chainer's trainer
I tried to visualize the common condition of VTuber channel viewers
I tried deep learning using Theano
[Fabric] I was addicted to using boolean as an argument, so make a note of the countermeasures.
I tried to transform the face image using sparse_image_warp of TensorFlow Addons
I tried to get the batting results of Hachinai using image processing
I tried to visualize the age group and rate distribution of Atcoder
I tried transcribing the news of the example business integration to Amazon Transcribe
I tried to estimate the similarity of the question intent using gensim's Doc2Vec
I tried to extract and illustrate the stage of the story using COTOHA
I tried to visualize the text of the novel "Weathering with You" with WordCloud
I tried to visualize the model with the low-code machine learning library "PyCaret"
Using COTOHA, I tried to follow the emotional course of Run, Melos!
I tried 200 magical exchanges [Pokemon Sword Shield]
Visualize the effects of deep learning / regularization
[Python] I tried to analyze the characteristics of thumbnails that are easy to play on YouTube by deep learning
I tried to understand the learning function of neural networks carefully without using a machine learning library (first half).
I tried to get the information of the .aspx site that is paging using Selenium IDE as non-programming as possible.
I tried to notify the update of "Hamelin" using "Beautiful Soup" and "IFTTT"
[Python] I tried to judge the member image of the idol group using Keras
[Python] Deep Learning: I tried to implement deep learning (DBN, SDA) without using a library.
I tried to easily visualize the tweets of JAWS DAYS 2017 with Python + ELK
I tried to extract a line art from an image with Deep Learning
I tried to predict the presence or absence of snow by machine learning.
An amateur tried Deep Learning using Caffe (Introduction)
Machine learning of sports-Analysis of J-League as an example-②
An amateur tried Deep Learning using Caffe (Practice)
I tried to touch the API of ebay
I tried to correct the keystone of the image
An amateur tried Deep Learning using Caffe (Overview)
I tried using the image filter of OpenCV
I tried to predict the price of ETF
I tried to vectorize the lyrics of Hinatazaka46!
I tried to predict the victory or defeat of the Premier League using the Qore SDK
I tried to notify the update of "Become a novelist" using "IFTTT" and "Become a novelist API"
I tried to extract the text in the image file using Tesseract of the OCR engine
I tried to visualize the characteristics of new coronavirus infected person information with wordcloud
[First data science ⑥] I tried to visualize the market price of restaurants in Tokyo
I tried to visualize the running data of the racing game (Assetto Corsa) with Plotly
I tried to operate from Postman using Cisco Guest Shell as an API server
I tried to summarize the basic form of GPLVM
Techniques for understanding the basis of deep learning decisions
I tried to get an AMI using AWS Lambda
I tried to approximate the sin function using chainer