16.03.2021 Views

Advanced Deep Learning with Keras

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

Disentangled Representation GANs

# first output is probability that feature1 is real

f1_source = Dense(1)(x)

f1_source = Activation('sigmoid', name='feature1_source')

(f1_source)

# z1 reonstruction (Q1 network)

z1_recon = Dense(z_dim)(x)

z1_recon = Activation('tanh', name='z1')(z1_recon)

discriminator_outputs = [f1_source, z1_recon]

dis1 = Model(inputs, discriminator_outputs, name='dis1')

return dis1

With all builder functions available, StackedGAN is assembled in Listing 6.2.4. Before

training StackedGAN, the encoder is pretrained. Note that we already incorporated

the three generator loss functions (adversarial, conditional, and entropy) in the

adversarial model training. The Q-Network shares some common layers with

the discriminator model. Therefore, its loss function is also incorporated in the

discriminator model training.

Listing 6.2.4, stackedgan-mnist-6.2.1.py. Building StackedGAN in Keras:

def build_and_train_models():

# load MNIST dataset

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# reshape and normalize images

image_size = x_train.shape[1]

x_train = np.reshape(x_train, [-1, image_size, image_size, 1])

x_train = x_train.astype('float32') / 255

x_test = np.reshape(x_test, [-1, image_size, image_size, 1])

x_test = x_test.astype('float32') / 255

# number of labels

num_labels = len(np.unique(y_train))

# to one-hot vector

y_train = to_categorical(y_train)

y_test = to_categorical(y_test)

model_name = "stackedgan_mnist"

# network parameters

batch_size = 64

train_steps = 40000

[ 190 ]

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!