Skip to content

Commit

Permalink
add reshape+transpose+matmul_v2 only (#37847)
Browse files Browse the repository at this point in the history
* reshape+transpose+matmul_v2

* in_name->input_name

* fix pr-ci-static-check
  • Loading branch information
sfraczek authored Dec 14, 2021
1 parent 6a85253 commit a922168
Show file tree
Hide file tree
Showing 18 changed files with 591 additions and 71 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ if(WITH_MKLDNN)
pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_v2_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(matmul_v2_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(batch_norm_act_fuse_pass inference DIR mkldnn)
Expand Down Expand Up @@ -190,7 +191,7 @@ endif()
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass)
cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass reshape_transpose_matmul_v2_mkldnn_fuse_pass)
cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass matmul_v2_transpose_reshape_fuse_pass)
cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass)
cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass)
Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2711,12 +2711,13 @@ void patterns::DeleteQuantDequantFilterOpPattern::operator()() {
}

PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
bool with_reshape_xshape, bool with_transpose_xshape) {
const std::string &op_name, bool with_reshape_xshape,
bool with_transpose_xshape) {
auto reshape_op =
pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2");
auto transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op(op_name);

auto reshape_in = pattern->NewNode(reshape_in_repr())
->AsInput()
Expand All @@ -2737,7 +2738,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(

auto transpose_out = pattern->NewNode(transpose_out_repr())
->AsIntermediate()
->assert_is_op_input("matmul")
->assert_is_op_input(op_name)
->assert_is_op_output("transpose2", "Out");
if (!with_transpose_xshape)
transpose_out->assert_is_only_output_of_op("transpose2");
Expand All @@ -2751,7 +2752,7 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(

auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput()
->assert_is_op_output("matmul", "Out");
->assert_is_op_output(op_name, "Out");

reshape_op->LinksFrom({reshape_in}).LinksTo({reshape_out});
if (with_reshape_xshape) reshape_op->LinksTo({reshape_xshape});
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1570,7 +1570,8 @@ struct ReshapeTransposeMatmulPattern : public PatternBase {
const std::string& name_scope)
: PatternBase(pattern, name_scope, "reshape_transpose_matmul") {}

PDNode* operator()(bool with_reshape_xshape, bool with_transpose_xshape);
PDNode* operator()(const std::string& op_name, bool with_reshape_xshape,
bool with_transpose_xshape);

PATTERN_DECL_NODE(reshape_in);
PATTERN_DECL_NODE(reshape_op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ namespace framework {
namespace ir {

ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() {
op_name_ = "matmul";

AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -55,7 +57,7 @@ ReshapeTransposeMatmulMkldnnFusePass::ReshapeTransposeMatmulMkldnnFusePass() {
.IsType<std::vector<int>>()
.End();

AddOpCompat(OpCompat("matmul"))
AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
Expand All @@ -82,17 +84,17 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
patterns::ReshapeTransposeMatmulPattern rtm_pattern(gpd.mutable_pattern(),
name_scope_);

rtm_pattern(with_reshape_xshape, with_transpose_xshape);
rtm_pattern(op_name_, with_reshape_xshape, with_transpose_xshape);

int found_reshape_transpose_matmul_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Op compatible check in "
"reshape_transpose_matmul_mkldnn_fuse_pass failed.";
LOG(WARNING) << "Op compatible check in reshape_transpose_" << op_name_
<< "_mkldnn_fuse_pass failed.";
return;
}
VLOG(4) << "handle ReshapeTransposeMatmulMkldnn fuse";
VLOG(4) << "handle reshape_transpose_" << op_name_ << " fuse";
GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, rtm_pattern);
Expand Down Expand Up @@ -131,8 +133,8 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
} else if (matmul_desc->Inputs().at("Y").at(0) == input_var_name) {
UpdateMatmul("Y");
} else {
throw platform::errors::InvalidArgument(
"Unexpected input to MatMul encountered.");
throw platform::errors::InvalidArgument("Unexpected input to " +
op_name_ + " encountered.");
}

std::unordered_set<const ir::Node *> nodes_to_remove{
Expand All @@ -151,7 +153,7 @@ void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_reshape_transpose_matmul_count
<< " ReshapeTransposeMatmulMkldnn patterns";
<< " ReshapeTransposeMatmul patterns for " << op_name_ << " Op";
if (with_reshape_xshape) msg_ss << " with reshape's xshape";
if (with_transpose_xshape) msg_ss << " with transpose's xshape";
string::PrettyLogDetail(msg_ss.str().c_str());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ namespace ir {
class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase {
public:
ReshapeTransposeMatmulMkldnnFusePass();
virtual ~ReshapeTransposeMatmulMkldnnFusePass() {}

protected:
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"reshape_transpose_matmul_fuse"};

void Fuse(Graph* graph, bool with_reshape_xshape,
bool with_transpose_xshape) const;
std::string op_name_;
};
} // namespace ir
} // namespace framework
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h"

#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
Expand All @@ -37,7 +38,7 @@ Scope* CreateParamScope() {
return param_scope;
}

void TestMain(bool with_xshapes) {
void TestMain(const std::string& op_name, bool with_xshapes) {
// inputs operator output
// -----------------------------------------------
// a1,w1,bias1 fc -> b1
Expand All @@ -46,7 +47,7 @@ void TestMain(bool with_xshapes) {
// a2,w2,bias2 fc -> b2
// b2 reshape -> c2
// c2 transpose -> d2
// (d1, d2) matmul -> (...)
// (d1, d2) matmul(_v2) -> (...)
Layers layers;
auto* a1 = layers.data("a1", {-1, 128, 768});
auto* w1 = layers.data("w1", {768, 768}, true);
Expand All @@ -66,7 +67,11 @@ void TestMain(bool with_xshapes) {
c2->SetShape({-1, 128, 12, 64});
auto* d2 = layers.transpose2(c2, {0, 2, 1, 3});
d2->SetShape({-1, 12, 128, 64});
layers.matmul(d1, d2);
if (op_name == "matmul_v2") {
layers.matmul_v2(d1, d2);
} else {
layers.matmul(d1, d2);
}

std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
Expand All @@ -76,8 +81,8 @@ void TestMain(bool with_xshapes) {
int total_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);

auto pass =
PassRegistry::Instance().Get("reshape_transpose_matmul_mkldnn_fuse_pass");
auto pass = PassRegistry::Instance().Get("reshape_transpose_" + op_name +
"_mkldnn_fuse_pass");
graph.reset(pass->Apply(graph.release()));

int num_reshape_nodes_after = GetNumOpNodes(graph, "reshape2");
Expand All @@ -92,7 +97,7 @@ void TestMain(bool with_xshapes) {
int removed = 8; // 2* reshape, reshape_out, transpose, transpose_out
if (with_xshapes) removed += 2; // transpose_xshape, reshape_xshape
EXPECT_EQ(total_nodes_before - removed, total_nodes_after);
auto* matmul_op_desc = GetOpNodes(graph, "matmul").at(0)->Op();
auto* matmul_op_desc = GetOpNodes(graph, op_name).at(0)->Op();

auto check = [&matmul_op_desc](std::string a) {
std::string shape_str = "fused_reshape_" + a;
Expand All @@ -108,16 +113,27 @@ void TestMain(bool with_xshapes) {

TEST(ReshapeTransposeMatmulMkldnnFusePass,
both_matmul_inputs_reshape_transpose) {
TestMain(false);
TestMain("matmul", false);
}

TEST(ReshapeTransposeMatmulMkldnnFusePass,
both_matmul_inputs_reshape_transpose_one_with_xshapes) {
TestMain(true);
TestMain("matmul", true);
}

TEST(ReshapeTransposeMatmulV2MkldnnFusePass,
both_matmulv2_inputs_reshape_transpose) {
TestMain("matmul_v2", false);
}

TEST(ReshapeTransposeMatmulV2MkldnnFusePass,
both_matmulv2_inputs_reshape_transpose_one_with_xshapes) {
TestMain("matmul_v2", true);
}

} // namespace ir
} // namespace framework
} // namespace paddle

USE_PASS(reshape_transpose_matmul_mkldnn_fuse_pass);
USE_PASS(reshape_transpose_matmul_v2_mkldnn_fuse_pass);
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_v2_mkldnn_fuse_pass.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"

namespace paddle {
namespace framework {
namespace ir {

ReshapeTransposeMatmulV2MkldnnFusePass::
ReshapeTransposeMatmulV2MkldnnFusePass() {
op_name_ = "matmul_v2";

AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
// The reshape2 op for this pass should not have "Shape" and "ShapeTensor"
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();

AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();

AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(reshape_transpose_matmul_v2_mkldnn_fuse_pass,
paddle::framework::ir::ReshapeTransposeMatmulV2MkldnnFusePass);

REGISTER_PASS_CAPABILITY(reshape_transpose_matmul_v2_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("transpose2", 0)
.EQ("reshape2", 0));
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <string>

#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"

namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse Reshape->Transpose->MatMulV2 when MatMulV2 uses mkldnn.
*/

class ReshapeTransposeMatmulV2MkldnnFusePass
: public ReshapeTransposeMatmulMkldnnFusePass {
public:
ReshapeTransposeMatmulV2MkldnnFusePass();
virtual ~ReshapeTransposeMatmulV2MkldnnFusePass() {}

protected:
const std::string name_scope_{"reshape_transpose_matmul_v2_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
13 changes: 13 additions & 0 deletions paddle/fluid/framework/ir/pass_tester_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,19 @@ struct Layers {
return out;
}

VarDesc* matmul_v2(VarDesc* x, VarDesc* y, VarDesc* alpha = nullptr,
bool trans_x = false, bool trans_y = false) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("matmul_v2");
op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("trans_x", trans_x);
op->SetAttr("trans_y", trans_y);
return out;
}

VarDesc* transpose2(VarDesc* x, std::vector<int> axis,
bool with_xshape = false) {
VarDesc* out = lod_tensor(unique_name());
Expand Down
Loading

0 comments on commit a922168

Please sign in to comment.