Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add reshape+transpose+matmul_v2 only #37847

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ if(WITH_MKLDNN)
pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_v2_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(matmul_v2_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(batch_norm_act_fuse_pass inference DIR mkldnn)
Expand Down Expand Up @@ -190,7 +191,7 @@ endif()
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass)
cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass reshape_transpose_matmul_v2_mkldnn_fuse_pass)
cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass matmul_v2_transpose_reshape_fuse_pass)
cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass)
cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass)
Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2691,12 +2691,13 @@ void patterns::DeleteQuantDequantFilterOpPattern::operator()() {
}

PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
bool with_reshape_xshape, bool with_transpose_xshape) {
const std::string &op_name, bool with_reshape_xshape,
bool with_transpose_xshape) {
auto reshape_op =
pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2");
auto transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op(op_name);

auto reshape_in = pattern->NewNode(reshape_in_repr())
->AsInput()
Expand All @@ -2717,7 +2718,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(

auto transpose_out = pattern->NewNode(transpose_out_repr())
->AsIntermediate()
->assert_is_op_input("matmul")
->assert_is_op_input(op_name)
->assert_is_op_output("transpose2", "Out");
if (!with_transpose_xshape)
transpose_out->assert_is_only_output_of_op("transpose2");
Expand All @@ -2731,7 +2732,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(

auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput()
->assert_is_op_output("matmul", "Out");
->assert_is_op_output(op_name, "Out");

reshape_op->LinksFrom({reshape_in}).LinksTo({reshape_out});
if (with_reshape_xshape) reshape_op->LinksTo({reshape_xshape});
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1556,7 +1556,8 @@ struct ReshapeTransposeMatmulPattern : public PatternBase {
const std::string& name_scope)
: PatternBase(pattern, name_scope, "reshape_transpose_matmul") {}

PDNode* operator()(bool with_reshape_xshape, bool with_transpose_xshape);
PDNode* operator()(const std::string& op_name, bool with_reshape_xshape,
bool with_transpose_xshape);

PATTERN_DECL_NODE(reshape_in);
PATTERN_DECL_NODE(reshape_op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ namespace framework {
namespace ir {

ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() {
op_name_ = "matmul";

AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -55,7 +57,7 @@ ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() {
.IsType<std::vector<int>>()
.End();

AddOpCompat(OpCompat("matmul"))
AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
Expand All @@ -82,17 +84,17 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
patterns::ReshapeTransposeMatmulPattern rtm_pattern(gpd.mutable_pattern(),
name_scope_);

rtm_pattern(with_reshape_xshape, with_transpose_xshape);
rtm_pattern(op_name_, with_reshape_xshape, with_transpose_xshape);

int found_reshape_transpose_matmul_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Op compatible check in "
"reshape_transpose_matmul_mkldnn_fuse_pass failed.";
LOG(WARNING) << "Op compatible check in reshape_transpose_" << op_name_
<< "_mkldnn_fuse_pass failed.";
return;
}
VLOG(4) << "handle ReshapeTransposeMatmulMkldnn fuse";
VLOG(4) << "handle reshape_transpose_" << op_name_ << " fuse";
GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, rtm_pattern);
Expand Down Expand Up @@ -131,8 +133,8 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
} else if (matmul_desc->Inputs().at("Y").at(0) == input_var_name) {
UpdateMatmul("Y");
} else {
throw platform::errors::InvalidArgument(
"Unexpected input to MatMul encountered.");
throw platform::errors::InvalidArgument("Unexpected input to " +
op_name_ + " encountered.");
}

std::unordered_set<const ir::Node *> nodes_to_remove{
Expand All @@ -151,7 +153,7 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_reshape_transpose_matmul_count
<< " ReshapeTransposeMatmulMkldnn patterns";
<< " ReshapeTransposeMatmul patterns for " << op_name_ << " Op";
if (with_reshape_xshape) msg_ss << " with reshape's xshape";
if (with_transpose_xshape) msg_ss << " with transpose's xshape";
string::PrettyLogDetail(msg_ss.str().c_str());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ namespace ir {
class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase {
public:
ReshapeTransposeMatmulMkldnnFusePass();
virtual ~ReshapeTransposeMatmulMkldnnFusePass() {}

protected:
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"reshape_transpose_matmul_fuse"};

void Fuse(Graph* graph, bool with_reshape_xshape,
bool with_transpose_xshape) const;
std::string op_name_;
};
} // namespace ir
} // namespace framework
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h"

#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
Expand All @@ -37,7 +38,7 @@ Scope* CreateParamScope() {
return param_scope;
}

void TestMain(bool with_xshapes) {
void TestMain(const std::string& op_name, bool with_xshapes) {
// inputs operator output
// -----------------------------------------------
// a1,w1,bias1 fc -> b1
Expand All @@ -46,7 +47,7 @@ void TestMain(bool with_xshapes) {
// a2,w2,bias2 fc -> b2
// b2 reshape -> c2
// c2 transpose -> d2
// (d1, d2) matmul -> (...)
// (d1, d2) matmul(_v2) -> (...)
Layers layers;
auto* a1 = layers.data("a1", {-1, 128, 768});
auto* w1 = layers.data("w1", {768, 768}, true);
Expand All @@ -66,7 +67,11 @@ void TestMain(bool with_xshapes) {
c2->SetShape({-1, 128, 12, 64});
auto* d2 = layers.transpose2(c2, {0, 2, 1, 3});
d2->SetShape({-1, 12, 128, 64});
layers.matmul(d1, d2);
if (op_name == "matmul_v2") {
layers.matmul_v2(d1, d2);
} else {
layers.matmul(d1, d2);
}

std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
Expand All @@ -76,8 +81,8 @@ void TestMain(bool with_xshapes) {
int total_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);

auto pass =
PassRegistry::Instance().Get("reshape_transpose_matmul_mkldnn_fuse_pass");
auto pass = PassRegistry::Instance().Get("reshape_transpose_" + op_name +
"_mkldnn_fuse_pass");
graph.reset(pass->Apply(graph.release()));

int num_reshape_nodes_after = GetNumOpNodes(graph, "reshape2");
Expand All @@ -92,7 +97,7 @@ void TestMain(bool with_xshapes) {
int removed = 8; // 2* reshape, reshape_out, transpose, transpose_out
if (with_xshapes) removed += 2; // transpose_xshape, reshape_xshape
EXPECT_EQ(total_nodes_before - removed, total_nodes_after);
auto* matmul_op_desc = GetOpNodes(graph, "matmul").at(0)->Op();
auto* matmul_op_desc = GetOpNodes(graph, op_name).at(0)->Op();

auto check = [&matmul_op_desc](std::string a) {
std::string shape_str = "fused_reshape_" + a;
Expand All @@ -108,16 +113,27 @@ void TestMain(bool with_xshapes) {

TEST(ReshapeTransposeMatmulMkldnnFusePass,
both_matmul_inputs_reshape_transpose) {
TestMain(false);
TestMain("matmul", false);
}

TEST(ReshapeTransposeMatmulMkldnnFusePass,
both_matmul_inputs_reshape_transpose_one_with_xshapes) {
TestMain(true);
TestMain("matmul", true);
}

TEST(ReshapeTransposeMatmulV2MkldnnFusePass,
both_matmulv2_inputs_reshape_transpose) {
TestMain("matmul_v2", false);
}

TEST(ReshapeTransposeMatmulV2MkldnnFusePass,
both_matmulv2_inputs_reshape_transpose_one_with_xshapes) {
TestMain("matmul_v2", true);
}

} // namespace ir
} // namespace framework
} // namespace paddle

USE_PASS(reshape_transpose_matmul_mkldnn_fuse_pass);
USE_PASS(reshape_transpose_matmul_v2_mkldnn_fuse_pass);
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"

namespace paddle {
namespace framework {
namespace ir {

ReshapeTransposeMatmulV2MkldnnFusePass::
ReshapeTransposeMatmulV2MkldnnFusePass() {
op_name_ = "matmul_v2";

AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
// The reshape2 op for this pass should not have "Shape" and "ShapeTensor"
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();

AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();

AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(reshape_transpose_matmul_v2_mkldnn_fuse_pass,
paddle::framework::ir::ReshapeTransposeMatmulV2MkldnnFusePass);

REGISTER_PASS_CAPABILITY(reshape_transpose_matmul_v2_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("transpose2", 0)
.EQ("reshape2", 0));
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>

#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"

namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse Reshape->Transpose->MatMulV2 when MatMulV2 uses mkldnn.
*/

class ReshapeTransposeMatmulV2MkldnnFusePass
: public ReshapeTransposeMatmulMkldnnFusePass {
public:
ReshapeTransposeMatmulV2MkldnnFusePass();
virtual ~ReshapeTransposeMatmulV2MkldnnFusePass() {}

protected:
const std::string name_scope_{"reshape_transpose_matmul_v2_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
13 changes: 13 additions & 0 deletions paddle/fluid/framework/ir/pass_tester_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,19 @@ struct Layers {
return out;
}

VarDesc* matmul_v2(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr,
bool trans_x = false, bool trans_y = false) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("matmul_v2");
op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("trans_x", trans_x);
op->SetAttr("trans_y", trans_y);
return out;
}

VarDesc* transpose2(VarDesc* x, std::vector<int> axis,
bool with_xshape = false) {
VarDesc* out = lod_tensor(unique_name());
Expand Down
Loading