Skip to content

Commit

Permalink
tensorflow: improve ops.py compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
ymjiang committed Jun 29, 2019
1 parent 3fe43cf commit 0138c6a
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions byteps/tensorflow/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ def _push_pull(tensor, scope='', name=None):
if name is None and not _executing_eagerly():
name = 'BytePSPushPull_%s' % _normalize_name(tensor.name)
if scope == '' and not _executing_eagerly():
scope = tf.compat.v1.get_default_graph().get_name_scope()
try:
scope = tf.get_default_graph().get_name_scope()
except:
scope = tf.compat.v1.get_default_graph().get_name_scope()
if scope != '':
scope += '/'
full_name = scope + name
Expand Down Expand Up @@ -115,7 +118,10 @@ def broadcast(tensor, root_rank, scope='', name=None, is_variable=True):
if name is None and not _executing_eagerly():
name = 'BytePSBroadcast_%s' % _normalize_name(tensor.name)
if scope == '' and not _executing_eagerly():
scope = tf.compat.v1.get_default_graph().get_name_scope()
try:
scope = tf.get_default_graph().get_name_scope()
except:
scope = tf.compat.v1.get_default_graph().get_name_scope()
if scope != '':
scope += '/'
full_name = scope + name
Expand Down

0 comments on commit 0138c6a

Please sign in to comment.