@@ -325,10 +325,18 @@ def run(self, start_new_model=False):
325
325
326
326
target , device_fn = self .start_server_if_distributed ()
327
327
328
+ meta_filename = self .get_meta_filename (start_new_model , self .train_dir )
329
+
328
330
with tf .Graph ().as_default () as graph :
331
+
332
+ if meta_filename :
333
+ saver = self .recover_model (meta_filename )
334
+
329
335
with tf .device (device_fn ):
330
336
331
- saver = self .recover_or_build_model (start_new_model , self .train_dir )
337
+ if not meta_filename :
338
+ saver = self .build_model ()
339
+
332
340
global_step = tf .get_collection ("global_step" )[0 ]
333
341
loss = tf .get_collection ("loss" )[0 ]
334
342
predictions = tf .get_collection ("predictions" )[0 ]
@@ -422,29 +430,30 @@ def remove_training_directory(self, train_dir):
422
430
" when starting a new model. Please delete it manually and" +
423
431
" try again." , task_as_string (self .task ))
424
432
425
- def recover_or_build_model (self , start_new_model , train_dir ):
426
- """Recovers the model from a checkpoint or build it."""
427
-
428
- latest_checkpoint = tf .train .latest_checkpoint (train_dir )
429
-
433
+ def get_meta_filename (self , start_new_model , train_dir ):
430
434
if start_new_model :
431
435
logging .info ("%s: Flag 'start_new_model' is set. Building a new model." ,
432
436
task_as_string (self .task ))
433
- return self .build_model ()
434
- elif not latest_checkpoint :
437
+ return None
438
+
439
+ latest_checkpoint = tf .train .latest_checkpoint (train_dir )
440
+ if not latest_checkpoint :
435
441
logging .info ("%s: No checkpoint file found. Building a new model." ,
436
442
task_as_string (self .task ))
437
- return self . build_model ()
438
- else :
439
- meta_filename = latest_checkpoint + ".meta"
440
- if not gfile .Exists (meta_filename ):
441
- logging .info ("%s: No meta graph file found. Building a new model." ,
443
+ return None
444
+
445
+ meta_filename = latest_checkpoint + ".meta"
446
+ if not gfile .Exists (meta_filename ):
447
+ logging .info ("%s: No meta graph file found. Building a new model." ,
442
448
task_as_string (self .task ))
443
- return self .build_model ()
444
- else :
445
- logging .info ("%s: Restoring from meta graph file %s" ,
446
- task_as_string (self .task ), meta_filename )
447
- return tf .train .import_meta_graph (meta_filename )
449
+ return None
450
+ else :
451
+ return meta_filename
452
+
453
+ def recover_model (self , meta_filename ):
454
+ logging .info ("%s: Restoring from meta graph file %s" ,
455
+ task_as_string (self .task ), meta_filename )
456
+ return tf .train .import_meta_graph (meta_filename )
448
457
449
458
def build_model (self ):
450
459
"""Find the model and build the graph."""
0 commit comments