Introduction to AI creation with Python! Part 1 I tried to classify and predict what the numbers are from the handwritten number images.

About this article

Using "classification", which is the basis of deep learning, I would like to make a program that predicts what numbers handwritten numbers are. I will write from a beginner's point of view as much as possible.

What is deep learning?

Simply put, it's a technique for finding regularity in data using large amounts of data. It consists of the following three processes.

  1. Modeling and compiling
  2. Learning
  3. Reasoning

What is the classification of deep learning?

Roughly speaking, it is to predict __ data type __ based on characteristic data.

Characteristic data is images, The latest time series data, It's like a hint to make a prediction. We also call it features or explanatory variables.

When classifying two types of patterns such as "dog" and "cat", it is called __2 classification __. I think that the feature amount in this case will be a large number of images of dogs and cats.

Also, when classifying three or more types of patterns such as "sunny", "cloudy", and "rain", it is called multiclass classification.

In this article, I will try to classify images of handwritten numbers from 0 to 9 into actual numbers (multi-class classification).

import

The library uses tensorflow. If you don't have tensorflow, please install it with pip install in advance. By the way, this time, I use a package called keras of tensorflow, but this is a different package from the independent keras. tensorflow keras ≠ independent keras

from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Activation, Dense, Dropout
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.utils import to_categorical
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

Download handwritten digit data

Since we predict the classification of handwritten numbers, we must actually prepare handwritten number data. It is ant to actually write it by hand and prepare it, but since it is troublesome, let's download it using a library called MNIST.

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

The data downloaded by MNIST is divided into the following four types of data. Handwritten digit data for use in learning and label data containing the answers: (train_images, train_labels) Handwritten digit data for use in the test and label data with the answer: (test_images, test_labels)

It's about train_images, so do you have images? You might think, but the contents are an array. There are a lot of image data expressed in an array of 28 rows and 28 columns ↓

And train_labels contains the data of the answer of the image. In my environment, if you look at the contents of train_labels [0], the number 5 is included. This means "the image (array) of train_images [0] is 5."

The structure of the contents of (test_images, test_labels) is the same. The answer for test_images [0] is in test_labels [0].

In the steps that follow, you will learn using (train_images, train_labels) and The flow is to check the correctness with (test_images, test_labels).

Confirmation of handwritten digit data

Arrays are not friendly to the human eye, so Convert about 10 images and check.

for i in range(10):
    plt.subplot(1, 10, i+1)
    plt.imshow(train_images[i], 'gray')
plt.show()

image.png

I was able to confirm that something like handwritten numbers was included. Next, check the label data.

print(train_labels[0:10])

[5 0 4 1 9 2 1 3 1 4] is output. It seems that it matches the image of the handwritten numbers properly.

Data set preprocessing

The data has been read, but it cannot be predicted as it is. I would like to format the array data of the image and the label data of the answer so that it can be predicted. This is called data preprocessing and is a very important task.

Looking at the blog of a certain active data scientist, It was written that pre-processing work occupies most of the AI creation work.

Each image in the current images is a 2D array with 28 rows and 28 columns ↓ image.png

Convert this to a one-dimensional array using reshape ↓ (Because 28 × 28 = 784, 784)

train_images = train_images.reshape((train_images.shape[0], 784))
test_images = test_images.reshape((test_images.shape[0], 784))

Next is the label data formatting process. Label data is converted to "One-Hot representation". "One-Hot expression" is an expression method in which only one element is 1 and the other is 0. If the answer is "5", the expression will be as follows. [0,0,0,0,1,0,0,0,0] # 1 is entered in the 5th place

Fix it using to_categorical ↓

train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)

This completes the data preprocessing.

Creating a model

Now that the pre-processing is done, it's time to create the model. A model is like a blueprint in a nutshell. Declare the number of input layers, the number of hidden layers, the dropout invalidation rate, the number of output layers, what activation function to use, and so on.

model = Sequential()
model.add(Dense(256, activation='sigmoid', input_shape=(784,))) #Input layer
model.add(Dense(128, activation='sigmoid')) #Hidden layer
model.add(Dropout(rate=0.5)) #Drop out
model.add(Dense(10, activation='softmax')) #Output layer

I will explain each one.

Model declaration

I am creating a Sequential model with model = Sequential (). Sequential is a model of a structure in which multiple layers are stacked. Imagine a cake millefeuille.

Input layer

