[PYTHON] How to use tf.data

Summarize what you learned about tf.data to handle large files

I referred to the following site https://qiita.com/Suguru_Toyohara/items/820b0dad955ecd91c7f3 https://qiita.com/wasnot/items/9b64550237a3c5267bfd https://qiita.com/everylittle/items/a7c31b08d2f76c886a92

What is tf.data

It is a library related to data supply of tensorflow. It seems to have the following merits when used.

  1. You can reduce GPU latency and maximize learning speed
  2. Data that does not fit in memory can be read sequentially
  3. Preprocessing such as data augumentation can be accelerated
  4. Put together a pretreatment pipeline

1. Convert and use numpy.array


You can use it by converting from np.array to tf.data object


import numpy as np
import tensorflow as tf

arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr)

for item in dataset:


tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)


Output by repeating the argument times


arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).repeat(3)

for item in dataset:


tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)


Output as a batch argument


arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).batch(2)

for item in dataset:


[[0 1 2 3 4]
 [5 6 7 8 9]], shape=(2, 5), dtype=int32)

[[10 11 12 13 14]
 [15 16 17 18 19]], shape=(2, 5), dtype=int32)

tf.Tensor([[20 21 22 23 24]], shape=(1, 5), dtype=int32)


The argument specifies how far the data should be replaced. If the argument is 1, there will be no replacement, and if it is a small value, it will not be shuffled sufficiently, so I think it is better to enter the same value as the data size.

Click here for details on shuffle size https://qiita.com/exy81/items/d1388f6f02a11c8f1d7e


arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).shuffle(5)

for item in dataset:


[[0 1 2 3 4]
 [5 6 7 8 9]], shape=(2, 5), dtype=int32)

[[10 11 12 13 14]
 [15 16 17 18 19]], shape=(2, 5), dtype=int32)

tf.Tensor([[20 21 22 23 24]], shape=(1, 5), dtype=int32)


You can use the above in combination. It will be executed in order, so be careful not to do the meaningless thing of cutting the batch and then shuffling.


arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).repeat(2).shuffle(5).batch(4)

for item in dataset:


[[15 16 17 18 19]
 [ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]], shape=(4, 5), dtype=int32)

[[20 21 22 23 24]
 [ 0  1  2  3  4]
 [20 21 22 23 24]
 [10 11 12 13 14]], shape=(4, 5), dtype=int32)

[[15 16 17 18 19]
 [ 5  6  7  8  9]], shape=(2, 5), dtype=int32)


You can apply the function with dataset.map ().

It is desirable that the function to be applied is composed of the tensorflow function, but it seems that it is also possible to convert the normally written function with @ tf.function and tf.py_function.


import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage

def rotate(image):
    return ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)

def rotate_tf(image):
    rotated = tf.py_function(rotate,[image],[tf.int32])
    return rotated[0]

[train_x, train_y], [test_x, test_y] =  tf.keras.datasets.mnist.load_data()
train_x = train_x.reshape(-1,28,28,1)
dataset = tf.data.Dataset.from_tensor_slices(train_x)
dataset = dataset.map(rotate_tf).batch(16)

first_batch = next(iter(dataset))
images = first_batch.numpy().reshape((-1,28,28))

plt.figure(figsize=(4, 4))
for i, image in enumerate(sample_images):
    plt.subplot(4, 4,i+1)


Put x and y together into a dataset

You can also combine multiple data into one dataset


def make_model():

    inputs = tf.keras.layers.Input(shape=(28, 28))
    network = tf.keras.layers.Flatten()(inputs)
    network = tf.keras.layers.Dense(100, activation='relu')(network)
    outputs = tf.keras.layers.Dense(10, activation='softmax')(network)

    model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy', 
    return model

[x_train, y_train], [x_test, y_test] = tf.keras.datasets.mnist.load_data()

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(x_train.shape[0]).batch(64)
test_data = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(x_test.shape[0]).batch(64)

model = make_model()
hist = model.fit(train_data, validation_data=test_data,
                 epochs=10, verbose=False)

plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')

2. Serialize and use

In order to read the data efficiently, it is recommended to serialize the data and save it as a set of 100-200MB files that can be read continuously. You can easily do this with TFRecord.


Export the TFRecord file with tf.io.TFRecordWriter ()


[x_train, y_train], [x_test, y_test] = tf.keras.datasets.mnist.load_data()

def make_example(image, label):
    return tf.train.Example(features=tf.train.Features(feature={
        'x' : tf.train.Feature(float_list=tf.train.FloatList(value=image)),
        'y' : tf.train.Feature(int64_list=tf.train.Int64List(value=label))

def write_tfrecord(images, labels, filename):
    writer = tf.io.TFRecordWriter(filename)
    for image, label in zip(images, labels):
        ex = make_example(image.ravel().tolist(), [int(label)])

write_tfrecord(x_train, y_train, '../mnist_train.tfrecord')
write_tfrecord(x_test, y_test, '../mnist_test.tfrecord')

image.png The file size seems to be slightly larger than npz.

Read (one record at a time)

Read with tf.data.TFRecordDataset ()

The read data is serialized and must be parsed. In the example below, tf.io.parse_single_example () is used for parse. If you call it with the same key as when you wrote it and reshape it, it will be the same as tf.data before serialization.


def parse_features(example):
    features = tf.io.parse_single_example(example, features={
        'x' : tf.io.FixedLenFeature([28, 28], tf.float32),
        'y' : tf.io.FixedLenFeature([1], tf.int64),
    x = features['x']
    y = features['y']
    return x, y

train_dataset = tf.data.TFRecordDataset(filenames='../mnist_train.tfrecord')
train_dataset = train_dataset.map(parse_features).shuffle(60000).batch(512)

test_dataset = tf.data.TFRecordDataset(filenames='../mnist_test.tfrecord')
test_dataset = test_dataset.map(parse_features).shuffle(12000).batch(512)

model = make_model()
hist = model.fit(train_dataset, validation_data=test_dataset,
                 epochs=10, verbose=False)

plt.figure(figsize=(4, 4))
plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')

Read (batch unit)

Actually, it is faster to parse in batch units than to parse one record at a time with tf.io.parse_single_example (), so it is recommended to parse in batch units.


def dict2tuple(feat):
    return feat["x"], feat["y"]

train_dataset = tf.data.TFRecordDataset(filenames='../mnist_train.tfrecord').batch(512).apply(
                    "x": tf.io.FixedLenFeature([28, 28], dtype=tf.float32),
                    "y": tf.io.FixedLenFeature([1], dtype=tf.int64)})).map(dict2tuple)

test_dataset = tf.data.TFRecordDataset(filenames='../mnist_test.tfrecord')
test_dataset = test_dataset.batch(512).apply(
                    "x": tf.io.FixedLenFeature([28, 28], dtype=tf.float32),
                    "y": tf.io.FixedLenFeature([1], dtype=tf.int64)})).map(dict2tuple)

model = make_model()
hist = model.fit(train_dataset, validation_data=test_dataset,
                 epochs=10, verbose=False)

plt.figure(figsize=(4, 4))
plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')

processing time

I measured the processing time when mnist data was trained with the same model. After all, parse one record at a time seems to be quite slow. If you process in batch units, you can get the same speed as on-memory tf.data, so it can be said that it is quite fast.

Also, here the result is that numpy.array is the fastest as it is, but in practice tf.data is obviously faster, so numpy.array is not faster if it is on memory. I think. We would appreciate it if you could try various things in your own environment.

