Advanced Deep Learning with Keras
Chapter 5As shown in Listing 5.3.3, the discriminator and adversarial models are modifiedto accommodate the changes in the discriminator network. We now have two lossfunctions. The first is the original binary cross-entropy to train the discriminatorin estimating the probability if the input image is real. The second is the imageclassifier predicting the class label. The output is a one-hot vector of 10 dimensions.Referring to Listing 5.3.3, acgan-mnist-5.3.1.py, where highlighted are thechanges implemented in the discriminator and adversarial models to accommodatethe image classifier of the discriminator network. The two loss functions correspondto the two outputs of the discriminator:def build_and_train_models():# load MNIST dataset(x_train, y_train), (_, _) = mnist.load_data()# reshape data for CNN as (28, 28, 1) and normalizeimage_size = x_train.shape[1]x_train = np.reshape(x_train, [-1, image_size, image_size, 1])x_train = x_train.astype('float32') / 255# train labelsnum_labels = len(np.unique(y_train))y_train = to_categorical(y_train)model_name = "acgan_mnist"# network parameterslatent_size = 100batch_size = 64train_steps = 40000lr = 2e-4decay = 6e-8input_shape = (image_size, image_size, 1)label_shape = (num_labels, )# build discriminator Modelinputs = Input(shape=input_shape, name='discriminator_input')# call discriminator builder with 2 outputs,# pred source and labelsdiscriminator = gan.discriminator(inputs, num_labels=num_labels)# [1] uses Adam, but discriminator converges easily with RMSpropoptimizer = RMSprop(lr=lr, decay=decay)# 2 loss fuctions: 1) probability image is real# 2) class label of the imageloss = ['binary_crossentropy', 'categorical_crossentropy'][ 153 ]
Improved GANsdiscriminator.compile(loss=loss,optimizer=optimizer,metrics=['accuracy'])discriminator.summary()# build generator modelinput_shape = (latent_size, )inputs = Input(shape=input_shape, name='z_input')labels = Input(shape=label_shape, name='labels')# call generator builder with input labelsgenerator = gan.generator(inputs, image_size, labels=labels)generator.summary()# build adversarial model = generator + discriminatoroptimizer = RMSprop(lr=lr*0.5, decay=decay*0.5)# freeze the weights of discriminator# during adversarial trainingdiscriminator.trainable = Falseadversarial = Model([inputs, labels],discriminator(generator([inputs, labels])),name=model_name)# same 2 loss fuctions: 1) probability image is real# 2) class label of the imageadversarial.compile(loss=loss,optimizer=optimizer,metrics=['accuracy'])adversarial.summary()# train discriminator and adversarial networksmodels = (generator, discriminator, adversarial)data = (x_train, y_train)params = (batch_size, latent_size, train_steps, num_labels, model_name)train(models, data, params)In Listing 5.3.4, we highlight the changes implemented in the training routine. Themain difference compared to CGAN code is that the output label must be suppliedduring discriminator and adversarial training.As seen in Listing 5.3.4, acgan-mnist-5.3.1.py, the changes implemented in thetrain function are highlighted:def train(models, data, params):"""Train the discriminator and adversarial Networks[ 154 ]
- Page 119 and 120: Generative Adversarial Networks (GA
- Page 121 and 122: Generative Adversarial Networks (GA
- Page 123 and 124: Generative Adversarial Networks (GA
- Page 125 and 126: Generative Adversarial Networks (GA
- Page 127 and 128: Generative Adversarial Networks (GA
- Page 129 and 130: Generative Adversarial Networks (GA
- Page 131 and 132: Generative Adversarial Networks (GA
- Page 133 and 134: Generative Adversarial Networks (GA
- Page 135 and 136: Generative Adversarial Networks (GA
- Page 137 and 138: Generative Adversarial Networks (GA
- Page 139 and 140: Generative Adversarial Networks (GA
- Page 141 and 142: Generative Adversarial Networks (GA
- Page 143 and 144: Improved GANsIn summary, the goal o
- Page 145 and 146: Improved GANsThe intuition behind E
- Page 147 and 148: Improved GANsThis makes sense since
- Page 149 and 150: Improved GANsIn the context of GANs
- Page 151 and 152: Improved GANsFigure 5.1.3: Top: Tra
- Page 153 and 154: Improved GANsThe functions include:
- Page 155 and 156: Improved GANsmodels = (generator, d
- Page 157 and 158: Improved GANsfor layer in discrimin
- Page 159 and 160: Improved GANsFollowing figure shows
- Page 161 and 162: Improved GANsThe preceding table sh
- Page 163 and 164: Improved GANsFollowing figure shows
- Page 165 and 166: Improved GANsEssentially, in CGAN w
- Page 167 and 168: Improved GANslayer = Dense(layer_fi
- Page 169: Improved GANsx = BatchNormalization
- Page 173 and 174: Improved GANssize=batch_size)real_i
- Page 175 and 176: Improved GANsUnlike CGAN, the sampl
- Page 177 and 178: Improved GANsConclusionIn this chap
- Page 179 and 180: Disentangled Representation GANsIn
- Page 181 and 182: Disentangled Representation GANsInf
- Page 183 and 184: Disentangled Representation GANsFol
- Page 185 and 186: Disentangled Representation GANs# A
- Page 187 and 188: Disentangled Representation GANsif
- Page 189 and 190: Disentangled Representation GANsLis
- Page 191 and 192: Disentangled Representation GANsdat
- Page 193 and 194: Disentangled Representation GANsy[b
- Page 195 and 196: Disentangled Representation GANspyt
- Page 197 and 198: Disentangled Representation GANsThe
- Page 199 and 200: Disentangled Representation GANsSta
- Page 201 and 202: Disentangled Representation GANs( )
- Page 203 and 204: Disentangled Representation GANsThe
- Page 205 and 206: Disentangled Representation GANsfea
- Page 207 and 208: Disentangled Representation GANs# f
- Page 209 and 210: Disentangled Representation GANslat
- Page 211 and 212: Disentangled Representation GANsDis
- Page 213 and 214: Disentangled Representation GANsz_d
- Page 215 and 216: Disentangled Representation GANs2.
- Page 217 and 218: Disentangled Representation GANsFig
Improved GANs
discriminator.compile(loss=loss,
optimizer=optimizer,
metrics=['accuracy'])
discriminator.summary()
# build generator model
input_shape = (latent_size, )
inputs = Input(shape=input_shape, name='z_input')
labels = Input(shape=label_shape, name='labels')
# call generator builder with input labels
generator = gan.generator(inputs, image_size, labels=labels)
generator.summary()
# build adversarial model = generator + discriminator
optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5)
# freeze the weights of discriminator
# during adversarial training
discriminator.trainable = False
adversarial = Model([inputs, labels],
discriminator(generator([inputs, labels])),
name=model_name)
# same 2 loss fuctions: 1) probability image is real
# 2) class label of the image
adversarial.compile(loss=loss,
optimizer=optimizer,
metrics=['accuracy'])
adversarial.summary()
# train discriminator and adversarial networks
models = (generator, discriminator, adversarial)
data = (x_train, y_train)
params = (batch_size, latent_size, train_steps, num_labels, model_
name)
train(models, data, params)
In Listing 5.3.4, we highlight the changes implemented in the training routine. The
main difference compared to CGAN code is that the output label must be supplied
during discriminator and adversarial training.
As seen in Listing 5.3.4, acgan-mnist-5.3.1.py, the changes implemented in the
train function are highlighted:
def train(models, data, params):
"""Train the discriminator and adversarial Networks
[ 154 ]