[PYTHON] How to use tf.data

Summarize what you learned about tf.data to handle large files

I referred to the following site https://qiita.com/Suguru_Toyohara/items/820b0dad955ecd91c7f3 https://qiita.com/wasnot/items/9b64550237a3c5267bfd https://qiita.com/everylittle/items/a7c31b08d2f76c886a92

What is tf.data

It is a library related to data supply of tensorflow. It seems to have the following merits when used.

  1. You can reduce GPU latency and maximize learning speed
  2. Data that does not fit in memory can be read sequentially
  3. Preprocessing such as data augumentation can be accelerated
  4. Put together a pretreatment pipeline

1. Convert and use numpy.array

Convert

You can use it by converting from np.array to tf.data object

python


import numpy as np
import tensorflow as tf

arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr)

for item in dataset:
    print(item)

output


tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)

repeat

Output by repeating the argument times

python


arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).repeat(3)

for item in dataset:
    print(item)

output


tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)

batch

Output as a batch argument

python


arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).batch(2)

for item in dataset:
    print(item)

output


tf.Tensor(
[[0 1 2 3 4]
 [5 6 7 8 9]], shape=(2, 5), dtype=int32)

tf.Tensor(
[[10 11 12 13 14]
 [15 16 17 18 19]], shape=(2, 5), dtype=int32)

tf.Tensor([[20 21 22 23 24]], shape=(1, 5), dtype=int32)

shuffle

The argument specifies how far the data should be replaced. If the argument is 1, there will be no replacement, and if it is a small value, it will not be shuffled sufficiently, so I think it is better to enter the same value as the data size.

Click here for details on shuffle size https://qiita.com/exy81/items/d1388f6f02a11c8f1d7e

python


arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).shuffle(5)

for item in dataset:
    print(item)

output


tf.Tensor(
[[0 1 2 3 4]
 [5 6 7 8 9]], shape=(2, 5), dtype=int32)

tf.Tensor(
[[10 11 12 13 14]
 [15 16 17 18 19]], shape=(2, 5), dtype=int32)

tf.Tensor([[20 21 22 23 24]], shape=(1, 5), dtype=int32)

combination

You can use the above in combination. It will be executed in order, so be careful not to do the meaningless thing of cutting the batch and then shuffling.

python


arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).repeat(2).shuffle(5).batch(4)

for item in dataset:
    print(item)
    print()

output


tf.Tensor(
[[15 16 17 18 19]
 [ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]], shape=(4, 5), dtype=int32)

tf.Tensor(
[[20 21 22 23 24]
 [ 0  1  2  3  4]
 [20 21 22 23 24]
 [10 11 12 13 14]], shape=(4, 5), dtype=int32)

tf.Tensor(
[[15 16 17 18 19]
 [ 5  6  7  8  9]], shape=(2, 5), dtype=int32)

argumentation

You can apply the function with dataset.map ().

It is desirable that the function to be applied is composed of the tensorflow function, but it seems that it is also possible to convert the normally written function with @ tf.function and tf.py_function.

python


import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage

def rotate(image):
    return ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)

@tf.function
def rotate_tf(image):
    rotated = tf.py_function(rotate,[image],[tf.int32])
    return rotated[0]

[train_x, train_y], [test_x, test_y] =  tf.keras.datasets.mnist.load_data()
train_x = train_x.reshape(-1,28,28,1)
dataset = tf.data.Dataset.from_tensor_slices(train_x)
dataset = dataset.map(rotate_tf).batch(16)

first_batch = next(iter(dataset))
images = first_batch.numpy().reshape((-1,28,28))

plt.figure(figsize=(4, 4))
for i, image in enumerate(sample_images):
    plt.subplot(4, 4,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(image)
    plt.grid(False)
plt.show()

image.png

Put x and y together into a dataset

You can also combine multiple data into one dataset

python


def make_model():
    tf.keras.backend.clear_session()

    inputs = tf.keras.layers.Input(shape=(28, 28))
    network = tf.keras.layers.Flatten()(inputs)
    network = tf.keras.layers.Dense(100, activation='relu')(network)
    outputs = tf.keras.layers.Dense(10, activation='softmax')(network)

    model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy', 
                  metrics=['accuracy'])
    model.summary()
    return model

[x_train, y_train], [x_test, y_test] = tf.keras.datasets.mnist.load_data()

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(x_train.shape[0]).batch(64)
test_data = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(x_test.shape[0]).batch(64)

model = make_model()
hist = model.fit(train_data, validation_data=test_data,
                 epochs=10, verbose=False)

plt.figure(figsize=(4,4))
plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.show()

2. Serialize and use

In order to read the data efficiently, it is recommended to serialize the data and save it as a set of 100-200MB files that can be read continuously. You can easily do this with TFRecord.

Save

Export the TFRecord file with tf.io.TFRecordWriter ()

python


[x_train, y_train], [x_test, y_test] = tf.keras.datasets.mnist.load_data()

