[PYTHON] J'ai essayé d'implémenter la régularisation Shake-Shake (ShakeNet) avec PyTorch

Qu'est-ce que la régularisation Shake-Shake?

C'est une de la régularisation. En augmentant les données d'entraînement d'une manière pseudo, il y a un avantage que vous pouvez apprendre lentement pendant une longue période.

Est-ce efficace comme augmentation de données lorsque le nombre de données est petit? Pour le moment, j'aimerais essayer cette fois avec CIFAR10.

shake-shake-regularization.jpg

Une brève description de la régularisation Shake-Shake est présentée ci-dessus. Créez deux blocs redisuels en parallèle dans Resnet et ajoutez les opérations suivantes à la sortie des blocs résiduels.

Pour plus de détails, veuillez vous référer aux articles d'autres personnes. J'ai lu et compris cet article. https://qiita.com/masataka46/items/fc7f31073c89b02f8a04

Autres détails écrits dans l'article

Le document a été conçu de différentes manières.

--Le flux de l'architecture Plain est ReLU → Conv → BN → ReLU → Conv → BN → Mul (avec un nombre aléatoire α) ――Divisé en 3 étapes, chacune avec 4 blocs résiduels --32,64,128 canaux pour chaque étage --Appliquer 3x3 Conv avant l'étape 1

Créer un bloc résiduel

Il existe deux types de resnet, l'architecture simple et l'architecture à goulot d'étranglement, mais cette fois, j'aimerais utiliser l'architecture Plain suivant l'article.

test.py


class ResidualPlainBlock(nn.Module):

    def __init__(self, in_channels, out_channels, stride, padding=0):
        super(ResidualPlainBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.conv1 = nn.Conv2d(in_channels,  out_channels, kernel_size=3,  stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(out_channels, out_channels,  kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.conv1_2 = nn.Conv2d(in_channels,  out_channels, kernel_size=3,  stride=stride, padding=1)
        self.bn1_2 = nn.BatchNorm2d(out_channels)

        self.conv2_2 = nn.Conv2d(out_channels, out_channels,  kernel_size=3, stride=1, padding=1)
        self.bn2_2 = nn.BatchNorm2d(out_channels)

        self.identity = nn.Identity()

        if in_channels != out_channels:
          self.down_avg1 = nn.AvgPool2d(kernel_size=1, stride=1)
          self.down_conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=stride, padding=0)
          self.down_pad1 = nn.ZeroPad2d((1,0,1,0))
          self.down_avg2 = nn.AvgPool2d(kernel_size=1, stride=1)
          self.down_conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=stride, padding=0)

    #Traitement spécial pendant le sous-échantillonnage
    def shortcut(self,x):
      x = F.relu(x)
      h1 = self.down_avg1(x)
      h1 = self.down_conv1(h1)
      h2 = self.down_pad1(x[:,:,1:,1:])
      h2 = self.down_avg1(h2)
      h2 = self.down_conv2(h2)
      return torch.cat((h1,h2),axis=1)


    def forward(self, x):
      if self.training:
        #1er bloc résiduel
          out = self.bn1(self.conv1(F.relu(x)))
          out = self.bn2(self.conv2(F.relu(out)))
          
        #Deuxième bloc résiduel
          out2 = self.bn1_2(self.conv1_2(F.relu(x)))
          out2 = self.bn2_2(self.conv2_2(F.relu(out2)))

          if self.in_channels != self.out_channels:
            output = self.shortcut(x) + ShakeShake.apply(out,out2)
          else:
            output = self.identity(x) + ShakeShake.apply(out,out2)
          
          return output
      else:
          out = self.bn1(self.conv1(F.relu(x)))
          out = self.bn2(self.conv2(F.relu(out)))
          
          out2 = self.bn1_2(self.conv1_2(F.relu(x)))
          out2 = self.bn2_2(self.conv2_2(F.relu(out2)))

          if self.in_channels != self.out_channels:
            output = self.shortcut(x) + (out+out2)*0.5
          else:
            output = self.identity(x) + (out+out2)*0.5
          
          return output

Le constructeur est compliqué, mais je pense que vous pouvez le comprendre en regardant la fonction forward.

Le contenu de la fonction avant est 1: Donnez le x reçu à deux blocs 2: sortie et sortie 2 3: Faire traiter les sorties et les sorties2 par ** ShakeShake.apply () ** 4: Ajouter le raccourci et 3 et les sortir ensemble

Capture 1: ShakeShake.apply ()

Vous pouvez définir une classe appelée classe ShakeShake pour définir le traitement avant et arrière.

test.py


class ShakeShake(torch.autograd.Function):
  @staticmethod
  def forward(ctx, i1, i2):
    alpha = random.random()
    result = i1 * alpha + i2 * (1-alpha)

    return result
  @staticmethod
  def backward(ctx, grad_output):
    beta  = random.random()

    return grad_output * beta, grad_output * (1-beta)

En avant, un nombre aléatoire alpha est généré et multiplié par out et out2.

En arrière, un nouveau nombre aléatoire bêta est généré et appliqué à grad_output (valeur transmise par propagation en retour d'erreur).

Capture 2: changer la vitesse d'apprentissage avec une courbe cosinus

PyTorch vous permet de planifier des taux d'apprentissage.

Mettez en œuvre comme suit.

test.py


learning_rate = 0.02
optimizer = optim.SGD(net.parameters(),lr=learning_rate,momentum=0.9,weight_decay=0.0001)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=0.001)

for i in range(200):
  #for 1epoch
    #Apprenez pour 1 époque ici...
  scheduler.step()

En faisant cela, le taux d'apprentissage fluctuera le long de la courbe cosinus pour chaque époque.

Après avoir défini l'optimiseur, appelez quelque chose appelé ** CosineAnnealingLR **.

Le premier argument est l'optimiseur, le deuxième argument (T_max) est le nombre de pas (numéro d'époque) jusqu'au demi-cycle du cosinus, et le troisième argument est le taux d'apprentissage minimum.

