[PYTHON] Essayez la segmentation sémantique (Pytorch)

en premier

La segmentation sémantique est un type de technologie de reconnaissance d'image qui peut être reconnue pixel par pixel. seg.png

Je laisserai la théorie détaillée séparément, mais j'aimerais essayer la segmentation sémantique en utilisant Pytorch. Cette fois-ci, nous traiterons d'un réseau moins profond et plus simple et pouvant être suffisamment appris même avec un ordinateur portable, plutôt qu'un réseau avec une structure profonde et compliquée comme Seg-Net, U-net ou PSP-net.

L'environnement est CPU: intel(R) core(TM)i5 7200U Mémoire: 8 Go OS: Windows10 python ver3.6.9 pytorch ver1.3.1 numpy ver1.17.4

Créer un jeu de données

Cette fois, j'utiliserai l'image que j'ai moi-même composée. L'image de la ligne supérieure correspond aux données d'entrée et l'image remplie inférieure correspond aux données de l'enseignant. En d'autres termes, il crée automatiquement un réseau qui se remplit comme un logiciel de peinture. input_auto.png correct_auto.png

Créez les données nécessaires à l'apprentissage. imgs correspond à 1000 données d'entrée imgs_ano produit 1000 feuilles de données de sortie (données de l'enseignant) Les carrés et les carrés ne sont pas couverts, et la longueur des côtés et le nombre de carrés sont également déterminés au hasard.

import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader

def rectangle(img, img_ano, centers, max_side):
    """
img… Image bidimensionnelle avec uniquement des lignes carrées
    img_ano… cette image d'artation
centres… liste des coordonnées du centre
    max_côté… 1 de la longueur maximale du côté/2 
    """
    if max_side < 3: #max_Quand le côté est trop petit
        max_side = 4
    #Côté longueur 1/Définir 2
    side_x = np.random.randint(3, int(max_side))
    side_y = np.random.randint(3, int(max_side))    
    
    #Coordonnées du centre,(x, y)Définir
    x = np.random.randint(max_side + 1, img.shape[0] - (max_side + 1))
    y = np.random.randint(max_side + 1, img.shape[1] - (max_side + 1))
    
    #Lorsqu'une position proche de la position centrale passée est incluse,Renvoie les données d'entrée telles quelles
    for center in centers:
        if np.abs(center[0] - x) < (2 *max_side + 1):
            if np.abs(center[1] - y) < (2 * max_side + 1):
                return img, img_ano, centers
            
    img[x - side_x : x + side_x, y - side_y] = 1.0      #Face supérieure
    img[x - side_x : x + side_x, y + side_y] = 1.0      #La partie au fond
    img[x - side_x, y - side_y : y + side_y] = 1.0      #Côté gauche
    img[x + side_x, y - side_y : y + side_y + 1] = 1.0  #côté droit
    img_ano[x - side_x : x + side_x + 1, y - side_y : y + side_y + 1] = 1.0
    centers.append([x, y])
    return img, img_ano, centers


num_images = 1000                                   #Nombre d'images à générer
length = 64                                          #Taille de l'image
imgs = np.zeros([num_images, 1, length, length])     #Générer une matrice zéro,Image d'entrée
imgs_ano = np.zeros([num_images, 1, length, length]) #Image de sortie

for i in range(num_images):
    centers = []
    img = np.zeros([length, length])
    img_ano = np.zeros([64, 64])
    for j in range(6):                       #Générez jusqu'à 6 quads
        img, img_ano, centers = rectangle(img, img_ano, centers, 12) 
    imgs[i, 0, :, :] = img
    imgs_ano[i, 0, :, :] = img_ano
   
imgs = torch.tensor(imgs, dtype = torch.float32)                 #ndarray - torch.tensor
imgs_ano = torch.tensor(imgs_ano, dtype = torch.float32)           #ndarray - torch.tensor
data_set = TensorDataset(imgs, imgs_ano)
data_loader = DataLoader(data_set, batch_size = 100, shuffle = True)

Réseau_1

Créez ensuite une classe de réseau dans Pytorch. Tout d'abord, j'ai utilisé le réseau défini par l'encodeur automatique de précédent tel quel. L'auto-encodeur et la segmentation génèrent une image de la même taille que l'image d'entrée, donc je pourrais l'utiliser (dans le cas de Pytorch). Que se passe-t-il avec Tensorflow?

class ConvAutoencoder(nn.Module):
    def __init__(self):
        super(ConvAutoencoder, self).__init__()
        #Encoder Layers
        self.conv1 = nn.Conv2d(in_channels = 1,
                               out_channels = 16,
                               kernel_size = 3,
                               padding = 1)
        self.conv2 = nn.Conv2d(in_channels = 16,
                               out_channels = 4,
                               kernel_size = 3,
                               padding = 1)
        #Decoder Layers
        self.t_conv1 = nn.ConvTranspose2d(in_channels = 4, out_channels = 16,
                                          kernel_size = 2, stride = 2)
        self.t_conv2 = nn.ConvTranspose2d(in_channels = 16, out_channels = 1,
                                          kernel_size = 2, stride = 2)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        #encode#                           
        x = self.relu(self.conv1(x))        
        x = self.pool(x)                  
        x = self.relu(self.conv2(x))      
        x = self.pool(x)                  
        #decode#
        x = self.relu(self.t_conv1(x))    
        x = self.sigmoid(self.t_conv2(x))
        return x

Apprenons sur ce réseau.

#******Sélectionnez un réseau******
net = ConvAutoencoder()                               
loss_fn = nn.MSELoss()                                #Définition de la fonction de perte
optimizer = optim.Adam(net.parameters(), lr = 0.01)

losses = []                                     #Record de perte pour chaque époque
epoch_time = 30
for epoch in range(epoch_time):
    running_loss = 0.0                          #Calcul de la perte pour chaque époque
    net.train()
    for i, (XX, yy) in enumerate(data_loader):
        optimizer.zero_grad()       
        y_pred = net(XX)
        loss = loss_fn(y_pred, yy)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print("epoch:",epoch, " loss:", running_loss/(i + 1))
    losses.append(running_loss/(i + 1))

#Visualisation de la perte
plt.plot(losses)
plt.ylabel("loss")
plt.xlabel("epoch time")
plt.savefig("loss_auto")
plt.show()

C'est une visualisation de la perte pour chaque époque. L'époque a-t-elle convergé dans une certaine mesure après 30 fois? loss_auto.png

Essayez-le avec une image que vous n'avez pas utilisée pour l'apprentissage. Je peux déterminer la position approximative, mais j'ai l'impression que la zone autour de la frontière n'est pas bien prise. output_auto.png

net.eval()            #Mode d'évaluation
#Générer une image qui n'a pas encore été apprise
num_images = 1
img_test = np.zeros([num_images, 1, length, length])
imgs_test_ano = np.zeros([num_images, 1, length, length])
for i in range(num_images):
    centers = []
    img = np.zeros([length, length])
    img_ano = np.zeros([length, length])
    for j in range(6):
        img, img_ano, centers = rectangle(img, img_ano, centers, 7)
    img_test[i, 0, :, :] = img

img_test = img_test.reshape([1, 1, 64, 64])
img_test = torch.tensor(img_test, dtype = torch.float32)
img_test = net(img_test)             #Transférer l'image générée vers le réseau formé
img_test = img_test.detach().numpy() #torch.tensor - ndarray
img_test = img_test[0, 0, :, :]

plt.imshow(img, cmap = "gray")       #Visualisation des données d'entrée
plt.savefig("input_auto")
plt.show()
plt.imshow(img_test, cmap = "gray")  #Visualisation des données de sortie
plt.savefig("output_auto")
plt.show()
plt.imshow(img_ano, cmap = "gray")   #Corriger les données de réponse
plt.savefig("correct_auto")
plt.plot()

Essayez d'approfondir le réseau.

Je ne pouvais pas obtenir suffisamment de performances avec le modèle précédent, alors j'aimerais approfondir la couche. Ici, ne nous contentons pas d'approfondir, mais ajoutons également une normalisation par lots pour éviter le sur-apprentissage et le suréchantillonnage avec un décodeur. Pour une explication détaillée du suréchantillonnage, cet article était facile à comprendre.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        #encoder
        self.encoder_conv_1 = nn.Sequential(*[
                                            nn.Conv2d(in_channels = 1, 
                                                      out_channels = 6,
                                                      kernel_size = 3,
                                                      padding = 1),
                                            nn.BatchNorm2d(6)
                                            ])
        
        self.encoder_conv_2 = nn.Sequential(*[
                                            nn.Conv2d(in_channels = 6,
                                                      out_channels = 16,
                                                      kernel_size = 3,
                                                      padding = 1),
                                            nn.BatchNorm2d(16)
                                            ])
        self.encoder_conv_3 = nn.Sequential(*[
                                            nn.Conv2d(in_channels = 16,
                                                      out_channels = 32,
                                                      kernel_size = 3,
                                                      padding = 1),
                                            nn.BatchNorm2d(32)
                                            ])
        
        #decoder
        self.decoder_convt_3 = nn.Sequential(*[
                                            nn.ConvTranspose2d(in_channels = 32,
                                                               out_channels = 16,
                                                               kernel_size = 3,
                                                               padding = 1),
                                            nn.BatchNorm2d(16)
                                            ])
        
        self.decoder_convt_2 = nn.Sequential(*[
                                            nn.ConvTranspose2d(in_channels = 16,
                                                               out_channels = 6,
                                                               kernel_size = 3,
                                                               padding = 1),
                                            nn.BatchNorm2d(6)
                                            ])
        
        self.decoder_convt_1 = nn.Sequential(*[
                                            nn.ConvTranspose2d(in_channels = 6,
                                                               out_channels = 1,
                                                               kernel_size = 3,
                                                               padding = 1)
                                            ])
    
    def forward(self, x):
        #encoder
        dim_0 = x.size()                    
        x = F.relu(self.encoder_conv_1(x))                            
        x, indices_1 = F.max_pool2d(x, kernel_size = 2,
                                    stride = 2, 
                                    return_indices = True)  #Enregistrer la position de maxpool avec indice
        dim_1 = x.size()
        x = F.relu(self.encoder_conv_2(x))                            
        x, indices_2 = F.max_pool2d(x, kernel_size = 2,
                                    stride = 2, 
                                    return_indices = True)            
        
        dim_2 = x.size()
        x = F.relu(self.encoder_conv_3(x))
        x, indices_3 = F.max_pool2d(x, kernel_size = 2,
                                    stride = 2, 
                                    return_indices = True)
        
        #decoder
        x = F.max_unpool2d(x, indices_3, kernel_size = 2,
                           stride = 2, output_size = dim_2)
        x = F.relu(self.decoder_convt_3(x))
        
        x = F.max_unpool2d(x, indices_2, kernel_size = 2,
                           stride = 2, output_size = dim_1)           
        x = F.relu(self.decoder_convt_2(x))                           
        
        x = F.max_unpool2d(x, indices_1, kernel_size = 2,
                           stride = 2, output_size = dim_0)           
        x = F.relu(self.decoder_convt_1(x))                           
        x = torch.sigmoid(x)                                       
        
        return x

Il est facile de passer à ce réseau

#******Sélectionnez un réseau******
net = ConvAutoencoder()

Changez simplement l'emplacement de la classe nouvellement créée.

#******Sélectionnez un réseau******
net = Net()

Représentez graphiquement la transition de la perte.

loss_auto.png

Entrez les données qui ne sont pas utilisées pour l'entraînement et comparez-les avec l'image correcte. output.png

Vous pouvez voir que la segmentation est terminée.

À la fin

J'ai essayé une segmentation simple cette fois. C'était un modèle simple qui était loin d'être pratique, mais je sens que j'ai pu saisir l'atmosphère.

Recommended Posts

Essayez la segmentation sémantique (Pytorch)
Ours ... pas de segmentation sémantique
Résumé des problèmes lors de la segmentation sémantique avec Pytorch
Essayez Auto Encoder avec Pytorch
Essayez d'implémenter XOR avec PyTorch
[PyTorch] Augmentation des données pour la segmentation
Vision par ordinateur: segmentation sémantique, partie 2 - segmentation sémantique en temps réel
Essayez le nouveau chaînage du planificateur dans PyTorch 1.4