From 0138c6a6d698611a45563ec320735e61e596f008 Mon Sep 17 00:00:00 2001 From: Yimin Jiang Date: Sat, 29 Jun 2019 14:39:42 +0800 Subject: [PATCH] tensorflow: improve ops.py compatibility --- byteps/tensorflow/ops.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/byteps/tensorflow/ops.py b/byteps/tensorflow/ops.py index 785f3f183..8eaeb9d4d 100644 --- a/byteps/tensorflow/ops.py +++ b/byteps/tensorflow/ops.py @@ -80,7 +80,10 @@ def _push_pull(tensor, scope='', name=None): if name is None and not _executing_eagerly(): name = 'BytePSPushPull_%s' % _normalize_name(tensor.name) if scope == '' and not _executing_eagerly(): - scope = tf.compat.v1.get_default_graph().get_name_scope() + try: + scope = tf.get_default_graph().get_name_scope() + except: + scope = tf.compat.v1.get_default_graph().get_name_scope() if scope != '': scope += '/' full_name = scope + name @@ -115,7 +118,10 @@ def broadcast(tensor, root_rank, scope='', name=None, is_variable=True): if name is None and not _executing_eagerly(): name = 'BytePSBroadcast_%s' % _normalize_name(tensor.name) if scope == '' and not _executing_eagerly(): - scope = tf.compat.v1.get_default_graph().get_name_scope() + try: + scope = tf.get_default_graph().get_name_scope() + except: + scope = tf.compat.v1.get_default_graph().get_name_scope() if scope != '': scope += '/' full_name = scope + name