diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 3ae61298de8e..55416355d8aa 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -180,3 +180,4 @@ List of Contributors * [Per Goncalves da Silva](https://github.com/perdasilva) * [Zhijingcheng Yu](https://github.com/jasonyu1996) * [Cheng-Che Lee](https://github.com/stu1130) +* [Chaitanya Bapat](https://github.com/ChaiBapchya) \ No newline at end of file diff --git a/src/operator/nn/pool.h b/src/operator/nn/pool.h index 8f7a5edc8324..33005c8e5f0f 100644 --- a/src/operator/nn/pool.h +++ b/src/operator/nn/pool.h @@ -73,7 +73,7 @@ namespace pool_enum { enum PoolingOpInputs {kData}; enum PoolingOpOutputs {kOut, kMask}; enum PoolingOpType {kMaxPooling, kAvgPooling, kSumPooling, kLpPooling}; -enum PoolingOpPadConventionType {kValid, kFull}; +enum PoolingOpPadConventionType {kValid, kFull, kSame}; } // namespace pool_enum /*! diff --git a/src/operator/nn/pooling-inl.h b/src/operator/nn/pooling-inl.h index ad74a8feae39..71d85da9ba52 100644 --- a/src/operator/nn/pooling-inl.h +++ b/src/operator/nn/pooling-inl.h @@ -74,6 +74,7 @@ struct PoolingParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(pooling_convention).set_default(pool_enum::kValid) .add_enum("full", pool_enum::kFull) .add_enum("valid", pool_enum::kValid) + .add_enum("same", pool_enum::kSame) .describe("Pooling convention to be applied."); DMLC_DECLARE_FIELD(stride).set_default(TShape()) diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 558722edb202..611568807a9a 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -96,6 +96,13 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, CHECK(param.p_value.has_value()); } const TShape &dshape = (*in_shape)[0]; + if (param.pooling_convention == pool_enum::kSame) { + CHECK_EQ(dshape.ndim(), 3U) + << "Pooling: Input data should be 3D in (batch, channel, x)" + << ". Currently 'same' supports Max Pooling 1-D"; + CHECK(param.pad[0] == 0 && param.pad[1] == 0 && param.pad[2] == 0) + << "Same pooling convention disables the use of pad parameter."; + } CHECK_GE(dshape.ndim(), 3U) << "Pooling: Input data should be 3D in (batch, channel, x)" << " Or 4D in (batch, channel, y, x) " @@ -126,11 +133,15 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs, oshape[2] = 1 + (dshape[2] + 2 * param.pad[0] - param.kernel[0]) / param.stride[0]; - } else { + } else if (param.pooling_convention == pool_enum::kFull) { oshape[2] = 1 + static_cast(std::ceil( static_cast(dshape[2] + 2 * param.pad[0] - param.kernel[0]) / param.stride[0])); + } else { + oshape[2] = static_cast(std::ceil( + static_cast(dshape[2] + 2 * param.pad[0]) / + param.stride[0])); } out_shape->clear(); out_shape->push_back(oshape); // save output shape diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 43c357808f1f..a7f484e81b38 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6975,6 +6975,40 @@ def test_valid_kernel_size(): mx.nd.array(np.random.rand(1, 1, 28, 28)), kernel_size=valid_kernel_size) +@with_seed() +def test_valid_max_pooling_pad_type_same(): + import math + input_data = mx.nd.array(np.random.rand(1,1,10)) + stride = 2 + kernel = 2 + output_data=mx.nd.Pooling( + input_data, + kernel=kernel, + stride=stride, + pad=(0,0,0), + pool_type='max', + name='pooling', + pooling_convention="same") + assert(math.ceil(input_data.shape[2]/stride) == output_data.shape[2]) + +@with_seed() +def test_invalid_max_pooling_pad_type_same(): + import math + input_data = mx.nd.array(np.random.rand(1,1,10)) + stride = 2 + kernel = 2 + pad = 2 + assert_exception( + mx.nd.Pooling, + MXNetError, + input_data, + stride=stride, + kernel=kernel, + pad=pad, + pool_type='max', + name='pooling', + pooling_convention="same") + if __name__ == '__main__': import nose nose.runmodule()