diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index 45a567b57f25c..0346e4f1efda8 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -21,6 +21,7 @@ from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME +from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar def convert_while_loop(cond, body, loop_vars): @@ -188,7 +189,8 @@ def _run_py_logical_not(x): return not x -def convert_ifelse(pred, true_fn, false_fn, true_args, false_args): +def convert_ifelse(pred, true_fn, false_fn, get_args, set_args, + return_name_ids): """ A function representation of a Python ``if/else`` statement. @@ -196,17 +198,18 @@ def convert_ifelse(pred, true_fn, false_fn, true_args, false_args): pred(bool|Tensor): A boolean Tensor which determines whether to return the result of ``true_fn`` or ``false_fn`` . true_fn(callable): A callable to be performed if ``pred`` is true. false_fn(callable): A callable to be performed if ``pred`` is false. - true_args(tuple): Parameters of ``true_fn``. - false_args(tuple): Parameters of ``false_fn``. + get_args(callable): Get all arguments that needed in true_fn and false_fn. + set_args(callable): Update arguments that modified in trure_fn and false_fn. Returns: - ``true_fn(true_args)`` if the predicate ``pred`` is true else ``false_fn(false_args)`` . + ``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` . """ if isinstance(pred, Variable): - out = _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args) + out = _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args, + return_name_ids) else: - out = _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args) + out = _run_py_ifelse(pred, true_fn, false_fn) return _remove_no_value_return_var(out) @@ -244,14 +247,59 @@ def _remove_no_value_return_var(out): return out -def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args): +def _check_no_undefined_var(outs, names, branch_name): + if names is None: return + if not isinstance(outs, (list, tuple)): + outs = [outs] + for var, name in zip(list(outs), names): + if isinstance(var, UndefinedVar): + raise ValueError( + "Required '{}' must be initialized both in if-else branch, but found it not initialized in '{}'." + .format(name, branch_name)) + + +def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args, + return_name_ids): + """ + Paddle cond API will evaluate both ture_fn and false_fn codes. + """ pred = cast_bool_if_necessary(pred) - return control_flow.cond(pred, lambda: true_fn(*true_args), - lambda: false_fn(*false_args)) + init_args = get_args() + + def new_true_fn(): + set_args(init_args) + outs = true_fn() + _check_no_undefined_var(outs, return_name_ids, 'if_body') + return outs + + def new_false_fn(): + set_args(init_args) + outs = false_fn() + _check_no_undefined_var(outs, return_name_ids, 'else_body') + return outs + + cond_outs = control_flow.cond(pred, new_true_fn, new_false_fn) + # IfExpr's return_name_ids maybe None + if return_name_ids is None: + return cond_outs + + # recover args state + num_outs = len(return_name_ids) + num_args = 1 if not isinstance(init_args, tuple) else len(init_args) + assert num_outs <= num_args + + if num_args == 1: + final_outs = cond_outs + else: + cond_outs = (cond_outs, ) if num_outs == 1 else cond_outs + final_outs = cond_outs + init_args[num_outs:] + + set_args(final_outs) + return final_outs -def _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args): - return true_fn(*true_args) if pred else false_fn(*false_args) +def _run_py_ifelse(pred, true_fn, false_fn): + return true_fn() if pred else false_fn() def convert_len(var): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index 9a29a535ab236..bff41c9b9ae02 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -16,6 +16,7 @@ import six import copy +import textwrap from collections import defaultdict # gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST). @@ -29,10 +30,14 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper -from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node +from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var +from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_nonlocal_stmt_node TRUE_FUNC_PREFIX = 'true_fn' FALSE_FUNC_PREFIX = 'false_fn' +GET_ARGS_FUNC_PREFIX = 'get_args' +SET_ARGS_FUNC_PREFIX = 'set_args' +ARGS_NAME = '__args' class IfElseTransformer(gast.NodeTransformer): @@ -56,13 +61,16 @@ def transform(self): def visit_If(self, node): self.generic_visit(node) - new_vars_stmts, true_func_node, false_func_node, return_name_ids = transform_if_else( + new_vars_stmts, true_func_node, false_func_node, get_args_node, set_args_node, return_name_ids = transform_if_else( node, self.root) new_node = create_convert_ifelse_node(return_name_ids, node.test, - true_func_node, false_func_node) + true_func_node, false_func_node, + get_args_node, set_args_node) - return new_vars_stmts + [true_func_node, false_func_node] + [new_node] + return new_vars_stmts + [ + get_args_node, set_args_node, true_func_node, false_func_node + ] + [new_node] def visit_Call(self, node): # Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]` @@ -80,7 +88,7 @@ def visit_IfExp(self, node): self.generic_visit(node) new_node = create_convert_ifelse_node(None, node.test, node.body, - node.orelse, True) + node.orelse, None, None, True) # Note: A blank line will be added separately if transform gast.Expr # into source code. Using gast.Expr.value instead to avoid syntax error # in python. @@ -192,6 +200,12 @@ def visit_Assign(self, node): self.generic_visit(node) def visit_FunctionDef(self, node): + # NOTE: We skip to visit names of get_args and set_args, because they contains + # nonlocal statement such as 'nonlocal x, self' where 'self' should not be + # parsed as returned value in contron flow. + if GET_ARGS_FUNC_PREFIX in node.name or SET_ARGS_FUNC_PREFIX in node.name: + return + if not self._in_range: self.generic_visit(node) return @@ -269,7 +283,7 @@ def get_name_ids(nodes, after_node=None, end_node=None): return name_visitor.name_ids -def parse_cond_args(parent_ids_dict, +def parse_cond_args(parent_ids, var_ids_dict, modified_ids_dict=None, ctx=gast.Load): @@ -307,24 +321,9 @@ def parse_cond_args(parent_ids_dict, # ``` # # In the above case, `v` should not be in the args of cond() - arg_name_ids = list(set(arg_name_ids) & set(parent_ids_dict)) - - arg_name_ids.sort() - args = [ - gast.Name(id=name_id, - ctx=gast.Load(), - annotation=None, - type_comment=None) for name_id in arg_name_ids - ] - arguments = gast.arguments(args=args, - posonlyargs=[], - vararg=None, - kwonlyargs=[], - kw_defaults=None, - kwarg=None, - defaults=[]) + arg_name_ids = set(arg_name_ids) & set(parent_ids) - return arguments + return arg_name_ids def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict, @@ -454,10 +453,35 @@ def _vars_loaded(ids_dict): return return_ids, modified_vars_from_parent, new_vars_to_create +def _valid_nonlocal_names(return_name_ids, nonlocal_names): + """ + All var in return_name_ids should be in nonlocal_names. + Moreover, we will always put return_name_ids in front of nonlocal_names. + + For Example: + + return_name_ids: [x, y] + nonlocal_names : [a, y, b, x] + + Return: + nonlocal_names : [x, y, a, b] + """ + assert isinstance(return_name_ids, list) + for name in return_name_ids: + if name not in nonlocal_names: + raise ValueError( + "Required returned var '{}' must be in 'nonlocal' statement '', but not found." + .format(name)) + nonlocal_names.remove(name) + + return return_name_ids + nonlocal_names + + def transform_if_else(node, root): """ Transform ast.If into control flow statement of Paddle static graph. """ + # TODO(liym27): Consider variable like `self.a` modified in if/else node. parent_name_ids = get_name_ids([root], end_node=node) body_name_ids = get_name_ids(node.body) @@ -480,73 +504,134 @@ def transform_if_else(node, root): for name in new_vars_to_create: # NOTE: Consider variable like `self.a` modified in if/else node. if "." not in name: - create_new_vars_in_parent_stmts.append( - create_static_variable_gast_node(name)) - - modified_name_ids = modified_name_ids_from_parent | new_vars_to_create + create_new_vars_in_parent_stmts.append(create_undefined_var(name)) + + parent_ids_set = set() + for k, ctxs in parent_name_ids.items(): + if any([not isinstance(ctx, gast.Load) for ctx in ctxs]): + parent_ids_set.add(k) + + trun_args = parse_cond_args(parent_ids_set, body_name_ids, + modified_name_ids_from_parent) + false_args = parse_cond_args(parent_ids_set, orelse_name_ids, + modified_name_ids_from_parent) + nonlocal_names = list(trun_args | false_args | new_vars_to_create) + nonlocal_names.sort() + # NOTE: All var in return_name_ids should be in nonlocal_names. + nonlocal_names = _valid_nonlocal_names(return_name_ids, nonlocal_names) + + # TODO(dev): Need a better way to deal this. + if ARGS_NAME in nonlocal_names: + nonlocal_names.remove(ARGS_NAME) + + nonlocal_stmt_node = [create_nonlocal_stmt_node(nonlocal_names)] + + empty_arg_node = gast.arguments(args=[], + posonlyargs=[], + vararg=None, + kwonlyargs=[], + kw_defaults=None, + kwarg=None, + defaults=[]) true_func_node = create_funcDef_node( - node.body, + nonlocal_stmt_node + node.body, name=unique_name.generate(TRUE_FUNC_PREFIX), - input_args=parse_cond_args(parent_name_ids, body_name_ids, - modified_name_ids), + input_args=empty_arg_node, return_name_ids=return_name_ids) false_func_node = create_funcDef_node( - node.orelse, + nonlocal_stmt_node + node.orelse, name=unique_name.generate(FALSE_FUNC_PREFIX), - input_args=parse_cond_args(parent_name_ids, orelse_name_ids, - modified_name_ids), + input_args=empty_arg_node, return_name_ids=return_name_ids) - return create_new_vars_in_parent_stmts, true_func_node, false_func_node, return_name_ids + + get_args_node = create_get_args_node(nonlocal_names) + set_args_node = create_set_args_node(nonlocal_names) + + return create_new_vars_in_parent_stmts, true_func_node, false_func_node, get_args_node, set_args_node, return_name_ids + + +def create_get_args_node(names): + """ + Create get_args function as follows: + + def get_args_0(): + nonlocal x, y + """ + assert isinstance(names, (list, tuple)) + template = """ + def {func_name}(): + nonlocal {vars} + return {vars} + """ + func_def = template.format( + func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX), + vars=",".join(names)) + return gast.parse(textwrap.dedent(func_def)).body[0] + + +def create_set_args_node(names): + """ + Create set_args function as follows: + + def set_args_0(__args): + nonlocal x, y + x, y = __args + """ + assert isinstance(names, (list, tuple)) + template = """ + def {func_name}({args}): + nonlocal {vars} + {vars} = {args} + """ + func_def = template.format( + func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX), + args=ARGS_NAME, + vars=",".join(names)) + return gast.parse(textwrap.dedent(func_def)).body[0] def create_convert_ifelse_node(return_name_ids, pred, true_func, false_func, + get_args_func, + set_args_func, is_if_expr=False): """ Create `paddle.jit.dy2static.convert_ifelse( - pred, true_fn, false_fn, true_args, false_args)` + pred, true_fn, false_fn, get_args, set_args, return_name_ids)` to replace original `python if/else` statement. """ - def create_name_nodes(name_ids): + def create_name_str(name_ids): + """ + Return "('x', 'y')" for [x, y] + """ if not name_ids: - return gast.Tuple(elts=[], ctx=gast.Load()) + return 'None' - gast_names = [ - gast.Name(id=name_id, - ctx=gast.Load(), - annotation=None, - type_comment=None) for name_id in name_ids - ] - name_node = gast.Tuple(elts=gast_names, ctx=gast.Load()) - return name_node + names_str = ["'%s'" % name for name in name_ids] + return "(%s, )" % ','.join(names_str) if is_if_expr: - true_args = gast.Tuple(elts=[], ctx=gast.Load()) - false_args = gast.Tuple(elts=[], ctx=gast.Load()) true_func_source = "lambda : {}".format(ast_to_source_code(true_func)) false_func_source = "lambda : {}".format(ast_to_source_code(false_func)) else: - true_args = gast.Tuple(elts=true_func.args.args, ctx=gast.Load()) - false_args = gast.Tuple(elts=false_func.args.args, ctx=gast.Load()) true_func_source = true_func.name false_func_source = false_func.name convert_ifelse_layer = gast.parse( '_jst.convert_ifelse(' - '{pred}, {true_fn}, {false_fn}, {true_args}, {false_args})'.format( + '{pred}, {true_fn}, {false_fn}, {get_args}, {set_args}, {return_name_ids})' + .format( pred=ast_to_source_code(pred), true_fn=true_func_source, false_fn=false_func_source, - true_args=ast_to_source_code(true_args), - false_args=ast_to_source_code(false_args))).body[0].value - - if return_name_ids: - _, cond_node = create_assign_node(return_name_ids, convert_ifelse_layer) - else: # No variables can be returned if no assign statement in if.body. - cond_node = gast.Expr(value=convert_ifelse_layer) + get_args=get_args_func.name if not is_if_expr else + 'lambda: None', #TODO: better way to deal with this + set_args=set_args_func.name + if not is_if_expr else 'lambda args: None', + return_name_ids=create_name_str(return_name_ids))).body[0] - return cond_node + return convert_ifelse_layer diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 0afe42e3e296b..2df8169a3efe1 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -87,6 +87,23 @@ def visit(self, node): ]) +class UndefinedVar: + + def __init__(self, name): + self.name = name + + def check(self): + raise UnboundLocalError( + "local variable '{}' should be created before using it.") + + +def saw(x): + if isinstance(x, UndefinedVar): + return x.check() + else: + return x + + def getfullargspec(target): if hasattr(inspect, "getfullargspec"): return inspect.getfullargspec(target) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py index 263c3cbae9579..e823813acaacb 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py @@ -25,7 +25,7 @@ __all__ = [ 'create_bool_as_type', 'create_fill_constant_node', 'create_static_variable_gast_node', 'data_layer_not_check', - 'to_static_variable', 'to_static_variable_gast_node' + 'to_static_variable', 'to_static_variable_gast_node', 'create_undefined_var' ] @@ -74,6 +74,17 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0): need_check_feed=False) +def create_undefined_var(name): + func_code = "{} = _jst.UndefinedVar('{}')".format(name, name) + return gast.parse(func_code).body[0] + + +def create_nonlocal_stmt_node(names): + assert isinstance(names, (list, tuple)) + func_code = "nonlocal {}".format(','.join(names)) + return gast.parse(func_code).body[0] + + def to_static_variable_gast_node(name): func_code = "{} = _jst.to_static_variable({})".format(name, name) return gast.parse(func_code).body[0] diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py index cf8be6640300e..75bac135424aa 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -72,33 +72,51 @@ def dyfunc_with_if_else(x_v, label=None): name='__return_value_init_0') __return_value_0 = __return_value_init_0 - def true_fn_0(x_v): + def get_args_0(): + nonlocal x_v + return x_v + + def set_args_0(__args): + nonlocal x_v + x_v = __args + + def true_fn_0(): + nonlocal x_v x_v = x_v - 1 return x_v - def false_fn_0(x_v): + def false_fn_0(): + nonlocal x_v x_v = x_v + 1 return x_v - x_v = _jst.convert_ifelse( - fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ), - (x_v, )) + _jst.convert_ifelse( + fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, get_args_0, + set_args_0, ('x_v', )) + + def get_args_1(): + nonlocal __return_value_0, label, x_v + return __return_value_0, label, x_v + + def set_args_1(__args): + nonlocal __return_value_0, label, x_v + __return_value_0, label, x_v = __args - def true_fn_1(__return_value_0, label, x_v): + def true_fn_1(): + nonlocal __return_value_0, label, x_v loss = fluid.layers.cross_entropy(x_v, label) __return_0 = _jst.create_bool_as_type(label is not None, True) __return_value_0 = loss return __return_value_0 - def false_fn_1(__return_value_0, label, x_v): + def false_fn_1(): + nonlocal __return_value_0, label, x_v __return_1 = _jst.create_bool_as_type(label is not None, True) __return_value_0 = x_v return __return_value_0 - __return_value_0 = _jst.convert_ifelse(label is not None, true_fn_1, - false_fn_1, - (__return_value_0, label, x_v), - (__return_value_0, label, x_v)) + _jst.convert_ifelse(label is not None, true_fn_1, false_fn_1, + get_args_1, set_args_1, ('__return_value_0', )) return __return_value_0 @@ -111,33 +129,51 @@ def dyfunc_with_if_else(x_v, label=None): name='__return_value_init_1') __return_value_1 = __return_value_init_1 - def true_fn_2(x_v): + def get_args_2(): + nonlocal x_v + return x_v + + def set_args_2(__args): + nonlocal x_v + x_v = __args + + def true_fn_2(): + nonlocal x_v x_v = x_v - 1 return x_v - def false_fn_2(x_v): + def false_fn_2(): + nonlocal x_v x_v = x_v + 1 return x_v - x_v = _jst.convert_ifelse( - fluid.layers.mean(x_v)[0] > 5, true_fn_2, false_fn_2, (x_v, ), - (x_v, )) + _jst.convert_ifelse( + fluid.layers.mean(x_v)[0] > 5, true_fn_2, false_fn_2, get_args_2, + set_args_2, ('x_v', )) + + def get_args_3(): + nonlocal __return_value_1, label, x_v + return __return_value_1, label, x_v + + def set_args_3(__args): + nonlocal __return_value_1, label, x_v + __return_value_1, label, x_v = __args - def true_fn_3(__return_value_1, label, x_v): + def true_fn_3(): + nonlocal __return_value_1, label, x_v loss = fluid.layers.cross_entropy(x_v, label) __return_2 = _jst.create_bool_as_type(label is not None, True) __return_value_1 = loss return __return_value_1 - def false_fn_3(__return_value_1, label, x_v): + def false_fn_3(): + nonlocal __return_value_1, label, x_v __return_3 = _jst.create_bool_as_type(label is not None, True) __return_value_1 = x_v return __return_value_1 - __return_value_1 = _jst.convert_ifelse(label is not None, true_fn_3, - false_fn_3, - (__return_value_1, label, x_v), - (__return_value_1, label, x_v)) + _jst.convert_ifelse(label is not None, true_fn_3, false_fn_3, + get_args_3, set_args_3, ('__return_value_1', )) return __return_value_1 @@ -166,6 +202,7 @@ def test_program_translator(self): answer = get_source_code(StaticCode2.dyfunc_with_if_else) program_translator = ProgramTranslator() code = program_translator.get_code(dyfunc_with_if_else) + # print(code) self.assertEqual(answer, code) diff --git a/python/paddle/jit/dy2static/__init__.py b/python/paddle/jit/dy2static/__init__.py index 030d5499c2ca9..ebe3ba716ffc2 100644 --- a/python/paddle/jit/dy2static/__init__.py +++ b/python/paddle/jit/dy2static/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .base import saw +from .base import UndefinedVar from .convert_call_func import convert_call # noqa: F401 from .convert_operators import cast_bool_if_necessary # noqa: F401 from .convert_operators import convert_assert # noqa: F401 diff --git a/python/paddle/jit/dy2static/base.py b/python/paddle/jit/dy2static/base.py new file mode 100644 index 0000000000000..8b902f386174c --- /dev/null +++ b/python/paddle/jit/dy2static/base.py @@ -0,0 +1,19 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import print_function + +from ...fluid.dygraph.dygraph_to_static.utils import saw # noqa: F401 +from ...fluid.dygraph.dygraph_to_static.utils import UndefinedVar # noqa: F401 + +__all__ = []