Update wgan_train.py

enable training for gp
This commit is contained in:
kelvinkoh0308
2019-09-18 18:27:57 +09:00
committed by GitHub
parent 439e0a1bfa
commit 9e92b2d994

View File

@@ -69,7 +69,7 @@ def gradient_penalty(discriminator, batch_x, fake_image):
with tf.GradientTape() as tape:
tape.watch([interplate])
d_interplote_logits = discriminator(interplate)
d_interplote_logits = discriminator(interplate,is_training)
grads = tape.gradient(d_interplote_logits, interplate)
# grads:[b, h, w, c] => [b, -1]
@@ -170,4 +170,4 @@ def main():
if __name__ == '__main__':
main()
main()