Skip to content
This repository has been archived by the owner on Feb 9, 2021. It is now read-only.

Commit

Permalink
Added optional parameters to BilinearResize2D to do relative scaling (a…
Browse files Browse the repository at this point in the history
…pache#13985)

* Added optional parameters to BilinearResize2D to do relative scaling

* Removed unnecessary params in unit tests.

* Fixed deprecated casting style
  • Loading branch information
ifeherva authored and Gordon Reid committed Jan 27, 2019
1 parent 1c284f3 commit 1e2909c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
28 changes: 22 additions & 6 deletions src/operator/contrib/bilinear_resize-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,17 @@ namespace op {
struct BilinearSampleParam : public dmlc::Parameter<BilinearSampleParam> {
int height;
int width;
dmlc::optional<float> scale_height;
dmlc::optional<float> scale_width;
DMLC_DECLARE_PARAMETER(BilinearSampleParam) {
DMLC_DECLARE_FIELD(height).set_range(1, 10000)
.describe("output height (required)");
DMLC_DECLARE_FIELD(width).set_range(1, 10000)
.describe("output width (required)");
DMLC_DECLARE_FIELD(height).set_default(1).set_range(1, 10000)
.describe("output height (required, but ignored if scale_height is defined)");
DMLC_DECLARE_FIELD(width).set_default(1).set_range(1, 10000)
.describe("output width (required, but ignored if scale_width is defined)");
DMLC_DECLARE_FIELD(scale_height).set_default(dmlc::optional<float>())
.describe("sampling scale of the height (optional, ignores height if defined)");
DMLC_DECLARE_FIELD(scale_width).set_default(dmlc::optional<float>())
.describe("sampling scale of the scale_width (optional, ignores width if defined)");
}
};

Expand Down Expand Up @@ -129,8 +135,18 @@ static bool BilinearSampleOpInferShape(const nnvm::NodeAttrs& attrs,
const BilinearSampleParam& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
TShape dshape(in_shape->at(0));
if (dshape.ndim() == 0) return false;
dshape[2] = param.height;
dshape[3] = param.width;
if (param.scale_height.has_value()) {
dshape[2] = static_cast<int>(param.scale_height.value() * in_shape->at(0)[2]);
} else {
dshape[2] = param.height;
}

if (param.scale_height.has_value()) {
dshape[3] = static_cast<int>(param.scale_width.value() * in_shape->at(0)[3]);
} else {
dshape[3] = param.width;
}

out_shape->clear();
out_shape->push_back(dshape);
return true;
Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6533,6 +6533,11 @@ def check_bilinear_resize_op(shape, height, width):
x = mx.nd.random.uniform(shape=shape)
y = mx.nd.contrib.BilinearResize2D(x, height=height, width=width)
assert_almost_equal(y.asnumpy(), py_bilinear_resize(x.asnumpy(), height, width))

x_scale = width / shape[-1]
y_scale = height / shape[-2]
y = mx.nd.contrib.BilinearResize2D(x, scale_height=y_scale, scale_width=x_scale)
assert_almost_equal(y.asnumpy(), py_bilinear_resize(x.asnumpy(), height, width))
shape = (2, 2, 10, 10)
check_bilinear_resize_op(shape, 5, 5)
check_bilinear_resize_op(shape, 10, 10)
Expand Down

0 comments on commit 1e2909c

Please sign in to comment.