[PYTHON] A story of a deep learning beginner trying to classify guitars on CNN

Overview

Some people have already tried it on Qiita, but it also serves as their own study. I tried to classify guitar images using CNN (ResNet), so I tried it in the process, Here are some things that may be helpful. (Since it is not summarized, it is a little dirty, but I will also post the code)

table of contents

--Specific classification method --About pretreatment --About learning method --About learning results ――Try and play --Summary

About specific classification method

The guitar image is scraped and preprocessed to inflate the image. By fine-tuning ResNet, which is a method of CNN, using inflated images, I will try machine learning without spending too much learning cost.

About labels

I chose the following models, which seem to be relatively easy to collect images.

--Made by Fender --Stratocaster --Telecaster --Jazzmaster --Jaguar --Mustang (including similar models) --Made by Gibson

About pretreatment

The first is to collect images. This time, I collected it using iCrawler. Generally, most of them are collected from Google image search, but as of March 12, 2020, due to changes in specifications on the Google side. This time I collected images from Bing because the tool seems to be out of order.

crawling.py


import os

from icrawler.builtin import BingImageCrawler

searching_words = [
                    "Fender Stratocaster",
                    "Fender Telecaster",
                    "Fender Jazzmaster",
                    "Fender Jaguar",
                    "Fender Mustang",
                    "Gibson LesPaul",
                    "Gibson SG",
                    "Gibson FlyingV",
                    "Gibson ES-335",
                    "Acoustic guitar"
                ]
if __name__ == "__main__":
    for word in searching_words:
        if not os.path.isdir('./searched_image/' + word):
            os.makedirs('./searched_image/' + word)
        bing_crawler = BingImageCrawler(storage={ 'root_dir': './searched_image/' + word })
        bing_crawler.crawl(keyword=word, max_num=1000)

After collecting, I manually omitted images that are unlikely to be used (those that do not show the whole body of the guitar, those that contain letters, those that have reflections such as hands, etc.). As a result, we were able to collect about 100 to 160 images for each label. (I specified max_num = 1000 in the crawl method, but it only collected about 400 sheets.)

Next, we will preprocess the collected images. This time, the image was rotated by 45 ° and inverted. Therefore, the result increased 16 times to about 1600 to 2000 images for each label.

image_preprocessing.py


import os
import glob

from PIL import Image
import numpy as np
from sklearn.model_selection import train_test_split 

#The size of the image to be compressed
image_size = 224
#Number of training data
traindata = 1000
#Number of test data
testdata = 300

#Input folder name
src_dir = './searched_image'
#Output folder name
dst_dir = './input_guitar_data'

#Label name to identify
labels = [
                    "Fender Stratocaster",
                    "Fender Telecaster",
                    "Fender Jazzmaster",
                    "Fender Jaguar",
                    "Fender Mustang",
                    "Gibson LesPaul",
                    "Gibson SG",
                    "Gibson FlyingV",
                    "Gibson ES-335",
                    "Acoustic guitar"
                ]
