diff --git a/byteps/tensorflow/ops.py b/byteps/tensorflow/ops.py index 870fbaf76..8b4475c2c 100644 --- a/byteps/tensorflow/ops.py +++ b/byteps/tensorflow/ops.py @@ -80,7 +80,12 @@ def _push_pull(tensor, scope='', name=None): if name is None and not _executing_eagerly(): name = 'BytePSPushPull_%s' % _normalize_name(tensor.name) if scope == '' and not _executing_eagerly(): - scope = tf.compat.v1.get_default_graph().get_name_scope() + if 'v1' in dir(tf.compat): + scope = tf.compat.v1.get_default_graph().get_name_scope() + else: + scope = tf.get_default_graph().get_name_scope() + if scope != '': + scope += '/' full_name = scope + name full_name = full_name.encode("ascii") TF_LIB_CTYPES.byteps_tensorflow_declare_tensor(ctypes.c_char_p(full_name)) @@ -113,7 +118,12 @@ def broadcast(tensor, root_rank, scope='', name=None, is_variable=True): if name is None and not _executing_eagerly(): name = 'BytePSBroadcast_%s' % _normalize_name(tensor.name) if scope == '' and not _executing_eagerly(): - scope = tf.compat.v1.get_default_graph().get_name_scope() + if 'v1' in dir(tf.compat): + scope = tf.compat.v1.get_default_graph().get_name_scope() + else: + scope = tf.get_default_graph().get_name_scope() + if scope != '': + scope += '/' full_name = scope + name full_name = full_name.encode("ascii") TF_LIB_CTYPES.byteps_tensorflow_declare_tensor(ctypes.c_char_p(full_name))