From 091e11c24c00464288371678dda06a1dfeac4ee5 Mon Sep 17 00:00:00 2001 From: Hao Date: Thu, 11 Jul 2019 16:47:32 +0800 Subject: [PATCH] Update train.py --- train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/train.py b/train.py index 3161918..1672f8a 100755 --- a/train.py +++ b/train.py @@ -23,6 +23,8 @@ def train(): for epoch in range(flags.n_epoch): for step, batch_images in enumerate(images): + if batch_images.shape[0] != flags.batch_size: # if the remaining data in this epoch < batch_size + break step_time = time.time() with tf.GradientTape(persistent=True) as tape: # z = tf.distributions.Normal(0., 1.).sample([flags.batch_size, flags.z_dim]) #tf.placeholder(tf.float32, [None, z_dim], name='z_noise')