Advanced Deep Learning with Keras

fourpersent2020
from fourpersent2020 More from this publisher
16.03.2021 Views

Chapter 5As shown in Listing 5.3.3, the discriminator and adversarial models are modifiedto accommodate the changes in the discriminator network. We now have two lossfunctions. The first is the original binary cross-entropy to train the discriminatorin estimating the probability if the input image is real. The second is the imageclassifier predicting the class label. The output is a one-hot vector of 10 dimensions.Referring to Listing 5.3.3, acgan-mnist-5.3.1.py, where highlighted are thechanges implemented in the discriminator and adversarial models to accommodatethe image classifier of the discriminator network. The two loss functions correspondto the two outputs of the discriminator:def build_and_train_models():# load MNIST dataset(x_train, y_train), (_, _) = mnist.load_data()# reshape data for CNN as (28, 28, 1) and normalizeimage_size = x_train.shape[1]x_train = np.reshape(x_train, [-1, image_size, image_size, 1])x_train = x_train.astype('float32') / 255# train labelsnum_labels = len(np.unique(y_train))y_train = to_categorical(y_train)model_name = "acgan_mnist"# network parameterslatent_size = 100batch_size = 64train_steps = 40000lr = 2e-4decay = 6e-8input_shape = (image_size, image_size, 1)label_shape = (num_labels, )# build discriminator Modelinputs = Input(shape=input_shape, name='discriminator_input')# call discriminator builder with 2 outputs,# pred source and labelsdiscriminator = gan.discriminator(inputs, num_labels=num_labels)# [1] uses Adam, but discriminator converges easily with RMSpropoptimizer = RMSprop(lr=lr, decay=decay)# 2 loss fuctions: 1) probability image is real# 2) class label of the imageloss = ['binary_crossentropy', 'categorical_crossentropy'][ 153 ]

Improved GANsdiscriminator.compile(loss=loss,optimizer=optimizer,metrics=['accuracy'])discriminator.summary()# build generator modelinput_shape = (latent_size, )inputs = Input(shape=input_shape, name='z_input')labels = Input(shape=label_shape, name='labels')# call generator builder with input labelsgenerator = gan.generator(inputs, image_size, labels=labels)generator.summary()# build adversarial model = generator + discriminatoroptimizer = RMSprop(lr=lr*0.5, decay=decay*0.5)# freeze the weights of discriminator# during adversarial trainingdiscriminator.trainable = Falseadversarial = Model([inputs, labels],discriminator(generator([inputs, labels])),name=model_name)# same 2 loss fuctions: 1) probability image is real# 2) class label of the imageadversarial.compile(loss=loss,optimizer=optimizer,metrics=['accuracy'])adversarial.summary()# train discriminator and adversarial networksmodels = (generator, discriminator, adversarial)data = (x_train, y_train)params = (batch_size, latent_size, train_steps, num_labels, model_name)train(models, data, params)In Listing 5.3.4, we highlight the changes implemented in the training routine. Themain difference compared to CGAN code is that the output label must be suppliedduring discriminator and adversarial training.As seen in Listing 5.3.4, acgan-mnist-5.3.1.py, the changes implemented in thetrain function are highlighted:def train(models, data, params):"""Train the discriminator and adversarial Networks[ 154 ]

Improved GANs

discriminator.compile(loss=loss,

optimizer=optimizer,

metrics=['accuracy'])

discriminator.summary()

# build generator model

input_shape = (latent_size, )

inputs = Input(shape=input_shape, name='z_input')

labels = Input(shape=label_shape, name='labels')

# call generator builder with input labels

generator = gan.generator(inputs, image_size, labels=labels)

generator.summary()

# build adversarial model = generator + discriminator

optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5)

# freeze the weights of discriminator

# during adversarial training

discriminator.trainable = False

adversarial = Model([inputs, labels],

discriminator(generator([inputs, labels])),

name=model_name)

# same 2 loss fuctions: 1) probability image is real

# 2) class label of the image

adversarial.compile(loss=loss,

optimizer=optimizer,

metrics=['accuracy'])

adversarial.summary()

# train discriminator and adversarial networks

models = (generator, discriminator, adversarial)

data = (x_train, y_train)

params = (batch_size, latent_size, train_steps, num_labels, model_

name)

train(models, data, params)

In Listing 5.3.4, we highlight the changes implemented in the training routine. The

main difference compared to CGAN code is that the output label must be supplied

during discriminator and adversarial training.

As seen in Listing 5.3.4, acgan-mnist-5.3.1.py, the changes implemented in the

train function are highlighted:

def train(models, data, params):

"""Train the discriminator and adversarial Networks

[ 154 ]

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!