def make_example(image, label):
    return tf.train.Example(features=tf.train.Features(feature={
        'x' : tf.train.Feature(float_list=tf.train.FloatList(value=image)),
        'y' : tf.train.Feature(int64_list=tf.train.Int64List(value=label))
    }))

def write_tfrecord(images, labels, filename):
    writer = tf.io.TFRecordWriter(filename)
    for image, label in zip(images, labels):
        ex = make_example(image.ravel().tolist(), [int(label)])
        writer.write(ex.SerializeToString())
    writer.close()

write_tfrecord(x_train, y_train, '../mnist_train.tfrecord')
write_tfrecord(x_test, y_test, '../mnist_test.tfrecord')

image.png The file size seems to be slightly larger than npz.

Read (one record at a time)

Read with tf.data.TFRecordDataset ()

The read data is serialized and must be parsed. In the example below, tf.io.parse_single_example () is used for parse. If you call it with the same key as when you wrote it and reshape it, it will be the same as tf.data before serialization.

python


def parse_features(example):
    features = tf.io.parse_single_example(example, features={
        'x' : tf.io.FixedLenFeature([28, 28], tf.float32),
        'y' : tf.io.FixedLenFeature([1], tf.int64),
    })
    x = features['x']
    y = features['y']
    return x, y

train_dataset = tf.data.TFRecordDataset(filenames='../mnist_train.tfrecord')
train_dataset = train_dataset.map(parse_features).shuffle(60000).batch(512)

test_dataset = tf.data.TFRecordDataset(filenames='../mnist_test.tfrecord')
test_dataset = test_dataset.map(parse_features).shuffle(12000).batch(512)

model = make_model()
hist = model.fit(train_dataset, validation_data=test_dataset,
                 epochs=10, verbose=False)

plt.figure(figsize=(4, 4))
plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.show()

Read (batch unit)

Actually, it is faster to parse in batch units than to parse one record at a time with tf.io.parse_single_example (), so it is recommended to parse in batch units.

python


def dict2tuple(feat):
    return feat["x"], feat["y"]

train_dataset = tf.data.TFRecordDataset(filenames='../mnist_train.tfrecord').batch(512).apply(
                    tf.data.experimental.parse_example_dataset({
                    "x": tf.io.FixedLenFeature([28, 28], dtype=tf.float32),
                    "y": tf.io.FixedLenFeature([1], dtype=tf.int64)})).map(dict2tuple)

test_dataset = tf.data.TFRecordDataset(filenames='../mnist_test.tfrecord')
test_dataset = test_dataset.batch(512).apply(
                    tf.data.experimental.parse_example_dataset({
                    "x": tf.io.FixedLenFeature([28, 28], dtype=tf.float32),
                    "y": tf.io.FixedLenFeature([1], dtype=tf.int64)})).map(dict2tuple)

model = make_model()
hist = model.fit(train_dataset, validation_data=test_dataset,
                 epochs=10, verbose=False)

plt.figure(figsize=(4, 4))
plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.show()

processing time

I measured the processing time when mnist data was trained with the same model. After all, parse one record at a time seems to be quite slow. If you process in batch units, you can get the same speed as on-memory tf.data, so it can be said that it is quite fast.

Also, here the result is that numpy.array is the fastest as it is, but in practice tf.data is obviously faster, so numpy.array is not faster if it is on memory. I think. We would appreciate it if you could try various things in your own environment.

Recommended Posts

How to use tf.data
How to use xml.etree.ElementTree
How to use virtualenv
How to use Seaboan
How to use image-match
How to use shogun
How to use Pandas 2
How to use Virtualenv
How to use numpy.vectorize
How to use pytest_report_header
How to use partial
How to use Bio.Phylo
How to use SymPy
How to use x-means
How to use WikiExtractor.py
How to use IPython
How to use virtualenv
How to use Matplotlib
How to use iptables
How to use numpy
How to use TokyoTechFes2015
How to use venv
How to use dictionary {}
How to use Pyenv
How to use list []
How to use python-kabusapi
How to use OptParse
How to use return
How to use dotenv
How to use pyenv-virtualenv
How to use Go.mod
How to use imutils
How to use import
How to use Qt Designer
How to use search sorted
[gensim] How to use Doc2Vec
python3: How to use bottle (2)
Understand how to use django-filter
How to use the generator
[Python] How to use list 1
How to use FastAPI ③ OpenAPI
How to use Python argparse
How to use IPython Notebook
How to use Pandas Rolling
[Note] How to use virtualenv
How to use redis-py Dictionaries
Python: How to use pydub
[Python] How to use checkio
[Go] How to use "... (3 periods)"
How to use Django's GeoIp2
[Python] How to use input ()
How to use the decorator
[Introduction] How to use open3d
How to use Python lambda
How to use Jupyter Notebook
[Python] How to use virtualenv
python3: How to use bottle (3)
python3: How to use bottle
How to use Google Colaboratory
How to use Python bytes
How to use cron (personal memo)