Advanced Deep Learning with Keras
Chapter 71) Build target and source discriminators2) Build target and source generators3) Build the adversarial networkArguments:shapes (tuple): source and target shapessource_name (string): string to be appended on dis/gen modelstarget_name (string): string to be appended on dis/gen modelskernel_size (int): kernel size for the encoder/decoder or dis/genmodelspatchgan (bool): whether to use patchgan on discriminatoridentity (bool): whether to use identity lossReturns:(list): 2 generator, 2 discriminator, and 1 adversarial models"""source_shape, target_shape = shapeslr = 2e-4decay = 6e-8gt_name = "gen_" + target_namegs_name = "gen_" + source_namedt_name = "dis_" + target_nameds_name = "dis_" + source_name# build target and source generatorsg_target = build_generator(source_shape,target_shape,kernel_size=kernel_size,name=gt_name)g_source = build_generator(target_shape,source_shape,kernel_size=kernel_size,name=gs_name)print('---- TARGET GENERATOR ----')g_target.summary()print('---- SOURCE GENERATOR ----')g_source.summary()# build target and source discriminatorsd_target = build_discriminator(target_shape,patchgan=patchgan,[ 219 ]
Cross-Domain GANskernel_size=kernel_size,name=dt_name)d_source = build_discriminator(source_shape,patchgan=patchgan,kernel_size=kernel_size,name=ds_name)print('---- TARGET DISCRIMINATOR ----')d_target.summary()print('---- SOURCE DISCRIMINATOR ----')d_source.summary()optimizer = RMSprop(lr=lr, decay=decay)d_target.compile(loss='mse',optimizer=optimizer,metrics=['accuracy'])d_source.compile(loss='mse',optimizer=optimizer,metrics=['accuracy'])# freeze the discriminator weights in the adversarial modeld_target.trainable = Falsed_source.trainable = False# build the computational graph for the adversarial model# forward cycle network and target discriminatorsource_input = Input(shape=source_shape)fake_target = g_target(source_input)preal_target = d_target(fake_target)reco_source = g_source(fake_target)# backward cycle network and source discriminatortarget_input = Input(shape=target_shape)fake_source = g_source(target_input)preal_source = d_source(fake_source)reco_target = g_target(fake_source)# if we use identity loss, add 2 extra loss terms# and outputsif identity:iden_source = g_source(source_input)iden_target = g_target(target_input)loss = ['mse', 'mse', 'mae', 'mae', 'mae', 'mae']loss_weights = [1., 1., 10., 10., 0.5, 0.5]inputs = [source_input, target_input]outputs = [preal_source,[ 220 ]
- Page 185 and 186: Disentangled Representation GANs# A
- Page 187 and 188: Disentangled Representation GANsif
- Page 189 and 190: Disentangled Representation GANsLis
- Page 191 and 192: Disentangled Representation GANsdat
- Page 193 and 194: Disentangled Representation GANsy[b
- Page 195 and 196: Disentangled Representation GANspyt
- Page 197 and 198: Disentangled Representation GANsThe
- Page 199 and 200: Disentangled Representation GANsSta
- Page 201 and 202: Disentangled Representation GANs( )
- Page 203 and 204: Disentangled Representation GANsThe
- Page 205 and 206: Disentangled Representation GANsfea
- Page 207 and 208: Disentangled Representation GANs# f
- Page 209 and 210: Disentangled Representation GANslat
- Page 211 and 212: Disentangled Representation GANsDis
- Page 213 and 214: Disentangled Representation GANsz_d
- Page 215 and 216: Disentangled Representation GANs2.
- Page 217 and 218: Disentangled Representation GANsFig
- Page 220 and 221: Cross-Domain GANsIn computer vision
- Page 222 and 223: Chapter 7There are many more exampl
- Page 224 and 225: The CycleGAN ModelFigure 7.1.3 show
- Page 226 and 227: Chapter 7Repeat for n training step
- Page 228 and 229: Chapter 7Implementing CycleGAN usin
- Page 230 and 231: filters=16,kernel_size=3,strides=2,
- Page 232 and 233: Chapter 7kernel_size=kernel_size)e3
- Page 234 and 235: Listing 7.1.3, cyclegan-7.1.1.py sh
- Page 238 and 239: Chapter 7preal_target,reco_source,r
- Page 240 and 241: size=batch_size)real_source = sourc
- Page 242 and 243: Chapter 7returndirs=dirs,show=True)
- Page 244 and 245: Chapter 7Figure 7.1.10: Color (from
- Page 246 and 247: [ 229 ]Chapter 7titles = ('MNIST pr
- Page 248 and 249: Chapter 7Figure 7.1.13: Style trans
- Page 250 and 251: Chapter 7Figure 7.1.15: The backwar
- Page 252: Chapter 7References1. Yuval Netzer
- Page 255 and 256: Variational Autoencoders (VAEs)In t
- Page 257 and 258: Variational Autoencoders (VAEs)Typi
- Page 259 and 260: Variational Autoencoders (VAEs)For
- Page 261 and 262: Variational Autoencoders (VAEs)VAEs
- Page 263 and 264: Variational Autoencoders (VAEs)outp
- Page 265 and 266: Variational Autoencoders (VAEs)Figu
- Page 267 and 268: Variational Autoencoders (VAEs)The
- Page 269 and 270: Variational Autoencoders (VAEs)Figu
- Page 271 and 272: Variational Autoencoders (VAEs)Prec
- Page 273 and 274: Variational Autoencoders (VAEs)shap
- Page 275 and 276: Variational Autoencoders (VAEs)cvae
- Page 277 and 278: Variational Autoencoders (VAEs)Figu
- Page 279 and 280: Variational Autoencoders (VAEs)Figu
- Page 281 and 282: Variational Autoencoders (VAEs)In F
- Page 283 and 284: Variational Autoencoders (VAEs)Figu
- Page 285 and 286: Variational Autoencoders (VAEs)The
Chapter 7
1) Build target and source discriminators
2) Build target and source generators
3) Build the adversarial network
Arguments:
shapes (tuple): source and target shapes
source_name (string): string to be appended on dis/gen models
target_name (string): string to be appended on dis/gen models
kernel_size (int): kernel size for the encoder/decoder or dis/gen
models
patchgan (bool): whether to use patchgan on discriminator
identity (bool): whether to use identity loss
Returns:
(list): 2 generator, 2 discriminator, and 1 adversarial models
"""
source_shape, target_shape = shapes
lr = 2e-4
decay = 6e-8
gt_name = "gen_" + target_name
gs_name = "gen_" + source_name
dt_name = "dis_" + target_name
ds_name = "dis_" + source_name
# build target and source generators
g_target = build_generator(source_shape,
target_shape,
kernel_size=kernel_size,
name=gt_name)
g_source = build_generator(target_shape,
source_shape,
kernel_size=kernel_size,
name=gs_name)
print('---- TARGET GENERATOR ----')
g_target.summary()
print('---- SOURCE GENERATOR ----')
g_source.summary()
# build target and source discriminators
d_target = build_discriminator(target_shape,
patchgan=patchgan,
[ 219 ]