[PYTHON] Hinweise zur Verwendung von tf.data

Fassen Sie zusammen, was Sie über tf.data gelernt haben, um mit großen Dateien umgehen zu können

Ich habe auf die folgende Seite verwiesen https://qiita.com/Suguru_Toyohara/items/820b0dad955ecd91c7f3 https://qiita.com/wasnot/items/9b64550237a3c5267bfd https://qiita.com/everylittle/items/a7c31b08d2f76c886a92

Was ist tf.data

Eine Bibliothek zur Tensorflow-Datenversorgung. Es scheint die folgenden Vorteile zu haben, wenn es verwendet wird.

  1. Sie können die GPU-Latenz reduzieren und die Lerngeschwindigkeit maximieren
  2. Daten, die nicht in den Speicher passen, können nacheinander gelesen werden
  3. Die Vorverarbeitung wie die Datenerweiterung kann beschleunigt werden
  4. Stellen Sie eine Vorbehandlungsleitung zusammen

1. Konvertieren Sie numpy.array und verwenden Sie es

Konvertieren

Sie können es verwenden, indem Sie von np.array in ein tf.data-Objekt konvertieren

python


import numpy as np
import tensorflow as tf

arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr)

for item in dataset:
    print(item)

output


tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)

repeat

Ausgabe durch Wiederholen der Argumentationszeiten

python


arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).repeat(3)

for item in dataset:
    print(item)

output


tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int32)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int32)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int32)
tf.Tensor([15 16 17 18 19], shape=(5,), dtype=int32)
tf.Tensor([20 21 22 23 24], shape=(5,), dtype=int32)

batch

Argumente werden gestapelt und ausgegeben

python


arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).batch(2)

for item in dataset:
    print(item)

output


tf.Tensor(
[[0 1 2 3 4]
 [5 6 7 8 9]], shape=(2, 5), dtype=int32)

tf.Tensor(
[[10 11 12 13 14]
 [15 16 17 18 19]], shape=(2, 5), dtype=int32)

tf.Tensor([[20 21 22 23 24]], shape=(1, 5), dtype=int32)

shuffle

Das Argument gibt an, wie weit die Daten ersetzt werden sollen. Wenn das Argument 1 ist, wird es nicht ersetzt, und wenn es ein kleiner Wert ist, wird es nicht ausreichend gemischt. Ich denke, es ist besser, den gleichen Wert wie die Datengröße einzugeben.

Klicken Sie hier für Details zur Shuffle-Größe https://qiita.com/exy81/items/d1388f6f02a11c8f1d7e

python


arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).shuffle(5)

for item in dataset:
    print(item)

output


tf.Tensor(
[[0 1 2 3 4]
 [5 6 7 8 9]], shape=(2, 5), dtype=int32)

tf.Tensor(
[[10 11 12 13 14]
 [15 16 17 18 19]], shape=(2, 5), dtype=int32)

tf.Tensor([[20 21 22 23 24]], shape=(1, 5), dtype=int32)

Kombination

Sie können die oben genannten in Kombination verwenden. Es wird in der richtigen Reihenfolge ausgeführt. Achten Sie also darauf, dass Sie nicht die sinnlose Sache tun, die Charge zu schneiden und dann zu mischen.

python


arr = np.arange(25).reshape(5, 5)
dataset = tf.data.Dataset.from_tensor_slices(arr).repeat(2).shuffle(5).batch(4)

for item in dataset:
    print(item)
    print()

output


tf.Tensor(
[[15 16 17 18 19]
 [ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]], shape=(4, 5), dtype=int32)

tf.Tensor(
[[20 21 22 23 24]
 [ 0  1  2  3  4]
 [20 21 22 23 24]
 [10 11 12 13 14]], shape=(4, 5), dtype=int32)

tf.Tensor(
[[15 16 17 18 19]
 [ 5  6  7  8  9]], shape=(2, 5), dtype=int32)

Argumentation

Sie können die Funktion mit dataset.map () anwenden.

Es ist wünschenswert, dass die anzuwendende Funktion aus der Tensorflow-Funktion besteht, aber es scheint auch möglich zu sein, eine normal geschriebene Funktion mit @ tf.function und tf.py_function zu konvertieren.

python


import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage

def rotate(image):
    return ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)

