[PYTHON] Visualisation des cartes et des filtres des fonctionnalités CNN (Tensorflow 2.0)

Aperçu

J'ai regardé la carte des fonctionnalités et les filtres dans le CNN construit avec le modèle de sous-classification.

environnement

-Software- Windows 10 Home Anaconda3 64-bit(Python3.7) VSCode -Library- Tensorflow 2.1.0 opencv-python 4.1.2.30 -Hardware- CPU: Intel core i9 9900K GPU: NVIDIA GeForce RTX2080ti RAM: 16GB 3200MHz

référence

siteKeras: Visualisez CNN à l'aide de Fashion-MNIST

programme

Je le posterai sur Github. https://github.com/himazin331/CNN-Visualization Le référentiel contient un programme de démonstration (cnn_visual.py), un module de visualisation de la carte des caractéristiques (feature_visual.py) et Inclut un module de visualisation de filtre (filter_visual.py).

Code source

Les parties les moins pertinentes sont omises. ** Veuillez noter que le code est sale ... **

cnn_visual.py


import argparse as arg
import os
import sys

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #Masquer le message TF

import tensorflow as tf
import tensorflow.keras.layers as kl

import numpy as np
import matplotlib.pyplot as plt

import feature_visual
import filter_visual

# CNN
class CNN(tf.keras.Model):
    
    def __init__(self, n_out, input_shape):
        super().__init__()

        self.conv1 = kl.Conv2D(16, 4, activation='relu', input_shape=input_shape)
        self.conv2 = kl.Conv2D(32, 4, activation='relu')
        self.conv3 = kl.Conv2D(64, 4, activation='relu')

        self.mp1 = kl.MaxPool2D((2, 2), padding='same')
        self.mp2 = kl.MaxPool2D((2, 2), padding='same')
        self.mp3 = kl.MaxPool2D((2, 2), padding='same')
   
        self.flt = kl.Flatten()
      
        self.link = kl.Dense(1024, activation='relu')
        self.link_class = kl.Dense(n_out, activation='softmax')

    def call(self, x):   
        
        h1 = self.mp1(self.conv1(x))
        h2 = self.mp2(self.conv2(h1))
        h3 = self.mp3(self.conv3(h2))
        
        h4 = self.link(self.flt(h3))

        return self.link_class(h4)

#Apprentissage
class trainer(object):

    def __init__(self, n_out, input_shape):

        self.model = CNN(n_out, input_shape)
        self.model.compile(optimizer=tf.keras.optimizers.Adam(),
                           loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                           metrics=['accuracy'])
        

    def train(self, train_img, train_lab, batch_size, epochs, input_shape, test_img):
        #Apprentissage
        self.model.fit(train_img, train_lab, batch_size=batch_size, epochs=epochs)
        
        print("___Training finished\n\n")
        
        #Visualisation de la carte des fonctionnalités
        feature_visual.feature_vi(self.model, input_shape, train_img)
        #Visualisation des filtres
        filter_visual.filter_vi(self.model)

def main():

    """
Options de ligne de commande
    """

    #Acquisition d'ensembles de données, prétraitement
    (train_img, train_lab), (test_img, _) = tf.keras.datasets.mnist.load_data()
    train_img = tf.convert_to_tensor(train_img, np.float32)
    train_img /= 255
    train_img = train_img[:, :, :, np.newaxis]

    test_img = tf.convert_to_tensor(test_img, np.float32)
    test_img /= 255
    test_img = train_img[:, :, :, np.newaxis]

    #Commencer à apprendre
    print("___Start training...")

    input_shape = (28, 28, 1)

    Trainer = trainer(10, input_shape)
    Trainer.train(train_img, train_lab, batch_size=args.batch_size,
                epochs=args.epoch, input_shape=input_shape, test_img=test_img)

if __name__ == '__main__':
    main()

Résultat d'exécution

Cette fois, on m'a demandé de saisir les numéros manuscrits du MNIST. Le résultat est 10 Epoch et 256 mini-lots.

Carte des caractéristiques

** Couche pliante 1 ** image.png

** Couche de mise en commun 1 ** image.png

** Couche pliante 2 ** image.png

** Couche de mise en commun 2 ** image.png

filtre

** Couche pliante 1 ** image.png

** Couche pliante 2 ** image.png

** Couche pliante 3 ** Comme l'écran est petit et difficile à voir, il est agrandi et rogné par l'édition. image.png

La description

Je vais expliquer le code associé.

