[PYTHON] I tried to build a super-resolution method / SRCNN ③

Overview

Continuing from the previous session. It will be part3, the last article. Last time: I tried to build a super-resolution method / SRCNN ① Last time: I tried to build a super-resolution method / SRCNN②

table of contents

1.First of all 2. PC environment 3. Code description 4. At the end

1.First of all

Super-resolution is a technology that improves the resolution of low-resolution images and moving images, and SRCNN uses deep learning to measure results with higher accuracy than conventional methods. It is the method that was done. (Third time)

The full code is also posted on GitHub, so please check there. https://github.com/morisumori/srcnn_keras

2. PC environment

cpu : intel corei7 8th Gen gpu : NVIDIA GeForce RTX 1080ti os : ubuntu 20.04

3. Code description

As you can see from GitHub, it mainly consists of three codes. ・ Datacreate.py → Data set generation program ・ Model.py → SRCNN program ・ Main.py → Execution program I have created a function with datacreate.py and model.py and executed it with main.py.

__ This time, I will explain main.py. __

Description of model.py

model.py



import model
import data_create
import argparse
import os
import cv2

import numpy as np
import tensorflow as tf

if __name__ == "__main__":
    
    def psnr(y_true, y_pred):
        return tf.image.psnr(y_true, y_pred, 1, name=None)

    train_height = 33
    train_width = 33
    test_height = 700
    test_width = 700

    mag = 3.0
    cut_traindata_num = 10
    cut_testdata_num = 1

    train_file_path = "./train_data"
    test_file_path = "./test_data"

    BATSH_SIZE = 240
    EPOCHS = 1000
    opt = tf.keras.optimizers.Adam(learning_rate=0.0001)

    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', type=str, default='srcnn', help='srcnn, evaluate')

    args = parser.parse_args()

    if args.mode == "srcnn":
        train_x, train_y = data_create.save_frame(train_file_path,   #Path of the file containing the image to be cropped
                                                cut_traindata_num,  #Number of datasets generated
                                                train_height, #Storage size
                                                train_width,
                                                mag)   #magnification
                                                
        model = model.SRCNN() 
        model.compile(loss = "mean_squared_error",
                        optimizer = opt,
                        metrics = [psnr])
#https://keras.io/ja/getting-started/faq/
        model.fit(train_x,
                    train_y,
                    epochs = EPOCHS)

        model.save("srcnn_model.h5")

    elif args.mode == "evaluate":
        path = "srcnn_model"
        exp = ".h5"
        new_model = tf.keras.models.load_model(path + exp, custom_objects={'psnr':psnr})

        new_model.summary()

        test_x, test_y = data_create.save_frame(test_file_path,   #Path of the file containing the image to be cropped
                                                cut_testdata_num,  #Number of datasets generated
                                                test_height, #Storage size
                                                test_width,
                                                mag)   #magnification

        pred = new_model.predict(test_x)
        path = "resurt_" + path
        os.makedirs(path, exist_ok = True)
        path = path + "/"

        ps = psnr(tf.reshape(test_y[0], [test_height, test_width, 1]), pred[0])
        print("psnr:{}".format(ps))

        before_res = tf.keras.preprocessing.image.array_to_img(tf.reshape(test_x[0], [test_height, test_width, 1]))
        change_res = tf.keras.preprocessing.image.array_to_img(tf.reshape(test_y[0], [test_height, test_width, 1]))
        y_pred = tf.keras.preprocessing.image.array_to_img(pred[0])

        before_res.save(path + "low_" + str(0) + ".jpg ")
        change_res.save(path + "high_" + str(0) + ".jpg ")
        y_pred.save(path + "pred_" + str(0) + ".jpg ")

    else:
        raise Exception("Unknow --mode")

The main is quite long, but my impression is that if I can shorten it, I can do more. Below, I will explain the contents.

import model
import data_create
import argparse
import os
import cv2

import numpy as np
import tensorflow as tf

Here we are loading a function or another file in the same directory. datacreate.py, model.py and main.py should be in the same directory.

    def psnr(y_true, y_pred):
        return tf.image.psnr(y_true, y_pred, 1, name=None)

This time, I used psnr as a criterion for judging the quality of the generated image, so that is the definition. psnr is called the peak signal-to-noise ratio, and in simple terms it is like calculating the difference between the pixel values of the images you want to compare. I will omit the detailed explanation here, but this article is relatively detailed, and multiple evaluation methods are described.

    train_height = 33 #size of train data
    train_width = 33
    test_height = 700 #test data size
    test_width = 700

    mag = 3.0 #I don't use it, but I included it in the function.
    cut_traindata_num = 10 #How many photos are generated from one photo in train data generation.
    cut_testdata_num = 1 #How many photos are generated from one photo in test data generation?

    train_file_path = "./train_data" #Path of the file containing the image to be cropped
    test_file_path = "./test_data"

    BATSH_SIZE = 240 #batchsize
    EPOCHS = 1000 #epoch number
    opt = tf.keras.optimizers.Adam(learning_rate=0.0001) #optimizer

Here, the value used this time is set. It's okay if you look at github separately as config.py, but since it is not a large-scale program, it is summarized.

As for the size of the training data, the train data was adopted because the paper stated that it was 33 * 33. The test is just oversized for easy viewing. The number of data is 10 times the number of images contained in the file. (If 800 sheets, the number of data is 8,000)

This time, I used DIV2K Dataset, which is often used for super-resolution, for the data. Since the quality of the data is good, it is said that a certain amount of accuracy can be obtained with a small amount of data.

    parser = argparse.ArgumentParser()
    parser.add_argument('--mode', type=str, default='srcnn', help='srcnn, evaluate')

    args = parser.parse_args()

