Autoencoding beyond pixels using a learned similarity metric
VAE mit GAN-Diskriminator dahinter Verwenden Sie die aus der mittleren Ebene von Discriminator extrahierte Feature-Map anstelle des von Decoder für VAE-Fehler generierten Bildes. Während VAE den Fehler in Pixeleinheiten misst, erscheinen verschwommene Bilder. Durch Messen des Fehlers in Feature-Map-Einheiten können jedoch möglicherweise feine Bilder erzeugt werden, während globale Features reproduziert werden.
Der Encoder lernt den Unterschied zwischen der Feature-Map des Originalbilds und dem decodierten Bild als Fehler (VAE). Zusätzlich dazu lernt der Decoder das Diskriminierungsergebnis des decodierten Bildes / zufällig erzeugten Bildes durch den Diskriminator als Fehler (VAE + GAN). Der Diskriminator lernt aus den Identifikationsergebnissen von Originalbildern, decodierten Bildern und zufällig generierten Bildern (GAN).
Es scheint gut zu sein, VAE mit normalen Pixeleinheiten und GAN mit Originalbild und zufällig erzeugtem Bild vorab zu lernen (ersetzen Sie den GAN-Generator durch den VAE-Decoder und lernen Sie beide unabhängig voneinander).
Ich experimentiere immer noch, also kann ich es später ändern.
Im Fall von DCGAN habe ich mein Bestes versucht, um es auf die ursprüngliche Implementierung zu bringen, aber diesmal habe ich eine vereinfachte Version erhalten, in der jeder von Encoder, Decoder und Discriminator mit unterschiedlichen Beispielen trainiert wird.
Ich habe mich auf das Originalpapier und die Implementierung anderer bezogen, aber einige Änderungen vorgenommen
vaegan.py
from __future__ import (
division,
print_function,
absolute_import
)
from six.moves import range
import tensorflow as tf
import tflearn
import os
import numpy as np
from skimage import io
PRE_VAE_TENSORBOARD_DIR = '/tmp/tflearn_logs/vae/'
PRE_DIS_TENSORBOARD_DIR = '/tmp/tflearn_logs/dis/'
PRE_VAE_CHECKPOINT_PATH = '/tmp/vaegan/pre-vae'
PRE_DIS_CHECKPOINT_PATH = '/tmp/vaegan/pre-dis'
CHECKPOINT_PATH = '/tmp/vaegan/model'
DNN = tflearn.DNN
input_data = tflearn.input_data
fc = tflearn.fully_connected
reshape = tflearn.reshape
conv = tflearn.conv_2d
conv_t = tflearn.conv_2d_transpose
max_pool = tflearn.max_pool_2d
bn = tflearn.batch_normalization
merge = tflearn.merge
sigmoid = tflearn.sigmoid
softmax = tflearn.softmax
softplus = tflearn.softplus
relu = tflearn.relu
elu = tflearn.elu
crossentropy = tflearn.categorical_crossentropy
adam = tflearn.Adam
Trainer = tflearn.Trainer
TrainOp = tflearn.TrainOp
if not os.path.exists('/tmp/tflearn_logs'):
os.mkdir('/tmp/tflearn_logs')
if not os.path.exists(PRE_VAE_TENSORBOARD_DIR):
os.mkdir(PRE_VAE_TENSORBOARD_DIR)
if not os.path.exists(PRE_DIS_TENSORBOARD_DIR):
os.mkdir(PRE_DIS_TENSORBOARD_DIR)
if not os.path.exists('/tmp/vaegan/'):
os.mkdir('/tmp/vaegan/')
class VAEGAN(object):
def __init__(self, img_shape, n_first_channel, n_layer, latent_dim,
kullback_leibler_ratio, reconstruction_weight_against_detail,
vae_learning_rate=0.001, vae_beta1=0.5,
discriminator_learning_rate=0.00001, discriminator_beta1=0.5):
self.img_shape = list(img_shape)
self.input_shape = [None] + self.img_shape
self.img_size = img_shape[:2]
self.n_first_channel = n_first_channel
self.n_layer = n_layer
self.kullback_leibler_ratio = kullback_leibler_ratio
self.reconstruction_weight_against_detail = reconstruction_weight_against_detail
self.latent_dim = latent_dim
self.vae_learning_rate = vae_learning_rate
self.vae_beta1 = vae_beta1
self.discriminator_learning_rate = discriminator_learning_rate
self.discriminator_beta1 = discriminator_beta1
assert self.n_layer > 1, 'n_layer must be more than 1'
self.vae_pretrainer = None
self.discriminator_pretrainer = None
self.trainer = None
self.decoder_graph = tf.Graph()
self.trained_values = {}
def _build_vae_pretrainer(self, encoder, decoder):
inputs = input_data(shape=self.input_shape, name='input_x')
# Build Network
mean, log_var = encoder(inputs)
encoded = self._encode(mean, log_var)
decoded = decoder(encoded)
# Loss
element_wise_loss = self._get_mean_square(decoded, inputs)
kullback_leibler_divergence = \
self._get_kullback_leibler_divergence(mean, log_var)
pretrain_vae_loss = self.reconstruction_weight_against_detail *\
tf.reduce_mean(element_wise_loss + kullback_leibler_divergence)
# Trainer
pretrain_vae_op = TrainOp(loss=pretrain_vae_loss,
optimizer=self._get_optimizer('vae'),
batch_size=128,
name='VAE_pretrainer')
return Trainer(pretrain_vae_op, tensorboard_dir=PRE_VAE_TENSORBOARD_DIR,
tensorboard_verbose=0,
checkpoint_path=PRE_VAE_CHECKPOINT_PATH,
max_checkpoints=1)
def _build_discriminator_pretrainer(self, decoder, discriminator):
inputs = input_data(shape=self.input_shape, name='input_x')
is_true = input_data(shape=(None, 2), name='is_true')
is_false = input_data(shape=(None, 2), name='is_false')
# Build Network
shape = tf.shape(fc(inputs, self.latent_dim))
random_image = decoder(self._get_z(shape))
prediction_origin, _ = discriminator(inputs)
prediction_random, _ = discriminator(random_image, reuse=True)
# Loss
prediction_all = merge([prediction_origin, prediction_random], 'concat',
axis=0)
y_all = merge([is_true, is_false], 'concat', axis=0)
pretrain_discriminator_loss = crossentropy(prediction_all, y_all)
# Trainer
pretrain_discriminator_op = TrainOp(
loss=pretrain_discriminator_loss,
optimizer=self._get_optimizer('discriminator'), batch_size=128,
trainable_vars=self._get_trainable_variables(discriminator.scope),
name='Discriminator_pretrainer')
return Trainer(pretrain_discriminator_op,
tensorboard_dir=PRE_DIS_TENSORBOARD_DIR,
tensorboard_verbose=0,
checkpoint_path=PRE_DIS_CHECKPOINT_PATH,
max_checkpoints=1)
def _build_trainer(self, encoder, decoder, discriminator):
inputs = input_data(shape=self.input_shape, name='input_x')
is_true = input_data(shape=(None, 2), name='is_true')
is_false = input_data(shape=(None, 2), name='is_false')
# Build Network
mean, log_var = encoder(inputs)
encoded = self._encode(mean, log_var)
decoded = decoder(encoded)
random_image = decoder(self._get_z(tf.shape(mean)), reuse=True)
# Loss
## Encoder
prediction_origin, feature_map_origin = discriminator(inputs)
prediction_decoded, feature_map_decoded = discriminator(decoded, reuse=True)
prediction_random, _ = discriminator(random_image, reuse=True)
## Decoder
feature_wise_loss = \
self._get_mean_square(feature_map_decoded, feature_map_origin)
kullback_leibler_divergence = \
self._get_kullback_leibler_divergence(mean, log_var)
encoder_loss = self.reconstruction_weight_against_detail *\
tf.reduce_mean(feature_wise_loss + kullback_leibler_divergence)
prediction_gan = merge([prediction_decoded, prediction_random],
'concat', axis=0)
y_gan = merge([is_true, is_true], 'concat', axis=0)
gan_generator_loss = crossentropy(prediction_gan, y_gan)
decoder_loss = (encoder_loss + gan_generator_loss) * 0.5
## Discriminator
prediction_fake = merge([prediction_decoded, prediction_random],
'concat', axis=0)
y_fake = merge([is_false, is_false], 'concat', axis=0)
real_loss = crossentropy(prediction_origin, is_true)
fake_loss = crossentropy(prediction_fake, y_fake)
discriminator_loss = (real_loss + fake_loss) * 0.5
# Trainer
encoder_op = TrainOp(
loss=encoder_loss,
optimizer=self._get_optimizer('encoder'),
batch_size=64,
trainable_vars=self._get_trainable_variables(encoder.scope),
name='Encoder')
decoder_op = TrainOp(
loss=decoder_loss,
optimizer=self._get_optimizer('decoder'),
batch_size=64,
trainable_vars=self._get_trainable_variables(decoder.scope),
name='Decoder')
discriminator_op = TrainOp(
loss=discriminator_loss,
optimizer=self._get_optimizer('discriminator'),
batch_size=64,
trainable_vars=self._get_trainable_variables(discriminator.scope),
name='Discriminator')
return Trainer([encoder_op, decoder_op, discriminator_op],
checkpoint_path=CHECKPOINT_PATH, max_checkpoints=1)
def _encode(self, mean, log_var):
epsilon = tf.random_normal(tf.shape(mean), name='Epsilon')
return mean + tf.exp(0.5 * log_var) * epsilon
def _get_z(self, shape):
z = tf.random_normal(shape, name='RandomZ')
return reshape(z, (-1, self.latent_dim))
def _get_kullback_leibler_divergence(self, mean, log_var):
square_mean = tf.pow(mean, 2)
variance = tf.exp(log_var)
kullback_leibler_divergence = \
tf.reduce_mean(1 + log_var - square_mean - variance,
reduction_indices=1)
kullback_leibler_divergence = \
- 0.5 * self.kullback_leibler_ratio * kullback_leibler_divergence
return kullback_leibler_divergence
def _get_mean_square(self, prediction, truth):
return tf.reduce_mean(tf.squared_difference(prediction, truth),
reduction_indices=(1, 2, 3))
def _get_optimizer(self, type_str):
if type_str in ['vae', 'encoder', 'decoder']:
learning_rate = self.vae_learning_rate
beta1 = self.vae_beta1
else: # 'discriminator'
learning_rate = self.discriminator_learning_rate
beta1 = self.discriminator_beta1
opt = adam(learning_rate=learning_rate, beta1=beta1)
return opt.get_tensor()
def _get_trainable_variables(self, scope):
return [v for v in tflearn.get_all_trainable_variable()
if scope + '/' in v.name]
def _get_input_tensor_by_name(self, name):
return tf.get_collection(tf.GraphKeys.INPUTS, scope=name)[0]
def train(self, x, n_sample=None, pretrain_vae_epoch=1,
pretrain_discriminator_epoch=1, train_epoch=10):
if n_sample == None:
n_sample = x.shape[0]
is_true = np.tile([0., 1.], [n_sample, 1])
is_false = np.tile([1., 0.], [n_sample, 1])
encoder = Encoder(self.n_first_channel, self.n_layer, self.latent_dim)
decoder = Decoder(self.img_shape, self.n_first_channel, self.n_layer)
discriminator = Discriminator(self.n_first_channel, self.n_layer)
with tf.Graph().as_default():
self.vae_pretrainer = self._build_vae_pretrainer(encoder, decoder)
trainer = self.vae_pretrainer
input_tensor = self._get_input_tensor_by_name('input_x')
feed_dict = {input_tensor:x}
trainer.fit(feed_dict, n_epoch=pretrain_vae_epoch,
snapshot_epoch=True, shuffle_all=True,
run_id='VAE_pretrain')
self.trained_values[encoder.scope] = \
self._get_trained_values(trainer, encoder.scope)
self.trained_values[decoder.scope] = \
self._get_trained_values(trainer, decoder.scope)
with tf.Graph().as_default():
self.discriminator_pretrainer = \
self._build_discriminator_pretrainer(decoder, discriminator)
trainer = self.discriminator_pretrainer
self._assign_values(trainer, decoder.scope)
input_tensor = self._get_input_tensor_by_name('input_x')
true_tensor = self._get_input_tensor_by_name('is_true')
false_tensor = self._get_input_tensor_by_name('is_false')
feed_dict = {input_tensor:x,
true_tensor:is_true,
false_tensor:is_false}
trainer.fit(feed_dict, n_epoch=pretrain_discriminator_epoch,
snapshot_epoch=True, shuffle_all=True,
run_id='Discriminator_pretrain')
self.trained_values[discriminator.scope] = \
self._get_trained_values(trainer, discriminator.scope)
with tf.Graph().as_default():
self.trainer = self._build_trainer(encoder, decoder, discriminator)
trainer = self.trainer
self._assign_values(trainer, encoder.scope)
self._assign_values(trainer, decoder.scope)
self._assign_values(trainer, discriminator.scope)
self._set_decoder(decoder)
input_tensor = self._get_input_tensor_by_name('input_x')
true_tensor = self._get_input_tensor_by_name('is_true')
false_tensor = self._get_input_tensor_by_name('is_false')
feed_dict = {input_tensor:x,
true_tensor:is_true,
false_tensor:is_false}
self.trainer.fit([feed_dict] * 3, n_epoch=train_epoch,
snapshot_step=1000, snapshot_epoch=False,
shuffle_all=True, run_id='VAEGAN',
callbacks=[CustomCallback(self)])
def _get_trained_values(self, trainer, scope):
return {v.name:tflearn.variables.get_value(v, session=trainer.session)
for v in self._get_trainable_variables(scope)}
def _assign_values(self, trainer, scope):
[trainer.session.run(v.assign(self.trained_values[scope][v.name]))
for v in self._get_trainable_variables(scope)]
def _set_decoder(self, decoder):
with self.decoder_graph.as_default():
inputs = input_data(shape=(None, self.latent_dim))
net = decoder(inputs)
self.decoder = DNN(net)
def decode(self, z):
with self.decoder_graph.as_default():
return self.decoder.predict(z)
class Encoder(object):
def __init__(self, n_first_channel, n_layer, latent_dim):
self.n_first_channel = n_first_channel
self.n_layer = n_layer
self.latent_dim = latent_dim
self.scope = 'Encoder'
def __call__(self, x, reuse=False):
net = x
for i in range(self.n_layer):
n_channel = self.n_first_channel * 2 ** i
net = conv(net, n_channel, 4, strides=2, reuse=reuse,
scope='{s}/Conv_{n}'.format(s=self.scope, n=i))
net = bn(net, reuse=reuse,
scope='{s}/BN_{n}'.format(s=self.scope, n=i))
net = relu(net)
# net = softplus(net)
mean = fc(net, self.latent_dim, reuse=reuse,
scope='{s}/Mean'.format(s=self.scope))
log_var = fc(net, self.latent_dim, reuse=reuse,
scope='{s}/LogVariance'.format(s=self.scope))
return mean, log_var
class Decoder(object):
def __init__(self, img_shape, n_first_channel, n_layer):
self.img_size = img_shape[:2]
self.color_channel = img_shape[2]
self.n_first_channel = n_first_channel * 2 ** (n_layer - 1)
self.n_layer = n_layer
self.scope = 'Decoder'
def __call__(self, z, reuse=False):
net = z
feature_height = self.img_size[0] // 2 ** self.n_layer
feature_width = self.img_size[1] // 2 ** self.n_layer
feature_channel = self.n_first_channel
n_units = feature_height * feature_width * feature_channel
net = fc(net, n_units, reuse=reuse, scope='{s}/FC'.format(s=self.scope))
shape = [-1, feature_height, feature_width, feature_channel]
net = reshape(net, shape)
for i in range(self.n_layer):
feature_height *= 2
feature_width *= 2
if i < self.n_layer - 1:
feature_channel //= 2
else:
feature_channel = self.color_channel
net = bn(net, reuse=reuse,
scope='{s}/BN_{n}'.format(s=self.scope, n=i))
net = relu(net)
# net = elu(net)
net = conv_t(net, feature_channel, 4,
[feature_height, feature_width], strides=2,
reuse=reuse,
scope='{s}/ConvT_{n}'.format(s=self.scope, n=i))
net = sigmoid(net)
return net
class Discriminator(object):
def __init__(self, n_first_channel, n_layer):
self.n_first_channel = n_first_channel
self.n_layer = n_layer
self.scope = 'Discriminator'
def __call__(self, x, reuse=False):
net = x
for i in range(self.n_layer):
net = conv(net, self.n_first_channel * 2 ** i, 4, reuse=reuse,
scope='{s}/Conv_{n}'.format(s=self.scope, n=i))
net = max_pool(net, 2)
net = bn(net, reuse=reuse,
scope='{s}/BN_{n}'.format(s=self.scope, n=i))
# net = relu(net)
net = elu(net)
if i == self.n_layer - 1:
feature_reconstruction = net
net = fc(net, 2, reuse=reuse, scope='{s}/FC'.format(s=self.scope))
net = softmax(net)
return net, feature_reconstruction
class CustomCallback(tflearn.callbacks.Callback):
def __init__(self, model, n_side=10):
self.model = model
self.n_side = n_side
self.sample_z = np.random.normal(size=(n_side ** 2, model.latent_dim))
def _save(self, name, z):
model = self.model
n_side = self.n_side
img_height = model.img_shape[0]
img_width = model.img_shape[1]
img_channel = model.img_shape[2]
image = np.ndarray(shape=(n_side * img_height,
n_side * img_width,
img_channel),
dtype=np.float32)
model.trained_values = {
scope:model._get_trained_values(model.trainer, scope)
for scope in model.trained_values}
with model.decoder_graph.as_default():
[model._assign_values(model.decoder, scope)
for scope in model.trained_values]
decoded = model.decode(z)
for y in range(n_side):
for x in range(n_side):
image[y * img_height : (y + 1) * img_height,
x * img_width : (x + 1) * img_width,
:] = decoded[x + y * n_side]
image = np.clip(image, 0, 1)
image *= 255
io.imsave(name, image.astype(np.uint8))
def on_batch_end(self, training_state, snapshot=False):
if snapshot:
step = training_state.step
file_name = '{path}image-{step}.png'.format(path=CHECKPOINT_PATH,
step=step)
self._save(file_name, self.sample_z)
def on_train_end(self, training_state):
latent_dim = self.model.latent_dim
sample_z = np.ndarray(shape=(self.n_side ** 2, latent_dim),
dtype=np.float32)
for row in range(self.n_side):
start = np.random.normal(size=latent_dim)
stop = np.random.normal(size=latent_dim)
z_rows = np.array([np.linspace(start[i], stop[i], num=self.n_side)
for i in range(latent_dim)]).T
sample_z[row * self.n_side : (row + 1) * self.n_side, :] = z_rows
file_name = '{path}image-final.png'.format(path=CHECKPOINT_PATH)
self._save(file_name, sample_z)
(X, Y), (testX, testY) = tflearn.datasets.cifar10.load_data()
X = np.concatenate((X, testX), axis=0)
Y = np.concatenate((Y, testY), axis=0)
X = X[Y == 1]
img_shape = X.shape[1:]
vaegan = VAEGAN(img_shape=img_shape, n_first_channel=64, n_layer=4,
latent_dim=32, kullback_leibler_ratio=0.01,
reconstruction_weight_against_detail=50.0)
vaegan.train(X, pretrain_vae_epoch=1, pretrain_discriminator_epoch=10,
train_epoch=100)