#Loading images
for index, label in enumerate(labels):
    files =glob.glob("{}/{}/all/*.jpg ".format(src_dir, label))
        
    #Image converted data
    X = []
    #label
    Y = []

    for file in files:
        #Open image
        img = Image.open(file)
        img = img.convert("RGB")
        
        #===================#Convert to square#===================#
        width, height = img.size
        #If it is vertically long, expand it horizontally
        if width < height:
            result = Image.new(img.mode,(height, height),(255, 255, 255))
            result.paste(img, ((height - width) // 2, 0))
        #If it is horizontally long, expand it vertically
        elif width > height:
            result = Image.new(img.mode,(width, width),(255, 255, 255))
            result.paste(img, (0, (width - height) // 2))
        else:
            result = img

        #Align image size to 224x224
        result.resize((image_size, image_size))

        data = np.asarray(result)
        X.append(data)
        Y.append(index)

        #===================#Inflated data#===================#
        for angle in range(0, 360, 45):
            #rotation
            img_r = result.rotate(angle)
            data = np.asarray(img_r)
            X.append(data)
            Y.append(index)

            #Invert
            img_t = img_r.transpose(Image.FLIP_LEFT_RIGHT)
            data = np.asarray(img_t)
            X.append(data)
            Y.append(index)
    
    #Normalization(0~255->0~1)
    X = np.array(X,dtype='float32') / 255.0
    Y = np.array(Y)


    #Split data for cross-validation
    X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=testdata, train_size=traindata)
    xy = (X_train, X_test, y_train, y_test)
    np.save("{}/{}_{}.npy".format(dst_dir, label, index), xy)

Save the preprocessed results in an npy file for each label.

About learning method

This time, I will try to learn using ResNet, which is a typical method of CNN. Since the PC I own does not have an NVIDIA GPU, if I try to learn as it is, it will take a huge amount of time because it will be calculated only by the CPU, so let's execute and learn the following code in the GPGPU environment using Google Colab I did. (How to use Colab, how to upload files, etc. are omitted)

import gc

import keras
from keras.applications.resnet50 import ResNet50
from keras.models import Sequential, Model
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense, Input
from keras.callbacks import EarlyStopping 
from keras.utils import np_utils
from keras import optimizers

from sklearn.metrics import confusion_matrix

import numpy as np
import matplotlib.pyplot as plt

#Class label definition
classes = [
                    "Fender Stratocaster",
                    "Fender Telecaster",
                    "Fender Jazzmaster",
                    "Fender Jaguar",
                    "Fender Mustang",
                    "Gibson LesPaul",
                    "Gibson SG",
                    "Gibson FlyingV",
                    "Gibson ES-335",
                    "Acoustic guitar"
                ]
num_classes = len(classes)

#Image size to load
ScaleTo = 224

#Definition of main function
def main():
    #Reading training data
    src_dir = '/content/drive/My Drive/Machine learning/input_guitar_data'

    train_Xs = []
    test_Xs = []
    train_ys = []
    test_ys = []

    for index, class_name in enumerate(classes):
        file = "{}/{}_{}.npy".format(src_dir, class_name, index)
        #Bring a separate learning file
        train_X, test_X, train_y, test_y = np.load(file, allow_pickle=True)

        #Combine data into one
        train_Xs.append(train_X)
        test_Xs.append(test_X)
        train_ys.append(train_y)
        test_ys.append(test_y)

    #Combine the combined data
    X_train = np.concatenate(train_Xs, 0)
    X_test = np.concatenate(test_Xs, 0)
    y_train = np.concatenate(train_ys, 0)
    y_test = np.concatenate(test_ys, 0)

    #Label
    y_train = np_utils.to_categorical(y_train, num_classes)
    y_test = np_utils.to_categorical(y_test, num_classes)


    #Generation of machine learning model
    model, history = model_train(X_train, y_train, X_test, y_test)
    model_eval(model, X_test, y_test)
    #Display learning history
    model_visualization(history)

def model_train(X_train, y_train, X_test, y_test):
    #ResNet 50 load. Include because no fully connected layer is required_top=False
    input_tensor = Input(shape=(ScaleTo, ScaleTo, 3))
    resnet50 = ResNet50(include_top=False, weights='imagenet', input_tensor=input_tensor)

    #Creating a fully connected layer
    top_model = Sequential()
    top_model.add(Flatten(input_shape=resnet50.output_shape[1:]))
    top_model.add(Dense(256, activation='relu'))
    top_model.add(Dropout(0.5))
    top_model.add(Dense(num_classes, activation='softmax'))

    #Create a model by combining ResNet50 and a fully connected layer
    resnet50_model = Model(input=resnet50.input, output=top_model(resnet50.output))

    """
    #Fixed some weights of ResNet50
    for layer in resnet50_model.layers[:100]:
        layer.trainable = False
    """

    #Specify multi-class classification
    resnet50_model.compile(loss='categorical_crossentropy',
            optimizer=optimizers.SGD(lr=1e-3, momentum=0.9),
            metrics=['accuracy'])
    resnet50_model.summary()

    #Execution of learning
    early_stopping = EarlyStopping(monitor='val_loss', patience=0, verbose=1) 
    history = resnet50_model.fit(X_train, y_train,
                        batch_size=75,
                        epochs=25, validation_data=(X_test, y_test),
                        callbacks=[early_stopping])
    #Save model
    resnet50_model.save("/content/drive/My Drive/Machine learning/guitar_cnn_resnet50.h5")
    
    return resnet50_model, history

def model_eval(model, X_test, y_test):
    scores = model.evaluate(X_test, y_test, verbose=1)
    print("test Loss", scores[0])
    print("test Accuracy", scores[1])
    #Calculation of confusion matrix
    predict_classes = model.predict(X_test)
    predict_classes = np.argmax(predict_classes, 1)
    true_classes = np.argmax(y_test, 1)
    print(predict_classes)
    print(true_classes)
    cmx = confusion_matrix(true_classes, predict_classes)
    print(cmx)
    #Erase the model after inference
    del model
    keras.backend.clear_session() #← This is
    gc.collect()

def model_visualization(history):
    #Graph display of loss value
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()

    #Graph display of correct answer rate
    plt.plot(history.history['acc'])
    plt.plot(history.history['val_acc'])
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper left')
    plt.show()
    
if __name__ == "__main__":
    main()

This time, the result of val acc etc. was better if the weight was not fixed, so the weight of each layer is also learned again. In the code, 100 epochs are trained, but in reality, early stopping has actually completed the learning at the 5th epoch.

About learning results

The result is as follows.

test Loss 0.09369107168481061
test Accuracy 0.9744

I will also put out a confusion matrix.

[[199   0   1   0   0   0   0   0   0   0]
 [  0 200   0   0   0   0   0   0   0   0]
 [  2   5 191   2   0   0   0   0   0   0]
 [  1   0  11 180   6   0   2   0   0   0]
 [  0   2   0   0 198   0   0   0   0   0]
 [  0   0   0   0   0 288   4   0   6   2]
 [  0   2   0   0   0   0 296   0   2   0]
 [  0   0   0   0   0   0   0 300   0   0]
 [  0   0   0   0   0   0   0   0 300   0]
 [  0   0   0   0   0   0   0   1   0 299]]

ダウンロード2.png ダウンロード.png

At the end of one epoch, you can see that learning has progressed considerably.

Try and play

I will try inference based on the saved model. This time I tried to make it a very rudimentary web application using Flask that I touched for the first time.

graphing.py


import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

def to_graph(image, labels, predicted):
    #=======#Plot and save#=======#
    fig = plt.figure(figsize=(10.24, 5.12))
    fig.subplots_adjust(left=0.2)

    #=======#Write a bar chart#=======#
    ax1 = fig.add_subplot(1,2,1)
    ax1.barh(labels, predicted, color='c', align="center")
    ax1.set_yticks(labels)#y-axis label
    ax1.set_xticks([])#Remove x-axis labels

    #Write numbers in bar charts
    for interval, value in zip(range(0,len(labels)), predicted):
        ax1.text(0.02, interval, value, ha='left', va='center')

    #=======#Insert the identified image#=======#
    ax2 = fig.add_subplot(1,2,2)
    ax2.imshow(image)
    ax2.axis('off')

    return fig

def expand_to_square(input_file):
    """Convert a rectangular image to a square
    input_file:File name to convert
Return value:Converted image
    """
    img = Image.open(input_file)
    img = img.convert("RGB")
    
    width, height = img.size
    #If it is vertically long, expand it horizontally
    if width < height:
        result = Image.new(img.mode,(height, height),(255, 255, 255))
        result.paste(img, ((height - width) // 2, 0))
    #If it is horizontally long, expand it vertically
    elif width > height:
        result = Image.new(img.mode,(width, width),(255, 255, 255))
        result.paste(img, (0, (width - height) // 2))
    else:
        result = img
    
    return result 

predict_file.py


predict_file.py
import io
import gc

from flask import Flask, request, redirect, url_for
from flask import flash, render_template, make_response

from keras.models import Sequential, load_model
from keras.applications.resnet50 import decode_predictions
import keras

import numpy as np
from PIL import Image
from matplotlib.backends.backend_agg import FigureCanvasAgg

import graphing

classes = [
            "Fender Stratocaster",
            "Fender Telecaster",
            "Fender Jazzmaster",
            "Fender Jaguar",
            "Fender Mustang",
            "Gibson LesPaul",
            "Gibson SG",
            "Gibson FlyingV",
            "Gibson ES-335",
            "Acoustic guitar"
            ]
num_classes = len(classes)
image_size = 224
ALLOWED_EXTENSIONS = set(['png', 'jpg', 'gif'])


app = Flask(__name__)

def allowed_file(filename):
    return '.' in filename and filename.rsplit('.',1)[1].lower() in ALLOWED_EXTENSIONS

@app.route('/', methods=['GET', 'POST'])
def upload_file():
    if request.method == 'POST':
        if 'file' not in request.files:
            flash('No file')
            return redirect(request.url)
        file = request.files['file']

        if file.filename == '':
            flash('No file')
            return redirect(request.url)

        if file and allowed_file(file.filename):
            virtual_output = io.BytesIO()
            file.save(virtual_output)
            filepath = virtual_output

            model = load_model('./cnn_model/guitar_cnn_resnet50.h5')

            #Convert image to square
            image = graphing.expand_to_square(filepath)
            image = image.convert('RGB')
            #Align image size to 224x224
            image = image.resize((image_size, image_size))
            #Change from image to numpy array and normalize
            data = np.asarray(image) / 255.0
            #Increase the dimensions of the array(3D->4 dimensions)
            data = np.expand_dims(data, axis=0)
            #Make inferences using the learned model
            result = model.predict(data)[0]
            
            #Draw the inference result and the inferred image as a graph
            fig = graphing.to_graph(image, classes, result)
            canvas = FigureCanvasAgg(fig)
            png_output = io.BytesIO()
            canvas.print_png(png_output)
            data = png_output.getvalue()

            response = make_response(data)
            response.headers['Content-Type'] = 'image/png'
            response.headers['Content-Length'] = len(data)

            #Erase the model after inference
            del model
            keras.backend.clear_session()
            gc.collect()

            return response
    return '''
    <!doctype html>
    <html>
        <head>
            <meta charset="UTF-8">
            <title>Let's upload the file and judge</title>
        </head>
        <body>
            <h1>Upload the file and judge!</h1>
            <form method = post enctype = multipart/form-data>
                <p><input type=file name=file>
                <input type=submit value=Upload>
            </form>
        </body>
    </html>
    '''

By the way, if you repeat learning and inference on Keras many times, the data seems to overflow in the memory, so it seems that you have to explicitly erase it in the code. (Similarly on colab)

Reference URL ↓ Fixed the problem that memory usage increases when learning repeatedly with keras

Also, I will post the source code of the web application that I actually made. ↓ Guitar Classification Web App

Try and play

I actually tried it with my own instrument.

First from the Jazzmaster ジャズマスター判定.png It also responds to Jaguar, which has many similarities. However, if it is another image obtained from another net, it may be judged as 99% Jazzmaster, so it can not be said that the classification accuracy is bad.

Then Stratocaster ストラトキャスター判定.png It was almost certainly determined to be a Stratocaster. There seems to be no problem even if the contrast is slightly dark.

So what happens if you let them determine which base they haven't trained? I tried it with my jazz bass type. ジャズベース判定.png It is not clear that it is judged as a Mustang, but I am concerned that the probability of SG is also high. It seems that the horns are not similar ...?

Summary

This time, by fine-tuning ResNet, which is a method of CNN, we were able to create a classifier that is relatively easy to create but has high accuracy. However, some machine learning, such as CNN, is hard to explain why the results happened. Therefore, if I have time, I will try visualization methods such as Grad-CAM in the future.

that's all.

Recommended Posts

A story of a deep learning beginner trying to classify guitars on CNN
Introduction to Deep Learning ~ CNN Experiment ~
[Windows] A story of a beginner who stumbles on Anaconda's PATH setting.
The story of trying to reconnect the client
A story I was addicted to trying to install LightFM on Amazon Linux
A beginner of machine learning tried to predict Arima Kinen with python
A story about a beginner trying hard to set up CentOS 8 (procedure memo)
I tried the common story of using Deep Learning to predict the Nikkei 225
The story of trying to push SSH_AUTH_SOCK obsolete on screen with LD_PRELOAD
A story when a beginner gets stuck trying to build a vim 8.2 + python 3.8.2 + lua plugin environment on Ubuntu 18.04.4 LTS
A story about trying to install uwsgi on an EC2 instance and failing
The story of doing deep learning with TPU
From nothing on Ubuntu 18.04 to setting up a Deep Learning environment in Tensor
A memorandum of studying and implementing deep learning
Story of trying to use tensorboard with pytorch
Build a python environment to learn the theory and implementation of deep learning
A story of a high school graduate technician trying to predict the survival of the Titanic
A story that got stuck when trying to upgrade the Python version on GCE
I can't find the clocksource tsc! ?? The story of trying to write a kernel patch
A story of trial and error trying to create a dynamic user group in Slack
A story about a Python beginner trying to get Google search results using the API
Steps to quickly create a deep learning environment on Mac with TensorFlow and OpenCV
A story about trying to introduce Linter in the middle of a Python (Flask) project
Create a dataset of images to use for learning
A story about predicting exchange rates with Deep Learning
A story of trying out pyenv, virtualenv and virtualenvwrapper
Learning Deep Forest, a new learning device comparable to DNN
A note of trying a simple MCMC tutorial on PyMC3
A story about a Linux beginner putting Linux on a Windows tablet
Deep learning 1 Practice of deep learning
A story about trying to use cron on a Raspberry Pi and getting stuck in space
Try to build a deep learning / neural network with scratch
A record of hell lessons imposed on beginner Python students
[Introduction to AWS] A memorandum of building a web server on AWS
How to register a package on PyPI (as of September 2017)
Classify CIFAR-10 image datasets using various models of deep learning
Technology that supports jupyter: traitlets (story of trying to decipher)
A story about a 40-year-old engineer manager passing "Deep Learning for ENGINEER"
A story about trying to implement a private variable in Python.
I tried to divide with a deep learning language model
A story about a GCP beginner building a Minecraft server on GCE
The story of the algorithm drawing a ridiculous conclusion when trying to solve the traveling salesman problem properly
<Course> Deep Learning: Day2 CNN
Deep running 2 Tuning of deep learning
Introduction to Deep Learning ~ Learning Rules ~
Story of trying competitive programming 2
Deep Reinforcement Learning 1 Introduction to Reinforcement Learning
Deep reinforcement learning 2 Implementation of reinforcement learning
Introduction to Deep Learning ~ Backpropagation ~
A story stuck with the installation of the machine learning library JAX
A story about trying to automate a chot when cooking for yourself
A story that struggled to handle the Python package of PocketSphinx
I searched for a similar card of Hearthstone with Deep Learning
A story about trying to run multiple python versions (Mac edition)
A command to easily check the speed of the network on the console
Try to make a blackjack strategy by reinforcement learning ((1) Implementation of blackjack)
A story about wanting to think about garbled characters on GAE / P
The story of failing to update "calendar.day_abbr" on the admin screen of django
A story about trying to improve the testing process of a system written in C language for 20 years