[PYTHON] I tried to classify MNIST by GNN (with PyTorch geometric)

Introduction

Hi, my name is DNA1980. Recently ** GNN (Graph Neural Network) ** has become popular. I also want to follow the flow and handle graphs, but there are many graph data that exist in the world that I am not familiar with. I don't know what I'm doing when I classify it ... I thought that GNN could be applied if it could be incorporated into a graph even if it did not have a graph structure from the beginning, such as a network, so I applied it to everyone's favorite ** MNIST **.

If you are not familiar with GNN, there are people who have written in detail on Qiita, so I recommend reading this. GNN Summary (1): Introduction of GCN

The code used this time and the created data set are published on Github here.

environment

Python 3.7.6 PyTorch 1.4.0 PyTorch geometric 1.4.2

This time, I used PyTorch geometric as a library to handle GNN.

Creating a dataset

To apply GNN to MNIST, which is a 2D image, it needs to be graphed.

・ All bright pixels of 0.4 or more are used as nodes. ・ If there are nodes in the vicinity of 8 on the original image, add edges -Use two-dimensional quantities of x-coordinate and y-coordinate as features on each node.

The conversion was performed based on the above rules.

(Since it was troublesome to create, only 60000 data for train is used this time.)

The image looks like this makegraph.png Here is the code used to create the dataset this time. (At first, I was planning to put a side around 24, so it is padded more, but don't worry) Since there are not so many, I implemented it honestly, but it seems that it will be faster if you use bitboard etc.


#Call MNIST data from a gzip file to make it two-dimensional
data = 0
with gzip.open('./train-images-idx3-ubyte.gz', 'rb') as f:
    data = np.frombuffer(f.read(), np.uint8, offset=16)
    data = data.reshape([-1,28,28])
data = np.where(data < 102, -1, 1000)

for e,imgtmp in enumerate(data):
    img = np.pad(imgtmp,[(2,2),(2,2)],"constant",constant_values=(-1))
    cnt = 0

    for i in range(2,30):
        for j in range(2,30):
            if img[i][j] == 1000:
                img[i][j] = cnt
                cnt+=1
    
    edges = []
    #y coordinate, x coordinate
    npzahyou = np.zeros((cnt,2))

    for i in range(2,30):
        for j in range(2,30):
            if img[i][j] == -1:
                continue

            #8 Extract the part corresponding to the vicinity.
            filter = img[i-2:i+3,j-2:j+3].flatten()
            filter1 = filter[[6,7,8,11,13,16,17,18]]

            npzahyou[filter[12]][0] = i-2
            npzahyou[filter[12]][1] = j-2

            for tmp in filter1:
                if not tmp == -1:
                    edges.append([filter[12],tmp])

    np.save("../dataset/graphs/"+str(e),edges)
    np.save("../dataset/node_features/"+str(e),npzahyou)

To classify

This time ・ 6 layers of GCN and 2 layers of fully connected layers ・ Optimizer is Adam (all parameters are default) ・ Mini batch size is 100 ・ The number of epoch is 150 ・ ReLU is used for the activation function ・ Of all 60,000 data, 50.000 data is used for train and the rest is used for test.

I learned as.

model

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(2, 16)
        self.conv2 = GCNConv(16, 32)
        self.conv3 = GCNConv(32, 48)
        self.conv4 = GCNConv(48, 64)
        self.conv5 = GCNConv(64, 96)
        self.conv6 = GCNConv(96, 128)
        self.linear1 = torch.nn.Linear(128,64)
        self.linear2 = torch.nn.Linear(64,10)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = self.conv4(x, edge_index)
        x = F.relu(x)
        x = self.conv5(x, edge_index)
        x = F.relu(x)
        x = self.conv6(x, edge_index)
        x = F.relu(x)
        x, _ = scatter_max(x, data.batch, dim=0)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x

Learning part

data_size = 60000
train_size = 50000
batch_size = 100
epoch_num = 150

def main():
    mnist_list = load_mnist_graph(data_size=data_size)
    device = torch.device('cuda')
    model = Net().to(device)
    trainset = mnist_list[:train_size]
    optimizer = torch.optim.Adam(model.parameters())
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
    testset = mnist_list[train_size:]
    testloader = DataLoader(testset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    history = {
        "train_loss": [],
        "test_loss": [],
        "test_acc": []
    }

    print("Start Train")
    
    model.train()
    for epoch in range(epoch_num):
        train_loss = 0.0
        for i, batch in enumerate(trainloader):
            batch = batch.to("cuda")
            optimizer.zero_grad()
            outputs = model(batch)
            loss = criterion(outputs,batch.t)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.cpu().item()
            if i % 10 == 9:
                progress_bar = '['+('='*((i+1)//10))+(' '*((train_size//100-(i+1))//10))+']'
                print('\repoch: {:d} loss: {:.3f}  {}'
                        .format(epoch + 1, loss.cpu().item(), progress_bar), end="  ")

        print('\repoch: {:d} loss: {:.3f}'
            .format(epoch + 1, train_loss / (train_size / batch_size)), end="  ")
        history["train_loss"].append(train_loss / (train_size / batch_size))

        correct = 0
        total = 0
        batch_num = 0
        loss = 0
        with torch.no_grad():
            for data in testloader:
                data = data.to(device)
                outputs = model(data)
                loss += criterion(outputs,data.t)
                _, predicted = torch.max(outputs, 1)
                total += data.t.size(0)
                batch_num += 1
                correct += (predicted == data.t).sum().cpu().item()

        history["test_acc"].append(correct/total)
        history["test_loss"].append(loss.cpu().item()/batch_num)
        endstr = ' '*max(1,(train_size//1000-39))+"\n"
        print('Test Accuracy: {:.2f} %%'.format(100 * float(correct/total)), end='  ')
        print(f'Test Loss: {loss.cpu().item()/batch_num:.3f}',end=endstr)

    print('Finished Training')

    #Final result output
    correct = 0
    total = 0

    with torch.no_grad():
        for data in testloader:
            data = data.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs, 1)
            total += data.t.size(0)
            correct += (predicted == data.t).sum().cpu().item()
    print('Accuracy: {:.2f} %%'.format(100 * float(correct/total)))

result

The accuracy rate was ** 97.74% **. The changes in loss and test accuracy are as follows. At the end, it seems a little overfitting, but you can see that the learning is progressing cleanly. loss.png acc.png

I felt that the information was lost when the data was transformed, but I was surprised that it was better classified than MLP (Reference). It's interesting that you can classify this much just by using the coordinates instead of the brightness of the pixels as the features.

Then everyone has a good GNN life!

Recommended Posts

I tried to classify MNIST by GNN (with PyTorch geometric)
I tried to implement CVAE with PyTorch
I tried to implement sentence classification by Self Attention with PyTorch
I tried to detect Mario with pytorch + yolov3
I tried to implement reading Dataset with PyTorch
I tried to move GAN (mnist) with keras
I tried to classify dragon ball by adaline
I tried to classify mnist numbers by unsupervised learning [PCA, t-SNE, k-means]
I tried to move Faster R-CNN quickly with pytorch
I tried to implement and learn DCGAN with PyTorch
[Introduction to Pytorch] I tried categorizing Cifar10 with VGG16 ♬
I tried to implement SSD with PyTorch now (Dataset)
I tried to explain Pytorch dataset
I tried to implement SSD with PyTorch now (model edition)
I tried to implement Autoencoder with TensorFlow
I tried to visualize AutoEncoder with TensorFlow
I tried to get started with Hy
[Python] Introduction to CNN with Pytorch MNIST
I tried to classify text using TensorFlow
[Introduction to Pytorch] I played with sinGAN ♬
I tried batch normalization with PyTorch (+ note)
I tried implementing DeepPose with PyTorch PartⅡ
I tried to solve TSP with QAOA
765 I tried to identify the three professional families by CNN (with Chainer 2.0.0)
I tried to classify Oba Hana and Emiri Otani by deep learning
I tried to program bubble sort by language
I tried to use lightGBM, xgboost with Boruta
I tried to learn logical operations with TF Learn
I rewrote Chainer's MNIST code with PyTorch + Ignite
I tried to save the data with discord
I tried to detect motion quickly with OpenCV
I tried to integrate with Keras in TFv1.1
I tried to get CloudWatch data with Python
I tried to output LLVM IR with Python
I tried to detect an object with M2Det!
I tried to automate sushi making with python
I tried to predict Titanic survival with PyCaret
I tried to operate Linux with Discord Bot
I tried to study DP with Fibonacci sequence
I tried to start Jupyter with Amazon lightsail
I tried to judge Tsundere with Naive Bayes
I tried to classify Oba Hana and Emiri Otani by deep learning (Part 2)
I tried to implement sentence classification & Attention visualization by Japanese BERT in PyTorch
I tried to paste
I tried to learn the sin function with chainer
I tried to move machine learning (ObjectDetection) with TouchDesigner
I tried to create a table only with Django
I tried to extract features with SIFT of OpenCV
I tried to read and save automatically with VOICEROID2 2
Classify mnist numbers by unsupervised learning with keras [Autoencoder]
I tried to get started with blender python script_Part 01
I tried to touch the CSV file with Python
I tried to draw a route map with Python
I tried to automatically read and save with VOICEROID2
I tried to get started with blender python script_Part 02
I tried to generate ObjectId (primary key) with pymongo
I tried to implement an artificial perceptron with python
I tried to build ML Pipeline with Cloud Composer
I tried to implement time series prediction with GBDT
I tried to uncover our darkness with Chatwork API
I tried to automatically generate a password with Python3