[PYTHON] [PyTorch] Augmentation des données pour la segmentation

0. Qui est la cible de cet article

  1. Ceux qui mettent en œuvre la ** segmentation ** d'image à l'aide de ** PyTorch **
  2. Ceux qui souhaitent gonfler les données avec ** Augmentation des données **
  3. Ceux qui souhaitent appliquer ** exactement le même traitement ** à l'image d'origine et à l'image de masque correspondantes
  4. Surtout ceux qui utilisent ** leur propre ensemble de données ** (les données ne sont pas dans torchvision.datasets)

1. Vue d'ensemble

Principalement pour les ensembles de données de segmentation supervisée ou semi-supervisée

2 problèmes

Avant d'examiner le cas en question, considérons d'abord le cas où il n'y a pas de problème.

2.1 Aucun cas de problème (reconnaissance de classe d'objets, etc.)

Lors de l'exécution de l'augmentation des données avec PyTorch, la conversion est généralement définie comme suit.


transform = torchvision.transforms.Compose([
    #Faire pivoter par degrés d'angle
    transforms.RandomRotation(degrees),
    #Inverser horizontalement
    transforms.RandomHorizontalFlip(),
    #Inverser verticalement
    transforms.RandomVerticalFlip()
])

Je vais le mettre dans l'argument de l'ensemble de données


dataset = HogeDataset.HogeDataset(
    train=True, transform=transform
)

Ce n'est probablement pas un problème pour la reconnaissance des classes d'objets, etc. La raison en est que les données de l'enseignant ne sont pas une image, vous n'avez donc qu'à traiter l'image d'origine.

2.2 Cas problématiques (comme la segmentation)

Le prochain cas problématique La différence avec le cas précédent est que les données de l'enseignant sont données sous forme d'image.


transform = torchvision.transforms.Compose([
    #Faire pivoter par degrés d'angle
    transforms.RandomRotation(degrees),
    #Inverser horizontalement
    transforms.RandomHorizontalFlip(),
    #Inverser verticalement
    transforms.RandomVerticalFlip()
])

Je vais le mettre dans l'argument de l'ensemble de données


dataset = HogeDataset.HogeDataset(
    train=True, transform=transform, target_transform=transform
)

Cependant, cela ne signifie pas que les conversions effectuées sur l'image d'origine et l'image de masque sont compatibles lors de la récupération des données de «HogeDataset». Exemple) Image originale: rotation de 90 degrés, image de masque: rotation de 270 degrés Dans ce cas, même si les données sont gonflées, elles ne fonctionneront pas comme des données d'enseignant. Argument target_transform, pourquoi existez-vous? Cependant, la raison de cette existence est probablement que l'image du masque est également traitée (sans caractère aléatoire) comme «torchvision.transforms.Resize ()» et «torchvision.transforms.ToTensor ()». Je pense que c'est dans

3 solution

Alors, comment pouvons-nous appliquer le même traitement à l'image de masque que l'image d'origine? En guise de solution, vous pouvez créer votre propre classe Dataset comme indiqué ci-dessous.

HogeDataset.py


import os
import glob
import torch
from torchvision import transforms
from torchvision.transforms import functional as tvf
import random
from PIL import Image

DATA_PATH = '[Chemin du répertoire d'image d'origine]'
MASK_PATH = '[Chemin du répertoire de l'image du masque]'
TRAIN_NUM = [Nombre de données d'entraînement]

class HogeDataset(torch.utils.data.Dataset):
    def __init__(self, transform = None, target_transform = None, train = True):
        #transformer et cibler_transform est une transformation non aléatoire telle que la tensorisation
        self.transform = transform
        self.target_transform = target_transform


        data_files = glob.glob(DATA_PATH + '/*.[Extension de fichier]')
        mask_files = glob.glob(MASK_PATH + '/*.[Extension de fichier]')

        self.dataset = []
        self.maskset = []

        #Importer l'image originale
        for data_file in data_files:
            self.dataset.append(Image.open(
                DATA_PATH + os.path.basename(data_file)
            ))

        #Charger l'image du masque
        for mask_file in mask_files:
            self.maskset.append(Image.open(
                MASK_PATH + os.path.basename(mask_file)
            ))

        #Divisé en données d'entraînement et données de test
        if train:
            self.dataset = self.dataset[:TRAIN_NUM]
            self.maskset = self.maskset[:TRAIN_NUM]
        else:
            self.dataset = self.dataset[TRAIN_NUM+1:]
            self.maskset = self.maskset[TRAIN_NUM+1:]

        # Data Augmentation
        #La conversion aléatoire se fait ici
        self.augmented_dataset = []
        self.augmented_maskset = []
        for num in range(len(self.dataset)):
            data = self.dataset[num]
            mask = self.maskset[num]
            #Recadrage aléatoire
            for crop_num in range(16):
                #Position de la culture déterminée par un nombre aléatoire
                i, j, h, w = transforms.RandomCrop.get_params(data, output_size=(250,250))
                cropped_data = tvf.crop(data, i, j, h, w)
                cropped_mask = tvf.crop(mask, i, j, h, w)
                
                #rotation(0, 90, 180,270 degrés)
                for rotation_num in range(4):
                    rotated_data = tvf.rotate(cropped_data, angle=90*rotation_num)
                    rotated_mask = tvf.rotate(cropped_mask, angle=90*rotation_num)
                    
                    #Soit une inversion horizontale, soit une inversion verticale
                    #Inverser(direction horizontale)
                    for h_flip_num in range(2):
                        h_flipped_data = transforms.RandomHorizontalFlip(p=h_flip_num)(rotated_data)
                        h_flipped_mask = transforms.RandomHorizontalFlip(p=h_flip_num)(rotated_mask)
                    
                    """    
                    #Inverser(Direction verticale)
                    for v_flip_num in range(2):
                        v_flipped_data = transforms.RandomVerticalFlip(p=v_flip_num)(h_flipped_data)
                        v_flipped_mask = transforms.RandomVerticalFlip(p=v_flip_num)(h_flipped_mask)
                    """
        
                        #Ajouter des données augmentées de données
                        self.augmented_dataset.append(h_flipped_data)
                        self.augmented_maskset.append(h_flipped_mask)

        self.datanum = len(self.augmented_dataset)

    #Méthode d'acquisition de la taille des données
    def __len__(self):
        return self.datanum

    #Méthode d'acquisition des données
    #La conversion non aléatoire est effectuée ici
    def __getitem__(self, idx):
        out_data = self.augmented_dataset[idx]
        out_mask = self.augmented_maskset[idx]

        if self.transform:
            out_data = self.transform(out_data)

        if self.target_transform:
            out_mask = self.target_transform(out_mask)

        return out_data, out_mask

Ce que nous faisons est simple, nous faisons l'augmentation des données à l'intérieur de __init __ () À ce moment-là, à propos de chaque paire d'images

Traitement complet dans tous les cas

Pour le moment, vous pouvez appliquer le même traitement que l'image d'origine à l'image de masque et effectuer une augmentation des données comme ceci ** [Supplément] Il est recommandé de n'utiliser que le traitement d'inversion horizontale ou verticale, car la combinaison de la rotation et de l'inversion peut provoquer une duplication! !! ** **

4 Comment utiliser

Essayez d'utiliser votre propre classe de jeu de données dans 3


import torch
import torchvision
import HogeDataset

BATCH_SIZE = [Taille du lot]

#Prétraitement
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
target_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224), interpolation=0), 
    torchvision.transforms.ToTensor()
])

