Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-867] Pooling1D with "same" padding (#12594)
Browse files Browse the repository at this point in the history
* fix max pool same padding

* syntax test fix

* added my name to contributors

* merge residue cleaned

* add enum

* fixed for all dimension and check for edge cases

* indentation, invalid unittest

* fixes

* and replaced by &&
  • Loading branch information
ChaiBapchya authored and anirudh2290 committed Sep 25, 2018
1 parent 76ae725 commit 29ac191
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 2 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,4 @@ List of Contributors
* [Per Goncalves da Silva](https://github.com/perdasilva)
* [Zhijingcheng Yu](https://github.com/jasonyu1996)
* [Cheng-Che Lee](https://github.com/stu1130)
* [Chaitanya Bapat](https://github.com/ChaiBapchya)
2 changes: 1 addition & 1 deletion src/operator/nn/pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ namespace pool_enum {
enum PoolingOpInputs {kData};
enum PoolingOpOutputs {kOut, kMask};
enum PoolingOpType {kMaxPooling, kAvgPooling, kSumPooling, kLpPooling};
enum PoolingOpPadConventionType {kValid, kFull};
enum PoolingOpPadConventionType {kValid, kFull, kSame};
} // namespace pool_enum

/*!
Expand Down
1 change: 1 addition & 0 deletions src/operator/nn/pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
DMLC_DECLARE_FIELD(pooling_convention).set_default(pool_enum::kValid)
.add_enum("full", pool_enum::kFull)
.add_enum("valid", pool_enum::kValid)
.add_enum("same", pool_enum::kSame)
.describe("Pooling convention to be applied.");

DMLC_DECLARE_FIELD(stride).set_default(TShape())
Expand Down
13 changes: 12 additions & 1 deletion src/operator/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
CHECK(param.p_value.has_value());
}
const TShape &dshape = (*in_shape)[0];
if (param.pooling_convention == pool_enum::kSame) {
CHECK_EQ(dshape.ndim(), 3U)
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< ". Currently 'same' supports Max Pooling 1-D";
CHECK(param.pad[0] == 0 && param.pad[1] == 0 && param.pad[2] == 0)
<< "Same pooling convention disables the use of pad parameter.";
}
CHECK_GE(dshape.ndim(), 3U)
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< " Or 4D in (batch, channel, y, x) "
Expand Down Expand Up @@ -126,11 +133,15 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
oshape[2] = 1 +
(dshape[2] + 2 * param.pad[0] - param.kernel[0]) /
param.stride[0];
} else {
} else if (param.pooling_convention == pool_enum::kFull) {
oshape[2] = 1 + static_cast<int>(std::ceil(
static_cast<float>(dshape[2] + 2 * param.pad[0] -
param.kernel[0]) /
param.stride[0]));
} else {
oshape[2] = static_cast<int>(std::ceil(
static_cast<float>(dshape[2] + 2 * param.pad[0]) /
param.stride[0]));
}
out_shape->clear();
out_shape->push_back(oshape); // save output shape
Expand Down
34 changes: 34 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6975,6 +6975,40 @@ def test_valid_kernel_size():
mx.nd.array(np.random.rand(1, 1, 28, 28)),
kernel_size=valid_kernel_size)

@with_seed()
def test_valid_max_pooling_pad_type_same():
import math
input_data = mx.nd.array(np.random.rand(1,1,10))
stride = 2
kernel = 2
output_data=mx.nd.Pooling(
input_data,
kernel=kernel,
stride=stride,
pad=(0,0,0),
pool_type='max',
name='pooling',
pooling_convention="same")
assert(math.ceil(input_data.shape[2]/stride) == output_data.shape[2])

@with_seed()
def test_invalid_max_pooling_pad_type_same():
import math
input_data = mx.nd.array(np.random.rand(1,1,10))
stride = 2
kernel = 2
pad = 2
assert_exception(
mx.nd.Pooling,
MXNetError,
input_data,
stride=stride,
kernel=kernel,
pad=pad,
pool_type='max',
name='pooling',
pooling_convention="same")

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 29ac191

Please sign in to comment.