16.03.2021 Views

Advanced Deep Learning with Keras

Create successful ePaper yourself

Turn your PDF publications into a flip-book with our unique Google optimized e-Paper software.

Disentangled Representation GANs

data = (x_train, y_train)

params = (batch_size, latent_size, train_steps, num_labels,

model_name)

train(models, data, params)

As far as the training is concerned, we can see that InfoGAN is similar to ACGAN

except that we need to supply c for the continuous code. c is drawn from normal

distribution with a standard deviation of 0.5 and mean of 0.0. We'll use randomly

sampled labels for the fake data and dataset class labels for the real data to represent

discrete latent code. Following listing highlights the changes made on the training

function. Similar to all previous GANs, the discriminator and generator (through

adversarial) are trained alternately. During adversarial training, the discriminator

weights are frozen. Sample generator output images are saved every 500 interval

steps by using the gan.py plot_images() function.

Listing 6.1.4, infogan-mnist-6.1.1.py shows us how the training function for

InfoGAN is similar to ACGAN. The only difference is that we supply continuous

codes sampled from a normal distribution:

def train(models, data, params):

"""Train the Discriminator and Adversarial networks

Alternately train discriminator and adversarial networks by batch.

Discriminator is trained first with real and fake images,

corresponding one-hot labels and continuous codes.

Adversarial is trained next with fake images pretending to be

real,

corresponding one-hot labels and continous codes.

Generate sample images per save_interval.

# Arguments

models (Models): Generator, Discriminator, Adversarial models

data (tuple): x_train, y_train data

params (tuple): Network parameters

"""

# the GAN models

generator, discriminator, adversarial = models

# images and their one-hot labels

x_train, y_train = data

# network parameters

batch_size, latent_size, train_steps, num_labels, model_name =

params

# the generator image is saved every 500 steps

save_interval = 500

# noise vector to see how the generator output evolves

[ 174 ]

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!