[PYTHON] Easy Grad-CAM mit Pytorch-Gradcam

Einfach zu bedienende Grad-CAM mit Pytorch-Gradcam

Es gibt eine CNN-Visualisierungstechnologie namens Grad-CAM, mit der visualisiert werden kann, welche Funktionen für die Klassifizierung bei der Klassifizierung von Bildern verwendet werden. Auf diese Weise können Sie die Grundlage der Klassifizierungsregeln berücksichtigen und in einigen Fällen das daraus gewonnene Wissen für das Marketing verwenden.

Unten sehen Sie das Ergebnis der Visualisierung des für ein bestimmtes Bild interessanten Merkmalsbetrags mit VGG16.

スクリーンショット 2020-04-28 17.31.13.png

Dies ist das Implementierungsverfahren, das jedoch in der folgenden Form realisiert wird.

(Hier ist das Material, auf das ich mich bei der Implementierung selbst bezogen habe: Visualisierung von CNN durch Grad-CAM mit PyTorch.)

Anfangs habe ich Pytorch verwendet, um selbst zu codieren, aber wenn es von einer parallelen GPU gelernt würde, würde torch.nn.DataParallel das Modell umbrechen und die hierarchische Struktur des Modells würde sich ändern, oder ich würde es durch Feinabstimmung auswählen. Die Definition des Moduls ist jedoch je nach Netzwerk unterschiedlich ... Ich glaube, ich habe Codierungsfähigkeiten, aber jedes Mal, wenn ich ein anderes Netzwerk verwende, ist es ärgerlicher Schweiß

Ich habe mich gefragt, ob jemand eine Bibliothek erstellt hat, in der Grad-CAM problemlos mit Pytorch ohne Stress betrieben werden kann, und ich habe letzte Woche gefischt.

pytorch-gradcam

Es kann die Ergebnisse von GradCAM und GradCAM ++ visualisieren und unterstützt Alexnet, Vgg, Resnet, Densenet und Squeezenet. Vielen Dank!

** Darüber hinaus ist die Installationsmethode einfach, pip install pytorch-gradcam Mach einfach! ** ** **

Obwohl es sich um den Quellcode handelt, kann er durch Ausführen wie folgt visualisiert werden (Im Quellcode wird das trainierte Modell mit dem Datensatz geladen, der bereits densenet161 verwendet, und es wird zum trainierten Modell der 5-Klassen-Klassifizierung). ..

# Basic Modules
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

# PyTorch Modules
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, datasets
import torchvision.transforms as transforms
from torch.utils.data.dataset import Subset
import torchvision.models as models
import torch.optim as optim
from torchvision.utils import make_grid, save_image

# Grad-CAM
from gradcam.utils import visualize_cam
from gradcam import GradCAM, GradCAMpp


device = torch.device("cuda:0" if torch.cuda.is_available()  else "cpu")
model = models.densenet161(pretrained=True)
model.fc = nn.Linear(2048,5)
model = torch.nn.DataParallel(model).to(device)
model.eval()
model.load_state_dict(torch.load('trained_model.pt'))

# Grad-CAM
target_layer = model.module.features
gradcam = GradCAM(model, target_layer)
gradcam_pp = GradCAMpp(model, target_layer)

images = []
#Angenommen, Sie rufen ein Validierungsdatensatz für ein Etikett auf
for path in glob.glob("{}/label1/*".format(config['dataset'])):
    img = Image.open(path)
    torch_img = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])(img).to(device)
    normed_torch_img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(torch_img)[None]
    mask, _ = gradcam(normed_torch_img)
    heatmap, result = visualize_cam(mask, torch_img)

    mask_pp, _ = gradcam_pp(normed_torch_img)
    heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img)
    
    images.extend([torch_img.cpu(), heatmap, heatmap_pp, result, result_pp])
grid_image = make_grid(images, nrow=5)

#Ergebnisse anzeigen
transforms.ToPILImage()(grid_image)

Es ist erforderlich, die Zielschicht im Teil von target_layer = model.module.features anzugeben, um vorsichtig zu sein. Beziehen Sie sich jedoch auf utils.py von Github und benennen Sie die Zielschicht, die jedem Netzwerkmodell entspricht. Sie können es nachschlagen ---> utils.py. Das Folgende ist ein Auszug aus einem Teil, der so wie er ist in utils.py geschrieben ist, aber der Name von target_layer wird für jedes unterstützte Netzwerk beschrieben.

