[PYTHON] J'ai essayé de classer MNIST par GNN (avec PyTorch géométrique)

introduction

Salut, je m'appelle DNA1980. Récemment, ** GNN (Graph Neural Network) ** est devenu populaire. Je veux également suivre le flux et gérer les graphiques, mais il existe de nombreuses données graphiques qui existent dans le monde avec lesquelles je ne suis pas familier. Je ne sais pas ce que je fais quand je le classe ... Je pensais que GNN pouvait être appliqué s'il pouvait être incorporé dans un graphe même s'il n'avait pas de structure de graphe depuis le début, comme un réseau, donc je l'ai appliqué au ** MNIST ** préféré de tout le monde.

Si vous n'êtes pas familier avec GNN, il y en a qui ont écrit en détail sur Qiita, donc je vous recommande de lire ceci. Résumé GNN (1): Introduction de GCN

Le code utilisé cette fois et l'ensemble de données créé sont publiés sur Github ici.

environnement

Python 3.7.6 PyTorch 1.4.0 PyTorch geometric 1.4.2

Cette fois, j'ai utilisé PyTorch Geometric comme bibliothèque qui gère GNN.

Créer un jeu de données

Pour appliquer GNN à MNIST, qui est une image bidimensionnelle, il doit être représenté graphiquement.

・ Tous les pixels brillants de 0,4 ou plus sont utilisés comme nœuds. ・ S'il y a des nœuds près de 8 sur l'image d'origine, ajoutez un côté -Utiliser des quantités bidimensionnelles de coordonnées x et y comme quantités de caractéristiques sur chaque nœud.

La conversion a été effectuée selon les règles ci-dessus.

(Comme il était difficile de créer, seules 60000 données pour le train sont utilisées cette fois.)

L'image ressemble à ceci makegraph.png Voici le code utilisé pour créer le jeu de données cette fois. (Au début, je prévoyais de mettre un côté autour de 24, donc il est rembourré en plus, mais ne vous inquiétez pas) Comme il n'y en a pas beaucoup, je l'ai implémenté honnêtement, mais il semble que ce sera plus rapide si vous utilisez Bitboard, etc.


#Appelez les données MNIST à partir d'un fichier gzip pour le rendre bidimensionnel
data = 0
with gzip.open('./train-images-idx3-ubyte.gz', 'rb') as f:
    data = np.frombuffer(f.read(), np.uint8, offset=16)
    data = data.reshape([-1,28,28])
data = np.where(data < 102, -1, 1000)

for e,imgtmp in enumerate(data):
    img = np.pad(imgtmp,[(2,2),(2,2)],"constant",constant_values=(-1))
    cnt = 0

    for i in range(2,30):
        for j in range(2,30):
            if img[i][j] == 1000:
                img[i][j] = cnt
                cnt+=1
    
    edges = []
    #coordonnée y, coordonnée x
    npzahyou = np.zeros((cnt,2))

    for i in range(2,30):
        for j in range(2,30):
            if img[i][j] == -1:
                continue

            #8 Extraire la partie correspondant au voisinage.
            filter = img[i-2:i+3,j-2:j+3].flatten()
            filter1 = filter[[6,7,8,11,13,16,17,18]]

            npzahyou[filter[12]][0] = i-2
            npzahyou[filter[12]][1] = j-2

            for tmp in filter1:
                if not tmp == -1:
                    edges.append([filter[12],tmp])

    np.save("../dataset/graphs/"+str(e),edges)
    np.save("../dataset/node_features/"+str(e),npzahyou)

Classer

Cette fois ・ 6 couches de GCN et 2 couches de couches entièrement connectées ・ L'optimiseur est Adam (tous les paramètres sont par défaut) ・ La taille du mini lot est de 100 ・ Le nombre d'époque est de 150 ・ Utilisez ReLU pour la fonction d'activation ・ Sur les 60 000 données, 50 000 sont utilisées pour le train et le reste est utilisé pour le test.

