-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
38 lines (32 loc) · 1.57 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
## Source https://www.tensorflow.org/tutorials/generative/dcgan#create_a_gif
def generate_and_save_images(model, epoch, test_input):
# Notice `training` is set to False.
# This is so all layers run in inference mode (batchnorm).
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(10,10))
for i in range(25):
plt.subplot(5, 5, i+1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='binary')
plt.axis('off')
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
def train_dcgan(gan, dataset, batch_size, num_features, epochs=5):
generator, discriminator = gan.layers
for epoch in tqdm(range(epochs)):
print("Epoch {}/{}".format(epoch + 1, epochs))
for X_batch in dataset:
noise = tf.random.normal(shape=[batch_size, num_features])
generated_images = generator(noise)
X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
discriminator.trainable = True
discriminator.train_on_batch(X_fake_and_real, y1)
noise = tf.random.normal(shape=[batch_size, num_features])
y2 = tf.constant([[1.]] * batch_size)
discriminator.trainable = False
gan.train_on_batch(noise, y2)
# Produce images for the GIF as we go
display.clear_output(wait=True)
generate_and_save_images(generator, epoch + 1, seed)
display.clear_output(wait=True)
generate_and_save_images(generator, epochs, seed)