diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 3214e3b9b9ac..da042a1d21fb 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -28,7 +28,7 @@ #include #include "../elemwise_op_common.h" #include "../operator_common.h" -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "./mkldnn/mkldnn_batch_norm-inl.h" #endif @@ -379,7 +379,7 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, return true; } -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam ¶m) { mxnet::TShape shape = input.shape(); return SupportMKLDNN(input) && shape.ndim() == 4 @@ -454,7 +454,7 @@ static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs, const BatchNormParam ¶m = nnvm::get(attrs.parsed); bool dispatched = false; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (!dispatched) { dispatched = MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); @@ -592,11 +592,11 @@ then set ``gamma`` to 1 and its gradient to 0. .set_attr("FInferType", BatchNormType) .set_attr("FInferStorageType", BatchNormStorageType) .set_attr("FCompute", BatchNormCompute) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("FComputeEx", BatchNormComputeExCPU) #endif .set_attr("FGradient", BatchNormGrad) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; @@ -623,13 +623,13 @@ NNVM_REGISTER_OP(_backward_BatchNorm) .set_num_outputs(3) .set_attr("TIsBackward", true) .set_attr("FInferStorageType", BatchNormStorageType) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) #endif .set_attr_parser(ParamParser) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", BatchNormGradComputeExCPU) #endif diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 61de08fdde23..ef5886e5c86d 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -26,7 +26,7 @@ #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_NORM_INL_H_ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_NORM_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include #include @@ -44,54 +44,44 @@ typedef mkldnn::batch_normalization_forward::desc t_bn_f_desc; typedef mkldnn::batch_normalization_backward::primitive_desc t_bn_b_pdesc; typedef mkldnn::batch_normalization_backward::desc t_bn_b_desc; -using mkldnn::use_global_stats; -using mkldnn::use_scale_shift; -using mkldnn::forward_training; -using mkldnn::forward_inference; - -inline static unsigned _GetFlags(const std::vector &in_data, +inline static mkldnn::normalization_flags _GetFlags(const std::vector &in_data, const std::vector &aux_states, const BatchNormParam ¶m, bool is_train_and_not_global_stats) { - unsigned flags = 0U; + mkldnn::normalization_flags flags = static_cast(0U); if (in_data.size() == 3U) { - flags |= use_scale_shift; + flags |= mkldnn::normalization_flags::use_scale_shift; } // aux_states[0]: inMean // aux_states[1]: inVariance if (aux_states.size() == 2U && !is_train_and_not_global_stats) { - flags |= use_global_stats; + flags |= mkldnn::normalization_flags::use_global_stats; } return flags; } -template inline static t_bn_f_pdesc _GetFwd(const mkldnn::memory &data_mem, bool is_train, - DType eps, - unsigned flags) { - auto data_mpd = data_mem.get_primitive_desc(); - auto data_md = data_mpd.desc(); - auto engine = CpuEngine::Get()->get_engine(); + float eps, + mkldnn::normalization_flags flags) { + auto data_md = data_mem.get_desc(); + auto engine = CpuEngine::Get()->get_engine(); if (is_train) { - t_bn_f_desc bnFwd_desc(forward_training, data_md, eps, flags); + t_bn_f_desc bnFwd_desc(mkldnn::prop_kind::forward_training, data_md, eps, flags); return t_bn_f_pdesc(bnFwd_desc, engine); } else { - t_bn_f_desc bnFwd_desc(forward_inference, data_md, eps, flags); + t_bn_f_desc bnFwd_desc(mkldnn::prop_kind::forward_inference, data_md, eps, flags); return t_bn_f_pdesc(bnFwd_desc, engine); } } -template inline static t_bn_b_pdesc _GetBwd(const mkldnn::memory &data_mem, const mkldnn::memory &diff_mem, - DType eps, - unsigned flags) { - auto data_mpd = data_mem.get_primitive_desc(); - auto data_md = data_mpd.desc(); - auto diff_mpd = diff_mem.get_primitive_desc(); - auto diff_md = diff_mpd.desc(); + float eps, + mkldnn::normalization_flags flags) { + auto data_md = data_mem.get_desc(); + auto diff_md = diff_mem.get_desc(); auto engine = CpuEngine::Get()->get_engine(); t_bn_b_desc bnBwd_desc(mkldnn::prop_kind::backward, diff_md, data_md, eps, flags); @@ -101,18 +91,15 @@ inline static t_bn_b_pdesc _GetBwd(const mkldnn::memory &data_mem, typedef ParamOpSign MKLDNNBNSignature; class MKLDNNBNForward { - std::shared_ptr data_m; std::shared_ptr weight_m; - std::shared_ptr out_m; - std::shared_ptr mean_m; - std::shared_ptr var_m; std::shared_ptr fwd; bool is_train_and_not_global_stats; t_bn_f_pdesc pd; public: MKLDNNBNForward(const t_bn_f_pdesc &_pd, bool is_train_and_not_global_stats): pd(_pd) { - weight_m.reset(new mkldnn::memory(pd.weights_primitive_desc())); + weight_m.reset(new mkldnn::memory(pd.weights_desc(), CpuEngine::Get()->get_engine())); + fwd.reset(new mkldnn::batch_normalization_forward(pd)); this->is_train_and_not_global_stats = is_train_and_not_global_stats; } @@ -124,59 +111,6 @@ class MKLDNNBNForward { return pd; } - const mkldnn::memory &GetMean() const { - return *mean_m; - } - - const mkldnn::memory &GetVar() const { - return *var_m; - } - - void SetDataHandle(const mkldnn::memory *data, const mkldnn::memory *mean, - const mkldnn::memory *var, const mkldnn::memory *out) { - if (data_m) { - data_m->set_data_handle(data->get_data_handle()); - } else { - data_m.reset(new mkldnn::memory(data->get_primitive_desc(), - data->get_data_handle())); - } - if (out_m) { - out_m->set_data_handle(out->get_data_handle()); - } else { - out_m.reset(new mkldnn::memory(out->get_primitive_desc(), - out->get_data_handle())); - } - if (mean_m) { - mean_m->set_data_handle(mean->get_data_handle()); - } else { - mean_m.reset(new mkldnn::memory(mean->get_primitive_desc(), - mean->get_data_handle())); - } - if (var_m) { - var_m->set_data_handle(var->get_data_handle()); - } else { - var_m.reset(new mkldnn::memory(var->get_primitive_desc(), - var->get_data_handle())); - } - - if (fwd == nullptr) { - if (!is_train_and_not_global_stats) - fwd.reset(new mkldnn::batch_normalization_forward( - pd, *data_m, mkldnn::primitive::at(*mean_m), - mkldnn::primitive::at(*var_m), *weight_m, *out_m)); - else - fwd.reset(new mkldnn::batch_normalization_forward( - pd, mkldnn::primitive::at(*data_m), - mkldnn::primitive::at(*weight_m), *out_m, - *mean_m, *var_m)); - } - } - - void SetDataHandle(const NDArray &data, const NDArray &mean, - const NDArray &var, const mkldnn::memory &out) { - SetDataHandle(data.GetMKLDNNData(), mean.GetMKLDNNData(), var.GetMKLDNNData(), &out); - } - const mkldnn::batch_normalization_forward &GetFwd() const { return *fwd; } @@ -185,7 +119,7 @@ class MKLDNNBNForward { template static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, const OpContext &ctx, const mkldnn::memory *data_mem, - unsigned flags) { + mkldnn::normalization_flags flags) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map fwds; #else @@ -193,13 +127,12 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, #endif MKLDNNBNSignature key(param); key.AddSign(ctx.is_train); - key.AddSign(param.use_global_stats); key.AddSign(*data_mem); auto it = fwds.find(key); if (it == fwds.end()) { auto fwd_pd = _GetFwd(*data_mem, ctx.is_train, - (DType) param.eps, flags); + param.eps, flags); MKLDNNBNForward fwd(fwd_pd, ctx.is_train && !param.use_global_stats); it = AddToCache(&fwds, key, fwd); } @@ -209,7 +142,7 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, template static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, const OpContext &ctx, const NDArray &in_data, - unsigned flags) { + mkldnn::normalization_flags flags) { return GetBNForward(param, ctx, in_data.GetMKLDNNData(), flags); } @@ -220,18 +153,20 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, const std::vector &out_data, const std::vector &aux_states) { TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); - unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats); + mkldnn::normalization_flags flags = _GetFlags(in_data, + aux_states, + param, + ctx.is_train && !param.use_global_stats); const NDArray &data = in_data[batchnorm::kData]; - auto &fwd = GetBNForward(param, ctx, data, flags); - const NDArray &out = out_data[batchnorm::kOut]; + const NDArray &out = out_data[batchnorm::kOut]; // for output memory - auto out_mem = const_cast(out).CreateMKLDNNData(fwd.GetPd().dst_primitive_desc()); + auto out_mem = const_cast(out).CreateMKLDNNData(fwd.GetPd().dst_desc()); // mxnet will always use scale shift. // But if fix_gamma is true, then all scale elements will be set to 1.0f - if (flags & use_scale_shift) { + if (static_cast(flags) & static_cast(mkldnn::normalization_flags::use_scale_shift)) { const NDArray &gamma = in_data[batchnorm::kGamma]; const NDArray &beta = in_data[batchnorm::kBeta]; CHECK_EQ(gamma.storage_type(), mxnet::kDefaultStorage); @@ -241,7 +176,7 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, DType* weight_buf = reinterpret_cast(weight_mem.get_data_handle()); nnvm::dim_t channels_ = data.shape()[1]; - CHECK(weight_mem.get_primitive_desc().get_size() == channels_ * sizeof(DType) * 2); + CHECK(weight_mem.get_desc().get_size() == channels_ * sizeof(DType) * 2); DType* weight_ptr = gamma.data().dptr(); DType* bias_ptr = beta.data().dptr(); if (!param.fix_gamma) { @@ -249,17 +184,22 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, memcpy(&weight_buf[channels_], bias_ptr, sizeof(weight_buf[0]) * channels_); } else if (IsBNWriting(req[batchnorm::kGamma])) { for (int i = 0; i < channels_; i++) { - weight_buf[i] = (DType)1.0f; - weight_ptr[i] = (DType)1.0f; + weight_buf[i] = static_cast(1.0f); + weight_ptr[i] = static_cast(1.0f); weight_buf[channels_ + i] = bias_ptr[i]; // bias } } else { for (int i = 0; i < channels_; i++) { - weight_buf[i] = (DType)1.0f; + weight_buf[i] = static_cast(1.0f); weight_buf[channels_ + i] = bias_ptr[i]; // bias } } + mkldnn_args_map_t net_args; + net_args[MKLDNN_ARG_SRC] = *data.GetMKLDNNData(); + net_args[MKLDNN_ARG_SCALE_SHIFT] = weight_mem; + net_args[MKLDNN_ARG_DST] = *out_mem; + if (!ctx.is_train || param.use_global_stats) { DType* omean = out_data[batchnorm::kMean].data().dptr(); DType* ovar = out_data[batchnorm::kVar].data().dptr(); @@ -270,26 +210,21 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, omean[i] = inmean[i]; ovar[i] = VARIANCE_TO_INVSTD(invar[i], param.eps); } - - fwd.SetDataHandle(data, aux_states[batchnorm::kMovingMean], - aux_states[batchnorm::kMovingVar], - *out_mem); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + net_args[MKLDNN_ARG_MEAN] = *(aux_states[batchnorm::kMovingMean].GetMKLDNNData()); + net_args[MKLDNN_ARG_VARIANCE] = *(aux_states[batchnorm::kMovingVar].GetMKLDNNData()); + MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); MKLDNNStream::Get()->Submit(); } else { // training const NDArray &outMean = out_data[batchnorm::kMean]; const NDArray &outVar = out_data[batchnorm::kVar]; - DType* omean = outMean.data().dptr(); - DType* ovar = outVar.data().dptr(); - - fwd.SetDataHandle(data, outMean, outVar, *out_mem); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + net_args[MKLDNN_ARG_MEAN] = *(outMean.GetMKLDNNData()); + net_args[MKLDNN_ARG_VARIANCE] = *(outVar.GetMKLDNNData()); + MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); MKLDNNStream::Get()->Submit(); - DType* mean_mem_ptr = reinterpret_cast(fwd.GetMean().get_data_handle()); - DType* var_mem_ptr = reinterpret_cast(fwd.GetVar().get_data_handle()); + + DType* ovar = outVar.data().dptr(); for (int i = 0; i < channels_; i++) { - omean[i] = mean_mem_ptr[i]; - ovar[i] = VARIANCE_TO_INVSTD(var_mem_ptr[i], param.eps); + ovar[i] = VARIANCE_TO_INVSTD(ovar[i], param.eps); } } } else { // no input gamma and beta @@ -299,11 +234,6 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, class MKLDNNBNBackward { std::shared_ptr bwd; - std::shared_ptr data_m; - std::shared_ptr diff_m; - std::shared_ptr gradi_m; - std::shared_ptr mean_m; - std::shared_ptr var_m; const std::shared_ptr weight_m; const std::shared_ptr gradw_m; @@ -311,41 +241,16 @@ class MKLDNNBNBackward { const t_bn_b_pdesc pd; explicit MKLDNNBNBackward(const t_bn_b_pdesc &_pd) - : weight_m(new mkldnn::memory(_pd.weights_primitive_desc())), - gradw_m(new mkldnn::memory(_pd.diff_weights_primitive_desc())), - pd(_pd) {} + : weight_m(new mkldnn::memory(_pd.weights_desc(), CpuEngine::Get()->get_engine())), + gradw_m(new mkldnn::memory(_pd.diff_weights_desc(), CpuEngine::Get()->get_engine())), + pd(_pd) { + bwd.reset(new mkldnn::batch_normalization_backward(pd)); + } const mkldnn::memory &GetWeight() const { return *weight_m; } const mkldnn::memory &GetGradw() const { return *gradw_m; } - void SetDataHandle(const mkldnn::memory &data, const mkldnn::memory &diff, - const NDArray &mean, const mkldnn::memory &var, - const mkldnn::memory &gradi) { - auto mean_ptr = mean.data().dptr_; - if (bwd == nullptr) { - data_m.reset(new mkldnn::memory(data.get_primitive_desc(), - data.get_data_handle())); - diff_m.reset(new mkldnn::memory(diff.get_primitive_desc(), - diff.get_data_handle())); - gradi_m.reset(new mkldnn::memory(gradi.get_primitive_desc(), - gradi.get_data_handle())); - mean_m.reset(new mkldnn::memory(pd.mean_primitive_desc(), mean_ptr)); - var_m.reset(new mkldnn::memory(pd.variance_primitive_desc(), - var.get_data_handle())); - bwd.reset(new mkldnn::batch_normalization_backward( - pd, *data_m, mkldnn::primitive::at(*mean_m), - mkldnn::primitive::at(*var_m), *diff_m, *weight_m, *gradi_m, - *gradw_m)); - } else { - data_m->set_data_handle(data.get_data_handle()); - diff_m->set_data_handle(diff.get_data_handle()); - gradi_m->set_data_handle(gradi.get_data_handle()); - mean_m->set_data_handle(mean_ptr); - var_m->set_data_handle(var.get_data_handle()); - } - } - const mkldnn::batch_normalization_backward &GetBwd() const { return *bwd; } }; @@ -353,7 +258,7 @@ template static MKLDNNBNBackward &GetBNBackward( const BatchNormParam ¶m, const OpContext &ctx, const NDArray &in_data, const mkldnn::memory &in_mem, const NDArray &diff_data, - const mkldnn::memory &diff_mem, unsigned flags) { + const mkldnn::memory &diff_mem, mkldnn::normalization_flags flags) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map bwds; #else @@ -385,7 +290,10 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, CHECK_EQ(in_data.size(), 3U); CHECK_EQ(out_data.size(), 3U); CHECK_EQ(in_grad.size(), 3U); - unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats); + mkldnn::normalization_flags flags = _GetFlags(in_data, + aux_states, + param, + ctx.is_train && !param.use_global_stats); const NDArray &data = in_data[batchnorm::kData]; const NDArray &diff = out_grad[batchnorm::kOut]; @@ -405,13 +313,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, // MKLDNN batchnorm should run on special layouts. If one of them isn't, we // should reorder them. if (data.IsDefaultData()) - data_mem = data.GetMKLDNNDataReorder(diff_mem->get_primitive_desc()); + data_mem = data.GetMKLDNNDataReorder(diff_mem->get_desc()); else if (diff.IsDefaultData()) - diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_primitive_desc()); + diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc()); auto &bwd = GetBNBackward(param, ctx, data, *data_mem, diff, *diff_mem, flags); - auto gradi_mem = const_cast(gradIn).CreateMKLDNNData(data_mem->get_primitive_desc()); + auto gradi_mem = const_cast(gradIn).CreateMKLDNNData(data_mem->get_desc()); - if (flags & use_scale_shift) { + if (static_cast(flags) & static_cast(mkldnn::normalization_flags::use_scale_shift)) { const NDArray &gamma = in_data[batchnorm::kGamma]; const NDArray &beta = in_data[batchnorm::kBeta]; DType *weight_buf = reinterpret_cast(bwd.GetWeight().get_data_handle()); @@ -420,20 +328,27 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, if (!param.fix_gamma) weight_buf[i] = (gamma.data().dptr())[i]; // weight else - weight_buf[i] = (DType)1.0f; + weight_buf[i] = static_cast(1.0f); } for (int i = 0; i < channels_; i++) { weight_buf[channels_ + i] = (beta.data().dptr())[i]; // bias } + mkldnn_args_map_t net_args; + net_args[MKLDNN_ARG_SRC] = *data_mem; + net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem; + net_args[MKLDNN_ARG_SCALE_SHIFT] = bwd.GetWeight(); + net_args[MKLDNN_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw(); + net_args[MKLDNN_ARG_DIFF_DST] = *diff_mem; + // training but no input mean and variance if (ctx.is_train && !param.use_global_stats) { DType* moving_mean_ptr = reinterpret_cast(moving_mean.data().dptr()); DType* moving_var_ptr = reinterpret_cast(moving_var.data().dptr()); DType* out_mean_ptr = reinterpret_cast(out_mean.data().dptr()); DType* out_var_ptr = reinterpret_cast(out_var.data().dptr()); - mkldnn::memory var_mem(bwd.pd.variance_primitive_desc()); + mkldnn::memory var_mem(bwd.pd.variance_desc(), CpuEngine::Get()->get_engine()); DType *tmp_var_ptr = reinterpret_cast(var_mem.get_data_handle()); DType minus_mom = (1.0f - param.momentum); @@ -445,13 +360,14 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, moving_var_ptr[i] = moving_var_ptr[i] * param.momentum + variance * minus_mom; } - bwd.SetDataHandle(*data_mem, *diff_mem, out_mean, var_mem, *gradi_mem); - MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd()); + net_args[MKLDNN_ARG_MEAN] = *(out_mean.GetMKLDNNData()); + net_args[MKLDNN_ARG_VARIANCE] = var_mem; + MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); MKLDNNStream::Get()->Submit(); } else { - bwd.SetDataHandle(*data_mem, *diff_mem, moving_mean, - *moving_var.GetMKLDNNData(), *gradi_mem); - MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd()); + net_args[MKLDNN_ARG_MEAN] = *(moving_mean.GetMKLDNNData()); + net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData()); + MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); MKLDNNStream::Get()->Submit(); }