pdfcoffee

soumyasankar99
from soumyasankar99 More from this publisher
09.05.2023 Views

Chapter 6d_loss = (dA_loss + dB_loss) / 2# Translate images back to original domainreconstr_A = generator_BA(fake_B, training=True)reconstr_B = generator_AB(fake_A, training=True)id_A = generator_BA(imgs_A, training=True)id_B = generator_AB(imgs_B, training=True)gen_loss = tf.math.reduce_sum([1 * tf.math.reduce_mean(mean_squared_error(logits_fake_A,valid)),1 * tf.math.reduce_mean(mean_squared_error(logits_fake_B,valid)),10 * tf.math.reduce_mean(mean_squared_error(reconstr_A,imgs_A)),10 * tf.math.reduce_mean(mean_squared_error(reconstr_B,imgs_B)),0.1 * tf.math.reduce_mean(mean_squared_error(id_A,imgs_A)),0.1 * tf.math.reduce_mean(mean_squared_error(id_B,imgs_B)),])gradients_of_d = d_tape.gradient(d_loss, discriminator_A.trainable_variables + discriminator_B.trainable_variables)discriminator_optimizer.apply_gradients(zip(gradients_of_d,discriminator_A.trainable_variables + discriminator_B.trainable_variables))gradients_of_generator = g.gradient(gen_loss, generator_AB.trainable_variables + generator_BA.trainable_variables)optimizer.apply_gradients(zip(gradients_of_generator, generator_AB.trainable_variables + generator_BA.trainable_variables))return dA_loss, dB_loss, gen_lossWe define checkpoints to save the model weights:checkpoint_dird_A = './training_checkpointsd_A'checkpoint_prefixd_A = os.path.join(checkpoint_dird_A, "ckpt_{epoch}")checkpoint_dird_B = './training_checkpointsd_B'checkpoint_prefixd_B = os.path.join(checkpoint_dird_B, "ckpt_{epoch}")checkpoint_dirg_AB = './training_checkpointsg_AB'[ 225 ]

Generative Adversarial Networkscheckpoint_prefixg_AB = os.path.join(checkpoint_dirg_AB, "ckpt_{epoch}")checkpoint_dirg_BA = './training_checkpointsg_BA'checkpoint_prefixg_BA = os.path.join(checkpoint_dirg_BA, "ckpt_{epoch}")Let us now combine it all and train the network for 50 epochs. Please remember thatin the paper, the test network was trained for 200 epochs, so our results will not bethat good:def train(trainA_, trainB_, epochs):for epoch in range(epochs):start = time.time()for batch_i, (imgs_A, imgs_B) in enumerate(zip(trainA_,trainB_)):dA_loss, dB_loss, g_loss = train_batch(imgs_A, imgs_B)if batch_i % 1000 == 0:test_imgA = next(iter(test_A))test_imgB = next(iter(test_B))print ('Time taken for epoch {} batch index {} is {}seconds\n'.format(epoch, batch_i, time.time()-start))print("discriminator A: ", dA_loss.numpy())print("discriminator B: ", dB_loss.numpy())print("generator: {}\n".format(g_loss))fig, axs = plt.subplots(2, 2, figsize=(6, 3),sharey=True, sharex=True)gen_outputA = generator_AB(test_imgA, training=False)gen_outputB = generator_BA(test_imgB, training=False)axs[0,0].imshow(test_imgA[0]*0.5 + 0.5)axs[0,0].set_title("Generator A Input")axs[0,1].imshow(gen_outputA[0]*0.5 + 0.5)axs[0,1].set_title("Generator A Output")axs[1,0].imshow(test_imgB[0]*0.5 + 0.5)axs[1,0].set_title("Generator B Input")axs[1,1].imshow(gen_outputB[0]*0.5 + 0.5)axs[1,1].set_title("Generator B Output")plt.show()discriminator_A.save_weights(checkpoint_prefixd_A.format(epoch=epoch))discriminator_B.save_weights(checkpoint_prefixd_B.format(epoch=epoch))[ 226 ]

Chapter 6

d_loss = (dA_loss + dB_loss) / 2

# Translate images back to original domain

reconstr_A = generator_BA(fake_B, training=True)

reconstr_B = generator_AB(fake_A, training=True)

id_A = generator_BA(imgs_A, training=True)

id_B = generator_AB(imgs_B, training=True)

gen_loss = tf.math.reduce_sum([

1 * tf.math.reduce_mean(mean_squared_error(logits_fake_A,

valid)),

1 * tf.math.reduce_mean(mean_squared_error(logits_fake_B,

valid)),

10 * tf.math.reduce_mean(mean_squared_error(reconstr_A,

imgs_A)),

10 * tf.math.reduce_mean(mean_squared_error(reconstr_B,

imgs_B)),

0.1 * tf.math.reduce_mean(mean_squared_error(id_A,

imgs_A)),

0.1 * tf.math.reduce_mean(mean_squared_error(id_B,

imgs_B)),

])

gradients_of_d = d_tape.gradient(d_loss, discriminator_A.

trainable_variables + discriminator_B.trainable_variables)

discriminator_optimizer.apply_gradients(zip(gradients_of_d,

discriminator_A.trainable_variables + discriminator_B.trainable_

variables))

gradients_of_generator = g.gradient(gen_loss, generator_

AB.trainable_variables + generator_BA.trainable_variables)

optimizer.apply_gradients(zip(gradients_of_generator, generator_

AB.trainable_variables + generator_BA.trainable_variables))

return dA_loss, dB_loss, gen_loss

We define checkpoints to save the model weights:

checkpoint_dird_A = './training_checkpointsd_A'

checkpoint_prefixd_A = os.path.join(checkpoint_dird_A, "ckpt_{epoch}")

checkpoint_dird_B = './training_checkpointsd_B'

checkpoint_prefixd_B = os.path.join(checkpoint_dird_B, "ckpt_{epoch}")

checkpoint_dirg_AB = './training_checkpointsg_AB'

[ 225 ]

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!