Dans le cas ci-dessus, le taux d'apprentissage chute de 0,02 à 0,001 à 50 époques, puis revient à 50 époques, puis diminue à 50 époques, et ainsi de suite.

Résultat d'exécution

** La précision était de 89,43% **.

C'est subtil parce que c'est plus de 95% dans le journal.

Cependant, la précision du ResNet normal, qui n'est pas spécialement conçu, est d'environ 80%, il semble donc que la précision soit meilleure que cela.

Le bleu est train_acc et l'orange est test_acc. 17-89.43.png

En fait, il y a des endroits qui ne sont pas mis en œuvre selon le journal.

--Seulement 200 époques sont entraînées au lieu de 1800 époques --Dans l'article, le taux d'apprentissage maximal est fixé à 0,2, mais dans ce cas, l'erreur est devenue nan, alors je l'ai changé en 0,02. ―― Y a-t-il une erreur dans la lecture du journal?

finalement

La régularisation Shake-Shake attire l'attention en tant que méthode de régularisation puissante.

Récemment, il semble qu'une nouvelle méthode appelée Shake Drop ait été conçue, je vais donc l'implémenter également.

Recommended Posts

J'ai essayé d'implémenter la régularisation Shake-Shake (ShakeNet) avec PyTorch
J'ai essayé d'implémenter Attention Seq2Seq avec PyTorch
J'ai essayé d'implémenter VQE avec Blueqat
J'ai créé Word2Vec avec Pytorch
J'ai essayé d'implémenter DeepPose avec PyTorch
[Introduction à Pytorch] J'ai joué avec sinGAN ♬
J'ai essayé d'implémenter DeepPose avec PyTorch PartⅡ
J'ai essayé d'implémenter CVAE avec PyTorch
J'ai essayé d'implémenter la lecture de Dataset avec PyTorch
J'ai réécrit le code MNIST de Chainer avec PyTorch + Ignite
Jouez avec PyTorch
J'ai implémenté CycleGAN (1)
Validation croisée avec PyTorch
À partir de PyTorch
J'ai implémenté ResNet!
J'ai essayé de déplacer Faster R-CNN rapidement avec pytorch
J'ai essayé d'implémenter et d'apprendre DCGAN avec PyTorch
[Introduction à Pytorch] J'ai essayé de catégoriser Cifar10 avec VGG16 ♬
J'ai essayé de mettre en œuvre le co-filtrage (recommandation) avec redis et python
J'ai essayé d'implémenter SSD avec PyTorch maintenant (Dataset)
J'ai eu une erreur lors de l'utilisation de Tensorboard avec Pytorch
Utilisez RTX 3090 avec PyTorch
J'ai joué avec wordcloud!
Qiskit: j'ai implémenté VQE
Installer la diffusion de la torche avec PyTorch 1.7
J'ai essayé d'implémenter l'algorithme FloodFill avec TRON BATTLE de CodinGame
J'ai essayé de classer MNIST par GNN (avec PyTorch géométrique)
J'ai essayé d'implémenter SSD avec PyTorch maintenant (édition du modèle)