diff --git a/train.py b/train.py index 3161918..1672f8a 100755 --- a/train.py +++ b/train.py @@ -23,6 +23,8 @@ def train(): for epoch in range(flags.n_epoch): for step, batch_images in enumerate(images): + if batch_images.shape[0] != flags.batch_size: # if the remaining data in this epoch < batch_size + break step_time = time.time() with tf.GradientTape(persistent=True) as tape: # z = tf.distributions.Normal(0., 1.).sample([flags.batch_size, flags.z_dim]) #tf.placeholder(tf.float32, [None, z_dim], name='z_noise')