| @register_layer_finder('densenet') |
|:--|
| def find_densenet_layer(arch, target_layer_name): |
|     """Find densenet layer to calculate GradCAM and GradCAM++ |
|     Args: |
|         arch: default torchvision densenet models |
|         target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. |
|             target_layer_name = 'features' |
|             target_layer_name = 'features_transition1' |
|             target_layer_name = 'features_transition1_norm' |
|             target_layer_name = 'features_denseblock2_denselayer12' |
|             target_layer_name = 'features_denseblock2_denselayer12_norm1' |
|             target_layer_name = 'features_denseblock2_denselayer12_norm1' |
|             target_layer_name = 'classifier' |
|     Return: |
|         target_layer: found layer. this layer will be hooked to get forward/backward pass information. |
|     """ |

Der Grund, warum das Modul vor dem Feature in target_layer = model.module.features eingeklemmt ist, liegt darin, dass es das trainierte Modell auf der parallelen GPU mit DataParallel verwendet. Wenn Sie mehr wissen möchten, lesen Sie bitte die Stolperpunkte auf einer parallelen GPU [PyTorch] Der Stolperpunkt auf einer parallelen GPU mit DataParallel / 12/08 / post-110 /). Wenn Sie kein Lernmodell in Form von parallelen GPUs erstellt haben, benötigen Sie kein Modul.

Zusammenfassung

Dieses Mal habe ich ein Modul eingeführt, mit dem Grad-CAM mit Pytorch einfach durchgeführt werden kann.

Recommended Posts

Easy Grad-CAM mit Pytorch-Gradcam
Einfaches Debuggen mit ipdb
Einfache TopView mit OpenCV
Einfache toxische Umgebung mit Jenkins
[Analyse des gemeinsamen Auftretens] Einfache Analyse des gemeinsamen Auftretens mit Python! [Python]
Einfache Ordnersynchronisation mit Python
Machen Sie es mit der Syntax einfach
Einfaches Web-Scraping mit Scrapy
Einfache Python-Kompilierung mit NUITKA-Utilities
Einfacher HTTP-Server mit Python
Einfache Proxy-Anmeldung mit Django-Hijack
Einfache Vorhersage von Zeitreihen mit Prophet
[Python] Einfache Parallelverarbeitung mit Joblib
Seien Sie vorsichtig mit einfachen Methodenreferenzen
Einfacher Slackbot mit Docker und Errbot
Einfache Jupyter-Umgebungskonstruktion mit Cloud9
Einfache GUI App mit Tkinter Text
Einfach pyspark mit conda installieren
Einfache Python + OpenCV-Programmierung mit Canopy
Einfache Mailübertragung mit Eile Python3
Score-CAM-Implementierung mit Keras. Vergleich mit Grad-CAM
Bayesianische Optimierung, die mit Python sehr einfach ist
Einfacher AWS S3-Test mit MinIO
Einfache japanische Schrifteinstellung mit matplotlib
Visualisieren Sie Ihre Daten ganz einfach mit Python Seaborn.
Einfach mit Slack mit Bot #NowPlaying
Zeichnen Sie einfach Diagramme mit matplotlib
Einfache parallele Ausführung mit Python-Unterprozess
Einfache Animation mit matplotlib (mp4, gif)
Einfache Bereitstellungsumgebung mit Gaffer + Fabric
Einfache Schlüsselwortextraktion mit TermExtract für Python
[Python] Super einfacher Test mit Assert-Anweisung
[Python] Einfache Überprüfung des Argumenttyps mit Datenklasse
Einfache JSON-Formatierung mit Standard-Linux-Funktionen
Einfache Filterverwaltung mit der Option -m von Python
Machen Sie GUI-Apps mit tkinter ganz einfach
[Einfach] AI automatische Erkennung mit einer Webkamera!
Einfache Einführung der Spracherkennung mit Python
Einfache mehrsprachige C / C ++ - Bindung mit CMake + SWIG
[Easy Python] Lesen von Excel-Dateien mit openpyxl
Einfache Web-App mit Python + Flask + Heroku
Verarbeiten Sie Bilder in Python ganz einfach mit Pillow
[Easy Python] Lesen von Excel-Dateien mit Pandas
Einfach! Verwenden Sie gensim und word2vec mit MAMP.
Einfaches Web-Scraping mit Python und Ruby
[Python] Probieren Sie mit Keras-RL ganz einfach erweitertes Lernen (DQN) aus
Einfache REST-API mit API Gateway / Lambda / DynamoDB