J'ai appris comme.

modèle

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(2, 16)
        self.conv2 = GCNConv(16, 32)
        self.conv3 = GCNConv(32, 48)
        self.conv4 = GCNConv(48, 64)
        self.conv5 = GCNConv(64, 96)
        self.conv6 = GCNConv(96, 128)
        self.linear1 = torch.nn.Linear(128,64)
        self.linear2 = torch.nn.Linear(64,10)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = self.conv4(x, edge_index)
        x = F.relu(x)
        x = self.conv5(x, edge_index)
        x = F.relu(x)
        x = self.conv6(x, edge_index)
        x = F.relu(x)
        x, _ = scatter_max(x, data.batch, dim=0)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x

Partie d'apprentissage

data_size = 60000
train_size = 50000
batch_size = 100
epoch_num = 150

def main():
    mnist_list = load_mnist_graph(data_size=data_size)
    device = torch.device('cuda')
    model = Net().to(device)
    trainset = mnist_list[:train_size]
    optimizer = torch.optim.Adam(model.parameters())
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
    testset = mnist_list[train_size:]
    testloader = DataLoader(testset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    history = {
        "train_loss": [],
        "test_loss": [],
        "test_acc": []
    }

    print("Start Train")
    
    model.train()
    for epoch in range(epoch_num):
        train_loss = 0.0
        for i, batch in enumerate(trainloader):
            batch = batch.to("cuda")
            optimizer.zero_grad()
            outputs = model(batch)
            loss = criterion(outputs,batch.t)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.cpu().item()
            if i % 10 == 9:
                progress_bar = '['+('='*((i+1)//10))+(' '*((train_size//100-(i+1))//10))+']'
                print('\repoch: {:d} loss: {:.3f}  {}'
                        .format(epoch + 1, loss.cpu().item(), progress_bar), end="  ")

        print('\repoch: {:d} loss: {:.3f}'
            .format(epoch + 1, train_loss / (train_size / batch_size)), end="  ")
        history["train_loss"].append(train_loss / (train_size / batch_size))

        correct = 0
        total = 0
        batch_num = 0
        loss = 0
        with torch.no_grad():
            for data in testloader:
                data = data.to(device)
                outputs = model(data)
                loss += criterion(outputs,data.t)
                _, predicted = torch.max(outputs, 1)
                total += data.t.size(0)
                batch_num += 1
                correct += (predicted == data.t).sum().cpu().item()

        history["test_acc"].append(correct/total)
        history["test_loss"].append(loss.cpu().item()/batch_num)
        endstr = ' '*max(1,(train_size//1000-39))+"\n"
        print('Test Accuracy: {:.2f} %%'.format(100 * float(correct/total)), end='  ')
        print(f'Test Loss: {loss.cpu().item()/batch_num:.3f}',end=endstr)

    print('Finished Training')

    #Sortie du résultat final
    correct = 0
    total = 0

    with torch.no_grad():
        for data in testloader:
            data = data.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs, 1)
            total += data.t.size(0)
            correct += (predicted == data.t).sum().cpu().item()
    print('Accuracy: {:.2f} %%'.format(100 * float(correct/total)))

résultat

Le taux de réponse correcte (précision) était de ** 97,74% **. Les changements de perte et de précision des tests sont les suivants. À la fin, cela semble un peu surapprentissage, mais vous pouvez voir que l'apprentissage progresse proprement. loss.png acc.png

J'ai senti que les informations étaient perdues lors de la transformation des données, mais j'ai été surpris qu'elles soient mieux classées que MLP (Reference). Il est intéressant de noter que vous pouvez classer cela simplement en utilisant les coordonnées au lieu de la luminosité des pixels comme quantité de fonctionnalités.

Alors tout le monde a une bonne vie GNN!

Recommended Posts

J'ai essayé de classer MNIST par GNN (avec PyTorch géométrique)
J'ai essayé d'implémenter CVAE avec PyTorch
J'ai essayé d'implémenter la classification des phrases par Self Attention avec PyTorch
J'ai essayé d'implémenter la lecture de Dataset avec PyTorch
J'ai essayé de déplacer GAN (mnist) avec keras
J'ai essayé de classer les boules de dragon par adaline
J'ai essayé de classer les nombres de mnist par apprentissage non supervisé [PCA, t-SNE, k-means]
J'ai essayé de déplacer Faster R-CNN rapidement avec pytorch
J'ai essayé d'implémenter et d'apprendre DCGAN avec PyTorch
[Introduction à Pytorch] J'ai essayé de catégoriser Cifar10 avec VGG16 ♬
J'ai essayé d'implémenter SSD avec PyTorch maintenant (Dataset)
J'ai essayé d'expliquer l'ensemble de données de Pytorch
J'ai essayé d'implémenter SSD avec PyTorch maintenant (édition du modèle)
J'ai essayé d'implémenter Autoencoder avec TensorFlow
J'ai essayé de visualiser AutoEncoder avec TensorFlow
J'ai essayé de commencer avec Hy
J'ai essayé de classer le texte en utilisant TensorFlow
[Introduction à Pytorch] J'ai joué avec sinGAN ♬
J'ai essayé d'implémenter DeepPose avec PyTorch PartⅡ
J'ai essayé de résoudre TSP avec QAOA
765 J'ai essayé d'identifier les trois familles professionnelles par CNN (avec Chainer 2.0.0)
J'ai essayé de classer Oba Hanana et Otani Emiri par apprentissage profond
J'ai essayé de programmer la bulle de tri par langue
J'ai essayé d'utiliser lightGBM, xg boost avec Boruta
J'ai essayé d'apprendre le fonctionnement logique avec TF Learn
J'ai réécrit le code MNIST de Chainer avec PyTorch + Ignite
J'ai essayé de sauvegarder les données avec discorde
J'ai essayé de détecter rapidement un mouvement avec OpenCV
J'ai essayé d'intégrer Keras dans TFv1.1
J'ai essayé d'obtenir des données CloudWatch avec Python
J'ai essayé de sortir LLVM IR avec Python
J'ai essayé de détecter un objet avec M2Det!
J'ai essayé d'automatiser la fabrication des sushis avec python
J'ai essayé de prédire la survie du Titanic avec PyCaret
J'ai essayé d'utiliser Linux avec Discord Bot
J'ai essayé d'étudier DP avec séquence de Fibonacci
J'ai essayé de démarrer Jupyter avec toutes les lumières d'Amazon
J'ai essayé de juger Tundele avec Naive Bays
J'ai essayé de classer Hanana Oba et Emiri Otani par apprentissage profond (partie 2)
J'ai essayé d'implémenter la classification des phrases et la visualisation de l'attention par le japonais BERT avec PyTorch
J'ai essayé d'entraîner la fonction péché avec chainer
J'ai essayé de déplacer l'apprentissage automatique (détection d'objet) avec TouchDesigner
J'ai essayé d'extraire des fonctionnalités avec SIFT d'OpenCV
J'ai essayé de lire et d'enregistrer automatiquement avec VOICEROID2 2
Classer les numéros mnist par keras sans apprentissage par l'enseignant [Auto Encoder Edition]
J'ai essayé de démarrer avec le script python de blender_Part 01
J'ai essayé de toucher un fichier CSV avec Python
J'ai essayé de lire et d'enregistrer automatiquement avec VOICEROID2
J'ai essayé de démarrer avec le script python de blender_Partie 02
J'ai essayé de générer ObjectId (clé primaire) avec pymongo
J'ai essayé d'implémenter le perceptron artificiel avec python
J'ai essayé de créer un pipeline ML avec Cloud Composer
J'ai essayé de découvrir notre obscurité avec l'API Chatwork