[PYTHON] VAEGAN chez TF Learn

Qu'est-ce que le VAEGAN

Autoencoding beyond pixels using a learned similarity metric

VAE avec discriminateur GAN derrière Utilisez la carte des caractéristiques extraite de la couche intermédiaire de Discriminator au lieu de l'image générée par Decoder pour l'erreur VAE. Alors que VAE mesure l'erreur en unités de pixels, des images floues apparaissent, mais en mesurant l'erreur en unités de carte de caractéristiques, il peut être possible de générer des images fines tout en reproduisant des caractéristiques globales.

L'encodeur apprend la différence entre la carte des caractéristiques de l'image originale et l'image décodée comme une erreur (VAE) En plus de cela, Decoder apprend également le résultat de la discrimination de l'image décodée / image générée aléatoirement par Discriminator comme une erreur (VAE + GAN). Le discriminateur apprend des résultats d'identification des images originales, des images décodées et des images générées aléatoirement (GAN)

Il semble que VAE devrait être pré-formé avec des unités de pixels normales et GAN avec l'image originale et l'image générée aléatoirement (remplacez le générateur de GAN par le décodeur VAE et apprenez les deux indépendamment)

J'expérimente toujours, donc je peux le changer plus tard.

code

Au moment de DCGAN, j'ai fait de mon mieux pour l'amener à l'implémentation d'origine, mais cette fois j'ai terminé avec une version simplifiée où chacun des encodeurs, décodeurs, discriminateurs est formé avec différents échantillons

J'ai fait référence au document original et à la mise en œuvre d'autres, mais j'ai apporté quelques modifications

--VAE et GAN ont été formés en même temps en Pretraining, mais VAE → GAN sont formés dans l'ordre. --Dans Mean Square et Kullback Leibler Divergence, les dimensions de Feature et ainsi de suite ont été considérées comme totales (Sum), et les dimensions de Sample ont été prises comme moyennes (Mean), mais lorsque j'ai changé la taille de l'image et les variables latentes, le rapport a changé et j'ai senti que c'était gênant. Donc tout a changé en moyenne (peut être mathématiquement faux)

vaegan.py


from __future__ import (
    division,
    print_function,
    absolute_import
)
from six.moves import range

import tensorflow as tf
import tflearn

import os
import numpy as np
from skimage import io

PRE_VAE_TENSORBOARD_DIR = '/tmp/tflearn_logs/vae/'
PRE_DIS_TENSORBOARD_DIR = '/tmp/tflearn_logs/dis/'
PRE_VAE_CHECKPOINT_PATH = '/tmp/vaegan/pre-vae'
PRE_DIS_CHECKPOINT_PATH = '/tmp/vaegan/pre-dis'
CHECKPOINT_PATH = '/tmp/vaegan/model'

DNN = tflearn.DNN
input_data = tflearn.input_data
fc = tflearn.fully_connected
reshape = tflearn.reshape
conv = tflearn.conv_2d
conv_t = tflearn.conv_2d_transpose
max_pool = tflearn.max_pool_2d
bn = tflearn.batch_normalization
merge = tflearn.merge
sigmoid = tflearn.sigmoid
softmax = tflearn.softmax
softplus = tflearn.softplus
relu = tflearn.relu
elu = tflearn.elu
crossentropy = tflearn.categorical_crossentropy
adam = tflearn.Adam
Trainer = tflearn.Trainer
TrainOp = tflearn.TrainOp

if not os.path.exists('/tmp/tflearn_logs'):
    os.mkdir('/tmp/tflearn_logs')
if not os.path.exists(PRE_VAE_TENSORBOARD_DIR):
    os.mkdir(PRE_VAE_TENSORBOARD_DIR)
if not os.path.exists(PRE_DIS_TENSORBOARD_DIR):
    os.mkdir(PRE_DIS_TENSORBOARD_DIR)
if not os.path.exists('/tmp/vaegan/'):
    os.mkdir('/tmp/vaegan/')

