From 3bb434f0b30f2e34582799dea8f08d60f8729a8a Mon Sep 17 00:00:00 2001 From: wuxun-zhang Date: Wed, 25 Mar 2020 15:28:01 +0800 Subject: [PATCH] fix UT & address comments --- src/operator/nn/mkldnn/mkldnn_base-inl.h | 11 +++++++--- src/operator/nn/mkldnn/mkldnn_base.cc | 9 ++++++-- src/operator/nn/mkldnn/mkldnn_pooling-inl.h | 16 +++++++++----- src/operator/nn/mkldnn/mkldnn_pooling.cc | 22 +++++++++---------- src/operator/nn/pooling.cc | 10 +++++---- .../python/quantization/test_quantization.py | 8 +++++-- 6 files changed, 49 insertions(+), 27 deletions(-) diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index d5060925e5b4..65a0a6918558 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -154,7 +154,7 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) { return false; } return (dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16) && - (ndim >= 1 && ndim <= 5); + (ndim == 1 || ndim == 2 || ndim == 4); } static inline bool SupportMKLDNNQuantize(int dtype) { @@ -327,8 +327,13 @@ inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr, CHECK((ndim == 3) || (ndim == 4) || (ndim == 5)) << "MKL-DNN weight currently supports 3d or 4d or 5d layout"; auto tz = mkldnn::memory::dims{0}; - const int D = (ndim == 5) ? 2 : 1; - const int N = 0, C = 1, H = D + 1, W = D + 2; + int N = 0, C = 1, H = 2, W = 3; + int D = -1; + if (ndim == 5) { + D = 2; + H = 3; + W = 4; + } switch (ndim) { case 3: tz = mkldnn::memory::dims{ diff --git a/src/operator/nn/mkldnn/mkldnn_base.cc b/src/operator/nn/mkldnn/mkldnn_base.cc index d790d73896b6..7aeb21b494ea 100644 --- a/src/operator/nn/mkldnn/mkldnn_base.cc +++ b/src/operator/nn/mkldnn/mkldnn_base.cc @@ -241,8 +241,13 @@ const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) { auto format_tag = mkldnn::memory::format_tag::undef; auto engine = CpuEngine::Get()->get_engine(); const int ndim = arr.shape().ndim(); - const int D = (ndim == 5) ? 2 : 1; - const int O = 0, I = 1, H = D + 1, W = D + 2; + int O = 0, I = 1, H = 2, W = 3; + int D = -1; + if (ndim == 5) { + D = 2; + H = 3; + W = 4; + } if (ndim == 2) { tz = mkldnn::memory::dims{arr.shape()[O], arr.shape()[I]}; format_tag = mkldnn::memory::format_tag::oi; diff --git a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h index d23ce051a695..ae1e23ed4363 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_pooling-inl.h @@ -106,19 +106,25 @@ inline bool SupportMKLDNNPooling(const PoolingParam ¶m) { } inline bool SupportMKLDNNPooling(const PoolingParam ¶m, - const mxnet::TShape &dshape) { - bool ret = SupportMKLDNNPooling(param); - if (!ret) + const NDArray &input) { + const auto dshape = input.shape(); + const auto ndim = dshape.ndim(); + const auto dtype = input.dtype(); + + if (!(SupportStorageMKLDNN(input.storage_type()) && (ndim == 3 || ndim == 4 || ndim == 5) && + (dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16))) + return false; + + if (!SupportMKLDNNPooling(param)) return false; if (param.pooling_convention == pool_enum::kValid) { return true; } else { if (param.pool_type == pool_enum::kAvgPooling) { - CHECK(dshape.ndim() == 3 || dshape.ndim() == 4 || dshape.ndim() == 5); // mkldnn works differently when padding is asymmetric, so let's skip this case. bool is_symmetric = true; - switch (dshape.ndim()) { + switch (ndim) { case 5: is_symmetric = is_symmetric && (param.pad[2] == GetPaddingSizeFull(dshape[4], param.pad[2], param.pad[2], param.kernel[2], param.stride[2])); diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc index a0d212328c98..bb1a75eb3e5f 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling.cc +++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc @@ -118,15 +118,15 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam ¶m) { void InitPoolingPrimitiveParams(const PoolingParam ¶m, const mkldnn::memory::desc &data_md, - mkldnn::memory::dims *new_kernel, - mkldnn::memory::dims *new_strides, - mkldnn::memory::dims *new_pad_l, - mkldnn::memory::dims *new_pad_r) { + const mkldnn::memory::dims &new_kernel, + const mkldnn::memory::dims &new_strides, + const mkldnn::memory::dims &new_pad_l, + const mkldnn::memory::dims &new_pad_r) { const int kernel_ndims = param.kernel.ndim(); - mkldnn::memory::dims& kernel = *new_kernel; - mkldnn::memory::dims& strides = *new_strides; - mkldnn::memory::dims& pad_l = *new_pad_l; - mkldnn::memory::dims& pad_r = *new_pad_r; + mkldnn::memory::dims& kernel = const_cast(new_kernel); + mkldnn::memory::dims& strides = const_cast(new_strides); + mkldnn::memory::dims& pad_l = const_cast(new_pad_l); + mkldnn::memory::dims& pad_r = const_cast(new_pad_r); if (kernel_ndims == 1) { CHECK_GE(param.pad.ndim(), 1); CHECK_GE(param.stride.ndim(), 1); @@ -238,7 +238,7 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc( mkldnn::memory::dims pad_l(kernel_ndims); mkldnn::memory::dims pad_r(kernel_ndims); - InitPoolingPrimitiveParams(param, data_md, &kernel, &strides, &pad_l, &pad_r); + InitPoolingPrimitiveParams(param, data_md, kernel, strides, pad_l, pad_r); const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); mkldnn::prop_kind kind = mkldnn::prop_kind::forward_scoring; @@ -283,7 +283,7 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam ¶m, mkldnn::memory::dims strides(kernel_ndims); mkldnn::memory::dims pad_l(kernel_ndims); mkldnn::memory::dims pad_r(kernel_ndims); - InitPoolingPrimitiveParams(param, data_md, &kernel, &strides, &pad_l, &pad_r); + InitPoolingPrimitiveParams(param, data_md, kernel, strides, pad_l, pad_r); const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param); MKLDNNPoolingFwd fwd(data, output, kernel, strides, @@ -353,7 +353,7 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam ¶m, mkldnn::memory::dims pad_l(kernel_ndims); mkldnn::memory::dims pad_r(kernel_ndims); - InitPoolingPrimitiveParams(param, data_md, &kernel, &strides, &pad_l, &pad_r); + InitPoolingPrimitiveParams(param, data_md, kernel, strides, pad_l, pad_r); const mkldnn::pooling_backward::desc desc( alg, diff_in_md, diff_md, strides, kernel, pad_l, pad_r); diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 75c410270591..a2e48eb783ef 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -274,12 +274,12 @@ void PoolingComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext &ctx, // Pooling does not currently support working with views if (inputs[0].IsView() || outputs[0].IsView()) { + std::cout << "Fall back to Pooling backward pass..." << std::endl; FallBackCompute(PoolingCompute, attrs, ctx, inputs, req, outputs); return; } - if (SupportMKLDNN(inputs[0]) - && SupportMKLDNNPooling(param, inputs[0].shape())) { + if (SupportMKLDNNPooling(param, inputs[0])) { if (MKLDNNRequireWorkspace(param)) { CHECK_GT(outputs.size(), 1U); workspace = &outputs[1]; @@ -289,6 +289,7 @@ void PoolingComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext &ctx, MKLDNN_OPCHECK_RUN(PoolingCompute, attrs, ctx, inputs, req, outputs); return; } + std::cout << "Fall back to Pooling forward pass..." << std::endl; FallBackCompute(PoolingCompute, attrs, ctx, inputs, req, outputs); } @@ -300,13 +301,13 @@ void PoolingGradComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext &ctx, // Pooling does not currently support working with views if (inputs[0].IsView() || outputs[0].IsView()) { + std::cout << "Fall back to Pooling backward pass..." << std::endl; FallBackCompute(PoolingGradCompute, attrs, ctx, inputs, req, outputs); return; } - if (SupportMKLDNN(inputs[0]) - && SupportMKLDNNPooling(param, inputs[0].shape())) { + if (SupportMKLDNNPooling(param, inputs[0])) { const NDArray &out_grad = inputs[0]; const NDArray *workspace = nullptr; const NDArray *in_data = nullptr; @@ -329,6 +330,7 @@ void PoolingGradComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext &ctx, outputs); return; } + std::cout << "Fall back to Pooling backward pass..." << std::endl; FallBackCompute(PoolingGradCompute, attrs, ctx, inputs, req, outputs); } diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index d3a69c87e126..8c6100d50765 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -207,6 +207,9 @@ def check_quantized_conv(data_shape, kernel, num_filter, pad, stride, dilate, no elif qdtype == 'uint8' and is_test_for_gpu(): print('skipped testing quantized_conv for gpu uint8 since it is not supported yet') return + elif is_test_for_gpu() and len(data_shape) != 4: + print('skipped testing quantized_conv for gpu 5d layout since it is not supported yet') + return # run fp32 conv data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') @@ -276,8 +279,6 @@ def check_quantized_conv(data_shape, kernel, num_filter, pad, stride, dilate, no for qdtype in ['int8', 'uint8']: check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), (1, 1), True, qdtype) check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), (1, 1), False, qdtype) - check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), (2, 2), True, qdtype) - check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), (2, 2), False, qdtype) check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 1, 1), (1, 1, 1), False, qdtype) check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 1, 1), (1, 1, 1), True, qdtype) check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 1, 1), (2, 2, 2), False, qdtype) @@ -416,6 +417,9 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p elif qdtype == 'uint8' and is_test_for_gpu(): print('skipped testing quantized_pooling for gpu uint8 since it is not supported yet') return + elif is_test_for_gpu() and len(data_shape) != 4: + print('skipped testing quantized_pooling for gpu 5d layout since it is not supported yet') + return data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') pooling_fp32 = mx.sym.Pooling(data=data, kernel=kernel, pad=pad, stride=stride,