[PYTHON] PyTorch Learning Note 2 (I tried using a pre-trained model)

Introduction

This is a continuation of PyTorch Learning Memo (I made the same model as Karas). This time, I tried to make a model using the pre-trained model for the same cifar10.

1. Final code

The code is on github. The URL is as follows.

https://github.com/makaishi2/sample-data/blob/master/notebooks/cifar10_resnet.ipynb

The points of implementation are as follows.

1.1 Read data

Last time, the transformer that was common for training and verification was separated. The heart is to do "** Data Augmentation **" for training. Details are explained in 2.2 Inflating Training Data.

#define transform

#For validation data:Only normalize
transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

#For training data:Perform inversion and Random Erasing in addition to normalization
transform_train = transforms.Compose([
  transforms.RandomHorizontalFlip(p=0.5), 
  transforms.ToTensor(),
  transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 
  transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False)
])

1.2 Loading the model

The code for loading the model is as follows. PyTorch provides several patterns of pre-trained models that can be weighted and loaded by simply calling a function.

#Loading trained model
#Load Resnet 50 with weight
model_ft = models.resnet50(pretrained = True)

#Change the output of the last node to 10
model_ft.fc = nn.Linear(model_ft.fc.in_features, 10)

#Use of GPU
net = model_ft.to(device)

#Use cross entropy for loss function
criterion = nn.CrossEntropyLoss()

#Regarding optimization, as a result of examining some patterns, the following was the best result
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)

Specific available models are listed at the link below. https://pytorch.org/docs/stable/torchvision/models.html

If you want to use such a model, after loading it

model_ft.fc = nn.Linear(model_ft.fc.in_features, 10)

It is possible to create the desired classification model by replacing only the final stage with a new node with the code of. By the way, this model originally has an input size of 224x224. At first I wasn't sure if it would be okay to put 32x32 size data like the CIFAR-10 into such a large model, but in conclusion it doesn't seem to be a problem. Perhaps a lot of weight isn't used at all and is just wasted. Conversely, if you want to include data with a higher resolution than the original model, such as 1024x1024, you will need to reduce the resolution to 224x224 in the pre-processing.

1.3 Model overview view

If you want to get an overview of the model, you can run net at this stage. The result for resnet50 is shown below.

 Bottleneck(
      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (4): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (5): Bottleneck(
      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2048, out_features=10, bias=True)
)

It's really long. You can see that it is a fairly complicated neural network.

1.4 Model size display

To see what the size of the node in which stage is, run the following code at this stage.

#Model summary view

from torchsummary import summary
summary(net,(3,128,128))

