Skip to content

Commit

Permalink
[Dy2Stat]Support nonlocal mechanism in IF ast transformer (PaddlePadd…
Browse files Browse the repository at this point in the history
…le#43666)

* [Dy2Stat]Support nonlocal mechanism in IF ast transformer

* support prune return vars in cond

* fix unittest

* fix unittest

* fix static check
  • Loading branch information
Aurelius84 authored and sneaxiy committed Jun 27, 2022
1 parent 4e97a1f commit fce1ff6
Show file tree
Hide file tree
Showing 7 changed files with 311 additions and 92 deletions.
70 changes: 59 additions & 11 deletions python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar


def convert_while_loop(cond, body, loop_vars):
Expand Down Expand Up @@ -188,25 +189,27 @@ def _run_py_logical_not(x):
return not x


def convert_ifelse(pred, true_fn, false_fn, true_args, false_args):
def convert_ifelse(pred, true_fn, false_fn, get_args, set_args,
return_name_ids):
"""
A function representation of a Python ``if/else`` statement.
Args:
pred(bool|Tensor): A boolean Tensor which determines whether to return the result of ``true_fn`` or ``false_fn`` .
true_fn(callable): A callable to be performed if ``pred`` is true.
false_fn(callable): A callable to be performed if ``pred`` is false.
true_args(tuple): Parameters of ``true_fn``.
false_args(tuple): Parameters of ``false_fn``.
get_args(callable): Get all arguments that needed in true_fn and false_fn.
set_args(callable): Update arguments that modified in trure_fn and false_fn.
Returns:
``true_fn(true_args)`` if the predicate ``pred`` is true else ``false_fn(false_args)`` .
``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` .
"""
if isinstance(pred, Variable):
out = _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args)
out = _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
return_name_ids)
else:
out = _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args)
out = _run_py_ifelse(pred, true_fn, false_fn)

return _remove_no_value_return_var(out)

Expand Down Expand Up @@ -244,14 +247,59 @@ def _remove_no_value_return_var(out):
return out


def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args):
def _check_no_undefined_var(outs, names, branch_name):
if names is None: return
if not isinstance(outs, (list, tuple)):
outs = [outs]
for var, name in zip(list(outs), names):
if isinstance(var, UndefinedVar):
raise ValueError(
"Required '{}' must be initialized both in if-else branch, but found it not initialized in '{}'."
.format(name, branch_name))


def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
return_name_ids):
"""
Paddle cond API will evaluate both ture_fn and false_fn codes.
"""
pred = cast_bool_if_necessary(pred)
return control_flow.cond(pred, lambda: true_fn(*true_args),
lambda: false_fn(*false_args))
init_args = get_args()

def new_true_fn():
set_args(init_args)
outs = true_fn()
_check_no_undefined_var(outs, return_name_ids, 'if_body')
return outs

def new_false_fn():
set_args(init_args)
outs = false_fn()
_check_no_undefined_var(outs, return_name_ids, 'else_body')
return outs

cond_outs = control_flow.cond(pred, new_true_fn, new_false_fn)
# IfExpr's return_name_ids maybe None
if return_name_ids is None:
return cond_outs

# recover args state
num_outs = len(return_name_ids)
num_args = 1 if not isinstance(init_args, tuple) else len(init_args)
assert num_outs <= num_args

if num_args == 1:
final_outs = cond_outs
else:
cond_outs = (cond_outs, ) if num_outs == 1 else cond_outs
final_outs = cond_outs + init_args[num_outs:]

set_args(final_outs)
return final_outs


def _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args):
return true_fn(*true_args) if pred else false_fn(*false_args)
def _run_py_ifelse(pred, true_fn, false_fn):
return true_fn() if pred else false_fn()


def convert_len(var):
Expand Down
Loading

0 comments on commit fce1ff6

Please sign in to comment.