Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-867] Pooling1D with "same" padding #12594

Merged
merged 10 commits into from
Sep 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,4 @@ List of Contributors
* [Per Goncalves da Silva](https://github.com/perdasilva)
* [Zhijingcheng Yu](https://github.com/jasonyu1996)
* [Cheng-Che Lee](https://github.com/stu1130)
* [Chaitanya Bapat](https://github.com/ChaiBapchya)
2 changes: 1 addition & 1 deletion src/operator/nn/pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ namespace pool_enum {
enum PoolingOpInputs {kData};
enum PoolingOpOutputs {kOut, kMask};
enum PoolingOpType {kMaxPooling, kAvgPooling, kSumPooling, kLpPooling};
enum PoolingOpPadConventionType {kValid, kFull};
enum PoolingOpPadConventionType {kValid, kFull, kSame};
} // namespace pool_enum

/*!
Expand Down
1 change: 1 addition & 0 deletions src/operator/nn/pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
DMLC_DECLARE_FIELD(pooling_convention).set_default(pool_enum::kValid)
.add_enum("full", pool_enum::kFull)
.add_enum("valid", pool_enum::kValid)
.add_enum("same", pool_enum::kSame)
.describe("Pooling convention to be applied.");

DMLC_DECLARE_FIELD(stride).set_default(TShape())
Expand Down
13 changes: 12 additions & 1 deletion src/operator/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
CHECK(param.p_value.has_value());
}
const TShape &dshape = (*in_shape)[0];
if (param.pooling_convention == pool_enum::kSame) {
CHECK_EQ(dshape.ndim(), 3U)
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< ". Currently 'same' supports Max Pooling 1-D";
CHECK(param.pad[0] == 0 && param.pad[1] == 0 && param.pad[2] == 0)
<< "Same pooling convention disables the use of pad parameter.";
}
CHECK_GE(dshape.ndim(), 3U)
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< " Or 4D in (batch, channel, y, x) "
Expand Down Expand Up @@ -126,11 +133,15 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
oshape[2] = 1 +
(dshape[2] + 2 * param.pad[0] - param.kernel[0]) /
param.stride[0];
} else {
} else if (param.pooling_convention == pool_enum::kFull) {
oshape[2] = 1 + static_cast<int>(std::ceil(
static_cast<float>(dshape[2] + 2 * param.pad[0] -
param.kernel[0]) /
param.stride[0]));
} else {
oshape[2] = static_cast<int>(std::ceil(
static_cast<float>(dshape[2] + 2 * param.pad[0]) /
param.stride[0]));
}
out_shape->clear();
out_shape->push_back(oshape); // save output shape
Expand Down
34 changes: 34 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6975,6 +6975,40 @@ def test_valid_kernel_size():
mx.nd.array(np.random.rand(1, 1, 28, 28)),
kernel_size=valid_kernel_size)

@with_seed()
def test_valid_max_pooling_pad_type_same():
import math
input_data = mx.nd.array(np.random.rand(1,1,10))
stride = 2
kernel = 2
output_data=mx.nd.Pooling(
input_data,
kernel=kernel,
stride=stride,
pad=(0,0,0),
pool_type='max',
name='pooling',
pooling_convention="same")
assert(math.ceil(input_data.shape[2]/stride) == output_data.shape[2])

@with_seed()
def test_invalid_max_pooling_pad_type_same():
import math
input_data = mx.nd.array(np.random.rand(1,1,10))
stride = 2
kernel = 2
pad = 2
assert_exception(
mx.nd.Pooling,
MXNetError,
input_data,
stride=stride,
kernel=kernel,
pad=pad,
pool_type='max',
name='pooling',
pooling_convention="same")

if __name__ == '__main__':
import nose
nose.runmodule()