The results are as follows.

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 16, 16]           9,408
       BatchNorm2d-2           [-1, 64, 16, 16]             128
              ReLU-3           [-1, 64, 16, 16]               0
         MaxPool2d-4             [-1, 64, 8, 8]               0
            Conv2d-5             [-1, 64, 8, 8]           4,096
       BatchNorm2d-6             [-1, 64, 8, 8]             128
              ReLU-7             [-1, 64, 8, 8]               0
            Conv2d-8             [-1, 64, 8, 8]          36,864
       BatchNorm2d-9             [-1, 64, 8, 8]             128
             ReLU-10             [-1, 64, 8, 8]               0
           Conv2d-11            [-1, 256, 8, 8]          16,384
      BatchNorm2d-12            [-1, 256, 8, 8]             512
           Conv2d-13            [-1, 256, 8, 8]          16,384
      BatchNorm2d-14            [-1, 256, 8, 8]             512
             ReLU-15            [-1, 256, 8, 8]               0
       Bottleneck-16            [-1, 256, 8, 8]               0
           Conv2d-17             [-1, 64, 8, 8]          16,384
      BatchNorm2d-18             [-1, 64, 8, 8]             128
             ReLU-19             [-1, 64, 8, 8]               0
           Conv2d-20             [-1, 64, 8, 8]          36,864
      BatchNorm2d-21             [-1, 64, 8, 8]             128
             ReLU-22             [-1, 64, 8, 8]               0
           Conv2d-23            [-1, 256, 8, 8]          16,384
      BatchNorm2d-24            [-1, 256, 8, 8]             512
             ReLU-25            [-1, 256, 8, 8]               0
       Bottleneck-26            [-1, 256, 8, 8]               0
           Conv2d-27             [-1, 64, 8, 8]          16,384
      BatchNorm2d-28             [-1, 64, 8, 8]             128
             ReLU-29             [-1, 64, 8, 8]               0
           Conv2d-30             [-1, 64, 8, 8]          36,864
      BatchNorm2d-31             [-1, 64, 8, 8]             128
             ReLU-32             [-1, 64, 8, 8]               0
           Conv2d-33            [-1, 256, 8, 8]          16,384
      BatchNorm2d-34            [-1, 256, 8, 8]             512
             ReLU-35            [-1, 256, 8, 8]               0
       Bottleneck-36            [-1, 256, 8, 8]               0
           Conv2d-37            [-1, 128, 8, 8]          32,768
      BatchNorm2d-38            [-1, 128, 8, 8]             256
             ReLU-39            [-1, 128, 8, 8]               0
           Conv2d-40            [-1, 128, 4, 4]         147,456
      BatchNorm2d-41            [-1, 128, 4, 4]             256
             ReLU-42            [-1, 128, 4, 4]               0
           Conv2d-43            [-1, 512, 4, 4]          65,536
      BatchNorm2d-44            [-1, 512, 4, 4]           1,024
           Conv2d-45            [-1, 512, 4, 4]         131,072
      BatchNorm2d-46            [-1, 512, 4, 4]           1,024
             ReLU-47            [-1, 512, 4, 4]               0
       Bottleneck-48            [-1, 512, 4, 4]               0
           Conv2d-49            [-1, 128, 4, 4]          65,536
      BatchNorm2d-50            [-1, 128, 4, 4]             256
             ReLU-51            [-1, 128, 4, 4]               0
           Conv2d-52            [-1, 128, 4, 4]         147,456
      BatchNorm2d-53            [-1, 128, 4, 4]             256
             ReLU-54            [-1, 128, 4, 4]               0
           Conv2d-55            [-1, 512, 4, 4]          65,536
      BatchNorm2d-56            [-1, 512, 4, 4]           1,024
             ReLU-57            [-1, 512, 4, 4]               0
       Bottleneck-58            [-1, 512, 4, 4]               0
           Conv2d-59            [-1, 128, 4, 4]          65,536
      BatchNorm2d-60            [-1, 128, 4, 4]             256
             ReLU-61            [-1, 128, 4, 4]               0
           Conv2d-62            [-1, 128, 4, 4]         147,456
      BatchNorm2d-63            [-1, 128, 4, 4]             256
             ReLU-64            [-1, 128, 4, 4]               0
           Conv2d-65            [-1, 512, 4, 4]          65,536
      BatchNorm2d-66            [-1, 512, 4, 4]           1,024
             ReLU-67            [-1, 512, 4, 4]               0
       Bottleneck-68            [-1, 512, 4, 4]               0
           Conv2d-69            [-1, 128, 4, 4]          65,536
      BatchNorm2d-70            [-1, 128, 4, 4]             256
             ReLU-71            [-1, 128, 4, 4]               0
           Conv2d-72            [-1, 128, 4, 4]         147,456
      BatchNorm2d-73            [-1, 128, 4, 4]             256
             ReLU-74            [-1, 128, 4, 4]               0
           Conv2d-75            [-1, 512, 4, 4]          65,536
      BatchNorm2d-76            [-1, 512, 4, 4]           1,024
             ReLU-77            [-1, 512, 4, 4]               0
       Bottleneck-78            [-1, 512, 4, 4]               0
           Conv2d-79            [-1, 256, 4, 4]         131,072
      BatchNorm2d-80            [-1, 256, 4, 4]             512
             ReLU-81            [-1, 256, 4, 4]               0
           Conv2d-82            [-1, 256, 2, 2]         589,824
      BatchNorm2d-83            [-1, 256, 2, 2]             512
             ReLU-84            [-1, 256, 2, 2]               0
           Conv2d-85           [-1, 1024, 2, 2]         262,144
      BatchNorm2d-86           [-1, 1024, 2, 2]           2,048
           Conv2d-87           [-1, 1024, 2, 2]         524,288
      BatchNorm2d-88           [-1, 1024, 2, 2]           2,048
             ReLU-89           [-1, 1024, 2, 2]               0
       Bottleneck-90           [-1, 1024, 2, 2]               0
           Conv2d-91            [-1, 256, 2, 2]         262,144
      BatchNorm2d-92            [-1, 256, 2, 2]             512
             ReLU-93            [-1, 256, 2, 2]               0
           Conv2d-94            [-1, 256, 2, 2]         589,824
      BatchNorm2d-95            [-1, 256, 2, 2]             512
             ReLU-96            [-1, 256, 2, 2]               0
           Conv2d-97           [-1, 1024, 2, 2]         262,144
      BatchNorm2d-98           [-1, 1024, 2, 2]           2,048
             ReLU-99           [-1, 1024, 2, 2]               0
      Bottleneck-100           [-1, 1024, 2, 2]               0
          Conv2d-101            [-1, 256, 2, 2]         262,144
     BatchNorm2d-102            [-1, 256, 2, 2]             512
            ReLU-103            [-1, 256, 2, 2]               0
          Conv2d-104            [-1, 256, 2, 2]         589,824
     BatchNorm2d-105            [-1, 256, 2, 2]             512
            ReLU-106            [-1, 256, 2, 2]               0
          Conv2d-107           [-1, 1024, 2, 2]         262,144
     BatchNorm2d-108           [-1, 1024, 2, 2]           2,048
            ReLU-109           [-1, 1024, 2, 2]               0
      Bottleneck-110           [-1, 1024, 2, 2]               0
          Conv2d-111            [-1, 256, 2, 2]         262,144
     BatchNorm2d-112            [-1, 256, 2, 2]             512
            ReLU-113            [-1, 256, 2, 2]               0
          Conv2d-114            [-1, 256, 2, 2]         589,824
     BatchNorm2d-115            [-1, 256, 2, 2]             512
            ReLU-116            [-1, 256, 2, 2]               0
          Conv2d-117           [-1, 1024, 2, 2]         262,144
     BatchNorm2d-118           [-1, 1024, 2, 2]           2,048
            ReLU-119           [-1, 1024, 2, 2]               0
      Bottleneck-120           [-1, 1024, 2, 2]               0
          Conv2d-121            [-1, 256, 2, 2]         262,144
     BatchNorm2d-122            [-1, 256, 2, 2]             512
            ReLU-123            [-1, 256, 2, 2]               0
          Conv2d-124            [-1, 256, 2, 2]         589,824
     BatchNorm2d-125            [-1, 256, 2, 2]             512
            ReLU-126            [-1, 256, 2, 2]               0
          Conv2d-127           [-1, 1024, 2, 2]         262,144
     BatchNorm2d-128           [-1, 1024, 2, 2]           2,048
            ReLU-129           [-1, 1024, 2, 2]               0
      Bottleneck-130           [-1, 1024, 2, 2]               0
          Conv2d-131            [-1, 256, 2, 2]         262,144
     BatchNorm2d-132            [-1, 256, 2, 2]             512
            ReLU-133            [-1, 256, 2, 2]               0
          Conv2d-134            [-1, 256, 2, 2]         589,824
     BatchNorm2d-135            [-1, 256, 2, 2]             512
            ReLU-136            [-1, 256, 2, 2]               0
          Conv2d-137           [-1, 1024, 2, 2]         262,144
     BatchNorm2d-138           [-1, 1024, 2, 2]           2,048
            ReLU-139           [-1, 1024, 2, 2]               0
      Bottleneck-140           [-1, 1024, 2, 2]               0
          Conv2d-141            [-1, 512, 2, 2]         524,288
     BatchNorm2d-142            [-1, 512, 2, 2]           1,024
            ReLU-143            [-1, 512, 2, 2]               0
          Conv2d-144            [-1, 512, 1, 1]       2,359,296
     BatchNorm2d-145            [-1, 512, 1, 1]           1,024
            ReLU-146            [-1, 512, 1, 1]               0
          Conv2d-147           [-1, 2048, 1, 1]       1,048,576
     BatchNorm2d-148           [-1, 2048, 1, 1]           4,096
          Conv2d-149           [-1, 2048, 1, 1]       2,097,152
     BatchNorm2d-150           [-1, 2048, 1, 1]           4,096
            ReLU-151           [-1, 2048, 1, 1]               0
      Bottleneck-152           [-1, 2048, 1, 1]               0
          Conv2d-153            [-1, 512, 1, 1]       1,048,576
     BatchNorm2d-154            [-1, 512, 1, 1]           1,024
            ReLU-155            [-1, 512, 1, 1]               0
          Conv2d-156            [-1, 512, 1, 1]       2,359,296
     BatchNorm2d-157            [-1, 512, 1, 1]           1,024
            ReLU-158            [-1, 512, 1, 1]               0
          Conv2d-159           [-1, 2048, 1, 1]       1,048,576
     BatchNorm2d-160           [-1, 2048, 1, 1]           4,096
            ReLU-161           [-1, 2048, 1, 1]               0
      Bottleneck-162           [-1, 2048, 1, 1]               0
          Conv2d-163            [-1, 512, 1, 1]       1,048,576
     BatchNorm2d-164            [-1, 512, 1, 1]           1,024
            ReLU-165            [-1, 512, 1, 1]               0
          Conv2d-166            [-1, 512, 1, 1]       2,359,296
     BatchNorm2d-167            [-1, 512, 1, 1]           1,024
            ReLU-168            [-1, 512, 1, 1]               0
          Conv2d-169           [-1, 2048, 1, 1]       1,048,576
     BatchNorm2d-170           [-1, 2048, 1, 1]           4,096
            ReLU-171           [-1, 2048, 1, 1]               0
      Bottleneck-172           [-1, 2048, 1, 1]               0
