diff --git a/onnxruntime/core/providers/cpu/activation/activations.h b/onnxruntime/core/providers/cpu/activation/activations.h index 7b5e213d8c9b2..ef2bc78a849d8 100644 --- a/onnxruntime/core/providers/cpu/activation/activations.h +++ b/onnxruntime/core/providers/cpu/activation/activations.h @@ -90,7 +90,7 @@ class ParametricSoftplus final : public OpKernel { }; template -class Relu final : public OpKernel { +class Relu : public OpKernel { public: Relu(const OpKernelInfo& info) : OpKernel(info) {} diff --git a/onnxruntime/core/providers/mkldnn/activation/activations.cc b/onnxruntime/core/providers/mkldnn/activation/activations.cc new file mode 100644 index 0000000000000..0764f04e0e1f3 --- /dev/null +++ b/onnxruntime/core/providers/mkldnn/activation/activations.cc @@ -0,0 +1,223 @@ +// Copyright(C) 2018 Intel Corporation +// Licensed under the MIT License + +#ifdef _WIN32 +#pragma warning(disable : 4244) +#endif + +#include "core/providers/mkldnn/mkldnn_common.h" +#include "core/providers/mkldnn/activation/activations.h" +#include "core/providers/mkldnn/mkldnn_fwd.h" + +namespace onnxruntime { +namespace mkl_dnn { + +namespace { +// Struct which encapsulates parameters for MKLDNN Pool primitive. +struct ReluParams { + mkldnn::memory::dims& src_dims; + mkldnn::memory::dims& dst_dims; + size_t num_dimensions; + + ReluParams(mkldnn::memory::dims& src_dims, mkldnn::memory::dims& dst_dims, + size_t dimensions = 0) + : src_dims(src_dims), + dst_dims(dst_dims), + num_dimensions(dimensions) {} + + // Used as the key for Pool Primitive Reuse Pool. + std::string ToString() const { + std::string key; + key.reserve(64); + key.append("Relu_"); + AddDimsToKey(key, src_dims); + AddDimsToKey(key, dst_dims); + return key; + } +}; + +template +class ReluPrimitive final : public PrimitiveBase { + public: + explicit ReluPrimitive(const ReluParams& params) + : cpu_engine_(GetEngine()) { + context_.stream.reset(new mkldnn::stream(mkldnn::stream::kind::eager)); + if (context_.relu_fwd == nullptr) { + Initialize(params); + } + } + + ~ReluPrimitive() = default; + + void Compute(const T* src_data, const T* dst_data) { + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data))); + context_.dst_mem->set_data_handle( + static_cast(const_cast(dst_data))); + context_.stream->submit(context_.net); + + context_.src_mem->set_data_handle(nullptr); + context_.dst_mem->set_data_handle(nullptr); + return; + } + + std::unique_ptr + GetDstMemoryDesc() const { return context_.dst_md; } + + std::unique_ptr + GetPrimitiveDesc() const { + return context_.relu_fwd_pd; + } + + private: + struct ReluContext { + mkldnn::memory::format src_fmt; + mkldnn::memory::format dst_fmt; + + std::unique_ptr src_mem; + std::unique_ptr dst_mem; + + std::unique_ptr fwd_desc; + std::unique_ptr relu_fwd_pd; + std::unique_ptr relu_fwd; + + std::unique_ptr src_md; + std::unique_ptr dst_md; + + std::unique_ptr stream; + std::vector net; + + ReluContext() + : src_fmt(mkldnn::memory::format::any), + dst_fmt(mkldnn::memory::format::any), + src_mem(nullptr), + dst_mem(nullptr), + fwd_desc(nullptr), + relu_fwd_pd(nullptr), + relu_fwd(nullptr), + src_md(nullptr), + dst_md(nullptr), + stream(nullptr) {} + }; + + void Initialize(const ReluParams& params) { + + mkldnn::memory::format fmt = mkldnn::memory::format::any; + switch (params.num_dimensions) { + case 1: { fmt = mkldnn::memory::format::x; break; } + case 2: { fmt = mkldnn::memory::format::nc; break; } + case 3: { fmt = mkldnn::memory::format::ntc; break; } + case 4: { fmt = mkldnn::memory::format::nchw; break; } + case 5: { fmt = mkldnn::memory::format::ncdhw; break; } + default: { fmt = mkldnn::memory::format::any; break; } + } + + context_.src_md.reset(new mkldnn::memory::desc( + { params.src_dims }, MklDnnType(), fmt)); + context_.dst_md.reset(new mkldnn::memory::desc( + { params.dst_dims }, MklDnnType(), fmt)); + + context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc( + mkldnn::prop_kind::forward_inference, mkldnn::algorithm::eltwise_relu, + *context_.src_md, 0, 0)); + + context_.relu_fwd_pd.reset( + new mkldnn::eltwise_forward::primitive_desc(*context_.fwd_desc, + cpu_engine_)); + + context_.src_fmt = static_cast( + context_.relu_fwd_pd.get()->dst_primitive_desc().desc().data.format); + + context_.dst_fmt = static_cast( + context_.relu_fwd_pd.get()->dst_primitive_desc().desc().data.format); + + context_.src_mem.reset( + new mkldnn::memory(context_.relu_fwd_pd.get()->dst_primitive_desc(), + nullptr)); + context_.dst_mem.reset( + new mkldnn::memory(context_.relu_fwd_pd.get()->dst_primitive_desc(), + nullptr)); + context_.relu_fwd.reset( + new mkldnn::eltwise_forward(*context_.relu_fwd_pd, *context_.src_mem, + *context_.dst_mem)); + context_.net.push_back(*context_.relu_fwd); + } + + ReluContext context_; + mkldnn::engine& cpu_engine_; +}; + +// Pool which allows for reuse of MKLDNN Relu primitives which are expensive +// to instantiate. To address thread safety, the primitives are stored in a map +// on thread local storage. +template +class ReluPrimitivePool : public PrimitivePool { + public: + static ReluPrimitive* Get(const ReluParams& params) { + ReluPrimitive* primitive = dynamic_cast*>( + ReluPrimitivePool::GetInstance().GetPrimitive(params.ToString())); + + if (primitive == nullptr) { + auto relu_primitive = std::make_unique>(params); + primitive = relu_primitive.get(); + ReluPrimitivePool::GetInstance().SetPrimitive(params.ToString(), + std::move(relu_primitive)); + } + return primitive; + } + + private: + ReluPrimitivePool() = default; + ~ReluPrimitivePool() = default; + + static ReluPrimitivePool& GetInstance() { + static ReluPrimitivePool pool; + return pool; + } +}; +} // namespace + +template +Status Relu::Compute(OpKernelContext* context) const { + const Tensor* X = context->Input(0); + Tensor* Y = context->Output(0, X->Shape()); + + const TensorShape& x_shape = X->Shape(); + const auto& x_dims = x_shape.GetDims(); + + if (X->Shape().NumDimensions() > 5 ) { + return onnxruntime::Relu::Compute(context); + } + + const TensorShape& y_shape = Y->Shape(); + auto& y_dims = y_shape.GetDims(); + + const T* src_data = X->template Data(); + T* dst_data = Y->template MutableData(); + + mkldnn::memory::dims src_dims_mkl(x_dims.begin(), x_dims.end()); + mkldnn::memory::dims dst_dims_mkl(y_dims.begin(), y_dims.end()); + + try { + ReluParams pool_params(src_dims_mkl, dst_dims_mkl, x_shape.NumDimensions()); + ReluPrimitive* relulPrimitive = ReluPrimitivePool::Get(pool_params); + + relulPrimitive->Compute(src_data, dst_data); + } catch (const mkldnn::error& e) { + return ONNXRUNTIME_MAKE_STATUS(ONNXRUNTIME, FAIL, "Status: ", e.status, + ", message: ", e.message.c_str()); + } + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + Relu, + kOnnxDomain, + 6, + kMklDnnExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Relu); + +} // namespace mkl_dnn +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/mkldnn/activation/activations.h b/onnxruntime/core/providers/mkldnn/activation/activations.h new file mode 100644 index 0000000000000..be8ba3537b91c --- /dev/null +++ b/onnxruntime/core/providers/mkldnn/activation/activations.h @@ -0,0 +1,20 @@ +// Copyright(C) 2018 Intel Corporation +// Licensed under the MIT License + +#pragma once +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/activation/activations.h" + +namespace onnxruntime { +namespace mkl_dnn { + +template +class Relu : public onnxruntime::Relu { + public: + Relu(const OpKernelInfo& info) : onnxruntime::Relu(info) {} + + Status Compute(OpKernelContext* context) const override; +}; + +} // namespace mkl_dnn +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc b/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc index f98ce0c68bd48..6e99ffd83fb03 100644 --- a/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc +++ b/onnxruntime/core/providers/mkldnn/mkldnn_execution_provider.cc @@ -65,6 +65,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 1, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 7, Gemm); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 1, MemcpyToHost); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 6, Relu); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 7, 8, float, AveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 1, 8, float, GlobalAveragePool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kMklDnnExecutionProvider, kOnnxDomain, 1, 7, float, MaxPool); @@ -77,6 +78,7 @@ void RegisterMKLDNNKernels(std::function fn) { fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); + fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel()); fn(BuildKernel());