[PYTHON] Easy Grad-CAM with pytorch-gradcam

Easily run Grad-CAM with pytorch-gradcam

There is a CNN visualization technology called Grad-CAM, which allows you to visualize which features are used for classification when classifying images. By doing this, we will consider the basis of the classification rules, and in some cases, we will use it for marketing etc. based on the knowledge obtained from it.

Below is the result of visualizing the features that are of interest to an image using VGG16.

スクリーンショット 2020-04-28 17.31.13.png

This implementation procedure is realized in the following form.

--Global Average Pooling in front of the convolutional output layer --Determine the weight of each channel in the final layer in a class --Multiply each channel according to the weight and add --Pass them through the Relu function

(Here is the material that I referred to when implementing it myself: Visualization of CNN by Grad-CAM with PyTorch.)

Initially, I used pytorch to code by myself, but if it was learned by parallel GPU, torch.nn.DataParallel wrapped the model and the hierarchical structure of the model changed, or I chose it by fine tuning. However, the definition of the module is different depending on the network ... I think that I have coding ability, but every time I use a different network, it is annoying sweat

I was wondering if anyone had made a library that could easily run Grad-CAM with pytorch without feeling stress, and I was fishing last week.

pytorch-gradcam

It can visualize the results of GradCAM and GradCAM ++, and supports alexnet, vgg, resnet, densenet, and squeezenet. Thank you very much!

** Moreover, the installation method is easy, pip install pytorch-gradcam Just do! ** **

Although it is the source code, it can be visualized by executing as follows (in the source code, the trained model is loaded with the dataset already using densenet161, and it becomes the trained model of 5 class classification). ..

# Basic Modules
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

# PyTorch Modules
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms, datasets
import torchvision.transforms as transforms
from torch.utils.data.dataset import Subset
import torchvision.models as models
import torch.optim as optim
from torchvision.utils import make_grid, save_image

# Grad-CAM
from gradcam.utils import visualize_cam
from gradcam import GradCAM, GradCAMpp


device = torch.device("cuda:0" if torch.cuda.is_available()  else "cpu")
model = models.densenet161(pretrained=True)
model.fc = nn.Linear(2048,5)
model = torch.nn.DataParallel(model).to(device)
model.eval()
model.load_state_dict(torch.load('trained_model.pt'))

# Grad-CAM
target_layer = model.module.features
gradcam = GradCAM(model, target_layer)
gradcam_pp = GradCAMpp(model, target_layer)

images = []
#Assuming you are calling a validation dataset for a label
for path in glob.glob("{}/label1/*".format(config['dataset'])):
    img = Image.open(path)
    torch_img = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])(img).to(device)
    normed_torch_img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(torch_img)[None]
    mask, _ = gradcam(normed_torch_img)
    heatmap, result = visualize_cam(mask, torch_img)

    mask_pp, _ = gradcam_pp(normed_torch_img)
    heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img)
    
    images.extend([torch_img.cpu(), heatmap, heatmap_pp, result, result_pp])
grid_image = make_grid(images, nrow=5)

#View results
transforms.ToPILImage()(grid_image)

The place to be careful is that you need to specify the target layer in the part of target_layer = model.module.features, but refer to utils.py on Github and name the target_layer corresponding to each network model. You can look it up ---> utils.py. The following is an excerpt of a part written in utils.py as it is, but the name of target_layer is described for each supported network.

| @register_layer_finder('densenet') |
|:--|
| def find_densenet_layer(arch, target_layer_name): |
|     """Find densenet layer to calculate GradCAM and GradCAM++ |
|     Args: |
|         arch: default torchvision densenet models |
|         target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. |
|             target_layer_name = 'features' |
|             target_layer_name = 'features_transition1' |
|             target_layer_name = 'features_transition1_norm' |
|             target_layer_name = 'features_denseblock2_denselayer12' |
|             target_layer_name = 'features_denseblock2_denselayer12_norm1' |
|             target_layer_name = 'features_denseblock2_denselayer12_norm1' |
|             target_layer_name = 'classifier' |
|     Return: |
|         target_layer: found layer. this layer will be hooked to get forward/backward pass information. |
|     """ |

Also, the reason why the module is put before the feature in target_layer = model.module.features is because we are using the trained model on the parallel GPU using DataParallel. If you want to know more, please refer to here for the points of stumbling on parallel GPU [PyTorch] The point of stumbling on parallel GPU using DataParallel / 12/08 / post-110 /). You don't need a module unless you have created a learning model in the form of a parallel GPU.

Summary

This time, I introduced a module to easily perform Grad-CAM with pytorch.

Recommended Posts

Easy Grad-CAM with pytorch-gradcam
Easy debugging with ipdb
Easy TopView with OpenCV
Easy tox environment with Jenkins
[Co-occurrence analysis] Easy co-occurrence analysis with Python! [Python]
Easy folder synchronization with Python
Easy to make with syntax
Easy image classification with TensorFlow
Easy web scraping with Scrapy
Easy Python compilation with NUITKA-Utilities
Easy HTTP server with Python
Easy proxy login with django-hijack
Easy time series prediction with Prophet
[Python] Easy parallel processing with Joblib
Be careful with easy method references
Easy Slackbot with Docker and Errbot
Easy Jupyter environment construction with Cloud9
Easy GUI app with Tkinter Text
Easy to install pyspark with conda
Easy Python + OpenCV programming with Canopy
Easy email sending with haste python3
Score-CAM implementation with keras. Comparison with Grad-CAM
Bayesian optimization very easy with Python
Easy AWS S3 testing with MinIO
Easy Japanese font setting with matplotlib
Easy data visualization with Python seaborn.
Easy with Slack using Bot #NowPlaying
Easy to draw graphs with matplotlib
Easy parallel execution with python subprocess
Easy modeling with Blender and Python
Easy animation with matplotlib (mp4, gif)
Easy deployment environment with gaffer + fabric
Easy keyword extraction with TermExtract for Python
[Python] Super easy test with assert statement
[Python] Easy argument type check with dataclass
Easy JSON formatting with standard Linux functions
Easy filter management with Python's -m option
Make GUI apps super easy with tkinter
[Easy] AI automatic recognition with a webcam!
Easy introduction of speech recognition with Python
Easy C / C ++ multilingual binding with CMake + SWIG
[Easy Python] Reading Excel files with openpyxl
Easy web app with Python + Flask + Heroku
Easy image processing in Python with Pillow
[Easy Python] Reading Excel files with pandas
Easy! Use gensim and word2vec with MAMP.
Easy web scraping with Python and Ruby
[Python] Easy Reinforcement Learning (DQN) with Keras-RL
Easy REST API with API Gateway / Lambda / DynamoDB