[PYTHON] How to create data to put in CNN (Chainer)

I tried to make my own model by learning with a hackathon, but it took time to find out how to make the data to be put in CNN, so memo

from PIL import Image
import numpy as np
import glob
import random

def load_image():
    filepaths = glob.glob('data/*.png')

    datasets = []
    for filepath in filepaths:
        img = Image.open(filepath).convert('L')  #Load with Pillow.'L'Means grayscale
        img = img.resize((32, 32)) #Resized to 32x32x
        label = int(filepath.split('/')[-1].split('_')[0]) #label(Integer greater than or equal to 0) (In my case, I often put a label name at the beginning of the file name.)

        x = np.array(img, dtype=np.float32)
        x = x.reshape(1,32,32) # (Channel, height, width)
        t = np.array(label, dtype=np.int32) 

        datasets.append((x,t)) #List x and t as tuples

    random.shuffle(datasets) #shuffle
    train = datasets[:1000] #The first thousand for learning
    test = datasets[1000:1100] #For testing from the 1000th to the 1100th
    return train, test


def main(): #Below, refer to cifer10 of chainer example

    class_labels = 10
    train, test = load_image()
    
    model = L.Classifier(models.VGG.VGG(class_labels))
    if args.gpu >= 0:
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()  # Copy the model to the GPU

    optimizer = chainer.optimizers.MomentumSGD(args.learnrate)
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(5e-4))

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                 repeat=False, shuffle=False)
    # Set up a trainer
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(TestModeEvaluator(test_iter, model, device=args.gpu))

    # Reduce the learning rate by half every 25 epochs.
    trainer.extend(extensions.ExponentialShift('lr', 0.5),
                   trigger=(25, 'epoch'))

    # Dump a computational graph from 'loss' variable at the first iteration
    # The "main" refers to the target link of the "main" optimizer.
    trainer.extend(extensions.dump_graph('main/loss'))

    # Take a snapshot at each epoch
    trainer.extend(extensions.snapshot(), trigger=(args.epoch, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport())

    # Print selected entries of the log to stdout
    # Here "main" refers to the target link of the "main" optimizer again, and
    # "validation" refers to the default name of the Evaluator extension.
    # Entries other than 'epoch' are reported by the Classifier link, called by
    # either the updater or the evaluator.
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar())

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()

Recommended Posts

How to create data to put in CNN (Chainer)
Try to put data in MongoDB
[Itertools.permutations] How to put permutations in Python
How to create a large amount of test data in MySQL? ??
How to create sample CSV data with hypothesis
How to create a JSON file in Python
How to run CNN in 1 system notation in Tensorflow 2
How to read time series data in PyTorch
How to create a Rest Api in Django
The first step to log analysis (how to format and put log data in Pandas)
How to apply markers only to specific data in matplotlib
How to quickly create array sample data during coding
How to create an image uploader in Bottle (Python)
[Linux] How to put your IP in a variable
How to develop in Python
How to handle data frames
How to generate exponential pulse time series data in python
How to create and use static / dynamic libraries in C
How to get an overview of your data in Pandas
How to create dataframes and mess with elements in pandas
Data science companion in python, how to specify elements in pandas
[Python] How to do PCA in Python
How to handle session in SQLAlchemy
How to put a symbolic link
How to read e-Stat subregion data
How to use classes in Theano
How to write soberly in pandas
How to collect images in Python
How to update Spyder in Anaconda
How to use SQLite in Python
How to create your own Transform
How to deal with imbalanced data
How to create an email user
How to create a virtual bridge
How to convert 0.5 to 1056964608 in one shot
How to create / delete symbolic links
Misunderstanding on how to connect cnn
How to reflect CSS in Django
How to kill processes in bulk
How to use Mysql in python
How to Data Augmentation with PyTorch
How to wrap C in Python
How to use ChemSpider in Python
How to create a Dockerfile (basic)
How to use PubChem in Python
How to run TensorFlow 1.0 code in 2.0
How to handle Japanese in Python
How to create a config file
How to log in to Docker + NGINX
How to collect machine learning data
How to call PyTorch in Julia
How to create a heatmap with an arbitrary domain in Python
How to use python put in pyenv on macOS with PyCall
How to visualize where misclassification is occurring in data analysis classification
How to use calculated columns in CASTable
How to create a clone from Github
[Introduction to Python] How to use class in Python?
How to suppress display error in matplotlib
How to collect Twitter data without programming
How to access environment variables in Python
How to dynamically define variables in Python