[PYTHON] J'ai essayé de créer une méthode de super résolution / SRCNN ③

Aperçu

Suite de la session précédente. Ceci est la partie 3, le dernier article. La dernière fois: J'ai essayé de créer une méthode de super résolution / SRCNN ① La dernière fois: J'ai essayé de créer une méthode de super résolution / SRCNN ②

table des matières

1.Tout d'abord 2. Environnement PC 3. Description du code 4. À la fin

1.Tout d'abord

La super-résolution est une technologie qui améliore la résolution des images basse résolution et des images animées, et le SRCNN utilise l'apprentissage en profondeur pour mesurer les résultats avec une plus grande précision que les méthodes conventionnelles. C'est la méthode qui a été faite. (Troisième fois)

Le code complet est également publié sur GitHub, veuillez donc vérifier ici. https://github.com/morisumori/srcnn_keras

2. Environnement PC

cpu : intel corei7 8th Gen gpu : NVIDIA GeForce RTX 1080ti os : ubuntu 20.04

3. Description du code

Comme vous pouvez le voir sur GitHub, il se compose principalement de trois codes. ・ Datacreate.py → Programme de génération de jeux de données ・ Model.py → Programme SRCNN ・ Main.py → Programme d'exécution J'ai créé une fonction avec datacreate.py et model.py et l'ai exécutée avec main.py.

__ Cette fois, je vais expliquer main.py. __

Description de model.py

model.py



import model
import data_create
import argparse
import os
import cv2

import numpy as np
import tensorflow as tf

if __name__ == "__main__":
    
    def psnr(y_true, y_pred):
        return tf.image.psnr(y_true, y_pred, 1, name=None)

    train_height = 33
    train_width = 33
    test_height = 700
    test_width = 700

    mag = 3.0
    cut_traindata_num = 10
    cut_testdata_num = 1

    train_file_path = "./train_data"
    test_file_path = "./test_data"

    BATSH_SIZE = 240
    EPOCHS = 1000
    opt = tf.keras.optimizers.Adam(learning_rate=0.0001)

    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', type=str, default='srcnn', help='srcnn, evaluate')

    args = parser.parse_args()

    if args.mode == "srcnn":
        train_x, train_y = data_create.save_frame(train_file_path,   #Chemin du fichier contenant l'image à recadrer
                                                cut_traindata_num,  #Nombre d'ensembles de données générés
                                                train_height, #Taille de stockage
                                                train_width,
                                                mag)   #grossissement
                                                
        model = model.SRCNN() 
        model.compile(loss = "mean_squared_error",
                        optimizer = opt,
                        metrics = [psnr])
#https://keras.io/ja/getting-started/faq/
        model.fit(train_x,
                    train_y,
                    epochs = EPOCHS)

        model.save("srcnn_model.h5")

    elif args.mode == "evaluate":
        path = "srcnn_model"
        exp = ".h5"
        new_model = tf.keras.models.load_model(path + exp, custom_objects={'psnr':psnr})

        new_model.summary()

        test_x, test_y = data_create.save_frame(test_file_path,   #Chemin du fichier contenant l'image à recadrer
                                                cut_testdata_num,  #Nombre d'ensembles de données générés
                                                test_height, #Taille de stockage
                                                test_width,
                                                mag)   #grossissement

        pred = new_model.predict(test_x)
        path = "resurt_" + path
        os.makedirs(path, exist_ok = True)
        path = path + "/"

        ps = psnr(tf.reshape(test_y[0], [test_height, test_width, 1]), pred[0])
        print("psnr:{}".format(ps))

        before_res = tf.keras.preprocessing.image.array_to_img(tf.reshape(test_x[0], [test_height, test_width, 1]))
        change_res = tf.keras.preprocessing.image.array_to_img(tf.reshape(test_y[0], [test_height, test_width, 1]))
        y_pred = tf.keras.preprocessing.image.array_to_img(pred[0])

        before_res.save(path + "low_" + str(0) + ".jpg ")
        change_res.save(path + "high_" + str(0) + ".jpg ")
        y_pred.save(path + "pred_" + str(0) + ".jpg ")

    else:
        raise Exception("Unknow --mode")

Le principal est assez long, mais j'ai l'impression que si je peux le raccourcir, je peux faire plus. J'expliquerai le contenu ci-dessous.

import model
import data_create
import argparse
import os
import cv2

import numpy as np
import tensorflow as tf

Ici, nous chargeons une fonction ou un autre fichier dans le même répertoire. datacreate.py, model.py et main.py doivent être dans le même répertoire.

    def psnr(y_true, y_pred):
        return tf.image.psnr(y_true, y_pred, 1, name=None)

