diff --git a/paddle/fluid/operators/mkldnn/cast_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/cast_mkldnn_op.cc deleted file mode 100644 index 7335693053fa0..0000000000000 --- a/paddle/fluid/operators/mkldnn/cast_mkldnn_op.cc +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/platform/mkldnn_reuse.h" - -namespace paddle { -namespace operators { - -using paddle::framework::Tensor; - -template -class CastMKLDNNKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - this->RunKernel(ctx); - } - - void RunKernel(const framework::ExecutionContext& ctx) const { - const auto& dev_ctx = - ctx.template device_context(); - - auto* x = ctx.Input("X"); - auto* out = ctx.Output("Out"); - - int in_dtype = ctx.Attr("in_dtype"); - int out_dtype = ctx.Attr("out_dtype"); - - auto x_paddle_type = framework::proto::VarType::Type(in_dtype); - auto out_paddle_type = framework::proto::VarType::Type(out_dtype); - - dnnl::memory::data_type x_type = framework::ToMKLDNNDataType(x_paddle_type); - dnnl::memory::data_type out_type = - framework::ToMKLDNNDataType(out_paddle_type); - - auto x_tz = phi::vectorize(x->dims()); - - platform::ReorderMKLDNNHandler reorder_handler(x_tz, - x_paddle_type, - x_type, - out_paddle_type, - out_type, - dev_ctx.GetEngine()); - - auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( - x->mem_desc(), platform::to_void_cast(x->data())); - auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( - out, x->mem_desc(), dev_ctx.GetPlace()); - auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, - reorder_src_memory_p); - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); - astream.wait(); - - out->set_layout(framework::DataLayout::kMKLDNN); - out->set_mem_desc(reorder_dst_memory_p->get_desc()); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_KERNEL(cast, - MKLDNN, - paddle::platform::CPUPlace, - ops::CastMKLDNNKernel, - ops::CastMKLDNNKernel); diff --git a/paddle/fluid/operators/mkldnn/clip_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/clip_mkldnn_op.cc deleted file mode 100644 index f1a7ade2b4809..0000000000000 --- a/paddle/fluid/operators/mkldnn/clip_mkldnn_op.cc +++ /dev/null @@ -1,109 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/platform/mkldnn_reuse.h" - -namespace { - -using paddle::framework::Tensor; - -template -class ClipMKLDNNKernel : public paddle::framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - this->RunKernel(ctx); - } - - void RunKernel(const paddle::framework::ExecutionContext& ctx) const { - const auto& dev_ctx = - ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - - auto* x = ctx.Input("X"); - auto* out = ctx.Output("Out"); - - paddle::platform::ActivationMKLDNNHandler handler( - dnnl::algorithm::eltwise_clip_v2, - ctx, - mkldnn_engine, - ctx.GetPlace(), - x); - - auto src_memory_p = handler.AcquireSrcMemory(x); - auto dst_memory_p = handler.AcquireDstMemory(out); - auto activation_p = handler.AcquireForwardPrimitive(); - - auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); - activation_p->execute( - astream, - {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}}); - astream.wait(); - - out->set_mem_desc(dst_memory_p->get_desc()); - } -}; - -template -class ClipGradMKLDNNKernel : public paddle::framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - this->RunKernel(ctx); - } - - void RunKernel(const paddle::framework::ExecutionContext& ctx) const { - const auto& dev_ctx = - ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - - auto* x = ctx.Input("X"); - auto* dx = ctx.Output(paddle::framework::GradVarName("X")); - auto* dout = ctx.Input(paddle::framework::GradVarName("Out")); - - paddle::platform::ActivationMKLDNNHandler handler( - dnnl::algorithm::eltwise_clip_v2, - ctx, - mkldnn_engine, - ctx.GetPlace(), - x, - dout); - - auto src_memory_p = handler.AcquireBackwardSrcMemory(x); - auto diff_dst_memory_p = handler.AcquireDiffDstMemory(dout); - auto diff_src_memory_p = handler.AcquireDiffSrcMemory(dx); - auto activation_backward_p = handler.AcquireBackwardPrimitive(); - - auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); - activation_backward_p->execute(astream, - {{DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, - {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); - astream.wait(); - - dx->set_mem_desc(diff_dst_memory_p->get_desc()); - } -}; - -} // anonymous namespace - -REGISTER_OP_KERNEL(clip, - MKLDNN, - paddle::platform::CPUPlace, - ClipMKLDNNKernel, - ClipMKLDNNKernel); - -REGISTER_OP_KERNEL(clip_grad, - MKLDNN, - paddle::platform::CPUPlace, - ClipGradMKLDNNKernel, - ClipGradMKLDNNKernel); diff --git a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc deleted file mode 100644 index 00ae785bca95d..0000000000000 --- a/paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc +++ /dev/null @@ -1,403 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/mkldnn_helper.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" -#include "paddle/phi/kernels/funcs/pooling.h" - -namespace paddle { -namespace operators { - -using dnnl::memory; -using dnnl::pooling_backward; -using dnnl::pooling_forward; -using dnnl::primitive; -using dnnl::reorder; -using dnnl::stream; -using framework::DataLayout; -using framework::Tensor; -using platform::to_void_cast; - -template -class PoolingMKLDNNHandler - : public platform::MKLDNNHandlerNoCachingT { - public: - PoolingMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, - const dnnl::engine mkldnn_engine, - const Tensor* input, - Tensor* output) - : platform::MKLDNNHandlerNoCachingT( - mkldnn_engine, ctx.GetPlace()) { - const std::string pooling_type = ctx.Attr("pooling_type"); - - std::vector ksize_temp = ctx.Attr>("ksize"); - std::vector ksize(begin(ksize_temp), end(ksize_temp)); - - std::vector strides_temp = ctx.Attr>("strides"); - std::vector strides(begin(strides_temp), end(strides_temp)); - - std::vector paddings_temp = ctx.Attr>("paddings"); - std::vector paddings(begin(paddings_temp), end(paddings_temp)); - - const bool global_pooling = ctx.Attr("global_pooling"); - const std::string padding_algorithm = - ctx.Attr("padding_algorithm"); - - // Only 2D pooling is supported now - PADDLE_ENFORCE_EQ( - ksize.size(), - 2, - platform::errors::InvalidArgument( - "The ksize must be 2D, i.e. 2D pooling, but received %dD.", - ksize.size())); - PADDLE_ENFORCE_EQ( - pooling_type == "max" || pooling_type == "avg", - true, - platform::errors::InvalidArgument( - "The pooling_type must be 'max' or 'avg', but received %s.", - pooling_type)); - PADDLE_ENFORCE_EQ( - input->dims().size(), - 4, - platform::errors::InvalidArgument( - "Input dim must be with 4, i.e. NCHW, but received %d.", - input->dims().size())); - - const auto input_dims = input->dims(); - framework::DDim data_dims = - phi::slice_ddim(input_dims, 2, input_dims.size()); - - if (global_pooling) { - phi::funcs::UpdateKernelSize(&ksize, data_dims); - } - - phi::funcs::UpdatePadding(&paddings, - global_pooling, - 0, - padding_algorithm, - data_dims, - strides, - ksize); - - const auto is_test = ctx.Attr("is_test"); - const bool ceil_mode = ctx.Attr("ceil_mode"); - const auto exclude_padding = ctx.Attr("exclusive"); - auto mkldnn_paddings = platform::ToMkldnnPadding(paddings); - - const auto dt = framework::ToMKLDNNDataType( - framework::TransToProtoVarType(input->dtype())); - const auto src_tz = phi::vectorize(input->dims()); - const auto dst_tz = phi::vectorize(output->dims()); - const auto dst_md = - platform::MKLDNNMemDesc(dst_tz, dt, MKLDNNMemoryFormat::any); - - if (ceil_mode) { - CorrectOutputSize( - src_tz, dst_tz, ksize, paddings, strides, mkldnn_paddings[1]); - } - - ComputeAdaptivePoolParameters(ctx, src_tz, &ksize, &strides); - - this->AcquireForwardPrimitiveDescriptor( - is_test ? dnnl::prop_kind::forward_inference - : dnnl::prop_kind::forward_training, - pooling_type == "max" - ? dnnl::algorithm::pooling_max - : (exclude_padding ? dnnl::algorithm::pooling_avg_exclude_padding - : dnnl::algorithm::pooling_avg_include_padding), - input->mem_desc(), - dst_md, - strides, - ksize, - mkldnn_paddings[0], - mkldnn_paddings[1]); - } - - PoolingMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, - const dnnl::engine mkldnn_engine, - const Tensor* in_x, - const Tensor* out_grad, - Tensor* in_x_grad) - - : platform::MKLDNNHandlerNoCachingT( - mkldnn_engine, ctx.GetPlace()) { - PADDLE_ENFORCE_EQ( - ctx.Attr("is_test"), - false, - platform::errors::InvalidArgument( - "is_test attribute should be set to False in training phase.")); - - std::string pooling_type = ctx.Attr("pooling_type"); - - std::vector ksize_temp = ctx.Attr>("ksize"); - std::vector ksize(begin(ksize_temp), end(ksize_temp)); - - std::vector strides_temp = ctx.Attr>("strides"); - std::vector strides(begin(strides_temp), end(strides_temp)); - - std::vector paddings_temp = ctx.Attr>("paddings"); - std::vector paddings(begin(paddings_temp), end(paddings_temp)); - - bool global_pooling = ctx.Attr("global_pooling"); - std::string padding_algorithm = ctx.Attr("padding_algorithm"); - - auto in_x_dims = in_x->dims(); - framework::DDim data_dims = phi::slice_ddim(in_x_dims, 2, in_x_dims.size()); - - if (global_pooling) { - phi::funcs::UpdateKernelSize(&ksize, data_dims); - } - - phi::funcs::UpdatePadding(&paddings, - global_pooling, - 0, - padding_algorithm, - data_dims, - strides, - ksize); - - auto src_tz = phi::vectorize(in_x->dims()); - auto diff_src_tz = phi::vectorize(in_x_grad->dims()); - auto diff_dst_tz = phi::vectorize(out_grad->dims()); - - const auto dt = framework::ToMKLDNNDataType( - framework::TransToProtoVarType(in_x->dtype())); - auto dst_md = dnnl::memory::desc(diff_dst_tz, dt, MKLDNNMemoryFormat::any); - auto diff_src_md = dnnl::memory::desc( - diff_src_tz, platform::MKLDNNGetDataType(), MKLDNNMemoryFormat::any); - - auto mkldnn_paddings = platform::ToMkldnnPadding(paddings); - const bool ceil_mode = ctx.Attr("ceil_mode"); - - if (ceil_mode) { - CorrectOutputSize( - src_tz, diff_dst_tz, ksize, paddings, strides, mkldnn_paddings[1]); - } - ComputeAdaptivePoolParameters(ctx, diff_src_tz, &ksize, &strides); - - const auto exclude_padding = ctx.Attr("exclusive"); - - this->AcquireForwardPrimitiveDescriptor( - dnnl::prop_kind::forward_training, - pooling_type == "max" - ? dnnl::algorithm::pooling_max - : (exclude_padding ? dnnl::algorithm::pooling_avg_exclude_padding - : dnnl::algorithm::pooling_avg_include_padding), - in_x->mem_desc(), - dst_md, - strides, - ksize, - mkldnn_paddings[0], - mkldnn_paddings[1]); - - this->AcquireBackwardPrimitiveDescriptor( - pooling_type == "max" - ? dnnl::algorithm::pooling_max - : (exclude_padding ? dnnl::algorithm::pooling_avg_exclude_padding - : dnnl::algorithm::pooling_avg_include_padding), - diff_src_md, - out_grad->mem_desc(), - strides, - ksize, - mkldnn_paddings[0], - mkldnn_paddings[1]); - } - - std::shared_ptr AcquireWorkspaceMemory( - const platform::MKLDNNDeviceContext& dev_ctx, - const std::string& unique_name) { - dnnl::memory::desc workspace_md = this->fwd_pd_->workspace_desc(); - // Pooling Workspace has to be passed to Grad op that - // may be executed by diffrent thread, hence - // for that one we use key that does not contain TID - std::string workspace_key = platform::CreateKey(dev_ctx, - workspace_md.dims(), - workspace_md.data_type(), - unique_name, - "@wrk"); - auto mem_p = - std::static_pointer_cast(dev_ctx.GetBlob(workspace_key)); - if (mem_p == nullptr) { - static std::mutex acquire_barrier; - std::lock_guard block_threads_until_finish_this_job( - acquire_barrier); - mem_p = std::static_pointer_cast( - dev_ctx.GetBlob(workspace_key)); - if (mem_p == nullptr) { - mem_p = std::make_shared(workspace_md, this->engine_); - dev_ctx.SetBlob(workspace_key, mem_p); - } - } - return mem_p; - } - - static void ComputeAdaptivePoolParameters( - const paddle::framework::ExecutionContext& ctx, - const std::vector& src_tz, - std::vector* ksize, - std::vector* strides) { - if (ctx.Attr("adaptive")) { - // https://github.com/oneapi-src/oneDNN/tree/bkocot/adaptive-pooling/rfcs/20200818-adaptive-pooling - auto IH = static_cast(src_tz[src_tz.size() - 2]); - auto IW = static_cast(src_tz[src_tz.size() - 1]); - auto OH = static_cast(ksize->at(0)); - auto OW = static_cast(ksize->at(1)); - - strides->at(0) = - static_cast(floor((IH * 2.0) / OH) - floor(IH / OH)); - strides->at(1) = - static_cast(floor((IW * 2.0) / OW) - floor(IW / OW)); - ksize->at(0) = - static_cast(ceil((IH * 2.0) / OH) - floor(IH / OH)); - ksize->at(1) = - static_cast(ceil((IW * 2.0) / OW) - floor(IW / OW)); - } - } - - private: - static inline int ComputeCeiledOutput(int input_size, - int kernel_size, - int padding, - int stride) { - return (input_size - kernel_size + 2 * padding) / stride + 1; - } - - static inline void CorrectOutputSize( - const std::vector& src_tz, - const std::vector& dst_tz, - const std::vector& kernel_size, - const std::vector& paddings, - const std::vector& strides, - std::vector& right_bot_padding) { // NOLINT - for (size_t i = 0; i < right_bot_padding.size(); i++) { - int desired_size = ComputeCeiledOutput( - src_tz[i + 2], kernel_size[i], paddings[i], strides[i]); - if (desired_size != dst_tz[i + 2]) { - right_bot_padding[i] += strides[i] - 1; - } - } - } -}; - -template -class PoolMKLDNNOpKernel : public paddle::framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), - true, - paddle::platform::errors::PreconditionNotMet( - "Operator DNNL Pool must use CPUPlace")); - auto& dev_ctx = - ctx.template device_context(); - - const Tensor* input = ctx.Input("X"); - Tensor* output = ctx.Output("Out"); - - PoolingMKLDNNHandler handler(ctx, dev_ctx.GetEngine(), input, output); - - auto src_memory = handler.AcquireSrcMemory(input); - auto dst_memory = handler.AcquireDstMemory(output); - - auto pool_p = handler.AcquireForwardPrimitive(); - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - if ((ctx.Attr("is_test") == false) && - (ctx.Attr("pooling_type") == "max")) { - // Training - auto workspace_memory = - handler.AcquireWorkspaceMemory(dev_ctx, ctx.OutputName("Out")); - pool_p->execute(astream, - {{DNNL_ARG_SRC, *src_memory}, - {DNNL_ARG_DST, *dst_memory}, - {DNNL_ARG_WORKSPACE, *workspace_memory}}); - } else { - // Inference - pool_p->execute( - astream, {{DNNL_ARG_SRC, *src_memory}, {DNNL_ARG_DST, *dst_memory}}); - } - astream.wait(); - - output->set_mem_desc(dst_memory->get_desc()); - } -}; - -template -class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), - true, - paddle::platform::errors::PreconditionNotMet( - "Operator DNNL PoolGrad must use CPUPlace")); - const Tensor* in_x = ctx.Input("X"); - const Tensor* out_grad = ctx.Input(framework::GradVarName("Out")); - Tensor* in_x_grad = ctx.Output(framework::GradVarName("X")); - - auto& dev_ctx = - ctx.template device_context(); - - PoolingMKLDNNHandler handler( - ctx, dev_ctx.GetEngine(), in_x, out_grad, in_x_grad); - - auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad); - auto diff_src_memory = handler.AcquireDiffSrcMemory(in_x_grad); - - auto pool_bwd_p = handler.AcquireBackwardPrimitive(); - - auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); - if (ctx.Attr("pooling_type") == "max") { - // Max - pooling needs Workspace - auto workspace_memory = - handler.AcquireWorkspaceMemory(dev_ctx, ctx.InputName("Out")); - pool_bwd_p->execute(astream, - {{DNNL_ARG_DIFF_SRC, *diff_src_memory}, - {DNNL_ARG_DIFF_DST, *diff_dst_memory}, - {DNNL_ARG_WORKSPACE, *workspace_memory}}); - } else { - // Average Pooling - pool_bwd_p->execute(astream, - {{DNNL_ARG_DIFF_SRC, *diff_src_memory}, - {DNNL_ARG_DIFF_DST, *diff_dst_memory}}); - } - astream.wait(); - - in_x_grad->set_mem_desc(diff_src_memory->get_desc()); - } // Compute() -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_KERNEL(pool2d, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::PoolMKLDNNOpKernel, - ops::PoolMKLDNNOpKernel, - ops::PoolMKLDNNOpKernel, - ops::PoolMKLDNNOpKernel); - -REGISTER_OP_KERNEL(pool2d_grad, - MKLDNN, - ::paddle::platform::CPUPlace, - ops::PoolMKLDNNGradOpKernel, - ops::PoolMKLDNNGradOpKernel); diff --git a/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc b/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc index db590807179d9..18c3e40280a2a 100644 --- a/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc +++ b/paddle/fluid/operators/mkldnn/test_mkldnn_op_nhwc.cc @@ -28,7 +28,7 @@ #include "paddle/phi/core/kernel_registry.h" USE_OP_ITSELF(pool2d); -USE_OP_DEVICE_KERNEL(pool2d, MKLDNN); +PD_DECLARE_KERNEL(pool2d, OneDNN, ALL_LAYOUT); USE_OP_ITSELF(relu); PD_DECLARE_KERNEL(relu, OneDNN, ALL_LAYOUT); USE_OP_ITSELF(transpose); diff --git a/paddle/phi/backends/onednn/onednn_reuse.h b/paddle/phi/backends/onednn/onednn_reuse.h index 4a540ec884d93..66376dd883543 100644 --- a/paddle/phi/backends/onednn/onednn_reuse.h +++ b/paddle/phi/backends/onednn/onednn_reuse.h @@ -24,9 +24,12 @@ limitations under the License. */ #include "paddle/phi/backends/onednn/onednn_context.h" #include "paddle/phi/backends/onednn/onednn_helper.h" #include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/place.h" +#include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/funcs/data_layout_transform.h" +#include "paddle/phi/kernels/funcs/pooling.h" namespace phi { namespace funcs { @@ -947,5 +950,313 @@ class ReductionOneDNNHandler algo, x->mem_desc(), out_md, p, eps); } }; + +template +class ClipOneDNNHandler + : public OneDNNHandlerNoCachingT { + public: + ClipOneDNNHandler(const Scalar& min, + const Scalar& max, + const dnnl::engine engine, + Place cpu_place, + const DenseTensor* x) + : OneDNNHandlerNoCachingT(engine, cpu_place) { + float alpha = min.to(); + float beta = max.to(); + + this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, + dnnl::algorithm::eltwise_clip_v2, + x->mem_desc(), + alpha, + beta); + } + + ClipOneDNNHandler(const Scalar& min, + const Scalar& max, + const dnnl::engine engine, + Place cpu_place, + const DenseTensor* x, + const DenseTensor* dout) + : OneDNNHandlerNoCachingT(engine, cpu_place) { + float alpha = min.to(); + float beta = max.to(); + + this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, + dnnl::algorithm::eltwise_clip_v2, + x->mem_desc(), + alpha, + beta); + this->AcquireBackwardPrimitiveDescriptor(dnnl::algorithm::eltwise_clip_v2, + dout->mem_desc(), + x->mem_desc(), + alpha, + beta); + } + std::shared_ptr AcquireBackwardSrcMemory( + const DenseTensor* input) { + const T* input_data = input->data(); + return this->AcquireMemoryFromPrimitive(this->bwd_pd_->src_desc(), + to_void_cast(input_data)); + } +}; + +template +class PoolingOneDNNHandler + : public OneDNNHandlerNoCachingT { + public: + PoolingOneDNNHandler(const std::string& pooling_type, + const IntArray& kernel_size, + const std::vector& strides, + const std::vector& paddings, + bool global_pooling, + const std::string& padding_algorithm, + bool ceil_mode, + bool exclusive, + bool adaptive, + const dnnl::engine engine, + Place cpu_place, + const DenseTensor* input, + DenseTensor* output) + : OneDNNHandlerNoCachingT(engine, cpu_place) { + std::vector copied_kernel_size(kernel_size.GetData().begin(), + kernel_size.GetData().end()); + std::vector copied_strides(strides.begin(), strides.end()); + std::vector copied_paddings(paddings.begin(), paddings.end()); + // Only 2D pooling is supported now + PADDLE_ENFORCE_EQ( + copied_kernel_size.size(), + 2, + errors::InvalidArgument("The copied_kernel_size must be 2D, i.e. 2D " + "pooling, but received %dD.", + copied_kernel_size.size())); + PADDLE_ENFORCE_EQ( + pooling_type == "max" || pooling_type == "avg", + true, + errors::InvalidArgument( + "The pooling_type must be 'max' or 'avg', but received %s.", + pooling_type)); + PADDLE_ENFORCE_EQ( + input->dims().size(), + 4, + errors::InvalidArgument( + "Input dim must be with 4, i.e. NCHW, but received %d.", + input->dims().size())); + + const auto input_dims = input->dims(); + DDim data_dims = slice_ddim(input_dims, 2, input_dims.size()); + + if (global_pooling) { + UpdateKernelSize(&copied_kernel_size, data_dims); + } + + UpdatePadding(&copied_paddings, + global_pooling, + 0, + padding_algorithm, + data_dims, + copied_strides, + copied_kernel_size); + + auto onednn_paddings = ToOneDNNPadding(copied_paddings); + + const auto dt = ToOneDNNDataType(input->dtype()); + const auto src_tz = vectorize(input->dims()); + const auto dst_tz = vectorize(output->dims()); + const auto dst_md = OneDNNMemDesc(dst_tz, dt, OneDNNMemoryFormat::any); + + if (ceil_mode) { + CorrectOutputSize(src_tz, + dst_tz, + copied_kernel_size, + copied_paddings, + copied_strides, + onednn_paddings[1]); + } + + if (adaptive) { + ComputeAdaptivePoolParameters( + src_tz, &copied_kernel_size, &copied_strides); + } + this->AcquireForwardPrimitiveDescriptor( + dnnl::prop_kind::forward_training, + pooling_type == "max" + ? dnnl::algorithm::pooling_max + : (exclusive ? dnnl::algorithm::pooling_avg_exclude_padding + : dnnl::algorithm::pooling_avg_include_padding), + input->mem_desc(), + dst_md, + copied_strides, + copied_kernel_size, + onednn_paddings[0], + onednn_paddings[1]); + } + + PoolingOneDNNHandler(const std::string& pooling_type, + const IntArray& kernel_size, + const std::vector& strides, + const std::vector& paddings, + bool global_pooling, + const std::string& padding_algorithm, + bool ceil_mode, + bool exclusive, + bool adaptive, + const dnnl::engine engine, + Place cpu_place, + const DenseTensor* in_x, + const DenseTensor* out_grad, + DenseTensor* in_x_grad) + + : OneDNNHandlerNoCachingT(engine, cpu_place) { + std::vector copied_kernel_size(kernel_size.GetData().begin(), + kernel_size.GetData().end()); + std::vector copied_strides(strides.begin(), strides.end()); + std::vector copied_paddings(paddings.begin(), paddings.end()); + auto in_x_dims = in_x->dims(); + DDim data_dims = slice_ddim(in_x_dims, 2, in_x_dims.size()); + if (global_pooling) { + UpdateKernelSize(&copied_kernel_size, data_dims); + } + + UpdatePadding(&copied_paddings, + global_pooling, + 0, + padding_algorithm, + data_dims, + copied_strides, + copied_kernel_size); + + auto src_tz = vectorize(in_x->dims()); + auto diff_src_tz = vectorize(in_x_grad->dims()); + auto diff_dst_tz = vectorize(out_grad->dims()); + + const auto dt = ToOneDNNDataType(in_x->dtype()); + auto dst_md = dnnl::memory::desc(diff_dst_tz, dt, OneDNNMemoryFormat::any); + auto diff_src_md = dnnl::memory::desc( + diff_src_tz, oneDNNGetDataType(), OneDNNMemoryFormat::any); + + auto onednn_paddings = ToOneDNNPadding(copied_paddings); + + if (ceil_mode) { + CorrectOutputSize(src_tz, + diff_dst_tz, + copied_kernel_size, + copied_paddings, + copied_strides, + onednn_paddings[1]); + } + + if (adaptive) { + ComputeAdaptivePoolParameters( + diff_src_tz, &copied_kernel_size, &copied_strides); + } + + this->AcquireForwardPrimitiveDescriptor( + dnnl::prop_kind::forward_training, + pooling_type == "max" + ? dnnl::algorithm::pooling_max + : (exclusive ? dnnl::algorithm::pooling_avg_exclude_padding + : dnnl::algorithm::pooling_avg_include_padding), + in_x->mem_desc(), + dst_md, + copied_strides, + copied_kernel_size, + onednn_paddings[0], + onednn_paddings[1]); + + this->AcquireBackwardPrimitiveDescriptor( + pooling_type == "max" + ? dnnl::algorithm::pooling_max + : (exclusive ? dnnl::algorithm::pooling_avg_exclude_padding + : dnnl::algorithm::pooling_avg_include_padding), + diff_src_md, + out_grad->mem_desc(), + copied_strides, + copied_kernel_size, + onednn_paddings[0], + onednn_paddings[1]); + } + + std::shared_ptr AcquireWorkspaceMemory( + const OneDNNContext& dev_ctx, const std::string& unique_name) { + dnnl::memory::desc workspace_md = this->fwd_pd_->workspace_desc(); + // Pooling Workspace has to be passed to Grad op that + // may be executed by diffrent thread, hence + // for that one we use key that does not contain TID + std::string workspace_key = CreateKey(dev_ctx, + workspace_md.dims(), + workspace_md.data_type(), + unique_name, + "@wrk"); + auto mem_p = + std::static_pointer_cast(dev_ctx.GetBlob(workspace_key)); + if (mem_p == nullptr) { + static std::mutex acquire_barrier; + std::lock_guard block_threads_until_finish_this_job( + acquire_barrier); + mem_p = std::static_pointer_cast( + dev_ctx.GetBlob(workspace_key)); + if (mem_p == nullptr) { + mem_p = std::make_shared(workspace_md, this->engine_); + dev_ctx.SetBlob(workspace_key, mem_p); + } + } + return mem_p; + } + + static void ComputeAdaptivePoolParameters(const std::vector& src_tz, + std::vector* kernel_size, + std::vector* strides) { + // https://github.com/oneapi-src/oneDNN/tree/bkocot/adaptive-pooling/rfcs/20200818-adaptive-pooling + auto IH = static_cast(src_tz[src_tz.size() - 2]); + auto IW = static_cast(src_tz[src_tz.size() - 1]); + auto OH = static_cast(kernel_size->at(0)); + auto OW = static_cast(kernel_size->at(1)); + + strides->at(0) = + static_cast(floor((IH * 2.0) / OH) - floor(IH / OH)); + strides->at(1) = + static_cast(floor((IW * 2.0) / OW) - floor(IW / OW)); + kernel_size->at(0) = + static_cast(ceil((IH * 2.0) / OH) - floor(IH / OH)); + kernel_size->at(1) = + static_cast(ceil((IW * 2.0) / OW) - floor(IW / OW)); + } + + private: + static inline int ComputeCeiledOutput(int input_size, + int kernel_size, + int padding, + int stride) { + return (input_size - kernel_size + 2 * padding) / stride + 1; + } + + static inline void CorrectOutputSize( + const std::vector& src_tz, + const std::vector& dst_tz, + const std::vector& kernel_size, + const std::vector& paddings, + const std::vector& strides, + std::vector& right_bot_padding) { // NOLINT + for (size_t i = 0; i < right_bot_padding.size(); i++) { + int desired_size = ComputeCeiledOutput( + src_tz[i + 2], kernel_size[i], paddings[i], strides[i]); + if (desired_size != dst_tz[i + 2]) { + right_bot_padding[i] += strides[i] - 1; + } + } + } +}; } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/onednn/cast_kernel.cc b/paddle/phi/kernels/onednn/cast_kernel.cc new file mode 100644 index 0000000000000..166db43db665d --- /dev/null +++ b/paddle/phi/kernels/onednn/cast_kernel.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/cast_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void CastKernel(const Context& dev_ctx, + const DenseTensor& x, + DataType out_dtype, + DenseTensor* out) { + DataType in_dtype = x.dtype(); + + dnnl::memory::data_type in_dnnl_dtype = funcs::ToOneDNNDataType(in_dtype); + dnnl::memory::data_type out_dnnl_dtype = funcs::ToOneDNNDataType(out_dtype); + + auto x_tz = phi::vectorize(x.dims()); + + funcs::ReorderOneDNNHandler reorder_handler(x_tz, + in_dtype, + in_dnnl_dtype, + out_dtype, + out_dnnl_dtype, + dev_ctx.GetEngine()); + + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + x.mem_desc(), funcs::to_void_cast(x.data())); + auto reorder_dst_memory_p = + reorder_handler.AcquireDstMemory(out, x.mem_desc(), dev_ctx.GetPlace()); + auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, + reorder_src_memory_p); + + auto& astream = OneDNNContext::tls().get_stream(); + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); + astream.wait(); + + out->set_layout(DataLayout::ONEDNN); + out->set_mem_desc(reorder_dst_memory_p->get_desc()); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + cast, OneDNN, ALL_LAYOUT, phi::CastKernel, float, phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/onednn/clip_grad_kernel.cc b/paddle/phi/kernels/onednn/clip_grad_kernel.cc new file mode 100644 index 0000000000000..aded64616b124 --- /dev/null +++ b/paddle/phi/kernels/onednn/clip_grad_kernel.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/clip_grad_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void ClipGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out_grad, + const Scalar& min, + const Scalar& max, + DenseTensor* x_grad) { + const auto& onednn_engine = dev_ctx.GetEngine(); + + funcs::ClipOneDNNHandler handler( + min, max, onednn_engine, dev_ctx.GetPlace(), &x, &out_grad); + + auto src_memory_p = handler.AcquireBackwardSrcMemory(&x); + auto diff_dst_memory_p = handler.AcquireDiffDstMemory(&out_grad); + auto diff_src_memory_p = handler.AcquireDiffSrcMemory(x_grad); + auto activation_backward_p = handler.AcquireBackwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + activation_backward_p->execute(astream, + {{DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_DIFF_DST, *diff_dst_memory_p}, + {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); + astream.wait(); + + x_grad->set_mem_desc(diff_dst_memory_p->get_desc()); +} +} // namespace phi + +PD_REGISTER_KERNEL(clip_grad, + OneDNN, + ALL_LAYOUT, + phi::ClipGradKernel, + float, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/onednn/clip_kernel.cc b/paddle/phi/kernels/onednn/clip_kernel.cc new file mode 100644 index 0000000000000..7538dd9708a93 --- /dev/null +++ b/paddle/phi/kernels/onednn/clip_kernel.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/clip_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void ClipKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& min, + const Scalar& max, + DenseTensor* out) { + const auto& onednn_engine = dev_ctx.GetEngine(); + + funcs::ClipOneDNNHandler handler( + min, max, onednn_engine, dev_ctx.GetPlace(), &x); + + auto src_memory_p = handler.AcquireSrcMemory(&x); + auto dst_memory_p = handler.AcquireDstMemory(out); + auto activation_p = handler.AcquireForwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + activation_p->execute( + astream, {{DNNL_ARG_FROM, *src_memory_p}, {DNNL_ARG_TO, *dst_memory_p}}); + astream.wait(); + + out->set_mem_desc(dst_memory_p->get_desc()); +} +} // namespace phi + +PD_REGISTER_KERNEL( + clip, OneDNN, ALL_LAYOUT, phi::ClipKernel, float, phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/onednn/pool_grad_kernel.cc b/paddle/phi/kernels/onednn/pool_grad_kernel.cc new file mode 100644 index 0000000000000..0104cd53ae9cb --- /dev/null +++ b/paddle/phi/kernels/onednn/pool_grad_kernel.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pool_grad_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void Pool2dGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& out, + const DenseTensor& dout, + const IntArray& kernel_size, + const std::vector& strides, + const std::vector& paddings, + bool ceil_mode, + bool exclusive, + const std::string& data_format, + const std::string& pooling_type, + bool global_pooling, + bool adaptive, + const std::string& padding_algorithm, + DenseTensor* dx) { + funcs::PoolingOneDNNHandler handler(pooling_type, + kernel_size, + strides, + paddings, + global_pooling, + padding_algorithm, + ceil_mode, + exclusive, + adaptive, + dev_ctx.GetEngine(), + dev_ctx.GetPlace(), + &x, + &dout, + dx); + + auto diff_dst_memory = handler.AcquireDiffDstMemory(&dout); + auto diff_src_memory = handler.AcquireDiffSrcMemory(dx); + + auto pool_bwd_p = handler.AcquireBackwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + if (pooling_type == "max") { + // Max - pooling needs Workspace + auto workspace_memory = handler.AcquireWorkspaceMemory(dev_ctx, "Out"); + pool_bwd_p->execute(astream, + {{DNNL_ARG_DIFF_SRC, *diff_src_memory}, + {DNNL_ARG_DIFF_DST, *diff_dst_memory}, + {DNNL_ARG_WORKSPACE, *workspace_memory}}); + } else { + // Average Pooling + pool_bwd_p->execute(astream, + {{DNNL_ARG_DIFF_SRC, *diff_src_memory}, + {DNNL_ARG_DIFF_DST, *diff_dst_memory}}); + } + astream.wait(); + + dx->set_mem_desc(diff_src_memory->get_desc()); +} +} // namespace phi + +PD_REGISTER_KERNEL(pool2d_grad, + OneDNN, + ALL_LAYOUT, + phi::Pool2dGradKernel, + float, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/onednn/pool_kernel.cc b/paddle/phi/kernels/onednn/pool_kernel.cc new file mode 100644 index 0000000000000..fd9f40802c269 --- /dev/null +++ b/paddle/phi/kernels/onednn/pool_kernel.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/pool_kernel.h" + +#include "paddle/phi/backends/onednn/onednn_reuse.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +template +void Pool2dKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& kernel_size, + const std::vector& strides, + const std::vector& paddings, + bool ceil_mode, + bool exclusive, + const std::string& data_format, + const std::string& pooling_type, + bool global_pooling, + bool adaptive, + const std::string& padding_algorithm, + DenseTensor* out) { + funcs::PoolingOneDNNHandler handler(pooling_type, + kernel_size, + strides, + paddings, + global_pooling, + padding_algorithm, + ceil_mode, + exclusive, + adaptive, + dev_ctx.GetEngine(), + dev_ctx.GetPlace(), + &x, + out); + + auto src_memory = handler.AcquireSrcMemory(&x); + auto dst_memory = handler.AcquireDstMemory(out); + + auto pool_p = handler.AcquireForwardPrimitive(); + + auto& astream = OneDNNContext::tls().get_stream(); + if (pooling_type == "max") { + // Training + auto workspace_memory = handler.AcquireWorkspaceMemory(dev_ctx, "Out"); + pool_p->execute(astream, + {{DNNL_ARG_SRC, *src_memory}, + {DNNL_ARG_DST, *dst_memory}, + {DNNL_ARG_WORKSPACE, *workspace_memory}}); + } else { + // Inference + pool_p->execute(astream, + {{DNNL_ARG_SRC, *src_memory}, {DNNL_ARG_DST, *dst_memory}}); + } + astream.wait(); + + out->set_mem_desc(dst_memory->get_desc()); +} +} // namespace phi + +PD_REGISTER_KERNEL(pool2d, + OneDNN, + ALL_LAYOUT, + phi::Pool2dKernel, + float, + int8_t, + uint8_t, + phi::dtype::bfloat16) {}