class VAEGAN(object):
    def __init__(self, img_shape, n_first_channel, n_layer, latent_dim,
                 kullback_leibler_ratio, reconstruction_weight_against_detail,
                 vae_learning_rate=0.001, vae_beta1=0.5,
                 discriminator_learning_rate=0.00001, discriminator_beta1=0.5):
        self.img_shape = list(img_shape)
        self.input_shape = [None] + self.img_shape
        self.img_size = img_shape[:2]
        self.n_first_channel = n_first_channel
        self.n_layer = n_layer
        self.kullback_leibler_ratio = kullback_leibler_ratio
        self.reconstruction_weight_against_detail = reconstruction_weight_against_detail
        self.latent_dim = latent_dim
        self.vae_learning_rate = vae_learning_rate
        self.vae_beta1 = vae_beta1
        self.discriminator_learning_rate = discriminator_learning_rate
        self.discriminator_beta1 = discriminator_beta1

        assert self.n_layer > 1, 'n_layer must be more than 1'

        self.vae_pretrainer = None
        self.discriminator_pretrainer = None
        self.trainer = None
        self.decoder_graph = tf.Graph()
        self.trained_values = {}

    def _build_vae_pretrainer(self, encoder, decoder):
        inputs = input_data(shape=self.input_shape, name='input_x')
        # Build Network
        mean, log_var = encoder(inputs)
        encoded = self._encode(mean, log_var)
        decoded = decoder(encoded)
        # Loss
        element_wise_loss = self._get_mean_square(decoded, inputs)
        kullback_leibler_divergence = \
            self._get_kullback_leibler_divergence(mean, log_var)
        pretrain_vae_loss = self.reconstruction_weight_against_detail *\
            tf.reduce_mean(element_wise_loss + kullback_leibler_divergence)
        # Trainer
        pretrain_vae_op = TrainOp(loss=pretrain_vae_loss, 
                                  optimizer=self._get_optimizer('vae'), 
                                  batch_size=128, 
                                  name='VAE_pretrainer')

        return Trainer(pretrain_vae_op, tensorboard_dir=PRE_VAE_TENSORBOARD_DIR,
                       tensorboard_verbose=0,
                       checkpoint_path=PRE_VAE_CHECKPOINT_PATH,
                       max_checkpoints=1)

    def _build_discriminator_pretrainer(self, decoder, discriminator):
        inputs = input_data(shape=self.input_shape, name='input_x')
        is_true = input_data(shape=(None, 2), name='is_true')
        is_false = input_data(shape=(None, 2), name='is_false')
        # Build Network
        shape = tf.shape(fc(inputs, self.latent_dim))
        random_image = decoder(self._get_z(shape))
        prediction_origin, _ = discriminator(inputs)
        prediction_random, _ = discriminator(random_image, reuse=True)
        # Loss
        prediction_all = merge([prediction_origin, prediction_random], 'concat',
                               axis=0)
        y_all = merge([is_true, is_false], 'concat', axis=0)
        pretrain_discriminator_loss = crossentropy(prediction_all, y_all)
        # Trainer
        pretrain_discriminator_op = TrainOp(
            loss=pretrain_discriminator_loss, 
            optimizer=self._get_optimizer('discriminator'), batch_size=128,
            trainable_vars=self._get_trainable_variables(discriminator.scope),
            name='Discriminator_pretrainer')

        return Trainer(pretrain_discriminator_op,
                       tensorboard_dir=PRE_DIS_TENSORBOARD_DIR,
                       tensorboard_verbose=0,
                       checkpoint_path=PRE_DIS_CHECKPOINT_PATH,
                       max_checkpoints=1)

    def _build_trainer(self, encoder, decoder, discriminator):
        inputs = input_data(shape=self.input_shape, name='input_x')
        is_true = input_data(shape=(None, 2), name='is_true')
        is_false = input_data(shape=(None, 2), name='is_false')
        # Build Network
        mean, log_var = encoder(inputs)
        encoded = self._encode(mean, log_var)
        decoded = decoder(encoded)
        random_image = decoder(self._get_z(tf.shape(mean)), reuse=True)
        # Loss
        ## Encoder
        prediction_origin, feature_map_origin = discriminator(inputs)
        prediction_decoded, feature_map_decoded = discriminator(decoded, reuse=True)
        prediction_random, _ = discriminator(random_image, reuse=True)
        ## Decoder
        feature_wise_loss = \
            self._get_mean_square(feature_map_decoded, feature_map_origin)
        kullback_leibler_divergence = \
            self._get_kullback_leibler_divergence(mean, log_var)
        encoder_loss = self.reconstruction_weight_against_detail *\
            tf.reduce_mean(feature_wise_loss + kullback_leibler_divergence)

        prediction_gan = merge([prediction_decoded, prediction_random],
                               'concat', axis=0)
        y_gan = merge([is_true, is_true], 'concat', axis=0)
        gan_generator_loss = crossentropy(prediction_gan, y_gan)
        decoder_loss = (encoder_loss + gan_generator_loss) * 0.5
        ## Discriminator
        prediction_fake = merge([prediction_decoded, prediction_random],
                                'concat', axis=0)
        y_fake = merge([is_false, is_false], 'concat', axis=0)
        real_loss = crossentropy(prediction_origin, is_true)
        fake_loss = crossentropy(prediction_fake, y_fake)
        discriminator_loss = (real_loss + fake_loss) * 0.5
        # Trainer
        encoder_op = TrainOp(
            loss=encoder_loss,
            optimizer=self._get_optimizer('encoder'),
            batch_size=64,
            trainable_vars=self._get_trainable_variables(encoder.scope),
            name='Encoder')
        decoder_op = TrainOp(
            loss=decoder_loss,
            optimizer=self._get_optimizer('decoder'),
            batch_size=64,
            trainable_vars=self._get_trainable_variables(decoder.scope),
            name='Decoder')
        discriminator_op = TrainOp(
            loss=discriminator_loss,
            optimizer=self._get_optimizer('discriminator'),
            batch_size=64,
            trainable_vars=self._get_trainable_variables(discriminator.scope),
            name='Discriminator')
        return Trainer([encoder_op, decoder_op, discriminator_op],
                       checkpoint_path=CHECKPOINT_PATH, max_checkpoints=1)

    def _encode(self, mean, log_var):
        epsilon = tf.random_normal(tf.shape(mean), name='Epsilon')

        return mean + tf.exp(0.5 * log_var) * epsilon

    def _get_z(self, shape):
        z = tf.random_normal(shape, name='RandomZ')

        return reshape(z, (-1, self.latent_dim))

    def _get_kullback_leibler_divergence(self, mean, log_var):
        square_mean = tf.pow(mean, 2)
        variance = tf.exp(log_var)

        kullback_leibler_divergence = \
            tf.reduce_mean(1 + log_var - square_mean - variance,
                          reduction_indices=1)
        kullback_leibler_divergence = \
            - 0.5 * self.kullback_leibler_ratio * kullback_leibler_divergence

        return kullback_leibler_divergence

    def _get_mean_square(self, prediction, truth):
        return tf.reduce_mean(tf.squared_difference(prediction, truth),
                             reduction_indices=(1, 2, 3))

    def _get_optimizer(self, type_str):
        if type_str in ['vae', 'encoder', 'decoder']:
            learning_rate = self.vae_learning_rate
            beta1 = self.vae_beta1
        else: # 'discriminator'
            learning_rate = self.discriminator_learning_rate
            beta1 = self.discriminator_beta1
        opt = adam(learning_rate=learning_rate, beta1=beta1)

        return opt.get_tensor()

    def _get_trainable_variables(self, scope):
        return [v for v in tflearn.get_all_trainable_variable()
                if scope + '/' in v.name]

    def _get_input_tensor_by_name(self, name):
        return tf.get_collection(tf.GraphKeys.INPUTS, scope=name)[0]

    def train(self, x, n_sample=None, pretrain_vae_epoch=1, 
              pretrain_discriminator_epoch=1, train_epoch=10):
        if n_sample == None:
            n_sample = x.shape[0]
        is_true = np.tile([0., 1.], [n_sample, 1])
        is_false = np.tile([1., 0.], [n_sample, 1])

        encoder = Encoder(self.n_first_channel, self.n_layer, self.latent_dim)
        decoder = Decoder(self.img_shape, self.n_first_channel, self.n_layer)
        discriminator = Discriminator(self.n_first_channel, self.n_layer)

        with tf.Graph().as_default():
            self.vae_pretrainer = self._build_vae_pretrainer(encoder, decoder)
            trainer = self.vae_pretrainer
        
            input_tensor = self._get_input_tensor_by_name('input_x')
            feed_dict = {input_tensor:x}
            trainer.fit(feed_dict, n_epoch=pretrain_vae_epoch,
                        snapshot_epoch=True, shuffle_all=True,
                        run_id='VAE_pretrain')
            self.trained_values[encoder.scope] = \
                self._get_trained_values(trainer, encoder.scope)
            self.trained_values[decoder.scope] = \
                self._get_trained_values(trainer, decoder.scope)
        
        with tf.Graph().as_default():
            self.discriminator_pretrainer = \
                self._build_discriminator_pretrainer(decoder, discriminator)
            trainer = self.discriminator_pretrainer
            self._assign_values(trainer, decoder.scope)
        
            input_tensor = self._get_input_tensor_by_name('input_x')
            true_tensor = self._get_input_tensor_by_name('is_true')
            false_tensor = self._get_input_tensor_by_name('is_false')
            feed_dict = {input_tensor:x,
                         true_tensor:is_true,
                         false_tensor:is_false}
            trainer.fit(feed_dict, n_epoch=pretrain_discriminator_epoch,
                        snapshot_epoch=True, shuffle_all=True,
                        run_id='Discriminator_pretrain')
            self.trained_values[discriminator.scope] = \
                self._get_trained_values(trainer, discriminator.scope)

        with tf.Graph().as_default():
            self.trainer = self._build_trainer(encoder, decoder, discriminator)
            trainer = self.trainer
            self._assign_values(trainer, encoder.scope)
            self._assign_values(trainer, decoder.scope)
            self._assign_values(trainer, discriminator.scope)
            self._set_decoder(decoder)

            input_tensor = self._get_input_tensor_by_name('input_x')
            true_tensor = self._get_input_tensor_by_name('is_true')
            false_tensor = self._get_input_tensor_by_name('is_false')
            feed_dict = {input_tensor:x,
                         true_tensor:is_true,
                         false_tensor:is_false}
            self.trainer.fit([feed_dict] * 3, n_epoch=train_epoch,
                        snapshot_step=1000, snapshot_epoch=False,
                        shuffle_all=True, run_id='VAEGAN',
                        callbacks=[CustomCallback(self)])

    def _get_trained_values(self, trainer, scope):
        return {v.name:tflearn.variables.get_value(v, session=trainer.session)
                for v in self._get_trainable_variables(scope)}

    def _assign_values(self, trainer, scope):
        [trainer.session.run(v.assign(self.trained_values[scope][v.name]))
         for v in self._get_trainable_variables(scope)]

    def _set_decoder(self, decoder):
        with self.decoder_graph.as_default():
            inputs = input_data(shape=(None, self.latent_dim))
            net = decoder(inputs)
            self.decoder = DNN(net)

    def decode(self, z):
        with self.decoder_graph.as_default():
            return self.decoder.predict(z)

