16.03.2021 Views

Advanced Deep Learning with Keras

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

Improved GANs

The preceding table shows the comparison of the loss functions between GAN and

LSGAN. Minimizing Equation 5.2.1 or the discriminator loss function implies that

the MSE between real data classification and true label 1.0 should be close to zero.

In addition, the MSE between the fake data classification and the true label 0.0

should be close to zero.

Similar to GANs, the LSGAN discriminator is trained to classify real from fake

data samples. Minimizing Equation 5.2.2 means fooling the discriminator to think

that the generated fake sample data are real with label 1.0.

Implementing LSGAN using the DCGAN code in the previous chapter as the

basis requires few changes only. As shown in Listing 5.2.1, the discriminator

sigmoid activation is removed. The discriminator is built by calling:

discriminator = gan.discriminator(inputs, activation=None)

The generator is similar to the original DCGAN:

generator = gan.generator(inputs, image_size)

Both the discriminator and adversarial loss functions are replaced by mse. All the

network parameters are the same as in DCGAN. The network model of LSGAN

in Keras is similar to Figure 4.2.1 except that there is no linear or output activation.

The training process is similar to that seen in DCGAN and is provided by the utility

function:

gan.train(models, x_train, params)

Listing 5.2.1, lsgan-mnist-5.2.1.py shows how the discriminator and generator

are the same in DCGAN except for the discriminator output activation and the use

of MSE loss function:

def build_and_train_models():

# MNIST dataset

(x_train, _), (_, _) = mnist.load_data()

# reshape data for CNN as (28, 28, 1) and normalize

image_size = x_train.shape[1]

x_train = np.reshape(x_train, [-1, image_size, image_size, 1])

x_train = x_train.astype('float32') / 255

model_name = "lsgan_mnist"

# network parameters

# the latent or z vector is 100-dim

latent_size = 100

[ 144 ]

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!