Cette fois, j'ai utilisé psnr comme critère pour juger de la qualité de l'image générée, c'est donc la définition. psnr est appelé le rapport signal / bruit de crête et, en termes simples, cela revient à calculer la différence entre les valeurs de pixel des images que vous souhaitez comparer. Je n'expliquerai pas en détail ici, mais cet article est plutôt détaillé et plusieurs méthodes d'évaluation sont décrites.

    train_height = 33 #taille des données de train
    train_width = 33
    test_height = 700 #taille des données de test
    test_width = 700

    mag = 3.0 #Je ne l'utilise pas, mais je l'ai inclus dans la fonction.
    cut_traindata_num = 10 #Combien de feuilles de données sont générées à partir d'une photo dans la génération de données de train
    cut_testdata_num = 1 #Combien de feuilles de données sont générées à partir d'une photo lors de la génération des données de test.

    train_file_path = "./train_data" #Chemin du fichier contenant l'image à recadrer
    test_file_path = "./test_data"

    BATSH_SIZE = 240 #batchsize
    EPOCHS = 1000 #numéro d'époque
    opt = tf.keras.optimizers.Adam(learning_rate=0.0001) #optimizer

Ici, la valeur utilisée cette fois est définie. Si vous regardez github, c'est bien si vous avez un config.py séparé, mais comme ce n'est pas un programme à grande échelle, il est résumé.

Quant à la taille des données de formation, les données de train ont été adoptées parce que le document indiquait qu'elles étaient de 33 * 33. Le test est juste surdimensionné pour une visualisation facile. Le nombre de données est égal à 10 fois le nombre d'images contenues dans le fichier. (Si 800 feuilles, le nombre de données est de 8 000)

Cette fois, j'ai utilisé DIV2K Dataset, qui est souvent utilisé pour la super-résolution. Puisque la qualité des données est bonne, on dit qu'une certaine précision peut être obtenue avec une petite quantité de données.

    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', type=str, default='srcnn', help='srcnn, evaluate')

    args = parser.parse_args()

Je voulais séparer l'apprentissage et l'évaluation du modèle ici, donc je l'ai fait comme ça pour que je puisse le sélectionner avec --mode. Je n'expliquerai pas en détail, donc je publierai la documentation officielle de python. https://docs.python.org/ja/3/library/argparse.html

if args.mode == "srcnn":
        train_x, train_y = data_create.save_frame(train_file_path,   #Chemin du fichier contenant l'image à recadrer
                                                cut_traindata_num,  #Nombre d'ensembles de données générés
                                                train_height, #Taille de stockage
                                                train_width,
                                                mag)   #grossissement
                                                
        model = model.SRCNN() 
        model.compile(loss = "mean_squared_error",
                        optimizer = opt,
                        metrics = [psnr])
#https://keras.io/ja/getting-started/faq/
        model.fit(train_x,
                    train_y,
                    epochs = EPOCHS)

        model.save("srcnn_model.h5")

J'apprends ici. Si vous sélectionnez srcnn (la méthode sera décrite plus tard), ce programme fonctionnera.

Dans data_create.save_frame, la fonction appelée save_frame de data_create.py est lue et rendue disponible. Maintenant que les données sont dans train_x et train_y, chargez le modèle de la même manière et compilez et ajustez.

Voir documentation keras pour plus d'informations sur la compilation et plus encore. Nous utilisons les mêmes que les papiers.

Enfin, enregistrez le modèle et vous avez terminé.

elif args.mode == "evaluate":
        path = "srcnn_model"
        exp = ".h5"
        new_model = tf.keras.models.load_model(path + exp, custom_objects={'psnr':psnr})

        new_model.summary()

        test_x, test_y = data_create.save_frame(test_file_path,   #Chemin du fichier contenant l'image à recadrer
                                                cut_testdata_num,  #Nombre d'ensembles de données générés
                                                test_height, #Taille de stockage
                                                test_width,
                                                mag)   #grossissement

        pred = new_model.predict(test_x)
        path = "resurt_" + path
        os.makedirs(path, exist_ok = True)
        path = path + "/"

        ps = psnr(tf.reshape(test_y[0], [test_height, test_width, 1]), pred[0])
        print("psnr:{}".format(ps))

        before_res = tf.keras.preprocessing.image.array_to_img(tf.reshape(test_x[0], [test_height, test_width, 1]))
        change_res = tf.keras.preprocessing.image.array_to_img(tf.reshape(test_y[0], [test_height, test_width, 1]))
        y_pred = tf.keras.preprocessing.image.array_to_img(pred[0])

        before_res.save(path + "low_" + str(0) + ".jpg ")
        change_res.save(path + "high_" + str(0) + ".jpg ")
        y_pred.save(path + "pred_" + str(0) + ".jpg ")

    else:
        raise Exception("Unknow --mode")

