[PYTHON] J'ai créé une IA qui recadre joliment une image en utilisant Saliency Map

introduction

Dans cet article, je vais implémenter la méthode de recadrage d'images à l'aide de Saliency Map en utilisant l'apprentissage en profondeur avec Python / PyTorch lors de la lecture de l'article.

Lorsque nous parlons d'images en apprentissage profond, nous essayons souvent de classer les nombres manuscrits et de détecter les personnes, mais j'espère que vous verrez que vous pouvez également le faire.

Cet article participe au DeNA 20 New Graduate Advent Calendar 2019 --Qiita. Merci à Advent Calendar de m'avoir donné l'opportunité de le faire!

Hypothèse du lecteur

Comme il existe différents genres de calendriers de l'Avent, si vous venez de lire l'article, il est destiné à tous ceux qui ont touché au programme. Afin de le déplacer, il est supposé pour ceux qui ont fait des choses de type tutoriel d'apprentissage en profondeur.

Le code supposé pour Jupyter Notebook est inclus pour un essai facile, vous pouvez donc le déplacer à portée de main. L'affichage est réduit, cliquez donc pour l'ouvrir si nécessaire.

La bibliothèque utilise uniquement celle déjà installée dans Google Colaboratory. En raison du grand ensemble de données, il peut être un peu difficile d'essayer jusqu'à ce que vous vous entraîniez.

Recadrage d'image

Parfois, je veux recadrer (recadrer) une image d'une manière ou d'une autre. Par exemple, l'image de l'icône est à peu près carrée, donc je pense que tout le monde a réfléchi à la façon de la couper lors de l'inscription à divers services. De plus, l'image d'en-tête est à mi-chemin horizontalement et la forme de l'image est souvent décidée sur place. En revanche, si l'utilisateur le découpe en une forme fixe, vous pouvez faire de votre mieux pour qu'il se sente bien, mais il existe de nombreux cas où il est nécessaire de l'automatiser côté application.

Un petit exemple

Supposons que vous souhaitiez que l'image publiée soit toujours affichée verticalement (1: 3) sur une page. Il est verticalement long car c'est une condition qui semble difficile à couper.

J'ai pris cette photo avec "C'est un joli hall avec un sapin de Noël", si vous la coupez vous-même, bien sûr je la ferai comme ça pour montrer le sapin de Noël.

Cependant, il n'est pas possible pour une personne de voir et de découper toutes les images publiées en grand nombre, elles seront donc automatisées. Eh bien, j'ai décidé de l'implémenter en Python car il serait prudent de couper le milieu.

Cliquez ici pour voir l'implémentation
import numpy as np
import cv2
import matplotlib.pyplot as plt

def crop(image, aspect_rate=(1, 1)):
    """     
Découpez l'image à partir du centre afin qu'elle ait le rapport hauteur / largeur spécifié.
    
    Parameters:
    -----------------
    image : ndarray, (h, w, rgb), uint8
    aspect_rate : tuple of int (x, y)
        default : (1, 1)

    Returns:
    -----------------
    cropped_image : ndarray, (h, w, rgb), uint8
    """        
    assert image.dtype==np.uint8
    assert image.ndim==3        

    im_size = (image.shape[1], image.shape[0]) # tuple of int, (width, height)
    center = (int(round(im_size[0]/2)), int(round(im_size[1]/2))) # tuple of int, (x, y)

    #Trouvez les quatre valeurs suivantes
    # box_x : int,Coordonnée x en haut à gauche pour rogner, box_y : int,Coordonnée y en haut à gauche pour rogner
    # box_width : int,Largeur de découpe, box_height : int,Hauteur de découpe
    if im_size[0]>im_size[1]:
        box_y = 0
        box_height = im_size[1]
        box_width = int(round((im_size[1]/aspect_rate[1])*aspect_rate[0]))
        if box_width>im_size[0]:
            box_x = 0
            box_width = im_size[0]
            box_height = int(round((im_size[0]/aspect_rate[0])*aspect_rate[1]))
            box_y = int(round(center[1]-(box_height/2)))
        else:
            box_x = int(round(center[0]-(box_width/2)))
    else:
        box_x = 0
        box_width = im_size[0]
        box_height = int(round((im_size[0]/aspect_rate[0])*aspect_rate[1]))
        if box_height>im_size[1]:
            box_y = 0
            box_height = im_size[1]
            box_width = int(round((im_size[1]/aspect_rate[1])*aspect_rate[0]))
            box_y = int(round(center[0]-(box_width/2)))
        else:
            box_y = int(round(center[1]-(box_height/2)))

    cropped_image = image[box_y:box_y+box_height, box_x:box_x+box_width]
    return cropped_image
    
#image: Lisez l'image avec OpenCV etc. et faites-en un tableau NumPy
image = cv2.imread("tree.jpg ")[:, :, ::-1]
cropped_image = crop(image, aspect_rate=(1, 3))
plt.imshow(cropped_image)
plt.show()

En supposant que tout le côté long de l'image est utilisé, il est calculé à partir du rapport hauteur / largeur étant donné la longueur du côté court à ce moment.

Je vais essayer.

L'élément de Noël a disparu et c'est juste une belle photo de lobby. C'est mauvais. AI? Faisons quelque chose avec la puissance de.

Choses à se référer

Cette fois, j'essaierai d'imiter celui utilisant Saliency Map que Twitter et Adobe ont introduit au cours des deux dernières années. Sur Twitter [^ 1], lorsque vous publiez une image, elle est bien affichée sur la timeline. En outre, InDesigin d'Adobe dispose d'une fonction appelée Content-Aware Fit qui recadre l'image en fonction de la plage spécifiée.

[^ 1]: Présentation d'un réseau de neurones qui coupe les images de manière optimale et automatique https://blog.twitter.com/ja_jp/topics/product/2018/0125ML-CR.html

La détection d'objets peut être utilisée comme méthode de comparaison. Cependant, la méthode basée sur la carte de saillance est polyvalente en ce qu'elle n'affiche pas toujours les objets avec les étiquettes entraînées.

Recadrage avec la carte de saillance

Une méthode de recadrage utilisant Saliency Map [^ 2] a été proposée en 2013 par l'article d'Ardizzone "Saliency Based Image Cropping".

Qu'est-ce que la carte de saillance?

Où va la ligne de mire lorsqu'une personne voit l'image? ** Saliency Map ** est une version pixelisée de. Par exemple, dans le coin inférieur gauche de la figure, cela est obtenu en mesurant à partir de nombreuses personnes, et la carte de saillance est obtenue en calculant de telles choses. Sur cette figure, les zones blanches sont plus susceptibles d'avoir un point de vue et les zones noires sont moins susceptibles d'avoir un point de vue.

Figure: Exemple de carte de saillance. En haut à gauche: Image. En haut à droite: le point de vue mesuré est indiqué par un X rouge. En bas à gauche: carte de saillance. En bas à droite: Carte de saillance en couleur et superposée à l'image.

Cette figure est une visualisation des données d'apprentissage du jeu de données SALICON [^ 3]. Le X rouge en haut à droite correspond aux données de point de vue obtenues en demandant à de nombreuses personnes de regarder l'image en haut à gauche et de toucher la partie que vous regardez avec le curseur de la souris.

