[PYTHON] I made an AI that crops an image nicely using Salience Map

Introduction

In this article, I will implement the method of cropping images using Salience Map using deep learning with Python / PyTorch while reading the paper.

When we talk about images in deep learning, we often try to classify handwritten numbers and detect people, but I hope you will see that you can also do this.

This article participates in DeNA 20 New Graduate Advent Calendar 2019 --Qiita. Thanks to the Advent calendar for giving me the opportunity to make it!

Reader's assumption

Since there are various genres of Advent calendars, if you just read the article, it is intended for everyone who has touched the program. In terms of moving it, it is assumed for those who have done deep learning tutorial-like things.

Since the code assumed for Jupyter Notebook is included for easy trial, it is possible to move it at hand. The display is collapsed, so click to open it if necessary.

I am using only the library that is already installed in Google Colaboratory. Due to the large data set, it can be a little difficult to try until it is trained.

Image cropping

Sometimes I want to crop (crop) an image in some way. For example, the icon image is roughly square, so I think everyone has thought about how to cut it when registering for various services. In addition, the header image is halfway horizontally long, and the shape of the image is often decided on the spot. On the other hand, if the user cuts it into a fixed shape, you can do your best to make it feel good, but there are many cases where it is necessary to automate it on the application side.

A little example

Suppose you want the posted image to always be displayed vertically (1: 3) on a page. It is vertically long because it is a condition that seems difficult to cut.

This photo I took with "It's a nice lobby with a Christmas tree", if you cut it yourself, of course I will do it like this to show the Christmas tree.

However, it is not possible for people to see and cut out all the images posted in large numbers, so it will be automated. Well, I decided to implement it in Python because it would be safe to cut the middle.

Click here to see the implementation
import numpy as np
import cv2
import matplotlib.pyplot as plt

def crop(image, aspect_rate=(1, 1)):
    """     
Cut out the image from the center so that it has the specified aspect ratio.
    
    Parameters:
    -----------------
    image : ndarray, (h, w, rgb), uint8
    aspect_rate : tuple of int (x, y)
        default : (1, 1)

    Returns:
    -----------------
    cropped_image : ndarray, (h, w, rgb), uint8
    """        
    assert image.dtype==np.uint8
    assert image.ndim==3        

    im_size = (image.shape[1], image.shape[0]) # tuple of int, (width, height)
    center = (int(round(im_size[0]/2)), int(round(im_size[1]/2))) # tuple of int, (x, y)

    #Find the following four values
    # box_x : int,Top left x coordinate to crop, box_y : int,Top left y coordinate to crop
    # box_width : int,Cutout width, box_height : int,Cutout height
    if im_size[0]>im_size[1]:
        box_y = 0
        box_height = im_size[1]
        box_width = int(round((im_size[1]/aspect_rate[1])*aspect_rate[0]))
        if box_width>im_size[0]:
            box_x = 0
            box_width = im_size[0]
            box_height = int(round((im_size[0]/aspect_rate[0])*aspect_rate[1]))
            box_y = int(round(center[1]-(box_height/2)))
        else:
            box_x = int(round(center[0]-(box_width/2)))
    else:
        box_x = 0
        box_width = im_size[0]
        box_height = int(round((im_size[0]/aspect_rate[0])*aspect_rate[1]))
        if box_height>im_size[1]:
            box_y = 0
            box_height = im_size[1]
            box_width = int(round((im_size[1]/aspect_rate[1])*aspect_rate[0]))
            box_y = int(round(center[0]-(box_width/2)))
        else:
            box_y = int(round(center[1]-(box_height/2)))

    cropped_image = image[box_y:box_y+box_height, box_x:box_x+box_width]
    return cropped_image
    
#image: Read the image with OpenCV etc. and make it a NumPy array
image = cv2.imread("tree.jpg ")[:, :, ::-1]
cropped_image = crop(image, aspect_rate=(1, 3))
plt.imshow(cropped_image)
plt.show()

Assuming that the long side of the image is used entirely, it is calculated from the aspect ratio given the length of the short side at that time.

I will try it.

The Christmas element is gone and it's just a nice lobby photo. This is bad. AI? Let's do something with the power of.

Things to refer to

This time, I will try to imitate the one using Salience Map that Twitter and Adobe have introduced in the last two years. On Twitter [^ 1], when you post an image, it is nicely displayed on the timeline. In addition, Adobe's InDesigin has a function called Content-Aware Fit that crops the image according to the specified range.

[^ 1]: Introducing a neural network that optimally and automatically crops images https://blog.twitter.com/ja_jp/topics/product/2018/0125ML-CR.html

Object detection can be used as a comparison method. However, the Salience Map-based method is versatile in that it does not always show the object with the trained label.

Cropping using Salience Map

A cropping method using Saliency Map [^ 2] was proposed in 2013 by Ardizzone's paper "Saliency Based Image Cropping".

What is Saliency Map?

Where does the line of sight go when a person sees the image? ** Salience Map ** is a pixel-based version of. For example, in the lower left of the figure, this was obtained by measuring from many people, and the Salience Map is the one obtained by calculation. In this figure, the whiter the part, the higher the probability that the viewpoint is, and the black part, the lower the probability that the viewpoint is.

Figure: Example of Salience Map. Upper left: Image. Upper right: The measured viewpoint is indicated by a red X. Bottom left: Salience Map. Bottom right: Salience Map in color and overlaid on the image.

This figure is a visualization of the training data of the SALICON dataset [^ 3]. The red X in the upper right is the viewpoint data obtained by having many people look at the image in the upper left and touch the part you are looking at with the mouse cursor.

If you apply a Gaussian filter based on that data, you can create a map that shows the probability (0 to 1) that there is a viewpoint in pixel units, as shown in the lower left. This is the training data of Salience Map that you want to calculate.

As shown in the lower right, if you color it and overlay it on the image, you can see that there is a high probability that the cat will be eye-catching. When the probability of the viewpoint is close to 1, it is red, and when the probability is close to 0, it is blue.

Implemented Ardizzone method

Let's implement Ardizzone's method. For the time being, Salience Map will use the training data of the SALICON dataset as it is. This is the image of the cat and the learning data (correct answer data) of Salience Map for it.

What kind of method?

It is a method of cropping to include all pixels above a certain probability. It means that you should only place it in places where you are likely to see it.

