Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an elementwise + activation fusion pass. #36541

Merged
merged 94 commits into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
75ef32f
Add elementwise add and activation fuse pass
tsocha Oct 5, 2021
71c9502
Fix copy ellision
tsocha Oct 6, 2021
035d8da
More flexible pattern detector
tsocha Oct 7, 2021
007462d
More flexible fusion pass
tsocha Oct 7, 2021
115dbc4
Update lists for pass
tsocha Oct 8, 2021
125afa0
Add support for Pow operator
tsocha Oct 12, 2021
18e91ed
Add support for more activation types
tsocha Oct 12, 2021
ac327fe
Style
tsocha Oct 12, 2021
c94037a
Rename fusion pass
tsocha Oct 13, 2021
301a583
First version of tests
tsocha Oct 15, 2021
f6f3bdb
Dirty version of pass
tsocha Oct 18, 2021
f44e8ff
Polished version
tsocha Oct 19, 2021
05ce32a
Update pbtxt
tsocha Oct 19, 2021
11bd667
Style
tsocha Oct 19, 2021
f5ac4b4
Update names
tsocha Oct 19, 2021
b555dc6
Style
tsocha Oct 19, 2021
4d64bc7
Use PADDLE_ENFORCE_EQ
tsocha Oct 19, 2021
e80d874
Save error message to variable
tsocha Oct 19, 2021
e8beef3
WO for error checks
tsocha Oct 19, 2021
42684ac
CR
tsocha Oct 20, 2021
cfa9a5a
Static style check
tsocha Oct 21, 2021
9fb6d03
Add missing 'activation_scale' attribute
tsocha Oct 21, 2021
b691b7e
Add relu6 and sigmoid activations
tsocha Oct 22, 2021
6bd1342
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Oct 22, 2021
13d5f97
Style
tsocha Oct 22, 2021
11471bd
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Oct 25, 2021
73828f5
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Oct 27, 2021
0c5ff36
Fix fuse list formating
tsocha Nov 2, 2021
83eab78
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Nov 2, 2021
b17a704
Sync filenames for fuse pass files
tsocha Nov 2, 2021
8281179
Fix cmake after move
tsocha Nov 2, 2021
899bcd6
Fix registration
tsocha Nov 2, 2021
6bc3b18
Fix pass name in tests
tsocha Nov 2, 2021
ec52e50
Add missing activations to checker
tsocha Nov 4, 2021
8c46991
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Nov 4, 2021
72c295f
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Nov 5, 2021
006ceeb
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Nov 8, 2021
05f21ee
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Nov 15, 2021
05e1432
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Nov 18, 2021
9d97cf7
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Nov 22, 2021
d8431a4
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Nov 24, 2021
5e4d1a4
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Dec 14, 2021
4ce460d
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Dec 15, 2021
36d6d45
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Jan 10, 2022
86adeb5
WIPS
tsocha Jan 13, 2022
a6f3917
Working mul op
tsocha Jan 26, 2022
bb4683e
Working sub
tsocha Jan 27, 2022
57547bc
Working Add
tsocha Jan 27, 2022
11c2faf
Merge commit '3115d005aa6ec9fe2ae37332be6652124b9e5543' into el_add_a…
tsocha Jan 27, 2022
951425c
Remove pten includes
tsocha Jan 27, 2022
f7bcb93
Merge commit 'c3796061c385491738ef7e27e0e89fccd75877f3' into el_add_a…
tsocha Jan 27, 2022
6f6f860
Remove some forward declarations
tsocha Jan 27, 2022
47a602b
Merge commit '7e6a2190ddff25362d395667e397f5174f2346a2' into el_add_a…
tsocha Jan 28, 2022
c943734
Remove Includes
tsocha Jan 28, 2022
4285933
Fixes
tsocha Jan 28, 2022
eedad0d
Remove default kernels
tsocha Jan 28, 2022
0dca583
Add check if post_ops attributes are avaliable
tsocha Jan 28, 2022
df214b1
Style
tsocha Jan 31, 2022
f60566a
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Jan 31, 2022
bb4eb4f
Code adjustment
tsocha Jan 31, 2022
4e1ff4b
Register default kernels
tsocha Jan 31, 2022
1eb45cb
We have year 2022 not 2021...
tsocha Feb 1, 2022
1d794f3
Fast review fixes
tsocha Feb 1, 2022
9a3fa88
Review Fix
tsocha Feb 2, 2022
9bb5419
Rename one_dnn -> onednn
tsocha Feb 2, 2022
862a5dd
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Feb 2, 2022
0bedc39
Style after review
tsocha Feb 2, 2022
40be137
Fast and dirty fix for quantization
tsocha Feb 2, 2022
1426eb3
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Feb 2, 2022
7670b5c
Update tests
tsocha Feb 3, 2022
29c2f66
Style
tsocha Feb 3, 2022
7dae446
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Feb 9, 2022
dd9ae1d
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Feb 9, 2022
5a634f0
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Feb 9, 2022
0318442
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Feb 10, 2022
8b98063
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Feb 11, 2022
84e09f6
Fix mkldnn_quantizer config
tsocha Feb 14, 2022
f188574
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Feb 14, 2022
17f794c
Add Joanna's suggestion.
tsocha Feb 15, 2022
ba3f3f6
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Feb 15, 2022
1e1346c
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Feb 16, 2022
2c470f5
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Mar 1, 2022
8a834af
Check if operator is explicitly disables on OneDNN
tsocha Mar 1, 2022
49d9a62
Try to use unregistered attributes
tsocha Mar 1, 2022
5bc87e0
Style
tsocha Mar 2, 2022
2ccab7a
Merge remote-tracking branch 'upstream/develop' into el_add_ect_fuse_…
tsocha Mar 2, 2022
54e297d
Test new framework
tsocha Mar 2, 2022
0ffa4d5
FXI
tsocha Mar 3, 2022
6f08f6c
Merge remote-tracking branch 'upstream/develop' into el_add_ect_fuse_…
tsocha Mar 3, 2022
eab52ad
FXII
tsocha Mar 3, 2022
64a8277
Update test
tsocha Mar 3, 2022
c2def84
Style
tsocha Mar 4, 2022
a813e62
Merge remote-tracking branch 'upstream/develop' into el_add_ect_fuse_…
tsocha Mar 4, 2022
eb6b5c1
Merge remote-tracking branch 'upstream/develop' into el_add_act_fuse
tsocha Mar 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ if(WITH_MKLDNN)
pass_library(fc_mkldnn_pass inference DIR mkldnn)
pass_library(interpolate_mkldnn_pass inference DIR mkldnn)
pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(elt_act_onednn_fuse_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
Expand Down
30 changes: 30 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,36 @@ PDNode *patterns::ConvActivation::operator()(
return activation_out_var;
}

PDNode *patterns::ElementwiseActivation::operator()(
paddle::framework::ir::PDNode *elementwise_a, std::string elementwise_type,
tsocha marked this conversation as resolved.
Show resolved Hide resolved
std::string activation_type) {
// Create Operators
elementwise_a->assert_is_op_input(elementwise_type, "X");
auto *elementwise_op =
pattern->NewNode(elementwise_repr())->assert_is_op(elementwise_type);
auto *activation_op =
pattern->NewNode(activation_repr())->assert_is_op(activation_type);
// Create variables
auto *elementwise_b = pattern->NewNode(elementwise_b_repr())
->AsInput()
->assert_is_op_input(elementwise_type, "Y");
// intermediate variable, will be removed in the IR after fuse.
auto *elementwise_out_var =
pattern->NewNode(elementwise_out_repr())
->AsIntermediate()
->assert_is_only_output_of_op(elementwise_type)
->assert_is_op_input(activation_type);
// output
auto *activation_out_var = pattern->NewNode(activation_out_repr())
->AsOutput()
->assert_is_op_output(activation_type);

elementwise_op->LinksFrom({elementwise_a, elementwise_b})
.LinksTo({elementwise_out_var});
activation_op->LinksFrom({elementwise_out_var}).LinksTo({activation_out_var});
return activation_out_var;
}

PDNode *patterns::SeqConvEltAddRelu::operator()(
paddle::framework::ir::PDNode *seqconv_input) {
// Create Operators
Expand Down
22 changes: 22 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,28 @@ struct ConvActivation : public PatternBase {
PATTERN_DECL_NODE(activation_out);
};

// Elementwise with Activation
// op: elementwise + activation
// named nodes:
// elementwise_a, elementwise_b,
// elementwise_out, elementwise,
// activation_out, activation
struct ElementwiseActivation : public PatternBase {
ElementwiseActivation(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "elementwise_add_activation") {}

PDNode* operator()(PDNode* elementwise_a, std::string elementwise_type,
std::string activation_type);

// declare operator node's name
PATTERN_DECL_NODE(elementwise);
PATTERN_DECL_NODE(activation);
// declare variable node's name
PATTERN_DECL_NODE(elementwise_b);
PATTERN_DECL_NODE(elementwise_out);
PATTERN_DECL_NODE(activation_out);
};

// SEQCONV with Elementwise_Add ReLU
// op: seqconv + elementwise_add + relu
// named nodes:
Expand Down
139 changes: 139 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/elt_act_onednn_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/elt_act_onednn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"

namespace paddle {
namespace framework {
namespace ir {

using string::PrettyLogDetail;

void ElementwiseActivationOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = {"relu", "tanh", "leaky_relu",
tsocha marked this conversation as resolved.
Show resolved Hide resolved
"swish", "hardswish", "sqrt",
"abs", "clip", "gelu"};
std::vector<std::string> elt_types = {"elementwise_add", "elementwise_sub",
"elementwise_mul"};

for (const auto &elt_type : elt_types)
for (const auto &act_type : act_types) {
if (act_type == "swish")
FuseElementwiseAct(graph, elt_type, act_type,
{{"beta", "activation_alpha"}});
else if (act_type == "clip")
FuseElementwiseAct(
graph, elt_type, act_type,
{{"min", "activation_alpha"}, {"max", "activation_beta"}});
else
FuseElementwiseAct(
graph, elt_type, act_type,
{{"alpha", "activation_alpha"}, {"beta", "activation_beta"}});
tsocha marked this conversation as resolved.
Show resolved Hide resolved
}
}

void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
Graph *graph, const std::string &elt_type, const std::string &act_type,
const std::unordered_map<std::string, std::string> &attr_map) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("elementwise_act", graph);

GraphPatternDetector gpd;
auto *elementwise_input = gpd.mutable_pattern()
->NewNode(elt_type + "_act/elementwise_input")
->AsInput()
->assert_is_op_input(elt_type, "X");
patterns::ElementwiseActivation elementwise_act_pattern(gpd.mutable_pattern(),
elt_type + "_act");
elementwise_act_pattern(elementwise_input, elt_type, act_type);

int found_elementwise_activation_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "Fuse " << elt_type << " with activation op.";
// Elementwise Add output
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_act_pattern);
// ACT output
GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out,
elementwise_act_pattern);
// ops
GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise,
elementwise_act_pattern);
GET_IR_NODE_FROM_SUBGRAPH(activation, activation, elementwise_act_pattern);

auto *elementwise_op = elementwise->Op();

if (elementwise_op->HasAttr("use_mkldnn")) {
PADDLE_ENFORCE(
BOOST_GET_CONST(bool, elementwise_op->GetAttr("use_mkldnn")),
platform::errors::PreconditionNotMet(
"The " + elt_type +
"+Act fusion may happen only when oneDNN library "
"is used."));
}

auto *activation_op = activation->Op();
for (const auto &attr : attr_map) {
if (activation_op->HasAttr(attr.first)) {
elementwise_op->SetAttr(attr.second,
activation_op->GetAttr(attr.first));
}
}

elementwise_op->SetAttr("activation_type", act_type);

elementwise_op->SetAttr("use_mkldnn", true);

elementwise_op->SetOutput("Out", {activation_out->Name()});

IR_OP_VAR_LINK(elementwise, activation_out);
GraphSafeRemoveNodes(g, {activation, elementwise_out});
found_elementwise_activation_count++;
};

gpd(graph, handler);
AddStatis(found_elementwise_activation_count);
PrettyLogDetail("--- fused %d %s with %s activation",
found_elementwise_activation_count, elt_type, act_type);
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(elt_act_onednn_fuse_pass,
paddle::framework::ir::ElementwiseActivationOneDNNPass);
REGISTER_PASS_CAPABILITY(elt_act_onednn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("elementwise_add", 1)
.LE("elementwise_sub", 1)
.LE("elementwise_mul", 1)
.LE("elementwise_div", 1)
.LE("elementwise_pow", 1)
tsocha marked this conversation as resolved.
Show resolved Hide resolved
.LE("relu", 0)
.LE("tanh", 0)
.LE("leaky_relu", 1)
.LE("swish", 0)
.LE("hard_swish", 0)
.LE("sqrt", 0)
.LE("abs", 0)
.LE("clip", 1)
.LE("gelu", 0));
46 changes: 46 additions & 0 deletions paddle/fluid/framework/ir/mkldnn/elt_act_onednn_fuse_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>

