Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-867] Pooling1D with "same" padding #12594

Merged
merged 10 commits into from
Sep 25, 2018
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,4 @@ List of Contributors
* [Per Goncalves da Silva](https://github.com/perdasilva)
* [Zhijingcheng Yu](https://github.com/jasonyu1996)
* [Cheng-Che Lee](https://github.com/stu1130)
* [Chaitanya Bapat](https://github.com/ChaiBapchya)
2 changes: 1 addition & 1 deletion src/operator/nn/pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ namespace pool_enum {
enum PoolingOpInputs {kData};
enum PoolingOpOutputs {kOut, kMask};
enum PoolingOpType {kMaxPooling, kAvgPooling, kSumPooling, kLpPooling};
enum PoolingOpPadConventionType {kValid, kFull};
enum PoolingOpPadConventionType {kValid, kFull, kSame};
} // namespace pool_enum

/*!
Expand Down
1 change: 1 addition & 0 deletions src/operator/nn/pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ struct PoolingParam : public dmlc::Parameter<PoolingParam> {
DMLC_DECLARE_FIELD(pooling_convention).set_default(pool_enum::kValid)
.add_enum("full", pool_enum::kFull)
.add_enum("valid", pool_enum::kValid)
.add_enum("same", pool_enum::kSame)
.describe("Pooling convention to be applied.");

DMLC_DECLARE_FIELD(stride).set_default(TShape())
Expand Down
13 changes: 11 additions & 2 deletions src/operator/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
CHECK(param.p_value.has_value());
}
const TShape &dshape = (*in_shape)[0];
if (param.pooling_convention == pool_enum::kSame) {
CHECK_EQ(dshape.ndim(), 3U)
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< ", Currently 'same' supports Max Pooling 1-D";
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved
}
CHECK_GE(dshape.ndim(), 3U)
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< " Or 4D in (batch, channel, y, x) "
Expand Down Expand Up @@ -126,11 +131,15 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
oshape[2] = 1 +
(dshape[2] + 2 * param.pad[0] - param.kernel[0]) /
param.stride[0];
} else {
oshape[2] = 1 + static_cast<int>(std::ceil(
} else if (param.pooling_convention == pool_enum::kFull) {
ChaiBapchya marked this conversation as resolved.
Show resolved Hide resolved
oshape[2] = 1 + static_cast<int>(ceil(
static_cast<float>(dshape[2] + 2 * param.pad[0] -
param.kernel[0]) /
param.stride[0]));
} else {
oshape[2] = static_cast<int>(ceil(
static_cast<float>(dshape[2] + 2 * param.pad[0]) /
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indentation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unable to find the indentation issue. Also is there a similar linting tool for python to ensure I don't miss out on those. Do we use pylint?

param.stride[0]));
}
out_shape->clear();
out_shape->push_back(oshape); // save output shape
Expand Down
15 changes: 15 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6957,6 +6957,21 @@ def test_invalid_depth_dim():
test_invalid_block_size()
test_invalid_depth_dim()

@with_seed()
def test_max_pooling_pad_type_same():
import math
input_data = mx.nd.array(np.random.rand(1,1,10))
stride = 2
kernel = 2
output_data=mx.nd.Pooling(
input_data,
Copy link
Contributor

@apeforest apeforest Sep 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you also want to test the case where user selects "same" pad type but also supplies pad values?

kernel=kernel,
stride=stride,
pool_type='max',
name='pooling',
pooling_convention="same")
assert(math.ceil(input_data.shape[2]/stride) == output_data.shape[2])

@with_seed()
def test_invalid_kernel_size():
invalid_kernel_size = 28
Expand Down