Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2static] FunctionScopeVisitor Enhance and substitute the original NameVisitor in If #43967

Merged

Conversation

2742195759
Copy link
Contributor

@2742195759 2742195759 commented Jun 30, 2022

PR types

Others

PR changes

Others

Describe

FunctionScopeVisitor Enhance and substitute the original NameVisitor in If.
这个PR主要目标是增强了FunctionScopeVisitor的功能,支持控制流的 name 解析。并且替换了IF的NameVisitor和名字解析逻辑。其余的修改是针对上述修改的单测修复。
新的FunctionScopeVisitor主要有如下几个优点:

  1. 支持闭包分析:目前的FunctionScopeVIsitor支持了闭包的模拟机制。因此不会出现闭包中的变量被误解析的问题。之前的if就有这种问题。
  2. 统一了名字解析逻辑并简化。为了保证正确性,我们将FunctionScope中解析的所有需要修改的name作为了控制流的输入。这样可以保证正确性、简单性和鲁棒性
  3. 新增的了UndefinedVar来解决true、false分支定义临时变量的问题
  4. Disable 了变长的 return ,消除了之前变长return引入的bug。也因此消除了 run_python_if 中的不一致性。不会引用两个分支中的名字,不会出现name error。

TODO:

  1. 添加新解决的单测。
  2. 删除未删除干净的代码。具体文件有: return_transformer、ifelse_transformer 等
  3. 逻辑思考:是否 while 这种具有额外的参数的 body 函数会影响 FunctionScope 解析。

cond_outs = control_flow.cond(pred, new_true_fn, new_false_fn,
return_name_ids)
except Exception as e:
if re.search("Unsupported return type of true_fn and false_fn in cond",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里最好加个单测确保这行代码按照预期触发,因为报错信息有可能被别人迭代优化后,这个分支就失效了

while True:
pred = cond()
if isinstance(pred, Variable):
raise Dygraph2StaticException(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise这个error时会触发用户源码行的标记么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个会触发源码行的标记:
捕获

if re.search("Unsupported return type of true_fn and false_fn in cond",
str(e)):
raise Dygraph2StaticException(
"Your if/else have different return type. TODO: add link to modifty."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Your if/else have different return type. TODO: add link to modifty."
# TODO: add link to modifty.
"Your if/else has different return type. %s" % e

这里如果还没有加TODO的话,最好先把之前的err msg re-throw下

"Incompatible return values of true_fn and false_fn in cond",
str(e)):
raise Dygraph2StaticException(
"Your if/else have different number of return value. TODO: add link to modifty."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

@@ -52,6 +52,8 @@ def __init__(self, wrapper_root):
), "Type of input node should be AstNodeWrapper, but received %s ." % type(
wrapper_root)
self.root = wrapper_root.node
FunctionNameLivenessAnalysis(
self.root) # name analysis of current ast tree.
self.static_analysis_visitor = StaticAnalysisVisitor(self.root)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.static_analysis_visitor 没有用到了,这个我们是不是可以删除掉了?

def create_undefined_var_like(variable):
""" create a undefined var with the same shape and dtype like varaible.
"""
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续可以把RETURN_NO_VALUE_MAGIC_NUM 也放到这个文件,这里import虽然时动态import,单最好让utils成为一个叶子结点的文件,可以被其他文件import。可以后续PR优化

from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
var = data_layer_not_check(unique_name.generate("undefined_var"),
variable.shape, variable.dtype)
assign(RETURN_NO_VALUE_MAGIC_NUM, var)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么还需要一个assign?不可以通过name来过滤么

template = """
def {func_name}():
nonlocal {nonlocal_vars}
{nonlocal_vars}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果 nonlocal_vars 是空,还需要return names么?

@@ -1159,7 +1181,10 @@ def assign_skip_lod_tensor_array(input, output):
Assign input to output, but skip the process of copying LoDTensorArray unless it's created in while_block.
"""
if not isinstance(input, Variable) and not isinstance(input, core.VarBase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not isinstance(input, Variable) and not isinstance(input, core.VarBase):
if not isinstance(input, (Variablecore.VarBase)):

@@ -2377,7 +2402,7 @@ def copy_var_to_parent_block(var, layer_helper):
return parent_block_var


def cond(pred, true_fn=None, false_fn=None, name=None):
def cond(pred, true_fn=None, false_fn=None, return_name_ids=None, name=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def cond(pred, true_fn=None, false_fn=None, return_name_ids=None, name=None):
def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):

为了兼容性,新增参数一般放到最后

@2742195759 2742195759 changed the title [Dy2static] FunctionScopeVisitor Enhance and substitute the original NameVisitor in If. [Dy2static] FunctionScopeVisitor Enhance and substitute the original NameVisitor in If Jul 5, 2022
@2742195759 2742195759 marked this pull request as draft July 5, 2022 17:31
@2742195759 2742195759 marked this pull request as ready for review July 5, 2022 17:32
Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -2423,6 +2454,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
true. The default value is ``None`` .
false_fn(callable, optional): A callable to be performed if ``pred`` is
false. The default value is ``None`` .
return_names: A list of strings to represents the name of returned vars. useful to debug.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return_names: -> return_names(支持的数据类型, optional): XXX. Default is None, means XXX.
useful to debug,太 Chinese English 了,要改一下。
顺序在 name 后面。

@2742195759 2742195759 merged commit b603dd5 into PaddlePaddle:develop Jul 6, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants