[PYTHON] I made a data extension class for tensorflow> = 2.0 because ImageDataGenerator can no longer be used.

background

In Tensorflow, Keras' ImageDataGenerator was often used to extend image data. This was widely used because it allows you to train a tensorflow model while applying various data extensions to the input image in a multi-process.

However, since multiprocessing has been deprecated since tensorflow 2.0, when learning with tensorflow while expanding data by multi-process processing, the progress bar stops without suddenly throwing an error. What's particularly painful is that you don't get errors, so if you're using a service that charges you on an hourly basis, you'll be wasting money even though learning doesn't progress ...: money_mouth:

I used a generator that learns while reading from an hdf5 file that was divided into multiple files and wrote it in multiple processes, but even with tensorflow <2.0, learning sometimes stopped in about 2 days when using multiprocessing. After tensorflow> = 2.0, it stopped even more frequently, such as stopping even for about 2 hours.

Therefore...

Like ImageDataGenerator, I created a class that can be learned with tensorflow while easily </ strong> applying various data extensions to images in a multi-process.

policy

The recommended data input method for tensorflow is to use tensorflow.data.Dataset (hereinafter referred to as tf.data.Dataset). By using this, it is possible to create high-speed and multi-process data input processing, as mentioned in here, for example. Become. However, tf.data does not have much information on stack overflow etc., and although there is a writing that seems to be trying each data expansion, learning while easily expanding various data to the input image like ImageDataGenerator I couldn't find a way to do it ...

You can use tf.data to create fast data entry processing, but the official documentation shows that there are some pitfalls.

1. tf.data.Dataset.from_generator does not extend data in multiple processes

You can use tf.data.Dataset.from_generator to wrap a python generator and train it as tf.data with the fit () function. First, wrap ImageDataGenerator with this function! I thought easily. However, in the Official Document, Note of from_generator, there is the following description.

Note: The current implementation of Dataset.from_generator() uses tf.numpy_function and inherits the same constraints. In particular, it requires the Dataset- and Iterator-related operations to be placed on a device in the same process as the Python program that called Dataset.from_generator(). The body of generator will not be serialized in a GraphDef, and you should not use this method if you need to serialize your model and restore it in a different environment.

By using tf.numpy_function, I gave up because it does not support multi-process.

2. Implement with tf only as much as possible

It's listed in the official documentation tf.function, but if you enclose it in a @ tf.function decorator for performance, all the code is tf. Will be automatically converted to the code of. If you are using an external library or numpy at that time, you will end up wrapping it in tf.numpy_function, tf.py_func, etc., and you will end up with the same restrictions as in 1. Therefore, I tried to use tf.Tensor type for processing and data type as much as possible, and even if it wasn't, I tried to use only python standard type.

3. Expand the label image at the same time

If the input image is transformed by rotating it, is it necessary to transform the image that is the source of the label in exactly the same way? I did, so I tried to apply exactly the same transformations to the (optional) label image as the input image.

Installation method

python -m pip install git+https://github.com/piyop/tfaug

How to use

1. Initialization

from tfaug import augment_img 
#set your augment parameters below:
arg_fun = augment_img(rotation=0, 
                      standardize=False,
                      random_flip_left_right=True,
                      random_flip_up_down=True, 
                      random_shift=(.1,.1), 
                      random_zoom=.1,
                      random_brightness=.2,
                      random_saturation=None,
                      training=True) 
                      
"""
augment_img.__init__() setting up the parameters for augmantation.

Parameters
----------
rotation : float, optional
    rotation angle(degree). The default is 0.
standardize : bool, optional
    image standardization. The default is True.
random_flip_left_right : bool, optional
    The default is False.
random_flip_up_down : bool, optional
    The default is False.
random_shift : Tuple[float, float], optional
    random shift images.
    vartical direction (-list[0], list[0])
    holizontal direction  (-list[1], list[1])
    Each values shows ratio of image size.
    The default is None.
random_zoom : float, optional
    random zoom range -random_zoom to random_zoom.
    value of random_zoom is ratio of image size
    The default is None.
random_brightness : float, optional
    randomely adjust image brightness range 
    [-max_delta, max_delta). 
     The default is None.
random_saturation : Tuple[float, float], optional
    randomely adjust image brightness range between [lower, upper]. 
    The default is None.
training : bool, optional
    If false, this class don't augment image except standardize. 
    The default is False.
Returns
-------
class instance : Callable[[tf.Tensor, tf.Tensor, bool], Tuple[tf.Tensor,tf.Tensor]]
"""                     

2. Used in tf.data.Dataset.map ()

ds=tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(image),
                      tf.data.Dataset.from_tensor_slices(label))) \
                    .shuffle(BATCH_SIZE*10).batch(BATCH_SIZE)\
                    .map(arg_fun, num_parallel_calls=tf.data.experimental.AUTOTUNE)
model.fit(ds)

See test for a detailed usage example. (https://github.com/piyop/tfaug)

Recommended Posts

I made a data extension class for tensorflow> = 2.0 because ImageDataGenerator can no longer be used.
When developing with ipython, scrapy can no longer be read
[2015.02.22] Youtube-dl has been updated and can no longer be used in previous versions.
The kernel of jupyter notebook can no longer connect
eclipse no longer starts.
I made a data extension class for tensorflow> = 2.0 because ImageDataGenerator can no longer be used.
What to do if Linux VLC can no longer rotate
Can I be a data scientist?
I made it because I want JSON data that can be used freely in demos and prototypes
I created a template for a Python project that can be used universally
I made a familiar function that can be used in statistics with Python
[Atcoder] [C ++] I made a test automation tool that can be used during the contest
I made a tool to automatically generate a state transition diagram that can be used for both web development and application development
I made a shuffle that can be reset (reverted) with Python
[python] I made a class that can write a file tree quickly
I wrote a tri-tree that can be used for high-speed dictionary implementation in D language and Python.
[Updated Ver1.3.1] I made a data preprocessing library DataLiner for machine learning.
I made a simple timer that can be started from the terminal
I made a dash docset for Holoviews
I made a library for actuarial science
A class for PYTHON that can be operated without being aware of LDAP
[2015.02.22] Youtube-dl has been updated and can no longer be used in previous versions.
I made a python dictionary file for Neocomplete
I made a useful tool for Digital Ocean
I made a downloader for word distributed expression
I made a peeping prevention product for telework.
File sharing server made with Raspberry Pi that can be used for remote work