[PYTHON] Try Random Erasing Data Augmentation

What is Random Erasing Data Augmentation?

In machine learning, Data Augmentation that prevents overfitting by processing input data is often used, but recently a new Data Augmentation method has been proposed in the field of image recognition.

Both are methods of masking a random partial rectangular area of the image that is the teacher data. The difference is that Random Erasing randomizes the size and aspect ratio of the rectangle, while Cutout has a fixed size. (However, Cutout is also experimenting with a method of selectively masking a part of the target object, and a fixed size mask is just as effective as that, so if you are using a fixed size mask for simplicity Claim) In addition to image classification, Random Erasing has confirmed its effectiveness in object detection and person matching.

Image of Random Erasing

(The image used here is different from the data used this time)

Before image processing After image processing
remove_aug1.jpg remove_aug2.jpg

Try Random Erasing Data Augmentation

I decided to give Random Erasing a try. I chose this instead of Cutout because it seems to be more effective to make the rectangle size random.

The task I did was classify the CIFAR-10 dataset. Implemented with Chainer. The source code is below.

After cloning the source code, you can train with the following command (it is recommended to change it every time you train because leaving the last -p option the same will overwrite the saved data).

$ python src/download.py
$ python src/dataset.py
$ python src/train.py -g 0 -m vgg_no_fc -p remove_aug --iter 300 -b 128 --lr 0.1 --lr_decay_iter 150,225

Implementation of Random Erasing

Hyperparameters

The hyperparameters related to Random Erasing are as follows.

This time, I chose a value close to the paper and set it as follows.

Hyperparameters value
p 0.5
s_l 0.02
s_h 0.4
r_1 1/3
r_2 3

Implementation

The code actually used is as follows. It is implemented as a method of the inherited class of chainer.datasets.TupleDataset. The part from "# Remove erasing start" to "# Remove erasing end" is the process related to Remove Erasing, and the random rectangular area is filled with a random value. (I think it is better to align the range of fill values with the range of data to be used) x of _transform is an array of input data and has the size of [batch size, number of channels, height, width].

    def _transform(self, x):
        image = np.zeros_like(x)
        size = x.shape[2]
        offset = np.random.randint(-4, 5, size=(2,))
        mirror = np.random.randint(2)
        remove = np.random.randint(2)
        top, left = offset
        left = max(0, left)
        top = max(0, top)
        right = min(size, left + size)
        bottom = min(size, top + size)
        if mirror > 0:
            x = x[:,:,::-1]
        image[:,size-bottom:size-top,size-right:size-left] = x[:,top:bottom,left:right]
        # Remove erasing start
        if remove > 0:
            while True:
                s = np.random.uniform(0.02, 0.4) * size * size
                r = np.random.uniform(-np.log(3.0), np.log(3.0))
                r = np.exp(r)
                w = int(np.sqrt(s / r))
                h = int(np.sqrt(s * r))
                left = np.random.randint(0, size)
                top = np.random.randint(0, size)
                if left + w < size and top + h < size:
                    break
            c = np.random.randint(-128, 128)
            image[:, top:top + h, left:left + w] = c
        # Remove erasing end
        return image

Neural network structure

The network code is shown below. It combines Convolutional and Max Pooling like VGG. However, the Fully Connected Layer is not provided, and the number of parameters is reduced by performing Global Pooling instead.


class BatchConv2D(chainer.Chain):
    def __init__(self, ch_in, ch_out, ksize, stride=1, pad=0, activation=F.relu):
        super(BatchConv2D, self).__init__(
            conv=L.Convolution2D(ch_in, ch_out, ksize, stride, pad),
            bn=L.BatchNormalization(ch_out),
        )
        self.activation=activation

    def __call__(self, x):
        h = self.bn(self.conv(x))
        if self.activation is None:
            return h
        return self.activation(h)

class VGGNoFC(chainer.Chain):
    def __init__(self):
        super(VGGNoFC, self).__init__(
            bconv1_1=BatchConv2D(3, 64, 3, stride=1, pad=1),
            bconv1_2=BatchConv2D(64, 64, 3, stride=1, pad=1),
            bconv2_1=BatchConv2D(64, 128, 3, stride=1, pad=1),
            bconv2_2=BatchConv2D(128, 128, 3, stride=1, pad=1),
            bconv3_1=BatchConv2D(128, 256, 3, stride=1, pad=1),
            bconv3_2=BatchConv2D(256, 256, 3, stride=1, pad=1),
            bconv3_3=BatchConv2D(256, 256, 3, stride=1, pad=1),
            bconv3_4=BatchConv2D(256, 256, 3, stride=1, pad=1),
            fc=L.Linear(256, 10),
        )

    def __call__(self, x):
        h = self.bconv1_1(x)
        h = self.bconv1_2(h)
        h = F.dropout(F.max_pooling_2d(h, 2), 0.25)
        h = self.bconv2_1(h)
        h = self.bconv2_2(h)
        h = F.dropout(F.max_pooling_2d(h, 2), 0.25)
        h = self.bconv3_1(h)
        h = self.bconv3_2(h)
        h = self.bconv3_3(h)
        h = self.bconv3_4(h)
        h = F.dropout(F.max_pooling_2d(h, 2), 0.25)
        h = F.average_pooling_2d(h, 4, 1, 0)
        h = self.fc(F.dropout(h))
        return h

Conditions for learning

The conditions for learning are as follows.

result

Accuracy has been improved by using Random Erasing as shown below.

Method Test Error
Random Erasing not used 6.68
Use Random Erasing 5.67

The transition of Training Error and Test Error is as follows. When using Random Erasing, the discrepancy between Training Error and Test Error is smaller, and it seems that overfitting is suppressed.

Random Erasing Not Used: vgg_no_fc_error.png

Random Erasing used: vgg_no_fc_remove_aug_error.png

in conclusion

It was a simple method of masking the input image, so I was able to try it immediately. This time it was effective, but I think it is necessary to verify whether it is an effective method under various conditions. If it is effective, it may become the standard in the future.

It's such a simple method that I'm personally wondering if it has been proposed in the past.

References

Recommended Posts

Try Random Erasing Data Augmentation
Data Augmentation with openCV
Try disabling IPv6 at random
Try "100 knocks on data science" ①
[PyTorch] Data Augmentation for segmentation
New Data Augmentation? [Grid Mix]
Try to put data in MongoDB
How to Data Augmentation with PyTorch
[Machine learning] Try studying random forest
Try data parallelism with Distributed TensorFlow