processing to use notMNIST data in Python (and tried to classify it)

I heard a rumor recently notMNIST I wrote a process to handle a dataset in Python, so I will publish it: blush:

MNIST is a test data set of handwritten numbers that many people who are learning machine learning know, but this notMNIST is handwritten. An image dataset of alphabets represented in various fonts, not numbers.

If you try to visualize the contents, you will see the following data set. The letters at the beginning of the image are the alphabet that the image represents, and the corresponding image is displayed. Look at the second from the right on the first line. It doesn't look like "I" at all, it's a house, isn't it? This: sweat_smile: With that kind of feeling, even if one person sees one line, it contains suspicious data, but I think it's a very interesting subject.

notMNIST.png

The official page of notMNIST is http://yaroslavvb.blogspot.jp/2011/09/notmnist-dataset.html So, a person named Yaroslav Bulatov is made.

1. Download data

First of all http://yaroslavvb.com/upload/notMNIST/ Go to and there  notMNIST_large.tar.gz Please download the data from.

Since it is compressed with tar.gz, if you decompress it with a decompression tool etc. as appropriate  notMNIST_large Folders are created, and folders for each alphabet of A, B, C ... are created in the subfolders. I will write the process to read this in Python. I'm using a Jupyter Notebook, in which case I want the .ipynb file to be in the same directory as this notMNIST_large folder. For .py, make sure that the .py file is created in the same directory as well.

2. Python code

The set of code is also uploaded to Github, but I will write it here as well.

Install various libraries,

from __future__ import division
import sys, os, pickle

import numpy as np
import numpy.random as rd

from scipy.misc import imread

import matplotlib.pyplot as plt
%matplotlib inline

Define a function for pickle, a function for displaying an image, etc.

image_size = 28
depth = 255

def unpickle(filename):
    with open(filename, 'rb') as fo:
        _dict = pickle.load(fo)
    return _dict

def to_pickle(filename, obj):
    with open(filename, 'wb') as f:
        #pickle.dump(obj, f, -1)
        pickle.Pickler(f, protocol=2).dump(obj)

def count_empty_file(folder):
    cnt = 0
    for file in os.listdir(folder):
        if os.stat(os.path.join(folder, file)).st_size == 0:
            cnt += 1
    return cnt

I want to save the label as an int type, so prepare a dictionary for conversion.

label_conv = {a: i for a, i in zip('ABCDEFGHIJ', range(10))}
num2alpha = {i: a for i,a in zip(range(10), 'ABCDEFGHIJ')}

Read each image file in the folder and save it as a numpy ndarray. At the same time, prepare label data with the folder name as the label. After reading, store the image data in'data'and the label data in'target' in dictionary format, and save the object as a file with pickle. Occasionally there is a corrupted file and the size is 0 and it can not be read, so skip processing is included as a countermeasure for such things and those that cause reading errors.

#Existence check of the folder to be read
assert os.path.exists('notMNIST_large')
# assert os.path.exists('notMNIST_small')  #When reading small, please restore it for checking.

for root_dir in ['notMNIST_large']: # ['notMNIST_small', 'notMNIST_large']: #If you also use small, select both
    folders = [os.path.join(root_dir, d) for d in sorted(os.listdir(root_dir)) 
               if os.path.isdir(os.path.join(root_dir, d))]
    #Make a frame
    file_cnt = 0
    for folder in folders:

        label_name = os.path.basename(folder)
        file_list = os.listdir(folder)
        file_cnt += len(file_list)-count_empty_file(folder)

    dataset = np.ndarray(shape=(file_cnt, image_size*image_size), dtype=np.float32)
    labels  = np.ndarray(shape=(file_cnt), dtype=np.int)
     
    last_num = 0  #Last index of the previous character

    for folder in folders:

        file_list = os.listdir(folder)
        file_cnt = len(file_list)-count_empty_file(folder)

        label_name = os.path.basename(folder)
        labels[last_num:(last_num+file_cnt)] = label_conv[label_name]
        #label = np.array([label_name] * file_cnt)

        skip = 0
        for i, file in enumerate(file_list):

            #Skip files with 0 file size
            if os.stat(os.path.join(folder, file)).st_size == 0:
                skip += 1
                continue
            try:
                data = imread(os.path.join(folder, file))
                data   = data.astype(np.float32)
                data  /= depth     # 0-Convert to 1 data
                dataset[last_num+i-skip, :] = data.flatten()
            except:
                skip += 1
                print 'error {}'.format(file)
                continue
        last_num += i-skip
        
    notmnist = {}
    notmnist['data'] = dataset
    notmnist['target'] = labels
    to_pickle('{}.pkl'.format(root_dir), notmnist)

3. How to use

When using it, unpickle it, read the file, and extract it as an object. If necessary, change the range of values to 0-1 or divide it into training data and validation data.

from sklearn.cross_validation import train_test_split

notmnist = unpickle('notMNIST_large.pkl')   #NotMNIST in the same folder_large.Suppose it contains pkl.
notmnist_data = notmnist['data']
notmnist_target = notmnist['target']