@tf.function
def rotate_tf(image):
    rotated = tf.py_function(rotate,[image],[tf.int32])
    return rotated[0]

[train_x, train_y], [test_x, test_y] =  tf.keras.datasets.mnist.load_data()
train_x = train_x.reshape(-1,28,28,1)
dataset = tf.data.Dataset.from_tensor_slices(train_x)
dataset = dataset.map(rotate_tf).batch(16)

first_batch = next(iter(dataset))
images = first_batch.numpy().reshape((-1,28,28))

plt.figure(figsize=(4, 4))
for i, image in enumerate(sample_images):
    plt.subplot(4, 4,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.imshow(image)
    plt.grid(False)
plt.show()

image.png

Fügen Sie x und y zu einem Datensatz zusammen

Sie können auch mehrere Daten zu einem Datensatz kombinieren

python


def make_model():
    tf.keras.backend.clear_session()

    inputs = tf.keras.layers.Input(shape=(28, 28))
    network = tf.keras.layers.Flatten()(inputs)
    network = tf.keras.layers.Dense(100, activation='relu')(network)
    outputs = tf.keras.layers.Dense(10, activation='softmax')(network)

    model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer='rmsprop', loss='sparse_categorical_crossentropy', 
                  metrics=['accuracy'])
    model.summary()
    return model

[x_train, y_train], [x_test, y_test] = tf.keras.datasets.mnist.load_data()

train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(x_train.shape[0]).batch(64)
test_data = tf.data.Dataset.from_tensor_slices((x_test, y_test)).shuffle(x_test.shape[0]).batch(64)

model = make_model()
hist = model.fit(train_data, validation_data=test_data,
                 epochs=10, verbose=False)

plt.figure(figsize=(4,4))
plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.show()

2. Serialisieren und verwenden

Um die Daten effizient zu lesen, wird empfohlen, die Daten zu serialisieren und als Satz von 100-200 MB-Dateien zu speichern, die kontinuierlich gelesen werden können. Sie können dies problemlos mit TFRecord tun.

sparen

Exportieren Sie die TFRecord-Datei mit tf.io.TFRecordWriter ()

python


[x_train, y_train], [x_test, y_test] = tf.keras.datasets.mnist.load_data()

def make_example(image, label):
    return tf.train.Example(features=tf.train.Features(feature={
        'x' : tf.train.Feature(float_list=tf.train.FloatList(value=image)),
        'y' : tf.train.Feature(int64_list=tf.train.Int64List(value=label))
    }))

def write_tfrecord(images, labels, filename):
    writer = tf.io.TFRecordWriter(filename)
    for image, label in zip(images, labels):
        ex = make_example(image.ravel().tolist(), [int(label)])
        writer.write(ex.SerializeToString())
    writer.close()

write_tfrecord(x_train, y_train, '../mnist_train.tfrecord')
write_tfrecord(x_test, y_test, '../mnist_test.tfrecord')

image.png Die Dateigröße scheint etwas größer als npz zu sein.

Lesen (jeweils 1 Datensatz)

Lesen Sie mit tf.data.TFRecordDataset ()

Die gelesenen Daten werden serialisiert und müssen analysiert werden. Im folgenden Beispiel wird tf.io.parse_single_example () zum Parsen verwendet. Wenn Sie es mit demselben Schlüssel aufrufen wie beim Schreiben und Umformen, entspricht es tf.data vor der Serialisierung.

python


def parse_features(example):
    features = tf.io.parse_single_example(example, features={
        'x' : tf.io.FixedLenFeature([28, 28], tf.float32),
        'y' : tf.io.FixedLenFeature([1], tf.int64),
    })
    x = features['x']
    y = features['y']
    return x, y

train_dataset = tf.data.TFRecordDataset(filenames='../mnist_train.tfrecord')
train_dataset = train_dataset.map(parse_features).shuffle(60000).batch(512)

test_dataset = tf.data.TFRecordDataset(filenames='../mnist_test.tfrecord')
test_dataset = test_dataset.map(parse_features).shuffle(12000).batch(512)

model = make_model()
hist = model.fit(train_dataset, validation_data=test_dataset,
                 epochs=10, verbose=False)

plt.figure(figsize=(4, 4))
plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.show()

Lesen (Batch-Einheit)

