[PYTHON] Easy Grad-CAM avec pytorch-gradcam

Grad-CAM facile à utiliser avec pytorch-gradcam

Il existe une technologie de visualisation CNN appelée Grad-CAM, qui permet de visualiser les fonctionnalités utilisées pour la classification lors de la classification des images. En faisant cela, vous pouvez considérer la base des règles de classification et, dans certains cas, utiliser les connaissances acquises grâce à elles pour le marketing.

Vous trouverez ci-dessous le résultat de la visualisation de la quantité de fonctionnalités d'intérêt pour une certaine image à l'aide de VGG16.

スクリーンショット 2020-04-28 17.31.13.png

C'est la procédure de mise en œuvre, mais elle est réalisée sous la forme suivante.

(Voici le matériel auquel j'ai fait référence lors de sa mise en œuvre moi-même: Visualisation de CNN par Grad-CAM avec PyTorch.)

Au départ, j'utilisais pytorch pour coder moi-même, mais s'il était appris par un GPU parallèle, torch.nn.DataParallel envelopperait le modèle et la structure hiérarchique du modèle changerait, ou je la sélectionnerais par un réglage fin. Cependant, la définition du module est différente selon le réseau ... je pense que j'ai des capacités de codage, mais à chaque fois que j'utilise un réseau différent, c'est une sueur agaçante

Je me demandais si quelqu'un avait créé une bibliothèque qui pourrait facilement exécuter Grad-CAM avec pytorch sans ressentir de stress, et je pêchais la semaine dernière.

pytorch-gradcam

Il peut visualiser les résultats de GradCAM et GradCAM ++, et prend en charge alexnet, vgg, resnet, densenet et squeezenet. Merci beaucoup!

** De plus, la méthode d'installation est simple, pip install pytorch-gradcam Faites juste! ** **

Bien qu'il s'agisse du code source, il peut être visualisé en exécutant comme suit (dans le code source, le modèle entraîné est chargé avec l'ensemble de données utilisant déjà densenet161, et il devient le modèle entraîné de la classification en 5 classes) ..

# Basic Modules
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

# PyTorch Modules
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, datasets
import torchvision.transforms as transforms
from torch.utils.data.dataset import Subset
import torchvision.models as models
import torch.optim as optim
from torchvision.utils import make_grid, save_image

# Grad-CAM
from gradcam.utils import visualize_cam
from gradcam import GradCAM, GradCAMpp


device = torch.device("cuda:0" if torch.cuda.is_available()  else "cpu")
model = models.densenet161(pretrained=True)
model.fc = nn.Linear(2048,5)
model = torch.nn.DataParallel(model).to(device)
model.eval()
model.load_state_dict(torch.load('trained_model.pt'))

# Grad-CAM
target_layer = model.module.features
gradcam = GradCAM(model, target_layer)
gradcam_pp = GradCAMpp(model, target_layer)

images = []
#En supposant que vous appelez un ensemble de données de validation pour une étiquette
for path in glob.glob("{}/label1/*".format(config['dataset'])):
    img = Image.open(path)
    torch_img = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])(img).to(device)
    normed_torch_img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(torch_img)[None]
    mask, _ = gradcam(normed_torch_img)
    heatmap, result = visualize_cam(mask, torch_img)

    mask_pp, _ = gradcam_pp(normed_torch_img)
    heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img)
    
    images.extend([torch_img.cpu(), heatmap, heatmap_pp, result, result_pp])
grid_image = make_grid(images, nrow=5)

#Voir les résultats
transforms.ToPILImage()(grid_image)

Il est nécessaire de spécifier la couche cible dans la partie de target_layer = model.module.features pour être prudent, mais reportez-vous à utils.py de Github et nommez la cible_layer correspondant à chaque modèle de réseau. Vous pouvez le rechercher ---> utils.py. Ce qui suit est un extrait d'une partie écrite dans utils.py telle quelle, mais le nom de target_layer est décrit pour chaque réseau pris en charge.

