Skip to content

Commit

Permalink
[Prim][VJP]support autogen to remove unused composite in .yaml (#64054)
Browse files Browse the repository at this point in the history
* support autogen to remove unused composite in .yaml

* fix bug
  • Loading branch information
cyber-pioneer authored May 7, 2024
1 parent 9401dfe commit 643a9db
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
3 changes: 0 additions & 3 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,6 @@
'relu_grad',
'sigmoid_grad',
'silu_grad',
'exp_grad',
'log_grad',
'abs_double_grad',
'softmax_grad',
'sqrt_grad',
] # custom vjp list of composite op
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/primitive/codegen/templates/common.j2
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ template <typename T>
{%- endmacro -%}


{%- macro args(inputs, attrs) -%} {#- Arguments are variable pass into method -#}
{{sequence('', '', ', ', inputs)}}
{%- if inputs|length>0 and attrs|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between inputs and attrs -#}
{{sequence('', '', ', ', attrs)}}
{%- macro args(arg1, arg2) -%} {#- Arguments are variable pass into method -#}
{{sequence('', '', ', ', arg1)}}
{%- if arg1|length>0 and arg2|length > 0 -%} {{", "}} {%- endif -%} {#- append comma between arg1 and arg2 -#}
{{sequence('', '', ', ', arg2)}}
{%- endmacro -%}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ std::vector<std::vector<paddle::Tensor>> vjp_res;
for (auto arg: stop_gradients) {
vjp_res.push_back(std::vector<paddle::Tensor>(arg.size()));
}
{% if 'composite' in api and api.name in vjp_comp_white_list %}
{% if api.name in vjp_comp_white_list %}
std::string op_name = "{{api.name}}";
auto need_skip = paddle::prim::StaticCompositeContext::Instance().CheckSkipCompOps(op_name);
if (paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled() && !need_skip) {
Expand Down Expand Up @@ -115,7 +115,13 @@ for (size_t i=0; i< stop_gradients[{{i}}].size(); i++ ) {
{% endif %}
{% endfor %}
{{get_mutable_attribute(api.attrs, api.name)}}
details::{{api.composite.func_name}}<LazyTensor>({{api.composite.func_args}});

{%- set args_names=[] -%}
{%- for i in api.inputs -%} {%- do args_names.append(i.name) -%} {%- endfor -%}
{%- for i in api.attrs -%} {%- do args_names.append(i.name) -%} {%- endfor %}
{%- set outputs_names=[] -%}
{%- for i in api.outputs -%} {%- do outputs_names.append(i.name) -%} {%- endfor -%}
details::{{api.name}}<LazyTensor>({{common.args(args_names, outputs_names)}});
{% endmacro %}

{%- set api_map = {} -%}
Expand Down

0 comments on commit 643a9db

Please sign in to comment.