fig2.png Figure: Pipeline of Ardizzone's method (quoted from the paper [^ 2])

To summarize this figure in words, there are the following three steps.

--Salience Map is binarized with a certain threshold (set to 1 and 0) --Find a bounding box that encloses the range of 1 --Crop the original image with the bounding box

Binarize

NumPy makes binarization easy. NumPy also broadcasts the calculation of comparison operators (> and ==), so if you execute ndarray> float, you will get True or False of each element and binarization is completed. ..

Click here to see the implementation
threshhold = 0.3 #Set threshold, float (0<threshhold<1)
saliencymap_path = 'COCO_train2014_000000196971.png' #Salience Map path

saliencymap = cv2.imread(saliencymap_path)[:, :, ::-1] # ndarray, (h, w, rgb), np.uint8 (0-255)
saliencymap = saliencymap[:, :, 0] # ndarray, (h, w), np.uint8 (0-255)
plt.imshow(saliencymap)
plt.show()

threshhold *= 255 #Salience Map read from the image is 0-Since it is 255, convert the range

binarized_saliencymap = saliencymap>threshhold # ndarray, (h, w), bool

plt.imshow(binarized_saliencymap)
plt.show()

fig3.png Figure: Results of binarization

The result is as shown in this figure. By default, matplotlib's plt.imshow () shows large values in yellow and small values in purple.

The threshold is a hyperparameter that can be set arbitrarily. This time, it is unified to 0.3 throughout the article.

Ask for a bounding box

Calculate a ** bounding box ** (a rectangle that just surrounds) that contains all the 1s (True) obtained by binarization.

This is implemented in OpenCV's cv2.boundingRect () and can be achieved by just calling it.

Structural Analysis and Shape Descriptors — OpenCV 2.4.13.7 documentation

