Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Dy2Static | Controlflow ]While + Cond support for python container. #45105

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/phi/core/dense_tensor_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void DenseTensor::check_memory_size() const {
"Tensor's dimension is out of bound."
"Tensor's dimension must be equal or less than the size of its "
"memory."
"But received Tensor's dimension is d%, memory's size is %d.",
"But received Tensor's dimension is %d, memory's size is %d.",
numel() * SizeOf(dtype()),
memory_size()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, Dygraph2StaticException
from paddle.fluid.layers.utils import copy_mutable_vars


def indexable(x, code=None):
Expand Down Expand Up @@ -92,7 +93,10 @@ def _run_paddle_while(cond, body, getter, setter):
# NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors.
def new_body_fn(*args):
""" wrap the body() and add return value for `while_loop`
the args may be differ from getter().
"""
mutable_loop_vars = args
setter(mutable_loop_vars)
body()
return getter()

Expand All @@ -110,7 +114,7 @@ def new_cond_fn(*args):
setter(loop_vars) # change the non-local var to variable
# variable maybe modified to inner var. change it into
loop_vars = control_flow.while_loop(new_cond_fn, new_body_fn, loop_vars)
setter(loop_vars) # change the non-local var to variable
setter(loop_vars) # change back to loop_vars
return loop_vars


Expand Down Expand Up @@ -287,15 +291,17 @@ def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
init_args = get_args()

def new_true_fn():
set_args(init_args)
#init args may contain mutable python container like [var, 2], we copy then like in while_loop
set_args(copy_mutable_vars(init_args))
ret = true_fn()
# IfExpr will return a non-None return value, so we just return ret.
# We assume normal return has no return value.
if ret is None: return get_args()
else: return ret

def new_false_fn():
set_args(init_args)
#init args may contain mutable python container like [var, 2], we copy then like in while_loop
set_args(copy_mutable_vars(init_args))
ret = false_fn()
if ret is None: return get_args()
else: return ret
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from paddle.fluid import unique_name
from paddle.fluid.framework import Variable
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_variable
from paddle.fluid.layers.utils import map_structure, is_sequence

__all__ = [
'create_bool_as_type',
Expand Down Expand Up @@ -63,9 +64,12 @@ def to_static_variable(x):
if isinstance(x, six.integer_types):
return paddle.full(shape=[1], dtype='int64', fill_value=x)
if isinstance(x, UndefinedVar) or x is None:
""" for early return case, we need a variable to represent None, current we use data_layer_not_check.
"""
for early return case, we need a variable to represent None, current we use data_layer_not_check.
"""
return create_undefined_variable()
if is_sequence(x):
return map_structure(to_static_variable, x)
return x


Expand Down
7 changes: 5 additions & 2 deletions python/paddle/fluid/layers/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,8 +1327,11 @@ def create_var_like(o_var):
if isinstance(o_var,
(Variable, ) + support_ret_buildin_type) or o_var is None:
return create_undefined_variable()
if isinstance(o_var, (tuple, list)):
return [create_undefined_variable() for i in range(len(o_var))]
if is_sequence(o_var):
"""
Create a complex container class inside the body of while, including Python list and python Dict
"""
return map_structure(lambda x: create_undefined_variable(), o_var)

if len(output_vars) != len(loop_vars):
raise ValueError("The length of loop_vars should be the same.")
Expand Down