[PYTHON] [PyTorch] Datenerweiterung zur Segmentierung

0. Wer ist das Ziel dieses Artikels?

  1. Diejenigen, die die Bildsegmentierung mit PyTorch implementieren
  2. Diejenigen, die Daten mit ** Data Augmentation ** aufblasen möchten
  3. Diejenigen, die ** genau die gleiche Verarbeitung ** auf das entsprechende Originalbild und Maskenbild anwenden möchten
  4. Insbesondere diejenigen, die ** eigenen Datensatz ** verwenden (Daten nicht in torchvision.datasets)

1. Übersicht

Hauptsächlich für überwachte oder halbüberwachte Segmentierungsdatensätze

2 Probleme

Bevor wir uns den fraglichen Fall ansehen, betrachten wir zunächst den Fall, in dem es kein Problem gibt.

2.1 Kein Problemfall (Objektklassenerkennung usw.)

Bei der Datenerweiterung mit PyTorch wird die Konvertierung normalerweise wie folgt definiert.


transform = torchvision.transforms.Compose([
    #Um Winkelgrade drehen
    transforms.RandomRotation(degrees),
    #Horizontal umkehren
    transforms.RandomHorizontalFlip(),
    #Vertikal umkehren
    transforms.RandomVerticalFlip()
])

Ich werde es in das Argument des Datensatzes setzen


dataset = HogeDataset.HogeDataset(
    train=True, transform=transform
)

Wahrscheinlich ist dies kein Problem für die Objektklassenerkennung usw. Der Grund ist, dass die Lehrerdaten kein Bild sind, sodass Sie nur das Originalbild verarbeiten müssen.

2.2 Problemfälle (wie Segmentierung)

Der nächste Problemfall Der Unterschied zum vorherigen Fall besteht darin, dass die Lehrerdaten als Bild angegeben werden.


transform = torchvision.transforms.Compose([
    #Um Winkelgrade drehen
    transforms.RandomRotation(degrees),
    #Horizontal umkehren
    transforms.RandomHorizontalFlip(),
    #Vertikal umkehren
    transforms.RandomVerticalFlip()
])

Ich werde es in das Argument des Datensatzes setzen


dataset = HogeDataset.HogeDataset(
    train=True, transform=transform, target_transform=transform
)

Dies bedeutet jedoch nicht, dass die in das Originalbild und das Maskenbild vorgenommenen Konvertierungen beim Abrufen von Daten aus "HogeDataset" kompatibel sind. Beispiel) Originalbild: 90-Grad-Drehung, Maskenbild: 270-Grad-Drehung In diesem Fall fungieren die Daten selbst dann nicht als Lehrerdaten, wenn sie aufgeblasen sind. Argument target_transform, warum gibt es dich? Der Grund für diese Existenz ist jedoch wahrscheinlich, dass das Maskenbild auch (ohne Zufälligkeit) wie "torchvision.transforms.Resize ()" und "torchvision.transforms.ToTensor ()" verarbeitet wird. Ich denke es ist in

3 Lösung

Wie können wir also dieselbe Verarbeitung auf das Maskenbild anwenden wie auf das Originalbild? Als Lösung können Sie Ihre eigene Dataset-Klasse wie unten gezeigt erstellen.

HogeDataset.py


import os
import glob
import torch
from torchvision import transforms
from torchvision.transforms import functional as tvf
import random
from PIL import Image

DATA_PATH = '[Ursprünglicher Bildverzeichnispfad]'
MASK_PATH = '[Bildverzeichnispfad maskieren]'
TRAIN_NUM = [Anzahl der Trainingsdaten]

