From e5f30ab2da39fce6cf8419205beaebe9b749b7a4 Mon Sep 17 00:00:00 2001 From: cyber-pioneer Date: Mon, 6 May 2024 08:42:34 +0000 Subject: [PATCH 1/2] support autogen to remove unused composite in .yaml --- paddle/fluid/primitive/codegen/gen.py | 3 --- .../fluid/primitive/codegen/templates/common.j2 | 8 ++++---- .../rule/vjp/generated/generated_vjp.cc.j2 | 16 ++++++++++++++-- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 09c7b0c8729f4..636b18a75aeab 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -119,9 +119,6 @@ 'relu_grad', 'sigmoid_grad', 'silu_grad', - 'exp_grad', - 'log_grad', - 'abs_double_grad', 'softmax_grad', 'sqrt_grad', ] # custom vjp list of composite op diff --git a/paddle/fluid/primitive/codegen/templates/common.j2 b/paddle/fluid/primitive/codegen/templates/common.j2 index 5f7148017ab23..b29401133db03 100644 --- a/paddle/fluid/primitive/codegen/templates/common.j2 +++ b/paddle/fluid/primitive/codegen/templates/common.j2 @@ -33,10 +33,10 @@ template {%- endmacro -%} -{%- macro args(inputs, attrs) -%} {#- Arguments are variable pass into method -#} - {{sequence('', '', ', ', inputs)}} - {%- if inputs|length>0 and attrs|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between inputs and attrs -#} - {{sequence('', '', ', ', attrs)}} +{%- macro args(arg1, arg2) -%} {#- Arguments are variable pass into method -#} + {{sequence('', '', ', ', arg1)}} + {%- if arg1|length>0 and arg2|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between arg1 and arg2 -#} + {{sequence('', '', ', ', arg2)}} {%- endmacro -%} diff --git a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 index a2ac7b1ed64cd..d8a73428a9f99 100644 --- a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 @@ -27,7 +27,7 @@ std::vector> vjp_res; for (auto arg: stop_gradients) { vjp_res.push_back(std::vector(arg.size())); } - {% if 'composite' in api and api.name in vjp_comp_white_list %} + {% if api.name in vjp_comp_white_list %} std::string op_name = "{{api.name}}"; auto need_skip = paddle::prim::StaticCompositeContext::Instance().CheckSkipCompOps(op_name); if (paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled() && !need_skip) { @@ -115,7 +115,19 @@ for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) { {% endif %} {% endfor %} {{get_mutable_attribute(api.attrs, api.name)}} -details::{{api.composite.func_name}}({{api.composite.func_args}}); + +{%- set args_names=[] -%} +{%- for i in api.inputs -%} {%- do args_names.append(i.name) -%} {%- endfor -%} +{%- for i in api.attrs -%} + {%- if i is mutable_attribute -%} + {%- do args_names.append(i.name~'_') -%} + {%- else -%} + {%- do args_names.append(i.name) -%} + {%- endif -%} +{%- endfor %} +{%- set outputs_names=[] -%} +{%- for i in api.outputs -%} {%- do outputs_names.append(i.name) -%} {%- endfor -%} +details::{{api.name}}({{common.args(args_names, outputs_names)}}); {% endmacro %} {%- set api_map = {} -%} From 14504c44a7fc2e3793d8bfc3490b179649852e83 Mon Sep 17 00:00:00 2001 From: cyber-pioneer Date: Mon, 6 May 2024 09:36:38 +0000 Subject: [PATCH 2/2] fix bug --- .../templates/rule/vjp/generated/generated_vjp.cc.j2 | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 index d8a73428a9f99..0f6f5f83d33aa 100644 --- a/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 +++ b/paddle/fluid/primitive/codegen/templates/rule/vjp/generated/generated_vjp.cc.j2 @@ -118,13 +118,7 @@ for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) { {%- set args_names=[] -%} {%- for i in api.inputs -%} {%- do args_names.append(i.name) -%} {%- endfor -%} -{%- for i in api.attrs -%} - {%- if i is mutable_attribute -%} - {%- do args_names.append(i.name~'_') -%} - {%- else -%} - {%- do args_names.append(i.name) -%} - {%- endif -%} -{%- endfor %} +{%- for i in api.attrs -%} {%- do args_names.append(i.name) -%} {%- endfor %} {%- set outputs_names=[] -%} {%- for i in api.outputs -%} {%- do outputs_names.append(i.name) -%} {%- endfor -%} details::{{api.name}}({{common.args(args_names, outputs_names)}});