AdaptiveAvgPool2d-173           [-1, 2048, 1, 1]               0
          Linear-174                   [-1, 10]          20,490
================================================================
Total params: 23,528,522
Trainable params: 20,490
Non-trainable params: 23,508,032
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 5.86
Params size (MB): 89.75
Estimated Total Size (MB): 95.63
----------------------------------------------------------------
[ ]

At the end, the number of vertical and horizontal dimensions is one-dimensional, and I'm a little worried that this makes sense as a neural network. I investigated this point in 2.3 Model Selection, so please refer to the result.

1.5 Learning main loop

What I'm doing is the same as in PyTorch learning memo (I made the same model as Karas), but I changed the output during learning to make it look like Keras. .. I've added more comments than before, so I think it's much easier to understand.

for i in range(nb_epoch):
  train_loss = 0
  train_acc = 0
  val_loss = 0
  val_acc = 0

  #Learning
  net.train()

  for images, labels in train_loader:

    #Gradient initialization(You have to do it at the head of the loop)
    optimizer.zero_grad()

    #Preparation of training data
    images = images.to(device)
    labels = labels.to(device)

    #Forward propagation calculation
    outputs = net(images)

    #Error calculation
    loss = criterion(outputs, labels)
    train_loss += loss.item()

    #Learning
    loss.backward()
    optimizer.step()

    #Predicted value calculation
    predicted = outputs.max(1)[1]

    #Calculation of the number of correct answers
    train_acc += (predicted == labels).sum()

  #Calculation of loss and accuracy for training data
  avg_train_loss = train_loss / len(train_loader.dataset)
  avg_train_acc = train_acc / len(train_loader.dataset)

  #Evaluation
  net.eval()
  with torch.no_grad():

    for images, labels in test_loader:

      #Preparation of test data
      images = images.to(device)
      labels = labels.to(device)

      #Forward propagation calculation
      outputs = net(images)

      #Error calculation
      loss = criterion(outputs, labels)
      val_loss += loss.item()

      #Predicted value calculation
      predicted = outputs.max(1)[1]

      #Calculation of the number of correct answers
      val_acc += (predicted == labels).sum()

    #Calculation of loss and accuracy for validation data
    avg_val_loss = val_loss / len(test_loader.dataset)
    avg_val_acc = val_acc / len(test_loader.dataset)

  print (f'Epoch [{(i+1)}/{nb_epoch}], loss: {avg_train_loss:.5f} acc: {avg_train_acc:.5f} val_loss: {avg_val_loss:.5f}, val_acc: {avg_val_acc:.5f}')
  train_loss_list.append(avg_train_loss)
  train_acc_list.append(avg_train_acc)
  val_loss_list.append(avg_val_loss)
  val_acc_list.append(avg_val_acc)

