(Scheduled to be edited)
See ↓ https://arxiv.org/pdf/1710.09412.pdf
def mixup(x, y, batch_size, alpha = 0.2):
l = np.random.beta(alpha, alpha, batch_size)
x_, y_ = sklearn.utils.shuffle(x, y)
shape = tuple(-1 if i == 0 else 1 for i in range(len(x.shape)))
x = l.reshape(shape) * x + (1 - l).reshape(shape) * x_
shape = tuple(-1 if i == 0 else 1 for i in range(len(y.shape)))
y = l.reshape(shape) * y + (1 - l).reshape(shape) * y_
return x, y
| Argument name | Description (type) |
|---|---|
| x | Input data (numpy.ndarray) |
| y | Output data (numpy.ndarray) |
| batch_size | Literally batch size (int) |
| alpha | Parameters that determine the β distribution (float) |
It is assumed that x and y are input in batches, and the dataset is shuffled after each epoch.
l = np.random.beta(alpha, alpha, batch_size)
Here, the blending weight is calculated. The shape of l is(batch_size).
x_, y_ = sklearn.utils.shuffle(x, y)
Here, the shuffle is performed while maintaining the correspondence between x and y.
shape = tuple(-1 if i == 0 else 1 for i in range(len(x.shape)))
Here, we are looking for a shape that makes l broadcastable to x.
For example, if the shape of x is(batch_size, width, height, ch), and the shape of l is (batch_size), then when you do l * x The calculation result is not as expected. Also, depending on the shape of x, an error such as ʻoperands could not be broadcast ...` may occur.
x = l.reshape(shape) * x + (1 - l).reshape(shape) * x_
Here, x is blended.
shape = tuple(-1 if i == 0 else 1 for i in range(len(y.shape)))
Here, we are looking for a shape that allows l to be broadcast to y.
y = l.reshape(shape) * y + (1 - l).reshape(shape) * y_
Here, y is blended.
| Return value name | Description (type) |
|---|---|
| x | Mixed input data (numpy.ndarray) |
| y | Mixed output data (numpy.ndarray) |
When the mixup is applied by regarding the identity matrix as input data, the result is as follows.
| α = 0.1 | α = 0.2 | α = 0.3 |
|---|---|---|
![]() |
![]() |
![]() |
| α = 0.4 | α = 0.5 | α = 0.6 |
|---|---|---|
![]() |
![]() |
![]() |
| α = 0.7 | α = 0.8 | α = 0.9 |
|---|---|---|
![]() |
![]() |
![]() |
Reference
Recommended Posts