#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"

namespace paddle {
namespace framework {
namespace ir {

/*
* \brief Fuse the Elementwise and activation operators into single
* OneDNN's
* Elementwize with post-op.
*/
class ElementwiseActivationOneDNNPass : public FusePassBase {
public:
virtual ~ElementwiseActivationOneDNNPass() {}

protected:
void ApplyImpl(ir::Graph *graph) const override;
tsocha marked this conversation as resolved.
Show resolved Hide resolved

void FuseElementwiseAct(
ir::Graph *graph, const std::string &elt_types,
tsocha marked this conversation as resolved.
Show resolved Hide resolved
const std::string &act_types,
const std::unordered_map<std::string, std::string> &attr_map) const;
};

} // namespace ir
} // namespace framework
} // namespace paddlea
4 changes: 2 additions & 2 deletions paddle/fluid/framework/new_executor/workqueue.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ std::unique_ptr<WorkQueue> CreateMultiThreadedWorkQueue(
"WorkQueueOptions.num_threads must be "
"greater than 1."));
std::unique_ptr<WorkQueue> ptr(new WorkQueueImpl(options));
return std::move(ptr);
return ptr;
}

std::unique_ptr<WorkQueueGroup> CreateWorkQueueGroup(
Expand All @@ -176,7 +176,7 @@ std::unique_ptr<WorkQueueGroup> CreateWorkQueueGroup(
"For a WorkQueueGroup, the number of WorkQueueOptions "
"must be greater than 1."));
std::unique_ptr<WorkQueueGroup> ptr(new WorkQueueGroupImpl(queues_options));
return std::move(ptr);
return ptr;
}

} // namespace framework
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ void CpuPassStrategy::EnableMKLDNN() {
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass",
// "fc_act_mkldnn_fuse_pass",
"batch_norm_act_fuse_pass",
"batch_norm_act_fuse_pass", "elt_act_onednn_fuse_pass",
tsocha marked this conversation as resolved.
Show resolved Hide resolved
// TODO(intel): Please fix the bug on windows.
// https://github.com/PaddlePaddle/Paddle/issues/29710
// "mkldnn_inplace_pass", // This pass should be activated after
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/operators/compat/elementwise_add.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ extra {
name: "y_data_format"
type: STRING
}
attrs {
name: "activation_type"
type: STRING
}
attrs {
name: "activation_alpha"
type: FLOAT
}
attrs {
name: "activation_beta"
type: FLOAT
}
attrs {
name: "Scale_x"
type: FLOAT
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/operators/compat/elementwise_div.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ extra {
name: "y_data_format"
type: STRING
}
attrs {
name: "activation_type"
type: STRING
}
attrs {
name: "activation_alpha"
type: FLOAT
}
attrs {
name: "activation_beta"
type: FLOAT
}
attrs {
name: "Scale_x"
type: FLOAT
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/operators/compat/elementwise_mul.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ extra {
name: "y_data_format"
type: STRING
}
attrs {
name: "activation_type"
type: STRING
}
attrs {
name: "activation_alpha"
type: FLOAT
}
attrs {
name: "activation_beta"
type: FLOAT
}
attrs {
name: "Scale_x"
type: FLOAT
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/operators/compat/elementwise_pow.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ extra {
name: "y_data_format"
type: STRING
}
attrs {
name: "activation_type"
type: STRING
}
attrs {
name: "activation_alpha"
type: FLOAT
}
attrs {
name: "activation_beta"
type: FLOAT
}
attrs {
name: "Scale_x"
type: FLOAT
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/operators/compat/elementwise_sub.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ extra {
name: "y_data_format"
type: STRING
}
attrs {
name: "activation_type"
type: STRING
}
attrs {
name: "activation_alpha"
type: FLOAT
}
attrs {
name: "activation_beta"
type: FLOAT
}
attrs {
name: "Scale_x"
type: FLOAT
Expand Down
Loading