1.6 Learning curve display

Shows the learning curve for both loss function values ​​and accuracy. Here is the code and an example result:

Loss function value
#Learning curve(Loss function value)
plt.figure(figsize=(8,6))
plt.plot(val_loss_list,label='Verification', lw=2, c='b')
plt.plot(train_loss_list,label='Training', lw=2, c='k')
plt.title('Learning curve(Loss function value)')
plt.xticks(size=14)
plt.yticks(size=14)
plt.grid(lw=2)
plt.legend(fontsize=14)
plt.xticks(np.arange(0, 21, 2))
plt.show()

スクリーンショット 2021-01-13 18.47.58.png

accuracy
#Learning curve(accuracy)
plt.figure(figsize=(8,6))
plt.plot(val_acc_list,label='Verification', lw=2, c='b')
plt.plot(train_acc_list,label='Training', lw=2, c='k')
plt.title('Learning curve(accuracy)')
plt.xticks(size=14)
plt.yticks(size=14)
plt.grid(lw=2)
plt.legend(fontsize=14)
plt.xticks(np.arange(0, 21, 2))
plt.show()

スクリーンショット 2021-01-13 18.48.09.png

2. Tuning

2.1 Optimization parameters

I will omit the details, but after examining some patterns regarding optimization parameters, I came to the conclusion that the following seems to be the best.

