From 2bfbf627087cbb6221d5c535616666b025756c24 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Tue, 22 Jun 2021 12:13:53 +0200 Subject: [PATCH] enabled split op for inference --- paddle/fluid/framework/ir/graph_pattern_detector.cc | 10 +++++----- paddle/fluid/operators/split_op.cc | 5 +++++ .../unittests/mkldnn/test_split_bf16_mkldnn_op.py | 8 ++++++-- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 064da3d941602..573cb7dcd09b0 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2262,11 +2262,11 @@ PDNode *patterns::QuantizePlacement::operator()( PDNode *patterns::Bfloat16Placement::operator()( const std::unordered_set &bfloat16_enabled_op_types) { std::unordered_set supported_op_types = - std::unordered_set({"concat", "conv2d", "conv2d_transpose", - "elementwise_add", "elementwise_mul", - "fc", "fusion_gru", "gelu", "layer_norm", - "matmul", "pool2d", "relu", "reshape2", - "softmax", "sum", "transpose2"}); + std::unordered_set( + {"concat", "conv2d", "conv2d_transpose", "elementwise_add", + "elementwise_mul", "fc", "fusion_gru", "gelu", "layer_norm", + "matmul", "pool2d", "relu", "reshape2", "softmax", "split", "sum", + "transpose2"}); if (!bfloat16_enabled_op_types.empty()) { supported_op_types = bfloat16_enabled_op_types; } diff --git a/paddle/fluid/operators/split_op.cc b/paddle/fluid/operators/split_op.cc index 37a7575c12c2c..661e4ca727bee 100644 --- a/paddle/fluid/operators/split_op.cc +++ b/paddle/fluid/operators/split_op.cc @@ -148,6 +148,11 @@ This operator splits the input tensor into multiple sub-tensors. AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr( + "mkldnn_data_type", + "(string, default \"float32\"). Data type of mkldnn kernel") + .SetDefault("float32") + .InEnum({"float32", "bfloat16"}); } }; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py index 19407d8944a25..200360859b076 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_split_bf16_mkldnn_op.py @@ -41,7 +41,11 @@ def setUp(self): self.num = 0 self.init_data() self.inputs = {'X': self.x} - self.attrs = {'use_mkldnn': True, 'num': self.num} + self.attrs = { + 'use_mkldnn': True, + 'num': self.num, + 'mkldnn_data_type': "bfloat16" + } if self.axis is not None: self.attrs['axis'] = self.axis @@ -56,7 +60,7 @@ def setUp(self): for i in range(len(self.out))]} def test_check_output(self): - self.check_output(check_dygraph=False) + self.check_output_with_place(core.CPUPlace()) # TODO jakpiase enable grad check(concat op)