Si vous appliquez un filtre gaussien basé sur ces données, vous pouvez créer une carte qui montre la probabilité (0 à 1) qu'il y ait un point de vue en unités de pixels, comme indiqué en bas à gauche. Il s'agit des données d'entraînement de la carte de saillance que vous souhaitez calculer.

Comme indiqué en bas à droite, si vous le colorez et le superposez sur l'image, vous pouvez voir qu'il y a une forte probabilité que le chat soit remarqué. Lorsque la probabilité du point de vue est proche de 1, elle est rouge et lorsque la probabilité du point de vue est proche de 0, elle est bleue.

Implémentation de la méthode Ardizzone

Implémentons la méthode d'Ardizzone. Pour le moment, Saliency Map utilisera les données d'entraînement de l'ensemble de données SALICON telles quelles. Il s'agit de l'image du chat et des données d'apprentissage (données de réponse correctes) de la carte de saillance.

Quel genre de méthode?

C'est une méthode de recadrage pour inclure tous les pixels au-dessus d'une certaine probabilité. Cela signifie que vous ne devez le placer que dans des endroits où vous risquez de le voir.

fig2.png Figure: Pipeline de méthodes d'Ardizzone (extrait de l'article [^ 2])

Pour résumer ce chiffre en mots, il y a les trois étapes suivantes.

--Saliency Map est binarisé avec un certain seuil (défini sur 1 et 0) --Trouver un cadre englobant la plage de 1 --Cadrer l'image d'origine avec le cadre de sélection

Binar

La binarisation est facile avec NumPy. NumPy diffuse également le calcul des opérateurs de comparaison (> et ==), donc si vous exécutez ndarray> float, vous obtiendrez True ou False de chaque élément et la binarisation est terminée. ..

Cliquez ici pour voir l'implémentation
threshhold = 0.3 #Définir le seuil, float (0<threshhold<1)
saliencymap_path = 'COCO_train2014_000000196971.png' #Chemin de la carte de saillance

saliencymap = cv2.imread(saliencymap_path)[:, :, ::-1] # ndarray, (h, w, rgb), np.uint8 (0-255)
saliencymap = saliencymap[:, :, 0] # ndarray, (h, w), np.uint8 (0-255)
plt.imshow(saliencymap)
plt.show()

threshhold *= 255 #La carte de saillance lue à partir de l'image est 0-Puisqu'il est 255, convertissez la plage

binarized_saliencymap = saliencymap>threshhold # ndarray, (h, w), bool

plt.imshow(binarized_saliencymap)
plt.show()

fig3.png Figure: Résultats de la binarisation

Le résultat est comme indiqué sur cette figure. Par défaut, plt.imshow () de matplotlib affiche les grandes valeurs en jaune et les petites valeurs en violet.

Le seuil est un hyper paramètre qui peut être défini de manière arbitraire. Cette fois, il est unifié à 0,3 tout au long de l'article.

Demandez une boîte englobante

Calculez une ** boîte englobante ** (juste un rectangle fermé) qui contient tous les 1 (True) obtenus par binarisation.

