[PYTHON] CIFAR-10 classification implemented in virtually 60 lines in PyTorch

Overview

I haven't implemented deep learning yet, so I'd like to implement deep learning for the time being! This is an article for those who say.

This time, we implemented from reading the image data to outputting the training error / generalization error in a simple graph in 60 lines.

"Actual" is when you don't consider line breaks and comment statements to make the code easier to read, and the actual code is around 100 lines.

** I'll put the full code at the end of the article **

Process flow

Roughly speaking, the flow is as follows.

--Data preparation (using cifar-10) --Model definition (using resnet50) --Definition of loss function / optimization method --Training / reasoning --Result output, model saving, etc.

Let's take a look at the source code.

What to import

First, let's import the library etc. used this time first.

test.py


import torch
import numpy as np
import torch.nn as nn
from torch import optim
import torch.nn.init as init
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import Dataset,DataLoader
import torchvision.datasets as dsets
import matplotlib.pyplot as plt

There are 10 lines in this alone (laughs)

Data preparation

For image data, we use a famous data set called CIFAR-10.

10 means that there are 10 classes, so 50,000 training images (5,000 for each class) and 10,000 test images (1000 for each class) are prepared.

You can load CIFAR-10 as follows.

test.py


#Loading images
batch_size = 100

train_data = dsets.CIFAR10(root='./tmp/cifar-10', train=True, download=False, transform=transforms.Compose([transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False)]))

train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)

test_data = dsets.CIFAR10(root='./tmp/cifar-10', train=False, download=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]))

test_loader = DataLoader(test_data,batch_size=batch_size,shuffle=False)

Set download = True for the first time. The CIFAR-10 dataset will be downloaded.

There is a long point with transform = transforms.Compose ([...]), but this is an item to process the image data in various ways.

--RandomHorizonalFlip: Flip the image left and right --ToTensor: Tensor conversion --Normalize: Standardization --RandomErasing: Adds noise to a part of data

I feel said. It has the effect of making it easier to calculate and preventing overfitting with Data Augmentation.

There are many other transforms, so please refer to here.

Check if GPU can be used

Before defining the model, check if it can be calculated using the GPU. It is recommended that you can use the GPU because it will be many times faster.

test.py


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

If the GPU can be used, cuda is output, and if only the CPU can be used, cpu is output.

Model definition

It's a model that is important for deep learning, but now you can handle excellent models without having to think about the model yourself.

This time we will use a model called Resnet.

test.py


model_ft = models.resnet50(pretrained=True)
model_ft.fc = nn.Linear(model_ft.fc.in_features, 10)
net = model_ft.to(device)

models.resnet50 (pretrained = True) allows you to use Resnet's trained models. It's easy ...

By the way, if you set pretrained = False, you can use Resnet that has not been learned yet, but it is recommended to set it to True because the learning time is long.

The second line has an output layer of 10. It can be 40 or 100.

The third line assigns the model to a variable called net. In .to (device), if the GPU can be used, it will be calculated by the GPU (miscellaneous)

Definition of loss function / optimization method

In order for the model to learn, the loss function must give an error from the correct answer. And optimization is a method to reduce the error.

Here, we use cross entropy for the loss function and SGD for the optimization algorithm.

test.py


criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9,weight_decay=0.00005)

lr is the learning rate, which cannot be too big or too small. Here, it is set to 0.01.

weight_decay is called "weight attenuation" and is also one of the regularization methods to prevent overfitting.

Learning / reasoning

Now that we're ready, we're ready to enter the learning and inference phase.

test.py


loss,epoch_loss,count = 0,0,0
acc_list = []
loss_list = []
for i in range(50):
  
  #Learn from here
  net.train()
  
  for j,data in enumerate(train_loader,0):
    optimizer.zero_grad()

    #1:Read training data
    inputs,labels = data
    inputs = inputs.to(device)
    labels = labels.to(device)

    #2:calculate
    outputs = net(inputs)

    #3:Find the error
    loss = criterion(outputs,labels)

    #4:Learn from error
    loss.backward()
    optimizer.step()

    epoch_loss += loss
    count += 1
    print('%d: %.3f'%(j+1,loss))

  print('%depoch:mean_loss=%.3f\n'%(i+1,epoch_loss/count))
  loss_list.append(epoch_loss/count)

  epoch_loss = 0
  count = 0
  correct = 0
  total = 0
  accuracy = 0.0

  #Inference from here
  net.eval()
 
  for j,data in enumerate(test_loader,0):

    #Prepare test data
    inputs,labels = data
    inputs = inputs.to(device)
    labels = labels.to(device)

    #calculate
    outputs = net(inputs)

    #Find the predicted value
    _,predicted = torch.max(outputs.data,1)

    #Calculate accuracy
    correct += (predicted == labels).sum()
    total += batch_size

  accuracy = 100.*correct / total
  acc_list.append(accuracy)

  print('epoch:%d Accuracy(%d/%d):%f'%(i+1,correct,total,accuracy))
  torch.save(net.state_dict(),'Weight'+str(i+1))

