From 17c7733a0fc0435a4f51cd5332bef6ed7ff61ddc Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Tue, 15 Oct 2019 22:04:26 +0800 Subject: [PATCH 1/2] change slice to mkldnn v1.0 --- src/operator/nn/mkldnn/mkldnn_slice-inl.h | 6 ++--- src/operator/nn/mkldnn/mkldnn_slice.cc | 30 ++++++++++++----------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_slice-inl.h b/src/operator/nn/mkldnn/mkldnn_slice-inl.h index f41db01a9837..22334665b8bb 100644 --- a/src/operator/nn/mkldnn/mkldnn_slice-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_slice-inl.h @@ -26,7 +26,7 @@ #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include @@ -45,7 +45,7 @@ class MKLDNNSliceFwd { const NDArray &in, const NDArray &out); void SetNewMem(const mkldnn::memory &input, const mkldnn::memory &output); - const mkldnn::reorder &GetPd() const; + void Register(); private: std::shared_ptr data_; @@ -62,5 +62,5 @@ void MKLDNNSlice(const SliceParam ¶m, const OpContext& ctx, } // namespace op } // namespace mxnet -#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_USE_MKLDNN == 100 #endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_slice.cc b/src/operator/nn/mkldnn/mkldnn_slice.cc index 2a817a25a5b8..bb894e3fce5e 100644 --- a/src/operator/nn/mkldnn/mkldnn_slice.cc +++ b/src/operator/nn/mkldnn/mkldnn_slice.cc @@ -23,7 +23,7 @@ * \author Zhiyuan Huang */ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "./mkldnn_ops-inl.h" #include "./mkldnn_base-inl.h" @@ -49,13 +49,15 @@ MKLDNNSliceFwd::MKLDNNSliceFwd(const SliceParam ¶m, dims[i] = oshape[i]; offsets[i] = s; } - auto in_mem_pd = in.GetMKLDNNData()->get_primitive_desc(); - auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc(); - auto view_pd = mkldnn::view::primitive_desc(in_mem_pd, dims, offsets); - auto reorder_pd = reorder::primitive_desc(view_pd.dst_primitive_desc(), out_mem_pd); - this->data_ = std::make_shared(view_pd.dst_primitive_desc(), nullptr); - this->out_ = std::make_shared(view_pd.dst_primitive_desc(), nullptr); - this->fwd_ = std::make_shared(reorder_pd, *this->data_, *this->out_); + + auto in_md = in.GetMKLDNNData()->get_desc(); + auto out_md = out.GetMKLDNNData()->get_desc(); + auto sub_md = in_md.submemory_desc(dims, offsets); + + auto engine = CpuEngine::Get()->get_engine(); + this->data_ = std::make_shared(sub_md, engine, nullptr); + this->out_ = std::make_shared(out_md, engine, nullptr); + this->fwd_ = std::make_shared(*this->data_, *this->out_); } void MKLDNNSliceFwd::SetNewMem(const mkldnn::memory &input, const mkldnn::memory &output) { @@ -63,8 +65,8 @@ void MKLDNNSliceFwd::SetNewMem(const mkldnn::memory &input, const mkldnn::memory this->out_->set_data_handle(output.get_data_handle()); } -const mkldnn::reorder &MKLDNNSliceFwd::GetPd() const { - return *fwd_; +void MKLDNNSliceFwd::Register() { + MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, {{MKLDNN_ARG_FROM, *(this->data_)}, {MKLDNN_ARG_TO, *(this->out_)}}); } MKLDNNSliceFwd &GetSliceForward(const SliceParam ¶m, const bool is_train, @@ -91,14 +93,14 @@ void MKLDNNSlice(const SliceParam ¶m, const OpContext& ctx, const NDArray &in, OpReqType req, const NDArray &out) { MKLDNNSliceFwd &fwd = GetSliceForward(param, ctx.is_train, in, out); auto in_mem = in.GetMKLDNNData(); - auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc(); - auto out_mem = CreateMKLDNNMem(out, out_mem_pd, req); + auto out_md = out.GetMKLDNNData()->get_desc(); + auto out_mem = CreateMKLDNNMem(out, out_md, req); fwd.SetNewMem(*in_mem, *out_mem.second); - MKLDNNStream::Get()->RegisterPrim(fwd.GetPd()); + fwd.Register(); CommitOutput(out, out_mem); MKLDNNStream::Get()->Submit(); } } // namespace op } // namespace mxnet -#endif // MXNET_USE_MKLDNN == 1 +#endif // MXNET_USE_MKLDNN == 100 From 1d467c381c103333652553cb86a3026147095f41 Mon Sep 17 00:00:00 2001 From: Tao Lv Date: Wed, 16 Oct 2019 00:25:46 +0800 Subject: [PATCH 2/2] fix lint --- src/operator/nn/mkldnn/mkldnn_slice.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/operator/nn/mkldnn/mkldnn_slice.cc b/src/operator/nn/mkldnn/mkldnn_slice.cc index bb894e3fce5e..d9d98b36c6a0 100644 --- a/src/operator/nn/mkldnn/mkldnn_slice.cc +++ b/src/operator/nn/mkldnn/mkldnn_slice.cc @@ -66,7 +66,8 @@ void MKLDNNSliceFwd::SetNewMem(const mkldnn::memory &input, const mkldnn::memory } void MKLDNNSliceFwd::Register() { - MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, {{MKLDNN_ARG_FROM, *(this->data_)}, {MKLDNN_ARG_TO, *(this->out_)}}); + MKLDNNStream::Get()->RegisterPrimArgs(*fwd_, + {{MKLDNN_ARG_FROM, *(this->data_)}, {MKLDNN_ARG_TO, *(this->out_)}}); } MKLDNNSliceFwd &GetSliceForward(const SliceParam ¶m, const bool is_train,