diff --git a/byteps/tensorflow/ops.py b/byteps/tensorflow/ops.py index 8eaeb9d4d..8b4475c2c 100644 --- a/byteps/tensorflow/ops.py +++ b/byteps/tensorflow/ops.py @@ -80,10 +80,10 @@ def _push_pull(tensor, scope='', name=None): if name is None and not _executing_eagerly(): name = 'BytePSPushPull_%s' % _normalize_name(tensor.name) if scope == '' and not _executing_eagerly(): - try: - scope = tf.get_default_graph().get_name_scope() - except: + if 'v1' in dir(tf.compat): scope = tf.compat.v1.get_default_graph().get_name_scope() + else: + scope = tf.get_default_graph().get_name_scope() if scope != '': scope += '/' full_name = scope + name @@ -118,10 +118,10 @@ def broadcast(tensor, root_rank, scope='', name=None, is_variable=True): if name is None and not _executing_eagerly(): name = 'BytePSBroadcast_%s' % _normalize_name(tensor.name) if scope == '' and not _executing_eagerly(): - try: - scope = tf.get_default_graph().get_name_scope() - except: + if 'v1' in dir(tf.compat): scope = tf.compat.v1.get_default_graph().get_name_scope() + else: + scope = tf.get_default_graph().get_name_scope() if scope != '': scope += '/' full_name = scope + name