16.03.2021 Views

Advanced Deep Learning with Keras

Create successful ePaper yourself

Turn your PDF publications into a flip-book with our unique Google optimized e-Paper software.

[ 193 ]

Chapter 6

enc1.trainable = False

# discriminator1 weights frozen

dis1.trainable = False

gen1_inputs = [labels, z1]

gen1_outputs = gen1(gen1_inputs)

adv1_outputs = dis1(gen1_outputs) + [enc1(gen1_outputs)]

# labels + z1 to prob labels are real + z1 recon + feature1 recon

adv1 = Model(gen1_inputs, adv1_outputs, name="adv1")

# loss functions: 1) prob labels are real (adversarial1 loss)

# 2) Q network 1 loss (entropy1 loss)

# 3) conditional1 loss (classifier error)

loss_weights = [1.0, 1.0, 1.0]

loss = ['binary_crossentropy', 'mse', 'categorical_crossentropy']

adv1.compile(loss=loss,

loss_weights=loss_weights,

optimizer=optimizer,

metrics=['accuracy'])

adv1.summary()

# train discriminator and adversarial networks

models = (enc0, enc1, gen0, gen1, dis0, dis1, adv0, adv1)

params = (batch_size, train_steps, num_labels, z_dim, model_name)

train(models, data, params)

Finally, the training function bears a resemblance to a typical GAN training except

that we only train one GAN at a time (that is, GAN 1

then GAN 0

). The code is shown

in Listing 6.2.5. It's worth noting that the training sequence is:

1. Discriminator 1

and Q 1

networks by minimizing the discriminator and entropy

losses

2. Discriminator 0

and Q 0

networks by minimizing the discriminator and entropy

losses

3. Adversarial 1

network by minimizing the adversarial, entropy, and conditional

losses

4. Adversarial 0

network by minimizing the adversarial, entropy, and conditional

losses

Listing 6.2.5, stackedgan-mnist-6.2.1.py shows us training the StackedGAN

in Keras:

def train(models, data, params):

"""Train the discriminator and adversarial Networks

Alternately train discriminator and adversarial networks by batch.

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!