Skip to content

Commit ce82f59

Browse files
committed
Fixing a bug that prevents model recovery when using multiple worker or ps devices.
1 parent 595b975 commit ce82f59

File tree

2 files changed

+28
-19
lines changed

2 files changed

+28
-19
lines changed

cloudml-gpu-distributed.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ trainingInput:
44
masterType: standard_gpu
55
workerCount: 2
66
workerType: standard_gpu
7-
parameterServerCount: 1
7+
parameterServerCount: 2
88
parameterServerType: standard

train.py

+27-18
Original file line numberDiff line numberDiff line change
@@ -325,10 +325,18 @@ def run(self, start_new_model=False):
325325

326326
target, device_fn = self.start_server_if_distributed()
327327

328+
meta_filename = self.get_meta_filename(start_new_model, self.train_dir)
329+
328330
with tf.Graph().as_default() as graph:
331+
332+
if meta_filename:
333+
saver = self.recover_model(meta_filename)
334+
329335
with tf.device(device_fn):
330336

331-
saver = self.recover_or_build_model(start_new_model, self.train_dir)
337+
if not meta_filename:
338+
saver = self.build_model()
339+
332340
global_step = tf.get_collection("global_step")[0]
333341
loss = tf.get_collection("loss")[0]
334342
predictions = tf.get_collection("predictions")[0]
@@ -422,29 +430,30 @@ def remove_training_directory(self, train_dir):
422430
" when starting a new model. Please delete it manually and" +
423431
" try again.", task_as_string(self.task))
424432

425-
def recover_or_build_model(self, start_new_model, train_dir):
426-
"""Recovers the model from a checkpoint or build it."""
427-
428-
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
429-
433+
def get_meta_filename(self, start_new_model, train_dir):
430434
if start_new_model:
431435
logging.info("%s: Flag 'start_new_model' is set. Building a new model.",
432436
task_as_string(self.task))
433-
return self.build_model()
434-
elif not latest_checkpoint:
437+
return None
438+
439+
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
440+
if not latest_checkpoint:
435441
logging.info("%s: No checkpoint file found. Building a new model.",
436442
task_as_string(self.task))
437-
return self.build_model()
438-
else:
439-
meta_filename = latest_checkpoint + ".meta"
440-
if not gfile.Exists(meta_filename):
441-
logging.info("%s: No meta graph file found. Building a new model.",
443+
return None
444+
445+
meta_filename = latest_checkpoint + ".meta"
446+
if not gfile.Exists(meta_filename):
447+
logging.info("%s: No meta graph file found. Building a new model.",
442448
task_as_string(self.task))
443-
return self.build_model()
444-
else:
445-
logging.info("%s: Restoring from meta graph file %s",
446-
task_as_string(self.task), meta_filename)
447-
return tf.train.import_meta_graph(meta_filename)
449+
return None
450+
else:
451+
return meta_filename
452+
453+
def recover_model(self, meta_filename):
454+
logging.info("%s: Restoring from meta graph file %s",
455+
task_as_string(self.task), meta_filename)
456+
return tf.train.import_meta_graph(meta_filename)
448457

449458
def build_model(self):
450459
"""Find the model and build the graph."""

0 commit comments

Comments
 (0)