Update wgan_train.py

This commit is contained in:
Jackie Loong
2019-12-24 14:23:06 +08:00
committed by GitHub
parent 81365d989d
commit 1ce4f6b5ce

View File

@@ -4,7 +4,7 @@ os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import numpy as np
import tensorflow as tf
from tensorflow import keras
from scipy.misc import toimage
from PIL import Image
import glob
from gan import Generator, Discriminator
@@ -39,7 +39,8 @@ def save_result(val_out, val_block_size, image_path, color_mode):
if final_image.shape[2] == 1:
final_image = np.squeeze(final_image, axis=2)
toimage(final_image).save(image_path)
# toimage(final_image).save(image_path)
Image.fromarray(final_image).save(image_path)
def celoss_ones(logits):
@@ -71,7 +72,7 @@ def gradient_penalty(discriminator, batch_x, fake_image):
with tf.GradientTape() as tape:
tape.watch([interplate])
d_interplote_logits = discriminator(interplate,is_training)
d_interplote_logits = discriminator(interplate, training=True)
grads = tape.gradient(d_interplote_logits, interplate)
# grads:[b, h, w, c] => [b, -1]