I wanted to separate the learning and evaluation of the model here, so I made it like this so that I can select it with --mode. I will not explain in detail, so I will post the official python documentation. https://docs.python.org/ja/3/library/argparse.html

if args.mode == "srcnn":
        train_x, train_y = data_create.save_frame(train_file_path,   #Path of the file containing the image to be cropped
                                                cut_traindata_num,  #Number of datasets generated
                                                train_height, #Storage size
                                                train_width,
                                                mag)   #magnification
                                                
        model = model.SRCNN() 
        model.compile(loss = "mean_squared_error",
                        optimizer = opt,
                        metrics = [psnr])
#https://keras.io/ja/getting-started/faq/
        model.fit(train_x,
                    train_y,
                    epochs = EPOCHS)

        model.save("srcnn_model.h5")

I am learning here. If you select srcnn (the method will be described later), this program will work.

In data_create.save_frame, the function called save_frame of data_create.py is read and made available. Now that the data is in train_x and train_y, load the model in the same way and compile and fit.

See keras documentation for more information on compile and more. We use the same papers as the papers.

Finally, save the model and you're done.

elif args.mode == "evaluate":
        path = "srcnn_model"
        exp = ".h5"
        new_model = tf.keras.models.load_model(path + exp, custom_objects={'psnr':psnr})

        new_model.summary()

        test_x, test_y = data_create.save_frame(test_file_path,   #Path of the file containing the image to be cropped
                                                cut_testdata_num,  #Number of datasets generated
                                                test_height, #Storage size
                                                test_width,
                                                mag)   #magnification

        pred = new_model.predict(test_x)
        path = "resurt_" + path
        os.makedirs(path, exist_ok = True)
        path = path + "/"

        ps = psnr(tf.reshape(test_y[0], [test_height, test_width, 1]), pred[0])
        print("psnr:{}".format(ps))

        before_res = tf.keras.preprocessing.image.array_to_img(tf.reshape(test_x[0], [test_height, test_width, 1]))
        change_res = tf.keras.preprocessing.image.array_to_img(tf.reshape(test_y[0], [test_height, test_width, 1]))
        y_pred = tf.keras.preprocessing.image.array_to_img(pred[0])

        before_res.save(path + "low_" + str(0) + ".jpg ")
        change_res.save(path + "high_" + str(0) + ".jpg ")
        y_pred.save(path + "pred_" + str(0) + ".jpg ")

    else:
        raise Exception("Unknow --mode")

It is finally the explanation of the last. First, load the model you saved earlier so that you can use psnr. Next, generate a dataset for test and generate an image with predict.

I wanted to know the psnr value on the spot, so I calculated it. I wanted to save the image, so I converted it from a tensor to a numpy array, saved it, and finally it's done! high_0.jpg This is a high quality image. (The original image) low_0.jpg This is a low quality image with a Gaussian filter. pred_0.jpg This is the image generated by the model. It's getting pretty. Was the psnr value about 34?

4. At the end

It's been a long time since I divided it into three articles, but thank you for reading. I will continue to work on various things in the future. If you have any questions or comments, please do not hesitate to contact us!

Recommended Posts

I tried to build a super-resolution method / SRCNN ①
I tried to build a super-resolution method / SRCNN ③
I tried to build a super-resolution method / SRCNN ②
I tried to build a super-resolution method / ESPCN
[Go + Gin] I tried to build a Docker environment
I tried "How to get a method decorated in Python"
I tried to create a linebot (implementation)
I tried to create a linebot (preparation)
I want to build a Python environment
I tried to make a Web API
I tried to debug.
I tried to paste
I tried to build a Mac Python development environment with pythonz + direnv
I tried to make a ○ ✕ game using TensorFlow
I tried to make a "fucking big literary converter"
I tried to create a table only with Django
I tried to draw a route map with Python
I tried to implement a pseudo pachislot in Python
I tried to implement a recommendation system (content-based filtering)
I want to easily build a model-based development environment
I tried to build ML Pipeline with Cloud Composer
I tried to automatically generate a password with Python3
CTF beginner tried to build a problem server (web) [Problem]
I tried to simulate the dollar cost averaging method
I added a function to CPython (build & structure grasp)
I tried to draw a configuration diagram using Diagrams
I tried to learn PredNet
I tried to organize SVM.
I tried to implement PCANet
I tried to reintroduce Linux
I tried to introduce Pylint
I tried to summarize SparseMatrix
I tried to touch jupyter
I tried to implement StarGAN (1)
I tried to build an environment with WSL + Ubuntu + VS Code in a Windows environment
I tried to create a class to search files with Python's Glob method in VBA
I tried to implement a basic Recurrent Neural Network model
I tried the super-resolution algorithm "PULSE" in a Windows environment
I tried to implement a one-dimensional cellular automaton in Python
I tried to automatically create a report with Markov chain
[Markov chain] I tried to read a quote into Python.
I tried to get started with Hy ・ Define a class
I tried to automate [a certain task] using Raspberry Pi
I stumbled when I tried to install Basemap, so a memorandum
I tried to sort a random FizzBuzz column with bubble sort.
I tried to create a bot for PES event notification
I tried to make a stopwatch using tkinter in python
I tried to divide with a deep learning language model
I tried to make a simple text editor using PyQt
I tried to build a service that sells machine-learned data at explosive speed with Docker
I tried to implement Deep VQE
I tried to create Quip API
I tried to touch Python (installation)
I tried to implement adversarial validation
I tried to explain Pytorch dataset
I tried Watson Speech to Text
I tried to touch Tesla's API
I tried to implement hierarchical clustering
I tried to organize about MCMC.
I tried to implement Realness GAN
I tried to move the ball