[Area (contour) features — OpenCV-Python Tutorials 1 documentation](http://labs.eecs.tottori-u.ac.jp/sd/Member/oyamada/OpenCV/html/py_tutorials/py_imgproc/py_contours/py_contour_features/ py_contour_features.html)

Use patches.Rectangle () to draw rectangles in matplotlib.

matplotlib.patches.Rectangle — Matplotlib 3.1.1 documentation

Click here to see the implementation
import matplotlib.patches as patches

#Convert to a format that OpenCV can handle
binarized_saliencymap = binarized_saliencymap.astype(np.uint8) # ndarray, (h, w), np.uint8 (0 or 1)

box_x, box_y, box_width, box_height = cv2.boundingRect(binarized_saliencymap)
# box_x : int,Top left x coordinate to crop, box_y : int,Top left y coordinate to crop
# box_width : int,Cutout width, box_height : int,Cutout height

#Bounding box drawing
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#00FF00', fill=False)
ax.imshow(binarized_saliencymap)
ax.add_patch(bounding_box)
plt.show()

fig4.png Figure: Result of getting the bounding box

You can get the bounding box as shown in this figure. The rectangle information is held as the upper left coordinates and width / height values.

cut out

Crop the image based on the obtained bounding box. Slice the ndarray of the image using the value in the bounding box.

Click here to see the implementation
image_path = 'COCO_train2014_000000196971.jpg' #Image path
image = cv2.imread(image_path)[:, :, ::-1] # ndarray, (h, w, rgb), np.uint8 (0-255)

cropped_image = image[box_y:box_y+box_height, box_x:box_x+box_width] # ndarray, (h, w, rgb), np.uint8 (0-255)

#Visualization
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#00FF00', fill=False)
ax.imshow(image)
ax.add_patch(bounding_box)
plt.show()

plt.imshow(cropped_image)
plt.show()

fig5.png Figure: Results cut out by Ardizzone's method

As shown in this figure, an image was obtained in which the line of sight was likely to be directed.

Overlay with the colored Salience Map

To make it easier to see how it was processed, let's overlay the colored Salience Map and bounding box on the image. Implement a function to make Salience Map a color and a function to overlay Salience Map on an image.

Click here to see the implementation
def color_saliencymap(saliencymap):
    """
Color and visualize the Salience Map. Set 1 to red and 0 to blue.
    
    Parameters
    ----------------
    saliencymap : ndarray, np.uint8, (h, w) or (h, w, rgb)
    
    Returns
    ----------------
    saliencymap_colored : ndarray, np.uint8, (h, w, rgb)
    """
    assert saliencymap.dtype==np.uint8
    assert (saliencymap.ndim == 2) or (saliencymap.ndim == 3)
    
    saliencymap_colored = cv2.applyColorMap(saliencymap, cv2.COLORMAP_JET)[:, :, ::-1]
    
    return saliencymap_colored

def overlay_saliencymap_and_image(saliencymap_color, image):
    """
Overlay the image with Salience Map.
    
    Parameters
    ----------------
    saliencymap_color : ndarray, (h, w, rgb), np.uint8
    image : ndarray, (h, w, rgb), np.uint8
    
    Returns
    ----------------
    overlaid_image : ndarray(h, w, rgb)
    """
    assert saliencymap_color.ndim==3
    assert saliencymap_color.dtype==np.uint8
    assert image.ndim==3
    assert image.dtype==np.uint8
    im_size = (image.shape[1], image.shape[0])
    saliencymap_color = cv2.resize(saliencymap_color, im_size, interpolation=cv2.INTER_CUBIC)
    overlaid_image = cv2.addWeighted(src1=image, alpha=1, src2=saliencymap_color, beta=0.7, gamma=0)
    return overlaid_image

saliencymap_colored = color_saliencymap(saliencymap) # ndarray, (h, w, rgb), np.uint8
overlaid_image = overlay_saliencymap_and_image(saliencymap_colored, image) # ndarray, (h, w, rgb), np.uint8

#Visualization
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#00FF00', fill=False)
ax.imshow(overlaid_image)
ax.add_patch(bounding_box)
plt.show()

fig_6.png Figure: Image of colorized Salience Map and bounding box overlaid

As shown in this figure, you can see that the areas on the Salience Map that have a high probability of turning red are surrounded.

Supports any aspect ratio

With Ardizzone's method, the size and aspect ratio depends on the Salience Map. But now that I want to crop to a certain aspect ratio, I need to think about that.

Cut out so that the total value of Salience Map is large

I couldn't find an existing method for this, so I decided to use the following algorithm to determine the range to cut out.

--After using all the range obtained by Ardizzone's method, extend the range in a certain direction so that it has the specified aspect ratio. ――If you extend the range and it jumps out of the image, use the entire image for that direction and narrow the opposite direction to adjust. ――The range to be narrowed is the range where the total value of Salience Map is maximized in the range obtained by Ardizzone's method. --The range to be extended is the range where the total value of Salience Map is maximized.

Find the range that maximizes the sum of the Salience Map values, using the range found so far as much as possible.

Create a "SaliencyBasedImageCropping class" for cropping, and summarize the code so far below.

Click here to see the implementation
import copy

import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

class SaliencyBasedImageCropping:
    """
A class for cropping images using Salience Map. A method that uses the entire range that exceeds a certain threshold[1]To use.
    
* If no pixels exceed the threshold, the entire image is returned.

    [1] Ardizzone, Edoardo, Alessandro Bruno, and Giuseppe Mazzola. "Saliency based image cropping." International Conference on Image Analysis and Processing. Springer, Berlin, Heidelberg, 2013.
        
    Parameters
    ----------------
    aspect_rate : tuple of int (x, y)
If you specify the aspect ratio here,[1]Find the range that maximizes the total value of the Salience Map while using the range obtained by the method of.
    min_size : tuple of int (w, h)
        [1]If each axis of the range obtained by the method of is smaller than this value, the range is evenly expanded starting from the center of the range.
    
    Attributes
    ----------------
    self.aspect_rate : tuple of int (x, y)
    self.min_size : tuple of int (w, h)
    im_size : tuple of int (w, h)
    self.bounding_box_based_on_binary_saliency : list
        [1]Range obtained by the method of
        box_x : int
        box_y : int
        box_width : int
        box_height : int
    self.bounding_box : list
The final cropping range with the adjusted aspect ratio
        box_x : int
        box_y : int
        box_width : int
        box_height : int
    """
    def __init__(self, aspect_rate=None, min_size=(200, 200)):
        assert (aspect_rate is None)or((type(aspect_rate)==tuple)and(len(aspect_rate)==2))
        assert (type(min_size)==tuple)and(len(min_size)==2)
        self.aspect_rate = aspect_rate
        self.min_size = min_size
        self.im_size = None
        self.bounding_box_based_on_binary_saliency = None
        self.bounding_box = None
    
    def _compute_bounding_box_based_on_binary_saliency(self, saliencymap, threshhold):
        """
Ardizzone's method[1]Find the cropping range based on the Salience Map with.
        
        Parameters:
        -----------------
        saliencymap : ndarray, (h, w), np.uint8
            0<=saliencymap<=255
            
        threshhold : float
            0<threshhold<255
        
        Returns:
        -----------------
        bounding_box_based_on_binary_saliency : list
            box_x : int
            box_y : int
            box_width : int
            box_height : int
        
        """
        assert (threshhold>0)and(threshhold<255)
        assert saliencymap.dtype==np.uint8
        assert saliencymap.ndim==2
        
        binarized_saliencymap = saliencymap>threshhold
        #If there are no pixels in the Salience Map that exceed the threshold, treat them as all.
        if saliencymap.sum()==0:
            saliencymap+=True
        binarized_saliencymap = (binarized_saliencymap.astype(np.uint8))*255
        # binarized_saliencymap : ndarray, (h, w), uint8, 0 or 255
        
        #Small areas are erased by morphology processing (opening)
        kernel_size = round(min(self.im_size)*0.02)
        kernel = np.ones((kernel_size, kernel_size))
        binarized_saliencymap = cv2.morphologyEx(binarized_saliencymap, cv2.MORPH_OPEN, kernel)

        box_x, box_y, box_width, box_height = cv2.boundingRect(binarized_saliencymap)
        bounding_box_based_on_binary_saliency = [box_x, box_y, box_width, box_height]
        return bounding_box_based_on_binary_saliency
        
    def _expand_small_bounding_box_to_minimum_size(self, bounding_box):
        """
If the range is smaller than the specified size, widen it. Spread the range evenly starting from the center of the range. If it goes out of the image, spread it to the opposite side.
        
        Parameters:
        -----------------
        bounding_box : list
            box_x : int
            box_y : int
            box_width : int
            box_height : int
        """
        bounding_box = copy.copy(bounding_box) #Deep copy because I want to keep the values of the original list
        
        # axis=0 : x and witdth, axis=1 : y and hegiht
        for axis in range(2):
            if bounding_box[axis+2]<self.min_size[axis+0]:
                bounding_box[axis+0] -= np.floor((self.min_size[axis+0]-bounding_box[axis+2])/2).astype(np.int)
                bounding_box[axis+2] = self.min_size[axis+0]
                if bounding_box[axis+0]<0:
                    bounding_box[axis+0] = 0
                if (bounding_box[axis+0]+bounding_box[axis+2])>self.im_size[axis+0]:
                    bounding_box[axis+0] -= (bounding_box[axis+0]+bounding_box[axis+2]) - self.im_size[axis+0]
        return bounding_box
    
    def _expand_bounding_box_to_specified_aspect_ratio(self, bounding_box, saliencymap):
        """
Expand the range so that it has the specified aspect ratio.
Ardizzone's method[1]Find the range that maximizes the total value of the Salience Map while using the range obtained in step 2 as much as possible.
        
        Parameters
        ----------------
        bounding_box : list
            box_x : int
            box_y : int
            box_width : int
            box_height : int
        saliencymap : ndarray, (h, w), np.uint8
            0<=saliencymap<=255
        """
        assert saliencymap.dtype==np.uint8
        assert saliencymap.ndim==2
        
        bounding_box = copy.copy(bounding_box)

        # axis=0 : x and witdth, axis=1 : y and hegiht    
        if bounding_box[2]>bounding_box[3]:
            long_length_axis = 0
            short_length_axis = 1
        else:
            long_length_axis = 1
            short_length_axis = 0
        
        #In which direction to stretch
        rate1 = self.aspect_rate[long_length_axis]/self.aspect_rate[short_length_axis]
        rate2 = bounding_box[2+long_length_axis]/bounding_box[2+short_length_axis]
        if rate1>rate2:
            moved_axis = long_length_axis
            fixed_axis = short_length_axis
        else:
            moved_axis = short_length_axis
            fixed_axis = long_length_axis
        
        fixed_length = bounding_box[2+fixed_axis]
        moved_length = int(round((fixed_length/self.aspect_rate[fixed_axis])*self.aspect_rate[moved_axis]))
        if moved_length > self.im_size[moved_axis]:
            #When the size of the image is exceeded when stretched
            moved_axis, fixed_axis = fixed_axis, moved_axis
            fixed_length = self.im_size[fixed_axis]
            moved_length = int(round((fixed_length/self.aspect_rate[fixed_axis])*self.aspect_rate[moved_axis]))
            fixed_point = 0
            start_point = bounding_box[moved_axis]
            end_point = bounding_box[moved_axis]+bounding_box[2+moved_axis]
            if fixed_axis==0:
                saliencymap_extracted = saliencymap[start_point:end_point, :]
            elif fixed_axis==1:
                saliencymap_extracted = saliencymap[:, start_point:end_point:]
        else:   
            #When stretched to fit within the size of the image
            start_point = int(bounding_box[moved_axis]+bounding_box[2+moved_axis]-moved_length)
            if start_point<0:
                start_point = 0
            end_point = int(bounding_box[moved_axis]+moved_length)
            if end_point>self.im_size[moved_axis]:
                end_point = self.im_size[moved_axis]
            if fixed_axis==0:
                fixed_point = bounding_box[fixed_axis]
                saliencymap_extracted = saliencymap[start_point:end_point, fixed_point:fixed_point+fixed_length]
            elif fixed_axis==1:
                fixed_point = bounding_box[fixed_axis]
                saliencymap_extracted = saliencymap[fixed_point:fixed_point+fixed_length, start_point:end_point]
        saliencymap_summed_1d = saliencymap_extracted.sum(moved_axis)
        self.saliencymap_summed_slided = np.convolve(saliencymap_summed_1d, np.ones(moved_length), 'valid')
        moved_point = np.array(self.saliencymap_summed_slided).argmax() + start_point
        
        if fixed_axis==0:
            bounding_box = [fixed_point, moved_point, fixed_length, moved_length]
        elif fixed_axis==1:
            bounding_box = [moved_point, fixed_point, moved_length, fixed_length]
        return bounding_box
    
    def crop_center(self, image):
        """     
Crop the center of the image with the specified aspect ratio without using Salience Map.
        
        Parameters:
        -----------------
        image : ndarray, (h, w, rgb), uint8
            
        Returns:
        -----------------
        cropped_image : ndarray, (h, w, rgb), uint8
        """        
        assert image.dtype==np.uint8
        assert image.ndim==3        
        
        im_size = (image.shape[1], image.shape[0]) # tuple of int, (width, height)
        center = (int(round(im_size[0]/2)), int(round(im_size[1]/2))) # tuple of int, (x, y)
        
        if im_size[0]>im_size[1]:
            box_y = 0
            box_height = im_size[1]
            box_width = int(round((im_size[1]/self.aspect_rate[1])*self.aspect_rate[0]))
            if box_width>im_size[0]:
                box_x = 0
                box_width = im_size[0]
                box_height = int(round((im_size[0]/self.aspect_rate[0])*self.aspect_rate[1]))
                box_y = int(round(center[1]-(box_height/2)))
            else:
                box_x = int(round(center[0]-(box_width/2)))

        else:
            box_x = 0
            box_width = im_size[0]
            box_height = int(round((im_size[0]/self.aspect_rate[0])*self.aspect_rate[1]))
            if box_height>im_size[1]:
                box_y = 0
                box_height = im_size[1]
                box_width = int(round((im_size[1]/self.aspect_rate[1])*self.aspect_rate[0]))
                box_y = int(round(center[0]-(box_width/2)))
            else:
                box_y = int(round(center[1]-(box_height/2)))
        
        cropped_image = image[box_y:box_y+box_height, box_x:box_x+box_width]
        return cropped_image
    
    def crop(self, image, saliencymap, threshhold=0.3):
        """     
Cropping using Salience Map.
        
        Parameters:
        -----------------
        image : ndarray, (h, w, rgb), np.uint8
        saliencymap : ndarray, (h, w), np.uint8
            Saliency map's ndarray need not be the same size as image's ndarray. Saliency map is resized within this method.
        threshhold : float
            0 < threshhold <1
            
        Returns:
        -----------------
        cropped_image : ndarray, (h, w, rgb), uint8
        """
        assert (threshhold>0)and(threshhold<1)
        assert image.dtype==np.uint8
        assert image.ndim==3
        assert saliencymap.dtype==np.uint8
        assert saliencymap.ndim==2
        
        threshhold = threshhold*255 # scale to 0 - 255
        self.im_size = (image.shape[1], image.shape[0]) # (width, height)
        saliencymap = cv2.resize(saliencymap, self.im_size, interpolation=cv2.INTER_CUBIC)

        # compute bounding box based on saliency map
        bounding_box_based_on_binary_saliency = self._compute_bounding_box_based_on_binary_saliency(saliencymap, threshhold)
        bounding_box = self._expand_small_bounding_box_to_minimum_size(bounding_box_based_on_binary_saliency)
        if self.aspect_rate is not None:
            bounding_box = self._expand_bounding_box_to_specified_aspect_ratio(bounding_box, saliencymap)
            
        box_y = bounding_box[1]
        box_x = bounding_box[0]
        box_height = bounding_box[3]
        box_width = bounding_box[2]
        
        cropped_image = image[box_y:box_y+box_height, box_x:box_x+box_width]
        
        self.bounding_box_based_on_binary_saliency = bounding_box_based_on_binary_saliency
        self.bounding_box  = bounding_box
        
        return cropped_image

# -------------------
# SETTING
threshhold = 0.3 #Set threshold, float (0<threshhold<1)
saliencymap_path = 'COCO_train2014_000000196971.png' #Salience Map path
image_path = 'COCO_train2014_000000196971.jpg' #Image path
# -------------------

saliencymap = cv2.imread(saliencymap_path)[:, :, ::-1] # ndarray, (h, w, rgb), np.uint8 (0-255)
saliencymap = saliencymap[:, :, 0] # ndarray, (h, w), np.uint8 (0-255)
image = cv2.imread(image_path)[:, :, ::-1] # ndarray, (h, w, rgb), np.uint8 (0-255)

#Visualization of cropped images using Salience Map
cropper = SaliencyBasedImageCropping(aspect_rate=(1, 3))
cropped_image = cropper.crop(image, saliencymap, threshhold=0.3)
plt.imshow(cropped_image)
plt.show()

#Visualization of Salience Map and bounding box
#The one that matches the specified aspect ratio is red, and the one that matches the specified aspect ratio is green.
saliencymap_colored = color_saliencymap(saliencymap) # ndarray, (h, w, rgb), np.uint8
overlaid_image = overlay_saliencymap_and_image(saliencymap_colored, image) # ndarray, (h, w, rgb), np.uint8
box_x, box_y, box_width, box_height = cropper.bounding_box
box_x_0, box_y_0, box_width_0, box_height_0 = cropper.bounding_box_based_on_binary_saliency
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#FF0000', fill=False)
bounding_box_based_on_binary_saliency = patches.Rectangle(xy=(box_x_0, box_y_0), width=box_width_0, height=box_height_0, ec='#00FF00', fill=False)
ax.imshow(overlaid_image)
ax.add_patch(bounding_box)
ax.add_patch(bounding_box_based_on_binary_saliency)
plt.show()

#Visualization of image with cropped center for comparison
center_cropped_image = cropper.crop_center(image)
plt.imshow(center_cropped_image)
plt.show()

Np.convolve () is used to find the range where the Salience Map has the maximum value.

numpy.convolve — NumPy v1.17 Manual

This is a one-dimensional convolution function. By convolving with an array of all 1s of the length you want to sum, you can calculate the sum for each fixed range as shown below.

array_1d = np.array([1, 2, 3, 4])
print(np.convolve(array_1d, np.ones(2), 'valid')) # [3. 5. 7.]

If you use a simple for statement on Python, it will slow down, so we will combine NumPy functions as much as possible.

In addition, in the binarization process, we have added an implementation that erases a very small area by morphology conversion. Especially when the Salience Map is obtained by deep learning after this, such an area is likely to occur, so this implementation is added.

Morphology Transformation — OpenCV-Python Tutorials 1 documentation

See the result

The green bounding box is used before the aspect ratio is adjusted, and the red bounding box is used after adjusting the aspect ratio.

fig_7.png Figure: Results of cropping with a 1: 3 aspect ratio using Salience Map

As shown in this figure (a), a cat and a hand soap in a vertically long range? I succeeded in cutting out by inserting. Compared to the figure (b) where the center was just cut out, the part that humans want to see is put in a good feeling.

fig8.png Figure: Results of cropping with a 1: 1 aspect ratio using Salience Map

This is the case for a square (1: 1). When the center is cut out (Fig. (B)), the cat is firmly contained, but when using Salience Map (Fig. (A)), a narrower area is cut out, so it is displayed in the same size. If the cat is getting bigger. Not only whether an object is shown, but also whether it is shown in a sufficient size is important in cropping.

Implemented a model (SalGAN) that estimates Salience Map using deep learning with PyTorch

You cannot crop the image you prepared by yourself. I want to crop the image of the Christmas tree I took, not the image of the SALICON dataset, so I will use deep learning to create an estimation model of the Salience Map.

If you look at the benchmark site "MIT Saliency Benchmark" [^ 4] for the Salience Map task, you will find various methods, but this time we will implement SalGAN [^ 5]. The score doesn't seem to be very high, but I chose this because the mechanism seemed simple.

The author implementation [^ 6] was also released, but since the framework is not very familiar with Lasagne (Theano), I will write it with PyTorch while referring to it.

What is SalGAN?

"SalGAN: Visual Saliency Prediction with Generative Adversarial Networks" is a paper published in 2017. As the name implies, it is a technique to estimate the Salience Map using ** GAN (Generative Adversarial Networks) **.

I will omit the explanation of GAN because there are already many easy-to-understand articles. For example, GAN (1) Understanding the basic structure-Qiita is recommended. If you know a typical GAN method that has a lot of implementation and explanation, you can implement it by considering the difference.

fig_salgan.png Figure: Overall structure of SalGAN (quoted from paper [^ 5])

The Salience Map is a pixel-by-pixel binary classification problem because it has a probability of having a viewpoint for each pixel (0 to 1). It's close to one-class segmentation. Since I want to input an image and output an image (Salience Map), it becomes a ** Encoder-Decoder model ** using CNN as shown in this figure. Pix2Pix [^ 7] is famous when it comes to image-to-image using GAN, but it does not have a U-Net structure like that.

In the Encoder-Decoder model, you can also learn to reduce the output Salience Map and ** Binary Cross Entropy ** of the correct answer data. However, this SalGAN is trying to improve the accuracy by adding a network (** Discriminator **) that classifies the Salience Map as correct data or estimated data.

The loss function of the Encoder-Decoder part (** Generator **) is as follows. In addition to the usual Adversarial Loss, the estimated Salience Map and the Binary Cross Entropy section of the correct answer data are added. Adjust the ratio with the hyperparameter $ \ alpha $.

\mathcal{L}\_{BCE} = -\frac{1}{N}\sum\_{j=1}^{N}(S\_{j}\log{(\hat{S}\_{j})}+(1-S\_{j})\log{(1-\hat{S}\_{j})}).
\mathcal{L} = \alpha\cdot\mathcal{L}\_{BCE} + L(D(I, \hat{S}), 1).

The loss function of Discriminator is as follows. It is a general form.

\mathcal{L}\_{\mathcal{D}} = L(D(I, S), 1)+L(D(I, \hat{S}),0).

Read a little more

While quoting the paper, we will read the information necessary for implementation. I somehow understood by looking at the overall structure, but I will look for the part where the information I want to know a little more is written.

The encoder part of the network is identical in architecture to VGG-16 (Simonyan and Zisserman, 2015), omitting the final pooling and fully connected layers. The network is initialized with the weights of a VGG-16 model trained on the ImageNet data set for object classification (Deng et al., 2009). Only the last two groups of convolutional layers in VGG-16 are modified during the training for saliency prediction, while the earlier layers remain fixed from the original VGG-16 model.

--The CNN of the Encoder part of Generator uses VGG16 --Excluding the last pooling layer and fully connected layer --Initial value is the weight learned by ImageNet --Learn only the last two groups of convolution layers --The weights of the convolution layers of the previous 3 groups are fixed as they are learned by ImageNet.

The decoder architecture is structured in the same way as the encoder, but with the ordering of layers reversed, and with pooling layers being replaced by upsampling layers. Again, ReLU non-linearities are used in all convolution layers, and a final 1 × 1 convolution layer with sigmoid non-linearity is added to produce the saliency map. The weights for the decoder are randomly initialized. The final output of the network is a saliency map in the same size to input image.

--Decoder is the same as Encoder, but inserts an upsampling layer instead of the pooling layer --The last layer is a sigmoid function after a 1x1 convolution --Weights are initialized randomly --The output will be the same size as the input

The input to the discriminator network is an RGBS image of size 256×192×4 containing both the source image channels and (predicted or ground truth) saliency.

--Inject not only Salience Map but also the original image into Discriminator in 4 channels. --In the first place, input the image at 256 x 192

We train the networks on the 15,000 images from the SALICON training set using a batch size of 32.

--Use 15,000 images from the SALICON dataset --Batch size is 32

I don't want to do a reproduction experiment of the paper this time, so I am not particular about the details when implementing it. For example, instead of ReLU in the paper, Leaky ReLU, which is generally considered to be effective, is adopted.

Write code

I will write the code. Since the basis is GAN of the Encoder-Decoder model by CNN, we will refer to the implementation of the existing similar method. For example, eriklindernoren's GitHub has PyTorch with various GANs implemented. DCGAN implementation [^ 8] looks good.

Create a Generator and Discriminator class

Generator uses VGG16 which has been trained with ImageNet, which is provided in torchvision [^ 9]. In SalGAN, the weight is fixed on the front side and learning is on the back side, so describe it in separate layers like torchvision.models.vgg16 (pretrained = True) .features [:17]. You can check what number is what layer with print (torchvision.models.vgg16 (pretrained = True) .features).

torchvision.models — PyTorch master documentation

Click here to see the implementation
from torch import nn
import torchvision

class Generator(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()        
        self.encoder_first = torchvision.models.vgg16(pretrained=True).features[:17] #The part to be used with fixed weight
        self.encoder_last = torchvision.models.vgg16(pretrained=True).features[17:-1] #Part to learn
        self.decoder = nn.Sequential(
                    nn.Conv2d(512, 512, 3, padding=1), 
                    nn.LeakyReLU(),
                    nn.Conv2d(512, 512, 3, padding=1), 
                    nn.LeakyReLU(),
                    nn.Conv2d(512, 512, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Upsample(scale_factor=2),
                    nn.Conv2d(512, 512, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(512, 512, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(512, 512, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Upsample(scale_factor=2),
                    nn.Conv2d(512, 256, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(256, 256, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(256, 256, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Upsample(scale_factor=2),
                    nn.Conv2d(256, 128, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(128, 128, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Upsample(scale_factor=2),
                    nn.Conv2d(128, 64, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(64, 64, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(64, 1, 1, padding=0),
                    nn.Sigmoid())

    def forward(self, x):
        x = self.encoder_first(x)
        x = self.encoder_last(x)
        x = self.decoder(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
                    nn.Conv2d(4, 3, 1, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(3, 32, 3, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.MaxPool2d(2, stride=2),
                    nn.Conv2d(32, 64, 3, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(64, 64, 3, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.MaxPool2d(2, stride=2),
                    nn.Conv2d(64, 64, 3, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(64, 64, 3, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.MaxPool2d(2, stride=2))
        self.classifier = nn.Sequential(
                    nn.Linear(64*32*24, 100, bias=True),
                    nn.Tanh(),
                    nn.Linear(100, 2, bias=True),
                    nn.Tanh(),
                    nn.Linear(2, 1, bias=True),
                    nn.Sigmoid())

    def forward(self, x):
        x = self.main(x)
        x = x.view(x.shape[0], -1)
        x = self.classifier(x)
        return x

Create a dataset class

You need a dataset class to read the SALICON dataset. It is a little troublesome to write according to the prepared data set and task. When it comes to how to write, the PyTorch tutorial will be helpful.

Writing Custom Datasets, DataLoaders and Transforms — PyTorch Tutorials 1.3.1 documentation

Preprocessing using torchvision.transforms is also described here. This time, we will only resize to 192 x 256 and Normalize.

torchvision.transforms — PyTorch master documentation

The SALICON dataset can be downloaded from LSUN’17 Saliency Prediction Challenge | SALICON.

Click here to see the implementation
import os

import torch.utils.data as data
import torchvision.transforms as transforms

class SALICONDataset(data.Dataset):
    def __init__(self, root_dataset_dir, val_mode = False):
        """
Dataset class for reading SALICON datasets
        
        Parameters:
        -----------------
        root_dataset_dir : str
Directory path above the SALICON dataset
        val_mode : bool (default: False)
If False, Train data is read. If True, Validation data is read.
        """
        self.root_dataset_dir = root_dataset_dir
        self.imgsets_dir = os.path.join(self.root_dataset_dir, 'SALICON/image_sets')
        self.img_dir = os.path.join(self.root_dataset_dir, 'SALICON/imgs')
        self.distribution_target_dir = os.path.join(self.root_dataset_dir, 'SALICON/algmaps')
        self.img_tail = '.jpg'
        self.distribution_target_tail = '.png'
        self.transform = transforms.Compose([transforms.Resize((192, 256)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
        self.distribution_transform = transforms.Compose([transforms.Resize((192, 256)), transforms.ToTensor()])
        
        if val_mode:
            train_or_val = "val"
        else:
            train_or_val = "train"
        imgsets_file = os.path.join(self.imgsets_dir, '{}.txt'.format(train_or_val))
        files = []
        for data_id in open(imgsets_file).readlines():
            data_id = data_id.strip()
            img_file = os.path.join(self.img_dir, '{0}{1}'.format(data_id, self.img_tail))
            distribution_target_file = os.path.join(self.distribution_target_dir, '{0}{1}'.format(data_id, self.distribution_target_tail))
            files.append({
                'img': img_file,
                'distribution_target': distribution_target_file,
                'data_id': data_id
            })
        self.files = files
        
    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        """
        Returns
        -----------
        data : list
            [img, distribution_target, data_id]
        """
        data_file = self.files[index]
        data = []

        img_file = data_file['img']
        img = Image.open(img_file)
        data.append(img)

        distribution_target_file = data_file['distribution_target']
        distribution_target = Image.open(distribution_target_file)
        data.append(distribution_target)
        
        # transform
        data[0] = self.transform(data[0])
        data[1] = self.distribution_transform(data[1])

        data.append(data_file['data_id'])
        return data

learn

Write the code for the rest of the learning. The point is how to calculate the loss function and how to learn Generator and Discriminator.

It takes about several hours to learn the same 120 epochs as the dissertation using GPU.

Click here to see the implementation
from datetime import datetime

import torch
from torch.autograd import Variable

#-----------------
# SETTING
root_dataset_dir = "" #Directory path above the SALICON dataset
alpha = 0.005 #Hyperparameters of the Generator loss function. The recommended value for the paper is 0.005
epochs = 120
batch_size = 32 #32 in the dissertation
#-----------------

#Use start time for file name
start_time_stamp = '{0:%Y%m%d-%H%M%S}'.format(datetime.now())

save_dir = "./log/"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#Load data loader
train_dataset = SALICONDataset(
                    root_dataset_dir,
                )
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle=True, num_workers = 4, pin_memory=True, sampler=None)
val_dataset = SALICONDataset(
                    root_dataset_dir,
                    val_mode=True
                )
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 1, shuffle=False, num_workers = 4, pin_memory=True, sampler=None)

#Load model and loss function
loss_func = torch.nn.BCELoss().to(DEVICE)
generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)

#Definition of optimization method (using the settings in the paper)
optimizer_G = torch.optim.Adagrad([
                {'params': generator.encoder_last.parameters()},
                {'params': generator.decoder.parameters()}
            ], lr=0.0001, weight_decay=3*0.0001)
optimizer_D = torch.optim.Adagrad(discriminator.parameters(), lr=0.0001, weight_decay=3*0.0001)

#Learning
for epoch in range(epochs):
    n_updates = 0 #Iteration count
    n_discriminator_updates = 0
    n_generator_updates = 0
    d_loss_sum = 0
    g_loss_sum = 0
    
    for i, data in enumerate(train_loader):
        imgs = data[0] # ([batch_size, rgb, h, w])
        salmaps = data[1] # ([batch_size, 1, h, w])

        #Create label for Discriminator
        valid = Variable(torch.FloatTensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False).to(DEVICE)
        fake = Variable(torch.FloatTensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False).to(DEVICE)

        imgs = Variable(imgs).to(DEVICE)
        real_salmaps = Variable(salmaps).to(DEVICE)

        #Alternately learn Generator and Discriminator for each iteration
        if n_updates % 2 == 0:
            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()
            gen_salmaps = generator(imgs)
            
            #Combine the original image and the generated Salience Map for input to the Discriminator to create a 4-channel array
            fake_d_input = torch.cat((imgs, gen_salmaps.detach()), 1) # ([batch_size, rgbs, h, w])
            
            #Calculate the loss function of Generator
            g_loss1 = loss_func(gen_salmaps, real_salmaps)
            g_loss2 = loss_func(discriminator(fake_d_input), valid)
            g_loss = alpha*g_loss1 + g_loss2
            
            g_loss.backward()
            optimizer_G.step()
            
            g_loss_sum += g_loss.item()
            n_generator_updates += 1
            
        else:
            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()
            
            #Combine the original image and the Salience Map of the correct data for input to the Discriminator to create a 4-channel array
            real_d_input = torch.cat((imgs, real_salmaps), 1) # ([batch_size, rgbs, h, w])

            #Calculate the loss function of Discriminator
            real_loss = loss_func(discriminator(real_d_input), valid)
            fake_loss = loss_func(discriminator(fake_d_input), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()
            
            d_loss_sum += d_loss.item()
            n_discriminator_updates += 1
            
        n_updates += 1
        if n_updates%10==0:
            if n_discriminator_updates>0:
                print(
                    "[%d/%d (%d/%d)] [loss D: %f, G: %f]"
                    % (epoch, epochs-1, i, len(train_loader), d_loss_sum/n_discriminator_updates , g_loss_sum/n_generator_updates)
                )
            else:
                print(
                    "[%d/%d (%d/%d)] [loss G: %f]"
                    % (epoch, epochs-1, i, len(train_loader), g_loss_sum/n_generator_updates)
                )                
    
    #Saving weights
    #Save every 5 epochs and the last epoch
    if ((epoch+1)%5==0)or(epoch==epochs-1):
        generator_save_path = '{}.pkl'.format(os.path.join(save_dir, "{}_generator_epoch{}".format(start_time_stamp, epoch)))
        discriminator_save_path = '{}.pkl'.format(os.path.join(save_dir, "{}_discriminator_epoch{}".format(start_time_stamp, epoch)))
        torch.save(generator.state_dict(), generator_save_path)
        torch.save(discriminator.state_dict(), discriminator_save_path)
        
    #Visualize part of Validation data for each epoch
    with torch.no_grad():
        print("validation")
        for i, data in enumerate(val_loader):
            image = Variable(data[0]).to(DEVICE)
            gen_salmap = generator(imgs)
            gen_salmap_np = np.array(gen_salmaps.data.cpu())[0, 0]
            
            plt.imshow(np.array(image[0].cpu()).transpose(1, 2, 0))
            plt.show()
            plt.imshow(gen_salmap_np)
            plt.show()
            if i==1:
                break

Estimate the Salience Map

Input the image to the learned SalGAN and try to estimate the Salience Map. See how it is estimated with images that are not used for learning.

Click here to see the implementation
generator_path = "" #Path of Generator weight file (pkl) obtained by learning
image_path = "COCO_train2014_000000196971.jpg " #The path of the image you want to enter

generator = Generator().to(DEVICE)
generator.load_state_dict(torch.load(generator_path))

image_pil = Image.open(image_path) #PIL format image input is assumed for transform
image = np.array(image_pil)
plt.imshow(image)
plt.show()

transform = transforms.Compose([transforms.Resize((192, 256)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
image_torch = transform(image_pil)
image_torch = image_torch.unsqueeze(0).to(DEVICE) # (1, rgb, h, w)
with torch.no_grad():
    pred_saliencymap = generator(img_torch)
    pred_saliencymap = np.array(pred_saliencymap.cpu())[0, 0]
pred_saliencymap = pred_saliencymap/pred_saliencymap.sum() #Scaling so that the sum is 1
pred_saliencymap = ((pred/pred.max())*255).astype(np.uint8) #Np so that it can be treated as an image.Convert to uint8
plt.imshow(pred_saliencymap)
plt.show()

fig_15.png

Figure: Example of Salience Map estimated by SalGAN

The Salience Map in this figure (b) has been estimated. Compared to the correct answer data (Fig. (C)), it has a big impression, but the probability of convincing points around the pitcher and batter can be estimated high.

If you want to learn well, you need to verify with various indicators for Salience Map as done in the paper. Since we have not verified it now, it is unclear how much results are obtained compared to the SalGAN introduced in the paper. This time, the main thing is cropping, so I'd like to move on because I've made something qualitatively like that.

Crop the image with the estimated Saliency Map

Now you have what you want to make. By combining the Salience Map estimated by SalGAN with the cropping class, you can crop the image nicely.

Click here to see the implementation
# -------------------
# SETTING
threshhold = 0.3 #Set threshold, float (0<threshhold<1)
generator_path = "" #Path of Generator weight file (pkl) obtained by learning
image_path = "COCO_train2014_000000196971.jpg " #The path of the image you want to crop
# -------------------

generator = Generator().to(DEVICE)
generator.load_state_dict(torch.load(generator_path))

image_pil = Image.open(image_path) #PIL format image input is assumed for transform
image = np.array(image_pil)
plt.imshow(image)
plt.show()

transform = transforms.Compose([transforms.Resize((192, 256)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
image_torch = transform(image_pil)
image_torch = image_torch.unsqueeze(0).to(DEVICE) # (1, rgb, h, w)
with torch.no_grad():
    pred_saliencymap = generator(image_torch)
    pred_saliencymap = np.array(pred_saliencymap.cpu())[0, 0]
pred_saliencymap = pred_saliencymap/pred_saliencymap.sum() #Scaling so that the sum is 1
pred_saliencymap = ((pred_saliencymap/pred_saliencymap.max())*255).astype(np.uint8) #Np so that it can be treated as an image.Convert to uint8
plt.imshow(pred_saliencymap)
plt.show()

#Visualization of cropped images using Salience Map
cropper = SaliencyBasedImageCropping(aspect_rate=(1, 3))
cropped_image = cropper.crop(image, pred_saliencymap, threshhold=0.3)
plt.imshow(cropped_image)
plt.show()

#Visualization of Salience Map and bounding box
#The one that matches the specified aspect ratio is red, and the one that matches the specified aspect ratio is green.
saliencymap_colored = color_saliencymap(pred_saliencymap) # ndarray, (h, w, rgb), np.uint8
overlaid_image = overlay_saliencymap_and_image(saliencymap_colored, image) # ndarray, (h, w, rgb), np.uint8
box_x, box_y, box_width, box_height = cropper.bounding_box
box_x_0, box_y_0, box_width_0, box_height_0 = cropper.bounding_box_based_on_binary_saliency
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#FF0000', fill=False)
bounding_box_based_on_binary_saliency = patches.Rectangle(xy=(box_x_0, box_y_0), width=box_width_0, height=box_height_0, ec='#00FF00', fill=False)
ax.imshow(overlaid_image)
ax.add_patch(bounding_box)
ax.add_patch(bounding_box_based_on_binary_saliency)
plt.show()

#Visualization of image with cropped center for comparison
center_cropped_image = cropper.crop_center(image)
plt.imshow(center_cropped_image)
plt.show()

fig_12.png Figure: Comparison of cutouts when using correct data and when using SalGAN (baseball image)

When using the correct answer data in Fig. (A) and when using SalGAN in Fig. (B), almost the same result was obtained by cutting out the batter part. Which batter or pitcher will look at you? I don't think I've learned enough about that, but I'm glad I got through it in the same way.

Isn't this kind of story a cherry picking that only lists successful results? I have a question. If you want to measure this result quantitatively, you can see how these two types overlap with the SALICON dataset. It is an IoU-like calculation in the object detection task. However, I will omit it because it is a story that I just made it now.

I used the image of the cat that appeared in the first half of the article because it is cute, but it is actually Train data, so it is not suitable for verification. But let's take a look.

fig_11.png Figure: Comparison of cropping when using correct data and when using SalGAN (cat image)

Almost the same result was obtained here. It was good.

Cut out the Christmas tree

Finally, we will return to the image of the first Christmas tree. If you can make a convincing cutout with a photo you took that is not in the dataset, you have achieved your goal.

fig_13.png Figure: Comparison of cropping with SalGAN and simply centered (Christmas tree image)

I got the perfect result. The figure (a) with the Christmas tree looks better than the figure (b) without it. AI? We have completed a mechanism that can automatically do things that are close to what people do with the power of. This may reduce the disappointment of cropping blank areas, or reduce the work of manually cropping images.

Summary

The contents were the implementation of two papers, cropping using Salience Map [^ 2] and SalGAN [^ 5] for estimating Salience Map + alpha.

You can make something like this with only the publicly available information. Even if you have stopped moving something like a tutorial in deep learning or machine learning, I would like you to take a little challenge and try to make something like this!

Recommended Posts

I made an AI that crops an image nicely using Salience Map
I made an Android application that displays Google Map
I made an image discrimination (cifar10) model using a convolutional neural network.
I made an extenum package that extends an enum
I made a LINE BOT that returns a terrorist image using the Flickr API
I made an AI that predicts from trivia and made me infer my trivia. Hee-AI
I made an Ansible-installer
I made an anomaly detection model that works on iOS
I made an original program guide using the NHK program guide API.
I made an Xubuntu server.
[Python] I made an image viewer with a simple sorting function.
I made an Anpanman painter discriminator
I made an Angular starter kit
I made an AI to judge whether it is alcohol or not!
I made a plug-in "EZPrinter" that easily outputs map PDF with QGIS.
I want to display an image on Jupyter Notebook using OpenCV (mac)
I made a Chrome extension that displays a graph on an AMeDAS page
ConSinGAN: I tried using GAN that can be generated from one image