Skip to content
Snippets Groups Projects

params search

Merged Li, Honglin (PG/R - Elec Electronic Eng) requested to merge multi_head into master
5 files
+ 42
16
Compare changes
  • Side-by-side
  • Inline
Files
5
+ 7
5
@@ -11,19 +11,21 @@ from tqdm import tqdm
class Train_Gans:
def __init__(self, args):
self.args = args
optimizer = Adam(args.lr, args.decay)
D_optimizer = Adam(args.d_lr, args.d_decay)
G_optimizer = Adam(args.g_lr, args.g_decay)
# Baseline model
self.baseline_model = build_generator(args, beta=1)
self.baseline_model.compile(loss='categorical_crossentropy',
optimizer=optimizer,
optimizer=D_optimizer,
metrics=['accuracy'])
# Build and compile the discriminator
self.discriminator = build_discriminator(args)
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
optimizer=D_optimizer,
metrics=['accuracy'])
# Build the generator
self.generator = build_generator(args)
input_data = Input(shape=args.img_shape) if args.conv else Input(shape=(np.prod(args.img_shape),))
@@ -32,7 +34,7 @@ class Train_Gans:
valid = self.discriminator([input_data, fake_labels])
self.combined = Model(input_data, valid)
self.combined.compile(loss=['binary_crossentropy'],
optimizer=optimizer, metrics=['accuracy'])
optimizer=G_optimizer, metrics=['accuracy'])
self.history = {'Loss': {'Discriminator': [], 'Generator': []},
'Accuracy': {'Discriminator': [], 'Generator': []}
}
@@ -57,7 +59,7 @@ class Train_Gans:
d_loss_fake = self.discriminator.train_on_batch([valid_imgs, gen_labels], fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
g_loss = self.combined.train_on_batch(valid_imgs, valid)
desc = "Epoch %d/%d [D loss: %f, acc.: %.2f%%] [G loss: %f, acc.: %.2f%%]" \
desc = "Epoch %d/%d [D loss: %f, acc.: %.2f%%] [G loss: %f, acc.: %.2f%%] " \
% (epoch, self.args.epochs, d_loss[0], 100 * d_loss[1], g_loss[0], 100 * g_loss[1])
self.history['Loss']['Discriminator'].append(d_loss[0])
self.history['Accuracy']['Discriminator'].append(d_loss[1])
Loading