It's a little long.

The learning phase is from net.train () to before net.eval (), and the inference phase is after net.eval ().

Learning phase

In the above code, the learning phase is

  1. Read training data
  2. Calculate
  3. Find the error between the predicted value and the actual value
  4. Learn from the error of 3

It has become a flow.

Images are loaded by batch_size (← defined in "Data preparation") in one loop of the for statement. The for statement ends when all the training images are loaded.

(Example) If there are 50,000 learning data and the batch size is 100, the number of loops is 500.

Inference phase

In the above code, the inference phase is

  1. Read test data
  2. Calculate
  3. Find the predicted value
  4. Obtain accuracy (correct answer rate)

It has become a flow.

1 and 2 are the same as when learning.

In 3, the predicted value is calculated. It's like "This image is class 5!" Actually, the value obtained by the calculation of 2 is ** the probability that the loaded image will be classified into each class **.

For example, if you read one image and output [0.01,0.04,0.95],

The probability of being class 1 is 0.01 (1%) The probability of being class 2 is 0.04 (4%) The probability of being class 3 is 0.95 (95%)

It means that.

And in this case, the predicted value is class 3.

torch.save (net.state_dict (),'Weight' + str (i + 1)) can save the learned weights.

Graph output

A list called acc_list and loss_list was defined in the learning / inference phase, but this is a list that stores the training error and accuracy for each epoch.

To graph this, do the following:

test.py


plt.plot(acc_list)
plt.show(acc_list)
plt.plot(loss_list)
plt.show(loss_list)

This is the simplest graph output method.

By the way, the graph output when this code is executed is as follows. accuracy acc_short.png Training error loss_short.png

The accuracy drops for a moment around 8,9 epochs. The accuracy is less than 88%.

Finally

This time, it was an article about implementing a series of deep learning flows in 60 lines.

** Of course, more ingenuity is needed to improve accuracy **, but if you want to implement the infrastructure for the time being, please refer to it.

Finally, I'll put the whole code here.

test.py


import torch
import numpy as np
import torch.nn as nn
from torch import optim
import torch.nn.init as init
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import Dataset,DataLoader
import torchvision.datasets as dsets
import matplotlib.pyplot as plt

#Loading images
batch_size = 100
train_data = dsets.CIFAR10(root='./tmp/cifar-10', train=True, download=False, transform=transforms.Compose([transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False)]))
train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
test_data = dsets.CIFAR10(root='./tmp/cifar-10', train=False, download=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]))
test_loader = DataLoader(test_data,batch_size=batch_size,shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model_ft = models.resnet50(pretrained=True)
model_ft.fc = nn.Linear(model_ft.fc.in_features, 10)
net = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9,weight_decay=0.00005)

loss, epoch_loss, count = 0, 0, 0
acc_list = []
loss_list = []

#Training / reasoning
for i in range(50):
  
  #Learn from here
  net.train()
  
  for j,data in enumerate(train_loader,0):
    optimizer.zero_grad()

    #1:Read training data
    inputs,labels = data
    inputs = inputs.to(device)
    labels = labels.to(device)

    #2:calculate
    outputs = net(inputs)

    #3:Find the error
    loss = criterion(outputs,labels)

    #4:Learn from error
    loss.backward()
    optimizer.step()

    epoch_loss += loss
    count += 1
    print('%d: %.3f'%(j+1,loss))

  print('%depoch:mean_loss=%.3f\n'%(i+1,epoch_loss/count))
  loss_list.append(epoch_loss/count)

  epoch_loss, count = 0, 0
  correct,total = 0, 0
  accuracy = 0.0

  #Inference from here
  net.eval()
 
  for j,data in enumerate(test_loader,0):

    #Prepare test data
    inputs,labels = data
    inputs = inputs.to(device)
    labels = labels.to(device)

    #calculate
    outputs = net(inputs)

    #Find the predicted value
    _,predicted = torch.max(outputs.data,1)

    #Calculate accuracy
    correct += (predicted == labels).sum()
    total += batch_size

  accuracy = 100.*correct / total
  acc_list.append(accuracy)

  print('epoch:%d Accuracy(%d/%d):%f'%(i+1,correct,total,accuracy))
  torch.save(net.state_dict(),'Weight'+str(i+1))

plt.plot(acc_list)
plt.show(acc_list)
plt.plot(loss_list)
plt.show(loss_list)

Recommended Posts

CIFAR-10 classification implemented in virtually 60 lines in PyTorch
[PyTorch] Image classification of CIFAR-10
Implemented in Python PRML Chapter 4 Classification by Perceptron Algorithm
Implemented SimRank in Python
Implemented hard-swish in Keras
Implemented Shiritori in Python