Skip to content

Commit

Permalink
add op_compat for the seqpool_cvm_concat_fuse_pass, test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Shixiaowei02 committed Jun 21, 2021
1 parent 009a163 commit 41bfad5
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 1 deletion.
42 changes: 42 additions & 0 deletions paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,48 @@ static void GetConcatNodes(ir::Graph* graph, std::vector<Node*>* concat_nodes) {
}
} // anonymous namespace

SeqPoolCVMConcatFusePass::SeqPoolCVMConcatFusePass() {
AddOpCompat(OpCompat("sequence_pool"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("MaxIndex")
.End()
.AddAttr("pooltype")
.End()
.AddAttr("pad_value")
.End();
AddOpCompat(OpCompat("cvm"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("CVM")
.IsTensor()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("use_cvm")
.IsBoolEQ(true)
.End();
AddOpCompat(OpCompat("concat"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("AxisTensor")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumGE(1)
.End();
}

void SeqPoolCVMConcatFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("seqpool_cvm_concat_fuse", graph);
std::vector<Node*> concat_nodes;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Graph;

class SeqPoolCVMConcatFusePass : public FusePassBase {
public:
virtual ~SeqPoolCVMConcatFusePass() {}
SeqPoolCVMConcatFusePass();

protected:
void ApplyImpl(ir::Graph* graph) const override;
Expand Down
39 changes: 39 additions & 0 deletions paddle/fluid/operators/compat/cvm.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
type: "cvm"
def {
inputs {
name: "X"
}
inputs {
name: "CVM"
}
outputs {
name: "Y"
}
attrs {
name: "use_cvm"
type: BOOLEAN
}
}
extra {
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}

47 changes: 47 additions & 0 deletions paddle/fluid/operators/compat/sequence_pool.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
type: "sequence_pool"
def {
inputs {
name: "X"
}
outputs {
name: "Out"
}
outputs {
name: "MaxIndex"
}
attrs {
name: "pooltype"
type: STRING
}
attrs {
name: "pad_value"
type: FLOAT
}
}
extra {
attrs {
name: "is_test"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}

0 comments on commit 41bfad5

Please sign in to comment.