Le modèle de réseau est un CNN avec la structure suivante.

Modèle de réseau


# CNN
class CNN(tf.keras.Model):
    
    def __init__(self, n_out, input_shape):
        super().__init__()

        self.conv1 = kl.Conv2D(16, 4, activation='relu', input_shape=input_shape)
        self.conv2 = kl.Conv2D(32, 4, activation='relu')
        self.conv3 = kl.Conv2D(64, 4, activation='relu')

        self.mp1 = kl.MaxPool2D((2, 2), padding='same')
        self.mp2 = kl.MaxPool2D((2, 2), padding='same')
        self.mp3 = kl.MaxPool2D((2, 2), padding='same')
   
        self.flt = kl.Flatten()
      
        self.link = kl.Dense(1024, activation='relu')
        self.link_class = kl.Dense(n_out, activation='softmax')

    def call(self, x):   
        
        h1 = self.mp1(self.conv1(x))
        h2 = self.mp2(self.conv2(h1))
        h3 = self.mp3(self.conv3(h2))
        
        h4 = self.link(self.flt(h3))

        return self.link_class(h4)

La visualisation de la carte des entités est effectuée dans feature_visual.py.

feature_visual.py


import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #Masquer le message TF

import tensorflow as tf

import numpy as np
import matplotlib.pyplot as plt

#Visualisation de la carte des fonctionnalités
def feature_vi(model, input_shape, test_img):
        
    #Reconstruction du modèle
    x = tf.keras.Input(shape=input_shape)
    model_vi = tf.keras.Model(inputs=x, outputs=model.call(x))
     
    #Sortie de configuration réseau
    model_vi.summary()
    print("")
        
    #Obtenir des informations sur la couche
    feature_vi = []
    feature_vi.append(model_vi.get_layer('input_1'))
    feature_vi.append(model_vi.get_layer('conv2d'))
    feature_vi.append(model_vi.get_layer('max_pooling2d'))
    feature_vi.append(model_vi.get_layer('conv2d_1'))
    feature_vi.append(model_vi.get_layer('max_pooling2d_1'))

    #Extraction de données aléatoires
    idx = int(np.random.randint(0, len(test_img), 1))
    img = test_img[idx]
    img = img[None, :, :, :]

    for i in range(len(feature_vi)-1):
            
        #Acquisition de la carte des caractéristiques
        feature_model = tf.keras.Model(inputs=feature_vi[0].input, outputs=feature_vi[i+1].output)
        feature_map = feature_model.predict(img)
        feature_map = feature_map[0]
        feature = feature_map.shape[2]
            
        #Définition du nom de la fenêtre
        fig = plt.gcf()
        fig.canvas.set_window_title(feature_vi[i+1].name + " feature-map visualization")
            
        #production
        for j in range(feature):
            plt.subplots_adjust(wspace=0.4, hspace=0.8)
            plt.subplot(feature/6 + 1, 6, j+1)
            plt.xticks([])
            plt.yticks([])
            plt.xlabel(f'filter {j}')
            plt.imshow(feature_map[:,:,j])
        plt.show()

Vous ne pouvez pas utiliser le modèle de classe CNN tel quel. En effet, la couche d'entrée n'est pas définie. Dans le cas d'une implémentation dans le modèle SubClassing, ** ajoutez une couche d'entrée au modèle ** comme indiqué ci-dessous.

    #Reconstruction du modèle
    x = tf.keras.Input(shape=input_shape)
    model_vi = tf.keras.Model(inputs=x, outputs=model.call(x))

Ensuite, préparez une liste et ** ajoutez la couche d'entrée et les informations de couche arbitraire à la liste **. Cette fois, je veux voir la sortie de la première couche de pliage, de la première couche de mise en commun maximale, de la deuxième couche de pliage et de la deuxième couche de mise en commun maximale. Décrivez comme suit.

    #Obtenir des informations sur la couche
    feature_vi = []
    feature_vi.append(model_vi.get_layer('input_1'))
    feature_vi.append(model_vi.get_layer('conv2d'))
    feature_vi.append(model_vi.get_layer('max_pooling2d'))
    feature_vi.append(model_vi.get_layer('conv2d_1'))
    feature_vi.append(model_vi.get_layer('max_pooling2d_1'))

Ensuite, préparez les données d'entrée. Les données de test correspondant à l'index sont acquises en utilisant une valeur numérique aléatoire comme index. Puisque la forme des données de test acquises est (28, 28, 1), nous ajouterons la dimension du nombre d'éléments de données.

    #Extraction de données aléatoires
    idx = int(np.random.randint(0, len(test_img), 1))
    img = test_img[idx]
    img = img[None, :, :, :]

