[PYTHON] Versuchen Sie es mit Pytorchs collate_fn

Pytorch collate_fn ist ein Argument von Dataloader.

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

Dieses Mal möchte ich sein Verhalten und seine Verwendung bestätigen.

Was ist collate_fn?

Wenn das im Datensatz definierte \ _ \ _ getitem \ _ \ _ die Form eines Stapels hat, wird jedes Element (Bild, Ziel usw.) zuerst in einer Liste konsolidiert. Collate_fn manipuliert es wie in Pytroch Official beschrieben und macht es schließlich zu torch.Tensor Es ist eine Funktion.

dataset_iter = iter(dataset)
for indices in batch_sampler:
    yield collate_fn([next(dataset_iter) for _ in indices])

Standardmäßig wird torch.stack () verwendet, um Tensor zu erstellen. Sie können jedoch einen erweiterten Stapel erstellen, indem Sie Ihre eigene collate_fn verwenden.

Machen Sie Ihre eigene collate_fn

Das Standardverhalten ist fast das gleiche wie unten. (Obwohl die Anzahl der Rückgaben von \ _ \ _ getitem \ _ \ _ abhängt) Es nimmt Batch als Argument, stapelt es und gibt es zurück.

def collate_fn(batch):
    images, targets= list(zip(*batch))
    images = torch.stack(images)
    targets = torch.stack(targets)
    return images, targets

Sie können den Inhalt Ihrer eigenen collate_fn ändern.

Dieses Mal erstellen wir einen Stapel Objekterkennung. Die Objekterkennung gibt im Wesentlichen das Rechteck des Objekts und seine Beschriftung ein. Da ein Bild jedoch mehrere Rechtecke enthalten kann, müssen Sie beim Stapeln verbinden, welches Bild welches Rechteck ist, und den Index Muss beigefügt sein.

[[label, xc, yx, w, h],
 [                   ],
 [                   ],...]

#Ändern Sie dies nach unten

[[0, label xc, yx, w, h],
 [0,                   ],
 [1,                   ],...]

Die Implementierung selbst ist nicht so schwierig.

def batch_idx_fn(batch):
    images, bboxes = list(zip(*batch))
    targets = []
    for idx, bbox in enumerate(bboxes):
        target = np.zeros((len(bbox), 6))
        target[:, 1:] = bbox
        target[:, 0] = idx
        targets.append(target)
    images = torch.stack(images)
    targets = torch.Tensor(np.concatenate(targets)) # [[batch_idx, label, xc, yx, w, h], ...]
    return images, targets

Wenn Sie es tatsächlich verwenden, wird es wie folgt sein.

test_data_loader = torch.utils.data.DataLoader(
                       test_dataset, 
                       batch_size=1, 
                       shuffle=False, 
                       collate_fn=batch_idx_fn
                       )
print(iter(test_data_loader).next()[0])
# [[0.0000, 0.0000, 0.6001, 0.5726, 0.1583, 0.1119],
# [0.0000, 9.0000, 0.0568, 0.5476, 0.1150, 0.1143],
# [1.0000, 5.0000, 0.8316, 0.4113, 0.1080, 0.3452],
# [1.0000, 0.0000, 0.3476, 0.6494, 0.1840, 0.1548],
# [2.0000, 2.0000, 0.8276, 0.6763, 0.1720, 0.3240],
# [2.0000, 4.0000, 0.1626, 0.0496, 0.0900, 0.0880],
# [2.0000, 5.0000, 0.2476, 0.2736, 0.1400, 0.5413],
# [2.0000, 5.0000, 0.5786, 0.4523, 0.4210, 0.5480],
# [3.0000, 0.0000, 0.4636, 0.4618, 0.0400, 0.1024],
# [3.0000, 0.0000, 0.5706, 0.5061, 0.0380, 0.0683]]

abschließend

Anders als bei der Indizierung in diesem Artikel Wenn sich das Ziel für jede Charge ändert, Wenn das Ziel keine numerischen Daten sind, die nicht gestapelt werden können Sie können es verwenden, wenn Sie denselben Datensatz mit geringfügig unterschiedlichen Änderungen verwenden möchten.

