From 2e97faf140f2b8829e45e8d09f9669c862ea5efb Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Tue, 29 Jun 2021 13:57:45 +0800 Subject: [PATCH] [ pass_enhance ]transpose_flatten_concat_fuse_pass (#33744) --- .../ir/transpose_flatten_concat_fuse_pass.cc | 49 ++++++++++++++++++- .../ir/transpose_flatten_concat_fuse_pass.h | 5 +- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc index 50d6b97bbea8e..523c216132646 100644 --- a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc +++ b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc @@ -19,7 +19,50 @@ namespace paddle { namespace framework { namespace ir { -void RunTransposeFlattenConcatFuse(ir::Graph *graph, int times) { +TransposeFlattenConcatFusePass::TransposeFlattenConcatFusePass() { + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("axis") + .IsType>() + .End(); + AddOpCompat(OpCompat("flatten2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumGE(0) + .End(); + AddOpCompat(OpCompat("concat")) + .AddInput("X") // Input("X"): vector + .End() + .AddInput("AxisTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({0, 1}) + .End(); +} + +void TransposeFlattenConcatFusePass::RunTransposeFlattenConcatFuse( + ir::Graph *graph, int times) const { const std::string pattern_name = "transpose_flatten" + std::to_string(times) + "_concat_fuse"; @@ -37,6 +80,10 @@ void RunTransposeFlattenConcatFuse(ir::Graph *graph, int times) { auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, Graph *g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass in op compat failed."; + return; + } const int kNumFields = 5; const int kTransOffset = 1; const int kTransOutOffset = 2; diff --git a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h index 939a8c31e5501..7c3ef2986e27e 100644 --- a/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h +++ b/paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.h @@ -16,7 +16,6 @@ #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" namespace paddle { namespace framework { @@ -28,10 +27,14 @@ namespace ir { // structure. class TransposeFlattenConcatFusePass : public FusePassBase { public: + TransposeFlattenConcatFusePass(); virtual ~TransposeFlattenConcatFusePass() {} protected: void ApplyImpl(ir::Graph* graph) const override; + + private: + void RunTransposeFlattenConcatFuse(ir::Graph* graph, int times) const; }; } // namespace ir