39
39
def main (_ ):
40
40
pp .pprint (flags .FLAGS .__flags )
41
41
42
- if not os .path .exists (FLAGS .checkpoint_dir ):
43
- os .makedirs (FLAGS .checkpoint_dir )
44
- if not os .path .exists (FLAGS .sample_dir ):
45
- os .makedirs (FLAGS .sample_dir )
42
+ tl .files .exists_or_mkdir (FLAGS .checkpoint_dir )
43
+ tl .files .exists_or_mkdir (FLAGS .sample_dir )
46
44
47
45
z_dim = 100
48
46
@@ -138,11 +136,6 @@ def main(_):
138
136
if np .mod (iter_counter , FLAGS .sample_step ) == 0 :
139
137
# generate and visualize generated images
140
138
img , errD , errG = sess .run ([net_g2 .outputs , d_loss , g_loss ], feed_dict = {z : sample_seed , real_images : sample_images })
141
- '''
142
- img255 = (np.array(img) + 1) / 2 * 255
143
- tl.visualize.images2d(images=img255, second=0, saveable=True,
144
- name='./{}/train_{:02d}_{:04d}'.format(FLAGS.sample_dir, epoch, idx), dtype=None, fig_idx=2838)
145
- '''
146
139
save_images (img , [8 , 8 ],
147
140
'./{}/train_{:02d}_{:04d}.png' .format (FLAGS .sample_dir , epoch , idx ))
148
141
print ("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD , errG ))
@@ -159,13 +152,13 @@ def main(_):
159
152
# the latest version location
160
153
net_g_name = os .path .join (save_dir , 'net_g.npz' )
161
154
net_d_name = os .path .join (save_dir , 'net_d.npz' )
162
- # this version is for future re-check and visualization analysis
163
- net_g_iter_name = os .path .join (save_dir , 'net_g_%d.npz' % iter_counter )
164
- net_d_iter_name = os .path .join (save_dir , 'net_d_%d.npz' % iter_counter )
165
- tl .files .save_npz (net_g .all_params , name = net_g_name , sess = sess )
166
- tl .files .save_npz (net_d .all_params , name = net_d_name , sess = sess )
167
- tl .files .save_npz (net_g .all_params , name = net_g_iter_name , sess = sess )
168
- tl .files .save_npz (net_d .all_params , name = net_d_iter_name , sess = sess )
155
+ # # this version is for future re-check and visualization analysis
156
+ # net_g_iter_name = os.path.join(save_dir, 'net_g_%d.npz' % iter_counter)
157
+ # net_d_iter_name = os.path.join(save_dir, 'net_d_%d.npz' % iter_counter)
158
+ # tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess)
159
+ # tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess)
160
+ # tl.files.save_npz(net_g.all_params, name=net_g_iter_name, sess=sess)
161
+ # tl.files.save_npz(net_d.all_params, name=net_d_iter_name, sess=sess)
169
162
print ("[*] Saving checkpoints SUCCESS!" )
170
163
171
164
if __name__ == '__main__' :
0 commit comments