Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MKLDNN] enable MaxPooling with full pooling convention #16860

Merged
merged 12 commits into from
Dec 16, 2019
20 changes: 19 additions & 1 deletion src/operator/nn/mkldnn/mkldnn_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ class MKLDNNPoolingBwd {
const mkldnn::pooling_backward::primitive_desc &GetPd();
};

inline int GetPaddingSizeFull(dim_t x, int padl, int padr, int k, int s) {
if ((x + padl + padr - k) % s != 0) {
return (padr + s - ((x + padl + padr - k) % s));
} else {
return padr;
}
}

inline bool SupportMKLDNNPooling(const PoolingParam &param) {
return param.kernel.ndim() == 2 &&
(param.pool_type == pool_enum::kMaxPooling ||
Expand All @@ -105,7 +113,17 @@ inline bool SupportMKLDNNPooling(const PoolingParam &param,
if (param.pooling_convention == pool_enum::kValid) {
return true;
} else {
// currently, only max-pooling is supported for full convention
if (param.pool_type == pool_enum::kAvgPooling) {
CHECK_EQ(dshape.ndim(), 4);
// mkldnn works differently when padding is asymmetric, so let's skip this case.
if (param.pad[0] == GetPaddingSizeFull(dshape[2], param.pad[0], param.pad[0], param.kernel[0],
param.stride[0]) &&
param.pad[1] == GetPaddingSizeFull(dshape[3], param.pad[1], param.pad[1], param.kernel[1],
param.stride[1])) {
return true;
}
return false;
}
return param.pool_type == pool_enum::kMaxPooling;
}
}
Expand Down
7 changes: 0 additions & 7 deletions src/operator/nn/mkldnn/mkldnn_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,6 @@ mkldnn::algorithm GetMKLDNNPoolAlgo(const PoolingParam &param) {
}
}

static inline int GetPaddingSizeFull(dim_t x, int padl, int padr, int k, int s) {
if ((x + padl + padr - k) % s != 0) {
return (padr + s - ((x + padl + padr - k) % s));
} else {
return padr;
}
}

mkldnn::pooling_forward::primitive_desc GetPoolingFwdPdesc(
const PoolingParam &param, const bool is_train, const mkldnn::memory::desc &data_md,
Expand Down
5 changes: 2 additions & 3 deletions src/operator/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,9 +278,8 @@ void PoolingComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
return;
}


if (SupportMKLDNN(inputs[0]) &&
SupportMKLDNNPooling(param, inputs[0].shape())) {
if (SupportMKLDNN(inputs[0])
&& SupportMKLDNNPooling(param, inputs[0].shape())) {
if (MKLDNNRequireWorkspace(param)) {
CHECK_GT(outputs.size(), 1U);
workspace = &outputs[1];
Expand Down