[PYTHON] I tried to implement SSD with PyTorch now (Dataset)

Introduction

As the title suggests, I implemented Single Shot Multibox Detector (SSD) with PyTorch ([https://github.com] /jjjkkkjjj/pytorch_SSD](https://github.com/jjjkkkjjj/pytorch_SSD)) However, the calculation is calculated compared to ssd.pytorch etc. It's slow (I'll investigate the cause later). </ strike> [^ 1] However, I worked hard on the abstraction, so I think it is highly customizable. (I don't know if you can document how to use it properly ...). If you look up the SSD implementation, you will find a lot, but this time,

[^ 1]: I just didn't set the initialization argument num_workers of DataLoader. .. .. Currently, it is as fast as ssd.pytorch.

――I wanted to customize it freely --I wanted to understand SSD by implementing it --I was free with corona

I implemented it for that reason. There are many easy-to-understand explanations (reference) in many articles, but I would like to summarize them in my own way to organize my mind. So, this time I would like to summarize around the dataset.

-I implemented SSD with PyTorch now (Dataset) -I tried to implement SSD with PyTorch (model edition)

PyTorch First of all, I will briefly touch on the deep learning framework PyTorch used this time. Actually, when I first tried to implement SSD, I used Tensorflow. However,

--Difficult to debug --There are too many similar functions (compat.v1, compat.v2, etc.?)

Therefore, the implementation did not proceed easily. In particular, "hard to debug" was fatal to me, and I couldn't understand the translation well. Well, I think it would be useful to learn how to use it just because I didn't understand Tensorflow. .. ..

I thought that it would not be completed for the rest of my life, so I changed it to Pytorch, which has the feature of being able to perform operations similar to Numpy. When I changed to PyTorch, the" difficulty of debugging "that I felt in Tensorflow was considerably improved, and the implementation went smoothly. Well, the operation of Numpy is Matft (https://github.com/jjjkkkjjj/Matft), [I made an N-dimensional matrix calculation library Matft with Swift] Get used to implementing (https://qiita.com/jjjkkkjjj/items/1f2b5c3835b1600d3129)) and MILES (https://github.com/jjjkkkjjj/MIL) Because it was, PyTorch was perfect for me.

What is SSD?

First of all, I would like to briefly touch on what SSD is and then summarize it in detail. SSD is an object detection algorithm that can predict the position and label of an object end-to-end. It is a suitable figure, but if you give an input image like this, the SSD will output the position and label of the object at once.

end-to-end.png

model.png

What this model does and does is as follows. I will explain step by step.

  • data set -[Read Dataset](#VOC Dataset Read)
  • Input data -Normalized $ [0,1] $ RGB image (This time, see Transform) --What to predict --Offset value from Default Box → Position of bounding box --Label --Learning flow --Create Default Box --Enter the image and correct label --Assign the correct label bounding box to the Default Box (matching strategy) -** Normalization of correct label! !! ** ** --Calculation of localization loss and confidence loss (hard negative mining) --Test flow --Create Default Box --Enter image --Predict offset value and label with Default Box --Remove extra Box (Non maximum suppression)

Read VOC dataset

In the original SSD paper, the datasets are PASCAL VOC2007 and [PASCAL VOC2012](http://host.robots.ox. .ac.uk/pascal/VOC/voc2012/) and COCO2014 are used. COCO has not been implemented yet, so I will explain the VOC dataset. </ strike> First, let's talk about VOC datasets.

Construction

The structure of the directory is basically unified, and it looks like the following.

voc directory


$ tree -I '*.png|*.xml|*.jpg|*.txt'
└── VOCdevkit
    └── VOC20**
        ├── Annotations
        ├── ImageSets
        │   ├── Action
        │   ├── Layout
        │   ├── Main
        │   └── Segmentation
        ├── JPEGImages
        ├── SegmentationClass
        └── SegmentationObject

What is required for object detection is ʻAnnotations, JPEGImages, and ʻImageSets / Main directly underVOC20 **. Each is as follows.

  • Annotations --Contains the .xml file of annotation data. It is one with the .jpeg file of JPEGImages.
  • JPEGImages --Contains the .jpeg file of the image. It is one with the .xml file of ʻAnnotations`.
  • ImageSets/Main --Contains a .txt file that represents information about the dataset set. The file names of ʻAnnotation and JPEGImages`, which are the elements of the set, are described.

Annotation data (.xml file)

The .xml file under the ʻAnnotaions` directory is as follows.

Annotations/~~.xml


<annotation>
	<folder>VOC2007</folder>
	<filename>000005.jpg</filename>
	<source>
		<database>The VOC2007 Database</database>
		<annotation>PASCAL VOC2007</annotation>
		<image>flickr</image>
		<flickrid>325991873</flickrid>
	</source>
	<owner>
		<flickrid>archintent louisville</flickrid>
		<name>?</name>
	</owner>
	<size>
		<width>500</width>
		<height>375</height>
		<depth>3</depth>
	</size>
	<segmented>0</segmented>
	<object>
		<name>chair</name>
		<pose>Rear</pose>
		<truncated>0</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>263</xmin>
			<ymin>211</ymin>
			<xmax>324</xmax>
			<ymax>339</ymax>
		</bndbox>
	</object>
	<object>
		<name>chair</name>
		<pose>Unspecified</pose>
		<truncated>0</truncated>
		<difficult>0</difficult>
		<bndbox>
			<xmin>165</xmin>
			<ymin>264</ymin>
			<xmax>253</xmax>
			<ymax>372</ymax>
		</bndbox>
	</object>
        ...
</annotation>

The important points are as follows.

  • <filename> --The .jpeg file that this annotation data corresponds to
  • <object>
    • <name> --Label name
    • <truncated> --Whether the object is entirely visible (0) or partially visible (1).
    • <difficult> --Difficult (1) or not (0)
    • <bndbox> --Bounding box (position of object). corners notation

Implementation

To implement a dataset, you need to extend the Dataset class. Then, it is necessary to implement __len__ which returns the number of datasets and __getitem__ which returns the input data and the correct answer label for ʻindex` within the range of the number of datasets.

What we are doing with the following implementation is

--Save the paths of the .xml files directly under ʻAnnotations in self._annopathsas a list --Get the image and bounding box forself._annopaths [index] from ʻindex given in __getitem__ --Images are read by OpenCV and returned by ** RGB order *** (input data) --The bounding box is normalized by the width and height of the image. --Bounding box and label are returned as concatenate (correct label)

is.

  • The RGB order is used because the VGG Pre-Trained model distributed by PyTorch is used. The Pre-Trained model of PyTorch trains the image normalized by RGB order,mean = (0.485, 0.456, 0.406),std = (0.229, 0.224, 0.225)as input. (Reference)

ObjectDetectionDatasetBase


class ObjectDetectionDatasetBase(_DatasetBase):
    def __init__(self, ignore=None, transform=None, target_transform=None, augmentation=None):

abridgement

    def __getitem__(self, index):
        """
        :param index: int
        :return:
            img : rgb image(Tensor or ndarray)
            targets : Tensor or ndarray of bboxes and labels [box, label]
            = [xmin, ymin, xmamx, ymax, label index(or relu_one-hotted label)]
            or
            = [cx, cy, w, h, label index(or relu_one-hotted label)]
        """
        img = self._get_image(index)
        bboxes, linds, flags = self._get_bbox_lind(index)

        img, bboxes, linds, flags = self.apply_transform(img, bboxes, linds, flags)

        # concatenate bboxes and linds
        if isinstance(bboxes, torch.Tensor) and isinstance(linds, torch.Tensor):
            if linds.ndim == 1:
                linds = linds.unsqueeze(1)
            targets = torch.cat((bboxes, linds), dim=1)
        else:
            if linds.ndim == 1:
                linds = linds[:, np.newaxis]
            targets = np.concatenate((bboxes, linds), axis=1)

        return img, targets

    def apply_transform(self, img, bboxes, linds, flags):
        """
        IMPORTATANT: apply transform function in order with ignore, augmentation, transform and target_transform
        :param img:
        :param bboxes:
        :param linds:
        :param flags:
        :return:
            Transformed img, bboxes, linds, flags
        """
        # To Percent mode
        height, width, channel = img.shape
        # bbox = [xmin, ymin, xmax, ymax]
        # [bbox[0] / width, bbox[1] / height, bbox[2] / width, bbox[3] / height]
        bboxes[:, 0::2] /= float(width)
        bboxes[:, 1::2] /= float(height)

        if self.ignore:
            bboxes, linds, flags = self.ignore(bboxes, linds, flags)

        if self.augmentation:
            img, bboxes, linds, flags = self.augmentation(img, bboxes, linds, flags)

        if self.transform:
            img, bboxes, linds, flags = self.transform(img, bboxes, linds, flags)

        if self.target_transform:
            bboxes, linds, flags = self.target_transform(bboxes, linds, flags)

        return img, bboxes, linds, flags

VOCDatasetBase


VOC_class_labels = ['aeroplane', 'bicycle', 'bird', 'boat',
    'bottle', 'bus', 'car', 'cat', 'chair',
    'cow', 'diningtable', 'dog', 'horse',
    'motorbike', 'person', 'pottedplant',
    'sheep', 'sofa', 'train', 'tvmonitor']
VOC_class_nums = len(VOC_class_labels)
class VOCSingleDatasetBase(ObjectDetectionDatasetBase):
    def __init__(self, voc_dir, focus, ignore=None, transform=None, target_transform=None, augmentation=None, class_labels=None):
        """
        :param voc_dir: str, voc directory path above 'Annotations', 'ImageSets' and 'JPEGImages'
                e.g.) voc_dir = '~~~~/trainval/VOCdevkit/voc2007'
        :param focus: str, image set name. Assign txt file name under 'ImageSets' directory
        :param ignore: target_transforms.Ignore
        :param transform: instance of transforms
        :param target_transform: instance of target_transforms
        :param augmentation:  instance of augmentations
        :param class_labels: None or list or tuple, if it's None use VOC_class_labels
        """
        super().__init__(ignore=ignore, transform=transform, target_transform=target_transform, augmentation=augmentation)

        self._voc_dir = voc_dir
        self._focus = focus
        self._class_labels = _check_ins('class_labels', class_labels, (list, tuple), allow_none=True)
        if self._class_labels is None:
            self._class_labels = VOC_class_labels

        layouttxt_path = os.path.join(self._voc_dir, 'ImageSets', 'Main', self._focus + '.txt')
        if os.path.exists(layouttxt_path):
            with open(layouttxt_path, 'r') as f:
                filenames = f.read().splitlines()
                filenames = [filename.split()[0] for filename in filenames]
                self._annopaths = [os.path.join(self._voc_dir, 'Annotations', '{}.xml'.format(filename)) for filename in filenames]
        else:
            raise FileNotFoundError('layout: {} was invalid arguments'.format(focus))

    @property
    def class_nums(self):
        return len(self._class_labels)
    @property
    def class_labels(self):
        return self._class_labels

    def _jpgpath(self, filename):
        """
        :param filename: path containing .jpg
        :return: path of jpg
        """
        return os.path.join(self._voc_dir, 'JPEGImages', filename)

    def __len__(self):
        return len(self._annopaths)

    """
    Detail of contents in voc > https://towardsdatascience.com/coco-data-format-for-object-detection-a4c5eaf518c5

    VOC bounding box (xmin, ymin, xmax, ymax)
    """
    def _get_image(self, index):
        """
        :param index: int
        :return:
            rgb image(ndarray)
        """
        root = ET.parse(self._annopaths[index]).getroot()
        img = cv2.imread(self._jpgpath(_get_xml_et_value(root, 'filename')))
        # pytorch's image order is rgb
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return img.astype(np.float32)

    def _get_bbox_lind(self, index):
        """
        :param index: int
        :return:
            list of bboxes, list of bboxes' label index, list of flags([difficult, truncated,...])
        """
        linds = []
        bboxes = []
        flags = []

        root = ET.parse(self._annopaths[index]).getroot()
        for obj in root.iter('object'):
            linds.append(self._class_labels.index(_get_xml_et_value(obj, 'name')))

            bndbox = obj.find('bndbox')

            # bbox = [xmin, ymin, xmax, ymax]
            bboxes.append([_get_xml_et_value(bndbox, 'xmin', int), _get_xml_et_value(bndbox, 'ymin', int), _get_xml_et_value(bndbox, 'xmax', int), _get_xml_et_value(bndbox, 'ymax', int)])

            flags.append({'difficult': _get_xml_et_value(obj, 'difficult', int) == 1})#,
                          #'partial': _get_xml_et_value(obj, 'truncated', int) == 1})

        return np.array(bboxes, dtype=np.float32), np.array(linds, dtype=np.float32), flags

Read COCO dataset

Construction

The structure of the directory is the same as VOC in that it is divided into annotations and images (ʻannotations and ʻimages / {train or val} 20 **), but the handling of annotations is slightly different.

├── annotations
│   ├── captions_train2014.json
│   ├── captions_val2014.json
│   ├── instances_train2014.json
│   ├── instances_val2014.json
│   ├── person_keypoints_train2014.json
│   └── person_keypoints_val2014.json
└── images
    ├── train2014
    └── val2014

As you can see, unlike VOCs, all annotations are written in one file. And what you need for object detection is the ʻinstances_ {train or val} 20 **. Json file. The format is described in detail in [Official](http://cocodataset.org/#format-data). And since [python api](https://github.com/cocodataset/cocoapi) is prepared in COCO, the Annotation file for object detection is ʻinstances_ {train or val} 20 **. Json. If you know, honestly, you don't have to understand the contents very much.

Just in case, when I check the format, it looks like this.

  • instances_{train or val}20**.json --The information is described in the following format.
{
  "info": info, 
  "images": [image], 
  "annotations": [annotation], 
  "licenses": [license],
}

info{
  "year": int, 
  "version": str, 
  "description": str, 
  "contributor": str, 
  "url": str, 
  "date_created": datetime,
}

image{
  "id": int, 
  "width": int, 
  "height": int,
  "file_name": str, 
  "license": int, 
  "flickr_url": str, 
  "coco_url": str, 
  "date_captured": datetime,
}

license{
  "id": int, 
  "name": str,
  "url": str,
}

The ʻannotationandcatecories` of object detection are as follows.

annotation{
  "id": int,
  "image_id": int, 
  "category_id": int, 
  "segmentation": RLE or [polygon], 
  "area": float, "bbox": [x,y,width,height], 
  "iscrowd": 0 or 1,
}

categories[{
  "id": int, 
  "name": str, 
  "supercategory": str,
}]

Implementation

Just implement it like VOC. The necessary information is acquired via the COCO object of the API.

COCO_class_labels = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
                    'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
                    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
                    'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
                    'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
                    'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
                    'kite', 'baseball bat', 'baseball glove', 'skateboard',
                    'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
                    'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
                    'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
                    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
                    'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
                    'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
                    'refrigerator', 'book', 'clock', 'vase', 'scissors',
                    'teddy bear', 'hair drier', 'toothbrush']
COCO_class_nums = len(COCO_class_labels)

COCO2014_ROOT = os.path.join(DATA_ROOT, 'coco', 'coco2014')
class COCOSingleDatasetBase(ObjectDetectionDatasetBase):
    def __init__(self, coco_dir, focus, ignore=None, transform=None, target_transform=None, augmentation=None, class_labels=None):
        """
        :param coco_dir: str, coco directory path above 'annotations' and 'images'
                e.g.) coco_dir = '~~~~/coco2007/trainval'
        :param focus: str or str, directory name under images
                e.g.) focus = 'train2014'
        :param ignore: target_transforms.Ignore
        :param transform: instance of transforms
        :param target_transform: instance of target_transforms
        :param augmentation:  instance of augmentations
        :param class_labels: None or list or tuple, if it's None use VOC_class_labels
        """
        super().__init__(ignore=ignore, transform=transform, target_transform=target_transform, augmentation=augmentation)

        self._coco_dir = coco_dir
        self._focus = focus

        self._class_labels = _check_ins('class_labels', class_labels, (list, tuple), allow_none=True)
        if self._class_labels is None:
            self._class_labels = COCO_class_labels

        self._annopath = os.path.join(self._coco_dir, 'annotations', 'instances_' + self._focus + '.json')
        if os.path.exists(self._annopath):
            self._coco = COCO(self._annopath)
        else:
            raise FileNotFoundError('json: {} was not found'.format('instances_' + self._focus + '.json'))


        # remove no annotation image
        self._imageids = list(self._coco.imgToAnns.keys())

    @property
    def class_nums(self):
        return len(self._class_labels)
    @property
    def class_labels(self):
        return self._class_labels

    def _jpgpath(self, filename):
        """
        :param filename: path containing .jpg
        :return: path of jpg
        """
        return os.path.join(self._coco_dir, 'images', self._focus, filename)

    def __len__(self):
        return len(self._imageids)

    """
    Detail of contents in voc > https://towardsdatascience.com/coco-data-format-for-object-detection-a4c5eaf518c5

    VOC bounding box (xmin, ymin, xmax, ymax)
    """
    def _get_image(self, index):
        """
        :param index: int
        :return:
            rgb image(ndarray)
        """

        """
        self._coco.loadImgs(self._imageids[index]): list of dict, contains;
            license: int
            file_name: str
            coco_url: str
            height: int
            width: int
            date_captured: str
            flickr_url: str
            id: int
        """
        filename = self._coco.loadImgs(self._imageids[index])[0]['file_name']
        img = cv2.imread(self._jpgpath(filename))
        # pytorch's image order is rgb
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return img.astype(np.float32)

    def _get_bbox_lind(self, index):
        """
        :param index: int
        :return:
            list of bboxes, list of bboxes' label index, list of flags([difficult, truncated,...])
        """
        linds = []
        bboxes = []
        flags = []

        # anno_ids is list
        anno_ids = self._coco.getAnnIds(self._imageids[index])

        # annos is list of dict
        annos = self._coco.loadAnns(anno_ids)
        for anno in annos:
            """
            anno's  keys are;
                segmentation: list of float
                area: float
                iscrowd: int, 0 or 1
                image_id: int
                bbox: list of float, whose length is 4
                category_id: int
                id: int
            """
            """
            self._coco.loadCats(anno['category_id']) is list of dict, contains;
                supercategory: str
                id: int
                name: str
            """
            cat = self._coco.loadCats(anno['category_id'])[0]

            linds.append(self.class_labels.index(cat['name']))

            # bbox = [xmin, ymin, w, h]
            xmin, ymin, w, h = anno['bbox']
            # convert to corners
            xmax, ymax = xmin + w, ymin + h
            bboxes.append([xmin, ymin, xmax, ymax])

            """
            flag = {}
            keys = ['iscrowd']
            for key in keys:
                if key in anno.keys():
                    flag[key] = anno[key] == 1
                else:
                    flag[key] = False
            flags.append(flag)
            """
            flags.append({'difficult': anno['iscrowd'] == 1})

        return np.array(bboxes, dtype=np.float32), np.array(linds, dtype=np.float32), flags

Augmentation Augmentation is not always necessary, but even in the Original Paper

Data augmentation is crucial

It seems to be important because it is mentioned as. In the original paper, the specific method is omitted, but there are two main types of Augmentation methods.

  • Geometric Distortions
  • Photometric Distortions

In the following, I will write about how this original image is augmented.

image.png

Geometric Distortions In Geometric Distortions, there are the following three methods.

  • Random Expand --As the name suggests, the size is expanded randomly. --Fill the margins for size expansion with the mean rgb_mean = (103.939, 116.779, 123.68) used in normalization. image.png

  • Random Sample --Sample at random. --At that time, the threshold value of the degree of overlap (IoU value) between the sampled image and the bounding box is randomly determined. --One of (0.1,0.3,0.5,0.7,0.9, None) --None has no threshold. However, ʻIoU = 0` without overlap is excluded. --Repeat until the sample image exceeds the threshold.

image.png image.png

  • Random Flip --Randomly flips.

image.png

I will omit the implementation for a moment. Specifically, it is here. See here for other image examples.

Photometric Distortions With Photometric Distortions, there are the following five methods.

  • Random Brightness --Add a randomly selected value in the range of $ [-32,32] $ to the RGB value

image.png

  • Random Contrast --Multiply the RGB value by a randomly selected value in the range of $ [0.5,1.5] $

image.png

  • Random Hue --Image in [HSV Space](https://ja.wikipedia.org/wiki/HSV%E8%89%B2%E7%A9%BA%E9%96%93#:~:text=HSV%E3%83 % A2% E3% 83% 87% E3% 83% AB (% E8% 8B% B1% 3A% 20HSV% 20model,% E3% 80% 81 Brightness)% E3% 81% A8% E3% 82% 82% E8% Converted to A8% 80% E3% 82% 8F% E3% 82% 8C% E3% 82% 8B% E3% 80% 82) and randomly selected in the Hue value range of $ [-18,18] $ Add the values

image.png

  • Random Saturation --Image in [HSV Space](https://ja.wikipedia.org/wiki/HSV%E8%89%B2%E7%A9%BA%E9%96%93#:~:text=HSV%E3%83 % A2% E3% 83% 87% E3% 83% AB (% E8% 8B% B1% 3A% 20HSV% 20model,% E3% 80% 81 Brightness)% E3% 81% A8% E3% 82% 82% E8% Converted to A8% 80% E3% 82% 8F% E3% 82% 8C% E3% 82% 8B% E3% 80% 82) and randomly selected in the range of $ [0.5,1.5] $ for the Saturation value. Multiply the value --It may be a bug. .. ..

image.png

  • Random Lighting Noise --Randomly replace the Channel value

image.png

I will omit the implementation for a moment. Specifically, it is here. See here for other image examples.

Transform

Input image preprocessing

This is the preprocessing of the input image.

--Resize (300x300, 512x512, etc.) --Convert RGB input image ndarray to torch.Tensor -Convert from $ [0,255] $ to $ [0,1] $

  • **Normalization! !! !! !! *** ← I personally think it is important
  • If it is not normalized, it will not converge well. Normalization $x_{norm}=\frac{x-x_{mean}}{x_{std}}$ It is a process to convert to 0 on average and 1 on variance. Generally, rgb_means (0.485, 0.456, 0.406), rgb_stds = (0.229, 0.224, 0.225) are used for the mean and variance of the image, respectively. (I forgot, the mean and variance of the VGG dataset?)

Implementation

The processing in ↑ is implemented as follows. PyTorch has a processing function transforms for the preprocessed input image, but the PIL image Because it is a function for, I made it for Opencv. If you create your own transforms, you need to correspond to the processing in the class method of the above dataset. This time, the _apply_transform method passes ʻimg, bboxes, linds, flags, that is, flag information such as images, bounding boxes, labels, and difficult as arguments. (* By the way, I omitted it, but ʻaugmentation has the same implementation method.)

_apply_transform method


if self.ignore:
    bboxes, linds, flags = self.ignore(bboxes, linds, flags)

if self.augmentation:
    img, bboxes, linds, flags = self.augmentation(img, bboxes, linds, flags)

if self.transform:
    img, bboxes, linds, flags = self.transform(img, bboxes, linds, flags)

if self.target_transform:
    bboxes, linds, flags = self.target_transform(bboxes, linds, flags)

Therefore, you should implement the __call__ (self, img, bboxes, linds, flags) method.

The ToTensor class that converts to torch.Tensor has the order of the input image of Conv2d of PyTorch from the order (h, w, c) of OpenCV (b, c, h, w) Converted to .

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img, bboxes, labels, flags):
        for t in self.transforms:
            img, bboxes, labels, flags = t(img, bboxes, labels, flags)
        return img, bboxes, labels, flags

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string

class ToTensor(object):
    """
    Note that convert ndarray to tensor and [0-255] to [0-1]
    """
    def __call__(self, img, *args):
        # convert ndarray into Tensor
        # transpose img's tensor (h, w, c) to pytorch's format (c, h, w). (num, c, h, w)
        img = np.transpose(img, (2, 0, 1))
        return (torch.from_numpy(img).float() / 255., *args)

class Resize(object):
    def __init__(self, size):
        """
        :param size: 2d-array-like, (height, width)
        """
        self._size = size

    def __call__(self, img, *args):
        return (cv2.resize(img, self._size), *args)


class Normalize(object):
    #def __init__(self, rgb_means=(103.939, 116.779, 123.68), rgb_stds=(1.0, 1.0, 1.0)):
    def __init__(self, rgb_means=(0.485, 0.456, 0.406), rgb_stds=(0.229, 0.224, 0.225)):
        self.means = np.array(rgb_means, dtype=np.float32).reshape((-1, 1, 1))
        if np.any(np.abs(self.means) > 1):
            logging.warning("In general, mean value should be less than 1 because img's range is [0-1]")

        self.stds = np.array(rgb_stds, dtype=np.float32).reshape((-1, 1, 1))

    def __call__(self, img, *args):
        if isinstance(img, torch.Tensor):
            return ((img.float() - torch.from_numpy(self.means)) / torch.from_numpy(self.stds), *args)
        else:
            return ((img.astype(np.float32) - self.means) / self.stds, *args)

Example of use

from data import transforms
transform = transforms.Compose(
        [transforms.Resize((300, 300)),
         transforms.ToTensor(),
         transforms.Normalize(rgb_means=(0.485, 0.456, 0.406), rgb_stds=(0.229, 0.224, 0.225))]
    )

Target transform

Bounding box, label conversion

Bounding box and label conversion.

--Convert bounding box from corners representation to centroids representation --Convert label to Onehot vector --Convert from ndarray to torch.Tensor --concatenate the bounding box and label (shape = (box num, 4 = (cx, cy, w, h) + class_nums + 1 = (background)) )

Bounding box expression

There are three.

--centroids representation --Use the center coordinates $ (c_x, c_y) $ and the width / height $ (w, h) . $bbox = (c_x,c_y,w,h)$$

--corners expression --Use the upper left coordinate $ (x_ {min}, y_ {min}) $ and the lower right coordinate $ (x_ {max}, y_ {max}) . $bbox = (x_{min},y_{min},x_{max},y_{max})$$

--minmax expression --Use the upper left coordinate $ (x_ {min}, y_ {min}) $ and the lower right coordinate $ (x_ {max}, y_ {max}) . --The order is different from the corners notation. $bbox = (x_{min},x_{max},y_{min},y_{max})$$

--Center coordinate $ (c_x, c_y) $ and width / height $ (w, h) $, upper left coordinate $ (x_ {min}, y_ {min}) $ and lower right coordinate $ (x_ {max} , y_ {max}) $ relationship

\begin{align}
(c_x,c_y) &= (\frac{x_{min}+x_{max}}{2},\frac{y_{min}+y_{max}}{2}) \\
(w,h) &= (x_{max}-x_{min},y_{max}-y_{min})
\end{align}

Implementation

The process of ↑ is implemented as follows. Correct label processing for object detection target_transforms does not exist in PyTorch, so you need to create target_transforms yourself. Again, the _apply_transform method passes the flag information such as bboxes, linds, flags, that is, the bounding box, label, and difficult as arguments as shown below.

_apply_transform method


if self.ignore:
    bboxes, linds, flags = self.ignore(bboxes, linds, flags)

if self.augmentation:
    img, bboxes, linds, flags = self.augmentation(img, bboxes, linds, flags)

if self.transform:
    img, bboxes, linds, flags = self.transform(img, bboxes, linds, flags)

if self.target_transform:
    bboxes, linds, flags = self.target_transform(bboxes, linds, flags)

Therefore, you can implement the __call__ (self, bboxes, linds, flags) method.

class ToTensor(object):
    def __call__(self, bboxes, labels, flags):
        return torch.from_numpy(bboxes), torch.from_numpy(labels), flags

class ToCentroids(object):
    def __call__(self, bboxes, labels, flags):
        # bbox = [xmin, ymin, xmax, ymax]
        bboxes = np.concatenate(((bboxes[:, 2:] + bboxes[:, :2]) / 2,
                                 (bboxes[:, 2:] - bboxes[:, :2])), axis=1)

        return bboxes, labels, flags

class ToCorners(object):
    def __call__(self, bboxes, labels, flags):
        # bbox = [cx, cy, w, h]
        bboxes = np.concatenate((bboxes[:, :2] - bboxes[:, 2:]/2,
                                 bboxes[:, :2] + bboxes[:, 2:]/2), axis=1)

        return bboxes, labels, flags

class OneHot(object):
    def __init__(self, class_nums, add_background=True):
        self._class_nums = class_nums
        self._add_background = add_background
        if add_background:
            self._class_nums += 1

    def __call__(self, bboxes, labels, flags):
        if labels.ndim != 1:
            raise ValueError('labels might have been already relu_one-hotted or be invalid shape')

        labels = _one_hot_encode(labels.astype(np.int), self._class_nums)
        labels = np.array(labels, dtype=np.float32)

        return bboxes, labels, flags

Example of use

target_transform = target_transforms.Compose(
        [target_transforms.ToCentroids(),
         target_transforms.OneHot(class_nums=datasets.VOC_class_nums, add_background=True),
         target_transforms.ToTensor()]
    )

in conclusion

The data set processing looks like this. As usual, I was halfway through, but I don't think there are many articles on data set processing, so I hope you find it helpful.

reference

Recommended Posts