16.03.2021 Views

Advanced Deep Learning with Keras

Create successful ePaper yourself

Turn your PDF publications into a flip-book with our unique Google optimized e-Paper software.

Chapter 4

discriminator.trainable = False

# adversarial = generator + discriminator

adversarial = Model(inputs,

discriminator(generator(inputs)),

name=model_name)

adversarial.compile(loss='binary_crossentropy',

optimizer=optimizer,

metrics=['accuracy'])

adversarial.summary()

# train discriminator and adversarial networks

models = (generator, discriminator, adversarial)

params = (batch_size, latent_size, train_steps, model_name)

train(models, x_train, params)

Listing 4.2.4 shows the function dedicated to training the discriminator and

adversarial networks. Due to custom training, the usual fit() function is not going

to be used. Instead, train_on_batch() is called up to run a single gradient update

for the given batch of data. The generator is then trained via an adversarial network.

The training first randomly picks a batch of real images from the dataset. This is

labeled as real (1.0). Then a batch of fake images will be generated by the generator.

This is labeled as fake (0.0). The two batches are concatenated and are used to train

the discriminator.

After this is completed, a new batch of fake images will be generated by the

generator and labeled as real (1.0). This batch will be used to train the adversarial

network. The two networks are trained alternately for about 40,000 steps. At regular

intervals, the generated MNIST digits based on a certain noise vector are saved on

the filesystem. At the last training step, the network has converged. The generator

model is also saved on a file so we can easily reuse the trained model for future

MNIST digits generation. However, only the generator model is saved since that

is the useful part of GANs in the generation of new MNIST digits. For example,

we can generate new and random MNIST digits by executing:

python3 dcgan-mnist-4.2.1.py --generator=dcgan_mnist.h5

Listing 4.2.4, dcgan-mnist-4.2.1.py shows us the function to train the

discriminator and adversarial networks:

def train(models, x_train, params):

"""Train the Discriminator and Adversarial Networks

Alternately train Discriminaor and Adversarial networks by batch.

Discriminator is trained first with properly real and fake images.

Adversarial is trained next with fake images pretending to be real

Generate sample images per save_interval.

[ 111 ]

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!