Tatsächlich ist es schneller, stapelweise zu analysieren, als mit tf.io.parse_single_example () jeweils einen Datensatz zu analysieren. Daher wird empfohlen, stapelweise zu analysieren.

python


def dict2tuple(feat):
    return feat["x"], feat["y"]

train_dataset = tf.data.TFRecordDataset(filenames='../mnist_train.tfrecord').batch(512).apply(
                    tf.data.experimental.parse_example_dataset({
                    "x": tf.io.FixedLenFeature([28, 28], dtype=tf.float32),
                    "y": tf.io.FixedLenFeature([1], dtype=tf.int64)})).map(dict2tuple)

test_dataset = tf.data.TFRecordDataset(filenames='../mnist_test.tfrecord')
test_dataset = test_dataset.batch(512).apply(
                    tf.data.experimental.parse_example_dataset({
                    "x": tf.io.FixedLenFeature([28, 28], dtype=tf.float32),
                    "y": tf.io.FixedLenFeature([1], dtype=tf.int64)})).map(dict2tuple)

model = make_model()
hist = model.fit(train_dataset, validation_data=test_dataset,
                 epochs=10, verbose=False)

plt.figure(figsize=(4, 4))
plt.plot(hist.history['loss'], label='loss')
plt.plot(hist.history['val_loss'], label='val_loss')
plt.show()

Verarbeitungszeit

Ich habe die Verarbeitungszeit gemessen, als Mnist-Daten mit demselben Modell trainiert wurden. Immerhin scheint es, dass das Parsen eines Datensatzes nach dem anderen ziemlich langsam sein wird. Wenn Sie in Batch-Einheiten verarbeiten, können Sie die gleiche Geschwindigkeit wie im Speicher tf.data erhalten, sodass gesagt werden kann, dass es ziemlich schnell ist.

Auch hier ist das Ergebnis, dass numpy.array so schnell wie es ist, aber in der Praxis ist tf.data offensichtlich schneller, so dass numpy.array nicht schneller ist, wenn es sich im Speicher befindet. Ich denke. Wir würden uns freuen, wenn Sie verschiedene Dinge in Ihrer eigenen Umgebung ausprobieren könnten.

Recommended Posts

Hinweise zur Verwendung von tf.data
Verwendung von xml.etree.ElementTree
Verwendung von virtualenv
Wie benutzt man Seaboan?
Verwendung von Image-Match
Wie man Shogun benutzt
Verwendung von Pandas 2
Verwendung von Virtualenv
Verwendung von numpy.vectorize
Verwendung von pytest_report_header
Wie man teilweise verwendet
Wie man Bio.Phylo benutzt
Verwendung von SymPy
Wie man x-means benutzt
Verwendung von WikiExtractor.py
Verwendung von IPython
Verwendung von virtualenv
Wie benutzt man Matplotlib?
Verwendung von iptables
Wie benutzt man numpy?
Verwendung von TokyoTechFes2015
Wie benutzt man venv
Verwendung des Wörterbuchs {}
Wie benutzt man Pyenv?
Verwendung der Liste []
Wie man Python-Kabusapi benutzt
Verwendung von OptParse
Verwendung von return
Wie man Imutils benutzt
Verwendung von Qt Designer
Verwendung der Suche sortiert
[gensim] Verwendung von Doc2Vec
python3: Verwendung der Flasche (2)
Verstehen Sie, wie man Django-Filter verwendet
Verwendung des Generators
[Python] Verwendung von Liste 1
Verwendung von FastAPI ③ OpenAPI
Wie benutzt man Python Argparse?
Verwendung von IPython Notebook
Wie man Pandas Rolling benutzt
[Hinweis] Verwendung von virtualenv
Verwendung von Redispy-Wörterbüchern
Python: Wie man pydub benutzt
[Python] Verwendung von checkio
[Go] Verwendung von "... (3 Perioden)"
So bedienen Sie GeoIp2 von Django
[Python] Verwendung von input ()
Wie benutzt man den Dekorateur?
[Einführung] Verwendung von open3d
Wie benutzt man Python Lambda?
So verwenden Sie Jupyter Notebook
[Python] Verwendung von virtualenv
python3: Verwendung der Flasche (3)
python3: Wie man eine Flasche benutzt
So verwenden Sie Google Colaboratory
Verwendung von Python-Bytes
Verwendung von cron (persönliches Memo)