Construisez un modèle feature_model avec l'entrée comme couche d'entrée et la sortie comme sortie de chaque couche. Passez ensuite les données d'entrée avec `prédire 'et obtenez la sortie de la couche.

        #Acquisition de la carte des caractéristiques
        feature_model = tf.keras.Model(inputs=feature_vi[0].input, outputs=feature_vi[i+1].output)
        feature_map = feature_model.predict(img)
        feature_map = feature_map[0]
        feature = feature_map.shape[2]

Après cela, tracez la sortie de la couche et répétez-la comme la sortie de la couche suivante.


La visualisation des filtres se fait dans fileter_visual.py.

filter_visual.py


import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' #Masquer le message TF

import tensorflow as tf

import numpy as np
import matplotlib.pyplot as plt

#Visualisation des filtres
def filter_vi(model):
        
    vi_layer = []
        
    #Calque à visualiser
    vi_layer.append(model.get_layer('conv2d'))
    vi_layer.append(model.get_layer('conv2d_1'))
    vi_layer.append(model.get_layer('conv2d_2'))
        
    for i in range(len(vi_layer)):      
            
        #Obtenir un filtre de calque
        target_layer = vi_layer[i].get_weights()[0]
        filter_num = target_layer.shape[3]
            
        #Définition du nom de la fenêtre
        fig = plt.gcf()
        fig.canvas.set_window_title(vi_layer[i].name + " filter visualization")
            
        #production
        for j in range(filter_num):
            plt.subplots_adjust(wspace=0.4, hspace=0.8)
            plt.subplot(filter_num/6 + 1, 6, j+1)
            plt.xticks([])
            plt.yticks([])
            plt.xlabel(f'filter {j}')  
            plt.imshow(target_layer[ :, :, 0, j], cmap="gray") 
        plt.show()

Comme pour la visualisation de la carte des entités, ajoutez à la liste des couches convolutives qui correspondent au filtre que vous souhaitez voir.

    vi_layer = []
        
    #Calque à visualiser
    vi_layer.append(model.get_layer('conv2d'))
    vi_layer.append(model.get_layer('conv2d_1'))
    vi_layer.append(model.get_layer('conv2d_2'))

Obtenez le ** filtre de la couche cible avec get_weights () [0] **. Au fait, vous pouvez obtenir le biais en écrivant get_weights () [1].

La forme du filtre obtenu est (H, W, I_C, O_C). I_C est le nombre de canaux d'entrée et O_C est le nombre de canaux de sortie.

        #Obtenir un filtre de calque
        target_layer = vi_layer[i].get_weights()[0]
        filter_num = target_layer.shape[3]

Après cela, sortez le filtre et répétez comme le filtre suivant.

en conclusion

Je voulais voir la carte des fonctionnalités et les filtres, alors je l'ai recherchée et l'ai mise en œuvre avec divers changements. La carte des fonctionnalités est intéressante à regarder, mais je ne sais pas quel est le filtre, donc ce n'est pas intéressant ~~. Ces dernières années, il semble que l'IA explicable (IA) ait attiré l'attention, mais j'attends avec impatience le moment où les humains pourront comprendre pourquoi de tels filtres peuvent la reconnaître.

Recommended Posts

Visualisation des cartes et des filtres des fonctionnalités CNN (Tensorflow 2.0)
Agrégation et visualisation des nombres accumulés
pix2pix tensorflow2 Enregistrement d'essais et d'erreurs
Visualisation de corrélation entre la quantité de caractéristiques et la variable objective
Collection de recettes comparant les versions 1 et 2 de TensorFlow (partie 1)
Record of TensorFlow mnist Expert Edition (Visualisation de TensorBoard)
Analyse des données financières par pandas et leur visualisation (2)
Analyse des données financières par pandas et leur visualisation (1)
Estimation la plus probable de la moyenne et de la variance avec TensorFlow
Vue d'ensemble et astuces de Seaborn avec visualisation de données statistiques
[Ingénierie de contrôle] Visualisation et analyse du contrôle PID et de la réponse par étapes
Visualisation de la connexion entre le malware et le serveur de rappel
Bibliothèque DNN (Deep Learning): Comparaison de chainer et TensorFlow (1)
Comment visualiser les données par variable explicative et variable objective