| @register_layer_finder('densenet') |
|:--|
| def find_densenet_layer(arch, target_layer_name): |
|     """Find densenet layer to calculate GradCAM and GradCAM++ |
|     Args: |
|         arch: default torchvision densenet models |
|         target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. |
|             target_layer_name = 'features' |
|             target_layer_name = 'features_transition1' |
|             target_layer_name = 'features_transition1_norm' |
|             target_layer_name = 'features_denseblock2_denselayer12' |
|             target_layer_name = 'features_denseblock2_denselayer12_norm1' |
|             target_layer_name = 'features_denseblock2_denselayer12_norm1' |
|             target_layer_name = 'classifier' |
|     Return: |
|         target_layer: found layer. this layer will be hooked to get forward/backward pass information. |
|     """ |

De plus, la raison pour laquelle le module est pris en sandwich avant la fonctionnalité dans target_layer = model.module.features est parce qu'il utilise le modèle entraîné sur le GPU parallèle utilisant DataParallel. Si vous voulez en savoir plus, veuillez vous référer ici pour les points de trébuchement sur le GPU parallèle [PyTorch] Le point de trébucher sur le GPU parallèle en utilisant DataParallel / 12/08 / post-110 /). Si vous n'avez pas créé de modèle d'apprentissage sous la forme de GPU parallèles, vous n'avez pas besoin d'un module.

Résumé

Cette fois, j'ai introduit un module pour réaliser facilement Grad-CAM avec pytorch.

Recommended Posts

Easy Grad-CAM avec pytorch-gradcam
Débogage facile avec ipdb
TopView facile avec OpenCV
Environnement toxique facile avec Jenkins
[Analyse de co-occurrence] Analyse de co-occurrence facile avec Python! [Python]
Synchronisation facile des dossiers avec Python
Rendre avec la syntaxe facile
Grattage Web facile avec Scrapy
Compilation facile de Python avec NUITKA-Utilities
Serveur HTTP facile avec Python
Connexion proxy facile avec django-hijack
Prédiction de séries chronologiques facile avec Prophet
[Python] Traitement parallèle facile avec Joblib
Soyez prudent avec les références de méthodes faciles
Easy Slackbot avec Docker et Errbot
Construction d'un environnement Jupyter facile avec Cloud9
Application GUI facile avec Tkinter Text
Installez facilement pyspark avec conda
Programmation facile Python + OpenCV avec Canopy
Transmission de courrier facile avec Hâte Python3
Implémentation Score-CAM avec keras. Comparaison avec Grad-CAM
Optimisation bayésienne très simple avec Python
Tests faciles d'AWS S3 avec MinIO
Réglage facile de la police japonaise avec matplotlib
Visualisez facilement vos données avec Python seaborn.
Facile avec Slack en utilisant Bot #NowPlaying
Dessinez facilement des graphiques avec matplotlib
Exécution parallèle facile avec le sous-processus python
Animation facile avec matplotlib (mp4, gif)
Environnement de déploiement facile avec gaffer + tissu
Extraction de mots-clés facile avec TermExtract pour Python
[Python] Test super facile avec instruction assert
[Python] Vérification simple du type d'argument avec la classe de données
Formatage JSON facile avec les fonctions Linux standard
Gestion facile des filtres avec l'option -m de Python
Rendez les applications GUI super faciles avec tkinter
[Facile] Reconnaissance automatique AI avec une webcam!
Introduction facile de la reconnaissance vocale avec Python
Liaison multilingue C / C ++ facile avec CMake + SWIG
[Easy Python] Lecture de fichiers Excel avec openpyxl
Application Web facile avec Python + Flask + Heroku
Traitez facilement des images en Python avec Pillow
[Easy Python] Lecture de fichiers Excel avec des pandas
Facile! Utilisez gensim et word2vec avec MAMP.
Scraping Web facile avec Python et Ruby
[Python] Essayez facilement l'apprentissage amélioré (DQN) avec Keras-RL
API REST facile avec API Gateway / Lambda / DynamoDB