From 4ca32b4f1e950c2342a5217d2184c1085e18bce0 Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Wed, 18 Sep 2019 22:28:07 +0800 Subject: [PATCH 1/4] add mkldnn bn --- src/operator/nn/batch_norm.cc | 14 +- .../nn/mkldnn/mkldnn_batch_norm-inl.h | 222 ++++++------------ 2 files changed, 77 insertions(+), 159 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 3214e3b9b9ac..da042a1d21fb 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -28,7 +28,7 @@ #include #include "../elemwise_op_common.h" #include "../operator_common.h" -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include "./mkldnn/mkldnn_batch_norm-inl.h" #endif @@ -379,7 +379,7 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, return true; } -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam ¶m) { mxnet::TShape shape = input.shape(); return SupportMKLDNN(input) && shape.ndim() == 4 @@ -454,7 +454,7 @@ static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs, const BatchNormParam ¶m = nnvm::get(attrs.parsed); bool dispatched = false; -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 if (!dispatched) { dispatched = MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); @@ -592,11 +592,11 @@ then set ``gamma`` to 1 and its gradient to 0. .set_attr("FInferType", BatchNormType) .set_attr("FInferStorageType", BatchNormStorageType) .set_attr("FCompute", BatchNormCompute) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("FComputeEx", BatchNormComputeExCPU) #endif .set_attr("FGradient", BatchNormGrad) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; @@ -623,13 +623,13 @@ NNVM_REGISTER_OP(_backward_BatchNorm) .set_num_outputs(3) .set_attr("TIsBackward", true) .set_attr("FInferStorageType", BatchNormStorageType) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) #endif .set_attr_parser(ParamParser) -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", BatchNormGradComputeExCPU) #endif diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 61de08fdde23..f2a9b3228403 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -26,7 +26,7 @@ #ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_NORM_INL_H_ #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_BATCH_NORM_INL_H_ -#if MXNET_USE_MKLDNN == 1 +#if MXNET_USE_MKLDNN == 100 #include #include #include @@ -44,54 +44,44 @@ typedef mkldnn::batch_normalization_forward::desc t_bn_f_desc; typedef mkldnn::batch_normalization_backward::primitive_desc t_bn_b_pdesc; typedef mkldnn::batch_normalization_backward::desc t_bn_b_desc; -using mkldnn::use_global_stats; -using mkldnn::use_scale_shift; -using mkldnn::forward_training; -using mkldnn::forward_inference; - -inline static unsigned _GetFlags(const std::vector &in_data, +inline static mkldnn::normalization_flags _GetFlags(const std::vector &in_data, const std::vector &aux_states, const BatchNormParam ¶m, bool is_train_and_not_global_stats) { - unsigned flags = 0U; + mkldnn::normalization_flags flags = static_cast(0U); if (in_data.size() == 3U) { - flags |= use_scale_shift; + flags |= mkldnn::normalization_flags::use_scale_shift; } // aux_states[0]: inMean // aux_states[1]: inVariance if (aux_states.size() == 2U && !is_train_and_not_global_stats) { - flags |= use_global_stats; + flags |= mkldnn::normalization_flags::use_global_stats; } return flags; } -template inline static t_bn_f_pdesc _GetFwd(const mkldnn::memory &data_mem, bool is_train, - DType eps, - unsigned flags) { - auto data_mpd = data_mem.get_primitive_desc(); - auto data_md = data_mpd.desc(); - auto engine = CpuEngine::Get()->get_engine(); + float eps, + mkldnn::normalization_flags flags) { + auto data_md = data_mem.get_desc(); + auto engine = CpuEngine::Get()->get_engine(); if (is_train) { - t_bn_f_desc bnFwd_desc(forward_training, data_md, eps, flags); + t_bn_f_desc bnFwd_desc(mkldnn::prop_kind::forward_training, data_md, eps, flags); return t_bn_f_pdesc(bnFwd_desc, engine); } else { - t_bn_f_desc bnFwd_desc(forward_inference, data_md, eps, flags); + t_bn_f_desc bnFwd_desc(mkldnn::prop_kind::forward_inference, data_md, eps, flags); return t_bn_f_pdesc(bnFwd_desc, engine); } } -template inline static t_bn_b_pdesc _GetBwd(const mkldnn::memory &data_mem, const mkldnn::memory &diff_mem, - DType eps, - unsigned flags) { - auto data_mpd = data_mem.get_primitive_desc(); - auto data_md = data_mpd.desc(); - auto diff_mpd = diff_mem.get_primitive_desc(); - auto diff_md = diff_mpd.desc(); + float eps, + mkldnn::normalization_flags flags) { + auto data_md = data_mem.get_desc(); + auto diff_md = diff_mem.get_desc(); auto engine = CpuEngine::Get()->get_engine(); t_bn_b_desc bnBwd_desc(mkldnn::prop_kind::backward, diff_md, data_md, eps, flags); @@ -101,18 +91,15 @@ inline static t_bn_b_pdesc _GetBwd(const mkldnn::memory &data_mem, typedef ParamOpSign MKLDNNBNSignature; class MKLDNNBNForward { - std::shared_ptr data_m; std::shared_ptr weight_m; - std::shared_ptr out_m; - std::shared_ptr mean_m; - std::shared_ptr var_m; std::shared_ptr fwd; bool is_train_and_not_global_stats; t_bn_f_pdesc pd; public: MKLDNNBNForward(const t_bn_f_pdesc &_pd, bool is_train_and_not_global_stats): pd(_pd) { - weight_m.reset(new mkldnn::memory(pd.weights_primitive_desc())); + weight_m.reset(new mkldnn::memory(pd.weights_desc(), CpuEngine::Get()->get_engine())); + fwd.reset(new mkldnn::batch_normalization_forward(pd)); this->is_train_and_not_global_stats = is_train_and_not_global_stats; } @@ -124,59 +111,6 @@ class MKLDNNBNForward { return pd; } - const mkldnn::memory &GetMean() const { - return *mean_m; - } - - const mkldnn::memory &GetVar() const { - return *var_m; - } - - void SetDataHandle(const mkldnn::memory *data, const mkldnn::memory *mean, - const mkldnn::memory *var, const mkldnn::memory *out) { - if (data_m) { - data_m->set_data_handle(data->get_data_handle()); - } else { - data_m.reset(new mkldnn::memory(data->get_primitive_desc(), - data->get_data_handle())); - } - if (out_m) { - out_m->set_data_handle(out->get_data_handle()); - } else { - out_m.reset(new mkldnn::memory(out->get_primitive_desc(), - out->get_data_handle())); - } - if (mean_m) { - mean_m->set_data_handle(mean->get_data_handle()); - } else { - mean_m.reset(new mkldnn::memory(mean->get_primitive_desc(), - mean->get_data_handle())); - } - if (var_m) { - var_m->set_data_handle(var->get_data_handle()); - } else { - var_m.reset(new mkldnn::memory(var->get_primitive_desc(), - var->get_data_handle())); - } - - if (fwd == nullptr) { - if (!is_train_and_not_global_stats) - fwd.reset(new mkldnn::batch_normalization_forward( - pd, *data_m, mkldnn::primitive::at(*mean_m), - mkldnn::primitive::at(*var_m), *weight_m, *out_m)); - else - fwd.reset(new mkldnn::batch_normalization_forward( - pd, mkldnn::primitive::at(*data_m), - mkldnn::primitive::at(*weight_m), *out_m, - *mean_m, *var_m)); - } - } - - void SetDataHandle(const NDArray &data, const NDArray &mean, - const NDArray &var, const mkldnn::memory &out) { - SetDataHandle(data.GetMKLDNNData(), mean.GetMKLDNNData(), var.GetMKLDNNData(), &out); - } - const mkldnn::batch_normalization_forward &GetFwd() const { return *fwd; } @@ -185,7 +119,7 @@ class MKLDNNBNForward { template static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, const OpContext &ctx, const mkldnn::memory *data_mem, - unsigned flags) { + mkldnn::normalization_flags flags) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map fwds; #else @@ -193,13 +127,12 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, #endif MKLDNNBNSignature key(param); key.AddSign(ctx.is_train); - key.AddSign(param.use_global_stats); key.AddSign(*data_mem); auto it = fwds.find(key); if (it == fwds.end()) { auto fwd_pd = _GetFwd(*data_mem, ctx.is_train, - (DType) param.eps, flags); + param.eps, flags); MKLDNNBNForward fwd(fwd_pd, ctx.is_train && !param.use_global_stats); it = AddToCache(&fwds, key, fwd); } @@ -209,7 +142,7 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, template static MKLDNNBNForward &GetBNForward(const BatchNormParam& param, const OpContext &ctx, const NDArray &in_data, - unsigned flags) { + mkldnn::normalization_flags flags) { return GetBNForward(param, ctx, in_data.GetMKLDNNData(), flags); } @@ -220,18 +153,20 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, const std::vector &out_data, const std::vector &aux_states) { TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); - unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats); + mkldnn::normalization_flags flags = _GetFlags(in_data, + aux_states, + param, + ctx.is_train && !param.use_global_stats); const NDArray &data = in_data[batchnorm::kData]; - auto &fwd = GetBNForward(param, ctx, data, flags); - const NDArray &out = out_data[batchnorm::kOut]; + const NDArray &out = out_data[batchnorm::kOut]; // for output memory - auto out_mem = const_cast(out).CreateMKLDNNData(fwd.GetPd().dst_primitive_desc()); + auto out_mem = const_cast(out).CreateMKLDNNData(fwd.GetPd().dst_desc()); // mxnet will always use scale shift. // But if fix_gamma is true, then all scale elements will be set to 1.0f - if (flags & use_scale_shift) { + if (static_cast(flags) & static_cast(mkldnn::normalization_flags::use_scale_shift)) { const NDArray &gamma = in_data[batchnorm::kGamma]; const NDArray &beta = in_data[batchnorm::kBeta]; CHECK_EQ(gamma.storage_type(), mxnet::kDefaultStorage); @@ -241,7 +176,7 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, DType* weight_buf = reinterpret_cast(weight_mem.get_data_handle()); nnvm::dim_t channels_ = data.shape()[1]; - CHECK(weight_mem.get_primitive_desc().get_size() == channels_ * sizeof(DType) * 2); + CHECK(weight_mem.get_desc().get_size() == channels_ * sizeof(DType) * 2); DType* weight_ptr = gamma.data().dptr(); DType* bias_ptr = beta.data().dptr(); if (!param.fix_gamma) { @@ -260,6 +195,11 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, } } + std::unordered_map net_args; + net_args.insert({MKLDNN_ARG_SRC, *(data.GetMKLDNNData())}); + net_args.insert({MKLDNN_ARG_SCALE_SHIFT, weight_mem}); + net_args.insert({MKLDNN_ARG_DST, *out_mem}); + if (!ctx.is_train || param.use_global_stats) { DType* omean = out_data[batchnorm::kMean].data().dptr(); DType* ovar = out_data[batchnorm::kVar].data().dptr(); @@ -271,25 +211,22 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, ovar[i] = VARIANCE_TO_INVSTD(invar[i], param.eps); } - fwd.SetDataHandle(data, aux_states[batchnorm::kMovingMean], - aux_states[batchnorm::kMovingVar], - *out_mem); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + net_args.insert({MKLDNN_ARG_MEAN, *(aux_states[batchnorm::kMovingMean].GetMKLDNNData())}); + net_args.insert({MKLDNN_ARG_VARIANCE, *(aux_states[batchnorm::kMovingVar].GetMKLDNNData())}); + MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); MKLDNNStream::Get()->Submit(); } else { // training const NDArray &outMean = out_data[batchnorm::kMean]; const NDArray &outVar = out_data[batchnorm::kVar]; - DType* omean = outMean.data().dptr(); - DType* ovar = outVar.data().dptr(); - fwd.SetDataHandle(data, outMean, outVar, *out_mem); - MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd()); + net_args.insert({MKLDNN_ARG_MEAN, *(outMean.GetMKLDNNData())}); + net_args.insert({MKLDNN_ARG_VARIANCE, *(outVar.GetMKLDNNData())}); + MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); MKLDNNStream::Get()->Submit(); - DType* mean_mem_ptr = reinterpret_cast(fwd.GetMean().get_data_handle()); - DType* var_mem_ptr = reinterpret_cast(fwd.GetVar().get_data_handle()); + + DType* ovar = outVar.data().dptr(); for (int i = 0; i < channels_; i++) { - omean[i] = mean_mem_ptr[i]; - ovar[i] = VARIANCE_TO_INVSTD(var_mem_ptr[i], param.eps); + ovar[i] = VARIANCE_TO_INVSTD(ovar[i], param.eps); } } } else { // no input gamma and beta @@ -299,11 +236,6 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, class MKLDNNBNBackward { std::shared_ptr bwd; - std::shared_ptr data_m; - std::shared_ptr diff_m; - std::shared_ptr gradi_m; - std::shared_ptr mean_m; - std::shared_ptr var_m; const std::shared_ptr weight_m; const std::shared_ptr gradw_m; @@ -311,41 +243,16 @@ class MKLDNNBNBackward { const t_bn_b_pdesc pd; explicit MKLDNNBNBackward(const t_bn_b_pdesc &_pd) - : weight_m(new mkldnn::memory(_pd.weights_primitive_desc())), - gradw_m(new mkldnn::memory(_pd.diff_weights_primitive_desc())), - pd(_pd) {} + : weight_m(new mkldnn::memory(_pd.weights_desc(), CpuEngine::Get()->get_engine())), + gradw_m(new mkldnn::memory(_pd.diff_weights_desc(), CpuEngine::Get()->get_engine())), + pd(_pd) { + bwd.reset(new mkldnn::batch_normalization_backward(pd)); + } const mkldnn::memory &GetWeight() const { return *weight_m; } const mkldnn::memory &GetGradw() const { return *gradw_m; } - void SetDataHandle(const mkldnn::memory &data, const mkldnn::memory &diff, - const NDArray &mean, const mkldnn::memory &var, - const mkldnn::memory &gradi) { - auto mean_ptr = mean.data().dptr_; - if (bwd == nullptr) { - data_m.reset(new mkldnn::memory(data.get_primitive_desc(), - data.get_data_handle())); - diff_m.reset(new mkldnn::memory(diff.get_primitive_desc(), - diff.get_data_handle())); - gradi_m.reset(new mkldnn::memory(gradi.get_primitive_desc(), - gradi.get_data_handle())); - mean_m.reset(new mkldnn::memory(pd.mean_primitive_desc(), mean_ptr)); - var_m.reset(new mkldnn::memory(pd.variance_primitive_desc(), - var.get_data_handle())); - bwd.reset(new mkldnn::batch_normalization_backward( - pd, *data_m, mkldnn::primitive::at(*mean_m), - mkldnn::primitive::at(*var_m), *diff_m, *weight_m, *gradi_m, - *gradw_m)); - } else { - data_m->set_data_handle(data.get_data_handle()); - diff_m->set_data_handle(diff.get_data_handle()); - gradi_m->set_data_handle(gradi.get_data_handle()); - mean_m->set_data_handle(mean_ptr); - var_m->set_data_handle(var.get_data_handle()); - } - } - const mkldnn::batch_normalization_backward &GetBwd() const { return *bwd; } }; @@ -353,7 +260,7 @@ template static MKLDNNBNBackward &GetBNBackward( const BatchNormParam ¶m, const OpContext &ctx, const NDArray &in_data, const mkldnn::memory &in_mem, const NDArray &diff_data, - const mkldnn::memory &diff_mem, unsigned flags) { + const mkldnn::memory &diff_mem, mkldnn::normalization_flags flags) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map bwds; #else @@ -385,7 +292,10 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, CHECK_EQ(in_data.size(), 3U); CHECK_EQ(out_data.size(), 3U); CHECK_EQ(in_grad.size(), 3U); - unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats); + mkldnn::normalization_flags flags = _GetFlags(in_data, + aux_states, + param, + ctx.is_train && !param.use_global_stats); const NDArray &data = in_data[batchnorm::kData]; const NDArray &diff = out_grad[batchnorm::kOut]; @@ -405,13 +315,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, // MKLDNN batchnorm should run on special layouts. If one of them isn't, we // should reorder them. if (data.IsDefaultData()) - data_mem = data.GetMKLDNNDataReorder(diff_mem->get_primitive_desc()); + data_mem = data.GetMKLDNNDataReorder(diff_mem->get_desc()); else if (diff.IsDefaultData()) - diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_primitive_desc()); + diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc()); auto &bwd = GetBNBackward(param, ctx, data, *data_mem, diff, *diff_mem, flags); - auto gradi_mem = const_cast(gradIn).CreateMKLDNNData(data_mem->get_primitive_desc()); + auto gradi_mem = const_cast(gradIn).CreateMKLDNNData(data_mem->get_desc()); - if (flags & use_scale_shift) { + if (static_cast(flags) & static_cast(mkldnn::normalization_flags::use_scale_shift)) { const NDArray &gamma = in_data[batchnorm::kGamma]; const NDArray &beta = in_data[batchnorm::kBeta]; DType *weight_buf = reinterpret_cast(bwd.GetWeight().get_data_handle()); @@ -427,13 +337,20 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, weight_buf[channels_ + i] = (beta.data().dptr())[i]; // bias } + std::unordered_map net_args; + net_args.insert({MKLDNN_ARG_SRC, *data_mem}); + net_args.insert({MKLDNN_ARG_DIFF_SRC, *gradi_mem}); + net_args.insert({MKLDNN_ARG_SCALE_SHIFT, bwd.GetWeight()}); + net_args.insert({MKLDNN_ARG_DIFF_SCALE_SHIFT, bwd.GetGradw()}); + net_args.insert({MKLDNN_ARG_DIFF_DST, *diff_mem}); + // training but no input mean and variance if (ctx.is_train && !param.use_global_stats) { DType* moving_mean_ptr = reinterpret_cast(moving_mean.data().dptr()); DType* moving_var_ptr = reinterpret_cast(moving_var.data().dptr()); DType* out_mean_ptr = reinterpret_cast(out_mean.data().dptr()); DType* out_var_ptr = reinterpret_cast(out_var.data().dptr()); - mkldnn::memory var_mem(bwd.pd.variance_primitive_desc()); + mkldnn::memory var_mem(bwd.pd.variance_desc(), CpuEngine::Get()->get_engine()); DType *tmp_var_ptr = reinterpret_cast(var_mem.get_data_handle()); DType minus_mom = (1.0f - param.momentum); @@ -445,13 +362,14 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, moving_var_ptr[i] = moving_var_ptr[i] * param.momentum + variance * minus_mom; } - bwd.SetDataHandle(*data_mem, *diff_mem, out_mean, var_mem, *gradi_mem); - MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd()); + net_args.insert({MKLDNN_ARG_MEAN, *(out_mean.GetMKLDNNData())}); + net_args.insert({MKLDNN_ARG_VARIANCE, var_mem}); + MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); MKLDNNStream::Get()->Submit(); } else { - bwd.SetDataHandle(*data_mem, *diff_mem, moving_mean, - *moving_var.GetMKLDNNData(), *gradi_mem); - MKLDNNStream::Get()->RegisterPrim(bwd.GetBwd()); + net_args.insert({MKLDNN_ARG_MEAN, *(moving_mean.GetMKLDNNData())}); + net_args.insert({MKLDNN_ARG_VARIANCE, *(moving_var.GetMKLDNNData())}); + MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); MKLDNNStream::Get()->Submit(); } From 8816cdb4c63cd78bd32dd0e863f5e274202f3d69 Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Thu, 19 Sep 2019 13:52:16 +0800 Subject: [PATCH 2/4] add static_cast to transform data type --- src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index f2a9b3228403..0fdd1ac66545 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -184,13 +184,13 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, memcpy(&weight_buf[channels_], bias_ptr, sizeof(weight_buf[0]) * channels_); } else if (IsBNWriting(req[batchnorm::kGamma])) { for (int i = 0; i < channels_; i++) { - weight_buf[i] = (DType)1.0f; - weight_ptr[i] = (DType)1.0f; + weight_buf[i] = static_cast(1.0f); + weight_ptr[i] = static_cast(1.0f); weight_buf[channels_ + i] = bias_ptr[i]; // bias } } else { for (int i = 0; i < channels_; i++) { - weight_buf[i] = (DType)1.0f; + weight_buf[i] = static_cast(1.0f); weight_buf[channels_ + i] = bias_ptr[i]; // bias } } @@ -224,7 +224,7 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); MKLDNNStream::Get()->Submit(); - DType* ovar = outVar.data().dptr(); + DType* ovar = outVar.data().dptr(); for (int i = 0; i < channels_; i++) { ovar[i] = VARIANCE_TO_INVSTD(ovar[i], param.eps); } @@ -330,7 +330,7 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, if (!param.fix_gamma) weight_buf[i] = (gamma.data().dptr())[i]; // weight else - weight_buf[i] = (DType)1.0f; + weight_buf[i] = static_cast(1.0f); } for (int i = 0; i < channels_; i++) { From 1eb2c81ce8be955f88c3d1899190ee29fc737585 Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Thu, 19 Sep 2019 22:04:19 +0800 Subject: [PATCH 3/4] change mkldnn_args_map_t --- .../nn/mkldnn/mkldnn_batch_norm-inl.h | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 0fdd1ac66545..ef5886e5c86d 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -195,10 +195,10 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, } } - std::unordered_map net_args; - net_args.insert({MKLDNN_ARG_SRC, *(data.GetMKLDNNData())}); - net_args.insert({MKLDNN_ARG_SCALE_SHIFT, weight_mem}); - net_args.insert({MKLDNN_ARG_DST, *out_mem}); + mkldnn_args_map_t net_args; + net_args[MKLDNN_ARG_SRC] = *data.GetMKLDNNData(); + net_args[MKLDNN_ARG_SCALE_SHIFT] = weight_mem; + net_args[MKLDNN_ARG_DST] = *out_mem; if (!ctx.is_train || param.use_global_stats) { DType* omean = out_data[batchnorm::kMean].data().dptr(); @@ -210,17 +210,15 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, omean[i] = inmean[i]; ovar[i] = VARIANCE_TO_INVSTD(invar[i], param.eps); } - - net_args.insert({MKLDNN_ARG_MEAN, *(aux_states[batchnorm::kMovingMean].GetMKLDNNData())}); - net_args.insert({MKLDNN_ARG_VARIANCE, *(aux_states[batchnorm::kMovingVar].GetMKLDNNData())}); + net_args[MKLDNN_ARG_MEAN] = *(aux_states[batchnorm::kMovingMean].GetMKLDNNData()); + net_args[MKLDNN_ARG_VARIANCE] = *(aux_states[batchnorm::kMovingVar].GetMKLDNNData()); MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); MKLDNNStream::Get()->Submit(); } else { // training const NDArray &outMean = out_data[batchnorm::kMean]; const NDArray &outVar = out_data[batchnorm::kVar]; - - net_args.insert({MKLDNN_ARG_MEAN, *(outMean.GetMKLDNNData())}); - net_args.insert({MKLDNN_ARG_VARIANCE, *(outVar.GetMKLDNNData())}); + net_args[MKLDNN_ARG_MEAN] = *(outMean.GetMKLDNNData()); + net_args[MKLDNN_ARG_VARIANCE] = *(outVar.GetMKLDNNData()); MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args); MKLDNNStream::Get()->Submit(); @@ -337,12 +335,12 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, weight_buf[channels_ + i] = (beta.data().dptr())[i]; // bias } - std::unordered_map net_args; - net_args.insert({MKLDNN_ARG_SRC, *data_mem}); - net_args.insert({MKLDNN_ARG_DIFF_SRC, *gradi_mem}); - net_args.insert({MKLDNN_ARG_SCALE_SHIFT, bwd.GetWeight()}); - net_args.insert({MKLDNN_ARG_DIFF_SCALE_SHIFT, bwd.GetGradw()}); - net_args.insert({MKLDNN_ARG_DIFF_DST, *diff_mem}); + mkldnn_args_map_t net_args; + net_args[MKLDNN_ARG_SRC] = *data_mem; + net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem; + net_args[MKLDNN_ARG_SCALE_SHIFT] = bwd.GetWeight(); + net_args[MKLDNN_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw(); + net_args[MKLDNN_ARG_DIFF_DST] = *diff_mem; // training but no input mean and variance if (ctx.is_train && !param.use_global_stats) { @@ -362,13 +360,13 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam ¶m, moving_var_ptr[i] = moving_var_ptr[i] * param.momentum + variance * minus_mom; } - net_args.insert({MKLDNN_ARG_MEAN, *(out_mean.GetMKLDNNData())}); - net_args.insert({MKLDNN_ARG_VARIANCE, var_mem}); + net_args[MKLDNN_ARG_MEAN] = *(out_mean.GetMKLDNNData()); + net_args[MKLDNN_ARG_VARIANCE] = var_mem; MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); MKLDNNStream::Get()->Submit(); } else { - net_args.insert({MKLDNN_ARG_MEAN, *(moving_mean.GetMKLDNNData())}); - net_args.insert({MKLDNN_ARG_VARIANCE, *(moving_var.GetMKLDNNData())}); + net_args[MKLDNN_ARG_MEAN] = *(moving_mean.GetMKLDNNData()); + net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData()); MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args); MKLDNNStream::Get()->Submit(); } From 0e0812b55c7544ee00ccdb45165017b27aea7c2a Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Mon, 23 Sep 2019 16:34:50 +0800 Subject: [PATCH 4/4] retrigger CI