[PYTHON] Randomly sample MNIST data to create a dataset

Overview

It became necessary to train using a part of MNIST instead of the entire MNIST dataset. Therefore, I created a program that randomly extracts n images from MNIST Training data of 60,000, divides them into folders for each class, and saves the images.

Execution environment

Google Colaboratory PyTorch 1.6.0

Implementation

Save MNIST in image format

Download the MNIST dataset and save it in image format for random extraction from the Train dataset. I referred to this site. Try using Image Folder with PyTorch

First, import the required modules

import os
from PIL import Image
from torchvision.datasets import MNIST
import shutil
import glob
from pprint import pprint
import random
from pathlib import Path
from tqdm import tqdm

If you do not have the required module, install it with pip or conda as appropriate.

Then download MNIST.

mnist_data = MNIST(root='./', train=True, transform=None, download=True)

You may get a User Warning when you download mnist, but don't worry because we are not learning with the downloaded mnist this time.

Save the MNIST image in PNG format from the downloaded MNIST binary file.

def makeMnistPng(image_dsets):
    for idx in tqdm(range(10)):
        print("Making image file for index {}".format(idx))
        num_img = 0
        dir_path = './mnist_all/'
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        for image, label in image_dsets:
           if label == idx:
                filename = dir_path +'/mnist_'+ str(idx) + '-' + str(num_img) + '.png'
                if not os.path.exists(filename):
                    image.save(filename)
                num_img += 1
    print('Success to make MNIST PNG image files. index={}'.format(idx))

Execute the function.

makeMnistPng(mnist_data)

This saves all 600 million mnist images under mnist_all. If you want to save images for each class, please do as follows.

def makeMnistPng(image_dsets):
    for idx in tqdm(range(10)):
        print("Making image file for index {}".format(idx))
        num_img = 0
        dir_path = './MNIST_PNG/' + str(idx)
        if not os.path.exists(dir_path):
            os.makedirs(dir_path)
        for image, label in image_dsets:
            if label == idx:
                filename = dir_path +'/' + 'mnist_'+ str(idx) + '_' + str(num_img) + '.png'
                if not os.path.exists(filename):
                    image.save(filename)
                num_img += 1
    print('Success to make MNIST PNG image files. index={}'.format(idx))

Randomly sample from files in the directory

Since I was able to drop all the data of mnist into one directory, I will randomly sample n images from there and copy them to another directory. The article that I used as a reference (used almost as it is) is here

Class definition


class FileControler(object):
    def get_file_path(self, input_dir, pattern):
        #Get file path
        #Create a path object by specifying a directory
        path_obj = Path(input_dir)
        #Match files in glob format
        files_path = path_obj.glob(pattern)
        #Posix conversion to treat as a character string
        files_path_posix = [file_path.as_posix() for file_path in files_path]
        return files_path_posix
    
    def random_sampling(self, files_path, sample_num, output_dir, fix_seed=True) -> None:
        #Random sampling
        #Pin Seed to sample the same file every time
        if fix_seed is True:
            random.seed(0)
        #Specify the file group path and the number of samples
        files_path_sampled = random.sample(files_path, sample_num)
        #Create if there is no output directory
        os.makedirs(output_dir, exist_ok=True)
        #copy
        for file_path in files_path_sampled:
            shutil.copy(file_path, output_dir)

Instance creation

file_controler =FileControler()

Directory settings

Set the sampling source directory and the directory to copy the sampled files.

all_file_dir = './mnist_all/'
sampled_dir = './mnist_sampled/'

Get the path of all files

pattern = '*.png'
files_path = file_controler.get_file_path(all_file_dir, pattern)

print(len(files_path))
# 60000

n sampling

sample_num = 100
file_controler.random_sampling(files_path, sample_num, sampled_dir)

sampled_files_path = file_controler.get_file_path(sampled_dir, pattern)
print(len(sampled_files_path))
# 100

With this, n (100 this time) were randomly sampled from mnist 60000.

Classification

We will divide the sampled images into classes so that they can be used as a machine learning dataset.

First, get all the file names in the sampled directory in list format.

files = glob.glob("./mnist_sampled/*")

Use the in operator to determine the substring of the file name list and divide it into folders for each class.

for i in range(10):
    os.makedirs(sampled_dir+str(i), exist_ok=True)
    for x in files:
        if '_' + str(i) in x:
            shutil.move(x, sampled_dir + str(i))

The sampled directory has such a directory structure.

./mnist_sampled
├── 0
├── 1
├── 2
├── 3
├── 4
├── 5
├── 6
├── 7
├── 8
└── 9

Now you can randomly sample the mnist images and classify them to create a dataset.

Recommended Posts

Randomly sample MNIST data to create a dataset
How to create sample CSV data with hypothesis
Create a dataset of images to use for learning
How to quickly create array sample data during coding
I want to randomly sample a file in Python
Create a dummy data file
Steps to create a Django project
How to create a Conda package
Code to randomly generate a score
How to create a virtual bridge
How to create a Dockerfile (basic)
5 Ways to Create a Python Chatbot
How to create a config file
How to create a large amount of test data in MySQL? ??
I tried to create a linebot (implementation)
How to create a clone from Github
Create a bot to retweet coronavirus information
I tried to create a linebot (preparation)
[MNIST] Convert data to PNG for keras
[Colab] How to copy a huge dataset
Load_data self-made to run Python MNIST sample code on your own dataset
Various ways to create a dictionary (memories)
[Development environment] How to create a data set close to the production DB
How to create a repository from media
Script to create a Mac dictionary file
I tried to create a sample to access Salesforce using Python and Bottle
[Pandas sample code] Create and aggregate sample data that looks like a purchase log
Aggregate steps by day from iPhone healthcare data to create a CSV file
I want to create a web application that uses League of Legends data ①
I tried to create a model with the sample of Amazon SageMaker Autopilot
[Python] List Comprehension Various ways to create a list
Edit Excel from Python to create a PivotTable
I want to easily create a Noise Model
How to create a function object from a string
I want to create a window in Python
How to create a JSON file in Python
If you want to create a Word Cloud.
Create a command to encode / decode Splunk base64
I'm addicted to Kintone as a data store
Create a binary data parser using Kaitai Struct
How to create data to put in CNN (Chainer)
Use click to create a sub-sub command --netsted sub-sub command -
Sample to draw a simple clock using ebiten
Steps to create a Twitter bot with python
Try to create a new command on linux
How to create a shortcut command for LINUX
I want to create a plug-in type implementation
[Note] How to create a Ruby development environment
How to create a Kivy 1-line input box
How to create a multi-platform app with kivy
AWS Step Functions to learn with a sample
How to create a Rest Api in Django
Use a cool graph to analyze PES data!
[PyTorch] Sample ⑧ ~ How to build a complex model ~
Create a command to get the work log
[Note] How to create a Mac development environment
Try to create a battle record table with matplotlib from the data of "Schedule-kun"
How to use fixture in Django to populate sample data associated with a user model
Quickly create a Python data analysis dashboard with Streamlit and deploy it to AWS
Create a shogi game record management application using Django 5 ~ Pass DB data to Template ~
I tried scraping food recall information with Python to create a pandas data frame