class HogeDataset(torch.utils.data.Dataset):
    def __init__(self, transform = None, target_transform = None, train = True):
        #transformieren und zielen_Transformation ist eine nicht zufällige Transformation wie Tensorisierung
        self.transform = transform
        self.target_transform = target_transform


        data_files = glob.glob(DATA_PATH + '/*.[Dateierweiterung]')
        mask_files = glob.glob(MASK_PATH + '/*.[Dateierweiterung]')

        self.dataset = []
        self.maskset = []

        #Originalbild importieren
        for data_file in data_files:
            self.dataset.append(Image.open(
                DATA_PATH + os.path.basename(data_file)
            ))

        #Maskenbild laden
        for mask_file in mask_files:
            self.maskset.append(Image.open(
                MASK_PATH + os.path.basename(mask_file)
            ))

        #Unterteilt in Trainingsdaten und Testdaten
        if train:
            self.dataset = self.dataset[:TRAIN_NUM]
            self.maskset = self.maskset[:TRAIN_NUM]
        else:
            self.dataset = self.dataset[TRAIN_NUM+1:]
            self.maskset = self.maskset[TRAIN_NUM+1:]

        # Data Augmentation
        #Hier erfolgt eine zufällige Konvertierung
        self.augmented_dataset = []
        self.augmented_maskset = []
        for num in range(len(self.dataset)):
            data = self.dataset[num]
            mask = self.maskset[num]
            #Zufällige Ernte
            for crop_num in range(16):
                #Ernteposition durch Zufallszahl bestimmt
                i, j, h, w = transforms.RandomCrop.get_params(data, output_size=(250,250))
                cropped_data = tvf.crop(data, i, j, h, w)
                cropped_mask = tvf.crop(mask, i, j, h, w)
                
                #Drehung(0, 90, 180,270 Grad)
                for rotation_num in range(4):
                    rotated_data = tvf.rotate(cropped_data, angle=90*rotation_num)
                    rotated_mask = tvf.rotate(cropped_mask, angle=90*rotation_num)
                    
                    #
                    #Umkehren(Horizontale Richtung)
                    for h_flip_num in range(2):
                        h_flipped_data = transforms.RandomHorizontalFlip(p=h_flip_num)(rotated_data)
                        h_flipped_mask = transforms.RandomHorizontalFlip(p=h_flip_num)(rotated_mask)
                    
                    """    
                    #Umkehren(Vertikale Richtung)
                    for v_flip_num in range(2):
                        v_flipped_data = transforms.RandomVerticalFlip(p=v_flip_num)(h_flipped_data)
                        v_flipped_mask = transforms.RandomVerticalFlip(p=v_flip_num)(h_flipped_mask)
                    """
        
                        #Daten hinzufügen Erweiterte Daten
                        self.augmented_dataset.append(h_flipped_data)
                        self.augmented_maskset.append(h_flipped_mask)

        self.datanum = len(self.augmented_dataset)

    #Methode zur Erfassung der Datengröße
    def __len__(self):
        return self.datanum

    #Datenerfassungsmethode
    #Hier erfolgt eine nicht zufällige Konvertierung
    def __getitem__(self, idx):
        out_data = self.augmented_dataset[idx]
        out_mask = self.augmented_maskset[idx]

        if self.transform:
            out_data = self.transform(out_data)

        if self.target_transform:
            out_mask = self.target_transform(out_mask)

        return out_data, out_mask

Was wir tun, ist einfach: Wir führen eine Datenerweiterung in __init __ () durch Zu diesem Zeitpunkt etwa jedes Bildpaar

Umfassende Bearbeitung in allen Fällen

Derzeit können Sie die gleiche Verarbeitung wie das Originalbild auf das Maskenbild anwenden und eine solche Datenvergrößerung durchführen ** [Ergänzung] Es wird empfohlen, nur die horizontale oder vertikale Inversionsverarbeitung zu verwenden, da die Kombination von Rotation und Inversion zu Doppelarbeit führen kann! !! ** **.

4 Verwendung

Versuchen Sie, Ihre eigene Dataset-Klasse in 3 zu verwenden


import torch
import torchvision
import HogeDataset

BATCH_SIZE = [Chargengröße]

#Vorverarbeitung
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
target_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224), interpolation=0), 
    torchvision.transforms.ToTensor()
])

#Vorbereitung von Trainingsdaten und Testdaten
trainset = HogeDataset.HogeDataset(
    train=True,
    transform=transform, 
    target_transform=target_transform
)
trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

testset = EpiDatasets.EpiDatasets(
    train=False,
    transform=transform,
    target_transform=target_transform
)
testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=BATCH_SIZE,
    shuffle=True
)

Recommended Posts

[PyTorch] Datenerweiterung zur Segmentierung
Aufblasen von Daten (Datenerweiterung) mit PyTorch
Zeigen Sie das Bild nach der Datenerweiterung mit Pytorch an
Versuchen Sie es mit semantischer Segmentierung (Pytorch)
Datensatzvorbereitung für PyTorch
Datenerweiterung mit openCV
Datensatz für maschinelles Lernen
Python für die Datenanalyse Kapitel 4
Python für die Datenanalyse Kapitel 2
Versuchen Sie die zufällige Löschung von Daten
Neue Datenerweiterung? [Grid Mix]
Tipps und Vorsichtsmaßnahmen bei der Datenanalyse
Python für die Datenanalyse Kapitel 3
[PyTorch] ÜBERTRAGUNGSLERNEN FÜR COMPUTERVISION
Python-Kurs für datenwissenschaftlich-nützliche Techniken
VS-Codefragmente für Datenanalysten
Vorverarbeitungsvorlage für die Datenanalyse (Python)
Datenanalyse zur Verbesserung von POG 3 ~ Regressionsanalyse ~
Datenformatierung für Python / Farbdiagramme
Empfohlene Wettbewerbsseite für Datenwissenschaftler
Mac-Grundeinstellungen (für Datenanalysten)