diff --git a/src/operator/nn/deconvolution-inl.h b/src/operator/nn/deconvolution-inl.h index b6d522b9e6f9..352ca66fb690 100644 --- a/src/operator/nn/deconvolution-inl.h +++ b/src/operator/nn/deconvolution-inl.h @@ -163,6 +163,31 @@ struct DeconvolutionParam : public dmlc::Parameter { this->cudnn_off == other.cudnn_off && this->layout == other.layout; } +#if MXNET_USE_MKLDNN == 1 + static uint64_t ComputeHash(const TShape &shape) { + uint64_t hash = 0; + for (size_t i = 0; i < shape.ndim(); i++) + hash = hash * 2 + shape[i]; + return hash; + } + + uint64_t GetHash() const { + uint64_t hash = 0; + hash = hash * 2 + ComputeHash(kernel); + hash = hash * 2 + ComputeHash(stride); + hash = hash * 2 + ComputeHash(dilate); + hash = hash * 2 + ComputeHash(pad); + hash = hash * 2 + ComputeHash(adj); + hash = hash * 2 + ComputeHash(target_shape); + hash = hash * 2 + num_filter; + hash = hash * 2 + num_group; + hash = hash * 2 + workspace; + hash = hash * 2 + no_bias; + if (layout.has_value()) + hash = hash * 2 + layout.value(); + return hash; + } +#endif }; } // namespace op diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc index dc22437e68df..db0c90d7f9a8 100644 --- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc @@ -20,14 +20,15 @@ /*! * \file mkldnn_deconvolution.cc * \brief - * \author Da Zheng + * \author Da Zheng, Rong Zhang (rong.a.zhang@intel.com) */ +#if MXNET_USE_MKLDNN == 1 + #include "../deconvolution-inl.h" #include "./mkldnn_ops-inl.h" #include "./mkldnn_base-inl.h" -#if MXNET_USE_MKLDNN == 1 namespace mxnet { namespace op { @@ -59,7 +60,7 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwd_( } } -static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwd( +static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl( const DeconvolutionParam& param, const NDArray &data, const NDArray &weights, bool has_bias, const NDArray &output) { auto data_md = GetMemDesc(data); @@ -70,11 +71,21 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwd( if (param.stride.ndim() == 2) { strides[0] = param.stride[0]; strides[1] = param.stride[1]; + } else if (param.stride.ndim() == 1) { + strides[0] = param.stride[0]; + strides[1] = param.stride[0]; + } else { + LOG(FATAL) << "Unsupported stride dim"; } mkldnn::memory::dims padding{0, 0}; if (param.pad.ndim() == 2) { padding[0] = param.pad[0]; padding[1] = param.pad[1]; + } else if (param.pad.ndim() == 1) { + padding[0] = param.pad[0]; + padding[1] = param.pad[0]; + } else { + LOG(FATAL) << "Unsupported pad dim"; } mkldnn::memory::dims dilate{0, 0}; if (param.dilate.ndim() == 2) { @@ -100,11 +111,21 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData( if (param.stride.ndim() == 2) { strides[0] = param.stride[0]; strides[1] = param.stride[1]; + } else if (param.stride.ndim() == 1) { + strides[0] = param.stride[0]; + strides[1] = param.stride[0]; + } else { + LOG(FATAL) << "Unsupported stride dim"; } mkldnn::memory::dims padding{0, 0}; if (param.pad.ndim() == 2) { padding[0] = param.pad[0]; padding[1] = param.pad[1]; + } else if (param.pad.ndim() == 1) { + padding[0] = param.pad[0]; + padding[1] = param.pad[0]; + } else { + LOG(FATAL) << "Unsupported pad dim"; } mkldnn::memory::dims dilate{0, 0}; if (param.dilate.ndim() == 2) { @@ -127,11 +148,21 @@ static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights( if (param.stride.ndim() == 2) { strides[0] = param.stride[0]; strides[1] = param.stride[1]; + } else if (param.stride.ndim() == 1) { + strides[0] = param.stride[0]; + strides[1] = param.stride[0]; + } else { + LOG(FATAL) << "Unsupported stride dim"; } mkldnn::memory::dims padding{0, 0}; if (param.pad.ndim() == 2) { padding[0] = param.pad[0]; padding[1] = param.pad[1]; + } else if (param.pad.ndim() == 1) { + padding[0] = param.pad[0]; + padding[1] = param.pad[0]; + } else { + LOG(FATAL) << "Unsupported pad dim"; } mkldnn::memory::dims dilate{0, 0}; if (param.dilate.ndim() == 2) { @@ -151,18 +182,58 @@ static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights( } } -void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const std::vector &in_data, - const std::vector &req, - const std::vector &out_data) { - TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]); - const DeconvolutionParam& param = nnvm::get(attrs.parsed); +class MKLDNNDeconvForward { + std::shared_ptr fwd; + std::shared_ptr data; + std::shared_ptr weight; + std::shared_ptr bias; + std::shared_ptr out; + OutDataOp data_op; + + public: + MKLDNNDeconvForward(const DeconvolutionParam& param, + const NDArray &data, + const NDArray &weights, + bool has_bias, + const NDArray &output); + void SetDataHandle(const DeconvolutionParam& param, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data); + + void Execute(const std::vector &out_data); + + private: + mkldnn::convolution_backward_data::primitive_desc fwd_pd; +}; // class MKLDNNDeconvForward - mkldnn::convolution_backward_data::primitive_desc deconvFwd_pd = GetDeconvFwd( - param, in_data[deconv::kData], in_data[deconv::kWeight], false, - out_data[deconv::kOut]); +MKLDNNDeconvForward::MKLDNNDeconvForward(const DeconvolutionParam& param, + const NDArray &data, + const NDArray &weights, + bool has_bias, + const NDArray &output) + :fwd_pd(GetDeconvFwdImpl(param, data, weights, has_bias, output)) { + this->data = std::shared_ptr(new mkldnn::memory( + fwd_pd.diff_dst_primitive_desc())); + this->weight = std::shared_ptr(new mkldnn::memory( + fwd_pd.weights_primitive_desc())); + this->out = std::shared_ptr(new mkldnn::memory( + fwd_pd.diff_src_primitive_desc())); + this->fwd = std::shared_ptr( + new mkldnn::convolution_backward_data(fwd_pd, + mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->weight), + *this->out)); +} + +void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { auto data_mem = in_data[deconv::kData].GetMKLDNNDataReorder( - deconvFwd_pd.diff_dst_primitive_desc()); + fwd_pd.diff_dst_primitive_desc()); const mkldnn::memory *weight_mem; if (ctx.is_train) { // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it @@ -170,23 +241,34 @@ void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &c if (in_data[deconv::kWeight].IsMKLDNN()) const_cast(in_data[deconv::kWeight]).Reorder2Default(); weight_mem = GetWeights(in_data[deconv::kWeight], - deconvFwd_pd.weights_primitive_desc(), + fwd_pd.weights_primitive_desc(), param.num_group); } else { // For inference, we want to reorder the weight array so we don't need to // reorder data every time. const_cast(in_data[deconv::kWeight]).Reorder( - deconvFwd_pd.weights_primitive_desc()); + fwd_pd.weights_primitive_desc()); weight_mem = in_data[deconv::kWeight].GetMKLDNNData(); } auto out_mem = CreateMKLDNNMem(out_data[deconv::kOut], - deconvFwd_pd.diff_src_primitive_desc(), - req[deconv::kOut]); + fwd_pd.diff_src_primitive_desc(), req[deconv::kOut]); + auto output = out_mem.second; + this->data->set_data_handle(data_mem->get_data_handle()); + this->weight->set_data_handle(weight_mem->get_data_handle()); + this->out->set_data_handle(output->get_data_handle()); + this->data_op = out_mem.first; +} - MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_data( - deconvFwd_pd, *data_mem, *weight_mem, *out_mem.second)); - CommitOutput(out_data[deconv::kOut], out_mem); +void MKLDNNDeconvForward::Execute(const std::vector &out_data) { + MKLDNNStream::Get()->RegisterPrim(*fwd); + CommitOutput(out_data[deconv::kOut], mkldnn_output_t(this->data_op, this->out.get())); MKLDNNStream::Get()->Submit(); +} + +static void MKLDNNDeconvFwdBiasPostProcess(const DeconvolutionParam& param, + const OpContext &ctx, + const std::vector &in_data, + const std::vector &out_data) { // add bias, broadcast bias to dim 1: channel if (!param.no_bias) { // MKLDNN only supports float right now. @@ -201,6 +283,55 @@ void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &c } } +typedef MKLDNNParamOpSign MKLDNNDeconvSignature; + +static inline MKLDNNDeconvForward &GetDeconvFwd( + const nnvm::NodeAttrs& attrs, const NDArray &data, + const NDArray &weights, const NDArray *bias, + const NDArray &output) { + static thread_local + std::unordered_map fwds; + const DeconvolutionParam& param = nnvm::get(attrs.parsed); + MKLDNNDeconvSignature key(param); + // Here we can sign the conv op with NDArray because conv primitive will + // decide the right layout for the, so we only need to get the shape and the + // data type of the arrays. + key.AddSign(data); + key.AddSign(weights); + key.AddSign(output); + if (bias) + key.AddSign(*bias); + + auto it = fwds.find(key); + if (it == fwds.end()) { + bool has_bias = (bias != nullptr); + MKLDNNDeconvForward fwd(param, data, weights, has_bias, output); + auto ins_ret = fwds.insert( + std::pair(key, fwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + +void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data) { + TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]); + const DeconvolutionParam& param = nnvm::get(attrs.parsed); + + MKLDNNDeconvForward &deconvFwd = GetDeconvFwd( + attrs, in_data[deconv::kData], in_data[deconv::kWeight], + param.no_bias ? nullptr : &in_data[deconv::kBias], out_data[deconv::kOut]); + + deconvFwd.SetDataHandle(param, ctx, in_data, req, out_data); + + deconvFwd.Execute(out_data); + + MKLDNNDeconvFwdBiasPostProcess(param, ctx, in_data, out_data); +} + void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector& inputs, const std::vector& req,