From d02011e032574b14b0667598515d5ae13c9f4e70 Mon Sep 17 00:00:00 2001 From: Istvan Fehervari Date: Thu, 24 Jan 2019 18:13:54 -0800 Subject: [PATCH 1/3] Added optional parameters to BilinearResize2D to do relative scaling --- src/operator/contrib/bilinear_resize-inl.h | 28 +++++++++++++++++----- tests/python/unittest/test_operator.py | 6 +++++ 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/src/operator/contrib/bilinear_resize-inl.h b/src/operator/contrib/bilinear_resize-inl.h index ff3f794d167d..5eb5337aa223 100644 --- a/src/operator/contrib/bilinear_resize-inl.h +++ b/src/operator/contrib/bilinear_resize-inl.h @@ -50,11 +50,17 @@ namespace op { struct BilinearSampleParam : public dmlc::Parameter { int height; int width; + dmlc::optional scale_height; + dmlc::optional scale_width; DMLC_DECLARE_PARAMETER(BilinearSampleParam) { - DMLC_DECLARE_FIELD(height).set_range(1, 10000) - .describe("output height (required)"); - DMLC_DECLARE_FIELD(width).set_range(1, 10000) - .describe("output width (required)"); + DMLC_DECLARE_FIELD(height).set_default(1).set_range(1, 10000) + .describe("output height (required, but ignored if scale_height is defined)"); + DMLC_DECLARE_FIELD(width).set_default(1).set_range(1, 10000) + .describe("output width (required, but ignored if scale_width is defined)"); + DMLC_DECLARE_FIELD(scale_height).set_default(dmlc::optional()) + .describe("sampling scale of the height (optional, ignores height if defined)"); + DMLC_DECLARE_FIELD(scale_width).set_default(dmlc::optional()) + .describe("sampling scale of the scale_width (optional, ignores width if defined)"); } }; @@ -129,8 +135,18 @@ static bool BilinearSampleOpInferShape(const nnvm::NodeAttrs& attrs, const BilinearSampleParam& param = nnvm::get(attrs.parsed); TShape dshape(in_shape->at(0)); if (dshape.ndim() == 0) return false; - dshape[2] = param.height; - dshape[3] = param.width; + if (param.scale_height.has_value()) { + dshape[2] = int(param.scale_height.value() * in_shape->at(0)[2]); + } else { + dshape[2] = param.height; + } + + if (param.scale_height.has_value()) { + dshape[3] = int(param.scale_width.value() * in_shape->at(0)[3]); + } else { + dshape[3] = param.width; + } + out_shape->clear(); out_shape->push_back(dshape); return true; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 67aeddf19c44..708f9f4d367a 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6533,6 +6533,12 @@ def check_bilinear_resize_op(shape, height, width): x = mx.nd.random.uniform(shape=shape) y = mx.nd.contrib.BilinearResize2D(x, height=height, width=width) assert_almost_equal(y.asnumpy(), py_bilinear_resize(x.asnumpy(), height, width)) + + x_scale = width / shape[-1] + y_scale = height / shape[-2] + y = mx.nd.contrib.BilinearResize2D(x, height=1, width=1, scale_height=y_scale, + scale_width=x_scale) + assert_almost_equal(y.asnumpy(), py_bilinear_resize(x.asnumpy(), height, width)) shape = (2, 2, 10, 10) check_bilinear_resize_op(shape, 5, 5) check_bilinear_resize_op(shape, 10, 10) From 9f03c1d31c833a3d4a76c32e202e2c6baafa19e5 Mon Sep 17 00:00:00 2001 From: Istvan Fehervari Date: Thu, 24 Jan 2019 18:19:04 -0800 Subject: [PATCH 2/3] Removed unnecessary params in unit tests. --- tests/python/unittest/test_operator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 708f9f4d367a..3f34ade448dc 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6536,8 +6536,7 @@ def check_bilinear_resize_op(shape, height, width): x_scale = width / shape[-1] y_scale = height / shape[-2] - y = mx.nd.contrib.BilinearResize2D(x, height=1, width=1, scale_height=y_scale, - scale_width=x_scale) + y = mx.nd.contrib.BilinearResize2D(x, scale_height=y_scale, scale_width=x_scale) assert_almost_equal(y.asnumpy(), py_bilinear_resize(x.asnumpy(), height, width)) shape = (2, 2, 10, 10) check_bilinear_resize_op(shape, 5, 5) From d045d3b0bb6c01fcf5ea7959d83ce8d8b7df824c Mon Sep 17 00:00:00 2001 From: Istvan Fehervari Date: Thu, 24 Jan 2019 21:45:51 -0800 Subject: [PATCH 3/3] Fixed deprecated casting style --- src/operator/contrib/bilinear_resize-inl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/contrib/bilinear_resize-inl.h b/src/operator/contrib/bilinear_resize-inl.h index 5eb5337aa223..5a653d8a175c 100644 --- a/src/operator/contrib/bilinear_resize-inl.h +++ b/src/operator/contrib/bilinear_resize-inl.h @@ -136,13 +136,13 @@ static bool BilinearSampleOpInferShape(const nnvm::NodeAttrs& attrs, TShape dshape(in_shape->at(0)); if (dshape.ndim() == 0) return false; if (param.scale_height.has_value()) { - dshape[2] = int(param.scale_height.value() * in_shape->at(0)[2]); + dshape[2] = static_cast(param.scale_height.value() * in_shape->at(0)[2]); } else { dshape[2] = param.height; } if (param.scale_height.has_value()) { - dshape[3] = int(param.scale_width.value() * in_shape->at(0)[3]); + dshape[3] = static_cast(param.scale_width.value() * in_shape->at(0)[3]); } else { dshape[3] = param.width; }