C'est finalement l'explication du dernier. Commencez par charger le modèle que vous avez enregistré précédemment afin de pouvoir utiliser psnr. Ensuite, générez un ensemble de données pour le test et générez une image avec prédire.

Je voulais connaître la valeur psnr sur place, alors je l'ai calculée. Je voulais enregistrer l'image, alors je l'ai convertie du tenseur en un tableau numpy, l'ai sauvegardée, et finalement c'est fait! high_0.jpg C'est une image de haute qualité. (L'image d'origine) low_0.jpg Il s'agit d'une image de faible qualité avec un filtre gaussien. pred_0.jpg Il s'agit de l'image générée par le modèle. C'est beau. La valeur psnr était-elle d'environ 34?

4. À la fin

Cela fait longtemps que je ne l'ai pas divisé en trois articles, mais merci d'avoir lu. Je continuerai à travailler sur diverses choses à l'avenir. Si vous avez des questions ou des commentaires, n'hésitez pas à nous contacter!

Recommended Posts

J'ai essayé de créer une méthode de super résolution / SRCNN ①
J'ai essayé de créer une méthode de super résolution / SRCNN ③
J'ai essayé de créer une méthode de super résolution / SRCNN ②
J'ai essayé de créer une méthode de super résolution / ESPCN
[Go + Gin] J'ai essayé de créer un environnement Docker
J'ai essayé "Comment obtenir une méthode décorée en Python"
J'ai essayé de créer un linebot (implémentation)
J'ai essayé de créer un linebot (préparation)
Je veux créer un environnement Python
J'ai créé une API Web
J'ai essayé de déboguer.
J'ai essayé de créer un environnement de développement Mac Python avec pythonz + direnv
J'ai créé un jeu ○ ✕ avec TensorFlow
J'ai essayé de faire un "putain de gros convertisseur de littérature"
J'ai essayé d'implémenter un pseudo pachislot en Python
Je souhaite créer facilement un environnement de développement basé sur un modèle
J'ai essayé de créer un pipeline ML avec Cloud Composer
Le débutant de la CTF a tenté de créer un serveur problématique (Web) [Problème]
J'ai essayé de simuler la méthode de calcul de la moyenne des coûts en dollars
J'ai ajouté une fonction à CPython (construction et compréhension de la structure)
J'ai essayé de dessiner un diagramme de configuration à l'aide de diagrammes
J'ai essayé d'apprendre PredNet
J'ai essayé d'organiser SVM.
J'ai essayé d'implémenter PCANet
J'ai essayé de réintroduire Linux
J'ai essayé de présenter Pylint
J'ai essayé de résumer SparseMatrix
jupyter je l'ai touché
J'ai essayé d'implémenter StarGAN (1)
J'ai essayé de créer un environnement avec WSL + Ubuntu + VS Code dans un environnement Windows
J'ai essayé de créer une classe pour rechercher des fichiers avec la méthode Glob de Python dans VBA
J'ai essayé de mettre en œuvre le modèle de base du réseau neuronal récurrent
J'ai essayé l'algorithme de super résolution "PULSE" dans un environnement Windows
J'ai essayé d'implémenter un automate cellulaire unidimensionnel en Python
J'ai essayé de créer automatiquement un rapport avec la chaîne de Markov
[Chaîne de Markov] J'ai essayé de lire les citations en Python.
J'ai essayé de commencer avec Hy ・ Définir une classe
J'ai essayé d'automatiser [une certaine tâche] à l'aide d'une tarte à la râpe
J'ai trébuché lorsque j'ai essayé d'installer Basemap, donc un mémorandum
J'ai essayé de trier une colonne FizzBuzz aléatoire avec un tri à bulles.
J'ai essayé de créer un bot pour annoncer un événement Wiire
J'ai fait un chronomètre en utilisant tkinter avec python
J'ai essayé d'écrire dans un modèle de langage profondément appris
J'ai créé un éditeur de texte simple en utilisant PyQt
J'ai essayé de créer un service qui vend des données apprises par machine à une vitesse explosive avec Docker
J'ai essayé d'implémenter Deep VQE
J'ai essayé de créer l'API Quip
J'ai essayé de toucher Python (installation)
J'ai essayé de mettre en place une validation contradictoire
J'ai essayé d'expliquer l'ensemble de données de Pytorch
J'ai essayé l'authentification vocale Watson (Speech to Text)
J'ai touché l'API de Tesla
J'ai essayé de m'organiser à propos de MCMC.
J'ai essayé d'implémenter Realness GAN
J'ai essayé de déplacer le ballon