Pytorch collate_fn ist ein Argument von Dataloader.
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
Dieses Mal möchte ich sein Verhalten und seine Verwendung bestätigen.
Wenn das im Datensatz definierte \ _ \ _ getitem \ _ \ _ die Form eines Stapels hat, wird jedes Element (Bild, Ziel usw.) zuerst in einer Liste konsolidiert. Collate_fn manipuliert es wie in Pytroch Official beschrieben und macht es schließlich zu torch.Tensor Es ist eine Funktion.
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
Standardmäßig wird torch.stack () verwendet, um Tensor zu erstellen. Sie können jedoch einen erweiterten Stapel erstellen, indem Sie Ihre eigene collate_fn verwenden.
Das Standardverhalten ist fast das gleiche wie unten. (Obwohl die Anzahl der Rückgaben von \ _ \ _ getitem \ _ \ _ abhängt) Es nimmt Batch als Argument, stapelt es und gibt es zurück.
def collate_fn(batch):
images, targets= list(zip(*batch))
images = torch.stack(images)
targets = torch.stack(targets)
return images, targets
Sie können den Inhalt Ihrer eigenen collate_fn ändern.
Dieses Mal erstellen wir einen Stapel Objekterkennung. Die Objekterkennung gibt im Wesentlichen das Rechteck des Objekts und seine Beschriftung ein. Da ein Bild jedoch mehrere Rechtecke enthalten kann, müssen Sie beim Stapeln verbinden, welches Bild welches Rechteck ist, und den Index Muss beigefügt sein.
[[label, xc, yx, w, h],
[ ],
[ ],...]
#Ändern Sie dies nach unten
[[0, label xc, yx, w, h],
[0, ],
[1, ],...]
Die Implementierung selbst ist nicht so schwierig.
def batch_idx_fn(batch):
images, bboxes = list(zip(*batch))
targets = []
for idx, bbox in enumerate(bboxes):
target = np.zeros((len(bbox), 6))
target[:, 1:] = bbox
target[:, 0] = idx
targets.append(target)
images = torch.stack(images)
targets = torch.Tensor(np.concatenate(targets)) # [[batch_idx, label, xc, yx, w, h], ...]
return images, targets
Wenn Sie es tatsächlich verwenden, wird es wie folgt sein.
test_data_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
collate_fn=batch_idx_fn
)
print(iter(test_data_loader).next()[0])
# [[0.0000, 0.0000, 0.6001, 0.5726, 0.1583, 0.1119],
# [0.0000, 9.0000, 0.0568, 0.5476, 0.1150, 0.1143],
# [1.0000, 5.0000, 0.8316, 0.4113, 0.1080, 0.3452],
# [1.0000, 0.0000, 0.3476, 0.6494, 0.1840, 0.1548],
# [2.0000, 2.0000, 0.8276, 0.6763, 0.1720, 0.3240],
# [2.0000, 4.0000, 0.1626, 0.0496, 0.0900, 0.0880],
# [2.0000, 5.0000, 0.2476, 0.2736, 0.1400, 0.5413],
# [2.0000, 5.0000, 0.5786, 0.4523, 0.4210, 0.5480],
# [3.0000, 0.0000, 0.4636, 0.4618, 0.0400, 0.1024],
# [3.0000, 0.0000, 0.5706, 0.5061, 0.0380, 0.0683]]
Anders als bei der Indizierung in diesem Artikel Wenn sich das Ziel für jede Charge ändert, Wenn das Ziel keine numerischen Daten sind, die nicht gestapelt werden können Sie können es verwenden, wenn Sie denselben Datensatz mit geringfügig unterschiedlichen Änderungen verwenden möchten.
Recommended Posts