model.add (Dense (256, activation ='sigmoid', input_shape = (784,))) is the input layer. Dense represents a fully connected layer. Fully connected layer means connecting all units to the next layer. This time we are creating 256 units.

`ʻactivation ='sigmoid'`` is the activation function to use. By writing sigmoid, you can use the sigmoid function as the activation function. By using the activation function, it is possible to capture features even in complex data where linear separation is difficult.

```input_shape = (784,) `` is the number of incoming data.

Hidden layer

model.add (Dense (128, activation ='sigmoid')) is a hidden layer. The hidden layer makes it possible to capture complex features. You can increase the number of hidden layers as much as you like, If you overdo it, you will end up with a classification prediction model that only works with training data. It is called "overfitting" or "curve fitting".

Drop out

model.add (Dropout (rate = 0.5)) is a dropout. Dropout is a technique to disable some units to prevent overfitting. The invalidation rate is determined by rate. It is generally said that 50% should be disabled. However, it doesn't seem to have to be done.

Output layer

model.add (Dense (10, activation ='softmax')) is the output layer. This time, there are 10 types of multi-class classification from 0 to 9, so the number of units will be 10. In addition, the softMAX function is used as the activation function. The softMAX function is said to be suitable for multi-class classification where the total value of the classification is 1. This time I use only sigmoid and softmax, but it seems that there are various other activation functions.

Compiling the model

Next, compile the created model. It ends with the following line, but it is packed with important information.

model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.1), metrics=['acc'])

The following three must be set for compilation.

--Loss function --Optimization function --Evaluation index

What is a loss function?

It is a function that calculates the __error __ of the predicted value of the model and the correct answer data. In multiclass classification, a function called categorical_crossentropy (multiclass cross entropy error) is often used. Let's set it with loss ='categorical_crossentropy'.

What is an optimization function?

It is a function that adjusts the weight and bias so that the error calculated by the loss function approaches 0. Optimization functions such as SGD and Adam are typical. This time I will use SGD. The argument lr of SGD represents the learning rate. A value that determines how often the weights of each layer are updated. This time set 0.1. optimizer=SGD(lr=0.1)

What is an evaluation index?

An index for measuring the performance of a model. The main evaluation indicators are acc and mse. Often used in classification is acc (Accuracy), which indicates the accuracy rate. When predicting numerical values such as regression, mse is said to be good. metrics=['acc']

Learning

Now that you're ready, let's learn. Learn using train_imeges and train_labels. Learning is done in model.fit.

history = model.fit(train_images, train_labels, batch_size=500, 
    epochs=5, validation_split=0.2)

batch_size is a setting for how many units of training data are used for training. The larger it is, the faster it learns, but it consumes memory.

epochs are the number of epochs to train.

validation_split is the ratio that separates training data and validation data. If it is 0.2, 20% will be used for verification data.

When executed, it will look like ↓. The learning situation at each epoch will be described. When I see this, I feel like "Wow! I'm making AI!" And I'm very excited. image.png

This is an explanation of the learning situation. loss is __training data error __. The closer it is to 0, the better the result. acc is __ training data accuracy rate __. The closer it is to 1, the better the result. val_loss is the __validation data error __. The closer it is to 0, the better the result. val_acc is the correct answer rate of validation data. The closer it is to 1, the better the result.

If you increase the number of epochs too much and overfit, loss and acc are very good numbers val_loss and val_acc are bad numbers.

You become a fool __ who can only solve past questions by overdoing the past questions. Set the number of epochs to a reasonable number.

By the way, you can check the learning result in history.history. image.png

To display the graph, use the following code.

plt.plot(history.history['acc'], label='acc')
plt.plot(history.history['val_acc'], label='val_acc')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(loc='best')
plt.show()

Evaluation of learning results

When the training is complete, evaluate using the test data. Use model.evaluate for evaluation.

test_loss, test_acc = model.evaluate(test_images, test_labels)
print('loss: {:.3f}\nacc: {:.3f}'.format(test_loss, test_acc ))

It seems that the acc (correct answer rate) is almost the same as the result of learning ↓ image.png

inference

Finally, we make inferences based on the actual data. This allows you to get the individual results you have predicted.

Display the inferred image ↓

for i in range(10):
    plt.subplot(1, 10, i+1)
    plt.imshow(test_images[i].reshape((28, 28)), 'gray')
plt.show()

Put this image in model.predict to make inferences. You can retrieve the result with np.argmax.

test_predictions = model.predict(test_images[0:10])
test_predictions = np.argmax(test_predictions, axis=1)
print(test_predictions)

