From 41bfad5d9d7183d25ab0573891e45b1ef9e246ff Mon Sep 17 00:00:00 2001 From: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> Date: Mon, 21 Jun 2021 03:05:31 +0000 Subject: [PATCH] add op_compat for the seqpool_cvm_concat_fuse_pass, test=develop --- .../ir/seqpool_cvm_concat_fuse_pass.cc | 42 +++++++++++++++++ .../ir/seqpool_cvm_concat_fuse_pass.h | 2 +- paddle/fluid/operators/compat/cvm.pbtxt | 39 +++++++++++++++ .../operators/compat/sequence_pool.pbtxt | 47 +++++++++++++++++++ 4 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/compat/cvm.pbtxt create mode 100644 paddle/fluid/operators/compat/sequence_pool.pbtxt diff --git a/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.cc b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.cc index 6bff4a05627d3..50364d2d0109d 100644 --- a/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.cc +++ b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.cc @@ -52,6 +52,48 @@ static void GetConcatNodes(ir::Graph* graph, std::vector* concat_nodes) { } } // anonymous namespace +SeqPoolCVMConcatFusePass::SeqPoolCVMConcatFusePass() { + AddOpCompat(OpCompat("sequence_pool")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("MaxIndex") + .End() + .AddAttr("pooltype") + .End() + .AddAttr("pad_value") + .End(); + AddOpCompat(OpCompat("cvm")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("CVM") + .IsTensor() + .End() + .AddOutput("Y") + .IsTensor() + .End() + .AddAttr("use_cvm") + .IsBoolEQ(true) + .End(); + AddOpCompat(OpCompat("concat")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("AxisTensor") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumGE(1) + .End(); +} + void SeqPoolCVMConcatFusePass::ApplyImpl(ir::Graph* graph) const { FusePassBase::Init("seqpool_cvm_concat_fuse", graph); std::vector concat_nodes; diff --git a/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h index b0a3573fb59f9..7680c30e485a8 100644 --- a/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h +++ b/paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h @@ -44,7 +44,7 @@ class Graph; class SeqPoolCVMConcatFusePass : public FusePassBase { public: - virtual ~SeqPoolCVMConcatFusePass() {} + SeqPoolCVMConcatFusePass(); protected: void ApplyImpl(ir::Graph* graph) const override; diff --git a/paddle/fluid/operators/compat/cvm.pbtxt b/paddle/fluid/operators/compat/cvm.pbtxt new file mode 100644 index 0000000000000..ccbeabc1f1511 --- /dev/null +++ b/paddle/fluid/operators/compat/cvm.pbtxt @@ -0,0 +1,39 @@ +type: "cvm" +def { + inputs { + name: "X" + } + inputs { + name: "CVM" + } + outputs { + name: "Y" + } + attrs { + name: "use_cvm" + type: BOOLEAN + } +} +extra { + attrs { + name: "op_role" + type: INT + } + attrs { + name: "op_role_var" + type: STRINGS + } + attrs { + name: "op_namescope" + type: STRING + } + attrs { + name: "op_callstack" + type: STRINGS + } + attrs { + name: "op_device" + type: STRING + } +} + diff --git a/paddle/fluid/operators/compat/sequence_pool.pbtxt b/paddle/fluid/operators/compat/sequence_pool.pbtxt new file mode 100644 index 0000000000000..c45f457fe0d9f --- /dev/null +++ b/paddle/fluid/operators/compat/sequence_pool.pbtxt @@ -0,0 +1,47 @@ +type: "sequence_pool" +def { + inputs { + name: "X" + } + outputs { + name: "Out" + } + outputs { + name: "MaxIndex" + } + attrs { + name: "pooltype" + type: STRING + } + attrs { + name: "pad_value" + type: FLOAT + } +} +extra { + attrs { + name: "is_test" + type: BOOLEAN + } + attrs { + name: "op_role" + type: INT + } + attrs { + name: "op_role_var" + type: STRINGS + } + attrs { + name: "op_namescope" + type: STRING + } + attrs { + name: "op_callstack" + type: STRINGS + } + attrs { + name: "op_device" + type: STRING + } +} +