mirror of
https://github.com/dragen1860/TensorFlow-2.x-Tutorials.git
synced 2021-05-12 18:32:23 +03:00
Update wgan_train.py
enable training for gp
This commit is contained in:
@@ -69,7 +69,7 @@ def gradient_penalty(discriminator, batch_x, fake_image):
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
tape.watch([interplate])
|
||||
d_interplote_logits = discriminator(interplate)
|
||||
d_interplote_logits = discriminator(interplate,is_training)
|
||||
grads = tape.gradient(d_interplote_logits, interplate)
|
||||
|
||||
# grads:[b, h, w, c] => [b, -1]
|
||||
@@ -170,4 +170,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user