Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2static] FunctionScopeVisitor Enhance and substitute the original NameVisitor in If #43967

Merged
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def transfer_from_node_type(self, node_wrapper):
BreakContinueTransformer, # break/continue in loops
ReturnTransformer, # return in functions
LogicalTransformer, # logical and/or/not
#CreateVariableTransformer, # create undefined var for if / while / for
LoopTransformer, # for/while -> while_op
IfElseTransformer, # if/else -> cond_op
AssertTransformer, # assert statement
Expand Down
60 changes: 39 additions & 21 deletions python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re

from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.framework import core, Variable
Expand All @@ -21,7 +23,7 @@
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, Dygraph2StaticException


def convert_while_loop(cond, body, getter, setter):
Expand All @@ -41,11 +43,9 @@ def convert_while_loop(cond, body, getter, setter):
# If loop_vars is changed during cond callable, then it causes bug, but current logical_and/logical_not/... doesn't change the loop_vars.
pred = cond()
if isinstance(pred, Variable):
loop_vars = _run_paddle_while(cond, body, getter, setter)
_run_paddle_while(cond, body, getter, setter)
else:
loop_vars = _run_py_while(cond, body, getter, setter)

return loop_vars
_run_py_while(cond, body, getter, setter)


def _run_paddle_while(cond, body, getter, setter):
Expand All @@ -61,10 +61,13 @@ def _run_paddle_while(cond, body, getter, setter):


def _run_py_while(cond, body, getter, setter):
loop_vars = getter()
while cond():
loop_vars = body()
return loop_vars
while True:
pred = cond()
if isinstance(pred, Variable):
raise Dygraph2StaticException(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise这个error时会触发用户源码行的标记么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个会触发源码行的标记:
捕获

"python while pred change from bool to variable.")
if not pred: break
body()


def convert_logical_and(x_func, y_func):
Expand Down Expand Up @@ -231,17 +234,32 @@ def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,

def new_true_fn():
set_args(init_args)
outs = true_fn()
_check_no_undefined_var(outs, return_name_ids, 'if_body')
return outs
ret = true_fn()
# IfExpr will return a non-None return value, so we just return ret.
# We assume normal return has no return value.
if ret is None: return get_args()
else: return ret

def new_false_fn():
set_args(init_args)
outs = false_fn()
_check_no_undefined_var(outs, return_name_ids, 'else_body')
return outs

cond_outs = control_flow.cond(pred, new_true_fn, new_false_fn)
ret = false_fn()
if ret is None: return get_args()
else: return ret

try:
cond_outs = control_flow.cond(pred, new_true_fn, new_false_fn, None,
return_name_ids)
except Exception as e:
if re.search("Unsupported return type of true_fn and false_fn in cond",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里最好加个单测确保这行代码按照预期触发,因为报错信息有可能被别人迭代优化后,这个分支就失效了

str(e)):
raise Dygraph2StaticException(
"Your if/else have different return type. TODO: add link to modifty. {}"
.format(str(e)))
if re.search("Incompatible return values of", str(e)):
raise Dygraph2StaticException(
"Your if/else have different number of return value. TODO: add link to modifty. {}"
.format(str(e)))
raise e
return _recover_args_state(cond_outs, get_args, set_args, return_name_ids)


Expand All @@ -251,8 +269,7 @@ def _run_py_ifelse(pred, true_fn, false_fn, get_args, set_args,
Evaluate python original branch function if-else.
"""
py_outs = true_fn() if pred else false_fn()
py_outs = _remove_no_value_return_var(py_outs)
return _recover_args_state(py_outs, get_args, set_args, return_name_ids)
return py_outs


def _remove_no_value_return_var(out):
Expand Down Expand Up @@ -317,9 +334,10 @@ def _recover_args_state(outs, get_args, set_args, return_name_ids):
assert num_outs <= num_args

if num_args == 1:
final_outs = (outs, )
final_outs = (outs, ) if not isinstance(outs,
(list, tuple)) else tuple(outs)
else:
outs = (outs, ) if num_outs == 1 else outs
outs = (outs, ) if num_outs == 1 else tuple(outs)
final_outs = outs + init_args[num_outs:]

set_args(final_outs)
Expand Down
Loading