Recommended Posts

Versuchen Sie es mit Pytorchs collate_fn
Versuchen Sie es mit Tkinter
Versuchen Sie es mit Docker-Py
Versuchen Sie es mit einem Ausstecher
Versuchen Sie es mit Geopandas
Versuchen Sie es mit Selen
Versuchen Sie es mit scipy
Versuchen Sie es mit pandas.DataFrame
Versuchen Sie es mit Django-Swiftbrowser
Versuchen Sie es mit matplotlib
Versuchen Sie es mit tf.metrics
Versuchen Sie es mit virtualenv (virtualenvwrapper)
Versuchen Sie es jetzt mit virtualenv
[Kaggle] Versuchen Sie es mit LGBM
Versuchen Sie es mit dem Feed-Parser von Python.
Versuchen Sie es mit Pythons Tkinter
Versuchen Sie es mit Tweepy [Python2.7]
Versuchen Sie, PythonTex mit Texpad zu verwenden.
[Python] Versuchen Sie, Tkinters Leinwand zu verwenden
Versuchen Sie es mit Jupyters Docker-Image
Versuchen Sie es mit Scikit-Learn (1) - K-Clustering nach Durchschnittsmethode
Versuchen Sie es mit matplotlib mit PyCharm
Versuchen Sie es mit Azure Logic Apps
Versuchen Sie es mit Kubernetes Client -Python-
Versuchen Sie es mit der Twitter-API
Versuchen Sie es mit OpenCV unter Windows
Versuchen Sie, Jupyter Notebook dynamisch zu verwenden
Versuchen Sie es mit AWS SageMaker Studio
Versuchen Sie, automatisch mit Selen zu twittern.
Versuchen Sie es mit SQLAlchemy + MySQL (Teil 1)
Versuchen Sie es mit der Twitter-API
Versuchen Sie es mit SQLAlchemy + MySQL (Teil 2)
Versuchen Sie es mit der PeeringDB 2.0-API
Versuchen Sie es mit der Entwurfsfunktion von Pelican
Versuchen Sie es mit pytest-Overview und Samples-
Versuchen Sie es mit Folium mit Anakonda
Versuchen Sie es mit der Admin-API von Janus Gateway
Versuchen Sie es mit Spyder, das in Anaconda enthalten ist
Versuchen Sie es mit Designmustern (Exporter Edition)
Versuchen Sie es mit Pillow auf iPython (Teil 1)
Versuchen Sie es mit Pillow auf iPython (Teil 2)
Versuchen Sie es mit der Pleasant-API (Python / FastAPI).
Versuchen Sie es mit LevelDB mit Python (plyvel)
Versuchen Sie, Nagios mit pynag zu konfigurieren
Versuchen Sie, die Remote-Debugging-Funktion von PyCharm zu verwenden
Versuchen Sie es mit ArUco mit Raspberry Pi
Versuchen Sie es mit billigem LiDAR (Camsense X1)
[Sakura-Mietserver] Versuchen Sie es mit einer Flasche.
Versuchen Sie, Statistiken mit e-Stat abzurufen
Versuchen Sie es mit dem Python Cmd-Modul
Versuchen Sie, Pythons networkx mit AtCoder zu verwenden
Versuchen Sie es mit LeapMotion mit Python
Versuchen Sie es mit der handgeschriebenen Zeichenerkennung (OCR) von GCP.
Versuchen Sie es mit Amazon DynamoDB von Python
Code-Server Lokale Umgebung (3) Verwenden Sie das VSCode-Plugin
Versuchen Sie es mit der Wunderlist-API in Python
Versuchen Sie eine Formel mit Σ mit Python
Versuchen Sie es mit dem Webanwendungsframework Flask
Versuchen Sie es mit Bash unter Windows 10 2 (TensorFlow-Installation)
Versuchen Sie, die Kraken-API mit Python zu verwenden
Versuchen Sie es mit dem $ 6 Rabatt LiDAR (Camsense X1)