09.05.2023 Views

pdfcoffee

You also want an ePaper? Increase the reach of your titles

YUMPU automatically turns print PDFs into web optimized ePapers that Google loves.

Generative Adversarial Networks

checkpoint_prefixg_AB = os.path.join(checkpoint_dirg_AB, "ckpt_

{epoch}")

checkpoint_dirg_BA = './training_checkpointsg_BA'

checkpoint_prefixg_BA = os.path.join(checkpoint_dirg_BA, "ckpt_

{epoch}")

Let us now combine it all and train the network for 50 epochs. Please remember that

in the paper, the test network was trained for 200 epochs, so our results will not be

that good:

def train(trainA_, trainB_, epochs):

for epoch in range(epochs):

start = time.time()

for batch_i, (imgs_A, imgs_B) in enumerate(zip(trainA_,

trainB_)):

dA_loss, dB_loss, g_loss = train_batch(imgs_A, imgs_B)

if batch_i % 1000 == 0:

test_imgA = next(iter(test_A))

test_imgB = next(iter(test_B))

print ('Time taken for epoch {} batch index {} is {}

seconds\n'.format(epoch, batch_i, time.time()-start))

print("discriminator A: ", dA_loss.numpy())

print("discriminator B: ", dB_loss.numpy())

print("generator: {}\n".format(g_loss))

fig, axs = plt.subplots(2, 2, figsize=(6, 3),

sharey=True, sharex=True)

gen_outputA = generator_AB(test_imgA, training=False)

gen_outputB = generator_BA(test_imgB, training=False)

axs[0,0].imshow(test_imgA[0]*0.5 + 0.5)

axs[0,0].set_title("Generator A Input")

axs[0,1].imshow(gen_outputA[0]*0.5 + 0.5)

axs[0,1].set_title("Generator A Output")

axs[1,0].imshow(test_imgB[0]*0.5 + 0.5)

axs[1,0].set_title("Generator B Input")

axs[1,1].imshow(gen_outputB[0]*0.5 + 0.5)

axs[1,1].set_title("Generator B Output")

plt.show()

discriminator_A.save_weights(checkpoint_prefixd_A.

format(epoch=epoch))

discriminator_B.save_weights(checkpoint_prefixd_B.

format(epoch=epoch))

[ 226 ]

Hooray! Your file is uploaded and ready to be published.

Saved successfully!

Ooh no, something went wrong!