class Encoder(object):
    def __init__(self, n_first_channel, n_layer, latent_dim):
        self.n_first_channel = n_first_channel
        self.n_layer = n_layer
        self.latent_dim = latent_dim
        self.scope = 'Encoder'

    def __call__(self, x, reuse=False):
        net = x

        for i in range(self.n_layer):
            n_channel = self.n_first_channel * 2 ** i
            net = conv(net, n_channel, 4, strides=2, reuse=reuse,
                       scope='{s}/Conv_{n}'.format(s=self.scope, n=i))
            net = bn(net, reuse=reuse,
                     scope='{s}/BN_{n}'.format(s=self.scope, n=i))
            net = relu(net)
            # net = softplus(net)
        mean = fc(net, self.latent_dim, reuse=reuse,
                  scope='{s}/Mean'.format(s=self.scope))
        log_var = fc(net, self.latent_dim, reuse=reuse,
                     scope='{s}/LogVariance'.format(s=self.scope))

        return mean, log_var

class Decoder(object):
    def __init__(self, img_shape, n_first_channel, n_layer):
        self.img_size = img_shape[:2]
        self.color_channel = img_shape[2]
        self.n_first_channel = n_first_channel * 2 ** (n_layer - 1)
        self.n_layer = n_layer
        self.scope = 'Decoder'

    def __call__(self, z, reuse=False):
        net = z

        feature_height = self.img_size[0] // 2 ** self.n_layer
        feature_width = self.img_size[1] // 2 ** self.n_layer
        feature_channel = self.n_first_channel

        n_units = feature_height * feature_width * feature_channel
        net = fc(net, n_units, reuse=reuse, scope='{s}/FC'.format(s=self.scope))
        shape = [-1, feature_height, feature_width, feature_channel]
        net = reshape(net, shape)

        for i in range(self.n_layer):
            feature_height *= 2
            feature_width *= 2
            if i < self.n_layer - 1:
                feature_channel //= 2
            else:
                feature_channel = self.color_channel

            net = bn(net, reuse=reuse,
                     scope='{s}/BN_{n}'.format(s=self.scope, n=i))
            net = relu(net)
            # net = elu(net)
            net = conv_t(net, feature_channel, 4,
                         [feature_height, feature_width], strides=2,
                         reuse=reuse,
                         scope='{s}/ConvT_{n}'.format(s=self.scope, n=i))

        net = sigmoid(net)

        return net

