[PYTHON] Ich habe eine KI erstellt, die ein Bild mit Saliency Map gut zuschneidet

Einführung

In diesem Artikel werde ich die Methode zum Zuschneiden von Bildern mithilfe von Saliency Map mithilfe von Deep Learning mit Python / PyTorch beim Lesen des Papiers implementieren.

Wenn wir über Bilder im Deep Learning sprechen, versuchen wir oft, handschriftliche Zahlen zu klassifizieren und Personen zu erkennen, aber ich hoffe, Sie werden sehen, dass Sie dies auch tun können.

Dieser Artikel nimmt an [DeNA 20 New Graduate Adventskalender 2019 - Qiita] teil (https://qiita.com/advent-calendar/2019/dena-20-shinsostu). Vielen Dank an den Adventskalender, der mir die Möglichkeit gegeben hat, es zu schaffen!

Annahme des Lesers

Da es verschiedene Genres von Adventskalendern gibt, ist der Artikel, wenn Sie ihn nur lesen, für alle gedacht, die das Programm berührt haben. Um es zu bewegen, wird es für diejenigen angenommen, die Deep-Learning-Tutorial-ähnliche Dinge getan haben.

Der für Jupyter Notebook angenommene Code ist zum einfachen Testen enthalten, sodass Sie ihn zur Hand haben können. Die Anzeige ist reduziert, klicken Sie also, um sie bei Bedarf zu öffnen.

Die Bibliothek verwendet nur die bereits in [Google Colaboratory] installierte Bibliothek (https://colab.research.google.com/). Aufgrund des großen Datensatzes kann es etwas schwierig sein, es zu versuchen, bis Sie trainieren.

Bildausschnitt

Manchmal möchte ich ein Bild auf irgendeine Weise zuschneiden. Zum Beispiel ist das Symbolbild ungefähr quadratisch, daher denke ich, dass jeder darüber nachgedacht hat, wie es bei der Registrierung für verschiedene Dienste geschnitten werden soll. Außerdem ist das Header-Bild zur Hälfte horizontal lang, und die Form des Bildes wird häufig vor Ort festgelegt. Wenn der Benutzer es in eine feste Form schneidet, können Sie Ihr Bestes geben, damit es sich gut anfühlt. In vielen Fällen ist es jedoch erforderlich, es auf der Anwendungsseite zu automatisieren.

Ein kleines Beispiel

Angenommen, Sie möchten, dass das veröffentlichte Bild immer vertikal (1: 3) auf einer Seite angezeigt wird. Es ist vertikal lang, weil es ein Zustand ist, der schwer zu schneiden scheint.

Dieses Foto habe ich mit "Es ist eine schöne Lobby mit einem Weihnachtsbaum" gemacht. Wenn Sie es selbst schneiden, werde ich es natürlich so machen, um den Weihnachtsbaum zu zeigen.

Es ist einer Person jedoch nicht möglich, alle in großer Anzahl veröffentlichten Bilder zu sehen und auszuschneiden, sodass sie automatisiert werden. Nun, ich habe beschlossen, es in Python zu implementieren, weil es sicher wäre, die Mitte zu schneiden.

Klicken Sie hier, um die Implementierung
anzuzeigen
import numpy as np
import cv2
import matplotlib.pyplot as plt

def crop(image, aspect_rate=(1, 1)):
    """     
Schneiden Sie das Bild von der Mitte aus so aus, dass es das angegebene Seitenverhältnis hat.
    
    Parameters:
    -----------------
    image : ndarray, (h, w, rgb), uint8
    aspect_rate : tuple of int (x, y)
        default : (1, 1)

    Returns:
    -----------------
    cropped_image : ndarray, (h, w, rgb), uint8
    """        
    assert image.dtype==np.uint8
    assert image.ndim==3        

    im_size = (image.shape[1], image.shape[0]) # tuple of int, (width, height)
    center = (int(round(im_size[0]/2)), int(round(im_size[1]/2))) # tuple of int, (x, y)

    #Finden Sie die folgenden vier Werte
    # box_x : int,X-Koordinate oben links zum Zuschneiden, box_y : int,Oben links y-Koordinate zum Zuschneiden
    # box_width : int,Ausschnittbreite, box_height : int,Ausschnitthöhe
    if im_size[0]>im_size[1]:
        box_y = 0
        box_height = im_size[1]
        box_width = int(round((im_size[1]/aspect_rate[1])*aspect_rate[0]))
        if box_width>im_size[0]:
            box_x = 0
            box_width = im_size[0]
            box_height = int(round((im_size[0]/aspect_rate[0])*aspect_rate[1]))
            box_y = int(round(center[1]-(box_height/2)))
        else:
            box_x = int(round(center[0]-(box_width/2)))
    else:
        box_x = 0
        box_width = im_size[0]
        box_height = int(round((im_size[0]/aspect_rate[0])*aspect_rate[1]))
        if box_height>im_size[1]:
            box_y = 0
            box_height = im_size[1]
            box_width = int(round((im_size[1]/aspect_rate[1])*aspect_rate[0]))
            box_y = int(round(center[0]-(box_width/2)))
        else:
            box_y = int(round(center[1]-(box_height/2)))

    cropped_image = image[box_y:box_y+box_height, box_x:box_x+box_width]
    return cropped_image
    
#image: Lesen Sie das Image mit OpenCV usw. und machen Sie es zu einem NumPy-Array
image = cv2.imread("tree.jpg ")[:, :, ::-1]
cropped_image = crop(image, aspect_rate=(1, 3))
plt.imshow(cropped_image)
plt.show()

Unter der Annahme, dass die gesamte lange Seite des Bildes verwendet wird, wird es aus dem Seitenverhältnis unter Berücksichtigung der Länge der kurzen Seite zu diesem Zeitpunkt berechnet.

Ich werde es versuchen.

Das Weihnachtselement ist weg und es ist nur ein schönes Lobbyfoto. Das ist schlecht. KI? Lassen Sie uns etwas mit der Kraft von tun.

Dinge, auf die man sich beziehen sollte

Dieses Mal werde ich versuchen, die mit der Saliency Map zu imitieren, die Twitter und Adobe in den letzten zwei Jahren eingeführt haben. Wenn Sie auf Twitter [^ 1] ein Bild veröffentlichen, wird es auf der Timeline gut angezeigt. Darüber hinaus verfügt InDesigin von Adobe über eine Funktion namens Content-Aware Fit, mit der das Bild gemäß dem angegebenen Bereich zugeschnitten wird.

[^ 1]: Einführung eines neuronalen Netzwerks, das Bilder optimal und automatisch schneidet https://blog.twitter.com/ja_jp/topics/product/2018/0125ML-CR.html

Die Objekterkennung kann als Vergleichsmethode verwendet werden. Die auf Saliency Map basierende Methode ist jedoch insofern vielseitig, als sie nicht immer die Objekte mit den trainierten Beschriftungen anzeigt.

Beschneiden mit Saliency Map

Eine Zuschneidemethode unter Verwendung der Saliency Map [^ 2] wurde 2013 von Ardizzones Artikel "Saliency Based Image Cropping" vorgeschlagen.

Was ist Saliency Map?

Wohin geht die Sichtlinie, wenn eine Person das Bild sieht? ** Saliency Map ** ist eine pixelbasierte Version von. Zum Beispiel wird dies unten links in der Figur durch Messen von vielen Personen erhalten, und die Ausprägungskarte wird durch Berechnen solcher Dinge erhalten. In dieser Figur ist die Wahrscheinlichkeit, einen Blickwinkel zu haben, umso geringer, je weißer der Bereich ist, je höher die Wahrscheinlichkeit ist, einen Blickwinkel zu haben, und der schwarze Bereich.

Abbildung: Beispiel einer Saliency Map. Oben links: Bild. Oben rechts: Der gemessene Blickwinkel wird durch ein rotes X angezeigt. Unten links: Saliency Map. Unten rechts: Saliency Map in Farbe und überlagert das Bild.

Diese Abbildung ist eine Visualisierung der Trainingsdaten des SALICON-Datensatzes [^ 3]. Das rote X oben rechts sind die Ansichtspunktdaten, die erhalten werden, wenn viele Personen das Bild oben links betrachten und den Teil, den Sie betrachten, mit dem Mauszeiger berühren.

Wenn Sie einen Gaußschen Filter basierend auf diesen Daten anwenden, können Sie eine Karte erstellen, die die Wahrscheinlichkeit (0 bis 1) anzeigt, dass es einen Ansichtspunkt in Pixeleinheiten gibt, wie unten links gezeigt. Dies sind die Trainingsdaten der Saliency Map, die Sie berechnen möchten.

Wie unten rechts gezeigt, können Sie beim Ausmalen und Überlagern des Bildes mit hoher Wahrscheinlichkeit feststellen, dass die Katze bemerkt wird. Wenn die Blickwinkelwahrscheinlichkeit nahe bei 1 liegt, ist sie rot, und wenn die Blickwinkelwahrscheinlichkeit nahe bei 0 liegt, ist sie blau.

Implementierte Ardizzone-Methode

Lassen Sie uns die Methode von Ardizzone implementieren. Saliency Map verwendet vorerst die Trainingsdaten des SALICON-Datensatzes unverändert. Dies ist das Bild der Katze und die Lerndaten (korrekte Antwortdaten) der Saliency Map dafür.

Was für eine Methode?

Es ist eine Methode zum Zuschneiden, um alle Pixel über einer bestimmten Wahrscheinlichkeit einzuschließen. Dies bedeutet, dass Sie es nur an Orten platzieren sollten, an denen Sie es wahrscheinlich sehen werden.

fig2.png Abbildung: Ardizzones Methodenpipeline (zitiert aus dem Artikel [^ 2])

Um diese Zahl in Worten zusammenzufassen, gibt es die folgenden drei Schritte.

--Dualisieren Sie die Saliency Map mit einem bestimmten Schwellenwert (setzen Sie ihn auf 1 und 0). --Finden Sie einen Begrenzungsrahmen, der den Bereich von 1 einschließt

Binar

Die Binarisierung ist mit NumPy einfach. NumPy sendet auch die Berechnung der Vergleichsoperatoren (> und ==). Wenn Sie also ndarray> float ausführen, erhalten Sie für jedes Element True oder False, und die Binärisierung ist abgeschlossen. ..

Klicken Sie hier, um die Implementierung
anzuzeigen
threshhold = 0.3 #Schwellenwert einstellen, float (0<threshhold<1)
saliencymap_path = 'COCO_train2014_000000196971.png' #Saliency Map Pfad

saliencymap = cv2.imread(saliencymap_path)[:, :, ::-1] # ndarray, (h, w, rgb), np.uint8 (0-255)
saliencymap = saliencymap[:, :, 0] # ndarray, (h, w), np.uint8 (0-255)
plt.imshow(saliencymap)
plt.show()

threshhold *= 255 #Die aus dem Bild gelesene Saliency Map ist 0-Da es 255 ist, konvertieren Sie den Bereich

binarized_saliencymap = saliencymap>threshhold # ndarray, (h, w), bool

plt.imshow(binarized_saliencymap)
plt.show()

fig3.png Abbildung: Ergebnisse der Binarisierung

Das Ergebnis ist in dieser Abbildung dargestellt. Standardmäßig zeigt matplotlibs plt.imshow () große Werte in Gelb und kleine Werte in Lila.

Der Schwellenwert ist ein Hyperparameter, der beliebig eingestellt werden kann. Dieses Mal wird es im gesamten Artikel auf 0,3 vereinheitlicht.

Fragen Sie nach einem Begrenzungsrahmen

Berechnen Sie einen ** Begrenzungsrahmen ** (nur ein eingeschlossenes Rechteck), der alle durch Binärisierung erhaltenen Einsen (Wahr) enthält.

Dies ist in OpenCVs "cv2.boundingRect ()" implementiert und kann durch einfaches Aufrufen erreicht werden.

Structural Analysis and Shape Descriptors — OpenCV 2.4.13.7 documentation

[Funktionen für Bereiche (Konturen) - Dokumentation zu OpenCV-Python-Tutorials 1](http://labs.eecs.tottori-u.ac.jp/sd/Member/oyamada/OpenCV/html/py_tutorials/py_imgproc/py_contours/py_contour_features/ py_contour_features.html)

Verwenden Sie patches.Rectangle (), um ein Rechteck in matplotlib zu zeichnen.

matplotlib.patches.Rectangle — Matplotlib 3.1.1 documentation

Klicken Sie hier, um die Implementierung
anzuzeigen
import matplotlib.patches as patches

#Konvertieren Sie in ein Format, das von OpenCV verarbeitet werden kann
binarized_saliencymap = binarized_saliencymap.astype(np.uint8) # ndarray, (h, w), np.uint8 (0 or 1)

box_x, box_y, box_width, box_height = cv2.boundingRect(binarized_saliencymap)
# box_x : int,X-Koordinate oben links zum Zuschneiden, box_y : int,Oben links y-Koordinate zum Zuschneiden
# box_width : int,Ausschnittbreite, box_height : int,Ausschnitthöhe

#Begrenzungsrahmenzeichnung
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#00FF00', fill=False)
ax.imshow(binarized_saliencymap)
ax.add_patch(bounding_box)
plt.show()

fig4.png Abbildung: Ergebnis des Abrufs des Begrenzungsrahmens

Sie können den Begrenzungsrahmen wie in dieser Abbildung gezeigt erhalten. Die Informationen des Rechtecks werden als obere linke Koordinate und als Breiten- / Höhenwert gespeichert.

ausgeschnitten

Beschneiden Sie das Bild basierend auf dem erhaltenen Begrenzungsrahmen. Schneiden Sie das ndarray des Bildes mit dem Wert im Begrenzungsrahmen.

Klicken Sie hier, um die Implementierung
anzuzeigen
image_path = 'COCO_train2014_000000196971.jpg' #Bildpfad
image = cv2.imread(image_path)[:, :, ::-1] # ndarray, (h, w, rgb), np.uint8 (0-255)

cropped_image = image[box_y:box_y+box_height, box_x:box_x+box_width] # ndarray, (h, w, rgb), np.uint8 (0-255)

#Visualisierung
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#00FF00', fill=False)
ax.imshow(image)
ax.add_patch(bounding_box)
plt.show()

plt.imshow(cropped_image)
plt.show()

fig5.png Abbildung: Ergebnisse nach Ardizzones Methode

Wie in dieser Figur gezeigt, wurde ein Bild erhalten, in das die Sichtlinie wahrscheinlich gerichtet war.

Überlagerung mit der kolorierten Saliency Map

Um zu sehen, wie es verarbeitet wurde, überlagern wir das Bild mit der kolorierten Saliency Map und dem Begrenzungsrahmen. Implementieren Sie eine Funktion zum Färben der Saliency Map und eine Funktion zum Überlagern der Saliency Map auf dem Bild.

Klicken Sie hier, um die Implementierung
anzuzeigen
def color_saliencymap(saliencymap):
    """
Färben und visualisieren Sie die Saliency Map. 1 ist rot und 0 ist blau.
    
    Parameters
    ----------------
    saliencymap : ndarray, np.uint8, (h, w) or (h, w, rgb)
    
    Returns
    ----------------
    saliencymap_colored : ndarray, np.uint8, (h, w, rgb)
    """
    assert saliencymap.dtype==np.uint8
    assert (saliencymap.ndim == 2) or (saliencymap.ndim == 3)
    
    saliencymap_colored = cv2.applyColorMap(saliencymap, cv2.COLORMAP_JET)[:, :, ::-1]
    
    return saliencymap_colored

def overlay_saliencymap_and_image(saliencymap_color, image):
    """
Überlagern Sie das Bild mit der Saliency Map.
    
    Parameters
    ----------------
    saliencymap_color : ndarray, (h, w, rgb), np.uint8
    image : ndarray, (h, w, rgb), np.uint8
    
    Returns
    ----------------
    overlaid_image : ndarray(h, w, rgb)
    """
    assert saliencymap_color.ndim==3
    assert saliencymap_color.dtype==np.uint8
    assert image.ndim==3
    assert image.dtype==np.uint8
    im_size = (image.shape[1], image.shape[0])
    saliencymap_color = cv2.resize(saliencymap_color, im_size, interpolation=cv2.INTER_CUBIC)
    overlaid_image = cv2.addWeighted(src1=image, alpha=1, src2=saliencymap_color, beta=0.7, gamma=0)
    return overlaid_image

saliencymap_colored = color_saliencymap(saliencymap) # ndarray, (h, w, rgb), np.uint8
overlaid_image = overlay_saliencymap_and_image(saliencymap_colored, image) # ndarray, (h, w, rgb), np.uint8

#Visualisierung
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#00FF00', fill=False)
ax.imshow(overlaid_image)
ax.add_patch(bounding_box)
plt.show()

fig_6.png Abbildung: Bild der kolorierten Saliency Map und des darüber liegenden Begrenzungsrahmens

Wie in dieser Abbildung gezeigt, sind die Bereiche auf der Saliency Map, die mit hoher Wahrscheinlichkeit rot werden, umgeben.

Unterstützung für jedes Seitenverhältnis

Bei der Methode von Ardizzone hängen Größe und Seitenverhältnis von der Saliency Map ab. Aber jetzt, wo ich auf ein bestimmtes Seitenverhältnis zuschneiden möchte, muss ich darüber nachdenken.

Schneiden Sie es so aus, dass der Gesamtwert der Saliency Map groß ist

Ich konnte keine vorhandene Methode dafür finden und entschied mich daher, den folgenden Algorithmus zu verwenden, um den auszuschneidenden Bereich zu bestimmen.

Verwenden Sie den so weit wie möglich erhaltenen Bereich und suchen Sie nach dem Bereich, der den Gesamtwert der Saliency Map maximiert.

Erstellen Sie eine "SaliencyBasedImageCropping-Klasse" zum Zuschneiden und fassen Sie den bisherigen Code zusammen.

Klicken Sie hier, um die Implementierung
anzuzeigen
import copy

import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np

class SaliencyBasedImageCropping:
    """
Eine Klasse zum Zuschneiden von Bildern mit Saliency Map. Eine Methode, die den gesamten Bereich verwendet, der einen bestimmten Schwellenwert überschreitet[1]Benutzen.
    
* Wenn keine Pixel den Schwellenwert überschreiten, wird das gesamte Bild zurückgegeben.

    [1] Ardizzone, Edoardo, Alessandro Bruno, and Giuseppe Mazzola. "Saliency based image cropping." International Conference on Image Analysis and Processing. Springer, Berlin, Heidelberg, 2013.
        
    Parameters
    ----------------
    aspect_rate : tuple of int (x, y)
Wenn Sie hier das Seitenverhältnis angeben,[1]Suchen Sie den Bereich, der den Gesamtwert der Saliency Map maximiert, während Sie den Bereich verwenden, der mit der Methode von erhalten wurde.
    min_size : tuple of int (w, h)
        [1]Wenn jede Achse des durch das Verfahren von erhaltenen Bereichs kleiner als dieser Wert ist, wird der Bereich ausgehend von der Mitte des Bereichs gleichmäßig erweitert.
    
    Attributes
    ----------------
    self.aspect_rate : tuple of int (x, y)
    self.min_size : tuple of int (w, h)
    im_size : tuple of int (w, h)
    self.bounding_box_based_on_binary_saliency : list
        [1]Bereich erhalten durch die Methode von
        box_x : int
        box_y : int
        box_width : int
        box_height : int
    self.bounding_box : list
Der endgültige Zuschneidebereich mit angepasstem Seitenverhältnis
        box_x : int
        box_y : int
        box_width : int
        box_height : int
    """
    def __init__(self, aspect_rate=None, min_size=(200, 200)):
        assert (aspect_rate is None)or((type(aspect_rate)==tuple)and(len(aspect_rate)==2))
        assert (type(min_size)==tuple)and(len(min_size)==2)
        self.aspect_rate = aspect_rate
        self.min_size = min_size
        self.im_size = None
        self.bounding_box_based_on_binary_saliency = None
        self.bounding_box = None
    
    def _compute_bounding_box_based_on_binary_saliency(self, saliencymap, threshhold):
        """
Ardizzones Methode[1]Finden Sie den Erntebereich anhand der Saliency Map.
        
        Parameters:
        -----------------
        saliencymap : ndarray, (h, w), np.uint8
            0<=saliencymap<=255
            
        threshhold : float
            0<threshhold<255
        
        Returns:
        -----------------
        bounding_box_based_on_binary_saliency : list
            box_x : int
            box_y : int
            box_width : int
            box_height : int
        
        """
        assert (threshhold>0)and(threshhold<255)
        assert saliencymap.dtype==np.uint8
        assert saliencymap.ndim==2
        
        binarized_saliencymap = saliencymap>threshhold
        #Wenn die Saliency Map keine Pixel enthält, die den Schwellenwert überschreiten, behandeln Sie alle als überschritten.
        if saliencymap.sum()==0:
            saliencymap+=True
        binarized_saliencymap = (binarized_saliencymap.astype(np.uint8))*255
        # binarized_saliencymap : ndarray, (h, w), uint8, 0 or 255
        
        #Kleine Bereiche werden durch Morphologieverarbeitung gelöscht (Öffnen)
        kernel_size = round(min(self.im_size)*0.02)
        kernel = np.ones((kernel_size, kernel_size))
        binarized_saliencymap = cv2.morphologyEx(binarized_saliencymap, cv2.MORPH_OPEN, kernel)

        box_x, box_y, box_width, box_height = cv2.boundingRect(binarized_saliencymap)
        bounding_box_based_on_binary_saliency = [box_x, box_y, box_width, box_height]
        return bounding_box_based_on_binary_saliency
        
    def _expand_small_bounding_box_to_minimum_size(self, bounding_box):
        """
Wenn der Bereich kleiner als die angegebene Größe ist, erweitern Sie ihn. Verteilen Sie den Bereich gleichmäßig von der Mitte des Bereichs aus. Wenn das Bild nicht mehr angezeigt wird, verteilen Sie es auf der gegenüberliegenden Seite.
        
        Parameters:
        -----------------
        bounding_box : list
            box_x : int
            box_y : int
            box_width : int
            box_height : int
        """
        bounding_box = copy.copy(bounding_box) #Deep Copy, weil ich die Werte der ursprünglichen Liste behalten möchte
        
        # axis=0 : x and witdth, axis=1 : y and hegiht
        for axis in range(2):
            if bounding_box[axis+2]<self.min_size[axis+0]:
                bounding_box[axis+0] -= np.floor((self.min_size[axis+0]-bounding_box[axis+2])/2).astype(np.int)
                bounding_box[axis+2] = self.min_size[axis+0]
                if bounding_box[axis+0]<0:
                    bounding_box[axis+0] = 0
                if (bounding_box[axis+0]+bounding_box[axis+2])>self.im_size[axis+0]:
                    bounding_box[axis+0] -= (bounding_box[axis+0]+bounding_box[axis+2]) - self.im_size[axis+0]
        return bounding_box
    
    def _expand_bounding_box_to_specified_aspect_ratio(self, bounding_box, saliencymap):
        """
Erweitern Sie den Bereich so, dass er das angegebene Seitenverhältnis hat.
Ardizzones Methode[1]Suchen Sie den Bereich, der den Gesamtwert der Saliency Map maximiert, während Sie den in Schritt 2 erhaltenen Bereich so weit wie möglich verwenden.
        
        Parameters
        ----------------
        bounding_box : list
            box_x : int
            box_y : int
            box_width : int
            box_height : int
        saliencymap : ndarray, (h, w), np.uint8
            0<=saliencymap<=255
        """
        assert saliencymap.dtype==np.uint8
        assert saliencymap.ndim==2
        
        bounding_box = copy.copy(bounding_box)

        # axis=0 : x and witdth, axis=1 : y and hegiht    
        if bounding_box[2]>bounding_box[3]:
            long_length_axis = 0
            short_length_axis = 1
        else:
            long_length_axis = 1
            short_length_axis = 0
        
        #In welche Richtung soll man sich strecken?
        rate1 = self.aspect_rate[long_length_axis]/self.aspect_rate[short_length_axis]
        rate2 = bounding_box[2+long_length_axis]/bounding_box[2+short_length_axis]
        if rate1>rate2:
            moved_axis = long_length_axis
            fixed_axis = short_length_axis
        else:
            moved_axis = short_length_axis
            fixed_axis = long_length_axis
        
        fixed_length = bounding_box[2+fixed_axis]
        moved_length = int(round((fixed_length/self.aspect_rate[fixed_axis])*self.aspect_rate[moved_axis]))
        if moved_length > self.im_size[moved_axis]:
            #Wenn es beim Strecken die Größe des Bildes überschreitet
            moved_axis, fixed_axis = fixed_axis, moved_axis
            fixed_length = self.im_size[fixed_axis]
            moved_length = int(round((fixed_length/self.aspect_rate[fixed_axis])*self.aspect_rate[moved_axis]))
            fixed_point = 0
            start_point = bounding_box[moved_axis]
            end_point = bounding_box[moved_axis]+bounding_box[2+moved_axis]
            if fixed_axis==0:
                saliencymap_extracted = saliencymap[start_point:end_point, :]
            elif fixed_axis==1:
                saliencymap_extracted = saliencymap[:, start_point:end_point:]
        else:   
            #Wenn gedehnt, um in die Größe des Bildes zu passen
            start_point = int(bounding_box[moved_axis]+bounding_box[2+moved_axis]-moved_length)
            if start_point<0:
                start_point = 0
            end_point = int(bounding_box[moved_axis]+moved_length)
            if end_point>self.im_size[moved_axis]:
                end_point = self.im_size[moved_axis]
            if fixed_axis==0:
                fixed_point = bounding_box[fixed_axis]
                saliencymap_extracted = saliencymap[start_point:end_point, fixed_point:fixed_point+fixed_length]
            elif fixed_axis==1:
                fixed_point = bounding_box[fixed_axis]
                saliencymap_extracted = saliencymap[fixed_point:fixed_point+fixed_length, start_point:end_point]
        saliencymap_summed_1d = saliencymap_extracted.sum(moved_axis)
        self.saliencymap_summed_slided = np.convolve(saliencymap_summed_1d, np.ones(moved_length), 'valid')
        moved_point = np.array(self.saliencymap_summed_slided).argmax() + start_point
        
        if fixed_axis==0:
            bounding_box = [fixed_point, moved_point, fixed_length, moved_length]
        elif fixed_axis==1:
            bounding_box = [moved_point, fixed_point, moved_length, fixed_length]
        return bounding_box
    
    def crop_center(self, image):
        """     
Beschneiden Sie die Bildmitte mit dem angegebenen Seitenverhältnis, ohne die Saliency Map zu verwenden.
        
        Parameters:
        -----------------
        image : ndarray, (h, w, rgb), uint8
            
        Returns:
        -----------------
        cropped_image : ndarray, (h, w, rgb), uint8
        """        
        assert image.dtype==np.uint8
        assert image.ndim==3        
        
        im_size = (image.shape[1], image.shape[0]) # tuple of int, (width, height)
        center = (int(round(im_size[0]/2)), int(round(im_size[1]/2))) # tuple of int, (x, y)
        
        if im_size[0]>im_size[1]:
            box_y = 0
            box_height = im_size[1]
            box_width = int(round((im_size[1]/self.aspect_rate[1])*self.aspect_rate[0]))
            if box_width>im_size[0]:
                box_x = 0
                box_width = im_size[0]
                box_height = int(round((im_size[0]/self.aspect_rate[0])*self.aspect_rate[1]))
                box_y = int(round(center[1]-(box_height/2)))
            else:
                box_x = int(round(center[0]-(box_width/2)))

        else:
            box_x = 0
            box_width = im_size[0]
            box_height = int(round((im_size[0]/self.aspect_rate[0])*self.aspect_rate[1]))
            if box_height>im_size[1]:
                box_y = 0
                box_height = im_size[1]
                box_width = int(round((im_size[1]/self.aspect_rate[1])*self.aspect_rate[0]))
                box_y = int(round(center[0]-(box_width/2)))
            else:
                box_y = int(round(center[1]-(box_height/2)))
        
        cropped_image = image[box_y:box_y+box_height, box_x:box_x+box_width]
        return cropped_image
    
    def crop(self, image, saliencymap, threshhold=0.3):
        """     
Zuschneiden mit Saliency Map.
        
        Parameters:
        -----------------
        image : ndarray, (h, w, rgb), np.uint8
        saliencymap : ndarray, (h, w), np.uint8
            Saliency map's ndarray need not be the same size as image's ndarray. Saliency map is resized within this method.
        threshhold : float
            0 < threshhold <1
            
        Returns:
        -----------------
        cropped_image : ndarray, (h, w, rgb), uint8
        """
        assert (threshhold>0)and(threshhold<1)
        assert image.dtype==np.uint8
        assert image.ndim==3
        assert saliencymap.dtype==np.uint8
        assert saliencymap.ndim==2
        
        threshhold = threshhold*255 # scale to 0 - 255
        self.im_size = (image.shape[1], image.shape[0]) # (width, height)
        saliencymap = cv2.resize(saliencymap, self.im_size, interpolation=cv2.INTER_CUBIC)

        # compute bounding box based on saliency map
        bounding_box_based_on_binary_saliency = self._compute_bounding_box_based_on_binary_saliency(saliencymap, threshhold)
        bounding_box = self._expand_small_bounding_box_to_minimum_size(bounding_box_based_on_binary_saliency)
        if self.aspect_rate is not None:
            bounding_box = self._expand_bounding_box_to_specified_aspect_ratio(bounding_box, saliencymap)
            
        box_y = bounding_box[1]
        box_x = bounding_box[0]
        box_height = bounding_box[3]
        box_width = bounding_box[2]
        
        cropped_image = image[box_y:box_y+box_height, box_x:box_x+box_width]
        
        self.bounding_box_based_on_binary_saliency = bounding_box_based_on_binary_saliency
        self.bounding_box  = bounding_box
        
        return cropped_image

# -------------------
# SETTING
threshhold = 0.3 #Schwellenwert einstellen, float (0<threshhold<1)
saliencymap_path = 'COCO_train2014_000000196971.png' #Saliency Map Pfad
image_path = 'COCO_train2014_000000196971.jpg' #Bildpfad
# -------------------

saliencymap = cv2.imread(saliencymap_path)[:, :, ::-1] # ndarray, (h, w, rgb), np.uint8 (0-255)
saliencymap = saliencymap[:, :, 0] # ndarray, (h, w), np.uint8 (0-255)
image = cv2.imread(image_path)[:, :, ::-1] # ndarray, (h, w, rgb), np.uint8 (0-255)

#Visualisierung von zugeschnittenen Bildern mit Saliency Map
cropper = SaliencyBasedImageCropping(aspect_rate=(1, 3))
cropped_image = cropper.crop(image, saliencymap, threshhold=0.3)
plt.imshow(cropped_image)
plt.show()

#Visualisierung der Saliency Map und des Begrenzungsrahmens
#Rot für das angegebene Seitenverhältnis, grün vor dem Abgleich
saliencymap_colored = color_saliencymap(saliencymap) # ndarray, (h, w, rgb), np.uint8
overlaid_image = overlay_saliencymap_and_image(saliencymap_colored, image) # ndarray, (h, w, rgb), np.uint8
box_x, box_y, box_width, box_height = cropper.bounding_box
box_x_0, box_y_0, box_width_0, box_height_0 = cropper.bounding_box_based_on_binary_saliency
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#FF0000', fill=False)
bounding_box_based_on_binary_saliency = patches.Rectangle(xy=(box_x_0, box_y_0), width=box_width_0, height=box_height_0, ec='#00FF00', fill=False)
ax.imshow(overlaid_image)
ax.add_patch(bounding_box)
ax.add_patch(bounding_box_based_on_binary_saliency)
plt.show()

#Visualisierung des Bildes mit beschnittener Mitte zum Vergleich
center_cropped_image = cropper.crop_center(image)
plt.imshow(center_cropped_image)
plt.show()

Np.convolve () wird verwendet, um den Bereich zu ermitteln, in dem die Saliency Map den Maximalwert hat.

numpy.convolve — NumPy v1.17 Manual

Dies ist eine eindimensionale Faltungsfunktion. Durch Falten mit einem Array aller Einsen der Länge, die Sie summieren möchten, können Sie die Summe für jeden festen Bereich wie unten gezeigt berechnen.

array_1d = np.array([1, 2, 3, 4])
print(np.convolve(array_1d, np.ones(2), 'valid')) # [3. 5. 7.]

Wenn Sie eine einfache for-Anweisung in Python verwenden, wird diese langsamer, sodass wir die NumPy-Funktionen so weit wie möglich kombinieren.

Darüber hinaus haben wir im Binarisierungsprozess eine Implementierung hinzugefügt, die einen sehr kleinen Bereich durch Morphologiekonvertierung löscht. Insbesondere wenn die Saliency Map danach durch tiefes Lernen erhalten wird, tritt wahrscheinlich ein solcher Bereich auf, sodass diese Implementierung hinzugefügt wird.

Morphology Conversion - OpenCV-Python Tutorials 1 Dokumentation

Siehe das Ergebnis

Der grüne Begrenzungsrahmen wird vor dem Anpassen des Seitenverhältnisses verwendet, und der rote Begrenzungsrahmen wird nach dem Anpassen des Seitenverhältnisses verwendet.

fig_7.png Abbildung: Ausschnitt mit einem Seitenverhältnis von 1: 3 unter Verwendung der Saliency Map

Wie in dieser Abbildung (a) gezeigt, sind eine Katze und eine Handseife in einem vertikal großen Bereich? Es gelang mir, durch Einfügen auszuschneiden. Im Vergleich zu der Abbildung (b), in der das Zentrum gerade ausgeschnitten wurde, ist der Teil, den Menschen sehen möchten, ein gutes Gefühl.

fig8.png Abbildung: Ergebnisse des Zuschneidens mit einem Seitenverhältnis von 1: 1 mithilfe der Saliency Map

Dies ist bei einem Quadrat (1: 1) der Fall. Wenn die Mitte ausgeschnitten ist (Abb. (B)), ist die Katze fest eingeschlossen. Bei Verwendung der Ausnahmekarte (Abb. (A)) wird jedoch ein schmalerer Bereich ausgeschnitten, sodass sie in derselben Größe angezeigt wird. Wenn die Katze größer wird. Beim Zuschneiden ist nicht nur wichtig, ob ein Objekt angezeigt wird, sondern auch, ob es in ausreichender Größe angezeigt wird.

Implementierung eines Modells (SalGAN), das die Saliency Map mithilfe von Deep Learning mit PyTorch schätzt

Sie können das von Ihnen selbst erstellte Bild nicht zuschneiden. Ich möchte das Bild des Weihnachtsbaums zuschneiden, den ich selbst aufgenommen habe, und nicht das Bild des SALICON-Datensatzes. Daher werde ich mithilfe von Deep Learning ein Schätzmodell für die Saliency Map erstellen.

Wenn Sie sich die Benchmark-Site "MIT Saliency Benchmark" [^ 4] für die Saliency Map-Aufgabe ansehen, werden Sie verschiedene Methoden finden, aber dieses Mal werden wir versuchen, SalGAN [^ 5] zu implementieren. Die Punktzahl scheint nicht sehr hoch zu sein, aber ich habe mich dafür entschieden, weil der Mechanismus einfach zu sein schien.

Die Implementierung des Autors [^ 6] wurde ebenfalls veröffentlicht, aber da das Framework mit Lasagne (Theano) nicht sehr vertraut ist, werde ich es in PyTorch schreiben, während ich mich darauf beziehe.

Was ist SalGAN?

"SalGAN: Visual Saliency Prediction with Generative Adversarial Networks" wurde 2017 veröffentlicht. Wie der Name schon sagt, handelt es sich um eine Methode zur Schätzung der Saliency Map unter Verwendung von ** GAN (Generative Adversarial Networks) **.

Ich werde die Erklärung zu GAN weglassen, da es bereits viele leicht verständliche Artikel gibt. Zum Beispiel wird GAN (1) Verstehen der Grundstruktur, die ich nicht mehr hören kann - Qiita empfohlen. Wenn Sie eine typische GAN-Methode kennen, die viel Implementierung und Erklärung enthält, können Sie sie implementieren, indem Sie den Unterschied berücksichtigen.

fig_salgan.png Abbildung: Gesamtstruktur von SalGAN (zitiert aus dem Artikel [^ 5])

Da die Saliency Map für jedes Pixel einen Ansichtspunkt (0 bis 1) hat, kann von einem binären Klassifizierungsproblem für jedes Pixel gesprochen werden. Es liegt in der Nähe einer Ein-Klassen-Segmentierung. Da wir ein Bild eingeben und ein Bild ausgeben möchten (Saliency Map), haben wir ein ** Encoder-Decoder-Modell **, das CNN verwendet, wie in dieser Abbildung gezeigt. Pix2Pix [^ 7] ist berühmt, wenn es um Bild-zu-Bild mit GAN geht, hat aber keine solche U-Net-Struktur.

Im Encoder-Decoder-Modell können Sie auch lernen, die ausgegebene Saliency Map und ** Binary Cross Entropy ** der richtigen Antwortdaten zu reduzieren. Dieses SalGAN versucht jedoch, die Genauigkeit zu verbessern, indem ein Netzwerk (** Diskriminator **) hinzugefügt wird, das die Saliency Map als korrekte Daten oder geschätzte Daten klassifiziert.

Die Verlustfunktion des Encoder-Decoder-Teils (** Generator **) ist wie folgt. Zusätzlich zum üblichen kontradiktorischen Verlust werden die geschätzte Ausnahmekarte und der Abschnitt "Binäre Kreuzentropie" der richtigen Antwortdaten hinzugefügt. Passen Sie den Prozentsatz mit dem Hyperparameter $ \ alpha $ an.

\mathcal{L}\_{BCE} = -\frac{1}{N}\sum\_{j=1}^{N}(S\_{j}\log{(\hat{S}\_{j})}+(1-S\_{j})\log{(1-\hat{S}\_{j})}).
\mathcal{L} = \alpha\cdot\mathcal{L}\_{BCE} + L(D(I, \hat{S}), 1).

Die Verlustfunktion von Discriminator ist wie folgt. Es ist eine allgemeine Form.

\mathcal{L}\_{\mathcal{D}} = L(D(I, S), 1)+L(D(I, \hat{S}),0).

Lesen Sie etwas mehr

Während wir das Papier zitieren, werden wir die für die Implementierung erforderlichen Informationen lesen. Ich habe es irgendwie verstanden, indem ich mir die Gesamtstruktur angesehen habe, aber ich werde nach dem Teil suchen, in dem die Informationen geschrieben sind, die ich ein bisschen mehr wissen möchte.

The encoder part of the network is identical in architecture to VGG-16 (Simonyan and Zisserman, 2015), omitting the final pooling and fully connected layers. The network is initialized with the weights of a VGG-16 model trained on the ImageNet data set for object classification (Deng et al., 2009). Only the last two groups of convolutional layers in VGG-16 are modified during the training for saliency prediction, while the earlier layers remain fixed from the original VGG-16 model.

The decoder architecture is structured in the same way as the encoder, but with the ordering of layers reversed, and with pooling layers being replaced by upsampling layers. Again, ReLU non-linearities are used in all convolution layers, and a final 1 × 1 convolution layer with sigmoid non-linearity is added to produce the saliency map. The weights for the decoder are randomly initialized. The final output of the network is a saliency map in the same size to input image.

--Decoder ist derselbe wie Encoder, fügt jedoch anstelle der Pooling-Schicht eine Upsampling-Schicht ein

The input to the discriminator network is an RGBS image of size 256×192×4 containing both the source image channels and (predicted or ground truth) saliency.

--Injizieren Sie nicht nur die Saliency Map, sondern auch das Originalbild in 4 Kanälen in Discriminator.

We train the networks on the 15,000 images from the SALICON training set using a batch size of 32.

--Verwenden Sie 15.000 Bilder aus dem SALICON-Datensatz

Ich möchte dieses Mal kein Reproduktionsexperiment des Papiers durchführen, daher sind mir die Details bei der Implementierung nicht besonders wichtig. Beispielsweise wird anstelle von ReLU in der Veröffentlichung Leaky ReLU verwendet, das allgemein als wirksam angesehen wird.

Code schreiben

Ich werde den Code schreiben. Da die Basis GAN des Encoder-Decoder-Modells von CNN ist, werden wir auf die Implementierung der bestehenden ähnlichen Methode verweisen. Zum Beispiel hat eriklindernorens GitHub verschiedene GANs in PyTorch implementiert. Die DCGAN-Implementierung [^ 8] sieht gut aus.

Erstellen Sie eine Generator- und eine Diskriminatorklasse

Der Generator verwendet VGG16, das mit ImageNet trainiert wurde und in Fackelvision [^ 9] bereitgestellt wird. In SalGAN ist das Gewicht auf der Vorderseite und das Lernen auf der Rückseite festgelegt. Beschreiben Sie es daher in separaten Ebenen wie "torchvision.models.vgg16 (pretrained = True) .features [: 17]". Sie können mit print (torchvision.models.vgg16 (pretrained = True) .features) überprüfen, welche Nummer welche Ebene ist.

torchvision.models — PyTorch master documentation

Klicken Sie hier, um die Implementierung
anzuzeigen
from torch import nn
import torchvision

class Generator(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()        
        self.encoder_first = torchvision.models.vgg16(pretrained=True).features[:17] #Das zu verwendende Teil mit festem Gewicht
        self.encoder_last = torchvision.models.vgg16(pretrained=True).features[17:-1] #Teil zu lernen
        self.decoder = nn.Sequential(
                    nn.Conv2d(512, 512, 3, padding=1), 
                    nn.LeakyReLU(),
                    nn.Conv2d(512, 512, 3, padding=1), 
                    nn.LeakyReLU(),
                    nn.Conv2d(512, 512, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Upsample(scale_factor=2),
                    nn.Conv2d(512, 512, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(512, 512, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(512, 512, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Upsample(scale_factor=2),
                    nn.Conv2d(512, 256, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(256, 256, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(256, 256, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Upsample(scale_factor=2),
                    nn.Conv2d(256, 128, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(128, 128, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Upsample(scale_factor=2),
                    nn.Conv2d(128, 64, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(64, 64, 3, padding=1),
                    nn.LeakyReLU(),
                    nn.Conv2d(64, 1, 1, padding=0),
                    nn.Sigmoid())

    def forward(self, x):
        x = self.encoder_first(x)
        x = self.encoder_last(x)
        x = self.decoder(x)
        return x

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
                    nn.Conv2d(4, 3, 1, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(3, 32, 3, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.MaxPool2d(2, stride=2),
                    nn.Conv2d(32, 64, 3, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(64, 64, 3, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.MaxPool2d(2, stride=2),
                    nn.Conv2d(64, 64, 3, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.Conv2d(64, 64, 3, padding=1),
                    nn.LeakyReLU(inplace=True),
                    nn.MaxPool2d(2, stride=2))
        self.classifier = nn.Sequential(
                    nn.Linear(64*32*24, 100, bias=True),
                    nn.Tanh(),
                    nn.Linear(100, 2, bias=True),
                    nn.Tanh(),
                    nn.Linear(2, 1, bias=True),
                    nn.Sigmoid())

    def forward(self, x):
        x = self.main(x)
        x = x.view(x.shape[0], -1)
        x = self.classifier(x)
        return x

Erstellen Sie eine Dataset-Klasse

Sie benötigen eine Dataset-Klasse, um das SALICON-Dataset zu lesen. Es ist ein wenig mühsam, entsprechend dem vorbereiteten Datensatz und der vorbereiteten Aufgabe zu schreiben. Wenn es um das Schreiben geht, ist das PyTorch-Tutorial hilfreich.

Writing Custom Datasets, DataLoaders and Transforms — PyTorch Tutorials 1.3.1 documentation

Hier wird auch die Vorverarbeitung mit "torchvision.transforms" beschrieben. Dieses Mal werden wir nur die Größe auf 192 x 256 ändern und normalisieren.

torchvision.transforms — PyTorch master documentation

Der SALICON-Datensatz kann von LSUN’17 Saliency Prediction Challenge | SALICON heruntergeladen werden.

Klicken Sie hier, um die Implementierung
anzuzeigen
import os

import torch.utils.data as data
import torchvision.transforms as transforms

class SALICONDataset(data.Dataset):
    def __init__(self, root_dataset_dir, val_mode = False):
        """
Datensatzklasse zum Lesen von SALICON-Datensätzen
        
        Parameters:
        -----------------
        root_dataset_dir : str
Pfad des Verzeichnisses über dem SALICON-Datensatz
        val_mode : bool (default: False)
Wenn False, lesen Sie die Zugdaten. Wenn True, lesen Sie die Validierungsdaten.
        """
        self.root_dataset_dir = root_dataset_dir
        self.imgsets_dir = os.path.join(self.root_dataset_dir, 'SALICON/image_sets')
        self.img_dir = os.path.join(self.root_dataset_dir, 'SALICON/imgs')
        self.distribution_target_dir = os.path.join(self.root_dataset_dir, 'SALICON/algmaps')
        self.img_tail = '.jpg'
        self.distribution_target_tail = '.png'
        self.transform = transforms.Compose([transforms.Resize((192, 256)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
        self.distribution_transform = transforms.Compose([transforms.Resize((192, 256)), transforms.ToTensor()])
        
        if val_mode:
            train_or_val = "val"
        else:
            train_or_val = "train"
        imgsets_file = os.path.join(self.imgsets_dir, '{}.txt'.format(train_or_val))
        files = []
        for data_id in open(imgsets_file).readlines():
            data_id = data_id.strip()
            img_file = os.path.join(self.img_dir, '{0}{1}'.format(data_id, self.img_tail))
            distribution_target_file = os.path.join(self.distribution_target_dir, '{0}{1}'.format(data_id, self.distribution_target_tail))
            files.append({
                'img': img_file,
                'distribution_target': distribution_target_file,
                'data_id': data_id
            })
        self.files = files
        
    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        """
        Returns
        -----------
        data : list
            [img, distribution_target, data_id]
        """
        data_file = self.files[index]
        data = []

        img_file = data_file['img']
        img = Image.open(img_file)
        data.append(img)

        distribution_target_file = data_file['distribution_target']
        distribution_target = Image.open(distribution_target_file)
        data.append(distribution_target)
        
        # transform
        data[0] = self.transform(data[0])
        data[1] = self.distribution_transform(data[1])

        data.append(data_file['data_id'])
        return data

lernen

Schreiben Sie den Code für den Rest des Lernens. Der Punkt ist, wie man die Verlustfunktion berechnet und wie man Generator und Diskriminator lernt.

Es dauert ungefähr mehrere Stunden, um die gleichen 120 Epochen wie das Papier mit der GPU zu lernen.

Klicken Sie hier, um die Implementierung
anzuzeigen
from datetime import datetime

import torch
from torch.autograd import Variable

#-----------------
# SETTING
root_dataset_dir = "" #Pfad des Verzeichnisses über dem SALICON-Datensatz
alpha = 0.005 #Hyperparameter der Generatorverlustfunktion. Der empfohlene Wert für Papier ist 0.005
epochs = 120
batch_size = 32 #32 in der Zeitung
#-----------------

#Verwenden Sie die Startzeit für den Dateinamen
start_time_stamp = '{0:%Y%m%d-%H%M%S}'.format(datetime.now())

save_dir = "./log/"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#Laden des Datenladers
train_dataset = SALICONDataset(
                    root_dataset_dir,
                )
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle=True, num_workers = 4, pin_memory=True, sampler=None)
val_dataset = SALICONDataset(
                    root_dataset_dir,
                    val_mode=True
                )
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 1, shuffle=False, num_workers = 4, pin_memory=True, sampler=None)

#Lastmodell und Verlustfunktion
loss_func = torch.nn.BCELoss().to(DEVICE)
generator = Generator().to(DEVICE)
discriminator = Discriminator().to(DEVICE)

#Definition der Optimierungsmethode (unter Verwendung der Einstellungen im Papier)
optimizer_G = torch.optim.Adagrad([
                {'params': generator.encoder_last.parameters()},
                {'params': generator.decoder.parameters()}
            ], lr=0.0001, weight_decay=3*0.0001)
optimizer_D = torch.optim.Adagrad(discriminator.parameters(), lr=0.0001, weight_decay=3*0.0001)

#Lernen
for epoch in range(epochs):
    n_updates = 0 #Iterationszahl
    n_discriminator_updates = 0
    n_generator_updates = 0
    d_loss_sum = 0
    g_loss_sum = 0
    
    for i, data in enumerate(train_loader):
        imgs = data[0] # ([batch_size, rgb, h, w])
        salmaps = data[1] # ([batch_size, 1, h, w])

        #Erstellen Sie ein Label für Discriminator
        valid = Variable(torch.FloatTensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False).to(DEVICE)
        fake = Variable(torch.FloatTensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False).to(DEVICE)

        imgs = Variable(imgs).to(DEVICE)
        real_salmaps = Variable(salmaps).to(DEVICE)

        #Lernen Sie abwechselnd Generator und Diskriminator für jede Iteration
        if n_updates % 2 == 0:
            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()
            gen_salmaps = generator(imgs)
            
            #Kombinieren Sie das Originalbild und die generierte Saliency Map für die Eingabe in den Discriminator, um ein 4-Kanal-Array zu erstellen
            fake_d_input = torch.cat((imgs, gen_salmaps.detach()), 1) # ([batch_size, rgbs, h, w])
            
            #Berechnen Sie die Verlustfunktion des Generators
            g_loss1 = loss_func(gen_salmaps, real_salmaps)
            g_loss2 = loss_func(discriminator(fake_d_input), valid)
            g_loss = alpha*g_loss1 + g_loss2
            
            g_loss.backward()
            optimizer_G.step()
            
            g_loss_sum += g_loss.item()
            n_generator_updates += 1
            
        else:
            # ---------------------
            #  Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()
            
            #Kombinieren Sie das Originalbild und die Ausnahmekarte der richtigen Antwortdaten für die Eingabe in den Diskriminator, um ein 4-Kanal-Array zu erstellen
            real_d_input = torch.cat((imgs, real_salmaps), 1) # ([batch_size, rgbs, h, w])

            #Berechnen Sie die Verlustfunktion von Discriminator
            real_loss = loss_func(discriminator(real_d_input), valid)
            fake_loss = loss_func(discriminator(fake_d_input), fake)
            d_loss = (real_loss + fake_loss) / 2

            d_loss.backward()
            optimizer_D.step()
            
            d_loss_sum += d_loss.item()
            n_discriminator_updates += 1
            
        n_updates += 1
        if n_updates%10==0:
            if n_discriminator_updates>0:
                print(
                    "[%d/%d (%d/%d)] [loss D: %f, G: %f]"
                    % (epoch, epochs-1, i, len(train_loader), d_loss_sum/n_discriminator_updates , g_loss_sum/n_generator_updates)
                )
            else:
                print(
                    "[%d/%d (%d/%d)] [loss G: %f]"
                    % (epoch, epochs-1, i, len(train_loader), g_loss_sum/n_generator_updates)
                )                
    
    #Gewichte sparen
    #Speichern Sie alle 5 Epochen und die letzte Epoche
    if ((epoch+1)%5==0)or(epoch==epochs-1):
        generator_save_path = '{}.pkl'.format(os.path.join(save_dir, "{}_generator_epoch{}".format(start_time_stamp, epoch)))
        discriminator_save_path = '{}.pkl'.format(os.path.join(save_dir, "{}_discriminator_epoch{}".format(start_time_stamp, epoch)))
        torch.save(generator.state_dict(), generator_save_path)
        torch.save(discriminator.state_dict(), discriminator_save_path)
        
    #Visualisieren Sie einen Teil der Validierungsdaten für jede Epoche
    with torch.no_grad():
        print("validation")
        for i, data in enumerate(val_loader):
            image = Variable(data[0]).to(DEVICE)
            gen_salmap = generator(imgs)
            gen_salmap_np = np.array(gen_salmaps.data.cpu())[0, 0]
            
            plt.imshow(np.array(image[0].cpu()).transpose(1, 2, 0))
            plt.show()
            plt.imshow(gen_salmap_np)
            plt.show()
            if i==1:
                break

Schätzen Sie die Saliency Map

Geben Sie ein Bild in das erlernte SalGAN ein und versuchen Sie, die Saliency Map zu schätzen. Sehen Sie, wie es mit Bildern geschätzt wird, die nicht für das Training verwendet werden.

Klicken Sie hier, um die Implementierung
anzuzeigen
generator_path = "" #Pfad der Generatorgewichtsdatei (pkl), die durch Lernen erhalten wurde
image_path = "COCO_train2014_000000196971.jpg " #Der Pfad des Bildes, das Sie eingeben möchten

generator = Generator().to(DEVICE)
generator.load_state_dict(torch.load(generator_path))

image_pil = Image.open(image_path) #Die Bildeingabe im PIL-Format wird für die Transformation angenommen
image = np.array(image_pil)
plt.imshow(image)
plt.show()

transform = transforms.Compose([transforms.Resize((192, 256)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
image_torch = transform(image_pil)
image_torch = image_torch.unsqueeze(0).to(DEVICE) # (1, rgb, h, w)
with torch.no_grad():
    pred_saliencymap = generator(img_torch)
    pred_saliencymap = np.array(pred_saliencymap.cpu())[0, 0]
pred_saliencymap = pred_saliencymap/pred_saliencymap.sum() #Skaliert so, dass die Summe 1 ist
pred_saliencymap = ((pred/pred.max())*255).astype(np.uint8) #Np, damit es als Bild behandelt werden kann.In uint8 konvertieren
plt.imshow(pred_saliencymap)
plt.show()

fig_15.png

Abbildung: Beispiel einer von SalGAN geschätzten Saliency Map

Die Saliency Map in dieser Abbildung (b) wurde geschätzt. Im Vergleich zu den richtigen Antwortdaten (Abb. (C)) hat es einen guten Eindruck, aber die Wahrscheinlichkeit, Punkte um den Krug und den Teig herum zu überzeugen, kann hoch geschätzt werden.

Wenn Sie gut lernen möchten, müssen Sie anhand verschiedener Indikatoren für die Saliency Map überprüfen, wie dies in diesem Dokument beschrieben wird. Da wir es jetzt nicht verifiziert haben, ist unklar, wie viele Ergebnisse im Vergleich zu dem in dem Papier vorgestellten SalGAN erzielt werden. Dieses Mal ist die Hauptsache das Zuschneiden, also würde ich gerne weitermachen, weil ich so etwas qualitativ gemacht habe.

Beschneiden Sie das Bild mit der geschätzten Saliency Map

Jetzt hast du was du machen willst. Durch Kombinieren der von SalGAN geschätzten Salency Map mit der Zuschneideklasse können Sie das Bild gut zuschneiden.

Klicken Sie hier, um die Implementierung
anzuzeigen
# -------------------
# SETTING
threshhold = 0.3 #Schwellenwert einstellen, float (0<threshhold<1)
generator_path = "" #Pfad der Generatorgewichtsdatei (pkl), die durch Lernen erhalten wurde
image_path = "COCO_train2014_000000196971.jpg " #Pfad des Bildes, das Sie zuschneiden möchten
# -------------------

generator = Generator().to(DEVICE)
generator.load_state_dict(torch.load(generator_path))

image_pil = Image.open(image_path) #Die Bildeingabe im PIL-Format wird für die Transformation angenommen
image = np.array(image_pil)
plt.imshow(image)
plt.show()

transform = transforms.Compose([transforms.Resize((192, 256)), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
image_torch = transform(image_pil)
image_torch = image_torch.unsqueeze(0).to(DEVICE) # (1, rgb, h, w)
with torch.no_grad():
    pred_saliencymap = generator(image_torch)
    pred_saliencymap = np.array(pred_saliencymap.cpu())[0, 0]
pred_saliencymap = pred_saliencymap/pred_saliencymap.sum() #Skaliert so, dass die Summe 1 ist
pred_saliencymap = ((pred_saliencymap/pred_saliencymap.max())*255).astype(np.uint8) #Np, damit es als Bild behandelt werden kann.In uint8 konvertieren
plt.imshow(pred_saliencymap)
plt.show()

#Visualisierung von zugeschnittenen Bildern mit Saliency Map
cropper = SaliencyBasedImageCropping(aspect_rate=(1, 3))
cropped_image = cropper.crop(image, pred_saliencymap, threshhold=0.3)
plt.imshow(cropped_image)
plt.show()

#Visualisierung der Saliency Map und des Begrenzungsrahmens
#Rot für das angegebene Seitenverhältnis, grün vor dem Abgleich
saliencymap_colored = color_saliencymap(pred_saliencymap) # ndarray, (h, w, rgb), np.uint8
overlaid_image = overlay_saliencymap_and_image(saliencymap_colored, image) # ndarray, (h, w, rgb), np.uint8
box_x, box_y, box_width, box_height = cropper.bounding_box
box_x_0, box_y_0, box_width_0, box_height_0 = cropper.bounding_box_based_on_binary_saliency
fig = plt.figure()
ax = plt.axes()
bounding_box = patches.Rectangle(xy=(box_x, box_y), width=box_width, height=box_height, ec='#FF0000', fill=False)
bounding_box_based_on_binary_saliency = patches.Rectangle(xy=(box_x_0, box_y_0), width=box_width_0, height=box_height_0, ec='#00FF00', fill=False)
ax.imshow(overlaid_image)
ax.add_patch(bounding_box)
ax.add_patch(bounding_box_based_on_binary_saliency)
plt.show()

#Visualisierung des Bildes mit beschnittener Mitte zum Vergleich
center_cropped_image = cropper.crop_center(image)
plt.imshow(center_cropped_image)
plt.show()

fig_12.png Abbildung: Vergleich des Zuschneidens bei Verwendung korrekter Antwortdaten und bei Verwendung von SalGAN (Baseballbild)

Bei Verwendung der richtigen Antwortdaten in Abb. (A) und bei Verwendung von SalGAN in Abb. (B) wurde durch Ausschneiden des Teigteils fast das gleiche Ergebnis erzielt. Welcher, der Teig oder der Krug, sieht dich an? Ich glaube nicht, dass ich genug darüber gelernt habe, aber ich bin froh, dass ich es auf die gleiche Weise geschafft habe.

Ist diese Art von Geschichte nicht eine Kirschernte, die nur erfolgreiche Ergebnisse auflistet? Ich habe eine Frage. Wenn Sie dieses Ergebnis quantitativ messen möchten, können Sie sehen, wie sich diese beiden Typen mit dem SALICON-Dataset überschneiden. IoU-ähnliche Berechnungen bei Objekterkennungsaufgaben. Ich werde es jedoch weglassen, weil es eine Geschichte ist, die ich gerade gemacht habe.

Das Bild der Katze, das in der ersten Hälfte des Artikels erschien, wurde verwendet, weil es niedlich ist, aber es handelt sich tatsächlich um Zugdaten, sodass es nicht zur Überprüfung geeignet ist. Aber schauen wir uns das an.

fig_11.png Abbildung: Vergleich des Zuschneidens bei Verwendung korrekter Antwortdaten und bei Verwendung von SalGAN (Katzenbild)

Hier wurde fast das gleiche Ergebnis erzielt. Es war gut.

Schneiden Sie den Weihnachtsbaum aus

Zum Schluss kehren wir zum Bild des ersten Weihnachtsbaumes zurück. Wenn Sie mit einem Foto, das Sie aufgenommen haben und das nicht im Datensatz enthalten ist, einen überzeugenden Ausschnitt erzielen können, haben Sie Ihr Ziel erreicht.

fig_13.png Abbildung: Vergleich von Zuschneiden mit SalGAN und einfach zentriert (Weihnachtsbaumbild)

Ich habe das perfekte Ergebnis erzielt. Die Figur (a) mit dem Weihnachtsbaum sieht besser aus als die Figur (b) ohne. KI? Wir haben einen Mechanismus fertiggestellt, der automatisch Dinge tun kann, die dem nahe kommen, was Menschen mit der Kraft von tun. Dies kann die Enttäuschung beim Zuschneiden leerer Bereiche oder die Arbeit beim manuellen Zuschneiden von Bildern verringern.

Zusammenfassung

Es war die Implementierung von zwei Papieren, die mit Saliency Map [^ 2] und SalGAN [^ 5] zur Schätzung von Saliency Map + Alpha beschnitten wurden.

Sie können so etwas nur mit den öffentlichen Informationen machen. Selbst wenn Sie aufgehört haben, so etwas wie ein Tutorial für tiefes Lernen oder maschinelles Lernen zu bewegen, möchte ich, dass Sie sich einer kleinen Herausforderung stellen und versuchen, so etwas zu machen!

Recommended Posts

Ich habe eine KI erstellt, die ein Bild mit Saliency Map gut zuschneidet
Ich habe eine Android-App erstellt, die Google Map anzeigt
Erstellt ein Bildunterscheidungsmodell (cifar10) unter Verwendung eines Faltungs-Neuronalen Netzwerks
Ich habe ein Extenum-Paket erstellt, das die Enumeration erweitert
Ich habe einen LINE BOT erstellt, der mithilfe der Flickr-API ein Bild von Reis-Terroristen zurückgibt
Ich habe eine KI erstellt, die aus Trivia vorhersagt, und mich dazu gebracht, auf meine Trivia zu schließen. Hee-AI
Ich habe einen Ansible-Installer gemacht
Ich habe ein Anomalieerkennungsmodell erstellt, das unter iOS funktioniert
Ich habe einen Original-Programmführer mit der NHK-Programmführer-API erstellt.
Ich habe einen Xubuntu-Server erstellt.
[Python] Ich habe einen Bildbetrachter mit einer einfachen Sortierfunktion erstellt.
Ich habe einen Anpanman-Maler diskriminiert
Ich habe ein Angular Starter Kit gemacht
Ich habe eine KI gemacht, um zu beurteilen, ob es Alkohol ist oder nicht!
Ich habe ein Plug-In "EZPrinter" erstellt, das Karten-PDF mit QGIS einfach ausgibt.
Ich möchte ein Bild auf Jupyter Notebook mit OpenCV (Mac) anzeigen.
Ich habe eine Chrome-Erweiterung erstellt, die ein Diagramm auf der Amedas-Seite anzeigt
ConSinGAN: Ich habe versucht, GAN zu verwenden, das aus einem Bild generiert werden kann