Ceci est implémenté dans `cv2.boundingRect () 'd'OpenCV et peut être réalisé en l'appelant simplement.

Structural Analysis and Shape Descriptors — OpenCV 2.4.13.7 documentation

[Fonctionnalités de zone (contour) - documentation OpenCV-Python Tutorials 1](http://labs.eecs.tottori-u.ac.jp/sd/Member/oyamada/OpenCV/html/py_tutorials/py_imgproc/py_contours/py_contour_features/ py_contour_features.html)

Utilisez patches.Rectangle () pour dessiner un rectangle dans matplotlib.

matplotlib.patches.Rectangle — Matplotlib 3.1.1 documentation

Cliquez ici pour voir l'implémentation
import matplotlib.patches as patches

#Convertir en un format qui peut être géré par OpenCV
binarized_saliencymap = binarized_saliencymap.astype(np.uint8) # ndarray, (h, w), np.uint8 (0 or 1)

box_x, box_y, box_width, box_height = cv2.boundingRect(binarized_saliencymap)
# box_x : int,Coordonnée x en haut à gauche pour rogner, box_y : int,Coordonnée y en haut à gauche pour rogner
# box_width : int,Largeur de découpe, box_height : int,Hauteur de découpe

#Dessin de la boîte englobante
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#00FF00', fill=False)
ax.imshow(binarized_saliencymap)
ax.add_patch(bounding_box)
plt.show()

fig4.png Figure: Résultat de l'obtention du cadre de sélection

Vous pouvez obtenir la boîte englobante comme indiqué sur cette figure. Les informations du rectangle sont conservées en tant que coordonnée supérieure gauche et valeur de largeur / hauteur.

coupé

Recadrez l'image en fonction du cadre de sélection obtenu. Découpez le ndarray de l'image en utilisant la valeur dans le cadre de sélection.

Cliquez ici pour voir l'implémentation
image_path = 'COCO_train2014_000000196971.jpg' #Chemin de l'image
image = cv2.imread(image_path)[:, :, ::-1] # ndarray, (h, w, rgb), np.uint8 (0-255)

cropped_image = image[box_y:box_y+box_height, box_x:box_x+box_width] # ndarray, (h, w, rgb), np.uint8 (0-255)

#Visualisation
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#00FF00', fill=False)
ax.imshow(image)
ax.add_patch(bounding_box)
plt.show()

plt.imshow(cropped_image)
plt.show()

fig5.png Figure: Résultats découpés par la méthode Ardizzone

Comme le montre cette figure, une image a été obtenue dans laquelle la ligne de visée était susceptible d'être dirigée.

Superposition avec la carte de saillance colorisée

Pour faciliter la visualisation de son traitement, superposons la carte de saillance colorisée et le cadre de sélection sur l'image. Implémentez une fonction pour colorer la carte de saillance et une fonction pour superposer la carte de saillance sur l'image.

Cliquez ici pour voir l'implémentation
def color_saliencymap(saliencymap):
    """
Colorez et visualisez la carte de saillance. 1 est rouge et 0 est bleu.
    
    Parameters
    ----------------
    saliencymap : ndarray, np.uint8, (h, w) or (h, w, rgb)
    
    Returns
    ----------------
    saliencymap_colored : ndarray, np.uint8, (h, w, rgb)
    """
    assert saliencymap.dtype==np.uint8
    assert (saliencymap.ndim == 2) or (saliencymap.ndim == 3)
    
    saliencymap_colored = cv2.applyColorMap(saliencymap, cv2.COLORMAP_JET)[:, :, ::-1]
    
    return saliencymap_colored

def overlay_saliencymap_and_image(saliencymap_color, image):
    """
Superposez l'image avec la carte de saillance.
    
    Parameters
    ----------------
    saliencymap_color : ndarray, (h, w, rgb), np.uint8
    image : ndarray, (h, w, rgb), np.uint8
    
    Returns
    ----------------
    overlaid_image : ndarray(h, w, rgb)
    """
    assert saliencymap_color.ndim==3
    assert saliencymap_color.dtype==np.uint8
    assert image.ndim==3
    assert image.dtype==np.uint8
    im_size = (image.shape[1], image.shape[0])
    saliencymap_color = cv2.resize(saliencymap_color, im_size, interpolation=cv2.INTER_CUBIC)
    overlaid_image = cv2.addWeighted(src1=image, alpha=1, src2=saliencymap_color, beta=0.7, gamma=0)
    return overlaid_image

saliencymap_colored = color_saliencymap(saliencymap) # ndarray, (h, w, rgb), np.uint8
overlaid_image = overlay_saliencymap_and_image(saliencymap_colored, image) # ndarray, (h, w, rgb), np.uint8

#Visualisation
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#00FF00', fill=False)
ax.imshow(overlaid_image)
ax.add_patch(bounding_box)
plt.show()

fig_6.png Figure: Image de la carte de saillance colorisée et du cadre de sélection superposés

Comme le montre cette figure, vous pouvez voir que les zones de la carte de saillance susceptibles de devenir rouges sont entourées.

Prise en charge de tous les formats d'image

Avec la méthode d'Ardizzone, la taille et le rapport hauteur / largeur dépendent de la carte de saillance. Mais maintenant que je veux recadrer à un certain rapport hauteur / largeur, je dois y réfléchir.

Découpez pour que la valeur totale de la carte de saillance soit grande

Je n'ai pas trouvé de méthode existante pour cela, j'ai donc décidé d'utiliser l'algorithme suivant pour déterminer la plage à découper.

Utilisez la plage obtenue autant que possible pour trouver la plage qui maximise la valeur totale de la carte de saillance.

Créez une «classe SaliencyBasedImageCropping» pour le recadrage et résumez le code ci-dessous.

Cliquez ici pour voir l'implémentation
import copy

import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

class SaliencyBasedImageCropping:
    """
Une classe pour recadrer des images à l'aide de Saliency Map. Une méthode qui utilise toute la plage qui dépasse un certain seuil[1]Utiliser.
    
* Si aucun pixel ne dépasse le seuil, l'image entière est renvoyée.

    [1] Ardizzone, Edoardo, Alessandro Bruno, and Giuseppe Mazzola. "Saliency based image cropping." International Conference on Image Analysis and Processing. Springer, Berlin, Heidelberg, 2013.
        
    Parameters
    ----------------
    aspect_rate : tuple of int (x, y)
Si vous spécifiez le rapport hauteur / largeur ici,[1]Trouvez la plage qui maximise la valeur totale de la carte de saillance tout en utilisant la plage obtenue par la méthode de.
    min_size : tuple of int (w, h)
        [1]Si chaque axe de la plage obtenue par la méthode de est plus petit que cette valeur, la plage est étendue uniformément à partir du centre de la plage.
    
    Attributes
    ----------------
    self.aspect_rate : tuple of int (x, y)
    self.min_size : tuple of int (w, h)
    im_size : tuple of int (w, h)
    self.bounding_box_based_on_binary_saliency : list
        [1]Gamme obtenue par la méthode de
        box_x : int
        box_y : int
        box_width : int
        box_height : int
    self.bounding_box : list
La plage de recadrage finale avec un rapport hauteur / largeur ajusté
        box_x : int
        box_y : int
        box_width : int
        box_height : int
    """
    def __init__(self, aspect_rate=None, min_size=(200, 200)):
        assert (aspect_rate is None)or((type(aspect_rate)==tuple)and(len(aspect_rate)==2))
        assert (type(min_size)==tuple)and(len(min_size)==2)
        self.aspect_rate = aspect_rate
        self.min_size = min_size
        self.im_size = None
        self.bounding_box_based_on_binary_saliency = None
        self.bounding_box = None
    
    def _compute_bounding_box_based_on_binary_saliency(self, saliencymap, threshhold):
        """
Méthode Ardizzone[1]Trouvez la plage de rognage en vous basant sur la carte de saillance.
        
        Parameters:
        -----------------
        saliencymap : ndarray, (h, w), np.uint8
            0<=saliencymap<=255
            
        threshhold : float
            0<threshhold<255
        
        Returns:
        -----------------
        bounding_box_based_on_binary_saliency : list
            box_x : int
            box_y : int
            box_width : int
            box_height : int
        
        """
        assert (threshhold>0)and(threshhold<255)
        assert saliencymap.dtype==np.uint8
        assert saliencymap.ndim==2
        
        binarized_saliencymap = saliencymap>threshhold
        #Si aucun pixel de la carte de saillance ne dépasse le seuil, traitez tout comme dépassé.
        if saliencymap.sum()==0:
            saliencymap+=True
        binarized_saliencymap = (binarized_saliencymap.astype(np.uint8))*255
        # binarized_saliencymap : ndarray, (h, w), uint8, 0 or 255
        
        #Les petites zones sont effacées par le traitement de la morphologie (ouverture)
        kernel_size = round(min(self.im_size)*0.02)
        kernel = np.ones((kernel_size, kernel_size))
        binarized_saliencymap = cv2.morphologyEx(binarized_saliencymap, cv2.MORPH_OPEN, kernel)

        box_x, box_y, box_width, box_height = cv2.boundingRect(binarized_saliencymap)
        bounding_box_based_on_binary_saliency = [box_x, box_y, box_width, box_height]
        return bounding_box_based_on_binary_saliency
        
    def _expand_small_bounding_box_to_minimum_size(self, bounding_box):
        """
Si la plage est inférieure à la taille spécifiée, élargissez-la. Répartissez uniformément la cuisinière à partir du centre de la cuisinière. S'il sort de l'image, étalez-le sur le côté opposé.
        
        Parameters:
        -----------------
        bounding_box : list
            box_x : int
            box_y : int
            box_width : int
            box_height : int
        """
        bounding_box = copy.copy(bounding_box) #Copie profonde parce que je souhaite conserver les valeurs de la liste d'origine
        
        # axis=0 : x and witdth, axis=1 : y and hegiht
        for axis in range(2):
            if bounding_box[axis+2]<self.min_size[axis+0]:
                bounding_box[axis+0] -= np.floor((self.min_size[axis+0]-bounding_box[axis+2])/2).astype(np.int)
                bounding_box[axis+2] = self.min_size[axis+0]
                if bounding_box[axis+0]<0:
                    bounding_box[axis+0] = 0
                if (bounding_box[axis+0]+bounding_box[axis+2])>self.im_size[axis+0]:
                    bounding_box[axis+0] -= (bounding_box[axis+0]+bounding_box[axis+2]) - self.im_size[axis+0]
        return bounding_box
    
    def _expand_bounding_box_to_specified_aspect_ratio(self, bounding_box, saliencymap):
        """
Développez la plage afin qu'elle ait le rapport hauteur / largeur spécifié.
Méthode Ardizzone[1]Trouvez la plage qui maximise la valeur totale de la carte de saillance tout en utilisant autant que possible la plage obtenue à l'étape 2.
        
        Parameters
        ----------------
        bounding_box : list
            box_x : int
            box_y : int
            box_width : int
            box_height : int
        saliencymap : ndarray, (h, w), np.uint8
            0<=saliencymap<=255
        """
        assert saliencymap.dtype==np.uint8
        assert saliencymap.ndim==2
        
        bounding_box = copy.copy(bounding_box)

        # axis=0 : x and witdth, axis=1 : y and hegiht    
        if bounding_box[2]>bounding_box[3]:
            long_length_axis = 0
            short_length_axis = 1
        else:
            long_length_axis = 1
            short_length_axis = 0
        
        #Dans quelle direction s'étirer
        rate1 = self.aspect_rate[long_length_axis]/self.aspect_rate[short_length_axis]
        rate2 = bounding_box[2+long_length_axis]/bounding_box[2+short_length_axis]
        if rate1>rate2:
            moved_axis = long_length_axis
            fixed_axis = short_length_axis
        else:
            moved_axis = short_length_axis
            fixed_axis = long_length_axis
        
        fixed_length = bounding_box[2+fixed_axis]
        moved_length = int(round((fixed_length/self.aspect_rate[fixed_axis])*self.aspect_rate[moved_axis]))
        if moved_length > self.im_size[moved_axis]:
            #S'il dépasse la taille de l'image lorsqu'elle est étirée
            moved_axis, fixed_axis = fixed_axis, moved_axis
            fixed_length = self.im_size[fixed_axis]
            moved_length = int(round((fixed_length/self.aspect_rate[fixed_axis])*self.aspect_rate[moved_axis]))
            fixed_point = 0
            start_point = bounding_box[moved_axis]
            end_point = bounding_box[moved_axis]+bounding_box[2+moved_axis]
            if fixed_axis==0:
                saliencymap_extracted = saliencymap[start_point:end_point, :]
            elif fixed_axis==1:
                saliencymap_extracted = saliencymap[:, start_point:end_point:]
        else:   
            #Lorsqu'il est étiré pour s'adapter à la taille de l'image
            start_point = int(bounding_box[moved_axis]+bounding_box[2+moved_axis]-moved_length)
            if start_point<0:
                start_point = 0
            end_point = int(bounding_box[moved_axis]+moved_length)
            if end_point>self.im_size[moved_axis]:
                end_point = self.im_size[moved_axis]
            if fixed_axis==0:
                fixed_point = bounding_box[fixed_axis]
                saliencymap_extracted = saliencymap[start_point:end_point, fixed_point:fixed_point+fixed_length]
            elif fixed_axis==1:
                fixed_point = bounding_box[fixed_axis]
                saliencymap_extracted = saliencymap[fixed_point:fixed_point+fixed_length, start_point:end_point]
        saliencymap_summed_1d = saliencymap_extracted.sum(moved_axis)
        self.saliencymap_summed_slided = np.convolve(saliencymap_summed_1d, np.ones(moved_length), 'valid')
        moved_point = np.array(self.saliencymap_summed_slided).argmax() + start_point
        
        if fixed_axis==0:
            bounding_box = [fixed_point, moved_point, fixed_length, moved_length]
        elif fixed_axis==1:
            bounding_box = [moved_point, fixed_point, moved_length, fixed_length]
        return bounding_box
    
    def crop_center(self, image):
        """     
Recadrez le centre de l'image avec le rapport hauteur / largeur spécifié sans utiliser la carte de saillance.
        
        Parameters:
        -----------------
        image : ndarray, (h, w, rgb), uint8
            
        Returns:
        -----------------
        cropped_image : ndarray, (h, w, rgb), uint8
        """        
        assert image.dtype==np.uint8
        assert image.ndim==3        
        
        im_size = (image.shape[1], image.shape[0]) # tuple of int, (width, height)
        center = (int(round(im_size[0]/2)), int(round(im_size[1]/2))) # tuple of int, (x, y)
        
        if im_size[0]>im_size[1]:
            box_y = 0
            box_height = im_size[1]
            box_width = int(round((im_size[1]/self.aspect_rate[1])*self.aspect_rate[0]))
            if box_width>im_size[0]:
                box_x = 0
                box_width = im_size[0]
                box_height = int(round((im_size[0]/self.aspect_rate[0])*self.aspect_rate[1]))
                box_y = int(round(center[1]-(box_height/2)))
            else:
                box_x = int(round(center[0]-(box_width/2)))

        else:
            box_x = 0
            box_width = im_size[0]
            box_height = int(round((im_size[0]/self.aspect_rate[0])*self.aspect_rate[1]))
            if box_height>im_size[1]:
                box_y = 0
                box_height = im_size[1]
                box_width = int(round((im_size[1]/self.aspect_rate[1])*self.aspect_rate[0]))
                box_y = int(round(center[0]-(box_width/2)))
            else:
                box_y = int(round(center[1]-(box_height/2)))
        
        cropped_image = image[box_y:box_y+box_height, box_x:box_x+box_width]
        return cropped_image
    
    def crop(self, image, saliencymap, threshhold=0.3):
        """     
Recadrage à l'aide de la carte de saillance.
        
        Parameters:
        -----------------
        image : ndarray, (h, w, rgb), np.uint8
        saliencymap : ndarray, (h, w), np.uint8
            Saliency map's ndarray need not be the same size as image's ndarray. Saliency map is resized within this method.
        threshhold : float
            0 < threshhold <1
            
        Returns:
        -----------------
        cropped_image : ndarray, (h, w, rgb), uint8
        """
        assert (threshhold>0)and(threshhold<1)
        assert image.dtype==np.uint8
        assert image.ndim==3
        assert saliencymap.dtype==np.uint8
        assert saliencymap.ndim==2
        
        threshhold = threshhold*255 # scale to 0 - 255
        self.im_size = (image.shape[1], image.shape[0]) # (width, height)
        saliencymap = cv2.resize(saliencymap, self.im_size, interpolation=cv2.INTER_CUBIC)

        # compute bounding box based on saliency map
        bounding_box_based_on_binary_saliency = self._compute_bounding_box_based_on_binary_saliency(saliencymap, threshhold)
        bounding_box = self._expand_small_bounding_box_to_minimum_size(bounding_box_based_on_binary_saliency)
        if self.aspect_rate is not None:
            bounding_box = self._expand_bounding_box_to_specified_aspect_ratio(bounding_box, saliencymap)
            
        box_y = bounding_box[1]
        box_x = bounding_box[0]
        box_height = bounding_box[3]
        box_width = bounding_box[2]
        
        cropped_image = image[box_y:box_y+box_height, box_x:box_x+box_width]
        
        self.bounding_box_based_on_binary_saliency = bounding_box_based_on_binary_saliency
        self.bounding_box  = bounding_box
        
        return cropped_image

# -------------------
# SETTING
threshhold = 0.3 #Définir le seuil, float (0<threshhold<1)
saliencymap_path = 'COCO_train2014_000000196971.png' #Chemin de la carte de saillance
image_path = 'COCO_train2014_000000196971.jpg' #Chemin de l'image
# -------------------

saliencymap = cv2.imread(saliencymap_path)[:, :, ::-1] # ndarray, (h, w, rgb), np.uint8 (0-255)
saliencymap = saliencymap[:, :, 0] # ndarray, (h, w), np.uint8 (0-255)
image = cv2.imread(image_path)[:, :, ::-1] # ndarray, (h, w, rgb), np.uint8 (0-255)

#Visualisation d'images recadrées à l'aide de la carte de saillance
cropper = SaliencyBasedImageCropping(aspect_rate=(1, 3))
cropped_image = cropper.crop(image, saliencymap, threshhold=0.3)
plt.imshow(cropped_image)
plt.show()

#Visualisation de la carte de saillance et de la boîte englobante
#Rouge pour le rapport hauteur / largeur spécifié, vert avant la correspondance
saliencymap_colored = color_saliencymap(saliencymap) # ndarray, (h, w, rgb), np.uint8
overlaid_image = overlay_saliencymap_and_image(saliencymap_colored, image) # ndarray, (h, w, rgb), np.uint8
box_x, box_y, box_width, box_height = cropper.bounding_box
box_x_0, box_y_0, box_width_0, box_height_0 = cropper.bounding_box_based_on_binary_saliency
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#FF0000', fill=False)
bounding_box_based_on_binary_saliency = patches.Rectangle(xy=(box_x_0, box_y_0), width=box_width_0, height=box_height_0, ec='#00FF00', fill=False)
ax.imshow(overlaid_image)
ax.add_patch(bounding_box)
ax.add_patch(bounding_box_based_on_binary_saliency)
plt.show()

#Visualisation de l'image avec centre recadré pour comparaison
center_cropped_image = cropper.crop_center(image)
plt.imshow(center_cropped_image)
plt.show()

Np.convolve () est utilisé pour trouver la plage où la carte de saillance a la valeur maximale.

numpy.convolve — NumPy v1.17 Manual

Il s'agit d'une fonction de convolution unidimensionnelle. En convoluant avec un tableau de tous les 1 de la longueur que vous souhaitez additionner, vous pouvez calculer la somme pour chaque plage fixe comme indiqué ci-dessous.

array_1d = np.array([1, 2, 3, 4])
print(np.convolve(array_1d, np.ones(2), 'valid')) # [3. 5. 7.]

Si vous utilisez une simple instruction for sur Python, elle ralentira, nous combinerons donc autant que possible les fonctions NumPy.

De plus, dans le processus de binarisation, nous avons ajouté une implémentation qui efface une très petite zone par conversion de morphologie. Surtout lorsque la carte de saillance est obtenue par apprentissage en profondeur après cela, une telle zone est susceptible de se produire, cette implémentation est donc ajoutée.

Conversion de morphologie - documentation OpenCV-Python Tutorials 1

Voir le résultat

Le cadre de sélection vert est utilisé avant de régler le rapport hauteur / largeur et le cadre de sélection rouge est utilisé après le réglage du rapport hauteur / largeur.

fig_7.png Figure: Résultats du recadrage avec un rapport hauteur / largeur de 1: 3 à l'aide de la carte de saillance

Comme le montre cette figure (a), un chat et un savon pour les mains à une distance verticale? J'ai réussi à découper en insérant. Par rapport à la figure (b) où le centre vient d'être découpé, la partie que les humains veulent voir est placée dans un bon sentiment.

fig8.png Figure: Résultats du recadrage avec un rapport hauteur / largeur de 1: 1 à l'aide de la carte de saillance

C'est le cas pour un carré (1: 1). Lorsque le centre est découpé (Fig. (B)), le chat est fermement contenu, mais lors de l'utilisation de la carte de saillance (Fig. (A)), une zone plus étroite est découpée, donc elle est affichée dans la même taille. Si le chat grossit. Il est important non seulement de savoir si un objet est affiché, mais aussi s'il est montré dans une taille suffisante pour le recadrage.

Implémentation d'un modèle (SalGAN) qui estime la Saliency Map en utilisant le deep learning avec PyTorch

Vous ne pouvez pas recadrer l'image que vous avez préparée vous-même. Je veux recadrer l'image de l'arbre de Noël que j'ai pris moi-même, pas l'image du jeu de données SALICON, donc j'utiliserai l'apprentissage en profondeur pour créer un modèle d'estimation de la carte de saillance.

Si vous regardez le site de référence des tâches de Saliency Map "MIT Saliency Benchmark" [^ 4], vous pouvez voir différentes méthodes, mais cette fois je vais essayer d'implémenter SalGAN [^ 5]. Le score ne semble pas très élevé, mais j'ai choisi cela parce que le mécanisme semblait simple.

L'implémentation de l'auteur [^ 6] a également été publiée, mais comme le framework n'est pas très familier avec Lasagne (Theano), je vais l'écrire dans PyTorch en y faisant référence.

Qu'est-ce que SalGAN?

«SalGAN: Visual Saliency Prediction with Generative Adversarial Networks» est un article publié en 2017. Comme son nom l'indique, il s'agit d'une méthode d'estimation de la carte de saillance en utilisant ** GAN (Generative Adversarial Networks) **.

Je vais omettre l'explication sur le GAN car il existe déjà de nombreux articles faciles à comprendre. Par exemple, GAN (1) Comprendre la structure de base que je n'entends plus - Qiita est recommandé. Si vous connaissez une méthode GAN typique qui a beaucoup d'implémentation et d'explications, vous pouvez l'implémenter en tenant compte de la différence.

fig_salgan.png Figure: Structure globale de SalGAN (tirée de l'article [^ 5])

Puisque la carte de saillance a un point de vue (0 à 1) pour chaque pixel, on peut dire qu'il s'agit d'un problème de classification binaire pour chaque pixel. C'est proche de la segmentation à classe unique. Puisque nous voulons entrer une image et générer une image (Saliency Map), nous avons un ** modèle Encoder-Decoder ** qui utilise CNN comme indiqué dans cette figure. Pix2Pix [^ 7] est célèbre en matière d'image à image utilisant le GAN, mais il n'a pas une structure U-Net comme celle-là.

Dans le modèle Encoder-Decoder, vous pouvez également apprendre à réduire la carte de saillance en sortie et ** Entropie croisée binaire ** des données de réponse correctes. Cependant, ce SalGAN tente d'améliorer la précision en ajoutant un réseau (** Discriminator **) qui classe la carte de saillance comme des données correctes ou des données estimées.

La fonction de perte de la partie Encoder-Decoder (** Generator **) est la suivante. En plus de la perte d'adversaire habituelle, la carte de saillance estimée et la section Entropie croisée binaire des données de réponse correcte sont ajoutées. Ajustez le pourcentage avec le paramètre hyper $ \ alpha $.

\mathcal{L}\_{BCE} = -\frac{1}{N}\sum\_{j=1}^{N}(S\_{j}\log{(\hat{S}\_{j})}+(1-S\_{j})\log{(1-\hat{S}\_{j})}).
\mathcal{L} = \alpha\cdot\mathcal{L}\_{BCE} + L(D(I, \hat{S}), 1).

La fonction de perte de Discriminator est la suivante. C'est une forme générale.

\mathcal{L}\_{\mathcal{D}} = L(D(I, S), 1)+L(D(I, \hat{S}),0).

Lire un peu plus

En citant l'article, nous lirons les informations nécessaires à la mise en œuvre. J'ai en quelque sorte compris en regardant la structure globale, mais je chercherai la partie où sont écrites les informations que je veux en savoir un peu plus.

The encoder part of the network is identical in architecture to VGG-16 (Simonyan and Zisserman, 2015), omitting the final pooling and fully connected layers. The network is initialized with the weights of a VGG-16 model trained on the ImageNet data set for object classification (Deng et al., 2009). Only the last two groups of convolutional layers in VGG-16 are modified during the training for saliency prediction, while the earlier layers remain fixed from the original VGG-16 model.

--VGG16 est utilisé pour CNN de la partie encodeur du générateur

The decoder architecture is structured in the same way as the encoder, but with the ordering of layers reversed, and with pooling layers being replaced by upsampling layers. Again, ReLU non-linearities are used in all convolution layers, and a final 1 × 1 convolution layer with sigmoid non-linearity is added to produce the saliency map. The weights for the decoder are randomly initialized. The final output of the network is a saliency map in the same size to input image.

The input to the discriminator network is an RGBS image of size 256×192×4 containing both the source image channels and (predicted or ground truth) saliency.

--Injectez non seulement la carte de saillance, mais aussi l'image originale dans Discriminator sur 4 canaux. --En premier lieu, entrez l'image à 256 x 192

We train the networks on the 15,000 images from the SALICON training set using a batch size of 32.

--Utilisez 15000 images de l'ensemble de données SALICON

Je ne veux pas faire une expérience de reproduction du papier cette fois, donc je ne suis pas particulier sur les détails lors de sa mise en œuvre. Par exemple, au lieu de ReLU dans le document, Leaky ReLU, qui est généralement considéré comme efficace, est utilisé.

Ecrire le code

J'écrirai le code. Puisque la base est GAN du modèle Encoder-Decoder de CNN, nous nous référerons à l'implémentation de la méthode similaire existante. Par exemple, le GitHub d'eriklindernoren a divers GAN implémentés dans PyTorch. L'implémentation DCGAN [^ 8] semble bonne.

Créer une classe Generator and Discriminator

Generator utilise VGG16 formé avec ImageNet, qui est fourni dans torchvision [^ 9]. Dans SalGAN, le poids est fixé sur la face avant et l'apprentissage est sur la face arrière, alors décrivez-le en couches séparées comme torchvision.models.vgg16 (pretrained = True) .features [: 17]. Vous pouvez vérifier quel nombre est quel calque avec print (torchvision.models.vgg16 (pretrained = True) .features).

torchvision.models — PyTorch master documentation

Cliquez ici pour voir l'implémentation
from torch import nn
import torchvision

class Generator(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()        
        self.encoder_first = torchvision.models.vgg16(pretrained=True).features[:17] #La pièce à utiliser avec un poids fixe
        self.encoder_last = torchvision.models.vgg16(pretrained=True).features[17:-1] #Partie à apprendre
        self.decoder = nn.Sequential(
                    nn.Conv2d(512, 512, 3, padding=1), 
                    nn.LeakyReLU(),
                    nn.Conv2d(512, 512, 3, padding=1), 
                    nn.LeakyReLU(),
                    nn.Conv2d(512, 512, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Upsample(scale_factor=2),
                    nn.Conv2d(512, 512, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(512, 512, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(512, 512, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Upsample(scale_factor=2),
                    nn.Conv2d(512, 256, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(256, 256, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(256, 256, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Upsample(scale_factor=2),
                    nn.Conv2d(256, 128, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(128, 128, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Upsample(scale_factor=2),
                    nn.Conv2d(128, 64, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(64, 64, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(64, 1, 1, padding=0),
                    nn.Sigmoid())

    def forward(self, x):
        x = self.encoder_first(x)
        x = self.encoder_last(x)
        x = self.decoder(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
                    nn.Conv2d(4, 3, 1, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(3, 32, 3, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.MaxPool2d(2, stride=2),
                    nn.Conv2d(32, 64, 3, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(64, 64, 3, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.MaxPool2d(2, stride=2),
                    nn.Conv2d(64, 64, 3, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(64, 64, 3, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.MaxPool2d(2, stride=2))
        self.classifier = nn.Sequential(
                    nn.Linear(64*32*24, 100, bias=True),
                    nn.Tanh(),
                    nn.Linear(100, 2, bias=True),
                    nn.Tanh(),
                    nn.Linear(2, 1, bias=True),
                    nn.Sigmoid())

    def forward(self, x):
        x = self.main(x)
        x = x.view(x.shape[0], -1)
        x = self.classifier(x)
        return x

Créer une classe de jeu de données

Vous avez besoin d'une classe d'ensemble de données pour lire l'ensemble de données SALICON. Il est un peu difficile d'écrire en fonction de l'ensemble de données et de la tâche préparés. En ce qui concerne l'écriture, le didacticiel PyTorch sera utile.

Writing Custom Datasets, DataLoaders and Transforms — PyTorch Tutorials 1.3.1 documentation

Le prétraitement à l'aide de «torchvision.transforms» est également décrit ici. Cette fois, nous ne redimensionnerons qu'à 192 x 256 et normaliserons.

torchvision.transforms — PyTorch master documentation

L'ensemble de données SALICON peut être téléchargé depuis LSUN'17 Saliency Prediction Challenge | SALICON.

Cliquez ici pour voir l'implémentation
import os

import torch.utils.data as data
import torchvision.transforms as transforms

class SALICONDataset(data.Dataset):
    def __init__(self, root_dataset_dir, val_mode = False):
        """
Classe de jeu de données pour la lecture du jeu de données SALICON
        
        Parameters:
        -----------------
        root_dataset_dir : str
Le chemin du répertoire au-dessus du jeu de données SALICON
        val_mode : bool (default: False)
Si False, lisez les données du train. Si True, lisez les données de validation.
        """
        self.root_dataset_dir = root_dataset_dir
        self.imgsets_dir = os.path.join(self.root_dataset_dir, 'SALICON/image_sets')
        self.img_dir = os.path.join(self.root_dataset_dir, 'SALICON/imgs')
        self.distribution_target_dir = os.path.join(self.root_dataset_dir, 'SALICON/algmaps')
        self.img_tail = '.jpg'
        self.distribution_target_tail = '.png'
        self.transform = transforms.Compose([transforms.Resize((192, 256)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
        self.distribution_transform = transforms.Compose([transforms.Resize((192, 256)), transforms.ToTensor()])
        
        if val_mode:
            train_or_val = "val"
        else:
            train_or_val = "train"
        imgsets_file = os.path.join(self.imgsets_dir, '{}.txt'.format(train_or_val))
        files = []
        for data_id in open(imgsets_file).readlines():
            data_id = data_id.strip()
            img_file = os.path.join(self.img_dir, '{0}{1}'.format(data_id, self.img_tail))
            distribution_target_file = os.path.join(self.distribution_target_dir, '{0}{1}'.format(data_id, self.distribution_target_tail))
            files.append({
                'img': img_file,
                'distribution_target': distribution_target_file,
                'data_id': data_id
            })
        self.files = files
        
    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        """
        Returns
        -----------
        data : list
            [img, distribution_target, data_id]
        """
        data_file = self.files[index]
        data = []

        img_file = data_file['img']
        img = Image.open(img_file)
        data.append(img)

        distribution_target_file = data_file['distribution_target']
        distribution_target = Image.open(distribution_target_file)
        data.append(distribution_target)
        
        # transform
        data[0] = self.transform(data[0])
        data[1] = self.distribution_transform(data[1])

        data.append(data_file['data_id'])
        return data

apprendre

Écrivez le code pour le reste de l'apprentissage. Le point est de savoir comment calculer la fonction de perte et comment apprendre Générateur et Discriminateur.

Il faut environ plusieurs heures pour apprendre les mêmes 120 époques que le papier utilisant le GPU.

Cliquez ici pour voir l'implémentation
from datetime import datetime

import torch
from torch.autograd import Variable

#-----------------
# SETTING
root_dataset_dir = "" #Le chemin du répertoire au-dessus du jeu de données SALICON
alpha = 0.005 #Hyper paramètre de la fonction de perte du générateur. La valeur recommandée pour le papier est 0.005
epochs = 120
batch_size = 32 #32 dans le papier
#-----------------

#Utiliser l'heure de début pour le nom du fichier
start_time_stamp = '{0:%Y%m%d-%H%M%S}'.format(datetime.now())

save_dir = "./log/"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#Chargement du chargeur de données
train_dataset = SALICONDataset(
                    root_dataset_dir,
                )
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle=True, num_workers = 4, pin_memory=True, sampler=None)
val_dataset = SALICONDataset(
                    root_dataset_dir,
                    val_mode=True
                )
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 1, shuffle=False, num_workers = 4, pin_memory=True, sampler=None)

#Modèle de charge et fonction de perte
loss_func = torch.nn.BCELoss().to(DEVICE)
generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)

#Définition de la méthode d'optimisation (en utilisant les paramètres du papier)
optimizer_G = torch.optim.Adagrad([
                {'params': generator.encoder_last.parameters()},
                {'params': generator.decoder.parameters()}
            ], lr=0.0001, weight_decay=3*0.0001)
optimizer_D = torch.optim.Adagrad(discriminator.parameters(), lr=0.0001, weight_decay=3*0.0001)

#Apprentissage
for epoch in range(epochs):
    n_updates = 0 #Nombre d'itérations
    n_discriminator_updates = 0
    n_generator_updates = 0
    d_loss_sum = 0
    g_loss_sum = 0
    
    for i, data in enumerate(train_loader):
        imgs = data[0] # ([batch_size, rgb, h, w])
        salmaps = data[1] # ([batch_size, 1, h, w])

        #Créer une étiquette pour Discriminateur
        valid = Variable(torch.FloatTensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False).to(DEVICE)
        fake = Variable(torch.FloatTensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False).to(DEVICE)

        imgs = Variable(imgs).to(DEVICE)
        real_salmaps = Variable(salmaps).to(DEVICE)

        #Apprenez alternativement Générateur et Discriminateur pour chaque itération
        if n_updates % 2 == 0:
            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()
            gen_salmaps = generator(imgs)
            
            #Combinez l'image originale et la carte de saillance générée pour l'entrée dans le discriminateur pour créer un tableau à 4 canaux
            fake_d_input = torch.cat((imgs, gen_salmaps.detach()), 1) # ([batch_size, rgbs, h, w])
            
            #Calculer la fonction de perte du générateur
            g_loss1 = loss_func(gen_salmaps, real_salmaps)
            g_loss2 = loss_func(discriminator(fake_d_input), valid)
            g_loss = alpha*g_loss1 + g_loss2
            
            g_loss.backward()
            optimizer_G.step()
            
            g_loss_sum += g_loss.item()
            n_generator_updates += 1
            
        else:
            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()
            
            #Combinez l'image d'origine et la carte de saillance des données de réponse correctes pour l'entrée dans le discriminateur pour créer un tableau à 4 canaux
            real_d_input = torch.cat((imgs, real_salmaps), 1) # ([batch_size, rgbs, h, w])

            #Calculer la fonction de perte de Discriminator
            real_loss = loss_func(discriminator(real_d_input), valid)
            fake_loss = loss_func(discriminator(fake_d_input), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()
            
            d_loss_sum += d_loss.item()
            n_discriminator_updates += 1
            
        n_updates += 1
        if n_updates%10==0:
            if n_discriminator_updates>0:
                print(
                    "[%d/%d (%d/%d)] [loss D: %f, G: %f]"
                    % (epoch, epochs-1, i, len(train_loader), d_loss_sum/n_discriminator_updates , g_loss_sum/n_generator_updates)
                )
            else:
                print(
                    "[%d/%d (%d/%d)] [loss G: %f]"
                    % (epoch, epochs-1, i, len(train_loader), g_loss_sum/n_generator_updates)
                )                
    
    #Économiser des poids
    #Sauver toutes les 5 époques et la dernière époque
    if ((epoch+1)%5==0)or(epoch==epochs-1):
        generator_save_path = '{}.pkl'.format(os.path.join(save_dir, "{}_generator_epoch{}".format(start_time_stamp, epoch)))
        discriminator_save_path = '{}.pkl'.format(os.path.join(save_dir, "{}_discriminator_epoch{}".format(start_time_stamp, epoch)))
        torch.save(generator.state_dict(), generator_save_path)
        torch.save(discriminator.state_dict(), discriminator_save_path)
        
    #Visualisez une partie des données de validation pour chaque époque
    with torch.no_grad():
        print("validation")
        for i, data in enumerate(val_loader):
            image = Variable(data[0]).to(DEVICE)
            gen_salmap = generator(imgs)
            gen_salmap_np = np.array(gen_salmaps.data.cpu())[0, 0]
            
            plt.imshow(np.array(image[0].cpu()).transpose(1, 2, 0))
            plt.show()
            plt.imshow(gen_salmap_np)
            plt.show()
            if i==1:
                break

Estimer la carte de saillance

Entrez une image dans le SalGAN appris et essayez d'estimer la carte de saillance. Voyez comment il est estimé avec des images qui ne sont pas utilisées pour la formation.

Cliquez ici pour voir l'implémentation
generator_path = "" #Chemin du fichier de poids du générateur (pkl) obtenu par apprentissage
image_path = "COCO_train2014_000000196971.jpg " #Le chemin de l'image que vous souhaitez saisir

generator = Generator().to(DEVICE)
generator.load_state_dict(torch.load(generator_path))

image_pil = Image.open(image_path) #L'entrée d'image au format PIL est supposée pour la transformation
image = np.array(image_pil)
plt.imshow(image)
plt.show()

transform = transforms.Compose([transforms.Resize((192, 256)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
image_torch = transform(image_pil)
image_torch = image_torch.unsqueeze(0).to(DEVICE) # (1, rgb, h, w)
with torch.no_grad():
    pred_saliencymap = generator(img_torch)
    pred_saliencymap = np.array(pred_saliencymap.cpu())[0, 0]
pred_saliencymap = pred_saliencymap/pred_saliencymap.sum() #Mise à l'échelle pour que la somme soit 1
pred_saliencymap = ((pred/pred.max())*255).astype(np.uint8) #Np pour qu'il puisse être traité comme une image.Convertir en uint8
plt.imshow(pred_saliencymap)
plt.show()

fig_15.png

Figure: Exemple de carte de saillance estimée par SalGAN

La carte de saillance de cette figure (b) a été estimée. Comparé aux données de réponse correcte (Fig. (C)), cela donne une grande impression, mais la probabilité de convaincre des points autour du lanceur et du frappeur peut être estimée élevée.

Si vous voulez bien apprendre, vous devez vérifier avec divers indicateurs pour la carte de saillance comme cela est fait dans le document. Comme nous ne l'avons pas vérifié maintenant, on ne sait pas combien de résultats sont obtenus par rapport au SalGAN introduit dans l'article. Cette fois, l'essentiel est le recadrage, alors j'aimerais passer à autre chose parce que j'ai fait quelque chose comme ça qualitativement.

Recadrer l'image avec la carte de saillance estimée

Maintenant vous avez ce que vous voulez faire. En combinant la carte de salence estimée par SalGAN avec la classe de recadrage, vous pouvez recadrer l'image joliment.

Cliquez ici pour voir l'implémentation
# -------------------
# SETTING
threshhold = 0.3 #Définir le seuil, float (0<threshhold<1)
generator_path = "" #Chemin du fichier de poids du générateur (pkl) obtenu par apprentissage
image_path = "COCO_train2014_000000196971.jpg " #Chemin de l'image que vous souhaitez recadrer
# -------------------

generator = Generator().to(DEVICE)
generator.load_state_dict(torch.load(generator_path))

image_pil = Image.open(image_path) #L'entrée d'image au format PIL est supposée pour la transformation
image = np.array(image_pil)
plt.imshow(image)
plt.show()

transform = transforms.Compose([transforms.Resize((192, 256)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
image_torch = transform(image_pil)
image_torch = image_torch.unsqueeze(0).to(DEVICE) # (1, rgb, h, w)
with torch.no_grad():
    pred_saliencymap = generator(image_torch)
    pred_saliencymap = np.array(pred_saliencymap.cpu())[0, 0]
pred_saliencymap = pred_saliencymap/pred_saliencymap.sum() #Mise à l'échelle pour que la somme soit 1
pred_saliencymap = ((pred_saliencymap/pred_saliencymap.max())*255).astype(np.uint8) #Np pour qu'il puisse être traité comme une image.Convertir en uint8
plt.imshow(pred_saliencymap)
plt.show()

#Visualisation d'images recadrées à l'aide de la carte de saillance
cropper = SaliencyBasedImageCropping(aspect_rate=(1, 3))
cropped_image = cropper.crop(image, pred_saliencymap, threshhold=0.3)
plt.imshow(cropped_image)
plt.show()

#Visualisation de la carte de saillance et de la boîte englobante
#Rouge pour le rapport hauteur / largeur spécifié, vert avant la correspondance
saliencymap_colored = color_saliencymap(pred_saliencymap) # ndarray, (h, w, rgb), np.uint8
overlaid_image = overlay_saliencymap_and_image(saliencymap_colored, image) # ndarray, (h, w, rgb), np.uint8
box_x, box_y, box_width, box_height = cropper.bounding_box
box_x_0, box_y_0, box_width_0, box_height_0 = cropper.bounding_box_based_on_binary_saliency
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#FF0000', fill=False)
bounding_box_based_on_binary_saliency = patches.Rectangle(xy=(box_x_0, box_y_0), width=box_width_0, height=box_height_0, ec='#00FF00', fill=False)
ax.imshow(overlaid_image)
ax.add_patch(bounding_box)
ax.add_patch(bounding_box_based_on_binary_saliency)
plt.show()

#Visualisation de l'image avec centre recadré pour comparaison
center_cropped_image = cropper.crop_center(image)
plt.imshow(center_cropped_image)
plt.show()

fig_12.png Figure: Comparaison du recadrage lors de l'utilisation de données de réponse correctes et lors de l'utilisation de SalGAN (image de baseball)

Lors de l'utilisation des données de réponse correctes sur la figure (A) et lors de l'utilisation de SalGAN sur la figure (B), presque le même résultat a été obtenu en découpant la partie de la pâte. Quel est le frappeur ou le lanceur? Je ne pense pas avoir suffisamment appris à ce sujet, mais je suis heureux de l'avoir traversé de la même manière.

Ce genre d'histoire n'est-il pas une sélection de cerises qui ne répertorie que les résultats réussis? J'ai une question. Si vous souhaitez mesurer ce résultat de manière quantitative, vous pouvez voir comment ces deux types se chevauchent avec l'ensemble de données SALICON. Calculs de type IoU dans les tâches de détection d'objets. Cependant, je vais l'omettre car c'est une histoire que je viens de faire maintenant.

L'image du chat qui est apparue dans la première moitié de l'article a été utilisée parce qu'elle est mignonne, mais il s'agit en fait de données Train, donc elle ne convient pas à la vérification. Mais jetons un œil.

fig_11.png Figure: Comparaison du recadrage lors de l'utilisation de données de réponse correctes et lors de l'utilisation de SalGAN (image de chat)

Presque le même résultat a été obtenu ici. C'était bon.

Découpez le sapin de Noël

Enfin, nous reviendrons sur l'image du premier sapin de Noël. Si vous pouvez faire un recadrage convaincant avec une photo que vous avez prise qui ne figure pas dans l'ensemble de données, vous avez atteint votre objectif.

fig_13.png Figure: Comparaison de la culture avec SalGAN et simplement centrée (image de l'arbre de Noël)

J'ai obtenu le résultat parfait. La figure (a) avec l'arbre de Noël est meilleure que la figure (b) sans elle. AI? Nous avons mis au point un mécanisme qui peut automatiquement faire des choses qui sont proches de ce que les gens font avec le pouvoir de. Cela peut réduire la déception de recadrer les zones vierges ou réduire le travail de recadrage manuel des images.

Résumé

C'était l'implémentation de deux articles, recadrant en utilisant Saliency Map [^ 2] et SalGAN [^ 5] pour estimer Saliency Map + alpha.

Vous pouvez créer quelque chose comme ça avec uniquement les informations publiques. Même si vous avez arrêté de déplacer quelque chose comme un tutoriel en deep learning ou en machine learning, j'aimerais que vous releviez un petit défi et que vous essayiez de faire quelque chose comme ça!

Recommended Posts

J'ai créé une IA qui recadre joliment une image en utilisant Saliency Map
J'ai créé une application Android qui affiche Google Map
Création d'un modèle de discrimination d'image (cifar10) à l'aide d'un réseau neuronal convolutif
J'ai créé un package extenum qui étend enum
J'ai créé un LINE BOT qui renvoie une image de riz terroriste en utilisant l'API Flickr
J'ai créé une IA qui prédit des anecdotes et m'a fait déduire mes anecdotes. Hee-AI
J'ai créé un installateur Ansible
J'ai créé un modèle de détection d'anomalies qui fonctionne sur iOS
J'ai créé un guide de programme original en utilisant l'API de guide de programme NHK.
J'ai créé un serveur Xubuntu.
[Python] J'ai créé une visionneuse d'images avec une fonction de tri simple.
J'ai fait un peintre discriminateur Anpanman
J'ai fait un kit de démarrage angulaire
J'ai fait une IA pour juger si c'est de l'alcool ou non!
J'ai créé un plug-in "EZPrinter" qui génère facilement des PDF cartographiques avec QGIS.
Je souhaite afficher une image sur Jupyter Notebook à l'aide d'OpenCV (mac)
J'ai créé une extension Chrome qui affiche un graphique sur la page Amedas
ConSinGAN: J'ai essayé d'utiliser le GAN qui peut être généré à partir d'une image