Skip to content

Commit

Permalink
support autogen to remove unused composite in .yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
cyber-pioneer committed May 6, 2024
1 parent 41b7c57 commit 73ce407
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
8 changes: 4 additions & 4 deletions paddle/fluid/primitive/codegen/templates/common.j2
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ template <typename T>
{%- endmacro -%}


{%- macro args(inputs, attrs) -%} {#- Arguments are variable pass into method -#}
{{sequence('', '', ', ', inputs)}}
{%- if inputs|length>0 and attrs|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between inputs and attrs -#}
{{sequence('', '', ', ', attrs)}}
{%- macro args(arg1, arg2) -%} {#- Arguments are variable pass into method -#}
{{sequence('', '', ', ', arg1)}}
{%- if arg1|length>0 and arg2|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between arg1 and arg2 -#}
{{sequence('', '', ', ', arg2)}}
{%- endmacro -%}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ std::vector<std::vector<paddle::Tensor>> vjp_res;
for (auto arg: stop_gradients) {
vjp_res.push_back(std::vector<paddle::Tensor>(arg.size()));
}
{% if 'composite' in api and api.name in vjp_comp_white_list %}
{% if api.name in vjp_comp_white_list %}
std::string op_name = "{{api.name}}";
auto need_skip = paddle::prim::StaticCompositeContext::Instance().CheckSkipCompOps(op_name);
if (paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled() && !need_skip) {
Expand Down Expand Up @@ -115,7 +115,19 @@ for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) {
{% endif %}
{% endfor %}
{{get_mutable_attribute(api.attrs, api.name)}}
details::{{api.composite.func_name}}<LazyTensor>({{api.composite.func_args}});

{%- set args_names=[] -%}
{%- for i in api.inputs -%} {%- do args_names.append(i.name) -%} {%- endfor -%}
{%- for i in api.attrs -%}
{%- if i is mutable_attribute -%}
{%- do args_names.append(i.name~'_') -%}
{%- else -%}
{%- do args_names.append(i.name) -%}
{%- endif -%}
{%- endfor %}
{%- set outputs_names=[] -%}
{%- for i in api.outputs -%} {%- do outputs_names.append(i.name) -%} {%- endfor -%}
details::{{api.name}}<LazyTensor>({{common.args(args_names, outputs_names)}});
{% endmacro %}

{%- set api_map = {} -%}
Expand Down

0 comments on commit 73ce407

Please sign in to comment.