Skip to content

Commit

Permalink
[ Dy2Static ]Change NameVisitor in while to FunctionScopeAnalysis (#4…
Browse files Browse the repository at this point in the history
…4155)

* change NameVisitor to FunctionScopeAnalysis

* polish the logic of undefined var in while_loop. create vars after body execution

* replace old NameVisitor in while and fix all CI

* Togather with CreateVariableTransformer

* add create_variable_transformer

* fix bugs

* merge

* fix some error, TODO: ForNodePreTransform ahead

* merge for unite PR

* fix conflict with base_transformer PR

* fix ci errors, fix [for i in range()] error

* fix according to code review
  • Loading branch information
2742195759 authored Jul 12, 2022
1 parent 8759c78 commit c5c6026
Show file tree
Hide file tree
Showing 17 changed files with 337 additions and 213 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer
from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import ReturnTransformer
from paddle.fluid.dygraph.dygraph_to_static.create_variable_transformer import CreateVariableTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer

Expand Down Expand Up @@ -96,7 +97,7 @@ def transfer_from_node_type(self, node_wrapper):
BreakContinueTransformer, # break/continue in loops
ReturnTransformer, # return in functions
LogicalTransformer, # logical and/or/not
#CreateVariableTransformer, # create undefined var for if / while / for
CreateVariableTransformer, # create undefined var for if / while / for
LoopTransformer, # for/while -> while_op
IfElseTransformer, # if/else -> cond_op
AssertTransformer, # assert statement
Expand Down
169 changes: 50 additions & 119 deletions python/paddle/fluid/dygraph/dygraph_to_static/base_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_VAR_LEN_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_VAR_NAME_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_ZIP_TO_LIST_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_TARGET_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.utils import FOR_ITER_ITERATOR_PREFIX


class BaseTransformer(gast.NodeTransformer):
Expand Down Expand Up @@ -119,32 +121,20 @@ def replace(s):