notmnist_data   = notmnist_data.astype(np.float32)
notmnist_target   = notmnist_target.astype(np.int32)
notmnist_data  /= 255     # 0-Convert to 1 data

#75 training data%, Set the verification data with the remaining number

x_train, x_test, y_train, y_test = train_test_split(notmnist_data, notmnist_target)

If you want to visualize what the read image looks like, try the display process with the following function.

def draw_digit(digits):
    size = 28
    plt.figure(figsize=(len(digits)*1.5, 2))
    for i, data in enumerate(digits):
        plt.subplot(1, len(digits), i+1)
        X, Y = np.meshgrid(range(size),range(size))
        Z = data[0].reshape(size,size)   # convert from vector to 28x28 matrix
        Z = Z[::-1,:]             # flip vertical
        plt.xlim(0,27)
        plt.ylim(0,27)
        plt.pcolor(X, Y, Z)
        plt.gray()
        plt.title(num2alpha[data[1]])
        plt.tick_params(labelbottom="off")
        plt.tick_params(labelleft="off")

    plt.show()

It is displayed in 10 rows and 10 columns.

[draw_digit2([[notmnist_data[idx], notmnist_target[idx]] for idx in rd.randint(len(dataset), size=10)]) for i in range(10)]
notMNIST.png

4. Try to classify

Since I read it with much effort, I will try classification with Random Forest. (Since the number of weak learning machines is set to 100, it will take some time to learn.)

from sklearn.ensemble import RandomForestClassifier

clf = RandomForestClassifier(n_estimators=100)
clf = clf.fit(x_train, y_train)

For now, let's look at the reassignment error rate.

#Reassignment error rate
pred = clf.predict(x_train)
result = [y==p for y, p in zip(y_train,pred)]
np.sum(result)/len(pred)

out


 0.99722555413319358
#Generalization performance check
pred = clf.predict(x_test)
result = [y==p for y, p in zip(y_test,pred)]
np.sum(result)/len(pred)

The generalization performance is also good at 91%.

out


 0.91262407487205077

Let's visualize the result of prediction.

#Visualize results
rd.seed(123)
[draw_digit([[x_test[idx], y_test[idx], pred[idx]] for idx in rd.randint(len(x_test), size=10)]) for i in range(10)]

Most people can make mistakes, but those that can be recognized as alphabets are almost correct: smile: notMNISTresult.png

Recommended Posts

processing to use notMNIST data in Python (and tried to classify it)
How to use is and == in Python
Full-width and half-width processing of CSV data in Python
[Python] How to name table data and output it in csv (to_csv method)
I tried to summarize how to use pandas in python
How to use Decorator in Django and how to make it
If you use Pandas' Plot function in Python, it is really seamless from data processing to graph creation
Try to make it using GUI and PyQt in Python
[Introduction to Udemy Python 3 + Application] 36. How to use In and Not
Comparison of how to use higher-order functions in Python 2 and 3
Object-oriented in C: Refactored "○ ✕ game" and ported it to Python
How to use SQLite in Python
Hashing data in R and Python
How to use Mysql in python
How to use ChemSpider in Python
How to use PubChem in Python
Solving AOJ's Algorithm and Introduction to Data Structures in Python -Part1-
How to install OpenCV on Cloud9 and run it in Python
How to use functions in separate files Perl and Python versions
Solving AOJ's Algorithm and Introduction to Data Structures in Python -Part2-
Solving AOJ's Algorithm and Introduction to Data Structures in Python -Part4-
Read big endian binary in Python and convert it to ndarray
Solving AOJ's Algorithm and Introduction to Data Structures in Python -Part3-
I tried to implement PLSA in Python
[Introduction to Python] How to use class in Python?
I tried to implement permutation in Python
How to install and use pandas_datareader [Python]
I tried to implement PLSA in Python 2
Easy way to use Wikipedia in Python
I tried to implement ADALINE in Python
Easily graph data in shell and Python
python: How to use locals () and globals ()
I tried to implement PPO in Python
How to use __slots__ in Python class
How to use "deque" for Python data
How to use Python zip and enumerate
Compress python data and write to sqlite
How to use regular expressions in Python
Data analysis: Easily apply descriptive and inference statistics to CSV data in Python
Connect to postgreSQL from Python and use stored procedures in a loop.
Use Cloud Dataflow to dynamically change the destination according to the value of the data and save it in GCS
How to use the C library in Python
How to generate permutations in Python and C ++
Python variables and data types learned in chemoinformatics
Type Python scripts to run in QGIS Processing
Receive and display HTML form data in Python
How to use Python Image Library in python3 series
[Python] Swapping rows and columns in Numpy data
[Python] How to read data from CIFAR-10 and CIFAR-100
Use decorators to prevent re-execution of data processing
Summary of how to use MNIST in Python
I tried to get CloudWatch data with Python
A clever way to time processing in Python
Send messages to Skype and Chatwork in Python
Use cryptography module to handle OpenSSL in Python
I tried to implement TOPIC MODEL in Python
How to use tkinter with python in pyenv
Y / n processing in bash, python and Go
[Python] How to use hash function and tuple.
POST JSON in Python and receive it in PHP
I tried to implement selection sort in python