Skip to content

Commit

Permalink
[Dy2static] FunctionScopeVisitor Enhance and substitute the original …
Browse files Browse the repository at this point in the history
…NameVisitor in If (#43967)

* add support for control flow block analysis

* move FunctionNameLivenessAnalysis into utils

* pass test_ifelse.py

* remove duplicate data_layer_not_check

* pass the test_ifelse.py

* fix unittest error .

* fix all ci error in first version

* temporay disable CreateVariableTransformer

* fix ci errors

* fix function name liveness analysis bugs

* modifty def cond

* fix

* fix ci error - v2

* fix by code review

* change return_name_ids -> return_name
  • Loading branch information
2742195759 authored Jul 6, 2022
1 parent bbe9955 commit b603dd5
Show file tree
Hide file tree
Showing 19 changed files with 629 additions and 801 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def transfer_from_node_type(self, node_wrapper):
BreakContinueTransformer, # break/continue in loops
ReturnTransformer, # return in functions
LogicalTransformer, # logical and/or/not
#CreateVariableTransformer, # create undefined var for if / while / for
LoopTransformer, # for/while -> while_op
IfElseTransformer, # if/else -> cond_op
AssertTransformer, # assert statement
Expand Down
60 changes: 39 additions & 21 deletions python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re

from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.framework import core, Variable
Expand All @@ -21,7 +23,7 @@
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, Dygraph2StaticException


def convert_while_loop(cond, body, getter, setter):
Expand All @@ -41,11 +43,9 @@ def convert_while_loop(cond, body, getter, setter):
# If loop_vars is changed during cond callable, then it causes bug, but current logical_and/logical_not/... doesn't change the loop_vars.
pred = cond()
if isinstance(pred, Variable):
loop_vars = _run_paddle_while(cond, body, getter, setter)
_run_paddle_while(cond, body, getter, setter)
else:
loop_vars = _run_py_while(cond, body, getter, setter)

return loop_vars
_run_py_while(cond, body, getter, setter)


def _run_paddle_while(cond, body, getter, setter):
Expand All @@ -61,10 +61,13 @@ def _run_paddle_while(cond, body, getter, setter):


def _run_py_while(cond, body, getter, setter):
loop_vars = getter()
while cond():
loop_vars = body()
return loop_vars
while True:
pred = cond()
if isinstance(pred, Variable):
raise Dygraph2StaticException(
"python while pred change from bool to variable.")
if not pred: break
body()


def convert_logical_and(x_func, y_func):
Expand Down Expand Up @@ -231,17 +234,32 @@ def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,

def new_true_fn():
set_args(init_args)
outs = true_fn()
_check_no_undefined_var(outs, return_name_ids, 'if_body')
return outs
ret = true_fn()
# IfExpr will return a non-None return value, so we just return ret.
# We assume normal return has no return value.
if ret is None: return get_args()
else: return ret

def new_false_fn():
set_args(init_args)
outs = false_fn()
_check_no_undefined_var(outs, return_name_ids, 'else_body')
return outs

cond_outs = control_flow.cond(pred, new_true_fn, new_false_fn)
ret = false_fn()
if ret is None: return get_args()
else: return ret

try:
cond_outs = control_flow.cond(pred, new_true_fn, new_false_fn, None,
return_name_ids)
except Exception as e:
if re.search("Unsupported return type of true_fn and false_fn in cond",
str(e)):
raise Dygraph2StaticException(
"Your if/else have different return type. TODO: add link to modifty. {}"
.format(str(e)))
if re.search("Incompatible return values of", str(e)):
raise Dygraph2StaticException(
"Your if/else have different number of return value. TODO: add link to modifty. {}"
.format(str(e)))
raise e
return _recover_args_state(cond_outs, get_args, set_args, return_name_ids)


Expand All @@ -251,8 +269,7 @@ def _run_py_ifelse(pred, true_fn, false_fn, get_args, set_args,
Evaluate python original branch function if-else.
"""
py_outs = true_fn() if pred else false_fn()
py_outs = _remove_no_value_return_var(py_outs)
return _recover_args_state(py_outs, get_args, set_args, return_name_ids)
return py_outs


def _remove_no_value_return_var(out):
Expand Down Expand Up @@ -317,9 +334,10 @@ def _recover_args_state(outs, get_args, set_args, return_name_ids):
assert num_outs <= num_args

if num_args == 1:
final_outs = (outs, )
final_outs = (outs, ) if not isinstance(outs,
(list, tuple)) else tuple(outs)
else:
outs = (outs, ) if num_outs == 1 else outs
outs = (outs, ) if num_outs == 1 else tuple(outs)
final_outs = outs + init_args[num_outs:]

set_args(final_outs)
Expand Down
Loading

0 comments on commit b603dd5

Please sign in to comment.