mirror of
https://github.com/dragen1860/TensorFlow-2.x-Tutorials.git
synced 2021-05-12 18:32:23 +03:00
Update wgan_train.py
This commit is contained in:
@@ -4,7 +4,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from scipy.misc import toimage
|
||||
from PIL import Image
|
||||
import glob
|
||||
from gan import Generator, Discriminator
|
||||
|
||||
@@ -39,7 +39,8 @@ def save_result(val_out, val_block_size, image_path, color_mode):
|
||||
|
||||
if final_image.shape[2] == 1:
|
||||
final_image = np.squeeze(final_image, axis=2)
|
||||
toimage(final_image).save(image_path)
|
||||
# toimage(final_image).save(image_path)
|
||||
Image.fromarray(final_image).save(image_path)
|
||||
|
||||
|
||||
def celoss_ones(logits):
|
||||
@@ -71,7 +72,7 @@ def gradient_penalty(discriminator, batch_x, fake_image):
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
tape.watch([interplate])
|
||||
d_interplote_logits = discriminator(interplate,is_training)
|
||||
d_interplote_logits = discriminator(interplate, training=True)
|
||||
grads = tape.gradient(d_interplote_logits, interplate)
|
||||
|
||||
# grads:[b, h, w, c] => [b, -1]
|
||||
|
||||
Reference in New Issue
Block a user