#Regarding optimization, as a result of examining some patterns, the following was the best result
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)

For example, for optimization functions, I also tried Adam, but the learning speed was rather slower than SGD. This seems to be a common story when using pre-trained models.

2.2 Inflating training data

I will repost the definition of transform for training data.

transform_train = transforms.Compose([
  transforms.RandomHorizontalFlip(p=0.5),   #Randomized 1
  transforms.ToTensor(),
  transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 
  transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False)  #Randomized 2
])

There are two calls here, RandomHorizontalFlip and RandomErasing. These two are functions that randomly perform "image inversion" and "image deletion of rectangular area", respectively, which increases the variation of the original training data. In order to confirm how effective this "data padding" is, we trained the same model with the following four patterns and compared the results. 'All the learning curves below are for validation data)

None: No randomization Invert: Upside down only Erase: Random Erasing only Both: Both inversion and erasure

Loss function value

スクリーンショット 2021-01-13 14.26.51.png

accuracy

スクリーンショット 2021-01-13 14.27.14.png

As is clear from the graph, these two are effective in improving accuracy, and since they have independent effects, they can be used in combination.

2.3 Model selection

Next, I replaced the model to be loaded and tried to see what the accuracy would be. The four models I tried are below. For Resnet, I tried 3 patterns with 18, 50 and 152 layers.

model_ft = models.resnet50(pretrained = True)
# model_ft = models.resnet18(pretrained = True)
# model_ft = models.resnet152(pretrained = True)
# model_ft = models.vgg19_bn(pretrained = True)

The graph of the learning curve of the result (for the validation data) is shown below.

Loss function value

スクリーンショット 2021-01-13 15.54.57.png

accuracy

スクリーンショット 2021-01-13 16.18.13.png

First, when comparing Resnet 18 and Resnet 50, Resnet 50 is better. From this, it can be seen that it makes sense to deepen the network hierarchy up to about 50 layers. On the contrary, when comparing Resnet 50 and Resnet 152, Resnet 152 tends to be almost the same or worse, so in the case of data with a small resolution such as CIFAR-10, Resnet can be deepened by 50 layers or more. You can also see that it doesn't make sense. Finally, vgg_19_bn (bn is a model that incorporates a technique called Batch Normalization) has produced overwhelmingly good results. ResNet came out later than VGG in terms of model evolution, but for data with a small resolution like CIFAR-10, a model with a simple structure like VGG19 is more accurate. maybe.

3. Unresolved issues

Actually, there is one point I still don't understand. This example itself was originally intended as a sample of "** Transfer Learning **". In order to make transfer learning, the following code should be used when loading the model.

#Load Resnet 50 with weight
model_ft = models.resnet50(pretrained = True)

#Do not calculate the gradient other than the final stage
for param in model_ft.parameters():
  param.requires_grad = False

#Replace only the last stage(Number of classification destination classes=10)
model_ft.fc = nn.Linear(model_ft.fc.in_features, 10)

#GPU allocation
net = model_ft.to(device)

#The loss function is the cross entropy function
criterion = nn.CrossEntropyLoss()

#Optimization calculation is also performed only in the final stage
optimizer = optim.SGD(net.fc.parameters(), lr=0.001, momentum=0.9)

However. .. .. No matter how many times I try, the result is as follows, and the accuracy is not good at all. If anyone can understand the reason for this, I would appreciate it if you could let me know.

