Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix UT & address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
wuxun-zhang committed Mar 25, 2020
1 parent 1e256ee commit 3bb434f
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 27 deletions.
11 changes: 8 additions & 3 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) {
return false;
}
return (dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16) &&
(ndim >= 1 && ndim <= 5);
(ndim == 1 || ndim == 2 || ndim == 4);
}

static inline bool SupportMKLDNNQuantize(int dtype) {
Expand Down Expand Up @@ -327,8 +327,13 @@ inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr,
CHECK((ndim == 3) || (ndim == 4) || (ndim == 5))
<< "MKL-DNN weight currently supports 3d or 4d or 5d layout";
auto tz = mkldnn::memory::dims{0};
const int D = (ndim == 5) ? 2 : 1;
const int N = 0, C = 1, H = D + 1, W = D + 2;
int N = 0, C = 1, H = 2, W = 3;
int D = -1;
if (ndim == 5) {
D = 2;
H = 3;
W = 4;
}
switch (ndim) {
case 3:
tz = mkldnn::memory::dims{
Expand Down
9 changes: 7 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,13 @@ const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) {
auto format_tag = mkldnn::memory::format_tag::undef;
auto engine = CpuEngine::Get()->get_engine();
const int ndim = arr.shape().ndim();
const int D = (ndim == 5) ? 2 : 1;
const int O = 0, I = 1, H = D + 1, W = D + 2;
int O = 0, I = 1, H = 2, W = 3;
int D = -1;
if (ndim == 5) {
D = 2;
H = 3;
W = 4;
}
if (ndim == 2) {
tz = mkldnn::memory::dims{arr.shape()[O], arr.shape()[I]};
format_tag = mkldnn::memory::format_tag::oi;
Expand Down
16 changes: 11 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,25 @@ inline bool SupportMKLDNNPooling(const PoolingParam &param) {
}

inline bool SupportMKLDNNPooling(const PoolingParam &param,
const mxnet::TShape &dshape) {
bool ret = SupportMKLDNNPooling(param);
if (!ret)
const NDArray &input) {
const auto dshape = input.shape();
const auto ndim = dshape.ndim();
const auto dtype = input.dtype();

if (!(SupportStorageMKLDNN(input.storage_type()) && (ndim == 3 || ndim == 4 || ndim == 5) &&
(dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16)))
return false;

if (!SupportMKLDNNPooling(param))
return false;

if (param.pooling_convention == pool_enum::kValid) {
return true;
} else {
if (param.pool_type == pool_enum::kAvgPooling) {
CHECK(dshape.ndim() == 3 || dshape.ndim() == 4 || dshape.ndim() == 5);
// mkldnn works differently when padding is asymmetric, so let's skip this case.
bool is_symmetric = true;
switch (dshape.ndim()) {
switch (ndim) {
case 5:
is_symmetric = is_symmetric && (param.pad[2] == GetPaddingSizeFull(dshape[4],
param.pad[2], param.pad[2], param.kernel[2], param.stride[2]));
Expand Down
22 changes: 11 additions & 11 deletions src/operator/nn/mkldnn/mkldnn_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,15 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam &param) {

void InitPoolingPrimitiveParams(const PoolingParam &param,
const mkldnn::memory::desc &data_md,
mkldnn::memory::dims *new_kernel,
mkldnn::memory::dims *new_strides,
mkldnn::memory::dims *new_pad_l,
mkldnn::memory::dims *new_pad_r) {
const mkldnn::memory::dims &new_kernel,
const mkldnn::memory::dims &new_strides,
const mkldnn::memory::dims &new_pad_l,
const mkldnn::memory::dims &new_pad_r) {
const int kernel_ndims = param.kernel.ndim();
mkldnn::memory::dims& kernel = *new_kernel;
mkldnn::memory::dims& strides = *new_strides;
mkldnn::memory::dims& pad_l = *new_pad_l;
mkldnn::memory::dims& pad_r = *new_pad_r;
mkldnn::memory::dims& kernel = const_cast<mkldnn::memory::dims&>(new_kernel);
mkldnn::memory::dims& strides = const_cast<mkldnn::memory::dims&>(new_strides);
mkldnn::memory::dims& pad_l = const_cast<mkldnn::memory::dims&>(new_pad_l);
mkldnn::memory::dims& pad_r = const_cast<mkldnn::memory::dims&>(new_pad_r);
if (kernel_ndims == 1) {
CHECK_GE(param.pad.ndim(), 1);
CHECK_GE(param.stride.ndim(), 1);
Expand Down Expand Up @@ -238,7 +238,7 @@ mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
mkldnn::memory::dims pad_l(kernel_ndims);
mkldnn::memory::dims pad_r(kernel_ndims);

InitPoolingPrimitiveParams(param, data_md, &kernel, &strides, &pad_l, &pad_r);
InitPoolingPrimitiveParams(param, data_md, kernel, strides, pad_l, pad_r);

const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
mkldnn::prop_kind kind = mkldnn::prop_kind::forward_scoring;
Expand Down Expand Up @@ -283,7 +283,7 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
mkldnn::memory::dims strides(kernel_ndims);
mkldnn::memory::dims pad_l(kernel_ndims);
mkldnn::memory::dims pad_r(kernel_ndims);
InitPoolingPrimitiveParams(param, data_md, &kernel, &strides, &pad_l, &pad_r);
InitPoolingPrimitiveParams(param, data_md, kernel, strides, pad_l, pad_r);

const mkldnn::algorithm alg = GetMKLDNNPoolAlgo(param);
MKLDNNPoolingFwd fwd(data, output, kernel, strides,
Expand Down Expand Up @@ -353,7 +353,7 @@ MKLDNNPoolingBwd &GetPoolingBwd(const PoolingParam &param,
mkldnn::memory::dims pad_l(kernel_ndims);
mkldnn::memory::dims pad_r(kernel_ndims);

InitPoolingPrimitiveParams(param, data_md, &kernel, &strides, &pad_l, &pad_r);
InitPoolingPrimitiveParams(param, data_md, kernel, strides, pad_l, pad_r);

const mkldnn::pooling_backward::desc desc(
alg, diff_in_md, diff_md, strides, kernel, pad_l, pad_r);
Expand Down
10 changes: 6 additions & 4 deletions src/operator/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,12 @@ void PoolingComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext &ctx,

// Pooling does not currently support working with views
if (inputs[0].IsView() || outputs[0].IsView()) {
std::cout << "Fall back to Pooling backward pass..." << std::endl;
FallBackCompute(PoolingCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}

if (SupportMKLDNN(inputs[0])
&& SupportMKLDNNPooling(param, inputs[0].shape())) {
if (SupportMKLDNNPooling(param, inputs[0])) {
if (MKLDNNRequireWorkspace(param)) {
CHECK_GT(outputs.size(), 1U);
workspace = &outputs[1];
Expand All @@ -289,6 +289,7 @@ void PoolingComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
MKLDNN_OPCHECK_RUN(PoolingCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
std::cout << "Fall back to Pooling forward pass..." << std::endl;
FallBackCompute(PoolingCompute<cpu>, attrs, ctx, inputs, req, outputs);
}

Expand All @@ -300,13 +301,13 @@ void PoolingGradComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext &ctx,

// Pooling does not currently support working with views
if (inputs[0].IsView() || outputs[0].IsView()) {
std::cout << "Fall back to Pooling backward pass..." << std::endl;
FallBackCompute(PoolingGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}


if (SupportMKLDNN(inputs[0])
&& SupportMKLDNNPooling(param, inputs[0].shape())) {
if (SupportMKLDNNPooling(param, inputs[0])) {
const NDArray &out_grad = inputs[0];
const NDArray *workspace = nullptr;
const NDArray *in_data = nullptr;
Expand All @@ -329,6 +330,7 @@ void PoolingGradComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
outputs);
return;
}
std::cout << "Fall back to Pooling backward pass..." << std::endl;
FallBackCompute(PoolingGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
}

Expand Down
8 changes: 6 additions & 2 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ def check_quantized_conv(data_shape, kernel, num_filter, pad, stride, dilate, no
elif qdtype == 'uint8' and is_test_for_gpu():
print('skipped testing quantized_conv for gpu uint8 since it is not supported yet')
return
elif is_test_for_gpu() and len(data_shape) != 4:
print('skipped testing quantized_conv for gpu 5d layout since it is not supported yet')
return

# run fp32 conv
data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
Expand Down Expand Up @@ -276,8 +279,6 @@ def check_quantized_conv(data_shape, kernel, num_filter, pad, stride, dilate, no
for qdtype in ['int8', 'uint8']:
check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), (1, 1), True, qdtype)
check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), (1, 1), False, qdtype)
check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), (2, 2), True, qdtype)
check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), (2, 2), False, qdtype)
check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 1, 1), (1, 1, 1), False, qdtype)
check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 1, 1), (1, 1, 1), True, qdtype)
check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 1, 1), (2, 2, 2), False, qdtype)
Expand Down Expand Up @@ -416,6 +417,9 @@ def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_p
elif qdtype == 'uint8' and is_test_for_gpu():
print('skipped testing quantized_pooling for gpu uint8 since it is not supported yet')
return
elif is_test_for_gpu() and len(data_shape) != 4:
print('skipped testing quantized_pooling for gpu 5d layout since it is not supported yet')
return

data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
pooling_fp32 = mx.sym.Pooling(data=data, kernel=kernel, pad=pad, stride=stride,
Expand Down

0 comments on commit 3bb434f

Please sign in to comment.