Pytorch collate_fn is an argument to Dataloader.
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
This time I would like to confirm its behavior and usage.
When the \ _ \ _ getitem \ _ \ _ defined in the dataset is in batch form, each element (image, target, etc.) is first consolidated in a list. Collate_fn manipulates it as described in Pytroch Official and eventually makes it torch.Tensor It is a function.
dataset_iter = iter(dataset)
for indices in batch_sampler:
yield collate_fn([next(dataset_iter) for _ in indices])
By default, it is only Tensor with torch.stack (), but you can make an advanced batch by using your own collate_fn.
The default behavior is almost the same as below. (Although the number of returns depends on \ _ \ _ getitem \ _ \ _) It takes a batch as an argument, stacks it, and returns it.
def collate_fn(batch):
images, targets= list(zip(*batch))
images = torch.stack(images)
targets = torch.stack(targets)
return images, targets
You can change the contents of your own collate_fn.
This time we will create a batch of object detection. Object detection basically inputs the rectangle of the object and its label, but since there may be multiple rectangles in one image, it is necessary to connect which image is which rectangle when batching, and the index Must be attached.
[[label, xc, yx, w, h],
[ ],
[ ],...]
#Change this down
[[0, label xc, yx, w, h],
[0, ],
[1, ],...]
The implementation itself is not that difficult.
def batch_idx_fn(batch):
images, bboxes = list(zip(*batch))
targets = []
for idx, bbox in enumerate(bboxes):
target = np.zeros((len(bbox), 6))
target[:, 1:] = bbox
target[:, 0] = idx
targets.append(target)
images = torch.stack(images)
targets = torch.Tensor(np.concatenate(targets)) # [[batch_idx, label, xc, yx, w, h], ...]
return images, targets
When you actually use it, it will be as follows.
test_data_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
collate_fn=batch_idx_fn
)
print(iter(test_data_loader).next()[0])
# [[0.0000, 0.0000, 0.6001, 0.5726, 0.1583, 0.1119],
# [0.0000, 9.0000, 0.0568, 0.5476, 0.1150, 0.1143],
# [1.0000, 5.0000, 0.8316, 0.4113, 0.1080, 0.3452],
# [1.0000, 0.0000, 0.3476, 0.6494, 0.1840, 0.1548],
# [2.0000, 2.0000, 0.8276, 0.6763, 0.1720, 0.3240],
# [2.0000, 4.0000, 0.1626, 0.0496, 0.0900, 0.0880],
# [2.0000, 5.0000, 0.2476, 0.2736, 0.1400, 0.5413],
# [2.0000, 5.0000, 0.5786, 0.4523, 0.4210, 0.5480],
# [3.0000, 0.0000, 0.4636, 0.4618, 0.0400, 0.1024],
# [3.0000, 0.0000, 0.5706, 0.5061, 0.0380, 0.0683]]
Other than when indexing in this article When the target changes for each batch, When the target is not numerical data that cannot be stacked I think it can be used when you want to use the same Dataset with slightly different changes.
Recommended Posts