#Préparation des données d'entraînement et des données de test
trainset = HogeDataset.HogeDataset(
    train=True,
    transform=transform, 
    target_transform=target_transform
)
trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

testset = EpiDatasets.EpiDatasets(
    train=False,
    transform=transform,
    target_transform=target_transform
)
testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

Recommended Posts

[PyTorch] Augmentation des données pour la segmentation
Comment augmenter les données avec PyTorch
Afficher l'image après l'augmentation des données avec Pytorch
Essayez la segmentation sémantique (Pytorch)
Préparation du jeu de données pour PyTorch
Augmentation des données avec openCV
Ensemble de données pour l'apprentissage automatique
Python pour l'analyse des données Chapitre 4
Python pour l'analyse des données Chapitre 2
Essayez l'augmentation de l'effacement aléatoire des données
Nouvelle augmentation des données? [Grid Mix]
Conseils et précautions lors de l'analyse des données
Python pour l'analyse des données Chapitre 3
[PyTorch] APPRENTISSAGE DE TRANSFERT POUR LA VISION INFORMATIQUE
Cours Python pour la science des données - techniques utiles
Extraits de code VS pour les analystes de données
Modèle de prétraitement pour l'analyse des données (Python)
Analyse de données pour améliorer POG 3 ~ Analyse de régression ~
Formatage des données pour les graphiques Python / couleur
Site de compétition recommandé pour les data scientists
Paramètres initiaux Mac (pour les analystes de données)