diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index d4b03ae3fc17..1c55f3602164 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -420,10 +420,14 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, #if MXNET_USE_MKLDNN == 1 static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam ¶m) { - mxnet::TShape shape = input.shape(); - return SupportMKLDNN(input) && shape.ndim() == 4 - && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS - && !mxnet::op::batchnorm::disable_mkl; + if (mxnet::op::batchnorm::disable_mkl) return false; + const mxnet::TShape shape = input.shape(); + const int ndim = shape.ndim(); + if (ndim == 0 || shape.Size() == 0) return false; + const int dtype = input.dtype(); + return (dtype == mshadow::kFloat32 || + dtype == mshadow::kBfloat16) && + SupportStorageMKLDNN(input.storage_type()); } void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs, diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 0875f05e669d..c7e991f98d18 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -698,8 +698,7 @@ void BatchNormCompute(const nnvm::NodeAttrs& attrs, param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); #if MXNET_USE_CUDNN == 1 - if (!param.use_global_stats && !param.cudnn_off - && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { + if (!param.use_global_stats && !param.cudnn_off) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { GetCuDNNOp(param).Forward(ctx, in_data, req, outputs, aux_states); }) @@ -727,8 +726,7 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); #if MXNET_USE_CUDNN == 1 - if (!param.use_global_stats && !param.cudnn_off - && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { + if (!param.use_global_stats && !param.cudnn_off) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { GetCuDNNOp(param).Backward(ctx, inputs, req, outputs); }) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 13db44d518b3..340c2f3494f2 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -262,15 +262,27 @@ class CuDNNBatchNormOp { private: void Init(const TBlob &in_data) { - if (in_data.ndim() == 4) { - for (int i = 0; i < 4; ++i) - shape_[i] = in_data.shape_[i]; + CHECK_GE(param_.axis, 0); + CHECK_LT(param_.axis, in_data.ndim()); + if (param_.axis == 1) { + if (in_data.ndim() == 4) { + for (int i = 0; i < 4; ++i) + shape_[i] = in_data.shape_[i]; + } else { + // when in_data.ndim() != 4 + shape_[0] = in_data.shape_[0]; + shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; + shape_[2] = 1; + shape_[3] = static_cast(in_data.shape_.ProdShape(2, + in_data.ndim())); + } } else { - // when in_data.ndim() != 4 - shape_[0] = in_data.shape_[0]; - shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; + // reshape to (N, C, 1, D), C is the `param_.axis` dimension + shape_[0] = static_cast(in_data.shape_.ProdShape(0, param_.axis)); + shape_[1] = in_data.shape_[param_.axis]; shape_[2] = 1; - shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim()); + shape_[3] = static_cast(in_data.shape_.ProdShape(param_.axis + 1, + in_data.ndim())); } CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_, diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index da4fd97e82da..0a29a6d87de6 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -157,7 +157,25 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs, bool fuse_relu) { const BatchNormParam ¶m = nnvm::get(attrs.parsed); - const std::vector in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); + std::vector in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); + + mxnet::TShape shape = inputs[batchnorm::kData].shape(); + const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); + CHECK_LT(real_axis, shape.ndim()); + NDArray out = outputs[batchnorm::kOut]; + if (param.axis != 1 || shape.ndim() != 4) { + // reshape to (N, C, 1, D) + mxnet::TShape new_shape{ + static_cast(shape.ProdShape(0, real_axis)), + shape[real_axis], + 1, + static_cast(shape.ProdShape(real_axis + 1, + static_cast(shape.ndim()))) + }; + in_data[batchnorm::kData] = in_data[batchnorm::kData].Reshape(new_shape); + out = out.Reshape(new_shape); + } + const std::vector aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end()); TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); mkldnn::normalization_flags flags = _GetFlags(in_data, @@ -166,7 +184,6 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, fuse_relu); const NDArray &data = in_data[batchnorm::kData]; auto &fwd = GetBNForward(param, ctx, data, flags); - const NDArray &out = outputs[batchnorm::kOut]; // for output memory auto out_mem = const_cast(out).CreateMKLDNNData(fwd.GetPd().dst_desc()); @@ -325,9 +342,9 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, ctx.is_train && !param.use_global_stats, fuse_relu); - const NDArray &data = in_data[batchnorm::kData]; - const NDArray &diff = out_grad[batchnorm::kOut]; - const NDArray &gradIn = in_grad[batchnorm::kData]; + NDArray data = in_data[batchnorm::kData]; + NDArray diff = out_grad[batchnorm::kOut]; + NDArray gradIn = in_grad[batchnorm::kData]; const NDArray &moving_mean = aux_states[batchnorm::kMovingMean]; const NDArray &moving_var = aux_states[batchnorm::kMovingVar]; const NDArray &out_mean = out_data[batchnorm::kMean]; @@ -338,6 +355,23 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, CHECK(moving_mean.IsDefaultData()); CHECK(moving_var.IsDefaultData()); + mxnet::TShape shape = data.shape(); + const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); + CHECK_LT(real_axis, shape.ndim()); + if (param.axis != 1 || shape.ndim() != 4) { + // reshape to (N, C, 1, D) + mxnet::TShape new_shape{ + static_cast(shape.ProdShape(0, real_axis)), + shape[real_axis], + 1, + static_cast(shape.ProdShape(real_axis + 1, + static_cast(shape.ndim()))) + }; + data = data.Reshape(new_shape); + diff = diff.Reshape(new_shape); + gradIn = gradIn.Reshape(new_shape); + } + auto data_mem = data.GetMKLDNNData(); auto diff_mem = diff.GetMKLDNNData(); // MKLDNN batchnorm should run on special layouts. If one of them isn't, we