Skip to content

Commit

Permalink
enabled split op for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
jakpiase committed Jun 22, 2021
1 parent e72f151 commit 2bfbf62
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
10 changes: 5 additions & 5 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2262,11 +2262,11 @@ PDNode *patterns::QuantizePlacement::operator()(
PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>({"concat", "conv2d", "conv2d_transpose",
"elementwise_add", "elementwise_mul",
"fc", "fusion_gru", "gelu", "layer_norm",
"matmul", "pool2d", "relu", "reshape2",
"softmax", "sum", "transpose2"});
std::unordered_set<std::string>(
{"concat", "conv2d", "conv2d_transpose", "elementwise_add",
"elementwise_mul", "fc", "fusion_gru", "gelu", "layer_norm",
"matmul", "pool2d", "relu", "reshape2", "softmax", "split", "sum",
"transpose2"});
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/operators/split_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ This operator splits the input tensor into multiple sub-tensors.
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "bfloat16"});
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ def setUp(self):
self.num = 0
self.init_data()
self.inputs = {'X': self.x}
self.attrs = {'use_mkldnn': True, 'num': self.num}
self.attrs = {
'use_mkldnn': True,
'num': self.num,
'mkldnn_data_type': "bfloat16"
}

if self.axis is not None:
self.attrs['axis'] = self.axis
Expand All @@ -56,7 +60,7 @@ def setUp(self):
for i in range(len(self.out))]}

def test_check_output(self):
self.check_output(check_dygraph=False)
self.check_output_with_place(core.CPUPlace())


# TODO jakpiase enable grad check(concat op)
Expand Down

0 comments on commit 2bfbf62

Please sign in to comment.