Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2stat]support Python3 type annotation #36544

Merged
merged 11 commits into from
Nov 3, 2021
8 changes: 4 additions & 4 deletions python/paddle/fluid/dygraph/dygraph_to_static/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,10 @@ def create_message(self):
message_lines.append("")

# Add paddle traceback after user code traceback
paddle_traceback_start_idnex = user_code_traceback_index[
paddle_traceback_start_index = user_code_traceback_index[
-1] + 1 if user_code_traceback_index else 0
for filepath, lineno, funcname, code in self.origin_traceback[
paddle_traceback_start_idnex:]:
paddle_traceback_start_index:]:
traceback_frame = TraceBackFrame(
Location(filepath, lineno), funcname, code)
message_lines.append(traceback_frame.formated_message())
Expand Down Expand Up @@ -305,10 +305,10 @@ def _simplify_error_value(self):
error_frame.append("")

# Add paddle traceback after user code traceback
paddle_traceback_start_idnex = user_code_traceback_index[
paddle_traceback_start_index = user_code_traceback_index[
-1] + 1 if user_code_traceback_index else 0
for filepath, lineno, funcname, code in error_traceback[
paddle_traceback_start_idnex:]:
paddle_traceback_start_index:]:
traceback_frame = TraceBackFrame(
Location(filepath, lineno), funcname, code)
error_frame.append(traceback_frame.formated_message())
Expand Down
90 changes: 74 additions & 16 deletions python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from __future__ import print_function

from paddle.utils import gast
from .utils import is_paddle_api, is_dygraph_api, is_numpy_api, index_in_list
from .logging_utils import warn
from .utils import is_paddle_api, is_dygraph_api, is_numpy_api, index_in_list, ast_to_source_code

__all__ = ['AstNodeWrapper', 'NodeVarType', 'StaticAnalysisVisitor']

Expand Down Expand Up @@ -57,6 +58,15 @@ class NodeVarType(object):
# If node.node_var_type in TENSOR_TYPES, it can be considered as tensor-dependent.
TENSOR_TYPES = {TENSOR, PADDLE_RETURN_TYPES}

Annotation_map = {
"Tensor": TENSOR,
"paddle.Tensor": TENSOR,
"int": INT,
"float": FLOAT,
"bool": BOOLEAN,
"str": STRING
}

@staticmethod
def binary_op_output_type(in_type1, in_type2):
if in_type1 == in_type2:
Expand All @@ -83,6 +93,16 @@ def binary_op_output_type(in_type1, in_type2):
return NodeVarType.UNKNOWN
return max(in_type1, in_type2)

@staticmethod
def type_from_annotation(annotation):
annotation_str = ast_to_source_code(annotation).strip()
if annotation_str in NodeVarType.Annotation_map:
return NodeVarType.Annotation_map[annotation_str]

# raise warning if not found
warn("Currently we don't support annotation: %s" % annotation_str)
return NodeVarType.UNKNOWN


class AstNodeWrapper(object):
"""
Expand Down Expand Up @@ -316,6 +336,18 @@ def _get_node_var_type(self, cur_wrapper):
self.var_env.set_var_type(target.id, ret_type)
return ret_type

if isinstance(node, gast.AnnAssign):
# TODO(0x45f): To determine whether need to support assignment statements
# like `self.x: float = 2.1`.
ret_type = {NodeVarType.type_from_annotation(node.annotation)}
# if annotation and value(Constant) are diffent type, we use value type
if node.value:
ret_type = self.node_to_wrapper_map[node.value].node_var_type
if isinstance(node.target, gast.Name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里只会对类似 x: float = 2.1 作处理,对于类似self.x: float = 2.1 是不会处理? 因为self.x是一个gast.Attribute,而不是一个gast.Name.

此PR 若是不支持的,这里记一个TODO,需要确认下是否有类似非gats.Name的场景和需求

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在之前代码中的Assign node也没有支持类似self.x= 2.1赋值语句,所以在AnnAssign node也暂时不支持self.x: float = 2.1类似语句。后续可以确认是否需要支持此类赋值语句。

self.node_to_wrapper_map[node.target].node_var_type = ret_type
self.var_env.set_var_type(node.target.id, ret_type)
return ret_type

if isinstance(node, gast.Name):
if node.id == "None":
return {NodeVarType.NONE}
Expand All @@ -325,21 +357,8 @@ def _get_node_var_type(self, cur_wrapper):
parent_node_wrapper = cur_wrapper.parent
if parent_node_wrapper and isinstance(parent_node_wrapper.node,
gast.arguments):
parent_node = parent_node_wrapper.node
var_type = {NodeVarType.UNKNOWN}
if parent_node.defaults:
index = index_in_list(parent_node.args, node)
args_len = len(parent_node.args)
if index != -1 and args_len - index <= len(
parent_node.defaults):
defaults_node = parent_node.defaults[index - args_len]
if isinstance(defaults_node, gast.Constant):
var_type = self._get_constant_node_type(
defaults_node)

# Add node with identified type into cur_env.
self.var_env.set_var_type(node.id, var_type)
return var_type

return self._get_func_argument_type(parent_node_wrapper, node)

return self.var_env.get_var_type(node.id)

Expand Down Expand Up @@ -373,3 +392,42 @@ def _get_node_var_type(self, cur_wrapper):
return {NodeVarType.TENSOR}

return {NodeVarType.STATEMENT}

def _get_func_argument_type(self, parent_node_wrapper, node):
"""
Returns type information by parsing annotation or default values.