The output result is as follows. [7 2 1 0 4 1 4 9 6 9] Everything seems to be correct.

When actually operating as AI, not test_images Throw the data you want to predict into model.predict and make a prediction. If you save the model, you do not need to learn again.

It's been long, but that's it.

The next article is ↓ https://qiita.com/sw1394/items/f25dea5cf24ce78dbb0c

Recommended Posts

Introduction to AI creation with Python! Part 1 I tried to classify and predict what the numbers are from the handwritten number images.
Introduction to AI creation with Python! Part 3 I tried to classify and predict images with a convolutional neural network (CNN)
Introduction to AI creation with Python! Part 2 I tried to predict the house price in Boston with a neural network
I tried to get the number of days of the month holidays (Saturdays, Sundays, and holidays) with python
I tried to learn the angle from sin and cos with chainer
I tried to predict next year with AI
[Introduction to AWS] I tried porting the conversation app and playing with text2speech @ AWS ♪
[Python] I tried the same calculation as LSTM predict with from scratch [Keras]
I tried to predict the number of people infected with coronavirus in consideration of the effect of refraining from going out
python beginners tried to predict the number of criminals
I tried to touch the CSV file with Python
I tried to solve the soma cube with python
I tried to solve the problem with Python Vol.1
I tried to solve AOJ's number theory with Python
Extract images and tables from pdf with python to reduce the burden of reporting
I tried to automate the article update of Livedoor blog with Python and selenium.
I tried to compare the processing speed with dplyr of R and pandas of Python
I tried to find the entropy of the image with python
A super introduction to Django by Python beginners! Part 6 I tried to implement the login function
I tried to simulate how the infection spreads with Python
I tried to summarize what python strong people are doing in the competition professional neighborhood
I tried using the Python library from Ruby with PyCall
What I did to welcome the Python2 EOL with confidence
I made a server with Python socket and ssl and tried to access it from a browser
I tried the python version of "Consideration of Conner Davis's answer" Printing numbers from 1 to 100 without using loops, recursion, and goto "
Sentiment analysis with natural language processing! I tried to predict the evaluation from the review text
I also tried to imitate the function monad and State monad with a generator in Python
I tried to enumerate the differences between java and python
I tried to make GUI tic-tac-toe with Python and Tkinter
I tried changing the python script from 2.7.11 to 3.6.0 on windows10
I tried to divide the file into folders with Python
I tried to find the trend of the number of ships in Tokyo Bay from satellite images.
I tried to make Othello AI with tensorflow without understanding the theory of machine learning ~ Introduction ~
I tried to predict the number of domestically infected people of the new corona with a mathematical model
Mayungo's Python Learning Episode 4: I tried to see what happens when numbers are treated as letters
I tried to solve the ant book beginner's edition with python
[Introduction to Python] I compared the naming conventions of C # and Python.
I tried to create a list of prime numbers with python
I tried to make a periodical process with Selenium and Python
I read "Reinforcement Learning with Python: From Introduction to Practice" Chapter 1
I tried to easily detect facial landmarks with python and dlib
From the introduction of JUMAN ++ to morphological analysis of Japanese with Python
I tried to improve the efficiency of daily work with Python
I read "Reinforcement Learning with Python: From Introduction to Practice" Chapter 2
I tried to automatically collect images of Kanna Hashimoto with Python! !!
[Python] Try to recognize characters from images with OpenCV and pyocr
Playing handwritten numbers with python Part 1
A super introduction to Django by Python beginners! Part 3 I tried using the template file inheritance function
A super introduction to Django by Python beginners! Part 2 I tried using the convenient functions of the template
I tried to get and analyze the statistical data of the new corona with Python: Data of Johns Hopkins University
I tried to deliver mail from Node.js and Python using the mail delivery service (SendGrid) of IBM Cloud!
[Python] I tried to visualize the night on the Galactic Railroad with WordCloud!
I tried to refer to the fun rock-paper-scissors poi for beginners with Python
[Introduction to Python] What is the difference between a list and a tuple?
I tried to create a program to convert hexadecimal numbers to decimal numbers with python
I tried to express sadness and joy with the stable marriage problem.
I tried to discriminate a 6-digit number with a number discrimination application made with python
I tried to make Kana's handwriting recognition Part 2/3 Data creation and learning
I tried with the top 100 PyPI packages> I tried to graph the packages installed on Python
I tried fMRI data analysis with python (Introduction to brain information decoding)
I tried to streamline the standard role of new employees with Python