class Discriminator(object):
    def __init__(self, n_first_channel, n_layer):
        self.n_first_channel = n_first_channel
        self.n_layer = n_layer
        self.scope = 'Discriminator'

    def __call__(self, x, reuse=False):
        net = x

        for i in range(self.n_layer):
            net = conv(net, self.n_first_channel * 2 ** i, 4, reuse=reuse,
                       scope='{s}/Conv_{n}'.format(s=self.scope, n=i))
            net = max_pool(net, 2)
            net = bn(net, reuse=reuse, 
                     scope='{s}/BN_{n}'.format(s=self.scope, n=i))
            # net = relu(net)
            net = elu(net)
            if i == self.n_layer - 1:
                feature_reconstruction = net

        net = fc(net, 2, reuse=reuse, scope='{s}/FC'.format(s=self.scope))
        net = softmax(net)

        return net, feature_reconstruction

class CustomCallback(tflearn.callbacks.Callback):
    def __init__(self, model, n_side=10):
        self.model = model
        self.n_side = n_side
        self.sample_z = np.random.normal(size=(n_side ** 2, model.latent_dim))

    def _save(self, name, z):
        model = self.model
        n_side = self.n_side
        img_height = model.img_shape[0]
        img_width = model.img_shape[1]
        img_channel = model.img_shape[2]
        image = np.ndarray(shape=(n_side * img_height,
                                  n_side * img_width,
                                  img_channel),
                           dtype=np.float32)

        model.trained_values = {
            scope:model._get_trained_values(model.trainer, scope)
            for scope in model.trained_values}
        with model.decoder_graph.as_default():
            [model._assign_values(model.decoder, scope)
             for scope in model.trained_values]
        decoded = model.decode(z)

        for y in range(n_side):
            for x in range(n_side):
                image[y * img_height : (y + 1) * img_height,
                      x * img_width : (x + 1) * img_width,
                      :] = decoded[x + y * n_side]
        image = np.clip(image, 0, 1)
        image *= 255
        io.imsave(name, image.astype(np.uint8))

    def on_batch_end(self, training_state, snapshot=False):
        if snapshot:
            step = training_state.step

            file_name = '{path}image-{step}.png'.format(path=CHECKPOINT_PATH,
                                                        step=step)
            self._save(file_name, self.sample_z)

    def on_train_end(self, training_state):
        latent_dim = self.model.latent_dim

        sample_z = np.ndarray(shape=(self.n_side ** 2, latent_dim),
                              dtype=np.float32)
        for row in range(self.n_side):
            start = np.random.normal(size=latent_dim)
            stop = np.random.normal(size=latent_dim)
            z_rows = np.array([np.linspace(start[i], stop[i], num=self.n_side)
                               for i in range(latent_dim)]).T
            sample_z[row * self.n_side : (row + 1) * self.n_side, :] = z_rows

        file_name = '{path}image-final.png'.format(path=CHECKPOINT_PATH)
        self._save(file_name, sample_z)

(X, Y), (testX, testY) = tflearn.datasets.cifar10.load_data()
X = np.concatenate((X, testX), axis=0)
Y = np.concatenate((Y, testY), axis=0)
X = X[Y == 1]

img_shape = X.shape[1:]

vaegan = VAEGAN(img_shape=img_shape, n_first_channel=64, n_layer=4,
                latent_dim=32, kullback_leibler_ratio=0.01,
                reconstruction_weight_against_detail=50.0)
vaegan.train(X, pretrain_vae_epoch=1, pretrain_discriminator_epoch=10, 
             train_epoch=100)

Site de référence

VAEGAN fauxtograph

Recommended Posts

VAEGAN chez TF Learn
DCGAN avec TF Learn