For example:
1. parse by default values.
foo(x, y=1, z='s') -> x: UNKNOWN, y: INT, z: STR

2. parse by Py3 type annotation.
foo(x: Tensor, y: int, z: str) -> x: Tensor, y: INT, z: STR

3. parse by type annotation and default values.
foo(x: Tensor, y: int, z: str = 'abc') -> x: Tensor, y: INT, z: STR

NOTE: Currently, we only support Tensor, int, bool, float, str et.al.
Other complicate types will be supported later.
"""
assert isinstance(node, gast.Name)

parent_node = parent_node_wrapper.node
var_type = {NodeVarType.UNKNOWN}
if node.annotation is not None:
var_type = {NodeVarType.type_from_annotation(node.annotation)}
self.var_env.set_var_type(node.id, var_type)

# if annotation and value(Constant) are diffent type, we use value type
if parent_node.defaults:
index = index_in_list(parent_node.args, node)
args_len = len(parent_node.args)
if index != -1 and args_len - index <= len(parent_node.defaults):
defaults_node = parent_node.defaults[index - args_len]
if isinstance(defaults_node, gast.Constant):
var_type = self._get_constant_node_type(defaults_node)

# Add node with identified type into cur_env.
self.var_env.set_var_type(node.id, var_type)

return var_type
3 changes: 2 additions & 1 deletion python/paddle/fluid/dygraph/dygraph_to_static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,8 @@ def remove_if_exit(filepath):

def _inject_import_statements():
import_statements = [
"import paddle", "import paddle.fluid as fluid", "from typing import *",
"import paddle", "from paddle import Tensor",
"import paddle.fluid as fluid", "from typing import *",
"import numpy as np"
]
return '\n'.join(import_statements) + '\n'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def set_test_func(self):
self.func = simple_func

def set_static_lineno(self):
self.static_abs_lineno_list = [5, 6, 7]
self.static_abs_lineno_list = [6, 7, 8]

def set_dygraph_info(self):
self.line_num = 3
Expand Down Expand Up @@ -149,7 +149,7 @@ def set_test_func(self):
self.func = nested_func

def set_static_lineno(self):
self.static_abs_lineno_list = [5, 7, 8, 9, 10]
self.static_abs_lineno_list = [6, 8, 9, 10, 11]

def set_dygraph_info(self):
self.line_num = 5
Expand All @@ -174,7 +174,7 @@ def set_test_func(self):
self.func = decorated_func

def set_static_lineno(self):
self.static_abs_lineno_list = [5, 6]
self.static_abs_lineno_list = [6, 7]

def set_dygraph_info(self):
self.line_num = 2
Expand Down Expand Up @@ -208,7 +208,7 @@ def set_test_func(self):
self.func = decorated_func2

def set_static_lineno(self):
self.static_abs_lineno_list = [5, 6]
self.static_abs_lineno_list = [6, 7]

def set_dygraph_info(self):
self.line_num = 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def func_to_test3():
h = None
i = False
j = None + 1
k: float = 1.0
l: paddle.Tensor = paddle.to_tensor([1, 2])


result_var_type3 = {
Expand All @@ -69,7 +71,9 @@ def func_to_test3():
'g': {NodeVarType.STRING},
'h': {NodeVarType.NONE},
'i': {NodeVarType.BOOLEAN},
'j': {NodeVarType.UNKNOWN}
'j': {NodeVarType.UNKNOWN},
'k': {NodeVarType.FLOAT},
'l': {NodeVarType.PADDLE_RETURN_TYPES}
}


Expand Down Expand Up @@ -139,13 +143,25 @@ def add(x, y):
'add': {NodeVarType.INT}
}


def func_to_test7(a: int, b: float, c: paddle.Tensor, d: float='diff'):
a = True


result_var_type7 = {
'a': {NodeVarType.BOOLEAN},
'b': {NodeVarType.FLOAT},
'c': {NodeVarType.TENSOR},
'd': {NodeVarType.STRING}
}

test_funcs = [
func_to_test1, func_to_test2, func_to_test3, func_to_test4, func_to_test5,
func_to_test6
func_to_test6, func_to_test7
]
result_var_type = [
result_var_type1, result_var_type2, result_var_type3, result_var_type4,
result_var_type5, result_var_type6
result_var_type5, result_var_type6, result_var_type7
]


Expand Down