Epoch [1/20], loss: 0.01858 acc: 0.35722 val_loss: 0.01631, val_acc: 0.45390
Epoch [2/20], loss: 0.01665 acc: 0.43014 val_loss: 0.01546, val_acc: 0.47900
Epoch [3/20], loss: 0.01618 acc: 0.44372 val_loss: 0.01509, val_acc: 0.49140
Epoch [4/20], loss: 0.01587 acc: 0.45340 val_loss: 0.01479, val_acc: 0.49960
Epoch [5/20], loss: 0.01562 acc: 0.45984 val_loss: 0.01468, val_acc: 0.50190
Epoch [6/20], loss: 0.01548 acc: 0.46750 val_loss: 0.01429, val_acc: 0.51690
Epoch [7/20], loss: 0.01537 acc: 0.47066 val_loss: 0.01421, val_acc: 0.51670
Epoch [8/20], loss: 0.01524 acc: 0.47468 val_loss: 0.01412, val_acc: 0.51730
Epoch [9/20], loss: 0.01511 acc: 0.47828 val_loss: 0.01418, val_acc: 0.51480
Epoch [10/20], loss: 0.01504 acc: 0.47974 val_loss: 0.01401, val_acc: 0.52380
Epoch [11/20], loss: 0.01497 acc: 0.48228 val_loss: 0.01390, val_acc: 0.52660
Epoch [12/20], loss: 0.01491 acc: 0.48466 val_loss: 0.01389, val_acc: 0.53050
Epoch [13/20], loss: 0.01485 acc: 0.48774 val_loss: 0.01375, val_acc: 0.53290
Epoch [14/20], loss: 0.01481 acc: 0.48686 val_loss: 0.01379, val_acc: 0.52630
Epoch [15/20], loss: 0.01476 acc: 0.48818 val_loss: 0.01380, val_acc: 0.53200
Epoch [16/20], loss: 0.01476 acc: 0.48818 val_loss: 0.01361, val_acc: 0.54190
Epoch [17/20], loss: 0.01476 acc: 0.48762 val_loss: 0.01362, val_acc: 0.53570
Epoch [18/20], loss: 0.01464 acc: 0.49438 val_loss: 0.01352, val_acc: 0.53940
Epoch [19/20], loss: 0.01462 acc: 0.49442 val_loss: 0.01349, val_acc: 0.54160
Epoch [20/20], loss: 0.01464 acc: 0.49370 val_loss: 0.01355, val_acc: 0.53710
Learning curve (loss function)

スクリーンショット 2021-01-13 17.20.45.png

Learning curve (accuracy)

スクリーンショット 2021-01-13 17.21.00.png

Recommended Posts

PyTorch Learning Note 2 (I tried using a pre-trained model)
I tried hosting a Pytorch sample model using TorchServe
I tried hosting a TensorFlow deep learning model using TensorFlow Serving
I tried reinforcement learning using PyBrain
I tried deep learning using Theano
Creating a learning model using MNIST
I tried using Tensorboard, a visualization tool for machine learning
I tried to divide with a deep learning language model
I tried playing a ○ ✕ game using TensorFlow
I tried to implement anomaly detection using a hidden Markov model
[Kaggle] I tried ensemble learning using LightGBM
I tried batch normalization with PyTorch (+ note)
I tried using pipenv, so a memo
I tried using the trained model VGG16 of the deep learning library Keras
I tried hosting Pytorch's deep learning model using TorchServe on Amazon SageMaker
I tried to make PyTorch model API in Azure environment using TorchServe
I tried using Pythonect, a dataflow programming language.
I tried reading a CSV file using Python
I tried using a database (sqlite3) with kivy
I tried to make a ○ ✕ game using TensorFlow
I tried using parameterized
I tried using argparse
I tried using mimesis
I tried using anytree
I tried using aiomysql
I tried using coturn
I tried using Pipenv
I tried using matplotlib
I tried using "Anvil".
I tried using Hubot
I tried using ESPCN
I tried using openpyxl
I tried using Ipython
I tried using ngrok
I tried using face_recognition
I tried using Jupyter
I tried using PyCaret
I tried using Heapq
I tried using doctest
I tried using folium
I tried using folium
I tried using time-window
[Python] Deep Learning: I tried to implement deep learning (DBN, SDA) without using a library.
I tried to implement various methods for machine learning (prediction model) using scikit-learn.
I tried learning my own dataset using Chainer Trainer
[MNIST] I tried Fine Tuning using the ImageNet model.
[Python] I tried running a local server using flask
I tried drawing a pseudo fractal figure using Python
I tried reading data from a file using Node.js.
I tried using Python (3) instead of a scientific calculator
I tried to draw a configuration diagram using Diagrams
I tried to compress the image using machine learning
[I tried using Pythonista 3] Introduction
I tried using easydict (memo).
I tried face recognition using Face ++
I tried using Random Forest
I tried using Amazon Glacier
I tried using git inspector
[Python] I tried using OpenPose
I tried using magenta / TensorFlow
I tried using AWS Chalice