From a273dc903f585ecf16431ddae2f5a567506f40ec Mon Sep 17 00:00:00 2001 From: JackieWu Date: Thu, 9 Jul 2020 08:01:35 +0800 Subject: [PATCH] [Improvement] Invoke mkldnn and cudnn BatchNorm when axis != 1 (#18504) * fix batch norm when fix_gamma is True * support gradient accumulation for batch norm * mkldnn batchnorm support grad add * unittest for bn * fix bn arg * fix lint * fix mkldnn * fix mkldnn bn * fix grad when fixing gamma * fix naive gpu bn * fix lint * invoke mkldnn and cudnn batchnorm when axis != 1 * backport 18500 * change condition * fix * fix * add mkldnn_off for bn * remove mkldnn_off * recover save_000800.json * cast --- src/operator/nn/batch_norm.cc | 12 +++-- src/operator/nn/batch_norm.cu | 6 +-- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 26 ++++++++--- .../nn/mkldnn/mkldnn_batch_norm-inl.h | 44 ++++++++++++++++--- 4 files changed, 68 insertions(+), 20 deletions(-) diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index b8961df1b782..42cb6c22a5ea 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -422,10 +422,14 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, #if MXNET_USE_MKLDNN == 1 static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam ¶m) { - mxnet::TShape shape = input.shape(); - return SupportMKLDNN(input) && shape.ndim() == 4 - && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS - && !mxnet::op::batchnorm::disable_mkl; + if (mxnet::op::batchnorm::disable_mkl) return false; + const mxnet::TShape shape = input.shape(); + const int ndim = shape.ndim(); + if (ndim == 0 || shape.Size() == 0) return false; + const int dtype = input.dtype(); + return (dtype == mshadow::kFloat32 || + dtype == mshadow::kBfloat16) && + SupportStorageMKLDNN(input.storage_type()); } void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs, diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 7b36d25e7496..40d677a8f16d 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -704,8 +704,7 @@ void BatchNormCompute(const nnvm::NodeAttrs& attrs, param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); #if MXNET_USE_CUDNN == 1 - if (!param.use_global_stats && !param.cudnn_off - && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { + if (!param.use_global_stats && !param.cudnn_off) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { GetCuDNNOp(param).Forward(ctx, in_data, req, outputs, aux_states); }) @@ -733,8 +732,7 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); #if MXNET_USE_CUDNN == 1 - if (!param.use_global_stats && !param.cudnn_off - && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { + if (!param.use_global_stats && !param.cudnn_off) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { GetCuDNNOp(param).Backward(ctx, inputs, req, outputs); }) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index fc91212fab37..797234c58f62 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -260,15 +260,27 @@ class CuDNNBatchNormOp { private: void Init(const TBlob &in_data) { - if (in_data.ndim() == 4) { - for (int i = 0; i < 4; ++i) - shape_[i] = in_data.shape_[i]; + CHECK_GE(param_.axis, 0); + CHECK_LT(param_.axis, in_data.ndim()); + if (param_.axis == 1) { + if (in_data.ndim() == 4) { + for (int i = 0; i < 4; ++i) + shape_[i] = in_data.shape_[i]; + } else { + // when in_data.ndim() != 4 + shape_[0] = in_data.shape_[0]; + shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; + shape_[2] = 1; + shape_[3] = static_cast(in_data.shape_.ProdShape(2, + in_data.ndim())); + } } else { - // when in_data.ndim() != 4 - shape_[0] = in_data.shape_[0]; - shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; + // reshape to (N, C, 1, D), C is the `param_.axis` dimension + shape_[0] = static_cast(in_data.shape_.ProdShape(0, param_.axis)); + shape_[1] = in_data.shape_[param_.axis]; shape_[2] = 1; - shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim()); + shape_[3] = static_cast(in_data.shape_.ProdShape(param_.axis + 1, + in_data.ndim())); } CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_, diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index 2021ba02c144..18055caaaeba 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -151,7 +151,25 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { const BatchNormParam ¶m = nnvm::get(attrs.parsed); - const std::vector in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); + std::vector in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); + + mxnet::TShape shape = inputs[batchnorm::kData].shape(); + const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); + CHECK_LT(real_axis, shape.ndim()); + NDArray out = outputs[batchnorm::kOut]; + if (param.axis != 1 || shape.ndim() != 4) { + // reshape to (N, C, 1, D) + mxnet::TShape new_shape{ + static_cast(shape.ProdShape(0, real_axis)), + shape[real_axis], + 1, + static_cast(shape.ProdShape(real_axis + 1, + static_cast(shape.ndim()))) + }; + in_data[batchnorm::kData] = in_data[batchnorm::kData].Reshape(new_shape); + out = out.Reshape(new_shape); + } + const std::vector aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end()); TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); mkldnn::normalization_flags flags = _GetFlags(in_data, @@ -160,7 +178,6 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, ctx.is_train && !param.use_global_stats); const NDArray &data = in_data[batchnorm::kData]; auto &fwd = GetBNForward(param, ctx, data, flags); - const NDArray &out = outputs[batchnorm::kOut]; // for output memory auto out_mem = const_cast(out).CreateMKLDNNData(fwd.GetPd().dst_desc()); @@ -304,9 +321,9 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, param, ctx.is_train && !param.use_global_stats); - const NDArray &data = in_data[batchnorm::kData]; - const NDArray &diff = out_grad[batchnorm::kOut]; - const NDArray &gradIn = in_grad[batchnorm::kData]; + NDArray data = in_data[batchnorm::kData]; + NDArray diff = out_grad[batchnorm::kOut]; + NDArray gradIn = in_grad[batchnorm::kData]; const NDArray &moving_mean = aux_states[batchnorm::kMovingMean]; const NDArray &moving_var = aux_states[batchnorm::kMovingVar]; const NDArray &out_mean = out_data[batchnorm::kMean]; @@ -317,6 +334,23 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, CHECK(moving_mean.IsDefaultData()); CHECK(moving_var.IsDefaultData()); + mxnet::TShape shape = data.shape(); + const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); + CHECK_LT(real_axis, shape.ndim()); + if (param.axis != 1 || shape.ndim() != 4) { + // reshape to (N, C, 1, D) + mxnet::TShape new_shape{ + static_cast(shape.ProdShape(0, real_axis)), + shape[real_axis], + 1, + static_cast(shape.ProdShape(real_axis + 1, + static_cast(shape.ndim()))) + }; + data = data.Reshape(new_shape); + diff = diff.Reshape(new_shape); + gradIn = gradIn.Reshape(new_shape); + } + auto data_mem = data.GetMKLDNNData(); auto diff_mem = diff.GetMKLDNNData(); // MKLDNN batchnorm should run on special layouts. If one of them isn't, we