From 455e91cea635b205a67ada6296b796976cfef554 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Thu, 23 Sep 2021 17:36:05 +0200 Subject: [PATCH 01/11] added shape oneDNN kernel --- .../fluid/operators/mkldnn/shape_mkldnn_op.cc | 58 +++++++++++++ paddle/fluid/operators/shape_op.cc | 23 +++++ paddle/fluid/platform/mkldnn_helper.h | 17 ++++ .../unittests/mkldnn/test_shape_mkldnn_op.py | 84 +++++++++++++++++++ 4 files changed, 182 insertions(+) create mode 100644 paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py diff --git a/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc new file mode 100644 index 0000000000000..8ac47350fff59 --- /dev/null +++ b/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc @@ -0,0 +1,58 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/shape_op.h" +#include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; +using paddle::framework::LoDTensor; + +template +class ShapeMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_var = ctx.InputVar("Input"); + auto* out = ctx.Output("Out"); + + framework::DDim in_dims; + if (in_var->IsType()) { + in_dims = in_var->Get().value().dims(); + } else { + in_dims = in_var->Get().dims(); + } + + out->Resize({in_dims.size()}); + auto out_data = out->mutable_data(platform::CPUPlace()); + for (int i = 0; i < in_dims.size(); ++i) { + out_data[i] = in_dims[i]; + } + + out->set_layout(framework::DataLayout::kMKLDNN); + out->set_format(platform::GetPlainMKLDNNFormat(in_dims.size())); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(shape, MKLDNN, paddle::platform::CPUPlace, + ops::ShapeMKLDNNKernel, + ops::ShapeMKLDNNKernel, + ops::ShapeMKLDNNKernel, + ops::ShapeMKLDNNKernel); diff --git a/paddle/fluid/operators/shape_op.cc b/paddle/fluid/operators/shape_op.cc index dd135b89714da..e42481afb1472 100644 --- a/paddle/fluid/operators/shape_op.cc +++ b/paddle/fluid/operators/shape_op.cc @@ -35,6 +35,21 @@ class ShapeOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", {in_dim.size()}); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } + protected: framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const framework::Tensor &tensor, @@ -58,6 +73,14 @@ Shape Operator. Return the shape of the input. )DOC"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr( + "mkldnn_data_type", + "(string, default \"float32\"). Data type of mkldnn kernel") + .SetDefault("float32") + .InEnum({"float32", "bfloat16", "int8"}); } }; diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index f14f92cb51fdb..c84de93a018ce 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -328,6 +328,23 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat( return mkldnn::memory::format_tag::undef; } +inline mkldnn::memory::format_tag GetPlainMKLDNNFormat(int tensor_rank) { + switch (tensor_rank) { + case 1: + return mkldnn::memory::format_tag::a; + case 2: + return mkldnn::memory::format_tag::ab; + case 3: + return mkldnn::memory::format_tag::abc; + case 4: + return mkldnn::memory::format_tag::abcd; + case 5: + return mkldnn::memory::format_tag::abcde; + } + + return mkldnn::memory::format_tag::abcdef; +} + inline mkldnn::memory::format_tag GetMKLDNNFormat(const mkldnn::memory memory) { auto mem_desc = memory.get_desc(); return GetMKLDNNFormat(mem_desc); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py new file mode 100644 index 0000000000000..ebb87edb98700 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py @@ -0,0 +1,84 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16 +import paddle +from paddle.fluid import core +from paddle.fluid.op import Operator + + +class TestShapeOneDNNOp(OpTest): + def setUp(self): + self.op_type = "shape" + self.config() + self.attrs = {'use_mkldnn': True} + self.inputs = {'Input': np.zeros(self.shape).astype(self.dtype)} + self.outputs = {'Out': np.array(self.shape)} + + def config(self): + self.shape = [5, 7, 4] + self.dtype = np.float32 + + def test_check_output(self): + self.check_output() + + +class TestShape1DOneDNNOp(TestShapeOneDNNOp): + def config(self): + self.shape = [2] + + +class TestShape4DBF16OneDNNOp(TestShapeOneDNNOp): + def config(self): + self.shape = [10, 2, 3, 5] + self.dtype = np.uint16 + + +class TestShape6DBF16OneDNNOp(TestShapeOneDNNOp): + def config(self): + self.shape = [10, 2, 3, 4, 5, 2] + self.dtype = np.uint16 + + +class TestShape3DINT8OneDNNOp(TestShapeOneDNNOp): + def config(self): + self.shape = [10, 2, 3] + self.dtype = np.int8 + + +class TestShape5DINT8OneDNNOp(TestShapeOneDNNOp): + def config(self): + self.shape = [10, 2, 3, 4, 3] + self.dtype = np.int8 + + +class TestShape2DUINT8OneDNNOp(TestShapeOneDNNOp): + def config(self): + self.shape = [7, 11] + self.dtype = np.uint8 + + +class TestShape3DUINT8OneDNNOp(TestShapeOneDNNOp): + def config(self): + self.shape = [2, 7, 11] + self.dtype = np.uint8 + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() From a685a606504ce101b4168917c510f88380c46b47 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Thu, 23 Sep 2021 17:45:47 +0200 Subject: [PATCH 02/11] removed unnecessary import from test --- .../paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py index ebb87edb98700..c4df04fb458b3 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py @@ -16,7 +16,7 @@ import unittest import numpy as np -from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, convert_float_to_uint16 +from paddle.fluid.tests.unittests.op_test import OpTest import paddle from paddle.fluid import core from paddle.fluid.op import Operator From d629217cc9216b54cc5ead6f668f54de723d341c Mon Sep 17 00:00:00 2001 From: jakpiase Date: Thu, 23 Sep 2021 19:20:12 +0200 Subject: [PATCH 03/11] added skipping tests for GPU --- .../fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py index c4df04fb458b3..d514d24e0bb41 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py @@ -16,12 +16,13 @@ import unittest import numpy as np -from paddle.fluid.tests.unittests.op_test import OpTest +from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool import paddle from paddle.fluid import core from paddle.fluid.op import Operator +@OpTestTool.skip_if_not_cpu_bf16() class TestShapeOneDNNOp(OpTest): def setUp(self): self.op_type = "shape" @@ -35,7 +36,7 @@ def config(self): self.dtype = np.float32 def test_check_output(self): - self.check_output() + self.check_output_with_place(core.CPUPlace()) class TestShape1DOneDNNOp(TestShapeOneDNNOp): From 02570d958d0ab116485eb7cdd43d4c5f4c0863fd Mon Sep 17 00:00:00 2001 From: jakpiase Date: Tue, 11 Jan 2022 15:58:41 +0100 Subject: [PATCH 04/11] refactoring --- .../fluid/operators/mkldnn/shape_mkldnn_op.cc | 3 +- .../unittests/mkldnn/test_shape_mkldnn_op.py | 33 +++---------------- 2 files changed, 6 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc index 3ce91bfafc32d..dddd453c84527 100644 --- a/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc @@ -15,7 +15,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/shape_op.h" #include "paddle/fluid/platform/mkldnn_helper.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" namespace paddle { namespace operators { @@ -27,7 +26,7 @@ template class ShapeMKLDNNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - cons auto* in_var = ctx.InputVar("Input"); + const auto* in_var = ctx.InputVar("Input"); auto* out = ctx.Output("Out"); framework::DDim in_dims; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py index d514d24e0bb41..943024a3c37de 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py @@ -23,7 +23,7 @@ @OpTestTool.skip_if_not_cpu_bf16() -class TestShapeOneDNNOp(OpTest): +class TestShape3DFP32OneDNNOp(OpTest): def setUp(self): self.op_type = "shape" self.config() @@ -39,47 +39,24 @@ def test_check_output(self): self.check_output_with_place(core.CPUPlace()) -class TestShape1DOneDNNOp(TestShapeOneDNNOp): - def config(self): - self.shape = [2] - - -class TestShape4DBF16OneDNNOp(TestShapeOneDNNOp): - def config(self): - self.shape = [10, 2, 3, 5] - self.dtype = np.uint16 - - -class TestShape6DBF16OneDNNOp(TestShapeOneDNNOp): +class TestShape6DBF16OneDNNOp(TestShape3DFP32OneDNNOp): def config(self): self.shape = [10, 2, 3, 4, 5, 2] self.dtype = np.uint16 -class TestShape3DINT8OneDNNOp(TestShapeOneDNNOp): +class TestShape9DINT8OneDNNOp(TestShape3DFP32OneDNNOp): def config(self): - self.shape = [10, 2, 3] + self.shape = [1, 2, 3, 4, 5, 6, 7, 8, 9] self.dtype = np.int8 -class TestShape5DINT8OneDNNOp(TestShapeOneDNNOp): - def config(self): - self.shape = [10, 2, 3, 4, 3] - self.dtype = np.int8 - - -class TestShape2DUINT8OneDNNOp(TestShapeOneDNNOp): +class TestShape2DUINT8OneDNNOp(TestShape3DFP32OneDNNOp): def config(self): self.shape = [7, 11] self.dtype = np.uint8 -class TestShape3DUINT8OneDNNOp(TestShapeOneDNNOp): - def config(self): - self.shape = [2, 7, 11] - self.dtype = np.uint8 - - if __name__ == '__main__': paddle.enable_static() unittest.main() From 1820dab6635c5ffcf24fa2aec1ff2d2d57daa848 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Fri, 4 Feb 2022 19:56:31 +0100 Subject: [PATCH 05/11] refactored shape kernel --- .../fluid/operators/mkldnn/shape_mkldnn_op.cc | 23 ++++--------------- paddle/fluid/platform/mkldnn_helper.h | 9 -------- .../unittests/mkldnn/test_shape_mkldnn_op.py | 2 +- 3 files changed, 6 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc index dddd453c84527..ea93d1afcad48 100644 --- a/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,27 +23,14 @@ using paddle::framework::Tensor; using paddle::framework::LoDTensor; template -class ShapeMKLDNNKernel : public framework::OpKernel { +class ShapeMKLDNNKernel : public ShapeKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - const auto* in_var = ctx.InputVar("Input"); - auto* out = ctx.Output("Out"); - - framework::DDim in_dims; - if (in_var->IsType()) { - in_dims = in_var->Get().value().dims(); - } else { - in_dims = in_var->Get().dims(); - } - - out->Resize({in_dims.size()}); - auto out_data = out->mutable_data(platform::CPUPlace()); - for (int i = 0; i < in_dims.size(); ++i) { - out_data[i] = in_dims[i]; - } + ShapeKernel::Compute(ctx); + auto* out = ctx.Output("Out"); out->set_layout(framework::DataLayout::kMKLDNN); - out->set_format(platform::GetPlainMKLDNNFormat(in_dims.size())); + out->set_format(platform::GetPlainMKLDNNFormat(out->dims().size())); } }; } // namespace operators diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 7a528cf8d6be1..9dbfe7013fae8 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -346,31 +346,22 @@ inline dnnl::memory::format_tag GetPlainMKLDNNFormat(int tensor_rank) { switch (tensor_rank) { case 1: return dnnl::memory::format_tag::a; - break; case 2: return dnnl::memory::format_tag::ab; - break; case 3: return dnnl::memory::format_tag::abc; - break; case 4: return dnnl::memory::format_tag::abcd; - break; case 5: return dnnl::memory::format_tag::abcde; - break; case 6: return dnnl::memory::format_tag::abcdef; - break; case 7: return dnnl::memory::format_tag::abcdefg; - break; case 8: return dnnl::memory::format_tag::abcdefgh; - break; case 9: return dnnl::memory::format_tag::abcdefghi; - break; default: PADDLE_THROW(platform::errors::Unimplemented( "Paddle support tensors with rank in range <1, 9>, but received " diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py index 943024a3c37de..41e6344a0a17f 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 5765a8a1386c72dc245aaa0ebe6264de1d744738 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Fri, 4 Feb 2022 20:33:36 +0100 Subject: [PATCH 06/11] added tests in new framework --- .../ir/inference/test_mkldnn_shape_op.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py new file mode 100644 index 0000000000000..4e1f6f31fe3d4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py @@ -0,0 +1,63 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import MkldnnAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +from functools import partial +import unittest +from hypothesis import given +import hypothesis.strategies as st + + +class TestMkldnnMishOp(MkldnnAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self, *args, **kwargs): + def generate_input(*args, **kwargs): + return np.random.random(kwargs['in_shape']).astype(kwargs['in_dtype']) + + shape_op = OpConfig( + type="shape", + inputs={"Input": ["input_data"]}, + outputs={"Out": ["output_data"]}) + + program_config = ProgramConfig( + ops=[shape_op], + weights={}, + inputs={ + "input_data": TensorConfig(data_gen=partial(generate_input, + *args, **kwargs)), + }, + outputs=["output_data"]) + + yield program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_mkldnn=True) + yield config, (1e-5, 1e-5) + + @given( + in_shape=st.lists( + st.integers( + min_value=1, max_value=3), min_size=1, max_size=9), + in_dtype=st.sampled_from([np.float32, np.uint16, np.int8, np.uint8])) + + def test(self, *args, **kwargs): + self.run_test(quant=False, *args, **kwargs) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From dc9956991ba496cd206810a2d45278e355aed207 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Fri, 4 Feb 2022 20:34:34 +0100 Subject: [PATCH 07/11] removed one line --- paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc index ea93d1afcad48..780c6e7f153e7 100644 --- a/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc @@ -20,7 +20,6 @@ namespace paddle { namespace operators { using paddle::framework::Tensor; -using paddle::framework::LoDTensor; template class ShapeMKLDNNKernel : public ShapeKernel { From 5033ffd8344af4768cd4c3e5238a68a8b5d197f3 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Fri, 4 Feb 2022 20:35:18 +0100 Subject: [PATCH 08/11] minor change --- .../fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py index 4e1f6f31fe3d4..484e880f7de93 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py @@ -21,7 +21,7 @@ import hypothesis.strategies as st -class TestMkldnnMishOp(MkldnnAutoScanTest): +class TestMkldnnShapeOp(MkldnnAutoScanTest): def is_program_valid(self, program_config: ProgramConfig) -> bool: return True From 7dfdcb6daaf4d4a5f40966c844c83fafad6152cf Mon Sep 17 00:00:00 2001 From: jakpiase Date: Fri, 4 Feb 2022 20:36:06 +0100 Subject: [PATCH 09/11] added newline at EOF --- .../fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py index 484e880f7de93..7a07828b57d97 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py @@ -60,4 +60,4 @@ def test(self, *args, **kwargs): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() From 834c9897916179699da4cf70c48b24f0218b1554 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Fri, 4 Feb 2022 21:53:36 +0100 Subject: [PATCH 10/11] added formatting --- .../tests/unittests/ir/inference/test_mkldnn_shape_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py index 7a07828b57d97..5b23669b98daa 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py @@ -27,7 +27,8 @@ def is_program_valid(self, program_config: ProgramConfig) -> bool: def sample_program_configs(self, *args, **kwargs): def generate_input(*args, **kwargs): - return np.random.random(kwargs['in_shape']).astype(kwargs['in_dtype']) + return np.random.random(kwargs['in_shape']).astype(kwargs[ + 'in_dtype']) shape_op = OpConfig( type="shape", @@ -54,7 +55,6 @@ def sample_predictor_configs(self, program_config): st.integers( min_value=1, max_value=3), min_size=1, max_size=9), in_dtype=st.sampled_from([np.float32, np.uint16, np.int8, np.uint8])) - def test(self, *args, **kwargs): self.run_test(quant=False, *args, **kwargs) From c4755323b2719692b938e5cb5b6900e2cdd116d3 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Sat, 5 Feb 2022 16:36:33 +0100 Subject: [PATCH 11/11] added attributes as extra --- paddle/fluid/operators/shape_op.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/shape_op.cc b/paddle/fluid/operators/shape_op.cc index e42481afb1472..5b7ccdde81097 100644 --- a/paddle/fluid/operators/shape_op.cc +++ b/paddle/fluid/operators/shape_op.cc @@ -75,12 +75,14 @@ Return the shape of the input. )DOC"); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") - .SetDefault(false); + .SetDefault(false) + .AsExtra(); AddAttr( "mkldnn_data_type", "(string, default \"float32\"). Data type of mkldnn kernel") .SetDefault("float32") - .InEnum({"float32", "bfloat16", "int8"}); + .InEnum({"float32", "bfloat16", "int8"}) + .AsExtra(); } };