Skip to content

Commit 27d83c3

Browse files
committed
update download
1 parent 2cfa53b commit 27d83c3

File tree

2 files changed

+11
-17
lines changed

2 files changed

+11
-17
lines changed

download.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
- MNIST dataset
88
"""
99
from __future__ import print_function
10-
import os, sys, gzip, json, shutil, zipfile, argparse, subprocess
10+
import os, sys, gzip, json, shutil, zipfile, argparse, subprocess, requests
11+
from tqdm import tqdm
1112
from six.moves import urllib
1213

1314
parser = argparse.ArgumentParser(description='Download dataset for DCGAN.')

main.py

+9-16
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,8 @@
3939
def main(_):
4040
pp.pprint(flags.FLAGS.__flags)
4141

42-
if not os.path.exists(FLAGS.checkpoint_dir):
43-
os.makedirs(FLAGS.checkpoint_dir)
44-
if not os.path.exists(FLAGS.sample_dir):
45-
os.makedirs(FLAGS.sample_dir)
42+
tl.files.exists_or_mkdir(FLAGS.checkpoint_dir)
43+
tl.files.exists_or_mkdir(FLAGS.sample_dir)
4644

4745
z_dim = 100
4846

@@ -138,11 +136,6 @@ def main(_):
138136
if np.mod(iter_counter, FLAGS.sample_step) == 0:
139137
# generate and visualize generated images
140138
img, errD, errG = sess.run([net_g2.outputs, d_loss, g_loss], feed_dict={z : sample_seed, real_images: sample_images})
141-
'''
142-
img255 = (np.array(img) + 1) / 2 * 255
143-
tl.visualize.images2d(images=img255, second=0, saveable=True,
144-
name='./{}/train_{:02d}_{:04d}'.format(FLAGS.sample_dir, epoch, idx), dtype=None, fig_idx=2838)
145-
'''
146139
save_images(img, [8, 8],
147140
'./{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, idx))
148141
print("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD, errG))
@@ -159,13 +152,13 @@ def main(_):
159152
# the latest version location
160153
net_g_name = os.path.join(save_dir, 'net_g.npz')
161154
net_d_name = os.path.join(save_dir, 'net_d.npz')
162-
# this version is for future re-check and visualization analysis
163-
net_g_iter_name = os.path.join(save_dir, 'net_g_%d.npz' % iter_counter)
164-
net_d_iter_name = os.path.join(save_dir, 'net_d_%d.npz' % iter_counter)
165-
tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess)
166-
tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess)
167-
tl.files.save_npz(net_g.all_params, name=net_g_iter_name, sess=sess)
168-
tl.files.save_npz(net_d.all_params, name=net_d_iter_name, sess=sess)
155+
# # this version is for future re-check and visualization analysis
156+
# net_g_iter_name = os.path.join(save_dir, 'net_g_%d.npz' % iter_counter)
157+
# net_d_iter_name = os.path.join(save_dir, 'net_d_%d.npz' % iter_counter)
158+
# tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess)
159+
# tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess)
160+
# tl.files.save_npz(net_g.all_params, name=net_g_iter_name, sess=sess)
161+
# tl.files.save_npz(net_d.all_params, name=net_d_iter_name, sess=sess)
169162
print("[*] Saving checkpoints SUCCESS!")
170163

171164
if __name__ == '__main__':

0 commit comments

Comments
 (0)