class ForLoopTuplePreTransformer(BaseTransformer):
"""
ForNodeVisitor parses 3 type statements (Here var is VarBase(Tensor) or python variable):
1). for x in range(var[*]|var.numpy()[*])
2). for x in var|var.numpy()
3). for i, x in enumerate(var|var.numpy())
We chose these 3 types because they are easier (x can be variable name iterating in var).
However, users can write tuples in Python for loop, such as
1). for var1, var2 in var|var.numpy()
2). for t in enumerate(var|var.numpy())
2). for i, (var1, var2, va3) in enumerate(var|var.numpy())
To handle these case, this method will do the rewrite tuple pre-process:
1). Non-enumerate case: for var1, var2 in var|var.numpy() will be re-written as:
for FOR_ITER_TUPLE_PREFIX_x in var | var.numpy():
var1 = FOR_ITER_TUPLE_PREFIX_x[0]
var2 = FOR_ITER_TUPLE_PREFIX_x[1]
2). Enumerate out tuple case: for t in enumerate(var|var.numpy) will be rewritten as:
for FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x in enumerate(var|var.numpy):
t = (FOR_ITER_TUPLE_INDEX_PREFIX_x, FOR_ITER_TUPLE_PREFIX_x)
3). Enumerate inner tuple case: for i, (var1, (var2, va3)) in enumerate(var|var.numpy()) will
be re-written as:
for i, FOR_ITER_TUPLE_PREFIX_x in var | var.numpy():
var1 = FOR_ITER_TUPLE_PREFIX_x[0]
var2 = FOR_ITER_TUPLE_PREFIX_x[1][0]
var3 = FOR_ITER_TUPLE_PREFIX_x[1][1]
""" pre-process of for loop.
>>> for A in B:
>>> C
will be changed into :
>>> UUID_iterator = _jst.Indexable(B) # make iterator-only to indexable list.
>>> for UUID_target in UUID_iterator:
>>> A = _jst.Unpack(UUID_target, structure)
>>> C
make the later loop_transform have unified type:
>>> for target in iter:
>>> body
"""

def __init__(self, wrapper_root):
Expand All @@ -155,104 +145,45 @@ def transform(self):
self.visit(self.root)

def visit_For(self, node):
if self.is_for_enumerate_iter(node):
if isinstance(node.target, (gast.Name, gast.Attribute)):
# Out tuple case
out_tuple_name = ast_to_source_code(node.target).strip()
tuple_iter_name = unique_name.generate(
FOR_ITER_TUPLE_INDEX_PREFIX)
tuple_var_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX)
node.target = gast.Tuple(elts=[
gast.Name(id=tuple_iter_name,
ctx=gast.Store(),
annotation=None,
type_comment=None),
gast.Name(id=tuple_var_name,
ctx=gast.Store(),
self.generic_visit(node)
tuple_target = unique_name.generate(FOR_ITER_TARGET_PREFIX)
tuple_iterator = unique_name.generate(FOR_ITER_ITERATOR_PREFIX)
origin_tuple_node = node.target
assign_iterator_node = gast.parse(
f"{tuple_iterator} = _jst.Indexable({ast_to_source_code(node.iter).strip()})"
).body[0]
node.target = gast.Name(id=tuple_target,
ctx=gast.Store(),
annotation=None,
type_comment=None)
node.iter = gast.Name(id=tuple_iterator,
ctx=gast.Load(),
annotation=None,
type_comment=None)
],
ctx=gast.Store())
node.body.insert(
0,
gast.Assign(targets=[
gast.Name(id=out_tuple_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=gast.Tuple(elts=[
gast.Name(id=tuple_iter_name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
gast.Name(id=tuple_var_name,
ctx=gast.Load(),
annotation=None,
type_comment=None)
],
ctx=gast.Load())))
elif isinstance(node.target, (gast.List, gast.Tuple)) and len(
node.target.elts) >= 2 and isinstance(
node.target.elts[1], (gast.List, gast.Tuple)):
# Inner tuple case
inner_tuple_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX)
origin_inner_tuple_node = node.target.elts[1]
node.target.elts[1] = gast.Name(id=inner_tuple_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
node.body[0:0] = self.tuple_to_stmts(origin_inner_tuple_node,
inner_tuple_name)
elif self.is_for_iter(node) and isinstance(node.target,
(gast.List, gast.Tuple)):
# Non-enumrate case:
tuple_name = unique_name.generate(FOR_ITER_TUPLE_PREFIX)
origin_tuple_node = node.target
node.target = gast.Name(id=tuple_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
node.body[0:0] = self.tuple_to_stmts(origin_tuple_node, tuple_name)
return node

def tuple_to_stmts(self, node, tuple_name, idx=[]):
if not isinstance(node, (gast.Tuple, gast.List)):
value_node_str = tuple_name
for i in idx:
value_node_str = value_node_str + "[{}]".format(i)

node_str = ast_to_source_code(node).strip()
assign_node_str = "{} = {}".format(node_str, value_node_str)
assign_node = gast.parse(assign_node_str).body[0]
return [assign_node]

# isinstance(node, (gast.Tuple, gast.List))
node.body[0:0] = self.tuple_to_stmts(origin_tuple_node, tuple_target)
# return a list will insert a list of node replace the original for node.
return [assign_iterator_node, node]

def tuple_node_to_unpack_structure(self, node):
""" Create a sequence to represents the structure of nest.
For example: `a, (b,c), [d,e,f]` is represented by
`[1, [1,1], [1,1,1]]`. the `1` is just a notation.
Specially, `a` is represented by `1`.
"""
ret = []
for i, element in enumerate(node.elts):
ret += self.tuple_to_stmts(node.elts[i], tuple_name, idx + [i])
if not isinstance(node, (gast.Tuple, gast.List)):
return 1
for element in node.elts:
ret.append(self.tuple_node_to_unpack_structure(element))
return ret

def is_for_iter(self, for_node):
assert isinstance(for_node,
gast.For), "Input node is not gast.For node."
if isinstance(for_node.iter, (gast.Name, gast.Attribute)):
return True
elif isinstance(for_node.iter, gast.Call) and isinstance(
for_node.iter.func,
gast.Attribute) and for_node.iter.func.attr == 'numpy':
return True
elif isinstance(for_node.iter, gast.Subscript):
return True
else:
return False

def is_for_enumerate_iter(self, for_node):
assert isinstance(for_node,
gast.For), "Input node is not gast.For node."
return isinstance(for_node.iter, gast.Call) and isinstance(
for_node.iter.func,
gast.Name) and for_node.iter.func.id == "enumerate"
def tuple_to_stmts(self, node, tuple_name):
structure_str = str(self.tuple_node_to_unpack_structure(node))
node_str = ast_to_source_code(node).strip()
assign_node_str = f"{node_str} = _jst.Unpack({tuple_name}, {structure_str})"
assign_node = gast.parse(assign_node_str).body[0]
return [assign_node]


class SplitAssignTransformer(BaseTransformer):
Expand Down
15 changes: 10 additions & 5 deletions python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,24 @@ def _no_need_convert_call(self, node):
Determines whether a function needs to be transformed by `convert_call`.
It doesn't need to be transformed when a function satisfies the following conditions:
1. It's a api of paddle
2. It's a python builtin function not include `len` and `zip`
2. It's a python builtin function not include `len`, `zip`, `range` and `enumerate`
"""
assert isinstance(node, gast.Call)
if is_paddle_api(node):
return True

