From 1ecc4cf466cf6291799d9aaddd71fac3fd48b3f2 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Wed, 2 Dec 2020 14:36:18 +0000 Subject: [PATCH 01/37] added external reorder to profiler --- paddle/fluid/framework/data_layout_transform.cc | 5 ++++- paddle/fluid/platform/profiler_helper.h | 9 ++++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/data_layout_transform.cc b/paddle/fluid/framework/data_layout_transform.cc index 8563b5b6d3695..f54c0f5e2064a 100644 --- a/paddle/fluid/framework/data_layout_transform.cc +++ b/paddle/fluid/framework/data_layout_transform.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/framework/data_layout_transform.h" - +#include "paddle/fluid/platform/profiler.h" #include #include "paddle/fluid/operators/math/math_function.h" @@ -194,6 +194,9 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p); mkldnn::stream astream(cpu_engine); + #ifdef PADDLE_WITH_MKLDNN + platform::RecordEvent record_reorder("ext_reorder", platform::EventRole::kUniqueOp); + #endif reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); astream.wait(); } else { diff --git a/paddle/fluid/platform/profiler_helper.h b/paddle/fluid/platform/profiler_helper.h index c79195aa0db0d..7e22a0e68552a 100644 --- a/paddle/fluid/platform/profiler_helper.h +++ b/paddle/fluid/platform/profiler_helper.h @@ -712,10 +712,17 @@ void AnalyzeEvent( } } for (size_t j = 0; j < table_size; ++j) { - if (child_index[j] == 0) { + if (child_index[j] == 0) { // pushes and counts only parents, ensures that time will not be counted twice main_event_items.push_back(event_items[j]); total += event_items[j].total_time; } + else if (child_index[j] == 1 && event_items[j].name.find("reorder") != std::string::npos){ + size_t first_slash_pos = event_items[j].name.find('/'); + if(first_slash_pos != std::string::npos){ + std::string fname = event_items[j].name.substr(0, first_slash_pos); + child_map->insert(std::pair(fname, event_items[j])); + } + } } // average time for (auto &item : main_event_items) { From 5c02f899db416ea8bd582b4c3242ef43908f64f6 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Mon, 22 Mar 2021 15:12:00 +0100 Subject: [PATCH 02/37] added mkldnn reduce op kernel --- .../reduce_ops/mkldnn/reduce_max_mkldnn_op.cc | 44 ++++++ .../mkldnn/reduce_mean_mkldnn_op.cc | 43 +++++ .../reduce_ops/mkldnn/reduce_min_mkldnn_op.cc | 43 +++++ .../reduce_ops/mkldnn/reduce_mkldnn_op.h | 84 ++++++++++ .../reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc | 56 +++++++ paddle/fluid/operators/reduce_ops/reduce_op.h | 26 ++++ paddle/fluid/platform/mkldnn_reuse.h | 34 +++- .../mkldnn/test_reduce_bf16_mkldnn_op.py | 147 ++++++++++++++++++ .../unittests/mkldnn/test_reduce_mkldnn_op.py | 145 +++++++++++++++++ tools/static_mode_white_list.py | 2 + 10 files changed, 616 insertions(+), 8 deletions(-) create mode 100644 paddle/fluid/operators/reduce_ops/mkldnn/reduce_max_mkldnn_op.cc create mode 100644 paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc create mode 100644 paddle/fluid/operators/reduce_ops/mkldnn/reduce_min_mkldnn_op.cc create mode 100644 paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h create mode 100644 paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_max_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_max_mkldnn_op.cc new file mode 100644 index 0000000000000..a380604e1f866 --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_max_mkldnn_op.cc @@ -0,0 +1,44 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h" + +namespace paddle { +namespace operators { + +using paddle::framework::LoDTensor; +using paddle::framework::Tensor; +using paddle::platform::CPUDeviceContext; +using paddle::platform::CreateKey; +using paddle::platform::MKLDNNGetDataType; +using paddle::platform::MKLDNNMemDesc; +using platform::to_void_cast; + + +template +class ReduceMaxMKLDNNKernel : public ReduceMKLDNNKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx, dnnl::algorithm::reduction_max); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(reduce_max, MKLDNN, paddle::platform::CPUPlace, + ops::ReduceMaxMKLDNNKernel, + ops::ReduceMaxMKLDNNKernel); diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc new file mode 100644 index 0000000000000..c0c4fc7f70032 --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc @@ -0,0 +1,43 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h" + +namespace paddle { +namespace operators { + +using paddle::framework::LoDTensor; +using paddle::framework::Tensor; +using paddle::platform::CPUDeviceContext; +using paddle::platform::CreateKey; +using paddle::platform::MKLDNNGetDataType; +using paddle::platform::MKLDNNMemDesc; +using platform::to_void_cast; + + +template +class ReduceMeanMKLDNNKernel : public ReduceMKLDNNKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx, dnnl::algorithm::reduction_mean); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(reduce_mean, MKLDNN, paddle::platform::CPUPlace, + ops::ReduceMeanMKLDNNKernel); diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_min_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_min_mkldnn_op.cc new file mode 100644 index 0000000000000..398760fd27854 --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_min_mkldnn_op.cc @@ -0,0 +1,43 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h" + +namespace paddle { +namespace operators { + +using paddle::framework::LoDTensor; +using paddle::framework::Tensor; +using paddle::platform::CPUDeviceContext; +using paddle::platform::CreateKey; +using paddle::platform::MKLDNNGetDataType; +using paddle::platform::MKLDNNMemDesc; +using platform::to_void_cast; + + +template +class ReduceMinMKLDNNKernel : public ReduceMKLDNNKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx, dnnl::algorithm::reduction_min); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(reduce_min, MKLDNN, paddle::platform::CPUPlace, + ops::ReduceMinMKLDNNKernel); diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h new file mode 100644 index 0000000000000..f1d4273144c2a --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/mkldnn_reuse.h" + + +namespace paddle { +namespace operators { + +using paddle::framework::LoDTensor; +using paddle::framework::Tensor; +using paddle::platform::CPUDeviceContext; +using paddle::platform::CreateKey; +using paddle::platform::MKLDNNGetDataType; +using paddle::platform::MKLDNNMemDesc; +using platform::to_void_cast; + + +template +class ReduceMKLDNNKernel : public framework::OpKernel { + public: + void RunKernel(const framework::ExecutionContext& ctx, dnnl::algorithm reduction_type) const { + auto& dev_ctx = + ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + + const auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + + int out_dtype = ctx.Attr("out_dtype"); + int in_dtype = ctx.Attr("in_dtype"); + + auto x_dims = input->dims(); + auto x_rank = x_dims.size(); + + auto dims = ctx.Attr>("dim"); // dims to reduce + bool reduce_all = ctx.Attr("reduce_all"); + //bool keep_dim = ctx.Attr("keep_dim"); + // Change data formats + + + + platform::ReductionMKLDNNHandler handler( + reduction_type, 0.0f, 0.0f, dev_ctx, onednn_engine, + ctx.GetPlace(), input, output, + ctx.InputName("X"), dims); + + auto src_memory_p = handler.AcquireSrcMemory(input); + auto dst_memory_p = handler.AcquireDstMemory(output); + + std::unordered_map reduction_args = { + {DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + auto reduction_p = handler.AcquireForwardPrimitive(); + + auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); + reduction_p->execute(astream, reduction_args); + astream.wait(); + + output->set_layout(framework::DataLayout::kMKLDNN); + output->set_format( + platform::GetMKLDNNFormat(dst_memory_p->get_desc().reshape( + paddle::framework::vectorize(output->dims())))); + } + + +}; + +} // namespace operators +} // namespace paddle + + diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc new file mode 100644 index 0000000000000..2851f2bb6e209 --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc @@ -0,0 +1,56 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h" + +namespace paddle { +namespace operators { + +using paddle::framework::LoDTensor; +using paddle::framework::Tensor; +using paddle::platform::CPUDeviceContext; +using paddle::platform::CreateKey; +using paddle::platform::MKLDNNGetDataType; +using paddle::platform::MKLDNNMemDesc; +using platform::to_void_cast; + + +template +class ReduceSumMKLDNNKernel : public ReduceMKLDNNKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx, dnnl::algorithm::reduction_sum); + } +}; + + +template +class ReduceSumGradMKLDNNKernel : public ReduceMKLDNNKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx, dnnl::algorithm::reduction_sum); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(reduce_sum, MKLDNN, paddle::platform::CPUPlace, + ops::ReduceSumMKLDNNKernel, + ops::ReduceSumMKLDNNKernel); + +REGISTER_OP_KERNEL(reduce_sum_grad, MKLDNN, paddle::platform::CPUPlace, + ops::ReduceSumGradMKLDNNKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 25f9453571ac6..b3fddec07d552 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -489,6 +489,29 @@ class ReduceOp : public framework::OperatorWithKernel { } } } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + // choose cudnn kernel if the runtime supported. + framework::LibraryType library_{framework::LibraryType::kPlain}; + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + + if (input_data_type == framework::proto::VarType::FP16) { + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::InvalidArgument( + "float16 can only be used on GPU place")); + } + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class ReduceOpUseInputPlace : public ReduceOp { @@ -579,6 +602,9 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { "(int, default -1)" "The dtype of output, default value is -1, the dtype is same as intput") .SetDefault(-1); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddComment(string::Sprintf(R"DOC( %s Operator. diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 0503c3f71a802..d2107c39f6da5 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -630,7 +630,7 @@ class ReductionMKLDNNHandler const float eps, const MKLDNNDeviceContext& dev_ctx, const mkldnn::engine engine, platform::Place cpu_place, const Tensor* x, const Tensor* y, - const std::string& uniq_name) + const std::string& uniq_name, std::vector dims = {}) : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), @@ -647,12 +647,31 @@ class ReductionMKLDNNHandler const auto src_tz = framework::vectorize(x->dims()); const auto dst_tz = framework::vectorize(y->dims()); - // For oneDNN dimensionality should match so we need to - // extend Y tensor dims with values of 1 (before and after pattern) - int j = 0; - std::vector dst_tz_ex(src_tz.size(), 1); - for (size_t i = 0; i < src_tz.size(); ++i) { - dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++]; + std::vector dst_tz_ex; + + if(dims.empty()) { + // For oneDNN dimensionality should match so we need to + // extend Y tensor dims with values of 1 (before and after pattern) + int j = 0; + + for (size_t i = 0; i < src_tz.size(); ++i) { + dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++]; + } + } else { + if (dst_tz.size() == 1) { // reduce_all + for(size_t j = 0 ; j < src_tz.size() ; ++j) { + dst_tz_ex.push_back(1); + } + } else { + for(auto &elem : src_tz) { + dst_tz_ex.push_back(elem); + } + + for(size_t i = 0; i < dims.size(); ++i) { + dims[i] = (dims[i] >= 0) ? dims[i] : src_tz.size() + dims[i]; // because dims can be counted backwards, "-1" = last dimension + dst_tz_ex[dims[i]] = 1; + } + } } const auto src_md = dnnl::memory::desc( @@ -663,7 +682,6 @@ class ReductionMKLDNNHandler this->AcquireForwardPrimitiveDescriptor(algo, src_md, dst_md, p, eps); } } -}; template class ActivationMKLDNNHandler diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py new file mode 100644 index 0000000000000..4efe4a6b2e4cb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py @@ -0,0 +1,147 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.framework import convert_np_dtype_to_dtype_ + + +@skip_check_grad_ci( + reason="not implemented") +class TestReduceSumBF16DefaultONEDNNOp(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + x_fp32 = np.random.random((5, 6, 10)).astype("float32") + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.outputs = {'Out': x_fp32.sum(axis=0)} + self.attrs = { + 'use_mkldnn': self.use_mkldnn + } + + def test_check_output(self): + self.check_output(check_dygraph=False) + + +#@skip_check_grad_ci( +# reason="not implemented") +#class TestReduceSumONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +# def setUp(self): +# self.op_type = "reduce_sum" +# self.use_mkldnn = True +# self.inputs = {'X': np.random.random((5, 10, 5, 5)).astype("float32")} +# self.attrs = { +# 'use_mkldnn': self.use_mkldnn, +# 'dim': [2] +# } +# self.outputs = {'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))} +# +# +#@skip_check_grad_ci( +# reason="reduce_max is discontinuous non-derivable function," +# " its gradient check is not supported by unittest framework.") +#class TestReduceMaxONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +# """Remove Max with subgradient from gradient check to confirm the success of CI.""" +# +# def setUp(self): +# self.op_type = "reduce_max" +# self.use_mkldnn = True +# self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} +# self.attrs = { +# 'dim': [-1], +# 'use_mkldnn' : self.use_mkldnn +# } +# self.outputs = { +# 'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim'])) +# } +# +# def test_check_output(self): +# self.check_output() +# +#@skip_check_grad_ci( +# reason="not implemented") +#class TestReduceSumToScalarONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +# def setUp(self): +# self.op_type = "reduce_sum" +# self.use_mkldnn = True +# self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")} +# self.attrs = { +# 'reduce_all': True, +# 'use_mkldnn': self.use_mkldnn +# } +# self.outputs = {'Out': self.inputs['X'].sum()} +# +# +#@skip_check_grad_ci( +# reason="reduce_min is discontinuous non-derivable function," +# " its gradient check is not supported by unittest framework.") +#class TestReduceMinONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +# """Remove Min with subgradient from gradient check to confirm the success of CI.""" +# +# def setUp(self): +# self.op_type = "reduce_min" +# self.use_mkldnn = True +# self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} +# self.attrs = { +# 'dim': [2], +# 'use_mkldnn': self.use_mkldnn +# } +# self.outputs = { +# 'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim'])) +# } +# +# +#@skip_check_grad_ci( +# reason="reduce_min is discontinuous non-derivable function," +# " its gradient check is not supported by unittest framework.") +#class TestReduceMeanONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +# def setUp(self): +# self.op_type = "reduce_mean" +# self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} +# self.outputs = {'Out': self.inputs['X'].sum(axis=0) / self.inputs['X'].shape[0]} +# +# +#@skip_check_grad_ci( +# reason="reduce_min is discontinuous non-derivable function," +# " its gradient check is not supported by unittest framework.") +#class TestReduceSumKeepDimsONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +# def setUp(self): +# self.op_type = "reduce_sum" +# self.use_mkldnn = True +# self.inputs = { +# 'X': np.random.random((2, 5, 3, 2, 2)).astype("float32") +# } +# self.attrs = { +# 'dim': (2, 3, 4), +# 'keep_dim': True, +# 'use_mkldnn': True +# } +# self.outputs = { +# 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']), +# keepdims=self.attrs['keep_dim']) +# } + + +if __name__ == '__main__': + import paddle + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py new file mode 100644 index 0000000000000..431e37570022c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py @@ -0,0 +1,145 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.framework import convert_np_dtype_to_dtype_ + + +@skip_check_grad_ci( + reason="not implemented") +class TestReduceSumDefaultONEDNNOp(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.outputs = {'Out': self.inputs['X'].sum(axis=0)} + self.attrs = { + 'use_mkldnn': self.use_mkldnn + } + + def test_check_output(self): + self.check_output() + + +@skip_check_grad_ci( + reason="not implemented") +class TestReduceSumONEDNNOp(TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 10, 5, 5)).astype("float32")} + self.attrs = { + 'use_mkldnn': self.use_mkldnn, + 'dim': [2] + } + self.outputs = {'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))} + + +@skip_check_grad_ci( + reason="reduce_max is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMaxONEDNNOp(TestReduceSumDefaultONEDNNOp): + """Remove Max with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = { + 'dim': [-1], + 'use_mkldnn' : self.use_mkldnn + } + self.outputs = { + 'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim'])) + } + + def test_check_output(self): + self.check_output() + +@skip_check_grad_ci( + reason="not implemented") +class TestReduceSumToScalarONEDNNOp(TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")} + self.attrs = { + 'reduce_all': True, + 'use_mkldnn': self.use_mkldnn + } + self.outputs = {'Out': self.inputs['X'].sum()} + + +@skip_check_grad_ci( + reason="reduce_min is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMinONEDNNOp(TestReduceSumDefaultONEDNNOp): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = { + 'dim': [2], + 'use_mkldnn': self.use_mkldnn + } + self.outputs = { + 'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim'])) + } + + +@skip_check_grad_ci( + reason="reduce_min is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMeanONEDNNOp(TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_mean" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.outputs = {'Out': self.inputs['X'].sum(axis=0) / self.inputs['X'].shape[0]} + + +@skip_check_grad_ci( + reason="reduce_min is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceSumKeepDimsONEDNNOp(TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + self.inputs = { + 'X': np.random.random((2, 5, 3, 2, 2)).astype("float32") + } + self.attrs = { + 'dim': (2, 3, 4), + 'keep_dim': True, + 'use_mkldnn': True + } + self.outputs = { + 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']), + keepdims=self.attrs['keep_dim']) + } + + +if __name__ == '__main__': + import paddle + paddle.enable_static() + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index dc537cb2684bb..c20560b4be525 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -420,6 +420,8 @@ 'test_reader_reset', 'test_recurrent_op', 'test_reduce_op', + 'test_reduce_mkldnn_op', + 'test_reduce_bf16_mkldnn_op', 'test_ref_by_trainer_id_op', 'test_registry', 'test_regularizer', From 4147b25cee16443e9b811eafc1656df3fc4f995e Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Tue, 23 Mar 2021 09:13:12 +0100 Subject: [PATCH 03/37] refactored reduce op --- .../mkldnn/elementwise_add_mkldnn_op.cc | 2 +- .../mkldnn/elementwise_mkldnn_op.h | 12 + .../mkldnn/elementwise_mul_mkldnn_op.cc | 2 +- .../reduce_ops/mkldnn/reduce_max_mkldnn_op.cc | 10 - .../mkldnn/reduce_mean_mkldnn_op.cc | 13 +- .../reduce_ops/mkldnn/reduce_min_mkldnn_op.cc | 13 +- .../reduce_ops/mkldnn/reduce_mkldnn_op.h | 38 +-- .../reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc | 22 -- paddle/fluid/platform/mkldnn_reuse.h | 33 +-- .../mkldnn/test_reduce_bf16_mkldnn_op.py | 272 +++++++++++------- .../unittests/mkldnn/test_reduce_mkldnn_op.py | 135 ++++++--- 11 files changed, 315 insertions(+), 237 deletions(-) diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index b43dddfcf19db..e3e256812203a 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -86,7 +86,7 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel { platform::ReductionMKLDNNHandler handler_sum( dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine, ctx.GetPlace(), dout, dy, - ctx.InputName(framework::GradVarName("Out"))); + ctx.InputName(framework::GradVarName("Out")), CalculateBroadcastedDims(dout, dy)); auto dy_memory_p = handler_sum.AcquireDstMemory(dy); auto reduction_p = handler_sum.AcquireForwardPrimitive(); reduction_p->execute(astream, {{DNNL_ARG_SRC, *reorder_src_memory_p}, diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h index df827117a0d30..247deb46bd4b9 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h @@ -81,5 +81,17 @@ class EltwiseMKLDNNKernel : public framework::OpKernel { z->set_format(platform::GetMKLDNNFormat(*dst_memory)); } }; + + inline std::vector CalculateBroadcastedDims(const Tensor* x, const Tensor* y){ + const auto src_tz = framework::vectorize(x->dims()); + const auto dst_tz = framework::vectorize(y->dims()); + + int j = 0; + std::vector dst_tz_ex(src_tz.size(), 1); + for (size_t i = 0; i < src_tz.size(); ++i) + dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++]; + + return dst_tz_ex; + } } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc index c9209cc39d5e3..bb1231111e2c9 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc @@ -105,7 +105,7 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel { platform::ReductionMKLDNNHandler handler_sum( dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, mkldnn_engine, ctx.GetPlace(), dout, dy, - ctx.InputName(framework::GradVarName("Out"))); + ctx.InputName(framework::GradVarName("Out")), CalculateBroadcastedDims(dout, dy)); auto dy_memory_p = handler_sum.AcquireDstMemory(dy); auto reduction_p = handler_sum.AcquireForwardPrimitive(); // As source we use mem object with results from binary operation diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_max_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_max_mkldnn_op.cc index a380604e1f866..a24caa718f28f 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_max_mkldnn_op.cc +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_max_mkldnn_op.cc @@ -12,21 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h" namespace paddle { namespace operators { -using paddle::framework::LoDTensor; -using paddle::framework::Tensor; -using paddle::platform::CPUDeviceContext; -using paddle::platform::CreateKey; -using paddle::platform::MKLDNNGetDataType; -using paddle::platform::MKLDNNMemDesc; -using platform::to_void_cast; - - template class ReduceMaxMKLDNNKernel : public ReduceMKLDNNKernel { public: diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc index c0c4fc7f70032..903bf5f24a43d 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc @@ -12,21 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h" namespace paddle { namespace operators { -using paddle::framework::LoDTensor; -using paddle::framework::Tensor; -using paddle::platform::CPUDeviceContext; -using paddle::platform::CreateKey; -using paddle::platform::MKLDNNGetDataType; -using paddle::platform::MKLDNNMemDesc; -using platform::to_void_cast; - - template class ReduceMeanMKLDNNKernel : public ReduceMKLDNNKernel { public: @@ -40,4 +30,5 @@ class ReduceMeanMKLDNNKernel : public ReduceMKLDNNKernel { namespace ops = paddle::operators; REGISTER_OP_KERNEL(reduce_mean, MKLDNN, paddle::platform::CPUPlace, - ops::ReduceMeanMKLDNNKernel); + ops::ReduceMeanMKLDNNKernel, + ops::ReduceMeanMKLDNNKernel); diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_min_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_min_mkldnn_op.cc index 398760fd27854..f39938893b654 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_min_mkldnn_op.cc +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_min_mkldnn_op.cc @@ -12,21 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h" namespace paddle { namespace operators { -using paddle::framework::LoDTensor; -using paddle::framework::Tensor; -using paddle::platform::CPUDeviceContext; -using paddle::platform::CreateKey; -using paddle::platform::MKLDNNGetDataType; -using paddle::platform::MKLDNNMemDesc; -using platform::to_void_cast; - - template class ReduceMinMKLDNNKernel : public ReduceMKLDNNKernel { public: @@ -40,4 +30,5 @@ class ReduceMinMKLDNNKernel : public ReduceMKLDNNKernel { namespace ops = paddle::operators; REGISTER_OP_KERNEL(reduce_min, MKLDNN, paddle::platform::CPUPlace, - ops::ReduceMinMKLDNNKernel); + ops::ReduceMinMKLDNNKernel, + ops::ReduceMinMKLDNNKernel); diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h index f1d4273144c2a..62d4abe759ebb 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -14,16 +14,11 @@ limitations under the License. */ #include "paddle/fluid/platform/mkldnn_reuse.h" - namespace paddle { namespace operators { using paddle::framework::LoDTensor; using paddle::framework::Tensor; -using paddle::platform::CPUDeviceContext; -using paddle::platform::CreateKey; -using paddle::platform::MKLDNNGetDataType; -using paddle::platform::MKLDNNMemDesc; using platform::to_void_cast; @@ -38,23 +33,16 @@ class ReduceMKLDNNKernel : public framework::OpKernel { const auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); - int out_dtype = ctx.Attr("out_dtype"); - int in_dtype = ctx.Attr("in_dtype"); - - auto x_dims = input->dims(); - auto x_rank = x_dims.size(); - - auto dims = ctx.Attr>("dim"); // dims to reduce + auto reduce_dims = ctx.Attr>("dim"); bool reduce_all = ctx.Attr("reduce_all"); - //bool keep_dim = ctx.Attr("keep_dim"); - // Change data formats - + bool keep_dim = ctx.Attr("keep_dim"); + std::vector output_dims = CalculateOutputDims(input, output, reduce_dims, reduce_all, keep_dim); platform::ReductionMKLDNNHandler handler( reduction_type, 0.0f, 0.0f, dev_ctx, onednn_engine, ctx.GetPlace(), input, output, - ctx.InputName("X"), dims); + ctx.InputName("X"), output_dims); auto src_memory_p = handler.AcquireSrcMemory(input); auto dst_memory_p = handler.AcquireDstMemory(output); @@ -69,13 +57,29 @@ class ReduceMKLDNNKernel : public framework::OpKernel { reduction_p->execute(astream, reduction_args); astream.wait(); + output->set_layout(framework::DataLayout::kMKLDNN); output->set_format( platform::GetMKLDNNFormat(dst_memory_p->get_desc().reshape( paddle::framework::vectorize(output->dims())))); } - +private: + std::vector CalculateOutputDims(const Tensor* input, const Tensor* output, std::vector& reduce_dims, bool reduce_all, bool keep_dim) const{ + if(keep_dim) + return framework::vectorize(output->dims()); + + if(reduce_all) + return std::vector (framework::vectorize(input->dims()).size(), 1); + + std::vector output_dims(framework::vectorize(input->dims())); + for(size_t i = 0; i < reduce_dims.size() ; ++i){ + reduce_dims[i] = (reduce_dims[i] >= 0) ? reduce_dims[i] : input->dims().size() + reduce_dims[i]; // dims can be counted backwards, "-1" = last dimension + output_dims[reduce_dims[i]] = 1; + } + + return output_dims; + } }; } // namespace operators diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc index 2851f2bb6e209..30823231b0fa6 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc @@ -12,21 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h" namespace paddle { namespace operators { -using paddle::framework::LoDTensor; -using paddle::framework::Tensor; -using paddle::platform::CPUDeviceContext; -using paddle::platform::CreateKey; -using paddle::platform::MKLDNNGetDataType; -using paddle::platform::MKLDNNMemDesc; -using platform::to_void_cast; - - template class ReduceSumMKLDNNKernel : public ReduceMKLDNNKernel { public: @@ -35,15 +25,6 @@ class ReduceSumMKLDNNKernel : public ReduceMKLDNNKernel { } }; - -template -class ReduceSumGradMKLDNNKernel : public ReduceMKLDNNKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - this->RunKernel(ctx, dnnl::algorithm::reduction_sum); - } -}; - } // namespace operators } // namespace paddle @@ -51,6 +32,3 @@ namespace ops = paddle::operators; REGISTER_OP_KERNEL(reduce_sum, MKLDNN, paddle::platform::CPUPlace, ops::ReduceSumMKLDNNKernel, ops::ReduceSumMKLDNNKernel); - -REGISTER_OP_KERNEL(reduce_sum_grad, MKLDNN, paddle::platform::CPUPlace, - ops::ReduceSumGradMKLDNNKernel); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index b880cf8f1cb8c..51864753d7d1d 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -638,7 +638,7 @@ class ReductionMKLDNNHandler const float eps, const MKLDNNDeviceContext& dev_ctx, const mkldnn::engine engine, platform::Place cpu_place, const Tensor* x, const Tensor* y, - const std::string& uniq_name, std::vector dims = {}) + const std::string& uniq_name, std::vector output_dims) : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), @@ -653,43 +653,16 @@ class ReductionMKLDNNHandler platform::errors::InvalidArgument("Wrong format set for X tensor.")); const auto src_tz = framework::vectorize(x->dims()); - const auto dst_tz = framework::vectorize(y->dims()); - - std::vector dst_tz_ex; - - if(dims.empty()) { - // For oneDNN dimensionality should match so we need to - // extend Y tensor dims with values of 1 (before and after pattern) - int j = 0; - - for (size_t i = 0; i < src_tz.size(); ++i) { - dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++]; - } - } else { - if (dst_tz.size() == 1) { // reduce_all - for(size_t j = 0 ; j < src_tz.size() ; ++j) { - dst_tz_ex.push_back(1); - } - } else { - for(auto &elem : src_tz) { - dst_tz_ex.push_back(elem); - } - - for(size_t i = 0; i < dims.size(); ++i) { - dims[i] = (dims[i] >= 0) ? dims[i] : src_tz.size() + dims[i]; // because dims can be counted backwards, "-1" = last dimension - dst_tz_ex[dims[i]] = 1; - } - } - } const auto src_md = dnnl::memory::desc( src_tz, platform::MKLDNNGetDataType(), x->format()); const auto dst_md = memory::desc( - dst_tz_ex, platform::MKLDNNGetDataType(), x->format()); + output_dims, platform::MKLDNNGetDataType(), x->format()); this->AcquireForwardPrimitiveDescriptor(algo, src_md, dst_md, p, eps); } } +}; template class ActivationMKLDNNHandler diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py index 4efe4a6b2e4cb..e0287e08ea479 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py @@ -42,104 +42,182 @@ def test_check_output(self): self.check_output(check_dygraph=False) -#@skip_check_grad_ci( -# reason="not implemented") -#class TestReduceSumONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): -# def setUp(self): -# self.op_type = "reduce_sum" -# self.use_mkldnn = True -# self.inputs = {'X': np.random.random((5, 10, 5, 5)).astype("float32")} -# self.attrs = { -# 'use_mkldnn': self.use_mkldnn, -# 'dim': [2] -# } -# self.outputs = {'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))} -# -# -#@skip_check_grad_ci( -# reason="reduce_max is discontinuous non-derivable function," -# " its gradient check is not supported by unittest framework.") -#class TestReduceMaxONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): -# """Remove Max with subgradient from gradient check to confirm the success of CI.""" -# -# def setUp(self): -# self.op_type = "reduce_max" -# self.use_mkldnn = True -# self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} -# self.attrs = { -# 'dim': [-1], -# 'use_mkldnn' : self.use_mkldnn -# } -# self.outputs = { -# 'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim'])) -# } -# -# def test_check_output(self): -# self.check_output() -# -#@skip_check_grad_ci( -# reason="not implemented") -#class TestReduceSumToScalarONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): -# def setUp(self): -# self.op_type = "reduce_sum" -# self.use_mkldnn = True -# self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")} -# self.attrs = { -# 'reduce_all': True, -# 'use_mkldnn': self.use_mkldnn -# } -# self.outputs = {'Out': self.inputs['X'].sum()} -# -# -#@skip_check_grad_ci( -# reason="reduce_min is discontinuous non-derivable function," -# " its gradient check is not supported by unittest framework.") -#class TestReduceMinONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): -# """Remove Min with subgradient from gradient check to confirm the success of CI.""" -# -# def setUp(self): -# self.op_type = "reduce_min" -# self.use_mkldnn = True -# self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} -# self.attrs = { -# 'dim': [2], -# 'use_mkldnn': self.use_mkldnn -# } -# self.outputs = { -# 'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim'])) -# } -# -# -#@skip_check_grad_ci( -# reason="reduce_min is discontinuous non-derivable function," -# " its gradient check is not supported by unittest framework.") -#class TestReduceMeanONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): -# def setUp(self): -# self.op_type = "reduce_mean" -# self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} -# self.outputs = {'Out': self.inputs['X'].sum(axis=0) / self.inputs['X'].shape[0]} -# -# -#@skip_check_grad_ci( -# reason="reduce_min is discontinuous non-derivable function," -# " its gradient check is not supported by unittest framework.") -#class TestReduceSumKeepDimsONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): -# def setUp(self): -# self.op_type = "reduce_sum" -# self.use_mkldnn = True -# self.inputs = { -# 'X': np.random.random((2, 5, 3, 2, 2)).astype("float32") -# } -# self.attrs = { -# 'dim': (2, 3, 4), -# 'keep_dim': True, -# 'use_mkldnn': True -# } -# self.outputs = { -# 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']), -# keepdims=self.attrs['keep_dim']) -# } +@skip_check_grad_ci( + reason="not implemented") +class TestReduceSum4DONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + x_fp32 = np.random.random((5, 10, 5, 5)).astype("float32") + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = { + 'use_mkldnn': self.use_mkldnn, + 'dim': [2] + } + self.outputs = {'Out': x_fp32.sum(axis=tuple(self.attrs['dim']))} + + +@skip_check_grad_ci( + reason="not implemented") +class TestReduceSum4DReduceAllWithoutReduceAllAttributeONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + x_fp32 = np.random.normal(size=(2, 3, 5, 6)).astype('float32') + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = { + 'use_mkldnn': self.use_mkldnn, + 'dim': [0, 1, 2, 3] + } + self.outputs = {'Out': x_fp32.sum(axis=tuple(self.attrs['dim']))} + + +@skip_check_grad_ci( + reason="not implemented") +class TestReduceSum4DReduceAllWithoutReduceAllAttributeNegativeDimsONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + x_fp32 = np.random.normal(size=(2, 7, 3, 5)).astype('float32') + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = { + 'use_mkldnn': self.use_mkldnn, + 'dim': [-1, -2, -3, -4] + } + self.outputs = {'Out': x_fp32.sum(axis=tuple(self.attrs['dim']))} + + +@skip_check_grad_ci( + reason="not implemented") +class TestReduceSum5DKeepDimsONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + x_fp32 = np.random.random((2, 5, 3, 2, 2)).astype("float32") + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = { + 'dim': (2, 3, 4), + 'keep_dim': True, + 'use_mkldnn': True + } + self.outputs = { + 'Out': x_fp32.sum(axis=tuple(self.attrs['dim']), + keepdims=self.attrs['keep_dim']) + } + + +@skip_check_grad_ci( + reason="not implemented") +class TestReduceSum5DReduceAllKeepDimsONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + x_fp32 = np.random.normal(size=(2, 5, 3, 2, 4)).astype('float32') + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = { + 'reduce_all': True, + 'keep_dim': True, + 'use_mkldnn': True + } + self.outputs = { + 'Out': x_fp32.sum(keepdims=self.attrs['keep_dim']) + } + + +@skip_check_grad_ci( + reason="not implemented") +class TestReduceSum4DReduceAllONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + x_fp32 = np.random.normal(size=(4, 3, 2, 3)).astype('float32') + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = { + 'reduce_all': True, + 'use_mkldnn': self.use_mkldnn + } + self.outputs = {'Out': x_fp32.sum()} + + +@skip_check_grad_ci( + reason="reduce_max is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMax3DONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): + """Remove Max with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.use_mkldnn = True + x_fp32 = np.random.random((5, 6, 10)).astype("float32") + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = { + 'dim': [-1], + 'use_mkldnn' : self.use_mkldnn + } + self.outputs = { + 'Out': x_fp32.max(axis=tuple(self.attrs['dim'])) + } + +@skip_check_grad_ci( + reason="reduce_max is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMax4DNegativeAndPositiveDimsONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): + """Remove Max with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.use_mkldnn = True + x_fp32 = np.random.random((5, 6, 10, 9)).astype("float32") + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = { + 'dim': [-1, 0, 1], + 'use_mkldnn' : self.use_mkldnn + } + self.outputs = { + 'Out': x_fp32.max(axis=tuple(self.attrs['dim'])) + } + +@skip_check_grad_ci( + reason="reduce_min is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMin3DONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.use_mkldnn = True + x_fp32 = np.random.random((5, 6, 10)).astype("float32") + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = { + 'dim': [2], + 'use_mkldnn': self.use_mkldnn + } + self.outputs = { + 'Out': x_fp32.min(axis=tuple(self.attrs['dim'])) + } + + +@skip_check_grad_ci( + reason="not implemented") +class TestReduceMean3DONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_mean" + self.use_mkldnn = True + x_fp32 = np.random.random((5, 6, 10)).astype("float32") + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = { 'use_mkldnn' : self.use_mkldnn } + self.outputs = {'Out': x_fp32.sum(axis=0) / x_fp32.shape[0]} if __name__ == '__main__': import paddle diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py index 431e37570022c..83a10960b1fab 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py @@ -42,7 +42,7 @@ def test_check_output(self): @skip_check_grad_ci( reason="not implemented") -class TestReduceSumONEDNNOp(TestReduceSumDefaultONEDNNOp): +class TestReduceSum4DONEDNNOp(TestReduceSumDefaultONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -55,29 +55,75 @@ def setUp(self): @skip_check_grad_ci( - reason="reduce_max is discontinuous non-derivable function," - " its gradient check is not supported by unittest framework.") -class TestReduceMaxONEDNNOp(TestReduceSumDefaultONEDNNOp): - """Remove Max with subgradient from gradient check to confirm the success of CI.""" + reason="not implemented") +class TestReduceSum4DReduceAllWithoutReduceAllAttributeONEDNNOp(TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 10, 5, 5)).astype("float32")} + self.attrs = { + 'use_mkldnn': self.use_mkldnn, + 'dim': [0, 1, 2, 3] + } + self.outputs = {'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))} + +@skip_check_grad_ci( + reason="not implemented") +class TestReduceSum4DReduceAllWithoutReduceAllAttributeNegativeDimsONEDNNOp(TestReduceSumDefaultONEDNNOp): def setUp(self): - self.op_type = "reduce_max" + self.op_type = "reduce_sum" self.use_mkldnn = True - self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.inputs = {'X': np.random.random((5, 10, 5, 5)).astype("float32")} self.attrs = { - 'dim': [-1], - 'use_mkldnn' : self.use_mkldnn + 'use_mkldnn': self.use_mkldnn, + 'dim': [-1, -2, -3, -4] } + self.outputs = {'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))} + + +@skip_check_grad_ci( + reason="not implemented") +class TestReduceSum5DKeepDimsONEDNNOp(TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + self.inputs = { + 'X': np.random.random((2, 5, 3, 2, 2)).astype("float32") + } + self.attrs = { + 'dim': (2, 3, 4), + 'keep_dim': True, + 'use_mkldnn': True + } self.outputs = { - 'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim'])) + 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']), + keepdims=self.attrs['keep_dim']) } - def test_check_output(self): - self.check_output() @skip_check_grad_ci( reason="not implemented") -class TestReduceSumToScalarONEDNNOp(TestReduceSumDefaultONEDNNOp): +class TestReduceSum5DReduceAllKeepDimsONEDNNOp(TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + self.inputs = { + 'X': np.random.random((2, 5, 3, 2, 2)).astype("float32") + } + self.attrs = { + 'reduce_all': True, + 'keep_dim': True, + 'use_mkldnn': True + } + self.outputs = { + 'Out': self.inputs['X'].sum(keepdims=self.attrs['keep_dim']) + } + + +@skip_check_grad_ci( + reason="not implemented") +class TestReduceSum4DReduceAllONEDNNOp(TestReduceSumDefaultONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -90,55 +136,70 @@ def setUp(self): @skip_check_grad_ci( - reason="reduce_min is discontinuous non-derivable function," + reason="reduce_max is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") -class TestReduceMinONEDNNOp(TestReduceSumDefaultONEDNNOp): - """Remove Min with subgradient from gradient check to confirm the success of CI.""" +class TestReduceMax3DONEDNNOp(TestReduceSumDefaultONEDNNOp): + """Remove Max with subgradient from gradient check to confirm the success of CI.""" def setUp(self): - self.op_type = "reduce_min" + self.op_type = "reduce_max" self.use_mkldnn = True self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} self.attrs = { - 'dim': [2], - 'use_mkldnn': self.use_mkldnn + 'dim': [-1], + 'use_mkldnn' : self.use_mkldnn } self.outputs = { - 'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim'])) + 'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim'])) } @skip_check_grad_ci( - reason="reduce_min is discontinuous non-derivable function," + reason="reduce_max is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") -class TestReduceMeanONEDNNOp(TestReduceSumDefaultONEDNNOp): - def setUp(self): - self.op_type = "reduce_mean" - self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} - self.outputs = {'Out': self.inputs['X'].sum(axis=0) / self.inputs['X'].shape[0]} +class TestReduceMax4DNegativeAndPositiveDimsONEDNNOp(TestReduceSumDefaultONEDNNOp): + """Remove Max with subgradient from gradient check to confirm the success of CI.""" + def setUp(self): + self.op_type = "reduce_max" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 6, 10, 9)).astype("float32")} + self.attrs = { + 'dim': [-1, 0, 1], + 'use_mkldnn' : self.use_mkldnn + } + self.outputs = { + 'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim'])) + } @skip_check_grad_ci( reason="reduce_min is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") -class TestReduceSumKeepDimsONEDNNOp(TestReduceSumDefaultONEDNNOp): +class TestReduceMin3DONEDNNOp(TestReduceSumDefaultONEDNNOp): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + def setUp(self): - self.op_type = "reduce_sum" + self.op_type = "reduce_min" self.use_mkldnn = True - self.inputs = { - 'X': np.random.random((2, 5, 3, 2, 2)).astype("float32") - } + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} self.attrs = { - 'dim': (2, 3, 4), - 'keep_dim': True, - 'use_mkldnn': True - } + 'dim': [2], + 'use_mkldnn': self.use_mkldnn + } self.outputs = { - 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']), - keepdims=self.attrs['keep_dim']) + 'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim'])) } +@skip_check_grad_ci( + reason="not implemented") +class TestReduceMean3DONEDNNOp(TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_mean" + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.outputs = {'Out': self.inputs['X'].sum(axis=0) / self.inputs['X'].shape[0]} + + if __name__ == '__main__': import paddle paddle.enable_static() From 726846f2fb846811d525689f3c11076027a64828 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Tue, 23 Mar 2021 09:17:32 +0100 Subject: [PATCH 04/37] reverted old file --- paddle/fluid/platform/profiler_helper.h | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/paddle/fluid/platform/profiler_helper.h b/paddle/fluid/platform/profiler_helper.h index cc69438050e46..ae4d75113cd06 100644 --- a/paddle/fluid/platform/profiler_helper.h +++ b/paddle/fluid/platform/profiler_helper.h @@ -730,7 +730,7 @@ void AnalyzeEvent( } } for (size_t j = 0; j < table_size; ++j) { - if (child_index[j] == 0) { // pushes and counts only parents, ensures that time will not be counted twice + if (child_index[j] == 0) { main_event_items.push_back(event_items[j]); total += event_items[j].total_time; } else if ((child_index[j] == 1 && @@ -746,13 +746,6 @@ void AnalyzeEvent( std::pair(fname, event_items[j])); } } - else if (child_index[j] == 1 && event_items[j].name.find("reorder") != std::string::npos){ - size_t first_slash_pos = event_items[j].name.find('/'); - if(first_slash_pos != std::string::npos){ - std::string fname = event_items[j].name.substr(0, first_slash_pos); - child_map->insert(std::pair(fname, event_items[j])); - } - } } // average time for (auto &item : main_event_items) { From 6763404ad814ac6c42a62147ee591f3a59235dcb Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Tue, 23 Mar 2021 09:34:00 +0100 Subject: [PATCH 05/37] added clang formatting --- .../mkldnn/elementwise_add_mkldnn_op.cc | 3 +- .../mkldnn/elementwise_mkldnn_op.h | 23 +-- .../mkldnn/elementwise_mul_mkldnn_op.cc | 3 +- .../reduce_ops/mkldnn/reduce_max_mkldnn_op.cc | 2 +- .../mkldnn/reduce_mean_mkldnn_op.cc | 2 +- .../reduce_ops/mkldnn/reduce_min_mkldnn_op.cc | 2 +- .../reduce_ops/mkldnn/reduce_mkldnn_op.h | 47 +++--- .../reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc | 2 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 8 +- paddle/fluid/platform/mkldnn_reuse.h | 3 +- .../mkldnn/test_reduce_bf16_mkldnn_op.py | 141 +++++++----------- .../unittests/mkldnn/test_reduce_mkldnn_op.py | 103 +++++-------- 12 files changed, 138 insertions(+), 201 deletions(-) diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index e3e256812203a..8f519de075760 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -86,7 +86,8 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel { platform::ReductionMKLDNNHandler handler_sum( dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine, ctx.GetPlace(), dout, dy, - ctx.InputName(framework::GradVarName("Out")), CalculateBroadcastedDims(dout, dy)); + ctx.InputName(framework::GradVarName("Out")), + CalculateBroadcastedDims(dout, dy)); auto dy_memory_p = handler_sum.AcquireDstMemory(dy); auto reduction_p = handler_sum.AcquireForwardPrimitive(); reduction_p->execute(astream, {{DNNL_ARG_SRC, *reorder_src_memory_p}, diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h index 247deb46bd4b9..0e35c0db04588 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h @@ -82,16 +82,17 @@ class EltwiseMKLDNNKernel : public framework::OpKernel { } }; - inline std::vector CalculateBroadcastedDims(const Tensor* x, const Tensor* y){ - const auto src_tz = framework::vectorize(x->dims()); - const auto dst_tz = framework::vectorize(y->dims()); - - int j = 0; - std::vector dst_tz_ex(src_tz.size(), 1); - for (size_t i = 0; i < src_tz.size(); ++i) - dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++]; - - return dst_tz_ex; - } +inline std::vector CalculateBroadcastedDims(const Tensor* x, + const Tensor* y) { + const auto src_tz = framework::vectorize(x->dims()); + const auto dst_tz = framework::vectorize(y->dims()); + + int j = 0; + std::vector dst_tz_ex(src_tz.size(), 1); + for (size_t i = 0; i < src_tz.size(); ++i) + dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++]; + + return dst_tz_ex; +} } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc index bb1231111e2c9..1c246e8d18937 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc @@ -105,7 +105,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel { platform::ReductionMKLDNNHandler handler_sum( dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, mkldnn_engine, ctx.GetPlace(), dout, dy, - ctx.InputName(framework::GradVarName("Out")), CalculateBroadcastedDims(dout, dy)); + ctx.InputName(framework::GradVarName("Out")), + CalculateBroadcastedDims(dout, dy)); auto dy_memory_p = handler_sum.AcquireDstMemory(dy); auto reduction_p = handler_sum.AcquireForwardPrimitive(); // As source we use mem object with results from binary operation diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_max_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_max_mkldnn_op.cc index a24caa718f28f..b18c16c8c71f7 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_max_mkldnn_op.cc +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_max_mkldnn_op.cc @@ -21,7 +21,7 @@ template class ReduceMaxMKLDNNKernel : public ReduceMKLDNNKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - this->RunKernel(ctx, dnnl::algorithm::reduction_max); + this->RunKernel(ctx, dnnl::algorithm::reduction_max); } }; diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc index 903bf5f24a43d..a9eed0d7eb042 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc @@ -21,7 +21,7 @@ template class ReduceMeanMKLDNNKernel : public ReduceMKLDNNKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - this->RunKernel(ctx, dnnl::algorithm::reduction_mean); + this->RunKernel(ctx, dnnl::algorithm::reduction_mean); } }; diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_min_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_min_mkldnn_op.cc index f39938893b654..ce63a1485471f 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_min_mkldnn_op.cc +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_min_mkldnn_op.cc @@ -21,7 +21,7 @@ template class ReduceMinMKLDNNKernel : public ReduceMKLDNNKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - this->RunKernel(ctx, dnnl::algorithm::reduction_min); + this->RunKernel(ctx, dnnl::algorithm::reduction_min); } }; diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h index 62d4abe759ebb..a7e6b8c85685d 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -21,11 +21,11 @@ using paddle::framework::LoDTensor; using paddle::framework::Tensor; using platform::to_void_cast; - template class ReduceMKLDNNKernel : public framework::OpKernel { public: - void RunKernel(const framework::ExecutionContext& ctx, dnnl::algorithm reduction_type) const { + void RunKernel(const framework::ExecutionContext& ctx, + dnnl::algorithm reduction_type) const { auto& dev_ctx = ctx.template device_context(); const auto& onednn_engine = dev_ctx.GetEngine(); @@ -37,19 +37,18 @@ class ReduceMKLDNNKernel : public framework::OpKernel { bool reduce_all = ctx.Attr("reduce_all"); bool keep_dim = ctx.Attr("keep_dim"); - std::vector output_dims = CalculateOutputDims(input, output, reduce_dims, reduce_all, keep_dim); + std::vector output_dims = + CalculateOutputDims(input, output, reduce_dims, reduce_all, keep_dim); platform::ReductionMKLDNNHandler handler( - reduction_type, 0.0f, 0.0f, dev_ctx, onednn_engine, - ctx.GetPlace(), input, output, - ctx.InputName("X"), output_dims); + reduction_type, 0.0f, 0.0f, dev_ctx, onednn_engine, ctx.GetPlace(), + input, output, ctx.InputName("X"), output_dims); auto src_memory_p = handler.AcquireSrcMemory(input); auto dst_memory_p = handler.AcquireDstMemory(output); std::unordered_map reduction_args = { - {DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}}; + {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; auto reduction_p = handler.AcquireForwardPrimitive(); @@ -57,25 +56,33 @@ class ReduceMKLDNNKernel : public framework::OpKernel { reduction_p->execute(astream, reduction_args); astream.wait(); - output->set_layout(framework::DataLayout::kMKLDNN); output->set_format( platform::GetMKLDNNFormat(dst_memory_p->get_desc().reshape( paddle::framework::vectorize(output->dims())))); } -private: - std::vector CalculateOutputDims(const Tensor* input, const Tensor* output, std::vector& reduce_dims, bool reduce_all, bool keep_dim) const{ - if(keep_dim) - return framework::vectorize(output->dims()); + private: + std::vector CalculateOutputDims(const Tensor* input, + const Tensor* output, + std::vector& reduce_dims, + bool reduce_all, + bool keep_dim) const { + if (keep_dim) return framework::vectorize(output->dims()); + + if (reduce_all) + return std::vector(framework::vectorize(input->dims()).size(), + 1); - if(reduce_all) - return std::vector (framework::vectorize(input->dims()).size(), 1); - std::vector output_dims(framework::vectorize(input->dims())); - for(size_t i = 0; i < reduce_dims.size() ; ++i){ - reduce_dims[i] = (reduce_dims[i] >= 0) ? reduce_dims[i] : input->dims().size() + reduce_dims[i]; // dims can be counted backwards, "-1" = last dimension - output_dims[reduce_dims[i]] = 1; + for (size_t i = 0; i < reduce_dims.size(); ++i) { + reduce_dims[i] = + (reduce_dims[i] >= 0) + ? reduce_dims[i] + : input->dims().size() + reduce_dims[i]; // dims can be counted + // backwards, "-1" = + // last dimension + output_dims[reduce_dims[i]] = 1; } return output_dims; @@ -84,5 +91,3 @@ class ReduceMKLDNNKernel : public framework::OpKernel { } // namespace operators } // namespace paddle - - diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc index 30823231b0fa6..4676589e68910 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc @@ -21,7 +21,7 @@ template class ReduceSumMKLDNNKernel : public ReduceMKLDNNKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - this->RunKernel(ctx, dnnl::algorithm::reduction_sum); + this->RunKernel(ctx, dnnl::algorithm::reduction_sum); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index b3fddec07d552..ab1b8de7c00e6 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -499,9 +499,9 @@ class ReduceOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx, input_data_type)) { - return framework::OpKernelType(input_data_type, ctx.GetPlace(), - framework::DataLayout::kMKLDNN, - framework::LibraryType::kMKLDNN); + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif @@ -603,7 +603,7 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { "The dtype of output, default value is -1, the dtype is same as intput") .SetDefault(-1); AddAttr("use_mkldnn", - "(bool, default false) Only used in mkldnn kernel") + "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); AddComment(string::Sprintf(R"DOC( %s Operator. diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 51864753d7d1d..0c45da63edd70 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -638,7 +638,8 @@ class ReductionMKLDNNHandler const float eps, const MKLDNNDeviceContext& dev_ctx, const mkldnn::engine engine, platform::Place cpu_place, const Tensor* x, const Tensor* y, - const std::string& uniq_name, std::vector output_dims) + const std::string& uniq_name, + std::vector output_dims) : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py index e0287e08ea479..844408516a547 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py @@ -24,8 +24,7 @@ from paddle.fluid.framework import convert_np_dtype_to_dtype_ -@skip_check_grad_ci( - reason="not implemented") +@skip_check_grad_ci(reason="not implemented") class TestReduceSumBF16DefaultONEDNNOp(OpTest): def setUp(self): self.op_type = "reduce_sum" @@ -34,191 +33,151 @@ def setUp(self): x_bf16 = convert_float_to_uint16(x_fp32) self.inputs = {'X': x_bf16} self.outputs = {'Out': x_fp32.sum(axis=0)} - self.attrs = { - 'use_mkldnn': self.use_mkldnn - } + self.attrs = {'use_mkldnn': self.use_mkldnn} def test_check_output(self): self.check_output(check_dygraph=False) -@skip_check_grad_ci( - reason="not implemented") -class TestReduceSum4DONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum4DBF16ONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True x_fp32 = np.random.random((5, 10, 5, 5)).astype("float32") - x_bf16 = convert_float_to_uint16(x_fp32) + x_bf16 = convert_float_to_uint16(x_fp32) self.inputs = {'X': x_bf16} - self.attrs = { - 'use_mkldnn': self.use_mkldnn, - 'dim': [2] - } + self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [2]} self.outputs = {'Out': x_fp32.sum(axis=tuple(self.attrs['dim']))} -@skip_check_grad_ci( - reason="not implemented") -class TestReduceSum4DReduceAllWithoutReduceAllAttributeONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum4DReduceAllWithoutReduceAllAttributeBF16ONEDNNOp( + TestReduceSumBF16DefaultONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True x_fp32 = np.random.normal(size=(2, 3, 5, 6)).astype('float32') - x_bf16 = convert_float_to_uint16(x_fp32) + x_bf16 = convert_float_to_uint16(x_fp32) self.inputs = {'X': x_bf16} - self.attrs = { - 'use_mkldnn': self.use_mkldnn, - 'dim': [0, 1, 2, 3] - } + self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [0, 1, 2, 3]} self.outputs = {'Out': x_fp32.sum(axis=tuple(self.attrs['dim']))} -@skip_check_grad_ci( - reason="not implemented") -class TestReduceSum4DReduceAllWithoutReduceAllAttributeNegativeDimsONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum4DReduceAllWithoutReduceAllAttributeNegativeDimsBF16ONEDNNOp( + TestReduceSumBF16DefaultONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True x_fp32 = np.random.normal(size=(2, 7, 3, 5)).astype('float32') - x_bf16 = convert_float_to_uint16(x_fp32) + x_bf16 = convert_float_to_uint16(x_fp32) self.inputs = {'X': x_bf16} - self.attrs = { - 'use_mkldnn': self.use_mkldnn, - 'dim': [-1, -2, -3, -4] - } + self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [-1, -2, -3, -4]} self.outputs = {'Out': x_fp32.sum(axis=tuple(self.attrs['dim']))} -@skip_check_grad_ci( - reason="not implemented") -class TestReduceSum5DKeepDimsONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum5DKeepDimsONEDNNOp(TestReduceSumBF16DefaultBF16ONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True x_fp32 = np.random.random((2, 5, 3, 2, 2)).astype("float32") - x_bf16 = convert_float_to_uint16(x_fp32) + x_bf16 = convert_float_to_uint16(x_fp32) self.inputs = {'X': x_bf16} - self.attrs = { - 'dim': (2, 3, 4), - 'keep_dim': True, - 'use_mkldnn': True - } + self.attrs = {'dim': (2, 3, 4), 'keep_dim': True, 'use_mkldnn': True} self.outputs = { 'Out': x_fp32.sum(axis=tuple(self.attrs['dim']), - keepdims=self.attrs['keep_dim']) + keepdims=self.attrs['keep_dim']) } -@skip_check_grad_ci( - reason="not implemented") -class TestReduceSum5DReduceAllKeepDimsONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum5DReduceAllKeepDimsBF16ONEDNNOp( + TestReduceSumBF16DefaultONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True x_fp32 = np.random.normal(size=(2, 5, 3, 2, 4)).astype('float32') - x_bf16 = convert_float_to_uint16(x_fp32) + x_bf16 = convert_float_to_uint16(x_fp32) self.inputs = {'X': x_bf16} - self.attrs = { - 'reduce_all': True, - 'keep_dim': True, - 'use_mkldnn': True - } - self.outputs = { - 'Out': x_fp32.sum(keepdims=self.attrs['keep_dim']) - } + self.attrs = {'reduce_all': True, 'keep_dim': True, 'use_mkldnn': True} + self.outputs = {'Out': x_fp32.sum(keepdims=self.attrs['keep_dim'])} -@skip_check_grad_ci( - reason="not implemented") -class TestReduceSum4DReduceAllONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum4DReduceAllBF16ONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True x_fp32 = np.random.normal(size=(4, 3, 2, 3)).astype('float32') - x_bf16 = convert_float_to_uint16(x_fp32) + x_bf16 = convert_float_to_uint16(x_fp32) self.inputs = {'X': x_bf16} - self.attrs = { - 'reduce_all': True, - 'use_mkldnn': self.use_mkldnn - } + self.attrs = {'reduce_all': True, 'use_mkldnn': self.use_mkldnn} self.outputs = {'Out': x_fp32.sum()} @skip_check_grad_ci( reason="reduce_max is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") -class TestReduceMax3DONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +class TestReduceMax3DBF16ONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): """Remove Max with subgradient from gradient check to confirm the success of CI.""" def setUp(self): self.op_type = "reduce_max" self.use_mkldnn = True x_fp32 = np.random.random((5, 6, 10)).astype("float32") - x_bf16 = convert_float_to_uint16(x_fp32) + x_bf16 = convert_float_to_uint16(x_fp32) self.inputs = {'X': x_bf16} - self.attrs = { - 'dim': [-1], - 'use_mkldnn' : self.use_mkldnn - } - self.outputs = { - 'Out': x_fp32.max(axis=tuple(self.attrs['dim'])) - } + self.attrs = {'dim': [-1], 'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': x_fp32.max(axis=tuple(self.attrs['dim']))} @skip_check_grad_ci( reason="reduce_max is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") -class TestReduceMax4DNegativeAndPositiveDimsONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +class TestReduceMax4DNegativeAndPositiveDimsBF16ONEDNNOp( + TestReduceSumBF16DefaultONEDNNOp): """Remove Max with subgradient from gradient check to confirm the success of CI.""" def setUp(self): self.op_type = "reduce_max" self.use_mkldnn = True x_fp32 = np.random.random((5, 6, 10, 9)).astype("float32") - x_bf16 = convert_float_to_uint16(x_fp32) + x_bf16 = convert_float_to_uint16(x_fp32) self.inputs = {'X': x_bf16} - self.attrs = { - 'dim': [-1, 0, 1], - 'use_mkldnn' : self.use_mkldnn - } - self.outputs = { - 'Out': x_fp32.max(axis=tuple(self.attrs['dim'])) - } + self.attrs = {'dim': [-1, 0, 1], 'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': x_fp32.max(axis=tuple(self.attrs['dim']))} + @skip_check_grad_ci( reason="reduce_min is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") -class TestReduceMin3DONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +class TestReduceMin3DBF16ONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): """Remove Min with subgradient from gradient check to confirm the success of CI.""" def setUp(self): self.op_type = "reduce_min" self.use_mkldnn = True x_fp32 = np.random.random((5, 6, 10)).astype("float32") - x_bf16 = convert_float_to_uint16(x_fp32) + x_bf16 = convert_float_to_uint16(x_fp32) self.inputs = {'X': x_bf16} - self.attrs = { - 'dim': [2], - 'use_mkldnn': self.use_mkldnn - } - self.outputs = { - 'Out': x_fp32.min(axis=tuple(self.attrs['dim'])) - } + self.attrs = {'dim': [2], 'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': x_fp32.min(axis=tuple(self.attrs['dim']))} -@skip_check_grad_ci( - reason="not implemented") -class TestReduceMean3DONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +@skip_check_grad_ci(reason="not implemented") +class TestReduceMean3DBF16ONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): def setUp(self): self.op_type = "reduce_mean" self.use_mkldnn = True x_fp32 = np.random.random((5, 6, 10)).astype("float32") - x_bf16 = convert_float_to_uint16(x_fp32) + x_bf16 = convert_float_to_uint16(x_fp32) self.inputs = {'X': x_bf16} - self.attrs = { 'use_mkldnn' : self.use_mkldnn } + self.attrs = {'use_mkldnn': self.use_mkldnn} self.outputs = {'Out': x_fp32.sum(axis=0) / x_fp32.shape[0]} + if __name__ == '__main__': import paddle paddle.enable_static() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py index 83a10960b1fab..c0cb593662828 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py @@ -24,114 +24,89 @@ from paddle.fluid.framework import convert_np_dtype_to_dtype_ -@skip_check_grad_ci( - reason="not implemented") +@skip_check_grad_ci(reason="not implemented") class TestReduceSumDefaultONEDNNOp(OpTest): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} self.outputs = {'Out': self.inputs['X'].sum(axis=0)} - self.attrs = { - 'use_mkldnn': self.use_mkldnn - } + self.attrs = {'use_mkldnn': self.use_mkldnn} def test_check_output(self): self.check_output() -@skip_check_grad_ci( - reason="not implemented") +@skip_check_grad_ci(reason="not implemented") class TestReduceSum4DONEDNNOp(TestReduceSumDefaultONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True self.inputs = {'X': np.random.random((5, 10, 5, 5)).astype("float32")} - self.attrs = { - 'use_mkldnn': self.use_mkldnn, - 'dim': [2] + self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [2]} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) } - self.outputs = {'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))} -@skip_check_grad_ci( - reason="not implemented") -class TestReduceSum4DReduceAllWithoutReduceAllAttributeONEDNNOp(TestReduceSumDefaultONEDNNOp): +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum4DReduceAllWithoutReduceAllAttributeONEDNNOp( + TestReduceSumDefaultONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True self.inputs = {'X': np.random.random((5, 10, 5, 5)).astype("float32")} - self.attrs = { - 'use_mkldnn': self.use_mkldnn, - 'dim': [0, 1, 2, 3] + self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [0, 1, 2, 3]} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) } - self.outputs = {'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))} -@skip_check_grad_ci( - reason="not implemented") -class TestReduceSum4DReduceAllWithoutReduceAllAttributeNegativeDimsONEDNNOp(TestReduceSumDefaultONEDNNOp): +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum4DReduceAllWithoutReduceAllAttributeNegativeDimsONEDNNOp( + TestReduceSumDefaultONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True self.inputs = {'X': np.random.random((5, 10, 5, 5)).astype("float32")} - self.attrs = { - 'use_mkldnn': self.use_mkldnn, - 'dim': [-1, -2, -3, -4] + self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [-1, -2, -3, -4]} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) } - self.outputs = {'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))} -@skip_check_grad_ci( - reason="not implemented") +@skip_check_grad_ci(reason="not implemented") class TestReduceSum5DKeepDimsONEDNNOp(TestReduceSumDefaultONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True - self.inputs = { - 'X': np.random.random((2, 5, 3, 2, 2)).astype("float32") - } - self.attrs = { - 'dim': (2, 3, 4), - 'keep_dim': True, - 'use_mkldnn': True - } + self.inputs = {'X': np.random.random((2, 5, 3, 2, 2)).astype("float32")} + self.attrs = {'dim': (2, 3, 4), 'keep_dim': True, 'use_mkldnn': True} self.outputs = { 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']), keepdims=self.attrs['keep_dim']) } -@skip_check_grad_ci( - reason="not implemented") +@skip_check_grad_ci(reason="not implemented") class TestReduceSum5DReduceAllKeepDimsONEDNNOp(TestReduceSumDefaultONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True - self.inputs = { - 'X': np.random.random((2, 5, 3, 2, 2)).astype("float32") - } - self.attrs = { - 'reduce_all': True, - 'keep_dim': True, - 'use_mkldnn': True - } + self.inputs = {'X': np.random.random((2, 5, 3, 2, 2)).astype("float32")} + self.attrs = {'reduce_all': True, 'keep_dim': True, 'use_mkldnn': True} self.outputs = { 'Out': self.inputs['X'].sum(keepdims=self.attrs['keep_dim']) } -@skip_check_grad_ci( - reason="not implemented") +@skip_check_grad_ci(reason="not implemented") class TestReduceSum4DReduceAllONEDNNOp(TestReduceSumDefaultONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")} - self.attrs = { - 'reduce_all': True, - 'use_mkldnn': self.use_mkldnn - } + self.attrs = {'reduce_all': True, 'use_mkldnn': self.use_mkldnn} self.outputs = {'Out': self.inputs['X'].sum()} @@ -145,10 +120,7 @@ def setUp(self): self.op_type = "reduce_max" self.use_mkldnn = True self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} - self.attrs = { - 'dim': [-1], - 'use_mkldnn' : self.use_mkldnn - } + self.attrs = {'dim': [-1], 'use_mkldnn': self.use_mkldnn} self.outputs = { 'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim'])) } @@ -157,21 +129,20 @@ def setUp(self): @skip_check_grad_ci( reason="reduce_max is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") -class TestReduceMax4DNegativeAndPositiveDimsONEDNNOp(TestReduceSumDefaultONEDNNOp): +class TestReduceMax4DNegativeAndPositiveDimsONEDNNOp( + TestReduceSumDefaultONEDNNOp): """Remove Max with subgradient from gradient check to confirm the success of CI.""" def setUp(self): self.op_type = "reduce_max" self.use_mkldnn = True self.inputs = {'X': np.random.random((5, 6, 10, 9)).astype("float32")} - self.attrs = { - 'dim': [-1, 0, 1], - 'use_mkldnn' : self.use_mkldnn - } + self.attrs = {'dim': [-1, 0, 1], 'use_mkldnn': self.use_mkldnn} self.outputs = { 'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim'])) } + @skip_check_grad_ci( reason="reduce_min is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") @@ -182,22 +153,20 @@ def setUp(self): self.op_type = "reduce_min" self.use_mkldnn = True self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} - self.attrs = { - 'dim': [2], - 'use_mkldnn': self.use_mkldnn - } + self.attrs = {'dim': [2], 'use_mkldnn': self.use_mkldnn} self.outputs = { 'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim'])) } -@skip_check_grad_ci( - reason="not implemented") +@skip_check_grad_ci(reason="not implemented") class TestReduceMean3DONEDNNOp(TestReduceSumDefaultONEDNNOp): def setUp(self): self.op_type = "reduce_mean" self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} - self.outputs = {'Out': self.inputs['X'].sum(axis=0) / self.inputs['X'].shape[0]} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=0) / self.inputs['X'].shape[0] + } if __name__ == '__main__': From f2555e5d37b00b155d7916e46be099286d3e9e9b Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Tue, 23 Mar 2021 09:49:51 +0100 Subject: [PATCH 06/37] removed unnecessary imports and comments --- .../fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h | 9 +++------ .../tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py | 8 ++------ .../tests/unittests/mkldnn/test_reduce_mkldnn_op.py | 8 +------- 3 files changed, 6 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h index a7e6b8c85685d..7073288a9ed86 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -76,12 +76,9 @@ class ReduceMKLDNNKernel : public framework::OpKernel { std::vector output_dims(framework::vectorize(input->dims())); for (size_t i = 0; i < reduce_dims.size(); ++i) { - reduce_dims[i] = - (reduce_dims[i] >= 0) - ? reduce_dims[i] - : input->dims().size() + reduce_dims[i]; // dims can be counted - // backwards, "-1" = - // last dimension + reduce_dims[i] = (reduce_dims[i] >= 0) + ? reduce_dims[i] + : input->dims().size() + reduce_dims[i]; output_dims[reduce_dims[i]] = 1; } diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py index 844408516a547..c741d43cb92bf 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py @@ -16,12 +16,9 @@ import unittest import numpy as np -from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 -import paddle -import paddle.fluid.core as core +from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci import paddle.fluid as fluid -from paddle.fluid import compiler, Program, program_guard -from paddle.fluid.framework import convert_np_dtype_to_dtype_ +import paddle @skip_check_grad_ci(reason="not implemented") @@ -179,6 +176,5 @@ def setUp(self): if __name__ == '__main__': - import paddle paddle.enable_static() unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py index c0cb593662828..c2ec1bfc36668 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py @@ -12,16 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import print_function - import unittest import numpy as np from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci -import paddle -import paddle.fluid.core as core import paddle.fluid as fluid -from paddle.fluid import compiler, Program, program_guard -from paddle.fluid.framework import convert_np_dtype_to_dtype_ +import paddle @skip_check_grad_ci(reason="not implemented") @@ -170,6 +165,5 @@ def setUp(self): if __name__ == '__main__': - import paddle paddle.enable_static() unittest.main() From 8f80eb5120e480bc8b9d447ac56c558b6f63be46 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Tue, 23 Mar 2021 10:28:10 +0100 Subject: [PATCH 07/37] minor change --- .../mkldnn/test_reduce_bf16_mkldnn_op.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py index c741d43cb92bf..a61b70ec773fb 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py @@ -16,13 +16,13 @@ import unittest import numpy as np -from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci +from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 import paddle.fluid as fluid import paddle @skip_check_grad_ci(reason="not implemented") -class TestReduceSumBF16DefaultONEDNNOp(OpTest): +class TestReduceSumDefaultBF16ONEDNNOp(OpTest): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -37,7 +37,7 @@ def test_check_output(self): @skip_check_grad_ci(reason="not implemented") -class TestReduceSum4DBF16ONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +class TestReduceSum4DBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -50,7 +50,7 @@ def setUp(self): @skip_check_grad_ci(reason="not implemented") class TestReduceSum4DReduceAllWithoutReduceAllAttributeBF16ONEDNNOp( - TestReduceSumBF16DefaultONEDNNOp): + TestReduceSumDefaultBF16ONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -63,7 +63,7 @@ def setUp(self): @skip_check_grad_ci(reason="not implemented") class TestReduceSum4DReduceAllWithoutReduceAllAttributeNegativeDimsBF16ONEDNNOp( - TestReduceSumBF16DefaultONEDNNOp): + TestReduceSumDefaultBF16ONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -75,7 +75,7 @@ def setUp(self): @skip_check_grad_ci(reason="not implemented") -class TestReduceSum5DKeepDimsONEDNNOp(TestReduceSumBF16DefaultBF16ONEDNNOp): +class TestReduceSum5DKeepDimsONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -91,7 +91,7 @@ def setUp(self): @skip_check_grad_ci(reason="not implemented") class TestReduceSum5DReduceAllKeepDimsBF16ONEDNNOp( - TestReduceSumBF16DefaultONEDNNOp): + TestReduceSumDefaultBF16ONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -103,7 +103,7 @@ def setUp(self): @skip_check_grad_ci(reason="not implemented") -class TestReduceSum4DReduceAllBF16ONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +class TestReduceSum4DReduceAllBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -117,7 +117,7 @@ def setUp(self): @skip_check_grad_ci( reason="reduce_max is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") -class TestReduceMax3DBF16ONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +class TestReduceMax3DBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): """Remove Max with subgradient from gradient check to confirm the success of CI.""" def setUp(self): @@ -134,7 +134,7 @@ def setUp(self): reason="reduce_max is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") class TestReduceMax4DNegativeAndPositiveDimsBF16ONEDNNOp( - TestReduceSumBF16DefaultONEDNNOp): + TestReduceSumDefaultBF16ONEDNNOp): """Remove Max with subgradient from gradient check to confirm the success of CI.""" def setUp(self): @@ -150,7 +150,7 @@ def setUp(self): @skip_check_grad_ci( reason="reduce_min is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") -class TestReduceMin3DBF16ONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +class TestReduceMin3DBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): """Remove Min with subgradient from gradient check to confirm the success of CI.""" def setUp(self): @@ -164,7 +164,7 @@ def setUp(self): @skip_check_grad_ci(reason="not implemented") -class TestReduceMean3DBF16ONEDNNOp(TestReduceSumBF16DefaultONEDNNOp): +class TestReduceMean3DBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): def setUp(self): self.op_type = "reduce_mean" self.use_mkldnn = True From 3dfabd99e16b4a74f381ce38a884f460c3288e8c Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Thu, 25 Mar 2021 21:50:40 +0100 Subject: [PATCH 08/37] merged with develop --- .github/ISSUE_TEMPLATE/---document-issue-.md | 2 +- paddle/fluid/inference/api/demo_ci/clean.sh | 14 +++++++++++ paddle/fluid/train/demo/run.sh | 14 +++++++++++ paddle/fluid/train/imdb_demo/run.sh | 13 ++++++++++ paddle/scripts/build_docker_images.sh | 15 ++++++++++++ .../docker/root/.scripts/git-completion.sh | 15 ++++++++++++ paddle/scripts/fast_install.sh | 14 +++++++++++ patches/eigen/TensorBlock.h | 14 +++++++++++ python/paddle/fluid/dataloader/fetcher.py | 7 +++--- .../incubate/fleet/tests/cluster_train.sh | 14 +++++++++++ .../test_squared_mat_sub_fuse_pass.py | 6 +++-- .../unittests/ir/inference/test_trt_matmul.py | 23 +++++++----------- .../fluid/tests/unittests/parallel_test.sh | 15 ++++++++++++ .../fluid/tests/unittests/test_bce_loss.py | 12 ++++++--- .../unittests/test_bce_with_logits_loss.py | 6 +++-- .../tests/unittests/test_c_comm_init_op.sh | 15 ++++++++++++ .../tests/unittests/test_dist_fleet_ps10.py | 1 - .../test_flatten_contiguous_range_op.py | 3 ++- .../fluid/tests/unittests/test_l1_loss.py | 12 ++++++--- .../tests/unittests/test_listen_and_serv.sh | 15 ++++++++++++ .../fluid/tests/unittests/test_mse_loss.py | 18 +++++++++----- ...ess_dataloader_iterable_dataset_dynamic.py | 1 + .../tests/unittests/test_pixel_shuffle.py | 12 ++++++--- .../fluid/tests/unittests/test_prod_op.py | 6 +++-- .../fluid/tests/unittests/test_selu_op.py | 9 ++++--- .../unittests/test_sigmoid_focal_loss.py | 6 +++-- .../tests/unittests/test_transpose_op.py | 8 ++++-- tools/check_sequence_op.sh | 14 +++++++++++ tools/cudaError/start.sh | 15 ++++++++++++ tools/diff_api.py | 15 ++++++++++++ tools/diff_unittest.py | 15 ++++++++++++ tools/dockerfile/icode.sh | 14 +++++++++++ tools/document_preview.sh | 15 ++++++++++++ tools/get_cpu_info.sh | 14 +++++++++++ tools/static_mode_white_list.pyc | Bin 21803 -> 22152 bytes 35 files changed, 341 insertions(+), 51 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/---document-issue-.md b/.github/ISSUE_TEMPLATE/---document-issue-.md index 7c464ac584bc8..ffc2fcd7817b6 100644 --- a/.github/ISSUE_TEMPLATE/---document-issue-.md +++ b/.github/ISSUE_TEMPLATE/---document-issue-.md @@ -56,4 +56,4 @@ For example: no sample code; The sample code is not helpful; The sample code not For example:Chinese API in this doc is inconsistent with English API, including params, description, sample code, formula, etc. #### Other -For example: The doc link is broken; The doc page is missing; Dead link in docs. \ No newline at end of file +For example: The doc link is broken; The doc page is missing; Dead link in docs. diff --git a/paddle/fluid/inference/api/demo_ci/clean.sh b/paddle/fluid/inference/api/demo_ci/clean.sh index 0d9f3d2aa237a..c265721db5775 100755 --- a/paddle/fluid/inference/api/demo_ci/clean.sh +++ b/paddle/fluid/inference/api/demo_ci/clean.sh @@ -1,3 +1,17 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + set -x cd `dirname $0` rm -rf build/ data/ diff --git a/paddle/fluid/train/demo/run.sh b/paddle/fluid/train/demo/run.sh index 2955e7574daa2..c45a3528febdd 100755 --- a/paddle/fluid/train/demo/run.sh +++ b/paddle/fluid/train/demo/run.sh @@ -1,5 +1,19 @@ #!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + set -x PADDLE_ROOT=$1 diff --git a/paddle/fluid/train/imdb_demo/run.sh b/paddle/fluid/train/imdb_demo/run.sh index f71b4bac602a9..6de1df27e0035 100644 --- a/paddle/fluid/train/imdb_demo/run.sh +++ b/paddle/fluid/train/imdb_demo/run.sh @@ -1,3 +1,16 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. set -exu build/demo_trainer --flagfile="train.cfg" diff --git a/paddle/scripts/build_docker_images.sh b/paddle/scripts/build_docker_images.sh index a90f0885294a9..2b584cdca6b4c 100644 --- a/paddle/scripts/build_docker_images.sh +++ b/paddle/scripts/build_docker_images.sh @@ -1,4 +1,19 @@ #!/bin/sh + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + set -xe REPO="${REPO:-paddlepaddle}" diff --git a/paddle/scripts/docker/root/.scripts/git-completion.sh b/paddle/scripts/docker/root/.scripts/git-completion.sh index bdddef5ac2faf..c43e88a4acd73 100755 --- a/paddle/scripts/docker/root/.scripts/git-completion.sh +++ b/paddle/scripts/docker/root/.scripts/git-completion.sh @@ -1,4 +1,19 @@ #!bash + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # # bash/zsh completion support for core Git. # diff --git a/paddle/scripts/fast_install.sh b/paddle/scripts/fast_install.sh index 1034b1c5c1043..cacec55d3bc22 100644 --- a/paddle/scripts/fast_install.sh +++ b/paddle/scripts/fast_install.sh @@ -1,5 +1,19 @@ #!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + ## purple to echo function purple(){ echo -e "\033[35m$1\033[0m" diff --git a/patches/eigen/TensorBlock.h b/patches/eigen/TensorBlock.h index 1e55d12c42fc2..1b7bfed9ec89e 100644 --- a/patches/eigen/TensorBlock.h +++ b/patches/eigen/TensorBlock.h @@ -1,3 +1,17 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + // This file is part of Eigen, a lightweight C++ template library // for linear algebra. // diff --git a/python/paddle/fluid/dataloader/fetcher.py b/python/paddle/fluid/dataloader/fetcher.py index 9382a70422370..41e12fbc68ec1 100644 --- a/python/paddle/fluid/dataloader/fetcher.py +++ b/python/paddle/fluid/dataloader/fetcher.py @@ -27,8 +27,8 @@ def fetch(self, batch_indices): class _IterableDatasetFetcher(_DatasetFetcher): def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last): - super(_IterableDatasetFetcher, self).__init__(dataset, auto_collate_batch, - collate_fn, drop_last) + super(_IterableDatasetFetcher, self).__init__( + dataset, auto_collate_batch, collate_fn, drop_last) self.dataset_iter = iter(dataset) def fetch(self, batch_indices): @@ -53,7 +53,8 @@ def fetch(self, batch_indices): class _MapDatasetFetcher(_DatasetFetcher): def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last): - super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch, collate_fn, drop_last) + super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch, + collate_fn, drop_last) def fetch(self, batch_indices): if self.auto_collate_batch: diff --git a/python/paddle/fluid/incubate/fleet/tests/cluster_train.sh b/python/paddle/fluid/incubate/fleet/tests/cluster_train.sh index 1df6b0618de8d..cac2f7234bdf2 100644 --- a/python/paddle/fluid/incubate/fleet/tests/cluster_train.sh +++ b/python/paddle/fluid/incubate/fleet/tests/cluster_train.sh @@ -1,5 +1,19 @@ #!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # start pserver0 python fleet_deep_ctr.py \ --role pserver \ diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py index 95cff4de6f6b0..69a9ae3c0ad2c 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py @@ -40,9 +40,11 @@ def setUp(self): matmul_ab_square = paddle.square(matmul_ab) matmul_square_ab = paddle.matmul(data_a_square, data_b_square) - scale = paddle.fluid.layers.fill_constant(shape=[1], value=0.5, dtype='float32') + scale = paddle.fluid.layers.fill_constant( + shape=[1], value=0.5, dtype='float32') - sub_val = paddle.fluid.layers.elementwise_sub(matmul_ab_square, matmul_square_ab) + sub_val = paddle.fluid.layers.elementwise_sub(matmul_ab_square, + matmul_square_ab) squared_mat_sub_out = fluid.layers.elementwise_mul(sub_val, scale) self.feeds = { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py index 94434f4043448..080d1ccc9054b 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py @@ -25,19 +25,16 @@ class TensorRTMatMulDims2Test(InferencePassTest): def setUp(self): self.set_params() with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data( - name="data", shape=[24, 24], dtype="float32") + data = fluid.data(name="data", shape=[24, 24], dtype="float32") matmul_out = fluid.layers.matmul( x=data, y=data, - transpose_x = self.transpose_x, - transpose_y = self.transpose_y, - alpha = self.alpha) + transpose_x=self.transpose_x, + transpose_y=self.transpose_y, + alpha=self.alpha) out = fluid.layers.batch_norm(matmul_out, is_test=True) - self.feeds = { - "data": np.ones([24, 24]).astype("float32"), - } + self.feeds = {"data": np.ones([24, 24]).astype("float32"), } self.enable_trt = True self.trt_parameters = TensorRTMatMulDims2Test.TensorRTParam( 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) @@ -65,14 +62,12 @@ def setUp(self): matmul_out = fluid.layers.matmul( x=data, y=data, - transpose_x = self.transpose_x, - transpose_y = self.transpose_y, - alpha = self.alpha) + transpose_x=self.transpose_x, + transpose_y=self.transpose_y, + alpha=self.alpha) out = fluid.layers.batch_norm(matmul_out, is_test=True) - self.feeds = { - "data": np.ones([1, 6, 24, 24]).astype("float32"), - } + self.feeds = {"data": np.ones([1, 6, 24, 24]).astype("float32"), } self.enable_trt = True self.trt_parameters = TensorRTMatMulTest.TensorRTParam( 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) diff --git a/python/paddle/fluid/tests/unittests/parallel_test.sh b/python/paddle/fluid/tests/unittests/parallel_test.sh index 9da4f035345d7..551b7cdb7a43c 100644 --- a/python/paddle/fluid/tests/unittests/parallel_test.sh +++ b/python/paddle/fluid/tests/unittests/parallel_test.sh @@ -1,4 +1,19 @@ #!/bin/bash + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + unset https_proxy http_proxy export FLAGS_rpc_disable_reuse_port=1 diff --git a/python/paddle/fluid/tests/unittests/test_bce_loss.py b/python/paddle/fluid/tests/unittests/test_bce_loss.py index 4b39436842b89..ea1a22780f093 100644 --- a/python/paddle/fluid/tests/unittests/test_bce_loss.py +++ b/python/paddle/fluid/tests/unittests/test_bce_loss.py @@ -27,8 +27,10 @@ def test_static_layer(place, prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): - input = paddle.fluid.data(name='input', shape=input_np.shape, dtype='float64') - label = paddle.fluid.data(name='label', shape=label_np.shape, dtype='float64') + input = paddle.fluid.data( + name='input', shape=input_np.shape, dtype='float64') + label = paddle.fluid.data( + name='label', shape=label_np.shape, dtype='float64') if weight_np is not None: weight = paddle.fluid.data( name='weight', shape=weight_np.shape, dtype='float64') @@ -58,8 +60,10 @@ def test_static_functional(place, prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): - input = paddle.fluid.data(name='input', shape=input_np.shape, dtype='float64') - label = paddle.fluid.data(name='label', shape=label_np.shape, dtype='float64') + input = paddle.fluid.data( + name='input', shape=input_np.shape, dtype='float64') + label = paddle.fluid.data( + name='label', shape=label_np.shape, dtype='float64') if weight_np is not None: weight = paddle.fluid.data( name='weight', shape=weight_np.shape, dtype='float64') diff --git a/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py b/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py index a6175aa471d69..153b8fd3e7f6b 100644 --- a/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py +++ b/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py @@ -48,8 +48,10 @@ def test_static(place, prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): - logit = paddle.fluid.data(name='logit', shape=logit_np.shape, dtype='float64') - label = paddle.fluid.data(name='label', shape=label_np.shape, dtype='float64') + logit = paddle.fluid.data( + name='logit', shape=logit_np.shape, dtype='float64') + label = paddle.fluid.data( + name='label', shape=label_np.shape, dtype='float64') feed_dict = {"logit": logit_np, "label": label_np} pos_weight = None diff --git a/python/paddle/fluid/tests/unittests/test_c_comm_init_op.sh b/python/paddle/fluid/tests/unittests/test_c_comm_init_op.sh index a9d450e223f1e..aba95a68ab790 100644 --- a/python/paddle/fluid/tests/unittests/test_c_comm_init_op.sh +++ b/python/paddle/fluid/tests/unittests/test_c_comm_init_op.sh @@ -1,4 +1,19 @@ #!/bin/bash + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + set -e # use default values # FIXME: random fails on Unknown command lines -c (or -m). diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps10.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps10.py index 16584ee50081a..a82866a797db1 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps10.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps10.py @@ -23,7 +23,6 @@ paddle.enable_static() - # For Net base_lr = 0.2 emb_lr = base_lr * 3 diff --git a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py index aa85eb3df3527..28803f5ac6232 100644 --- a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py +++ b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py @@ -170,7 +170,8 @@ def test_type(): x2 = np.arange(image_shape[0] * image_shape[1] * image_shape[2] * image_shape[3]).reshape(image_shape) / 100. x2 = x2.astype('float16') - x2_var = paddle.fluid.data(name='x2', shape=[3, 2, 4, 5], dtype='float16') + x2_var = paddle.fluid.data( + name='x2', shape=[3, 2, 4, 5], dtype='float16') paddle.flatten(x2_var) self.assertRaises(TypeError, test_type) diff --git a/python/paddle/fluid/tests/unittests/test_l1_loss.py b/python/paddle/fluid/tests/unittests/test_l1_loss.py index fba16959901a8..c35188623b440 100644 --- a/python/paddle/fluid/tests/unittests/test_l1_loss.py +++ b/python/paddle/fluid/tests/unittests/test_l1_loss.py @@ -44,8 +44,10 @@ def run_imperative(self): self.assertTrue(dy_result.shape, [10, 10, 5]) def run_static(self, use_gpu=False): - input = paddle.fluid.data(name='input', shape=[10, 10, 5], dtype='float32') - label = paddle.fluid.data(name='label', shape=[10, 10, 5], dtype='float32') + input = paddle.fluid.data( + name='input', shape=[10, 10, 5], dtype='float32') + label = paddle.fluid.data( + name='label', shape=[10, 10, 5], dtype='float32') result0 = paddle.nn.functional.l1_loss(input, label) result1 = paddle.nn.functional.l1_loss(input, label, reduction='sum') result2 = paddle.nn.functional.l1_loss(input, label, reduction='none') @@ -127,8 +129,10 @@ def run_imperative(self): self.assertTrue(dy_result.shape, [10, 10, 5]) def run_static(self, use_gpu=False): - input = paddle.fluid.data(name='input', shape=[10, 10, 5], dtype='float32') - label = paddle.fluid.data(name='label', shape=[10, 10, 5], dtype='float32') + input = paddle.fluid.data( + name='input', shape=[10, 10, 5], dtype='float32') + label = paddle.fluid.data( + name='label', shape=[10, 10, 5], dtype='float32') l1_loss = paddle.nn.loss.L1Loss() result0 = l1_loss(input, label) l1_loss = paddle.nn.loss.L1Loss(reduction='sum') diff --git a/python/paddle/fluid/tests/unittests/test_listen_and_serv.sh b/python/paddle/fluid/tests/unittests/test_listen_and_serv.sh index bee230fba5a7e..d9d64e4dfa693 100644 --- a/python/paddle/fluid/tests/unittests/test_listen_and_serv.sh +++ b/python/paddle/fluid/tests/unittests/test_listen_and_serv.sh @@ -1,4 +1,19 @@ #!/bin/bash + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + unset https_proxy http_proxy nohup python -u test_listen_and_serv_op.py > test_listen_and_serv_op.log 2>&1 & diff --git a/python/paddle/fluid/tests/unittests/test_mse_loss.py b/python/paddle/fluid/tests/unittests/test_mse_loss.py index bc5d35d3254bc..89eef6ca24243 100644 --- a/python/paddle/fluid/tests/unittests/test_mse_loss.py +++ b/python/paddle/fluid/tests/unittests/test_mse_loss.py @@ -191,8 +191,10 @@ def test_NNFunctionalMseLoss_mean(self): place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( ) else paddle.CPUPlace() with paddle.static.program_guard(prog, startup_prog): - input = paddle.fluid.data(name='input', shape=dim, dtype='float32') - target = paddle.fluid.data(name='target', shape=dim, dtype='float32') + input = paddle.fluid.data( + name='input', shape=dim, dtype='float32') + target = paddle.fluid.data( + name='target', shape=dim, dtype='float32') mse_loss = paddle.nn.functional.mse_loss(input, target, 'mean') exe = paddle.static.Executor(place) @@ -225,8 +227,10 @@ def test_NNFunctionalMseLoss_sum(self): place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( ) else paddle.CPUPlace() with paddle.static.program_guard(prog, startup_prog): - input = paddle.fluid.data(name='input', shape=dim, dtype='float32') - target = paddle.fluid.data(name='target', shape=dim, dtype='float32') + input = paddle.fluid.data( + name='input', shape=dim, dtype='float32') + target = paddle.fluid.data( + name='target', shape=dim, dtype='float32') mse_loss = paddle.nn.functional.mse_loss(input, target, 'sum') exe = paddle.static.Executor(place) @@ -259,8 +263,10 @@ def test_NNFunctionalMseLoss_none(self): place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( ) else paddle.CPUPlace() with paddle.static.program_guard(prog, startup_prog): - input = paddle.fluid.data(name='input', shape=dim, dtype='float32') - target = paddle.fluid.data(name='target', shape=dim, dtype='float32') + input = paddle.fluid.data( + name='input', shape=dim, dtype='float32') + target = paddle.fluid.data( + name='target', shape=dim, dtype='float32') mse_loss = paddle.nn.functional.mse_loss(input, target, 'none') exe = paddle.static.Executor(place) diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py index 0533a0d09fa0d..3bb3e843b1b11 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py @@ -160,5 +160,6 @@ def run_main(self, num_workers, places): print("time cost", ret['time'], 'step_list', ret['step']) return ret + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py b/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py index f75d6e9df540b..f1a409c712fc3 100644 --- a/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py +++ b/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py @@ -97,8 +97,10 @@ def test_static_graph_functional(self): place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() paddle.enable_static() - x_1 = paddle.fluid.data(name="x", shape=[2, 9, 4, 4], dtype="float64") - x_2 = paddle.fluid.data(name="x2", shape=[2, 4, 4, 9], dtype="float64") + x_1 = paddle.fluid.data( + name="x", shape=[2, 9, 4, 4], dtype="float64") + x_2 = paddle.fluid.data( + name="x2", shape=[2, 4, 4, 9], dtype="float64") out_1 = F.pixel_shuffle(x_1, 3) out_2 = F.pixel_shuffle(x_2, 3, "NHWC") @@ -123,8 +125,10 @@ def test_static_graph_layer(self): place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() paddle.enable_static() - x_1 = paddle.fluid.data(name="x", shape=[2, 9, 4, 4], dtype="float64") - x_2 = paddle.fluid.data(name="x2", shape=[2, 4, 4, 9], dtype="float64") + x_1 = paddle.fluid.data( + name="x", shape=[2, 9, 4, 4], dtype="float64") + x_2 = paddle.fluid.data( + name="x2", shape=[2, 4, 4, 9], dtype="float64") # init instance ps_1 = paddle.nn.PixelShuffle(3) ps_2 = paddle.nn.PixelShuffle(3, "NHWC") diff --git a/python/paddle/fluid/tests/unittests/test_prod_op.py b/python/paddle/fluid/tests/unittests/test_prod_op.py index 15fd79542d608..cdfcbb4e4e735 100644 --- a/python/paddle/fluid/tests/unittests/test_prod_op.py +++ b/python/paddle/fluid/tests/unittests/test_prod_op.py @@ -55,7 +55,8 @@ def run_imperative(self): self.assertTrue(np.allclose(dy_result.numpy(), expected_result)) def run_static(self, use_gpu=False): - input = paddle.fluid.data(name='input', shape=[10, 10, 5], dtype='float32') + input = paddle.fluid.data( + name='input', shape=[10, 10, 5], dtype='float32') result0 = paddle.prod(input) result1 = paddle.prod(input, axis=1) result2 = paddle.prod(input, axis=-1) @@ -114,7 +115,8 @@ def test_error(self): with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()): x = paddle.fluid.data(name='x', shape=[2, 2, 4], dtype='float32') - bool_x = paddle.fluid.data(name='bool_x', shape=[2, 2, 4], dtype='bool') + bool_x = paddle.fluid.data( + name='bool_x', shape=[2, 2, 4], dtype='bool') # The argument x shoule be a Tensor self.assertRaises(TypeError, paddle.prod, [1]) diff --git a/python/paddle/fluid/tests/unittests/test_selu_op.py b/python/paddle/fluid/tests/unittests/test_selu_op.py index 95ae1eecc6614..e71adae8d9b6e 100644 --- a/python/paddle/fluid/tests/unittests/test_selu_op.py +++ b/python/paddle/fluid/tests/unittests/test_selu_op.py @@ -128,15 +128,18 @@ def test_errors(self): # The input type must be Variable. self.assertRaises(TypeError, F.selu, 1) # The input dtype must be float16, float32, float64. - x_int32 = paddle.fluid.data(name='x_int32', shape=[12, 10], dtype='int32') + x_int32 = paddle.fluid.data( + name='x_int32', shape=[12, 10], dtype='int32') self.assertRaises(TypeError, F.selu, x_int32) # The scale must be greater than 1.0 - x_fp32 = paddle.fluid.data(name='x_fp32', shape=[12, 10], dtype='float32') + x_fp32 = paddle.fluid.data( + name='x_fp32', shape=[12, 10], dtype='float32') self.assertRaises(ValueError, F.selu, x_fp32, -1.0) # The alpha must be no less than 0 self.assertRaises(ValueError, F.selu, x_fp32, 1.6, -1.0) # support the input dtype is float16 - x_fp16 = paddle.fluid.data(name='x_fp16', shape=[12, 10], dtype='float16') + x_fp16 = paddle.fluid.data( + name='x_fp16', shape=[12, 10], dtype='float16') F.selu(x_fp16) diff --git a/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py b/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py index 85f9501e53f4a..2ef04d9cbfa73 100644 --- a/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py +++ b/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py @@ -42,8 +42,10 @@ def test_static(place, prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): - logit = paddle.fluid.data(name='logit', shape=logit_np.shape, dtype='float64') - label = paddle.fluid.data(name='label', shape=label_np.shape, dtype='float64') + logit = paddle.fluid.data( + name='logit', shape=logit_np.shape, dtype='float64') + label = paddle.fluid.data( + name='label', shape=label_np.shape, dtype='float64') feed_dict = {"logit": logit_np, "label": label_np} normalizer = None diff --git a/python/paddle/fluid/tests/unittests/test_transpose_op.py b/python/paddle/fluid/tests/unittests/test_transpose_op.py index f72df8cbe4640..59b4afdf8b02d 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_transpose_op.py @@ -23,6 +23,7 @@ paddle.enable_static() + class TestTransposeOp(OpTest): def setUp(self): self.init_op_type() @@ -151,6 +152,7 @@ def test_each_elem_value_check(): self.assertRaises(ValueError, test_each_elem_value_check) + class TestTransposeApi(unittest.TestCase): def test_static_out(self): paddle.enable_static() @@ -161,10 +163,11 @@ def test_static_out(self): place = paddle.CPUPlace() exe = paddle.static.Executor(place) x_np = np.random.random([2, 3, 4]).astype("float32") - result1, result2 = exe.run(feed={"x": x_np}, fetch_list=[x_trans1, x_trans2]) + result1, result2 = exe.run(feed={"x": x_np}, + fetch_list=[x_trans1, x_trans2]) expected_result1 = np.transpose(x_np, [1, 0, 2]) expected_result2 = np.transpose(x_np, (2, 1, 0)) - + np.testing.assert_array_equal(result1, expected_result1) np.testing.assert_array_equal(result2, expected_result2) @@ -185,6 +188,7 @@ def test_dygraph_out(self): # dygraph test paddle.enable_static() + class TestTAPI(unittest.TestCase): def test_out(self): with fluid.program_guard(fluid.Program()): diff --git a/tools/check_sequence_op.sh b/tools/check_sequence_op.sh index ada96750eaad8..a263b046b258b 100644 --- a/tools/check_sequence_op.sh +++ b/tools/check_sequence_op.sh @@ -1,5 +1,19 @@ #!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + PADDLE_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}")/../" && pwd )" function check_sequnece_op_unitests(){ diff --git a/tools/cudaError/start.sh b/tools/cudaError/start.sh index 3c0e57ffe7ec1..66e56b8485d8c 100644 --- a/tools/cudaError/start.sh +++ b/tools/cudaError/start.sh @@ -1,4 +1,19 @@ #!/usr/bin/env bash + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + set -ex SYSTEM=`uname -s` rm -f protoc-3.11.3-linux-x86_64.* diff --git a/tools/diff_api.py b/tools/diff_api.py index 8a2acbb3d0acc..f086598945afe 100644 --- a/tools/diff_api.py +++ b/tools/diff_api.py @@ -1,4 +1,19 @@ #!/usr/bin/env python + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from __future__ import print_function import difflib import sys diff --git a/tools/diff_unittest.py b/tools/diff_unittest.py index 382fbdd0b0c29..fa70be0990ec0 100644 --- a/tools/diff_unittest.py +++ b/tools/diff_unittest.py @@ -1,4 +1,19 @@ #!/usr/bin/env python + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import difflib import sys diff --git a/tools/dockerfile/icode.sh b/tools/dockerfile/icode.sh index da3ffb8c77db7..973975fe7f737 100755 --- a/tools/dockerfile/icode.sh +++ b/tools/dockerfile/icode.sh @@ -1,5 +1,19 @@ #!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + function install_gcc(){ sed -i 's##RUN apt-get update \ diff --git a/tools/document_preview.sh b/tools/document_preview.sh index 10f486f8fd4f6..83c758d0aa8b8 100755 --- a/tools/document_preview.sh +++ b/tools/document_preview.sh @@ -1,4 +1,19 @@ #!/bin/bash + +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + PADDLE_ROOT=/home mkdir ${PADDLE_ROOT} cd ${PADDLE_ROOT} diff --git a/tools/get_cpu_info.sh b/tools/get_cpu_info.sh index 81eb19dc0661e..bce338a8619e6 100755 --- a/tools/get_cpu_info.sh +++ b/tools/get_cpu_info.sh @@ -1,5 +1,19 @@ #!/bin/bash +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + if [ "`uname -s`" != "Linux" ]; then echo "Current scenario only support in Linux yet!" exit 0 diff --git a/tools/static_mode_white_list.pyc b/tools/static_mode_white_list.pyc index e9012c233595b6844f54e625972360f5aeeb0d3b..fdb2a15d7a94ae66e014844fb40917f41d91847c 100644 GIT binary patch delta 429 zcmYk1OG^S#6vywJRD7nEV@Qe+712g42-=yC77@%MXc5tM^h_`3Va^Ce5VULErEL$( zo)pM@l$Jr%7f7Nn&{Y$<+{N$Q|2hBjzfX_C$+Hk2_z8GF+UvzK`8zEjB*WQExLfft zR|e1r=m!jNAm}GJ2pHs-1Hlkr7;p$ML@+>b7%;+14g^KOC|8Mng8^}v$uS0To}Cq+kTGd*X2xsT(C7!PDArbSixK&esP zsA+VwZYs1Yo2IQ#jQWFOzon*Ny{xc>Y{b_?+VNqbT*N+28)0Fby=P;rxfo;vI#o?eqgq8pU8hFXO7_?Mm%Bj^^Kag?73%A%td;lx z-LKZIk^ZA*prX>U%zt7TyE zmor5%Fr>_53f7oE*_Pw-W*I=;jr!7q}*` u3^10o0lGuKC_leMKRG`owOGG6u`D${CqFS|^6LOg79UmyhRp_nKCA%dMl&)1 From 895f948b37a082955ee2dd84a11e7c9a2081aba5 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Thu, 25 Mar 2021 21:53:18 +0100 Subject: [PATCH 09/37] Revert "merged with develop" This reverts commit 3dfabd99e16b4a74f381ce38a884f460c3288e8c. --- .github/ISSUE_TEMPLATE/---document-issue-.md | 2 +- paddle/fluid/inference/api/demo_ci/clean.sh | 14 ----------- paddle/fluid/train/demo/run.sh | 14 ----------- paddle/fluid/train/imdb_demo/run.sh | 13 ---------- paddle/scripts/build_docker_images.sh | 15 ------------ .../docker/root/.scripts/git-completion.sh | 15 ------------ paddle/scripts/fast_install.sh | 14 ----------- patches/eigen/TensorBlock.h | 14 ----------- python/paddle/fluid/dataloader/fetcher.py | 7 +++--- .../incubate/fleet/tests/cluster_train.sh | 14 ----------- .../test_squared_mat_sub_fuse_pass.py | 6 ++--- .../unittests/ir/inference/test_trt_matmul.py | 23 +++++++++++------- .../fluid/tests/unittests/parallel_test.sh | 15 ------------ .../fluid/tests/unittests/test_bce_loss.py | 12 +++------ .../unittests/test_bce_with_logits_loss.py | 6 ++--- .../tests/unittests/test_c_comm_init_op.sh | 15 ------------ .../tests/unittests/test_dist_fleet_ps10.py | 1 + .../test_flatten_contiguous_range_op.py | 3 +-- .../fluid/tests/unittests/test_l1_loss.py | 12 +++------ .../tests/unittests/test_listen_and_serv.sh | 15 ------------ .../fluid/tests/unittests/test_mse_loss.py | 18 +++++--------- ...ess_dataloader_iterable_dataset_dynamic.py | 1 - .../tests/unittests/test_pixel_shuffle.py | 12 +++------ .../fluid/tests/unittests/test_prod_op.py | 6 ++--- .../fluid/tests/unittests/test_selu_op.py | 9 +++---- .../unittests/test_sigmoid_focal_loss.py | 6 ++--- .../tests/unittests/test_transpose_op.py | 8 ++---- tools/check_sequence_op.sh | 14 ----------- tools/cudaError/start.sh | 15 ------------ tools/diff_api.py | 15 ------------ tools/diff_unittest.py | 15 ------------ tools/dockerfile/icode.sh | 14 ----------- tools/document_preview.sh | 15 ------------ tools/get_cpu_info.sh | 14 ----------- tools/static_mode_white_list.pyc | Bin 22152 -> 21803 bytes 35 files changed, 51 insertions(+), 341 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/---document-issue-.md b/.github/ISSUE_TEMPLATE/---document-issue-.md index ffc2fcd7817b6..7c464ac584bc8 100644 --- a/.github/ISSUE_TEMPLATE/---document-issue-.md +++ b/.github/ISSUE_TEMPLATE/---document-issue-.md @@ -56,4 +56,4 @@ For example: no sample code; The sample code is not helpful; The sample code not For example:Chinese API in this doc is inconsistent with English API, including params, description, sample code, formula, etc. #### Other -For example: The doc link is broken; The doc page is missing; Dead link in docs. +For example: The doc link is broken; The doc page is missing; Dead link in docs. \ No newline at end of file diff --git a/paddle/fluid/inference/api/demo_ci/clean.sh b/paddle/fluid/inference/api/demo_ci/clean.sh index c265721db5775..0d9f3d2aa237a 100755 --- a/paddle/fluid/inference/api/demo_ci/clean.sh +++ b/paddle/fluid/inference/api/demo_ci/clean.sh @@ -1,17 +1,3 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - set -x cd `dirname $0` rm -rf build/ data/ diff --git a/paddle/fluid/train/demo/run.sh b/paddle/fluid/train/demo/run.sh index c45a3528febdd..2955e7574daa2 100755 --- a/paddle/fluid/train/demo/run.sh +++ b/paddle/fluid/train/demo/run.sh @@ -1,19 +1,5 @@ #!/bin/bash -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - set -x PADDLE_ROOT=$1 diff --git a/paddle/fluid/train/imdb_demo/run.sh b/paddle/fluid/train/imdb_demo/run.sh index 6de1df27e0035..f71b4bac602a9 100644 --- a/paddle/fluid/train/imdb_demo/run.sh +++ b/paddle/fluid/train/imdb_demo/run.sh @@ -1,16 +1,3 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. set -exu build/demo_trainer --flagfile="train.cfg" diff --git a/paddle/scripts/build_docker_images.sh b/paddle/scripts/build_docker_images.sh index 2b584cdca6b4c..a90f0885294a9 100644 --- a/paddle/scripts/build_docker_images.sh +++ b/paddle/scripts/build_docker_images.sh @@ -1,19 +1,4 @@ #!/bin/sh - -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - set -xe REPO="${REPO:-paddlepaddle}" diff --git a/paddle/scripts/docker/root/.scripts/git-completion.sh b/paddle/scripts/docker/root/.scripts/git-completion.sh index c43e88a4acd73..bdddef5ac2faf 100755 --- a/paddle/scripts/docker/root/.scripts/git-completion.sh +++ b/paddle/scripts/docker/root/.scripts/git-completion.sh @@ -1,19 +1,4 @@ #!bash - -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - # # bash/zsh completion support for core Git. # diff --git a/paddle/scripts/fast_install.sh b/paddle/scripts/fast_install.sh index cacec55d3bc22..1034b1c5c1043 100644 --- a/paddle/scripts/fast_install.sh +++ b/paddle/scripts/fast_install.sh @@ -1,19 +1,5 @@ #!/bin/bash -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - ## purple to echo function purple(){ echo -e "\033[35m$1\033[0m" diff --git a/patches/eigen/TensorBlock.h b/patches/eigen/TensorBlock.h index 1b7bfed9ec89e..1e55d12c42fc2 100644 --- a/patches/eigen/TensorBlock.h +++ b/patches/eigen/TensorBlock.h @@ -1,17 +1,3 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - // This file is part of Eigen, a lightweight C++ template library // for linear algebra. // diff --git a/python/paddle/fluid/dataloader/fetcher.py b/python/paddle/fluid/dataloader/fetcher.py index 41e12fbc68ec1..9382a70422370 100644 --- a/python/paddle/fluid/dataloader/fetcher.py +++ b/python/paddle/fluid/dataloader/fetcher.py @@ -27,8 +27,8 @@ def fetch(self, batch_indices): class _IterableDatasetFetcher(_DatasetFetcher): def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last): - super(_IterableDatasetFetcher, self).__init__( - dataset, auto_collate_batch, collate_fn, drop_last) + super(_IterableDatasetFetcher, self).__init__(dataset, auto_collate_batch, + collate_fn, drop_last) self.dataset_iter = iter(dataset) def fetch(self, batch_indices): @@ -53,8 +53,7 @@ def fetch(self, batch_indices): class _MapDatasetFetcher(_DatasetFetcher): def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last): - super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch, - collate_fn, drop_last) + super(_MapDatasetFetcher, self).__init__(dataset, auto_collate_batch, collate_fn, drop_last) def fetch(self, batch_indices): if self.auto_collate_batch: diff --git a/python/paddle/fluid/incubate/fleet/tests/cluster_train.sh b/python/paddle/fluid/incubate/fleet/tests/cluster_train.sh index cac2f7234bdf2..1df6b0618de8d 100644 --- a/python/paddle/fluid/incubate/fleet/tests/cluster_train.sh +++ b/python/paddle/fluid/incubate/fleet/tests/cluster_train.sh @@ -1,19 +1,5 @@ #!/bin/bash -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - # start pserver0 python fleet_deep_ctr.py \ --role pserver \ diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py index 69a9ae3c0ad2c..95cff4de6f6b0 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_squared_mat_sub_fuse_pass.py @@ -40,11 +40,9 @@ def setUp(self): matmul_ab_square = paddle.square(matmul_ab) matmul_square_ab = paddle.matmul(data_a_square, data_b_square) - scale = paddle.fluid.layers.fill_constant( - shape=[1], value=0.5, dtype='float32') + scale = paddle.fluid.layers.fill_constant(shape=[1], value=0.5, dtype='float32') - sub_val = paddle.fluid.layers.elementwise_sub(matmul_ab_square, - matmul_square_ab) + sub_val = paddle.fluid.layers.elementwise_sub(matmul_ab_square, matmul_square_ab) squared_mat_sub_out = fluid.layers.elementwise_mul(sub_val, scale) self.feeds = { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py index 080d1ccc9054b..94434f4043448 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_matmul.py @@ -25,16 +25,19 @@ class TensorRTMatMulDims2Test(InferencePassTest): def setUp(self): self.set_params() with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data(name="data", shape=[24, 24], dtype="float32") + data = fluid.data( + name="data", shape=[24, 24], dtype="float32") matmul_out = fluid.layers.matmul( x=data, y=data, - transpose_x=self.transpose_x, - transpose_y=self.transpose_y, - alpha=self.alpha) + transpose_x = self.transpose_x, + transpose_y = self.transpose_y, + alpha = self.alpha) out = fluid.layers.batch_norm(matmul_out, is_test=True) - self.feeds = {"data": np.ones([24, 24]).astype("float32"), } + self.feeds = { + "data": np.ones([24, 24]).astype("float32"), + } self.enable_trt = True self.trt_parameters = TensorRTMatMulDims2Test.TensorRTParam( 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) @@ -62,12 +65,14 @@ def setUp(self): matmul_out = fluid.layers.matmul( x=data, y=data, - transpose_x=self.transpose_x, - transpose_y=self.transpose_y, - alpha=self.alpha) + transpose_x = self.transpose_x, + transpose_y = self.transpose_y, + alpha = self.alpha) out = fluid.layers.batch_norm(matmul_out, is_test=True) - self.feeds = {"data": np.ones([1, 6, 24, 24]).astype("float32"), } + self.feeds = { + "data": np.ones([1, 6, 24, 24]).astype("float32"), + } self.enable_trt = True self.trt_parameters = TensorRTMatMulTest.TensorRTParam( 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) diff --git a/python/paddle/fluid/tests/unittests/parallel_test.sh b/python/paddle/fluid/tests/unittests/parallel_test.sh index 551b7cdb7a43c..9da4f035345d7 100644 --- a/python/paddle/fluid/tests/unittests/parallel_test.sh +++ b/python/paddle/fluid/tests/unittests/parallel_test.sh @@ -1,19 +1,4 @@ #!/bin/bash - -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - unset https_proxy http_proxy export FLAGS_rpc_disable_reuse_port=1 diff --git a/python/paddle/fluid/tests/unittests/test_bce_loss.py b/python/paddle/fluid/tests/unittests/test_bce_loss.py index ea1a22780f093..4b39436842b89 100644 --- a/python/paddle/fluid/tests/unittests/test_bce_loss.py +++ b/python/paddle/fluid/tests/unittests/test_bce_loss.py @@ -27,10 +27,8 @@ def test_static_layer(place, prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): - input = paddle.fluid.data( - name='input', shape=input_np.shape, dtype='float64') - label = paddle.fluid.data( - name='label', shape=label_np.shape, dtype='float64') + input = paddle.fluid.data(name='input', shape=input_np.shape, dtype='float64') + label = paddle.fluid.data(name='label', shape=label_np.shape, dtype='float64') if weight_np is not None: weight = paddle.fluid.data( name='weight', shape=weight_np.shape, dtype='float64') @@ -60,10 +58,8 @@ def test_static_functional(place, prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): - input = paddle.fluid.data( - name='input', shape=input_np.shape, dtype='float64') - label = paddle.fluid.data( - name='label', shape=label_np.shape, dtype='float64') + input = paddle.fluid.data(name='input', shape=input_np.shape, dtype='float64') + label = paddle.fluid.data(name='label', shape=label_np.shape, dtype='float64') if weight_np is not None: weight = paddle.fluid.data( name='weight', shape=weight_np.shape, dtype='float64') diff --git a/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py b/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py index 153b8fd3e7f6b..a6175aa471d69 100644 --- a/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py +++ b/python/paddle/fluid/tests/unittests/test_bce_with_logits_loss.py @@ -48,10 +48,8 @@ def test_static(place, prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): - logit = paddle.fluid.data( - name='logit', shape=logit_np.shape, dtype='float64') - label = paddle.fluid.data( - name='label', shape=label_np.shape, dtype='float64') + logit = paddle.fluid.data(name='logit', shape=logit_np.shape, dtype='float64') + label = paddle.fluid.data(name='label', shape=label_np.shape, dtype='float64') feed_dict = {"logit": logit_np, "label": label_np} pos_weight = None diff --git a/python/paddle/fluid/tests/unittests/test_c_comm_init_op.sh b/python/paddle/fluid/tests/unittests/test_c_comm_init_op.sh index aba95a68ab790..a9d450e223f1e 100644 --- a/python/paddle/fluid/tests/unittests/test_c_comm_init_op.sh +++ b/python/paddle/fluid/tests/unittests/test_c_comm_init_op.sh @@ -1,19 +1,4 @@ #!/bin/bash - -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - set -e # use default values # FIXME: random fails on Unknown command lines -c (or -m). diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps10.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps10.py index a82866a797db1..16584ee50081a 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps10.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps10.py @@ -23,6 +23,7 @@ paddle.enable_static() + # For Net base_lr = 0.2 emb_lr = base_lr * 3 diff --git a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py index 28803f5ac6232..aa85eb3df3527 100644 --- a/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py +++ b/python/paddle/fluid/tests/unittests/test_flatten_contiguous_range_op.py @@ -170,8 +170,7 @@ def test_type(): x2 = np.arange(image_shape[0] * image_shape[1] * image_shape[2] * image_shape[3]).reshape(image_shape) / 100. x2 = x2.astype('float16') - x2_var = paddle.fluid.data( - name='x2', shape=[3, 2, 4, 5], dtype='float16') + x2_var = paddle.fluid.data(name='x2', shape=[3, 2, 4, 5], dtype='float16') paddle.flatten(x2_var) self.assertRaises(TypeError, test_type) diff --git a/python/paddle/fluid/tests/unittests/test_l1_loss.py b/python/paddle/fluid/tests/unittests/test_l1_loss.py index c35188623b440..fba16959901a8 100644 --- a/python/paddle/fluid/tests/unittests/test_l1_loss.py +++ b/python/paddle/fluid/tests/unittests/test_l1_loss.py @@ -44,10 +44,8 @@ def run_imperative(self): self.assertTrue(dy_result.shape, [10, 10, 5]) def run_static(self, use_gpu=False): - input = paddle.fluid.data( - name='input', shape=[10, 10, 5], dtype='float32') - label = paddle.fluid.data( - name='label', shape=[10, 10, 5], dtype='float32') + input = paddle.fluid.data(name='input', shape=[10, 10, 5], dtype='float32') + label = paddle.fluid.data(name='label', shape=[10, 10, 5], dtype='float32') result0 = paddle.nn.functional.l1_loss(input, label) result1 = paddle.nn.functional.l1_loss(input, label, reduction='sum') result2 = paddle.nn.functional.l1_loss(input, label, reduction='none') @@ -129,10 +127,8 @@ def run_imperative(self): self.assertTrue(dy_result.shape, [10, 10, 5]) def run_static(self, use_gpu=False): - input = paddle.fluid.data( - name='input', shape=[10, 10, 5], dtype='float32') - label = paddle.fluid.data( - name='label', shape=[10, 10, 5], dtype='float32') + input = paddle.fluid.data(name='input', shape=[10, 10, 5], dtype='float32') + label = paddle.fluid.data(name='label', shape=[10, 10, 5], dtype='float32') l1_loss = paddle.nn.loss.L1Loss() result0 = l1_loss(input, label) l1_loss = paddle.nn.loss.L1Loss(reduction='sum') diff --git a/python/paddle/fluid/tests/unittests/test_listen_and_serv.sh b/python/paddle/fluid/tests/unittests/test_listen_and_serv.sh index d9d64e4dfa693..bee230fba5a7e 100644 --- a/python/paddle/fluid/tests/unittests/test_listen_and_serv.sh +++ b/python/paddle/fluid/tests/unittests/test_listen_and_serv.sh @@ -1,19 +1,4 @@ #!/bin/bash - -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - unset https_proxy http_proxy nohup python -u test_listen_and_serv_op.py > test_listen_and_serv_op.log 2>&1 & diff --git a/python/paddle/fluid/tests/unittests/test_mse_loss.py b/python/paddle/fluid/tests/unittests/test_mse_loss.py index 89eef6ca24243..bc5d35d3254bc 100644 --- a/python/paddle/fluid/tests/unittests/test_mse_loss.py +++ b/python/paddle/fluid/tests/unittests/test_mse_loss.py @@ -191,10 +191,8 @@ def test_NNFunctionalMseLoss_mean(self): place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( ) else paddle.CPUPlace() with paddle.static.program_guard(prog, startup_prog): - input = paddle.fluid.data( - name='input', shape=dim, dtype='float32') - target = paddle.fluid.data( - name='target', shape=dim, dtype='float32') + input = paddle.fluid.data(name='input', shape=dim, dtype='float32') + target = paddle.fluid.data(name='target', shape=dim, dtype='float32') mse_loss = paddle.nn.functional.mse_loss(input, target, 'mean') exe = paddle.static.Executor(place) @@ -227,10 +225,8 @@ def test_NNFunctionalMseLoss_sum(self): place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( ) else paddle.CPUPlace() with paddle.static.program_guard(prog, startup_prog): - input = paddle.fluid.data( - name='input', shape=dim, dtype='float32') - target = paddle.fluid.data( - name='target', shape=dim, dtype='float32') + input = paddle.fluid.data(name='input', shape=dim, dtype='float32') + target = paddle.fluid.data(name='target', shape=dim, dtype='float32') mse_loss = paddle.nn.functional.mse_loss(input, target, 'sum') exe = paddle.static.Executor(place) @@ -263,10 +259,8 @@ def test_NNFunctionalMseLoss_none(self): place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda( ) else paddle.CPUPlace() with paddle.static.program_guard(prog, startup_prog): - input = paddle.fluid.data( - name='input', shape=dim, dtype='float32') - target = paddle.fluid.data( - name='target', shape=dim, dtype='float32') + input = paddle.fluid.data(name='input', shape=dim, dtype='float32') + target = paddle.fluid.data(name='target', shape=dim, dtype='float32') mse_loss = paddle.nn.functional.mse_loss(input, target, 'none') exe = paddle.static.Executor(place) diff --git a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py index 3bb3e843b1b11..0533a0d09fa0d 100644 --- a/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py +++ b/python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_iterable_dataset_dynamic.py @@ -160,6 +160,5 @@ def run_main(self, num_workers, places): print("time cost", ret['time'], 'step_list', ret['step']) return ret - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py b/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py index f1a409c712fc3..f75d6e9df540b 100644 --- a/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py +++ b/python/paddle/fluid/tests/unittests/test_pixel_shuffle.py @@ -97,10 +97,8 @@ def test_static_graph_functional(self): place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() paddle.enable_static() - x_1 = paddle.fluid.data( - name="x", shape=[2, 9, 4, 4], dtype="float64") - x_2 = paddle.fluid.data( - name="x2", shape=[2, 4, 4, 9], dtype="float64") + x_1 = paddle.fluid.data(name="x", shape=[2, 9, 4, 4], dtype="float64") + x_2 = paddle.fluid.data(name="x2", shape=[2, 4, 4, 9], dtype="float64") out_1 = F.pixel_shuffle(x_1, 3) out_2 = F.pixel_shuffle(x_2, 3, "NHWC") @@ -125,10 +123,8 @@ def test_static_graph_layer(self): place = paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace() paddle.enable_static() - x_1 = paddle.fluid.data( - name="x", shape=[2, 9, 4, 4], dtype="float64") - x_2 = paddle.fluid.data( - name="x2", shape=[2, 4, 4, 9], dtype="float64") + x_1 = paddle.fluid.data(name="x", shape=[2, 9, 4, 4], dtype="float64") + x_2 = paddle.fluid.data(name="x2", shape=[2, 4, 4, 9], dtype="float64") # init instance ps_1 = paddle.nn.PixelShuffle(3) ps_2 = paddle.nn.PixelShuffle(3, "NHWC") diff --git a/python/paddle/fluid/tests/unittests/test_prod_op.py b/python/paddle/fluid/tests/unittests/test_prod_op.py index cdfcbb4e4e735..15fd79542d608 100644 --- a/python/paddle/fluid/tests/unittests/test_prod_op.py +++ b/python/paddle/fluid/tests/unittests/test_prod_op.py @@ -55,8 +55,7 @@ def run_imperative(self): self.assertTrue(np.allclose(dy_result.numpy(), expected_result)) def run_static(self, use_gpu=False): - input = paddle.fluid.data( - name='input', shape=[10, 10, 5], dtype='float32') + input = paddle.fluid.data(name='input', shape=[10, 10, 5], dtype='float32') result0 = paddle.prod(input) result1 = paddle.prod(input, axis=1) result2 = paddle.prod(input, axis=-1) @@ -115,8 +114,7 @@ def test_error(self): with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()): x = paddle.fluid.data(name='x', shape=[2, 2, 4], dtype='float32') - bool_x = paddle.fluid.data( - name='bool_x', shape=[2, 2, 4], dtype='bool') + bool_x = paddle.fluid.data(name='bool_x', shape=[2, 2, 4], dtype='bool') # The argument x shoule be a Tensor self.assertRaises(TypeError, paddle.prod, [1]) diff --git a/python/paddle/fluid/tests/unittests/test_selu_op.py b/python/paddle/fluid/tests/unittests/test_selu_op.py index e71adae8d9b6e..95ae1eecc6614 100644 --- a/python/paddle/fluid/tests/unittests/test_selu_op.py +++ b/python/paddle/fluid/tests/unittests/test_selu_op.py @@ -128,18 +128,15 @@ def test_errors(self): # The input type must be Variable. self.assertRaises(TypeError, F.selu, 1) # The input dtype must be float16, float32, float64. - x_int32 = paddle.fluid.data( - name='x_int32', shape=[12, 10], dtype='int32') + x_int32 = paddle.fluid.data(name='x_int32', shape=[12, 10], dtype='int32') self.assertRaises(TypeError, F.selu, x_int32) # The scale must be greater than 1.0 - x_fp32 = paddle.fluid.data( - name='x_fp32', shape=[12, 10], dtype='float32') + x_fp32 = paddle.fluid.data(name='x_fp32', shape=[12, 10], dtype='float32') self.assertRaises(ValueError, F.selu, x_fp32, -1.0) # The alpha must be no less than 0 self.assertRaises(ValueError, F.selu, x_fp32, 1.6, -1.0) # support the input dtype is float16 - x_fp16 = paddle.fluid.data( - name='x_fp16', shape=[12, 10], dtype='float16') + x_fp16 = paddle.fluid.data(name='x_fp16', shape=[12, 10], dtype='float16') F.selu(x_fp16) diff --git a/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py b/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py index 2ef04d9cbfa73..85f9501e53f4a 100644 --- a/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py +++ b/python/paddle/fluid/tests/unittests/test_sigmoid_focal_loss.py @@ -42,10 +42,8 @@ def test_static(place, prog = paddle.static.Program() startup_prog = paddle.static.Program() with paddle.static.program_guard(prog, startup_prog): - logit = paddle.fluid.data( - name='logit', shape=logit_np.shape, dtype='float64') - label = paddle.fluid.data( - name='label', shape=label_np.shape, dtype='float64') + logit = paddle.fluid.data(name='logit', shape=logit_np.shape, dtype='float64') + label = paddle.fluid.data(name='label', shape=label_np.shape, dtype='float64') feed_dict = {"logit": logit_np, "label": label_np} normalizer = None diff --git a/python/paddle/fluid/tests/unittests/test_transpose_op.py b/python/paddle/fluid/tests/unittests/test_transpose_op.py index 59b4afdf8b02d..f72df8cbe4640 100644 --- a/python/paddle/fluid/tests/unittests/test_transpose_op.py +++ b/python/paddle/fluid/tests/unittests/test_transpose_op.py @@ -23,7 +23,6 @@ paddle.enable_static() - class TestTransposeOp(OpTest): def setUp(self): self.init_op_type() @@ -152,7 +151,6 @@ def test_each_elem_value_check(): self.assertRaises(ValueError, test_each_elem_value_check) - class TestTransposeApi(unittest.TestCase): def test_static_out(self): paddle.enable_static() @@ -163,11 +161,10 @@ def test_static_out(self): place = paddle.CPUPlace() exe = paddle.static.Executor(place) x_np = np.random.random([2, 3, 4]).astype("float32") - result1, result2 = exe.run(feed={"x": x_np}, - fetch_list=[x_trans1, x_trans2]) + result1, result2 = exe.run(feed={"x": x_np}, fetch_list=[x_trans1, x_trans2]) expected_result1 = np.transpose(x_np, [1, 0, 2]) expected_result2 = np.transpose(x_np, (2, 1, 0)) - + np.testing.assert_array_equal(result1, expected_result1) np.testing.assert_array_equal(result2, expected_result2) @@ -188,7 +185,6 @@ def test_dygraph_out(self): # dygraph test paddle.enable_static() - class TestTAPI(unittest.TestCase): def test_out(self): with fluid.program_guard(fluid.Program()): diff --git a/tools/check_sequence_op.sh b/tools/check_sequence_op.sh index a263b046b258b..ada96750eaad8 100644 --- a/tools/check_sequence_op.sh +++ b/tools/check_sequence_op.sh @@ -1,19 +1,5 @@ #!/bin/bash -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - PADDLE_ROOT="$( cd "$( dirname "${BASH_SOURCE[0]}")/../" && pwd )" function check_sequnece_op_unitests(){ diff --git a/tools/cudaError/start.sh b/tools/cudaError/start.sh index 66e56b8485d8c..3c0e57ffe7ec1 100644 --- a/tools/cudaError/start.sh +++ b/tools/cudaError/start.sh @@ -1,19 +1,4 @@ #!/usr/bin/env bash - -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - set -ex SYSTEM=`uname -s` rm -f protoc-3.11.3-linux-x86_64.* diff --git a/tools/diff_api.py b/tools/diff_api.py index f086598945afe..8a2acbb3d0acc 100644 --- a/tools/diff_api.py +++ b/tools/diff_api.py @@ -1,19 +1,4 @@ #!/usr/bin/env python - -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - from __future__ import print_function import difflib import sys diff --git a/tools/diff_unittest.py b/tools/diff_unittest.py index fa70be0990ec0..382fbdd0b0c29 100644 --- a/tools/diff_unittest.py +++ b/tools/diff_unittest.py @@ -1,19 +1,4 @@ #!/usr/bin/env python - -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import difflib import sys diff --git a/tools/dockerfile/icode.sh b/tools/dockerfile/icode.sh index 973975fe7f737..da3ffb8c77db7 100755 --- a/tools/dockerfile/icode.sh +++ b/tools/dockerfile/icode.sh @@ -1,19 +1,5 @@ #!/bin/bash -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - function install_gcc(){ sed -i 's##RUN apt-get update \ diff --git a/tools/document_preview.sh b/tools/document_preview.sh index 83c758d0aa8b8..10f486f8fd4f6 100755 --- a/tools/document_preview.sh +++ b/tools/document_preview.sh @@ -1,19 +1,4 @@ #!/bin/bash - -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - PADDLE_ROOT=/home mkdir ${PADDLE_ROOT} cd ${PADDLE_ROOT} diff --git a/tools/get_cpu_info.sh b/tools/get_cpu_info.sh index bce338a8619e6..81eb19dc0661e 100755 --- a/tools/get_cpu_info.sh +++ b/tools/get_cpu_info.sh @@ -1,19 +1,5 @@ #!/bin/bash -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - if [ "`uname -s`" != "Linux" ]; then echo "Current scenario only support in Linux yet!" exit 0 diff --git a/tools/static_mode_white_list.pyc b/tools/static_mode_white_list.pyc index fdb2a15d7a94ae66e014844fb40917f41d91847c..e9012c233595b6844f54e625972360f5aeeb0d3b 100644 GIT binary patch delta 203 zcmeBJ%eZE zmor5%Fr>_53f7oE*_Pw-W*I=;jr!7q}*` u3^10o0lGuKC_leMKRG`owOGG6u`D${CqFS|^6LOg79UmyhRp_nKCA%dMl&)1 delta 429 zcmYk1OG^S#6vywJRD7nEV@Qe+712g42-=yC77@%MXc5tM^h_`3Va^Ce5VULErEL$( zo)pM@l$Jr%7f7Nn&{Y$<+{N$Q|2hBjzfX_C$+Hk2_z8GF+UvzK`8zEjB*WQExLfft zR|e1r=m!jNAm}GJ2pHs-1Hlkr7;p$ML@+>b7%;+14g^KOC|8Mng8^}v$uS0To}Cq+kTGd*X2xsT(C7!PDArbSixK&esP zsA+VwZYs1Yo2IQ#jQWFOzon*Ny{xc>Y{b_?+VNqbT*N+28)0Fby=P;rxfo;vI#o?eqgq8pU8hFXO7_?Mm%Bj^^Kag?73%A%td;lx z-LKZIk^ZA*prX>U%zt7Ty Date: Fri, 26 Mar 2021 00:09:56 +0100 Subject: [PATCH 10/37] minor change --- paddle/fluid/operators/reduce_ops/reduce_op.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index ab1b8de7c00e6..0f5a24d1caf53 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -493,7 +493,8 @@ class ReduceOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // choose cudnn kernel if the runtime supported. - framework::LibraryType library_{framework::LibraryType::kPlain}; + framework::LibraryType library_; + library_ = ramework::LibraryType::kPlain; auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN From 87fc5a1983c34f70d9ef403ce4e64d8a16c93b1d Mon Sep 17 00:00:00 2001 From: jakpiase <62569058+jakpiase@users.noreply.github.com> Date: Fri, 26 Mar 2021 00:11:02 +0100 Subject: [PATCH 11/37] fixed mispelling --- paddle/fluid/operators/reduce_ops/reduce_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 0f5a24d1caf53..e2c9518a78b91 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -494,7 +494,7 @@ class ReduceOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { // choose cudnn kernel if the runtime supported. framework::LibraryType library_; - library_ = ramework::LibraryType::kPlain; + library_ = framework::LibraryType::kPlain; auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN From a75ee12eaa542c5f9ea6b608ac72428e4e568225 Mon Sep 17 00:00:00 2001 From: jakpiase <62569058+jakpiase@users.noreply.github.com> Date: Fri, 26 Mar 2021 01:02:16 +0100 Subject: [PATCH 12/37] Minor refactoring --- paddle/fluid/operators/reduce_ops/reduce_op.h | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index e2c9518a78b91..a1172828146e4 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -493,13 +493,10 @@ class ReduceOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { // choose cudnn kernel if the runtime supported. - framework::LibraryType library_; - library_ = framework::LibraryType::kPlain; auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_data_type)) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); From b4428893fa8238f02dfb913a884cf5fde9a9ff3d Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Fri, 26 Mar 2021 12:39:08 +0100 Subject: [PATCH 13/37] minor change --- .../fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py index a61b70ec773fb..efa06a5e6fe70 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py @@ -20,7 +20,8 @@ import paddle.fluid as fluid import paddle - +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") @skip_check_grad_ci(reason="not implemented") class TestReduceSumDefaultBF16ONEDNNOp(OpTest): def setUp(self): From 27dec3a2460faafebc1d6dcc48bd24e5689b5919 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Fri, 26 Mar 2021 13:46:11 +0100 Subject: [PATCH 14/37] importet necessary modules --- .../fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py index efa06a5e6fe70..bb1dd79e77198 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py @@ -17,9 +17,11 @@ import unittest import numpy as np from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 +import paddle.fluid.core as core import paddle.fluid as fluid import paddle + @unittest.skipIf(not core.supports_bfloat16(), "place does not support BF16 evaluation") @skip_check_grad_ci(reason="not implemented") From 71089feb4b84aabcef90d91033cd3b46359b791f Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Fri, 26 Mar 2021 17:21:22 +0100 Subject: [PATCH 15/37] minor change --- paddle/fluid/operators/reduce_ops/reduce_op.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index a1172828146e4..f09e5876e77d3 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -495,6 +495,9 @@ class ReduceOp : public framework::OperatorWithKernel { // choose cudnn kernel if the runtime supported. auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + if(ctx.Input("X")->dims().size() > 5) + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + #ifdef PADDLE_WITH_MKLDNN if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), From 29097ce95e082acb51be450f9300621a81743e7a Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Fri, 26 Mar 2021 20:35:18 +0100 Subject: [PATCH 16/37] minor formatting change --- paddle/fluid/operators/reduce_ops/reduce_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index f09e5876e77d3..280464ea85279 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -495,7 +495,7 @@ class ReduceOp : public framework::OperatorWithKernel { // choose cudnn kernel if the runtime supported. auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - if(ctx.Input("X")->dims().size() > 5) + if (ctx.Input("X")->dims().size() > 5) return framework::OpKernelType(input_data_type, ctx.GetPlace()); #ifdef PADDLE_WITH_MKLDNN From 164043ae5f93938e23de55b236faf72b4919c452 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Mon, 29 Mar 2021 12:18:52 +0200 Subject: [PATCH 17/37] excluded cuda from bf test --- .../fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py index bb1dd79e77198..a894d042e426c 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py @@ -24,6 +24,8 @@ @unittest.skipIf(not core.supports_bfloat16(), "place does not support BF16 evaluation") +@unittest.skipIf(core.is_compiled_with_cuda(), + "core is compiled with CUDA which has no BF implementation") @skip_check_grad_ci(reason="not implemented") class TestReduceSumDefaultBF16ONEDNNOp(OpTest): def setUp(self): From be36f947fab2436539bcb46ed59a9f8a8a03e7ee Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Mon, 29 Mar 2021 19:17:49 +0200 Subject: [PATCH 18/37] fixed static mode in test_resnet_v2 --- .../unittests/dygraph_to_static/test_resnet_v2.py | 1 + .../tests/unittests/mkldnn/test_reduce_mkldnn_op.py | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py index 10346ab0cc442..8c779ec7ef62a 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py @@ -356,6 +356,7 @@ def test_resnet(self): self.verify_predict() def test_in_static_mode_mkldnn(self): + paddle.enable_static() paddle.fluid.set_flags({'FLAGS_use_mkldnn': True}) try: train(to_static=True) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py index c2ec1bfc36668..96c68ac98bd5b 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py @@ -158,11 +158,24 @@ def setUp(self): class TestReduceMean3DONEDNNOp(TestReduceSumDefaultONEDNNOp): def setUp(self): self.op_type = "reduce_mean" + self.use_mkldnn = True self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': [0], 'use_mkldnn': self.use_mkldnn} self.outputs = { 'Out': self.inputs['X'].sum(axis=0) / self.inputs['X'].shape[0] } +@skip_check_grad_ci(reason="not implemented") +class TestReduceMean4DReduceAllONEDNNOp(TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_mean" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 6, 8, 10)).astype("float32")} + self.attrs = {'reduce_all': True, 'use_mkldnn': self.use_mkldnn} + self.outputs = { + 'Out': self.inputs['X'].sum() / np.asarray(self.inputs['X'].shape).prod() + } + if __name__ == '__main__': paddle.enable_static() From 424083f53c034a17719b84e53afca97d357db18c Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Mon, 29 Mar 2021 19:22:33 +0200 Subject: [PATCH 19/37] added formatting --- .../fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py index 96c68ac98bd5b..665d3f6cc6b82 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py @@ -165,6 +165,7 @@ def setUp(self): 'Out': self.inputs['X'].sum(axis=0) / self.inputs['X'].shape[0] } + @skip_check_grad_ci(reason="not implemented") class TestReduceMean4DReduceAllONEDNNOp(TestReduceSumDefaultONEDNNOp): def setUp(self): @@ -173,7 +174,8 @@ def setUp(self): self.inputs = {'X': np.random.random((5, 6, 8, 10)).astype("float32")} self.attrs = {'reduce_all': True, 'use_mkldnn': self.use_mkldnn} self.outputs = { - 'Out': self.inputs['X'].sum() / np.asarray(self.inputs['X'].shape).prod() + 'Out': + self.inputs['X'].sum() / np.asarray(self.inputs['X'].shape).prod() } From 87b5b38677ef7c6a7d67aaab87f140d5cc419405 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Wed, 7 Apr 2021 18:12:55 +0200 Subject: [PATCH 20/37] added support for edge case --- .../reduce_ops/mkldnn/reduce_mkldnn_op.h | 75 ++++++++++++++----- .../unittests/mkldnn/test_reduce_mkldnn_op.py | 10 +++ 2 files changed, 65 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h index 7073288a9ed86..7e09aaa126eff 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -40,26 +40,61 @@ class ReduceMKLDNNKernel : public framework::OpKernel { std::vector output_dims = CalculateOutputDims(input, output, reduce_dims, reduce_all, keep_dim); - platform::ReductionMKLDNNHandler handler( - reduction_type, 0.0f, 0.0f, dev_ctx, onednn_engine, ctx.GetPlace(), - input, output, ctx.InputName("X"), output_dims); - - auto src_memory_p = handler.AcquireSrcMemory(input); - auto dst_memory_p = handler.AcquireDstMemory(output); - - std::unordered_map reduction_args = { - {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; - - auto reduction_p = handler.AcquireForwardPrimitive(); - - auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); - reduction_p->execute(astream, reduction_args); - astream.wait(); - - output->set_layout(framework::DataLayout::kMKLDNN); - output->set_format( - platform::GetMKLDNNFormat(dst_memory_p->get_desc().reshape( - paddle::framework::vectorize(output->dims())))); + auto input_dims = framework::vectorize(input->dims()); + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + + // oneDNN reduce op does not support edge case in which memory is being + // copied without actual reduction. + // In that case reorder must be executed to maintain compatibility with + // PaddlePaddle reduce op + if (input_dims == output_dims) { + mkldnn::memory::data_type input_type = + framework::ToMKLDNNDataType(input->type()); + std::string key = platform::CreateKey( + dev_ctx, input_dims, input->format(), input->format(), input_type); + platform::ReorderMKLDNNHandler reorder_handler( + input_dims, input->type(), input_type, dev_ctx, onednn_engine, key); + + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + input->format(), platform::to_void_cast(input->data())); + + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + output, input->format(), ctx.GetPlace()); + + auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p, + reorder_dst_memory_p); + + platform::RecordEvent record_reorder("int_reorder", + platform::EventRole::kUniqueOp); + + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); + astream.wait(); + + output->set_layout(framework::DataLayout::kMKLDNN); + output->set_format( + platform::GetMKLDNNFormat(reorder_dst_memory_p->get_desc().reshape( + paddle::framework::vectorize(output->dims())))); + } else { + platform::ReductionMKLDNNHandler handler( + reduction_type, 0.0f, 0.0f, dev_ctx, onednn_engine, ctx.GetPlace(), + input, output, ctx.InputName("X"), output_dims); + + auto src_memory_p = handler.AcquireSrcMemory(input); + auto dst_memory_p = handler.AcquireDstMemory(output); + + std::unordered_map reduction_args = { + {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; + + auto reduction_p = handler.AcquireForwardPrimitive(); + + reduction_p->execute(astream, reduction_args); + astream.wait(); + output->set_layout(framework::DataLayout::kMKLDNN); + output->set_format( + platform::GetMKLDNNFormat(dst_memory_p->get_desc().reshape( + paddle::framework::vectorize(output->dims())))); + } } private: diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py index 665d3f6cc6b82..c913b9eeea27d 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py @@ -179,6 +179,16 @@ def setUp(self): } +@skip_check_grad_ci(reason="not implemented") +class TestReduceMeanNoReduce1DOp(TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_mean" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((1)).astype("float32")} + self.attrs = {'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': self.inputs['X']} + + if __name__ == '__main__': paddle.enable_static() unittest.main() From 94e4ace1824b7db26f3bac0a869036f357dc3f42 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Tue, 13 Apr 2021 13:03:08 +0200 Subject: [PATCH 21/37] added files for reduce grad --- .../reduce_ops/mkldnn/reduce_mkldnn_op.h | 63 +++++++++++++++++- .../reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc | 12 ++++ paddle/fluid/operators/reduce_ops/reduce_op.h | 38 +++++++++-- paddle/fluid/platform/mkldnn_reuse.h | 65 +++++++++++++++++++ .../unittests/mkldnn/test_reduce_grad_tmp.py | 58 +++++++++++++++++ 5 files changed, 230 insertions(+), 6 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_reduce_grad_tmp.py diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h index 7e09aaa126eff..76d4818061ac3 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/mkldnn_reuse.h" +#include "paddle/fluid/platform/mkldnn_helper.h" namespace paddle { namespace operators { @@ -30,8 +31,8 @@ class ReduceMKLDNNKernel : public framework::OpKernel { ctx.template device_context(); const auto& onednn_engine = dev_ctx.GetEngine(); - const auto* input = ctx.Input("X"); - auto* output = ctx.Output("Out"); + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); auto reduce_dims = ctx.Attr>("dim"); bool reduce_all = ctx.Attr("reduce_all"); @@ -121,5 +122,63 @@ class ReduceMKLDNNKernel : public framework::OpKernel { } }; +template +class ReduceGradMKLDNNKernel : public framework::OpKernel { + public: + void RunKernel(const framework::ExecutionContext& ctx, + dnnl::algorithm reduction_type) const { + //only for reduce sum for now + auto& dev_ctx = + ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + + auto dims = ctx.Attr>("dim"); + auto* input_x = ctx.Input("X"); + auto* input_dy = ctx.Input(framework::GradVarName("Out")); + + auto* output_dx = ctx.Output(framework::GradVarName("X")); + + output_dx->mutable_data(ctx.GetPlace()); + output_dx->set_format(getPlainFormatTag(output_dx)); + output_dx->set_layout(input_dy->layout()); + + platform::BinaryReductionGradMKLDNNHandler handler(dnnl::algorithm::binary_add, dev_ctx, onednn_engine, + ctx.GetPlace(), output_dx, input_dy, 0.0f, 1.0f, + ctx.InputName(framework::GradVarName("Out"))); + + const auto src_dx_memory = handler.AcquireSrcMemory(output_dx); + const auto src_dy_memory = handler.AcquireSecondSrcMemory(input_dy); + + const auto binary_prim = handler.AcquireForwardPrimitive(); + + const std::unordered_map args = { + {DNNL_ARG_SRC_0, *src_dx_memory}, + {DNNL_ARG_SRC_1, *src_dy_memory}, + {DNNL_ARG_DST, *src_dx_memory}}; + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + + binary_prim->execute(astream, args); + astream.wait(); + } +protected: + mkldnn::memory::format_tag getPlainFormatTag(const Tensor* tensor) const{ + switch(tensor->dims().size()){ + case 1: + return mkldnn::memory::format_tag::a; + case 2: + return mkldnn::memory::format_tag::ab; + case 3: + return mkldnn::memory::format_tag::abc; + case 4: + return mkldnn::memory::format_tag::abcd; + case 5: + return mkldnn::memory::format_tag::abcde; + default: + platform::errors::InvalidArgument("Tensor dims must be in range <1, 5>"); + } + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc index 4676589e68910..07537e4d00f42 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc @@ -25,6 +25,14 @@ class ReduceSumMKLDNNKernel : public ReduceMKLDNNKernel { } }; +template +class ReduceSumGradMKLDNNKernel : public ReduceGradMKLDNNKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx, dnnl::algorithm::reduction_sum); + } +}; + } // namespace operators } // namespace paddle @@ -32,3 +40,7 @@ namespace ops = paddle::operators; REGISTER_OP_KERNEL(reduce_sum, MKLDNN, paddle::platform::CPUPlace, ops::ReduceSumMKLDNNKernel, ops::ReduceSumMKLDNNKernel); + +REGISTER_OP_KERNEL(reduce_sum_grad, MKLDNN, paddle::platform::CPUPlace, + ops::ReduceSumGradMKLDNNKernel, + ops::ReduceSumGradMKLDNNKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 280464ea85279..c3177794df248 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -495,7 +495,12 @@ class ReduceOp : public framework::OperatorWithKernel { // choose cudnn kernel if the runtime supported. auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - if (ctx.Input("X")->dims().size() > 5) + auto input_dims = framework::vectorize( + ctx.Input("X")->dims()); + auto output_dims = framework::vectorize( + ctx.Output("Out")->dims()); + + if (input_dims.size() > 5) return framework::OpKernelType(input_data_type, ctx.GetPlace()); #ifdef PADDLE_WITH_MKLDNN @@ -559,15 +564,37 @@ class ReduceGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { + + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, framework::GradVarName("Out")); + + auto CanMKLDNNReduceGradBeUsed = [&]() { + auto dx_dims = ctx.Input("X")->dims(); + auto dy_dims = ctx.Input(framework::GradVarName("Out"))->dims(); + + // Subtensor must be on rightmost part of the bigger tensor + for(size_t i = 0; i < dy_dims.size() ; ++i){ + if(dx_dims[dx_dims.size() - dy_dims.size() + i] != dy_dims[i]){ + return false; + } + } + return true; + }; + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type) && CanMKLDNNReduceGradBeUsed()) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + int in_dtype = ctx.Attr("in_dtype"); if (in_dtype >= 0) { return framework::OpKernelType( static_cast(in_dtype), ctx.GetPlace()); } - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -606,6 +633,9 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr("input_format", + "(int, default -1) Input memory format") + .SetDefault(-1); AddComment(string::Sprintf(R"DOC( %s Operator. diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 0c45da63edd70..e7157e1d766d8 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -630,6 +630,70 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT { } }; +template +class BinaryReductionGradMKLDNNHandler : public platform::MKLDNNHandlerT { + public: + BinaryReductionGradMKLDNNHandler(const dnnl::algorithm algo, + const MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine engine, platform::Place cpu_place, + const Tensor* x, const Tensor* y, + float scale_x, float scale_y, + const std::string& uniq_name) + : platform::MKLDNNHandlerT( + dev_ctx, engine, cpu_place, + platform::CreateKey( + dev_ctx, framework::vectorize(x->dims()), uniq_name)) { + + if (!this->isCached()) { + PADDLE_ENFORCE_EQ( + x->layout(), DataLayout::kMKLDNN, + platform::errors::InvalidArgument("Wrong layout set for X tensor.")); + PADDLE_ENFORCE_NE( + x->format(), MKLDNNMemoryFormat::undef, + platform::errors::InvalidArgument("Wrong format set for X tensor.")); + + PADDLE_ENFORCE_EQ( + y->layout(), DataLayout::kMKLDNN, + platform::errors::InvalidArgument("Wrong layout set for Y tensor.")); + PADDLE_ENFORCE_NE( + y->format(), MKLDNNMemoryFormat::undef, + platform::errors::InvalidArgument("Wrong format set for Y tensor.")); + + auto src1_tz = framework::vectorize(y->dims()); + const auto src0_tz = framework::vectorize(x->dims()); + + // GetExpectedKernelType checks if smaller vector is a subvector with all the dims in correct order on the rightmost part of the bigger vector, f.e. a correct vector for broadcasting: + // x = 5, 7, 3, 2, 4, 8 + // y = 4, 8 + for(size_t i = src1_tz.size() ; i < src0_tz.size() ; ++i){ + src1_tz.insert(src1_tz.begin(), 1L); + } + + const auto src0_md = dnnl::memory::desc( + src0_tz, platform::MKLDNNGetDataType(), x->format()); + const auto src1_md = dnnl::memory::desc( + src1_tz, platform::MKLDNNGetDataType(), x->format());//y->format()); + + //const auto dst_md = dnnl::memory::desc( // in reduction binary op is always inplace + // output_dims, platform::MKLDNNGetDataType(), x->format()); + + dnnl::primitive_attr attributes; + attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x}); + attributes.set_scales(DNNL_ARG_SRC_1, 0, {scale_y}); + + this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md, + src1_md, src0_md); + } + } + + std::shared_ptr AcquireSecondSrcMemory( + const framework::Tensor* input) { + const T* input_data = input->data(); + return this->AcquireMemoryFromPrimitive( + this->fwd_pd_->src1_desc(), to_void_cast(input_data), "@src1_mem_p"); + } +}; + template class ReductionMKLDNNHandler : public platform::MKLDNNHandlerT { @@ -665,6 +729,7 @@ class ReductionMKLDNNHandler } }; + template class ActivationMKLDNNHandler : public MKLDNNHandlerT Date: Wed, 14 Apr 2021 12:59:47 +0200 Subject: [PATCH 22/37] added grad tests for onednn reduce --- .../mkldnn/reduce_mean_mkldnn_op.cc | 29 +++ .../reduce_ops/mkldnn/reduce_mkldnn_op.h | 11 +- .../reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc | 2 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 3 + .../mkldnn/test_reduce_bf16_mkldnn_op.py | 169 ++++++++++-------- .../unittests/mkldnn/test_reduce_mkldnn_op.py | 52 ++---- 6 files changed, 147 insertions(+), 119 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc index a9eed0d7eb042..8bb931ad3c372 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc @@ -25,6 +25,31 @@ class ReduceMeanMKLDNNKernel : public ReduceMKLDNNKernel { } }; +template +class ReduceMeanGradMKLDNNKernel : public ReduceGradMKLDNNKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto* input_x = ctx.Input("X"); + auto input_dims = framework::vectorize(input_x->dims()); + auto reduce_dims = ctx.Attr>("dim"); + + int number_of_elements = 1; + if(!ctx.Attr("reduce_all")){ + for(size_t i = 0; i < reduce_dims.size() ; ++i){ + reduce_dims[i] = (reduce_dims[i] >= 0) + ? reduce_dims[i] + : input_dims.size() + reduce_dims[i]; + number_of_elements *= input_dims[reduce_dims[i]]; + } + } else { + for(size_t i = 0; i < input_dims.size() ; ++i) + number_of_elements *= input_dims[i]; + } + + this->RunKernel(ctx, dnnl::algorithm::binary_add, 0.0f, 1.0L / number_of_elements); + } +}; + } // namespace operators } // namespace paddle @@ -32,3 +57,7 @@ namespace ops = paddle::operators; REGISTER_OP_KERNEL(reduce_mean, MKLDNN, paddle::platform::CPUPlace, ops::ReduceMeanMKLDNNKernel, ops::ReduceMeanMKLDNNKernel); + +REGISTER_OP_KERNEL(reduce_mean_grad, MKLDNN, paddle::platform::CPUPlace, + ops::ReduceMeanGradMKLDNNKernel, + ops::ReduceMeanGradMKLDNNKernel); diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h index 76d4818061ac3..fc4607a0c7278 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -126,7 +126,7 @@ template class ReduceGradMKLDNNKernel : public framework::OpKernel { public: void RunKernel(const framework::ExecutionContext& ctx, - dnnl::algorithm reduction_type) const { + dnnl::algorithm binary_type, float scale_x, float scale_y) const { //only for reduce sum for now auto& dev_ctx = ctx.template device_context(); @@ -139,16 +139,19 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { auto* output_dx = ctx.Output(framework::GradVarName("X")); output_dx->mutable_data(ctx.GetPlace()); + output_dx->set_format(getPlainFormatTag(output_dx)); output_dx->set_layout(input_dy->layout()); - platform::BinaryReductionGradMKLDNNHandler handler(dnnl::algorithm::binary_add, dev_ctx, onednn_engine, - ctx.GetPlace(), output_dx, input_dy, 0.0f, 1.0f, + platform::BinaryReductionGradMKLDNNHandler handler(binary_type, dev_ctx, onednn_engine, + ctx.GetPlace(), output_dx, input_dy, scale_x, scale_y, ctx.InputName(framework::GradVarName("Out"))); - const auto src_dx_memory = handler.AcquireSrcMemory(output_dx); + auto src_dx_memory = handler.AcquireSrcMemory(output_dx); const auto src_dy_memory = handler.AcquireSecondSrcMemory(input_dy); + memset(output_dx->data(), 0, src_dx_memory->get_desc().get_size()); + const auto binary_prim = handler.AcquireForwardPrimitive(); const std::unordered_map args = { diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc index 07537e4d00f42..d9bf5ad427b7c 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc @@ -29,7 +29,7 @@ template class ReduceSumGradMKLDNNKernel : public ReduceGradMKLDNNKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - this->RunKernel(ctx, dnnl::algorithm::reduction_sum); + this->RunKernel(ctx, dnnl::algorithm::binary_add, 0.0f, 1.0f); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index c3177794df248..deb9e0a981677 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -569,6 +569,9 @@ class ReduceGradOp : public framework::OperatorWithKernel { auto CanMKLDNNReduceGradBeUsed = [&]() { auto dx_dims = ctx.Input("X")->dims(); + if (ctx.Attr("reduce_all") || (ctx.Attr>("dim").size() == dx_dims.size())) + return true; + auto dy_dims = ctx.Input(framework::GradVarName("Out"))->dims(); // Subtensor must be on rightmost part of the bigger tensor diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py index a894d042e426c..45284cfdc7c54 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py @@ -26,97 +26,106 @@ "place does not support BF16 evaluation") @unittest.skipIf(core.is_compiled_with_cuda(), "core is compiled with CUDA which has no BF implementation") -@skip_check_grad_ci(reason="not implemented") class TestReduceSumDefaultBF16ONEDNNOp(OpTest): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True - x_fp32 = np.random.random((5, 6, 10)).astype("float32") - x_bf16 = convert_float_to_uint16(x_fp32) - self.inputs = {'X': x_bf16} - self.outputs = {'Out': x_fp32.sum(axis=0)} + self.x_fp32 = np.random.random((5, 6, 10)).astype("float32") + self.x_bf16 = convert_float_to_uint16(self.x_fp32) + self.inputs = {'X': self.x_bf16} + self.outputs = {'Out': self.x_fp32.sum(axis=0)} self.attrs = {'use_mkldnn': self.use_mkldnn} def test_check_output(self): self.check_output(check_dygraph=False) + def calculate_grads(self): + tmp_tensor = np.zeros(self.x_fp32.shape).astype("float32") -@skip_check_grad_ci(reason="not implemented") -class TestReduceSum4DBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): - def setUp(self): - self.op_type = "reduce_sum" - self.use_mkldnn = True - x_fp32 = np.random.random((5, 10, 5, 5)).astype("float32") - x_bf16 = convert_float_to_uint16(x_fp32) - self.inputs = {'X': x_bf16} - self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [2]} - self.outputs = {'Out': x_fp32.sum(axis=tuple(self.attrs['dim']))} + prod_of_reduced_dims = self.inputs['X'].shape[0] + axis = 0 + if "dim" in self.attrs: + prod_of_reduced_dims = 1 + axis = tuple(self.attrs['dim']) + for i in range(len(axis)): + ax = axis[i] + if axis[i] < 0: + ax = len(axis) + axis[i] + prod_of_reduced_dims *= self.inputs['X'].shape[ax] -@skip_check_grad_ci(reason="not implemented") -class TestReduceSum4DReduceAllWithoutReduceAllAttributeBF16ONEDNNOp( - TestReduceSumDefaultBF16ONEDNNOp): + if 'reduce_all' in self.attrs: + if self.attrs['reduce_all'] is True: + axis = None + prod_of_reduced_dims = np.asarray(self.inputs['X'].shape).prod() + + keepdim = False + if 'keep_dim' in self.attrs: + keepdim = True + + self.grad_Out = self.x_fp32.sum(axis=axis, keepdims=keepdim) + + self.grad_Out = np.atleast_1d(self.grad_Out) + + self.grad_X = tmp_tensor + self.grad_Out # broadcast grad + + if self.op_type == 'reduce_mean': + self.grad_X /= prod_of_reduced_dims + + +class TestReduceDefaultWithGradBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): + def test_check_grad(self): + self.calculate_grads() + self.check_grad_with_place( + core.CPUPlace(), ["X"], + "Out", + check_dygraph=False, + user_defined_grads=[self.grad_X], + user_defined_grad_outputs=[convert_float_to_uint16(self.grad_Out)]) + + +class TestReduceSum4DReduceAllWithoutReduceAllAttributeBF16ONEDNNOp(TestReduceDefaultWithGradBF16ONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True - x_fp32 = np.random.normal(size=(2, 3, 5, 6)).astype('float32') - x_bf16 = convert_float_to_uint16(x_fp32) - self.inputs = {'X': x_bf16} + self.x_fp32 = np.random.normal(size=(2, 3, 5, 6)).astype('float32') + self.x_bf16 = convert_float_to_uint16(self.x_fp32) + self.inputs = {'X': self.x_bf16} self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [0, 1, 2, 3]} - self.outputs = {'Out': x_fp32.sum(axis=tuple(self.attrs['dim']))} + self.outputs = {'Out': self.x_fp32.sum(axis=tuple(self.attrs['dim']))} -@skip_check_grad_ci(reason="not implemented") -class TestReduceSum4DReduceAllWithoutReduceAllAttributeNegativeDimsBF16ONEDNNOp( - TestReduceSumDefaultBF16ONEDNNOp): +class TestReduceSum4DReduceAllWithoutReduceAllAttributeNegativeDimsBF16ONEDNNOp(TestReduceDefaultWithGradBF16ONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True - x_fp32 = np.random.normal(size=(2, 7, 3, 5)).astype('float32') - x_bf16 = convert_float_to_uint16(x_fp32) - self.inputs = {'X': x_bf16} + self.x_fp32 = np.random.normal(size=(4, 7, 6, 6)).astype('float32') + self.x_bf16 = convert_float_to_uint16(self.x_fp32) + self.inputs = {'X': self.x_bf16} self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [-1, -2, -3, -4]} - self.outputs = {'Out': x_fp32.sum(axis=tuple(self.attrs['dim']))} + self.outputs = {'Out': self.x_fp32.sum(axis=tuple(self.attrs['dim']))} -@skip_check_grad_ci(reason="not implemented") -class TestReduceSum5DKeepDimsONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): - def setUp(self): - self.op_type = "reduce_sum" - self.use_mkldnn = True - x_fp32 = np.random.random((2, 5, 3, 2, 2)).astype("float32") - x_bf16 = convert_float_to_uint16(x_fp32) - self.inputs = {'X': x_bf16} - self.attrs = {'dim': (2, 3, 4), 'keep_dim': True, 'use_mkldnn': True} - self.outputs = { - 'Out': x_fp32.sum(axis=tuple(self.attrs['dim']), - keepdims=self.attrs['keep_dim']) - } - - -@skip_check_grad_ci(reason="not implemented") -class TestReduceSum5DReduceAllKeepDimsBF16ONEDNNOp( - TestReduceSumDefaultBF16ONEDNNOp): +class TestReduceSum5DReduceAllKeepDimsBF16ONEDNNOp(TestReduceDefaultWithGradBF16ONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True - x_fp32 = np.random.normal(size=(2, 5, 3, 2, 4)).astype('float32') - x_bf16 = convert_float_to_uint16(x_fp32) - self.inputs = {'X': x_bf16} + self.x_fp32 = np.random.normal(size=(2, 5, 3, 2, 5)).astype('float32') + self.x_bf16 = convert_float_to_uint16(self.x_fp32) + self.inputs = {'X': self.x_bf16} self.attrs = {'reduce_all': True, 'keep_dim': True, 'use_mkldnn': True} - self.outputs = {'Out': x_fp32.sum(keepdims=self.attrs['keep_dim'])} + self.outputs = {'Out': self.x_fp32.sum(keepdims=self.attrs['keep_dim'])} -@skip_check_grad_ci(reason="not implemented") -class TestReduceSum4DReduceAllBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): +class TestReduceSum4DReduceAllBF16ONEDNNOp(TestReduceDefaultWithGradBF16ONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True - x_fp32 = np.random.normal(size=(4, 3, 2, 3)).astype('float32') - x_bf16 = convert_float_to_uint16(x_fp32) - self.inputs = {'X': x_bf16} + self.x_fp32 = np.random.normal(size=(4, 5, 4, 5)).astype('float32') + self.x_bf16 = convert_float_to_uint16(self.x_fp32) + self.inputs = {'X': self.x_bf16} self.attrs = {'reduce_all': True, 'use_mkldnn': self.use_mkldnn} - self.outputs = {'Out': x_fp32.sum()} + self.outputs = {'Out': self.x_fp32.sum()} @skip_check_grad_ci( @@ -128,11 +137,11 @@ class TestReduceMax3DBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): def setUp(self): self.op_type = "reduce_max" self.use_mkldnn = True - x_fp32 = np.random.random((5, 6, 10)).astype("float32") - x_bf16 = convert_float_to_uint16(x_fp32) - self.inputs = {'X': x_bf16} + self.x_fp32 = np.random.random((5, 6, 10)).astype("float32") + self.x_bf16 = convert_float_to_uint16(self.x_fp32) + self.inputs = {'X': self.x_bf16} self.attrs = {'dim': [-1], 'use_mkldnn': self.use_mkldnn} - self.outputs = {'Out': x_fp32.max(axis=tuple(self.attrs['dim']))} + self.outputs = {'Out': self.x_fp32.max(axis=tuple(self.attrs['dim']))} @skip_check_grad_ci( @@ -145,11 +154,11 @@ class TestReduceMax4DNegativeAndPositiveDimsBF16ONEDNNOp( def setUp(self): self.op_type = "reduce_max" self.use_mkldnn = True - x_fp32 = np.random.random((5, 6, 10, 9)).astype("float32") - x_bf16 = convert_float_to_uint16(x_fp32) - self.inputs = {'X': x_bf16} + self.x_fp32 = np.random.random((5, 6, 10, 9)).astype("float32") + self.x_bf16 = convert_float_to_uint16(self.x_fp32) + self.inputs = {'X': self.x_bf16} self.attrs = {'dim': [-1, 0, 1], 'use_mkldnn': self.use_mkldnn} - self.outputs = {'Out': x_fp32.max(axis=tuple(self.attrs['dim']))} + self.outputs = {'Out': self.x_fp32.max(axis=tuple(self.attrs['dim']))} @skip_check_grad_ci( @@ -161,23 +170,33 @@ class TestReduceMin3DBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): def setUp(self): self.op_type = "reduce_min" self.use_mkldnn = True - x_fp32 = np.random.random((5, 6, 10)).astype("float32") - x_bf16 = convert_float_to_uint16(x_fp32) - self.inputs = {'X': x_bf16} + self.x_fp32 = np.random.random((5, 6, 10)).astype("float32") + self.x_bf16 = convert_float_to_uint16(self.x_fp32) + self.inputs = {'X': self.x_bf16} self.attrs = {'dim': [2], 'use_mkldnn': self.use_mkldnn} - self.outputs = {'Out': x_fp32.min(axis=tuple(self.attrs['dim']))} + self.outputs = {'Out': self.x_fp32.min(axis=tuple(self.attrs['dim']))} -@skip_check_grad_ci(reason="not implemented") -class TestReduceMean3DBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): +class TestReduceMean3DBF16ONEDNNOp(TestReduceDefaultWithGradBF16ONEDNNOp): def setUp(self): self.op_type = "reduce_mean" self.use_mkldnn = True - x_fp32 = np.random.random((5, 6, 10)).astype("float32") - x_bf16 = convert_float_to_uint16(x_fp32) - self.inputs = {'X': x_bf16} + self.x_fp32 = np.random.random((5, 6, 10)).astype("float32") + self.x_bf16 = convert_float_to_uint16(self.x_fp32) + self.inputs = {'X': self.x_bf16} self.attrs = {'use_mkldnn': self.use_mkldnn} - self.outputs = {'Out': x_fp32.sum(axis=0) / x_fp32.shape[0]} + self.outputs = {'Out': self.x_fp32.sum(axis=0) / self.x_fp32.shape[0]} + + +class TestReduceMean4DBF16ONEDNNOp(TestReduceDefaultWithGradBF16ONEDNNOp): + def setUp(self): + self.op_type = "reduce_mean" + self.use_mkldnn = True + self.x_fp32 = np.random.random((5, 6, 3, 5)).astype("float32") + self.x_bf16 = convert_float_to_uint16(self.x_fp32) + self.inputs = {'X': self.x_bf16} + self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [0, 1]} + self.outputs = {'Out': self.x_fp32.sum(axis=tuple(self.attrs['dim'])) / (self.x_fp32.shape[0] * self.x_fp32.shape[1])} if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py index c913b9eeea27d..df2fe04f7f16f 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py @@ -19,7 +19,6 @@ import paddle -@skip_check_grad_ci(reason="not implemented") class TestReduceSumDefaultONEDNNOp(OpTest): def setUp(self): self.op_type = "reduce_sum" @@ -32,8 +31,12 @@ def test_check_output(self): self.check_output() -@skip_check_grad_ci(reason="not implemented") -class TestReduceSum4DONEDNNOp(TestReduceSumDefaultONEDNNOp): +class TestReduceDefaultWithGradONEDNNOp(TestReduceSumDefaultONEDNNOp): + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestReduceSum4DONEDNNOp(TestReduceDefaultWithGradONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -44,34 +47,19 @@ def setUp(self): } -@skip_check_grad_ci(reason="not implemented") class TestReduceSum4DReduceAllWithoutReduceAllAttributeONEDNNOp( - TestReduceSumDefaultONEDNNOp): + TestReduceDefaultWithGradONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True - self.inputs = {'X': np.random.random((5, 10, 5, 5)).astype("float32")} + self.inputs = {'X': np.random.random((5, 10, 5, 6)).astype("float32")} self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [0, 1, 2, 3]} self.outputs = { 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) } -@skip_check_grad_ci(reason="not implemented") -class TestReduceSum4DReduceAllWithoutReduceAllAttributeNegativeDimsONEDNNOp( - TestReduceSumDefaultONEDNNOp): - def setUp(self): - self.op_type = "reduce_sum" - self.use_mkldnn = True - self.inputs = {'X': np.random.random((5, 10, 5, 5)).astype("float32")} - self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [-1, -2, -3, -4]} - self.outputs = { - 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) - } - - -@skip_check_grad_ci(reason="not implemented") -class TestReduceSum5DKeepDimsONEDNNOp(TestReduceSumDefaultONEDNNOp): +class TestReduceSum5DKeepDimsONEDNNOp(TestReduceDefaultWithGradONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -83,8 +71,7 @@ def setUp(self): } -@skip_check_grad_ci(reason="not implemented") -class TestReduceSum5DReduceAllKeepDimsONEDNNOp(TestReduceSumDefaultONEDNNOp): +class TestReduceSum5DReduceAllKeepDimsONEDNNOp(TestReduceDefaultWithGradONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -95,8 +82,7 @@ def setUp(self): } -@skip_check_grad_ci(reason="not implemented") -class TestReduceSum4DReduceAllONEDNNOp(TestReduceSumDefaultONEDNNOp): +class TestReduceSum4DReduceAllONEDNNOp(TestReduceDefaultWithGradONEDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -154,8 +140,7 @@ def setUp(self): } -@skip_check_grad_ci(reason="not implemented") -class TestReduceMean3DONEDNNOp(TestReduceSumDefaultONEDNNOp): +class TestReduceMean3DONEDNNOp(TestReduceDefaultWithGradONEDNNOp): def setUp(self): self.op_type = "reduce_mean" self.use_mkldnn = True @@ -166,8 +151,7 @@ def setUp(self): } -@skip_check_grad_ci(reason="not implemented") -class TestReduceMean4DReduceAllONEDNNOp(TestReduceSumDefaultONEDNNOp): +class TestReduceMean4DReduceAllONEDNNOp(TestReduceDefaultWithGradONEDNNOp): def setUp(self): self.op_type = "reduce_mean" self.use_mkldnn = True @@ -179,16 +163,6 @@ def setUp(self): } -@skip_check_grad_ci(reason="not implemented") -class TestReduceMeanNoReduce1DOp(TestReduceSumDefaultONEDNNOp): - def setUp(self): - self.op_type = "reduce_mean" - self.use_mkldnn = True - self.inputs = {'X': np.random.random((1)).astype("float32")} - self.attrs = {'use_mkldnn': self.use_mkldnn} - self.outputs = {'Out': self.inputs['X']} - - if __name__ == '__main__': paddle.enable_static() unittest.main() From 7d3797f63cc7f9d7c762b5ce7bb1045f28e1e1e2 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Wed, 14 Apr 2021 13:17:32 +0200 Subject: [PATCH 23/37] added formatting --- .../mkldnn/reduce_mean_mkldnn_op.cc | 17 ++++----- .../reduce_ops/mkldnn/reduce_mkldnn_op.h | 25 +++++++------ .../reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc | 2 +- paddle/fluid/operators/reduce_ops/reduce_op.h | 16 +++++---- paddle/fluid/platform/mkldnn_reuse.h | 35 ++++++++++--------- .../mkldnn/test_reduce_bf16_mkldnn_op.py | 19 ++++++---- .../unittests/mkldnn/test_reduce_mkldnn_op.py | 3 +- 7 files changed, 67 insertions(+), 50 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc index 8bb931ad3c372..1386c0252aaea 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc @@ -34,19 +34,20 @@ class ReduceMeanGradMKLDNNKernel : public ReduceGradMKLDNNKernel { auto reduce_dims = ctx.Attr>("dim"); int number_of_elements = 1; - if(!ctx.Attr("reduce_all")){ - for(size_t i = 0; i < reduce_dims.size() ; ++i){ + if (!ctx.Attr("reduce_all")) { + for (size_t i = 0; i < reduce_dims.size(); ++i) { reduce_dims[i] = (reduce_dims[i] >= 0) - ? reduce_dims[i] - : input_dims.size() + reduce_dims[i]; + ? reduce_dims[i] + : input_dims.size() + reduce_dims[i]; number_of_elements *= input_dims[reduce_dims[i]]; } } else { - for(size_t i = 0; i < input_dims.size() ; ++i) - number_of_elements *= input_dims[i]; + for (size_t i = 0; i < input_dims.size(); ++i) + number_of_elements *= input_dims[i]; } - this->RunKernel(ctx, dnnl::algorithm::binary_add, 0.0f, 1.0L / number_of_elements); + this->RunKernel(ctx, dnnl::algorithm::binary_add, 0.0f, + 1.0L / number_of_elements); } }; @@ -60,4 +61,4 @@ REGISTER_OP_KERNEL(reduce_mean, MKLDNN, paddle::platform::CPUPlace, REGISTER_OP_KERNEL(reduce_mean_grad, MKLDNN, paddle::platform::CPUPlace, ops::ReduceMeanGradMKLDNNKernel, - ops::ReduceMeanGradMKLDNNKernel); + ops::ReduceMeanGradMKLDNNKernel); diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h index 9ea1a2d285ae7..39eee9c2ddc7d 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -125,8 +125,8 @@ template class ReduceGradMKLDNNKernel : public framework::OpKernel { public: void RunKernel(const framework::ExecutionContext& ctx, - dnnl::algorithm binary_type, float scale_x, float scale_y) const { - //only for reduce sum for now + dnnl::algorithm binary_type, float scale_x, + float scale_y) const { auto& dev_ctx = ctx.template device_context(); const auto& onednn_engine = dev_ctx.GetEngine(); @@ -138,13 +138,14 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { auto* output_dx = ctx.Output(framework::GradVarName("X")); output_dx->mutable_data(ctx.GetPlace()); - + output_dx->set_format(getPlainFormatTag(output_dx)); output_dx->set_layout(input_dy->layout()); - platform::BinaryReductionGradMKLDNNHandler handler(binary_type, dev_ctx, onednn_engine, - ctx.GetPlace(), output_dx, input_dy, scale_x, scale_y, - ctx.InputName(framework::GradVarName("Out"))); + platform::BinaryReductionGradMKLDNNHandler handler( + binary_type, dev_ctx, onednn_engine, ctx.GetPlace(), output_dx, + input_dy, scale_x, scale_y, + ctx.InputName(framework::GradVarName("Out"))); auto src_dx_memory = handler.AcquireSrcMemory(output_dx); const auto src_dy_memory = handler.AcquireSecondSrcMemory(input_dy); @@ -163,9 +164,10 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { binary_prim->execute(astream, args); astream.wait(); } -protected: - mkldnn::memory::format_tag getPlainFormatTag(const Tensor* tensor) const{ - switch(tensor->dims().size()){ + + protected: + mkldnn::memory::format_tag getPlainFormatTag(const Tensor* tensor) const { + switch (tensor->dims().size()) { case 1: return mkldnn::memory::format_tag::a; case 2: @@ -177,8 +179,9 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { case 5: return mkldnn::memory::format_tag::abcde; default: - platform::errors::InvalidArgument("Tensor dims must be in range <1, 5>"); - } + platform::errors::InvalidArgument( + "Tensor dims must be in range <1, 5>"); + } } }; diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc index d9bf5ad427b7c..e62edcf559677 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc @@ -43,4 +43,4 @@ REGISTER_OP_KERNEL(reduce_sum, MKLDNN, paddle::platform::CPUPlace, REGISTER_OP_KERNEL(reduce_sum_grad, MKLDNN, paddle::platform::CPUPlace, ops::ReduceSumGradMKLDNNKernel, - ops::ReduceSumGradMKLDNNKernel); + ops::ReduceSumGradMKLDNNKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index e04405d86c38b..c43c7e5e76f28 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -559,27 +559,29 @@ class ReduceGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - - auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, framework::GradVarName("Out")); + auto input_data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); auto CanMKLDNNReduceGradBeUsed = [&]() { auto dx_dims = ctx.Input("X")->dims(); - if (ctx.Attr("reduce_all") || (ctx.Attr>("dim").size() == dx_dims.size())) + if (ctx.Attr("reduce_all") || + (ctx.Attr>("dim").size() == dx_dims.size())) return true; auto dy_dims = ctx.Input(framework::GradVarName("Out"))->dims(); // Subtensor must be on rightmost part of the bigger tensor - for(size_t i = 0; i < dy_dims.size() ; ++i){ - if(dx_dims[dx_dims.size() - dy_dims.size() + i] != dy_dims[i]){ + for (size_t i = 0; i < dy_dims.size(); ++i) { + if (dx_dims[dx_dims.size() - dy_dims.size() + i] != dy_dims[i]) { return false; } } return true; }; - + #ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx, input_data_type) && CanMKLDNNReduceGradBeUsed()) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type) && + CanMKLDNNReduceGradBeUsed()) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index e7157e1d766d8..403cd0836a4ed 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -631,19 +631,19 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT { }; template -class BinaryReductionGradMKLDNNHandler : public platform::MKLDNNHandlerT { +class BinaryReductionGradMKLDNNHandler + : public platform::MKLDNNHandlerT { public: BinaryReductionGradMKLDNNHandler(const dnnl::algorithm algo, - const MKLDNNDeviceContext& dev_ctx, - const mkldnn::engine engine, platform::Place cpu_place, - const Tensor* x, const Tensor* y, - float scale_x, float scale_y, - const std::string& uniq_name) + const MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine engine, + platform::Place cpu_place, const Tensor* x, + const Tensor* y, float scale_x, + float scale_y, const std::string& uniq_name) : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, - platform::CreateKey( - dev_ctx, framework::vectorize(x->dims()), uniq_name)) { - + platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), + uniq_name)) { if (!this->isCached()) { PADDLE_ENFORCE_EQ( x->layout(), DataLayout::kMKLDNN, @@ -662,20 +662,24 @@ class BinaryReductionGradMKLDNNHandler : public platform::MKLDNNHandlerTdims()); const auto src0_tz = framework::vectorize(x->dims()); - // GetExpectedKernelType checks if smaller vector is a subvector with all the dims in correct order on the rightmost part of the bigger vector, f.e. a correct vector for broadcasting: + // GetExpectedKernelType checks if smaller vector is a subvector with all + // the dims in correct order on the rightmost part of the bigger vector, + // f.e. a correct vector for broadcasting: // x = 5, 7, 3, 2, 4, 8 // y = 4, 8 - for(size_t i = src1_tz.size() ; i < src0_tz.size() ; ++i){ + for (size_t i = src1_tz.size(); i < src0_tz.size(); ++i) { src1_tz.insert(src1_tz.begin(), 1L); } const auto src0_md = dnnl::memory::desc( src0_tz, platform::MKLDNNGetDataType(), x->format()); - const auto src1_md = dnnl::memory::desc( - src1_tz, platform::MKLDNNGetDataType(), x->format());//y->format()); + const auto src1_md = + dnnl::memory::desc(src1_tz, platform::MKLDNNGetDataType(), + x->format()); // y->format()); - //const auto dst_md = dnnl::memory::desc( // in reduction binary op is always inplace - // output_dims, platform::MKLDNNGetDataType(), x->format()); + // const auto dst_md = dnnl::memory::desc( // in reduction binary op is + // always inplace + // output_dims, platform::MKLDNNGetDataType(), x->format()); dnnl::primitive_attr attributes; attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x}); @@ -729,7 +733,6 @@ class ReductionMKLDNNHandler } }; - template class ActivationMKLDNNHandler : public MKLDNNHandlerT Date: Wed, 14 Apr 2021 13:22:02 +0200 Subject: [PATCH 24/37] minor changes --- paddle/fluid/platform/mkldnn_reuse.h | 9 ++------- .../tests/unittests/dygraph_to_static/test_resnet_v2.py | 1 - 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 403cd0836a4ed..48ae76d59709c 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -673,13 +673,8 @@ class BinaryReductionGradMKLDNNHandler const auto src0_md = dnnl::memory::desc( src0_tz, platform::MKLDNNGetDataType(), x->format()); - const auto src1_md = - dnnl::memory::desc(src1_tz, platform::MKLDNNGetDataType(), - x->format()); // y->format()); - - // const auto dst_md = dnnl::memory::desc( // in reduction binary op is - // always inplace - // output_dims, platform::MKLDNNGetDataType(), x->format()); + const auto src1_md = dnnl::memory::desc( + src1_tz, platform::MKLDNNGetDataType(), x->format()); dnnl::primitive_attr attributes; attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x}); diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py index 8c779ec7ef62a..10346ab0cc442 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_resnet_v2.py @@ -356,7 +356,6 @@ def test_resnet(self): self.verify_predict() def test_in_static_mode_mkldnn(self): - paddle.enable_static() paddle.fluid.set_flags({'FLAGS_use_mkldnn': True}) try: train(to_static=True) From bd6927085d2c9234e31020061b3884f9f091344a Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Wed, 14 Apr 2021 14:04:04 +0200 Subject: [PATCH 25/37] minor change --- paddle/fluid/operators/reduce_ops/reduce_op.h | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index c43c7e5e76f28..1f852f97728b3 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -562,24 +562,23 @@ class ReduceGradOp : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); +#ifdef PADDLE_WITH_MKLDNN auto CanMKLDNNReduceGradBeUsed = [&]() { auto dx_dims = ctx.Input("X")->dims(); if (ctx.Attr("reduce_all") || - (ctx.Attr>("dim").size() == dx_dims.size())) + ((int)ctx.Attr>("dim").size() == dx_dims.size())) return true; auto dy_dims = ctx.Input(framework::GradVarName("Out"))->dims(); // Subtensor must be on rightmost part of the bigger tensor - for (size_t i = 0; i < dy_dims.size(); ++i) { + for (int i = 0; i < dy_dims.size(); ++i) { if (dx_dims[dx_dims.size() - dy_dims.size() + i] != dy_dims[i]) { return false; } } return true; }; - -#ifdef PADDLE_WITH_MKLDNN if (this->CanMKLDNNBeUsed(ctx, input_data_type) && CanMKLDNNReduceGradBeUsed()) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), From ffe61565e68eee582c899d3fdbdc9fc630b70914 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Wed, 14 Apr 2021 14:55:05 +0200 Subject: [PATCH 26/37] minor formatting change --- .../reduce_ops/mkldnn/reduce_mkldnn_op.h | 2 +- .../unittests/mkldnn/test_reduce_grad_tmp.py | 58 ------------------- 2 files changed, 1 insertion(+), 59 deletions(-) delete mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_reduce_grad_tmp.py diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h index 39eee9c2ddc7d..f13606ba23a1e 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -132,7 +132,6 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { const auto& onednn_engine = dev_ctx.GetEngine(); auto dims = ctx.Attr>("dim"); - auto* input_x = ctx.Input("X"); auto* input_dy = ctx.Input(framework::GradVarName("Out")); auto* output_dx = ctx.Output(framework::GradVarName("X")); @@ -181,6 +180,7 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { default: platform::errors::InvalidArgument( "Tensor dims must be in range <1, 5>"); + return mkldnn:memory:format_tag::a; } } }; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_grad_tmp.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_grad_tmp.py deleted file mode 100644 index b8199d193a3f1..0000000000000 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_grad_tmp.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import unittest -import numpy as np -from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci -import paddle -import paddle.fluid.core as core -import paddle.fluid as fluid -from paddle.fluid import compiler, Program, program_guard -from paddle.fluid.framework import convert_np_dtype_to_dtype_ - - -class TestSumOpDefault(OpTest): - def setUp(self): - self.op_type = "reduce_sum" - self.inputs = {'X': np.random.random((5, 5, 4)).astype("float32")} - self.attrs = {'dim': (0, 1), 'use_mkldnn' : True} - self.outputs = {'Out': self.inputs['X'].sum(axis=(0, 1))} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - - -class TestSum4DOp(OpTest): - def setUp(self): - self.op_type = "reduce_sum" - self.inputs = {'X': np.random.random((3, 5, 6, 10)).astype("float32")} - self.attrs = {'dim': (0, 2, 3), 'use_mkldnn' : True} - self.outputs = {'Out': self.inputs['X'].sum(axis=(0, 2, 3))} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - - -if __name__ == '__main__': - import paddle - paddle.enable_static() - unittest.main() From 27f8bb72431ac4fc1da5a5acb5b63522163f7110 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Wed, 14 Apr 2021 15:21:45 +0200 Subject: [PATCH 27/37] minor change --- paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h index f13606ba23a1e..48dd006547e7d 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -180,7 +180,7 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { default: platform::errors::InvalidArgument( "Tensor dims must be in range <1, 5>"); - return mkldnn:memory:format_tag::a; + return mkldnn::memory::format_tag::a; } } }; From 27355d0cac88c94a98b8cf1448fefd18dea8512d Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Wed, 14 Apr 2021 16:04:22 +0200 Subject: [PATCH 28/37] changed test --- .../fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py index 193c7b5b45514..3b99c07008625 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py @@ -52,7 +52,7 @@ class TestReduceSum4DReduceAllWithoutReduceAllAttributeONEDNNOp( def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True - self.inputs = {'X': np.random.random((5, 10, 5, 6)).astype("float32")} + self.inputs = {'X': np.random.random((5, 10, 5, 7)).astype("float32")} self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [0, 1, 2, 3]} self.outputs = { 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) From 1445bd6b3ec2b0846021980afe49cf0bf4f22a02 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Wed, 14 Apr 2021 18:24:17 +0200 Subject: [PATCH 29/37] minor changes --- paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h | 6 +----- paddle/fluid/operators/reduce_ops/reduce_op.h | 2 ++ 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h index 48dd006547e7d..54b13901a1e90 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -175,12 +175,8 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { return mkldnn::memory::format_tag::abc; case 4: return mkldnn::memory::format_tag::abcd; - case 5: - return mkldnn::memory::format_tag::abcde; default: - platform::errors::InvalidArgument( - "Tensor dims must be in range <1, 5>"); - return mkldnn::memory::format_tag::a; + return mkldnn::memory::format_tag::abcde; } } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 1f852f97728b3..24719e7c34320 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -569,6 +569,8 @@ class ReduceGradOp : public framework::OperatorWithKernel { ((int)ctx.Attr>("dim").size() == dx_dims.size())) return true; + if(dx_dims.size() > 5) return false; // max 5D tensor is supported + auto dy_dims = ctx.Input(framework::GradVarName("Out"))->dims(); // Subtensor must be on rightmost part of the bigger tensor From aa5dccdcf4471106ad90560abf47894601a51244 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Wed, 14 Apr 2021 18:37:24 +0200 Subject: [PATCH 30/37] added formatting --- paddle/fluid/operators/reduce_ops/reduce_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 24719e7c34320..b6d92ef201cf3 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -569,7 +569,7 @@ class ReduceGradOp : public framework::OperatorWithKernel { ((int)ctx.Attr>("dim").size() == dx_dims.size())) return true; - if(dx_dims.size() > 5) return false; // max 5D tensor is supported + if (dx_dims.size() > 5) return false; // max 5D tensor is supported auto dy_dims = ctx.Input(framework::GradVarName("Out"))->dims(); From 9f9eea92d1f10c61db4fb72b92b0610a122d47ce Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Wed, 14 Apr 2021 20:15:24 +0200 Subject: [PATCH 31/37] minor change --- .../fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py index 3b99c07008625..9247a04e954b9 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py @@ -52,7 +52,7 @@ class TestReduceSum4DReduceAllWithoutReduceAllAttributeONEDNNOp( def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True - self.inputs = {'X': np.random.random((5, 10, 5, 7)).astype("float32")} + self.inputs = {'X': np.random.random((5, 10, 5, 3)).astype("float32")} self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [0, 1, 2, 3]} self.outputs = { 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) From 996b81e932566a910912f7a0009fc744149d8239 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Thu, 15 Apr 2021 21:59:45 +0200 Subject: [PATCH 32/37] added suggested changes --- .../reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc | 3 +-- .../operators/reduce_ops/mkldnn/reduce_mkldnn_op.h | 2 +- paddle/fluid/platform/mkldnn_reuse.h | 11 +++++++++-- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc index 1386c0252aaea..33daeea8599c6 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc @@ -42,8 +42,7 @@ class ReduceMeanGradMKLDNNKernel : public ReduceGradMKLDNNKernel { number_of_elements *= input_dims[reduce_dims[i]]; } } else { - for (size_t i = 0; i < input_dims.size(); ++i) - number_of_elements *= input_dims[i]; + number_of_elements = input_x->numel(); } this->RunKernel(ctx, dnnl::algorithm::binary_add, 0.0f, diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h index 54b13901a1e90..e9caab6a2660d 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -141,7 +141,7 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { output_dx->set_format(getPlainFormatTag(output_dx)); output_dx->set_layout(input_dy->layout()); - platform::BinaryReductionGradMKLDNNHandler handler( + platform::BroadcastDataMKLDNNHandler handler( binary_type, dev_ctx, onednn_engine, ctx.GetPlace(), output_dx, input_dy, scale_x, scale_y, ctx.InputName(framework::GradVarName("Out"))); diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 48ae76d59709c..b00cab5b7dae8 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -631,10 +631,10 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT { }; template -class BinaryReductionGradMKLDNNHandler +class BroadcastDataMKLDNNHandler : public platform::MKLDNNHandlerT { public: - BinaryReductionGradMKLDNNHandler(const dnnl::algorithm algo, + BroadcastDataMKLDNNHandler(const dnnl::algorithm algo, const MKLDNNDeviceContext& dev_ctx, const mkldnn::engine engine, platform::Place cpu_place, const Tensor* x, @@ -685,6 +685,13 @@ class BinaryReductionGradMKLDNNHandler } } + std::shared_ptr AcquireSrcMemory(framework::Tensor* input) { + T* input_data = input->data(); + memset(input_data, 0, this->fwd_pd_->src_desc().get_size()); + return this->AcquireMemoryFromPrimitive( + this->fwd_pd_->src_desc(), to_void_cast(input_data), "@src_mem_p"); + } + std::shared_ptr AcquireSecondSrcMemory( const framework::Tensor* input) { const T* input_data = input->data(); From fce4eb42efbeb4a2af7bb78ff86d6a211e4d2254 Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Thu, 15 Apr 2021 22:16:20 +0200 Subject: [PATCH 33/37] added formatting --- paddle/fluid/platform/mkldnn_reuse.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index b00cab5b7dae8..de36ee28b0d02 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -635,11 +635,11 @@ class BroadcastDataMKLDNNHandler : public platform::MKLDNNHandlerT { public: BroadcastDataMKLDNNHandler(const dnnl::algorithm algo, - const MKLDNNDeviceContext& dev_ctx, - const mkldnn::engine engine, - platform::Place cpu_place, const Tensor* x, - const Tensor* y, float scale_x, - float scale_y, const std::string& uniq_name) + const MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine engine, + platform::Place cpu_place, const Tensor* x, + const Tensor* y, float scale_x, float scale_y, + const std::string& uniq_name) : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), From 24af4d3d926b28cc08fc77853f584d8053dfe5ff Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Fri, 16 Apr 2021 10:51:40 +0200 Subject: [PATCH 34/37] removed doubled memset --- paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h index e9caab6a2660d..fabb30216833c 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -149,8 +149,6 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { auto src_dx_memory = handler.AcquireSrcMemory(output_dx); const auto src_dy_memory = handler.AcquireSecondSrcMemory(input_dy); - memset(output_dx->data(), 0, src_dx_memory->get_desc().get_size()); - const auto binary_prim = handler.AcquireForwardPrimitive(); const std::unordered_map args = { From 02dc16d1591433de08a127dc83c1bb4274d228cd Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Mon, 19 Apr 2021 13:43:33 +0200 Subject: [PATCH 35/37] added suggested changes --- .../reduce_ops/mkldnn/reduce_mkldnn_op.h | 20 ++++++----- paddle/fluid/operators/reduce_ops/reduce_op.h | 5 +-- paddle/fluid/platform/mkldnn_reuse.h | 8 +++-- .../mkldnn/test_reduce_bf16_mkldnn_op.py | 34 +++++++++---------- .../unittests/mkldnn/test_reduce_mkldnn_op.py | 30 ++++++++-------- 5 files changed, 50 insertions(+), 47 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h index fabb30216833c..e4026d18fbfb9 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -127,17 +127,15 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { void RunKernel(const framework::ExecutionContext& ctx, dnnl::algorithm binary_type, float scale_x, float scale_y) const { - auto& dev_ctx = + const auto& dev_ctx = ctx.template device_context(); const auto& onednn_engine = dev_ctx.GetEngine(); auto dims = ctx.Attr>("dim"); auto* input_dy = ctx.Input(framework::GradVarName("Out")); - auto* output_dx = ctx.Output(framework::GradVarName("X")); output_dx->mutable_data(ctx.GetPlace()); - output_dx->set_format(getPlainFormatTag(output_dx)); output_dx->set_layout(input_dy->layout()); @@ -146,9 +144,8 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { input_dy, scale_x, scale_y, ctx.InputName(framework::GradVarName("Out"))); - auto src_dx_memory = handler.AcquireSrcMemory(output_dx); + const auto src_dx_memory = handler.AcquireSrcMemory(output_dx); const auto src_dy_memory = handler.AcquireSecondSrcMemory(input_dy); - const auto binary_prim = handler.AcquireForwardPrimitive(); const std::unordered_map args = { @@ -157,14 +154,19 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { {DNNL_ARG_DST, *src_dx_memory}}; auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - binary_prim->execute(astream, args); astream.wait(); } protected: mkldnn::memory::format_tag getPlainFormatTag(const Tensor* tensor) const { - switch (tensor->dims().size()) { + + auto tensor_dims_size = tensor->dims().size(); + PADDLE_ENFORCE_EQ( + tensor_dims_size <= 5 && tensor_dims_size >= 1, true, + platform::errors::InvalidArgument("Dims for reduction_grad oneDNN op must be in range <1, 5>")); + + switch (tensor_dims_size) { case 1: return mkldnn::memory::format_tag::a; case 2: @@ -173,9 +175,9 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { return mkldnn::memory::format_tag::abc; case 4: return mkldnn::memory::format_tag::abcd; - default: - return mkldnn::memory::format_tag::abcde; } + + return mkldnn::memory::format_tag::abcde; } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index b6d92ef201cf3..913d941df8810 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -565,12 +565,13 @@ class ReduceGradOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN auto CanMKLDNNReduceGradBeUsed = [&]() { auto dx_dims = ctx.Input("X")->dims(); + + if (dx_dims.size() > 5) return false; // max 5D tensor is supported + if (ctx.Attr("reduce_all") || ((int)ctx.Attr>("dim").size() == dx_dims.size())) return true; - if (dx_dims.size() > 5) return false; // max 5D tensor is supported - auto dy_dims = ctx.Input(framework::GradVarName("Out"))->dims(); // Subtensor must be on rightmost part of the bigger tensor diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index de36ee28b0d02..448f4e177c99c 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -664,9 +664,11 @@ class BroadcastDataMKLDNNHandler // GetExpectedKernelType checks if smaller vector is a subvector with all // the dims in correct order on the rightmost part of the bigger vector, - // f.e. a correct vector for broadcasting: + // i.e. a correct vector for broadcasting: // x = 5, 7, 3, 2, 4, 8 // y = 4, 8 + src1_tz.reserve(src0_tz.size()); + for (size_t i = src1_tz.size(); i < src0_tz.size(); ++i) { src1_tz.insert(src1_tz.begin(), 1L); } @@ -674,7 +676,7 @@ class BroadcastDataMKLDNNHandler const auto src0_md = dnnl::memory::desc( src0_tz, platform::MKLDNNGetDataType(), x->format()); const auto src1_md = dnnl::memory::desc( - src1_tz, platform::MKLDNNGetDataType(), x->format()); + src1_tz, platform::MKLDNNGetDataType(), y->format()); dnnl::primitive_attr attributes; attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x}); @@ -689,7 +691,7 @@ class BroadcastDataMKLDNNHandler T* input_data = input->data(); memset(input_data, 0, this->fwd_pd_->src_desc().get_size()); return this->AcquireMemoryFromPrimitive( - this->fwd_pd_->src_desc(), to_void_cast(input_data), "@src_mem_p"); + this->fwd_pd_->src_desc(), to_void_cast(input_data), "@src0_mem_p"); } std::shared_ptr AcquireSecondSrcMemory( diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py index a10215a2bb787..1d7ab4f6b3369 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py @@ -26,7 +26,7 @@ "place does not support BF16 evaluation") @unittest.skipIf(core.is_compiled_with_cuda(), "core is compiled with CUDA which has no BF implementation") -class TestReduceSumDefaultBF16ONEDNNOp(OpTest): +class TestReduceSumDefaultBF16OneDNNOp(OpTest): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -64,16 +64,14 @@ def calculate_grads(self): keepdim = True self.grad_Out = self.x_fp32.sum(axis=axis, keepdims=keepdim) - self.grad_Out = np.atleast_1d(self.grad_Out) - self.grad_X = tmp_tensor + self.grad_Out # broadcast grad if self.op_type == 'reduce_mean': self.grad_X /= prod_of_reduced_dims -class TestReduceDefaultWithGradBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): +class TestReduceDefaultWithGradBF16OneDNNOp(TestReduceSumDefaultBF16OneDNNOp): def test_check_grad(self): self.calculate_grads() self.check_grad_with_place( @@ -84,8 +82,8 @@ def test_check_grad(self): user_defined_grad_outputs=[convert_float_to_uint16(self.grad_Out)]) -class TestReduceSum4DReduceAllWithoutReduceAllAttributeBF16ONEDNNOp( - TestReduceDefaultWithGradBF16ONEDNNOp): +class TestReduceSum4DReduceAllDimAttributeBF16OneDNNOp( + TestReduceDefaultWithGradBF16OneDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -96,8 +94,8 @@ def setUp(self): self.outputs = {'Out': self.x_fp32.sum(axis=tuple(self.attrs['dim']))} -class TestReduceSum4DReduceAllWithoutReduceAllAttributeNegativeDimsBF16ONEDNNOp( - TestReduceDefaultWithGradBF16ONEDNNOp): +class TestReduceSum4DReduceAllWithoutReduceAllAttributeNegativeDimsBF16OneDNNOp( + TestReduceDefaultWithGradBF16OneDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -108,8 +106,8 @@ def setUp(self): self.outputs = {'Out': self.x_fp32.sum(axis=tuple(self.attrs['dim']))} -class TestReduceSum5DReduceAllKeepDimsBF16ONEDNNOp( - TestReduceDefaultWithGradBF16ONEDNNOp): +class TestReduceSum5DReduceAllKeepDimsBF16OneDNNOp( + TestReduceDefaultWithGradBF16OneDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -120,8 +118,8 @@ def setUp(self): self.outputs = {'Out': self.x_fp32.sum(keepdims=self.attrs['keep_dim'])} -class TestReduceSum4DReduceAllBF16ONEDNNOp( - TestReduceDefaultWithGradBF16ONEDNNOp): +class TestReduceSum4DReduceAllBF16OneDNNOp( + TestReduceDefaultWithGradBF16OneDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -135,7 +133,7 @@ def setUp(self): @skip_check_grad_ci( reason="reduce_max is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") -class TestReduceMax3DBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): +class TestReduceMax3DBF16OneDNNOp(TestReduceSumDefaultBF16OneDNNOp): """Remove Max with subgradient from gradient check to confirm the success of CI.""" def setUp(self): @@ -151,8 +149,8 @@ def setUp(self): @skip_check_grad_ci( reason="reduce_max is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") -class TestReduceMax4DNegativeAndPositiveDimsBF16ONEDNNOp( - TestReduceSumDefaultBF16ONEDNNOp): +class TestReduceMax4DNegativeAndPositiveDimsBF16OneDNNOp( + TestReduceSumDefaultBF16OneDNNOp): """Remove Max with subgradient from gradient check to confirm the success of CI.""" def setUp(self): @@ -168,7 +166,7 @@ def setUp(self): @skip_check_grad_ci( reason="reduce_min is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") -class TestReduceMin3DBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): +class TestReduceMin3DBF16OneDNNOp(TestReduceSumDefaultBF16OneDNNOp): """Remove Min with subgradient from gradient check to confirm the success of CI.""" def setUp(self): @@ -181,7 +179,7 @@ def setUp(self): self.outputs = {'Out': self.x_fp32.min(axis=tuple(self.attrs['dim']))} -class TestReduceMean3DBF16ONEDNNOp(TestReduceDefaultWithGradBF16ONEDNNOp): +class TestReduceMean3DBF16OneDNNOp(TestReduceDefaultWithGradBF16OneDNNOp): def setUp(self): self.op_type = "reduce_mean" self.use_mkldnn = True @@ -192,7 +190,7 @@ def setUp(self): self.outputs = {'Out': self.x_fp32.sum(axis=0) / self.x_fp32.shape[0]} -class TestReduceMean4DBF16ONEDNNOp(TestReduceDefaultWithGradBF16ONEDNNOp): +class TestReduceMean4DBF16OneDNNOp(TestReduceDefaultWithGradBF16OneDNNOp): def setUp(self): self.op_type = "reduce_mean" self.use_mkldnn = True diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py index 9247a04e954b9..46ee2a14a2018 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py @@ -19,7 +19,7 @@ import paddle -class TestReduceSumDefaultONEDNNOp(OpTest): +class TestReduceSumDefaultOneDNNOp(OpTest): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -31,12 +31,12 @@ def test_check_output(self): self.check_output() -class TestReduceDefaultWithGradONEDNNOp(TestReduceSumDefaultONEDNNOp): +class TestReduceDefaultWithGradOneDNNOp(TestReduceSumDefaultOneDNNOp): def test_check_grad(self): self.check_grad(['X'], 'Out') -class TestReduceSum4DONEDNNOp(TestReduceDefaultWithGradONEDNNOp): +class TestReduceSum4DOneDNNOp(TestReduceDefaultWithGradOneDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -47,8 +47,8 @@ def setUp(self): } -class TestReduceSum4DReduceAllWithoutReduceAllAttributeONEDNNOp( - TestReduceDefaultWithGradONEDNNOp): +class TestReduceSum4DReduceAllDimAttributeBF16OneDNNOp( + TestReduceDefaultWithGradOneDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -59,7 +59,7 @@ def setUp(self): } -class TestReduceSum5DKeepDimsONEDNNOp(TestReduceDefaultWithGradONEDNNOp): +class TestReduceSum5DKeepDimsOneDNNOp(TestReduceDefaultWithGradOneDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -71,8 +71,8 @@ def setUp(self): } -class TestReduceSum5DReduceAllKeepDimsONEDNNOp( - TestReduceDefaultWithGradONEDNNOp): +class TestReduceSum5DReduceAllKeepDimsOneDNNOp( + TestReduceDefaultWithGradOneDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -83,7 +83,7 @@ def setUp(self): } -class TestReduceSum4DReduceAllONEDNNOp(TestReduceDefaultWithGradONEDNNOp): +class TestReduceSum4DReduceAllOneDNNOp(TestReduceDefaultWithGradOneDNNOp): def setUp(self): self.op_type = "reduce_sum" self.use_mkldnn = True @@ -95,7 +95,7 @@ def setUp(self): @skip_check_grad_ci( reason="reduce_max is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") -class TestReduceMax3DONEDNNOp(TestReduceSumDefaultONEDNNOp): +class TestReduceMax3DOneDNNOp(TestReduceSumDefaultOneDNNOp): """Remove Max with subgradient from gradient check to confirm the success of CI.""" def setUp(self): @@ -111,8 +111,8 @@ def setUp(self): @skip_check_grad_ci( reason="reduce_max is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") -class TestReduceMax4DNegativeAndPositiveDimsONEDNNOp( - TestReduceSumDefaultONEDNNOp): +class TestReduceMax4DNegativeAndPositiveDimsOneDNNOp( + TestReduceSumDefaultOneDNNOp): """Remove Max with subgradient from gradient check to confirm the success of CI.""" def setUp(self): @@ -128,7 +128,7 @@ def setUp(self): @skip_check_grad_ci( reason="reduce_min is discontinuous non-derivable function," " its gradient check is not supported by unittest framework.") -class TestReduceMin3DONEDNNOp(TestReduceSumDefaultONEDNNOp): +class TestReduceMin3DOneDNNOp(TestReduceSumDefaultOneDNNOp): """Remove Min with subgradient from gradient check to confirm the success of CI.""" def setUp(self): @@ -141,7 +141,7 @@ def setUp(self): } -class TestReduceMean3DONEDNNOp(TestReduceDefaultWithGradONEDNNOp): +class TestReduceMean3DOneDNNOp(TestReduceDefaultWithGradOneDNNOp): def setUp(self): self.op_type = "reduce_mean" self.use_mkldnn = True @@ -152,7 +152,7 @@ def setUp(self): } -class TestReduceMean4DReduceAllONEDNNOp(TestReduceDefaultWithGradONEDNNOp): +class TestReduceMean4DReduceAllOneDNNOp(TestReduceDefaultWithGradOneDNNOp): def setUp(self): self.op_type = "reduce_mean" self.use_mkldnn = True From 646444288bd72c3d9be764131ef452efbe36d11e Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Mon, 19 Apr 2021 13:54:26 +0200 Subject: [PATCH 36/37] reverted one change --- paddle/fluid/platform/mkldnn_reuse.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 448f4e177c99c..54efa55cc4cd9 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -676,7 +676,7 @@ class BroadcastDataMKLDNNHandler const auto src0_md = dnnl::memory::desc( src0_tz, platform::MKLDNNGetDataType(), x->format()); const auto src1_md = dnnl::memory::desc( - src1_tz, platform::MKLDNNGetDataType(), y->format()); + src1_tz, platform::MKLDNNGetDataType(), x->format()); dnnl::primitive_attr attributes; attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x}); From 73875932138f18a2a4c253e424374737043bb01a Mon Sep 17 00:00:00 2001 From: Jakub Piasecki Date: Mon, 19 Apr 2021 14:30:28 +0200 Subject: [PATCH 37/37] changed formatting --- paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h index e4026d18fbfb9..58416f479c043 100644 --- a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -160,11 +160,11 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel { protected: mkldnn::memory::format_tag getPlainFormatTag(const Tensor* tensor) const { - auto tensor_dims_size = tensor->dims().size(); PADDLE_ENFORCE_EQ( tensor_dims_size <= 5 && tensor_dims_size >= 1, true, - platform::errors::InvalidArgument("Dims for reduction_grad oneDNN op must be in range <1, 5>")); + platform::errors::InvalidArgument( + "Dims for reduction_grad oneDNN op must be in range <1, 5>")); switch (tensor_dims_size) { case 1: