From d4b4357bc01253221fd37f6b81b596d4e8265de1 Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Thu, 17 Sep 2020 10:23:42 +0800 Subject: [PATCH] [Dy2stat] Change the Global Switch Name of ProgramTranslator for API 2.0 (#27203) Change ProgramTranslator.enable_declarative to ProgramTranslator.enable_to_static to meet API 2.0 --- .../dygraph_to_static/program_translator.py | 46 +++++++++++-------- python/paddle/fluid/dygraph/jit.py | 6 +-- python/paddle/hapi/model.py | 10 ++-- 3 files changed, 34 insertions(+), 28 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index e5fce3e6ede15..dbf030ccda16f 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -246,7 +246,7 @@ def __init__(self, function, input_spec=None): self._function_spec = FunctionSpec(function, input_spec) self._program_cache = ProgramCache() self._descriptor_cache = weakref.WeakKeyDictionary() - # Note: Hold a reference to ProgramTranslator for switching `enable_declarative`. + # Note: Hold a reference to ProgramTranslator for switching `enable_to_static`. self._program_trans = ProgramTranslator() def __get__(self, instance, owner): @@ -299,16 +299,17 @@ def __call__(self, *args, **kwargs): """ # 1. call dygraph function directly if not enable `declarative` - if not self._program_trans.enable_declarative: + if not self._program_trans.enable_to_static: logging_utils.warn( - "The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable=False. " - "We will just return dygraph output.") + "The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable to False. " + "We will just return dygraph output. If you would like to get static graph output, please call API " + "ProgramTranslator.enable(True)") return self._call_dygraph_function(*args, **kwargs) - if not in_dygraph_mode() and self._program_trans.enable_declarative: + if not in_dygraph_mode(): raise RuntimeError( "Failed to run the callable object {} decorated by '@paddle.jit.to_static', " - "because it does NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the " + "because it is NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the " "following API: paddle.disable_static().".format( self.dygraph_function)) @@ -723,15 +724,15 @@ def __init__(self): return self._initialized = True self._program_cache = ProgramCache() - self.enable_declarative = True + self.enable_to_static = True - def enable(self, enable_declarative): + def enable(self, enable_to_static): """ Enable or disable the converting from imperative to declarative by ProgramTranslator globally. Args: - enable_declarative (bool): True or False to enable or disable declarative. + enable_to_static (bool): True or False to enable or disable declarative. Returns: None. @@ -760,9 +761,9 @@ def func(x): print(func(x).numpy()) # [[2. 2.]] """ - check_type(enable_declarative, "enable_declarative", bool, + check_type(enable_to_static, "enable_to_static", bool, "ProgramTranslator.enable") - self.enable_declarative = enable_declarative + self.enable_to_static = enable_to_static def get_output(self, dygraph_func, *args, **kwargs): """ @@ -803,10 +804,12 @@ def func(x): assert callable( dygraph_func ), "Input dygraph_func is not a callable in ProgramTranslator.get_output" - if not self.enable_declarative: + if not self.enable_to_static: warnings.warn( - "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. " - "We will just return dygraph output.") + "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. " + "We will just return dygraph output. " + "Please call ProgramTranslator.enable(True) if you would like to get static output." + ) return dygraph_func(*args, **kwargs) try: function_spec = FunctionSpec(dygraph_func) @@ -876,10 +879,11 @@ def func(x): assert callable( dygraph_func ), "Input dygraph_func is not a callable in ProgramTranslator.get_func" - if not self.enable_declarative: + if not self.enable_to_static: warnings.warn( - "The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable=False. We will " - "just return dygraph output.") + "The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable to False. We will " + "just return dygraph output. Please call ProgramTranslator.enable(True) if you would like to get static output." + ) return dygraph_func static_func = convert_to_static(dygraph_func) @@ -929,10 +933,12 @@ def func(x): assert callable( dygraph_func ), "Input dygraph_func is not a callable in ProgramTranslator.get_program" - if not self.enable_declarative: + if not self.enable_to_static: warnings.warn( - "The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable=False." - "We will just return dygraph output.") + "The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable to False." + "We will just return dygraph output. " + "Please call ProgramTranslator.enable(True) if you would like to get static output." + ) return dygraph_func(*args, **kwargs) function_spec = FunctionSpec(dygraph_func) diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 57864efec8a94..834c1a737d73b 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -119,7 +119,7 @@ def func(x): # TODO: remove this decorator after we finalize training API def __impl__(*args, **kwargs): program_translator = ProgramTranslator() - if in_dygraph_mode() or not program_translator.enable_declarative: + if in_dygraph_mode() or not program_translator.enable_to_static: warnings.warn( "The decorator 'dygraph_to_static_func' doesn't work in " "dygraph mode or set ProgramTranslator.enable to False. " @@ -832,9 +832,9 @@ def train(layer, loader, loss_fn, opt): # 1. input check prog_translator = ProgramTranslator() - if not prog_translator.enable: + if not prog_translator.enable_to_static: raise RuntimeError( - "The paddle.jit.save doesn't work when setting ProgramTranslator.enable=False." + "The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False." ) if not isinstance(layer, Layer): raise TypeError( diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index 2836a151ec356..c445977df1405 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -1680,7 +1680,7 @@ def get_inout_spec(all_vars, return_name=False): # TODO: # 1. Make it Unnecessary to run model before calling `save_inference_model` for users in dygraph. - # 2. Save correct shape of input, now the interface stores the shape that the user sent to + # 2. Save correct shape of input, now the interface stores the shape that the user sent to # the inputs of the model in running. # 3. Make it Unnecessary to add `@paddle.jit.to_static` for users in dynamic mode. if fluid.in_dygraph_mode(): @@ -1689,9 +1689,9 @@ def get_inout_spec(all_vars, return_name=False): # 1. input check prog_translator = ProgramTranslator() - if not prog_translator.enable_declarative: + if not prog_translator.enable_to_static: raise RuntimeError( - "save_inference_model doesn't work when setting ProgramTranslator.enable=False." + "save_inference_model doesn't work when setting ProgramTranslator.enable to False." ) if not isinstance(layer, Layer): raise TypeError( @@ -1902,8 +1902,8 @@ def _verify_spec(self, specs, is_input=False): assert isinstance(spec, Input) if spec.name is None: raise ValueError( - "Requires Input[{}].name != None, but receive `None` with {}.". - format(i, spec)) + "Requires Input[{}].name != None, but receive `None` with {}." + .format(i, spec)) return out_specs