func_str = ast_to_source_code(node.func).strip()
try:
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin_len, is_builtin, is_builtin_zip
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin
need_convert_builtin_func_list = {
'len',
'zip',
'range',
'enumerate',
}
is_builtin = eval("is_builtin({})".format(func_str))
is_builtin_len = eval("is_builtin_len({})".format(func_str))
is_builtin_zip = eval("is_builtin_zip({})".format(func_str))
return is_builtin and not is_builtin_len and not is_builtin_zip
need_convert = func_str in need_convert_builtin_func_list
return is_builtin and not need_convert
except Exception:
return False

Expand Down
34 changes: 19 additions & 15 deletions python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len, convert_zip
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_range, convert_enumerate
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
Expand Down Expand Up @@ -64,25 +65,22 @@ def __init__(self, not_convert=False):
self.not_convert = not_convert


def is_builtin(func):
if isinstance(func, types.BuiltinFunctionType):
def is_builtin(func, name=None):
""" predict whether a function is a builtin function with name={name}.
if name == None, then any builtin function will return True
"""

def name_judge():
return name is None or func.__name__ == name

if isinstance(func, types.BuiltinFunctionType) and name_judge():
return True
elif func in six.moves.builtins.__dict__.values():
elif func in six.moves.builtins.__dict__.values() and name_judge():
return True
else:
return False


def is_builtin_len(func):
if isinstance(func, types.BuiltinFunctionType) and func.__name__ == 'len':
return True
return False


def is_builtin_zip(func):
return is_builtin(func) and func.__name__ == 'zip'


def is_unsupported(func):
"""
Checks whether the func is supported by dygraph to static graph.
Expand Down Expand Up @@ -165,12 +163,18 @@ def dyfunc(x):
.format(func))
return func

if is_builtin_len(func):
if is_builtin(func, "len"):
return convert_len

if is_builtin_zip(func):
if is_builtin(func, "zip"):
return convert_zip

if is_builtin(func, "range"):
return convert_range

if is_builtin(func, "enumerate"):
return convert_enumerate

if is_builtin(func) or is_unsupported(func):
return func

Expand Down
Loading

0 comments on commit c5c6026

Please sign in to comment.