Skip to content

Commit

Permalink
[ pass_enhance ]transpose_flatten_concat_fuse_pass (#33744)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wangzheee authored Jun 29, 2021
1 parent 07eeb36 commit 2e97faf
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 2 deletions.
49 changes: 48 additions & 1 deletion paddle/fluid/framework/ir/transpose_flatten_concat_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,50 @@ namespace paddle {
namespace framework {
namespace ir {

void RunTransposeFlattenConcatFuse(ir::Graph *graph, int times) {
TransposeFlattenConcatFusePass::TransposeFlattenConcatFusePass() {
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("flatten2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumGE(0)
.End();
AddOpCompat(OpCompat("concat"))
.AddInput("X") // Input("X"): vector<tensors>
.End()
.AddInput("AxisTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({0, 1})
.End();
}

void TransposeFlattenConcatFusePass::RunTransposeFlattenConcatFuse(
ir::Graph *graph, int times) const {
const std::string pattern_name =
"transpose_flatten" + std::to_string(times) + "_concat_fuse";

Expand All @@ -37,6 +80,10 @@ void RunTransposeFlattenConcatFuse(ir::Graph *graph, int times) {

auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
const int kNumFields = 5;
const int kTransOffset = 1;
const int kTransOutOffset = 2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include <memory>

#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"

namespace paddle {
namespace framework {
Expand All @@ -28,10 +27,14 @@ namespace ir {
// structure.
class TransposeFlattenConcatFusePass : public FusePassBase {
public:
TransposeFlattenConcatFusePass();
virtual ~TransposeFlattenConcatFusePass() {}

protected:
void ApplyImpl(ir::Graph* graph) const override;

private:
void RunTransposeFlattenConcatFuse(ir::Graph* graph, int times) const;
};

} // namespace ir
Expand Down

0 comments on commit 2e97faf

Please sign in to comment.