diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 3e36559c0a7c..a59f8ba21705 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -422,10 +422,14 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs, #if MXNET_USE_MKLDNN == 1 static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam ¶m) { - mxnet::TShape shape = input.shape(); - return SupportMKLDNN(input) && shape.ndim() == 4 - && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS - && !mxnet::op::batchnorm::disable_mkl; + if (mxnet::op::batchnorm::disable_mkl) return false; + const mxnet::TShape shape = input.shape(); + const int ndim = shape.ndim(); + if (ndim == 0 || shape.Size() == 0) return false; + const int dtype = input.dtype(); + return (dtype == mshadow::kFloat32 || + dtype == mshadow::kBfloat16) && + SupportStorageMKLDNN(input.storage_type()); } void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs, diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 7b36d25e7496..40d677a8f16d 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -704,8 +704,7 @@ void BatchNormCompute(const nnvm::NodeAttrs& attrs, param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); #if MXNET_USE_CUDNN == 1 - if (!param.use_global_stats && !param.cudnn_off - && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { + if (!param.use_global_stats && !param.cudnn_off) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { GetCuDNNOp(param).Forward(ctx, in_data, req, outputs, aux_states); }) @@ -733,8 +732,7 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); #if MXNET_USE_CUDNN == 1 - if (!param.use_global_stats && !param.cudnn_off - && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { + if (!param.use_global_stats && !param.cudnn_off) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { GetCuDNNOp(param).Backward(ctx, inputs, req, outputs); }) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index fc91212fab37..797234c58f62 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -260,15 +260,27 @@ class CuDNNBatchNormOp { private: void Init(const TBlob &in_data) { - if (in_data.ndim() == 4) { - for (int i = 0; i < 4; ++i) - shape_[i] = in_data.shape_[i]; + CHECK_GE(param_.axis, 0); + CHECK_LT(param_.axis, in_data.ndim()); + if (param_.axis == 1) { + if (in_data.ndim() == 4) { + for (int i = 0; i < 4; ++i) + shape_[i] = in_data.shape_[i]; + } else { + // when in_data.ndim() != 4 + shape_[0] = in_data.shape_[0]; + shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; + shape_[2] = 1; + shape_[3] = static_cast(in_data.shape_.ProdShape(2, + in_data.ndim())); + } } else { - // when in_data.ndim() != 4 - shape_[0] = in_data.shape_[0]; - shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; + // reshape to (N, C, 1, D), C is the `param_.axis` dimension + shape_[0] = static_cast(in_data.shape_.ProdShape(0, param_.axis)); + shape_[1] = in_data.shape_[param_.axis]; shape_[2] = 1; - shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim()); + shape_[3] = static_cast(in_data.shape_.ProdShape(param_.axis + 1, + in_data.ndim())); } CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_, diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index da4fd97e82da..0a29a6d87de6 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -157,7 +157,25 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs, bool fuse_relu) { const BatchNormParam ¶m = nnvm::get(attrs.parsed); - const std::vector in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); + std::vector in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); + + mxnet::TShape shape = inputs[batchnorm::kData].shape(); + const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); + CHECK_LT(real_axis, shape.ndim()); + NDArray out = outputs[batchnorm::kOut]; + if (param.axis != 1 || shape.ndim() != 4) { + // reshape to (N, C, 1, D) + mxnet::TShape new_shape{ + static_cast(shape.ProdShape(0, real_axis)), + shape[real_axis], + 1, + static_cast(shape.ProdShape(real_axis + 1, + static_cast(shape.ndim()))) + }; + in_data[batchnorm::kData] = in_data[batchnorm::kData].Reshape(new_shape); + out = out.Reshape(new_shape); + } + const std::vector aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end()); TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); mkldnn::normalization_flags flags = _GetFlags(in_data, @@ -166,7 +184,6 @@ void MKLDNNBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, fuse_relu); const NDArray &data = in_data[batchnorm::kData]; auto &fwd = GetBNForward(param, ctx, data, flags); - const NDArray &out = outputs[batchnorm::kOut]; // for output memory auto out_mem = const_cast(out).CreateMKLDNNData(fwd.GetPd().dst_desc()); @@ -325,9 +342,9 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, ctx.is_train && !param.use_global_stats, fuse_relu); - const NDArray &data = in_data[batchnorm::kData]; - const NDArray &diff = out_grad[batchnorm::kOut]; - const NDArray &gradIn = in_grad[batchnorm::kData]; + NDArray data = in_data[batchnorm::kData]; + NDArray diff = out_grad[batchnorm::kOut]; + NDArray gradIn = in_grad[batchnorm::kData]; const NDArray &moving_mean = aux_states[batchnorm::kMovingMean]; const NDArray &moving_var = aux_states[batchnorm::kMovingVar]; const NDArray &out_mean = out_data[batchnorm::kMean]; @@ -338,6 +355,23 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, CHECK(moving_mean.IsDefaultData()); CHECK(moving_var.IsDefaultData()); + mxnet::TShape shape = data.shape(); + const int real_axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); + CHECK_LT(real_axis, shape.ndim()); + if (param.axis != 1 || shape.ndim() != 4) { + // reshape to (N, C, 1, D) + mxnet::TShape new_shape{ + static_cast(shape.ProdShape(0, real_axis)), + shape[real_axis], + 1, + static_cast(shape.ProdShape(real_axis + 1, + static_cast(shape.ndim()))) + }; + data = data.Reshape(new_shape); + diff = diff.Reshape(new_shape); + gradIn = gradIn.Reshape(new_shape); + } + auto data_mem = data.GetMKLDNNData(); auto diff_mem = diff.GetMKLDNNData(); // MKLDNN batchnorm should run on special layouts. If one of them isn't, we diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 97c7d8675495..475ff0243290 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1541,7 +1541,7 @@ def _test_batchnorm_impl(shape, fix_gamma, cudnn_off, output_mean_var, assert_almost_equal( bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) - shapes = [(24, 2), (24, 3, 4), (24, 8, 4, 5), (24, 5, 6, 4, 5)] + shapes = [(4, 2), (4, 3, 4), (4, 6, 4, 5), (4, 5, 6, 4, 5)] bools = [False, True] for shape, fix_gamma, cudnn_off, output_mean_var in itertools.product( shapes, bools, bools, bools): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index e6db0e9fc864..4e736e5dc0ab 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1964,7 +1964,7 @@ def _test_batchnorm_impl(op_name, shape, fix_gamma, cudnn_off, output_mean_var, bn_beta.grad.asnumpy(), adb.asnumpy(), atol=atol, rtol=rtol) op_names = ['BatchNorm', 'SyncBatchNorm'] - shapes = [(24, 2), (24, 3, 4), (24, 8, 4, 5), (24, 5, 6, 4, 5)] + shapes = [(4, 2), (4, 3, 4), (4, 6, 4, 5), (4, 5, 6, 4, 5)] bools = [False, True] for op_name, shape, fix_gamma, cudnn_off, output_mean_var in itertools.product( op_names, shapes, bools, bools, bools): diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 8e4fe11905cf..793b920bc1f8 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -272,36 +272,6 @@ def check_symbol_consistency(sym1, sym2, ctx, skip_grad=False, equal_nan=False): grad_req='null' if skip_grad else 'write', equal_nan=equal_nan) -def test_load_000800(): - with mx.AttrScope(ctx_group='stage1'): - data = mx.symbol.Variable('data', lr_mult=0.2) - weight = mx.sym.Variable(name='fc1_weight', lr_mult=1.2) - fc1 = mx.symbol.FullyConnected(data = data, weight=weight, name='fc1', num_hidden=128, wd_mult=0.3) - act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") - - set_stage1 = set(act1.list_arguments()) - with mx.AttrScope(ctx_group='stage2'): - fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64, lr_mult=0.01) - act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") - fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10) - fc3 = mx.symbol.BatchNorm(fc3, name='batchnorm0') - sym1 = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax') - - curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - sym2 = mx.sym.load(os.path.join(curr_path, 'save_000800.json')) - - attr1 = sym1.attr_dict() - attr2 = sym2.attr_dict() - for k, v1 in attr1.items(): - assert k in attr2, k - v2 = attr2[k] - for kk, vv1 in v1.items(): - if kk.startswith('__') and kk.endswith('__'): - assert kk in v2 and v2[kk] == vv1, k + str(v1) + str(v2) - - check_symbol_consistency(sym1, sym2, - {'ctx': mx.cpu(0), 'group2ctx': {'stage1' : mx.cpu(1), 'stage2' : mx.cpu(2)}, 'data': (1,200)}) - def test_blockgrad(): a = mx.sym.Variable('a')