diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index f3723f98c736..f0ec80e2725c 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -218,6 +218,7 @@ List of Contributors * [Dang Trung Kien](https://github.com/kiendang) * [Zach Boldyga](https://github.com/zboldyga) * [Gordon Reid](https://github.com/gordon1992) +* [Mikhail Lobanov](https://github.com/lobanov-m) * [Ming Yang](http://ufoym.com) * [Satya Krishna Gorti](https://github.com/satyakrishnagorti) * [Neo Chien](https://github.com/cchung100m) diff --git a/python/mxnet/visualization.py b/python/mxnet/visualization.py index 4101f749a583..bc8309a4a6ba 100644 --- a/python/mxnet/visualization.py +++ b/python/mxnet/visualization.py @@ -384,6 +384,10 @@ def looks_like_weight(name): continue else: inputs = node["inputs"] + + if node['op'] == '_contrib_BilinearResize2D': + inputs = [inputs[0]] + for item in inputs: input_node = nodes[item[0]] input_name = input_node["name"] diff --git a/src/operator/contrib/bilinear_resize-inl.h b/src/operator/contrib/bilinear_resize-inl.h index ce9c6c83504c..4da12cbdf280 100644 --- a/src/operator/contrib/bilinear_resize-inl.h +++ b/src/operator/contrib/bilinear_resize-inl.h @@ -44,6 +44,12 @@ #include "../mxnet_op.h" #include "../mshadow_op.h" +namespace bilinear_resize { +enum BilinearResizeOpMode{simple, odd_scale, like, to_even_down, to_even_up, to_odd_down, + to_odd_up}; +} // namespace bilinear_resize + + namespace mxnet { namespace op { @@ -52,15 +58,45 @@ struct BilinearSampleParam : public dmlc::Parameter { int width; dmlc::optional scale_height; dmlc::optional scale_width; + int mode; DMLC_DECLARE_PARAMETER(BilinearSampleParam) { DMLC_DECLARE_FIELD(height).set_default(1).set_range(1, 10000) - .describe("output height (required, but ignored if scale_height is defined)"); + .describe("output height (required, but ignored if scale_height is defined or mode is not " + "\"size\")"); DMLC_DECLARE_FIELD(width).set_default(1).set_range(1, 10000) - .describe("output width (required, but ignored if scale_width is defined)"); + .describe("output width (required, but ignored if scale_width is defined or mode is not " + "\"size\")"); DMLC_DECLARE_FIELD(scale_height).set_default(dmlc::optional()) - .describe("sampling scale of the height (optional, ignores height if defined)"); + .describe("sampling scale of the height (optional, used in modes \"scale\" and \"odd_scale\")"); DMLC_DECLARE_FIELD(scale_width).set_default(dmlc::optional()) - .describe("sampling scale of the scale_width (optional, ignores width if defined)"); + .describe("sampling scale of the width (optional, used in modes \"scale\" and \"odd_scale\")"); + DMLC_DECLARE_FIELD(mode) + .add_enum("size", bilinear_resize::simple) + .add_enum("odd_scale", bilinear_resize::odd_scale) + .add_enum("like", bilinear_resize::like) + .add_enum("to_even_down", bilinear_resize::to_even_down) + .add_enum("to_even_up", bilinear_resize::to_even_up) + .add_enum("to_odd_down", bilinear_resize::to_odd_down) + .add_enum("to_odd_up", bilinear_resize::to_odd_up) + .set_default(bilinear_resize::simple) + .describe("resizing mode. \"simple\" - output height equals parameter \"height\" if " + "\"scale_height\" parameter is not defined or input height multiplied by " + "\"scale_height\" otherwise. Same for width;" + "\"odd_scale\" - if original height or width is odd, then result height is " + "calculated like result_h = (original_h - 1) * scale + 1; " + "for scale > 1 the result shape would be like if we did deconvolution with kernel " + "= (1, 1) and stride = (height_scale, width_scale); and for scale < 1 shape " + "would be like we did convolution with kernel = (1, 1) and " + "stride = (int(1 / height_scale), int( 1/ width_scale);" + "\"like\" - resize first input to the height and width of second input; " + "\"to_even_down\" - resize input to nearest lower even height and width " + "(if original height is odd then result height = original height - 1);" + "\"to_even_up\" - resize input to nearest bigger even height and width " + "(if original height is odd then result height = original height + 1);" + "\"to_odd_down\" - resize input to nearest odd height and width " + "(if original height is odd then result height = original height - 1);" + "\"to_odd_up\" - resize input to nearest odd height and width " + "(if original height is odd then result height = original height + 1);"); } }; @@ -76,7 +112,8 @@ void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream *s, template void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, const std::vector &input, - const std::vector &output); + const std::vector &output, + bool modeLike); #if MXNET_USE_CUDA template @@ -87,7 +124,8 @@ void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream *s, template void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, const std::vector &input, - const std::vector &output); + const std::vector &output, + bool modeLike); #endif // MXNET_USE_CUDA template @@ -96,7 +134,9 @@ inline void BilinearSampleOpForward(const nnvm::NodeAttrs& attrs, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - CHECK_EQ(inputs.size(), 1U); + const BilinearSampleParam& param = nnvm::get(attrs.parsed); + size_t expected = param.mode == bilinear_resize::like ? 2 : 1; + CHECK_EQ(inputs.size(), expected); CHECK_EQ(outputs.size(), 1U); mshadow::Stream *s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, { @@ -111,8 +151,11 @@ inline void BilinearSampleOpBackward(const nnvm::NodeAttrs& attrs, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { + const BilinearSampleParam& param = nnvm::get(attrs.parsed); CHECK_EQ(inputs.size(), 1U); - CHECK_EQ(outputs.size(), 1U); + bool modeLike = param.mode == bilinear_resize::like; + size_t expected = modeLike ? 2 : 1; + CHECK_EQ(outputs.size(), expected); mshadow::Stream *s = ctx.get_stream(); if (IsWriting(req[0])) { // zero grad before backwarding @@ -121,7 +164,7 @@ inline void BilinearSampleOpBackward(const nnvm::NodeAttrs& attrs, }) } MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, { - SpatialUpSamplingBilinearUpdateGradInput(s, inputs, outputs); + SpatialUpSamplingBilinearUpdateGradInput(s, inputs, outputs, modeLike); }); } @@ -130,28 +173,120 @@ static bool BilinearSampleOpInferShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape) { using namespace mshadow; - CHECK_EQ(in_shape->size(), 1U) << "Input:[data]"; CHECK_EQ(out_shape->size(), 1U) << "Output:[data]"; const BilinearSampleParam& param = nnvm::get(attrs.parsed); + size_t expected = param.mode == bilinear_resize::like ? 2 : 1; + CHECK_EQ(in_shape->size(), expected); mxnet::TShape dshape(in_shape->at(0)); if (mxnet::op::shape_is_none(dshape)) return false; - if (param.scale_height.has_value()) { - dshape[2] = static_cast(param.scale_height.value() * in_shape->at(0)[2]); - } else { - dshape[2] = param.height; + int16_t new_height = -1; + int16_t new_width = -1; + switch (param.mode) { + case bilinear_resize::simple: + { + if (param.scale_height.has_value()) { + new_height = static_cast(param.scale_height.value() * in_shape->at(0)[2]); + } else { + new_height = param.height; + } + if (param.scale_height.has_value()) { + new_width = static_cast(param.scale_width.value() * in_shape->at(0)[3]); + } else { + new_width = param.width; + } + break; + } + case bilinear_resize::odd_scale: + { + new_height = ((dshape[2] % 2) == 0) ? (int16_t) (dshape[2] * param.scale_height.value()) : + (int16_t) ((dshape[2] - 1) * param.scale_height.value()) + 1; + new_width = ((dshape[3] % 2) == 0) ? (int16_t) (dshape[3] * param.scale_width.value()) : + (int16_t) ((dshape[3] - 1) * param.scale_width.value()) + 1; + break; + } + case bilinear_resize::like: + { + TShape like_shape(in_shape->at(1)); + if (dshape.ndim() == 0) return false; + new_height = like_shape[2]; + new_width = like_shape[3]; + break; + } + case bilinear_resize::to_even_down: + { + new_height = ((dshape[2] % 2) == 0) ? dshape[2] : dshape[2] - 1; + new_width = ((dshape[3] % 2) == 0) ? dshape[3] : dshape[3] - 1; + break; + } + case bilinear_resize::to_even_up: + { + new_height = ((dshape[2] % 2) == 0) ? dshape[2] : dshape[2] + 1; + new_width = ((dshape[3] % 2) == 0) ? dshape[3] : dshape[3] + 1; + break; + } + case bilinear_resize::to_odd_down: + { + new_height = ((dshape[2] % 2) == 1) ? dshape[2] : dshape[2] - 1; + new_width = ((dshape[3] % 2) == 1) ? dshape[3] : dshape[3] - 1; + break; + } + case bilinear_resize::to_odd_up: + { + new_height = ((dshape[2] % 2) == 1) ? dshape[2] : dshape[2] + 1; + new_width = ((dshape[3] % 2) == 1) ? dshape[3] : dshape[3] + 1; + break; + } + default: + { + LOG(FATAL) << "Invalid mode " << param.mode; + } } - if (param.scale_height.has_value()) { - dshape[3] = static_cast(param.scale_width.value() * in_shape->at(0)[3]); - } else { - dshape[3] = param.width; - } + dshape[2] = new_height; + dshape[3] = new_width; out_shape->clear(); out_shape->push_back(dshape); return true; } + +inline uint16_t BilinearSampleOpNumInputs(const NodeAttrs& attrs) { + auto& param = nnvm::get(attrs.parsed); + if (param.mode == bilinear_resize::like) { + return 2; + } else { + return 1; + } +} + +inline uint16_t BilinearSampleOpNumBackwardInputs(const NodeAttrs& attrs) { + auto& param = nnvm::get(attrs.parsed); + if (param.mode == bilinear_resize::like) { + return 3; + } else { + return 1; + } +} + +inline uint16_t BilinearSampleOpNumBackwardOutputs(const NodeAttrs& attrs) { + auto& param = nnvm::get(attrs.parsed); + if (param.mode == bilinear_resize::like) { + return 2; + } else { + return 1; + } +} + +inline std::vector BilinearSampleOpInputNames(const NodeAttrs& attrs) { + auto& param = nnvm::get(attrs.parsed); + if (param.mode == bilinear_resize::like) { + return std::vector{"data", "like"}; + } else { + return std::vector{"data"}; + } +} + } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/bilinear_resize.cc b/src/operator/contrib/bilinear_resize.cc index 1288e9d22691..441ea53ad9c6 100644 --- a/src/operator/contrib/bilinear_resize.cc +++ b/src/operator/contrib/bilinear_resize.cc @@ -97,7 +97,8 @@ void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream *s, template void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, const std::vector &input, - const std::vector &output) { + const std::vector &output, + bool modeLike) { Tensor gradOutput = input[0].get(s); Tensor gradInput = output[0].get(s); @@ -108,8 +109,8 @@ void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, int inputHeight = gradInput.size(2); int inputWidth = gradInput.size(3); - DType *data1 = gradInput.dptr_; - DType *data2 = gradOutput.dptr_; + DType *dataInput = gradInput.dptr_; + DType *dataOutput = gradOutput.dptr_; channels = nbatch * channels; // special case: same-size matching grids @@ -118,8 +119,8 @@ void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, const int h1 = h2; for (int w2 = 0; w2 < outputWidth; ++w2) { const int w1 = w2; - DType* pos1 = &data1[h1 * inputWidth + w1]; - const DType* pos2 = &data2[h2 * outputWidth + w2]; + DType* pos1 = &dataInput[h1 * inputWidth + w1]; + const DType* pos2 = &dataOutput[h2 * outputWidth + w2]; for (int c = 0; c < channels; ++c) { pos1[0] += pos2[0]; pos1 += inputWidth * inputHeight; @@ -145,15 +146,32 @@ void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, const int w1p = (w1 < inputWidth - 1) ? 1 : 0; const DType w1lambda = w1r - w1; const DType w0lambda = (DType)1. - w1lambda; - DType* pos1 = &data1[h1 * inputWidth + w1]; - const DType* pos2 = &data2[h2 * outputWidth + w2]; + DType* posInput = &dataInput[h1 * inputWidth + w1]; + const DType* posOutput = &dataOutput[h2 * outputWidth + w2]; for (int c = 0; c < channels; ++c) { - pos1[0] += h0lambda * w0lambda * pos2[0]; - pos1[w1p] += h0lambda * w1lambda * pos2[0]; - pos1[h1p * inputWidth] += h1lambda * w0lambda * pos2[0]; - pos1[h1p * inputWidth + w1p] += h1lambda * w1lambda * pos2[0]; - pos1 += inputWidth * inputHeight; - pos2 += outputWidth * outputHeight; + posInput[0] += h0lambda * w0lambda * posOutput[0]; + posInput[w1p] += h0lambda * w1lambda * posOutput[0]; + posInput[h1p * inputWidth] += h1lambda * w0lambda * posOutput[0]; + posInput[h1p * inputWidth + w1p] += h1lambda * w1lambda * posOutput[0]; + posInput += inputWidth * inputHeight; + posOutput += outputWidth * outputHeight; + } + } + } + + if (modeLike) { + Tensor gradInputLike = output[1].get(s); + int inputHeightLike = gradInputLike.size(2); + int inputWidthLike = gradInputLike.size(3); + DType *dataInputLike = gradInputLike.dptr_; + int channelsLike = nbatch * gradInputLike.size(1); + for (int h_like = 0; h_like < inputHeightLike; ++h_like) { + for (int w_like = 0; w_like < inputWidthLike; ++w_like) { + DType *posInput = &dataInputLike[h_like * inputWidthLike + w_like]; + for (int c = 0; c < channelsLike; ++c) { + posInput[0] = 0; + posInput += inputWidthLike * inputHeightLike; + } } } } @@ -174,19 +192,21 @@ first in one direction, and then again in the other direction. See the wikipedia for more details. )code" ADD_FILELINE) .set_attr_parser(ParamParser) -.set_num_inputs(1) +.set_num_inputs(BilinearSampleOpNumInputs) .set_num_outputs(1) +.set_attr("FListInputNames", BilinearSampleOpInputNames) .set_attr("FInferShape", BilinearSampleOpInferShape) .set_attr("FCompute", BilinearSampleOpForward) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_contrib_BilinearResize2D"}) .add_argument("data", "NDArray-or-Symbol", "Input data") +.add_argument("like", "NDArray-or-Symbol", "Resize data to it's shape") .add_arguments(BilinearSampleParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_contrib_BilinearResize2D) .set_attr_parser(ParamParser) -.set_num_inputs(1) -.set_num_outputs(1) +.set_num_inputs(BilinearSampleOpNumBackwardInputs) +.set_num_outputs(BilinearSampleOpNumBackwardOutputs) .set_attr("TIsBackward", true) .set_attr("FCompute", BilinearSampleOpBackward); diff --git a/src/operator/contrib/bilinear_resize.cu b/src/operator/contrib/bilinear_resize.cu index b0a4c4b316d9..0753c47a4bd7 100644 --- a/src/operator/contrib/bilinear_resize.cu +++ b/src/operator/contrib/bilinear_resize.cu @@ -32,6 +32,26 @@ namespace op { using namespace mshadow; +template +__global__ void like_mode_kernel_backward(const int n, + Tensor dataLike) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + const int batchsize = dataLike.size(0); + const int channels = dataLike.size(1); + const int height = dataLike.size(2); + const int width = dataLike.size(3); + if (index < n) { + const int w = index % width; + const int h = index / width; + for (int n = 0; n < batchsize ; n++) { + for (int c = 0; c < channels; ++c) { + dataLike[n][c][h][w] = 0; + } + } + return; + } +} + // Backward (adjoint) operation 1 <- 2 (accumulates) template __global__ void caffe_gpu_interp2_kernel_backward(const int n, @@ -118,7 +138,8 @@ void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream *s, template void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, const std::vector &input, - const std::vector &output) { + const std::vector &output, + bool modeLike) { Tensor data1 = output[0].get(s); Tensor data2 = input[0].get(s); int height1 = data1.size(2); @@ -135,6 +156,20 @@ void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, caffe_gpu_interp2_kernel_backward <<>>( num_kernels, rheight, rwidth, data1, data2); + + if (modeLike) { + Tensor dataLike = output[1].get(s); + int heightLike = dataLike.size(2); + int widthLike = dataLike.size(3); + const int num_kernels_like = heightLike * widthLike; + const int num_threads_like = getNumThreads(num_kernels_like, false); + dim3 blocksLike(static_cast(num_kernels_like / num_threads_like) + 1); + dim3 threadsLike(num_threads_like); + like_mode_kernel_backward + <<>>( + num_kernels_like, dataLike); + } + MSHADOW_CUDA_POST_KERNEL_CHECK(SpatialUpSamplingBilinearUpdateGradInput); } diff --git a/tests/python/gpu/test_gluon_transforms.py b/tests/python/gpu/test_gluon_transforms.py index 23b34d334888..599a02c7a4f4 100644 --- a/tests/python/gpu/test_gluon_transforms.py +++ b/tests/python/gpu/test_gluon_transforms.py @@ -96,14 +96,14 @@ def test_resize(): data_in_3d = nd.random.uniform(0, 255, (300, 300, 3)) out_nd_3d = transforms.Resize((100, 100))(data_in_3d) data_in_4d_nchw = nd.moveaxis(nd.expand_dims(data_in_3d, axis=0), 3, 1) - data_expected_3d = (nd.moveaxis(nd.contrib.BilinearResize2D(data_in_4d_nchw, 100, 100), 1, 3))[0] + data_expected_3d = (nd.moveaxis(nd.contrib.BilinearResize2D(data_in_4d_nchw, height=100, width=100), 1, 3))[0] assert_almost_equal(out_nd_3d.asnumpy(), data_expected_3d.asnumpy()) # Test with normal case 4D input float type data_in_4d = nd.random.uniform(0, 255, (2, 300, 300, 3)) out_nd_4d = transforms.Resize((100, 100))(data_in_4d) data_in_4d_nchw = nd.moveaxis(data_in_4d, 3, 1) - data_expected_4d = nd.moveaxis(nd.contrib.BilinearResize2D(data_in_4d_nchw, 100, 100), 1, 3) + data_expected_4d = nd.moveaxis(nd.contrib.BilinearResize2D(data_in_4d_nchw, height=100, width=100), 1, 3) assert_almost_equal(out_nd_4d.asnumpy(), data_expected_4d.asnumpy()) # Test invalid interp diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index e8bfaba4736d..2406a1c2f761 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -7164,6 +7164,45 @@ def py_bilinear_resize(x, outputHeight, outputWidth): h1lambda*((1-w1lambda)*x[b][c][h1+h1p][w1] + \ w1lambda*x[b][c][h1+h1p][w1+w1p]) return y + def py_bilinear_resize_backward(x, incoming_grads, mode='size'): + data1 = np.zeros_like(x) + data2 = incoming_grads + batchsize = data1.shape[0] + channels = data1.shape[1] + height1 = data1.shape[2] + width1 = data1.shape[3] + height2 = data2.shape[2] + width2 = data2.shape[3] + rheight = float(height1 - 1) / (height2 - 1) if (height2 > 1) else 0 + rwidth = float(width1 - 1) / (width2 - 1) if (width2 > 1) else 0 + # special case: just copy + if height1 == height2 and width1 == width2: + data1 += data2 + return [data1] + for h2 in range(0, height2): + for w2 in range(0, width2): + h1r = rheight * h2 + h1 = int(h1r) + h1p = 1 if (h1 < height1 - 1) else 0 + h1lambda = h1r - h1 + h0lambda = 1 - h1lambda + # + w1r = rwidth * w2 + w1 = int(w1r) + w1p = 1 if (w1 < width1 - 1) else 0 + w1lambda = w1r - w1 + w0lambda = 1 - w1lambda + # + for n in range(0, batchsize): + for c in range(0, channels): + d2val = data2[n][c][h2][w2] + data1[n][c][h1][w1] += h0lambda * w0lambda * d2val + data1[n][c][h1][w1 + w1p] += h0lambda * w1lambda * d2val + data1[n][c][h1 + h1p][w1] += h1lambda * w0lambda * d2val + data1[n][c][h1 + h1p][w1 + w1p] += h1lambda * w1lambda * d2val + if mode == 'like': + return data1, np.zeros_like(incoming_grads) + return [data1] def check_bilinear_resize_op(shape, height, width): x = mx.nd.random.uniform(shape=shape) y = mx.nd.contrib.BilinearResize2D(x, height=height, width=width) @@ -7173,12 +7212,89 @@ def check_bilinear_resize_op(shape, height, width): y_scale = height / shape[-2] y = mx.nd.contrib.BilinearResize2D(x, scale_height=y_scale, scale_width=x_scale) assert_almost_equal(y.asnumpy(), py_bilinear_resize(x.asnumpy(), height, width)) + def check_bilinear_resize_modes_op(shape, scale_height=None, scale_width=None, shape_1=None, mode=None): + x = mx.nd.random.uniform(shape=shape) + original_h = shape[2] + original_w = shape[3] + if mode == 'odd_scale': + assert scale_height is not None and scale_width is not None + new_h = int(original_h * scale_height) if (original_h % 2) == 0 else \ + int((original_h - 1) * scale_height) + 1 + new_w = int(original_w * scale_width) if (original_w % 2) == 0 \ + else int((original_w - 1) * scale_width) + 1 + y = mx.nd.contrib.BilinearResize2D(x, scale_height=scale_height, + scale_width=scale_width, + mode='odd_scale') + elif mode == 'to_even_down': + new_h = original_h if (original_h % 2) == 0 else original_h - 1 + new_w = original_w if (original_w % 2) == 0 else original_w - 1 + y = mx.nd.contrib.BilinearResize2D(x, mode='to_even_down') + elif mode == 'to_even_up': + new_h = original_h if (original_h % 2) == 0 else original_h + 1 + new_w = original_w if (original_w % 2) == 0 else original_w + 1 + y = mx.nd.contrib.BilinearResize2D(x, mode='to_even_up') + elif mode == 'to_odd_down': + new_h = original_h if (original_h % 2) == 1 else original_h - 1 + new_w = original_w if (original_w % 2) == 1 else original_w - 1 + y = mx.nd.contrib.BilinearResize2D(x, mode='to_odd_down') + elif mode == 'to_odd_up': + new_h = original_h if (original_h % 2) == 1 else original_h + 1 + new_w = original_w if (original_w % 2) == 1 else original_w + 1 + y = mx.nd.contrib.BilinearResize2D(x, mode='to_odd_up') + elif mode == 'like': + x_1 = mx.nd.random.uniform(shape=shape_1) + new_h = x_1.shape[2] + new_w = x_1.shape[3] + y = mx.nd.contrib.BilinearResize2D(x, x_1, mode='like') + new_shape_desired = np.array([shape[0], shape[1], new_h, new_w], dtype='int') + new_shape_got = np.array(y.shape, dtype='int') + data_sym = mx.sym.var('data') + data_np = x.asnumpy() + expected = py_bilinear_resize(data_np, new_h, new_w) + out_grads = np.ones([shape[0], shape[1], new_h, new_w]) + expected_backward = py_bilinear_resize_backward(data_np, out_grads, mode) + assert_array_equal(new_shape_desired, new_shape_got, "Desired and got shapes are not equal. {} vs {}".format( + str(new_shape_desired.tolist()), str(new_shape_got.tolist()))) + assert_almost_equal(y.asnumpy(), expected, 1e-3, 0) + if mode != 'like': + resize_sym = mx.sym.contrib.BilinearResize2D(data_sym, None, scale_height=scale_height, scale_width=scale_width, mode=mode) + check_symbolic_forward(resize_sym, [data_np], [expected], rtol=1e-3) + check_symbolic_backward(resize_sym, [data_np], [out_grads], expected_backward, rtol=1e-3) + check_numeric_gradient(resize_sym, [data_np]) + else: + data_sym_like = mx.sym.var('data_like') + resize_sym = mx.sym.contrib.BilinearResize2D(data_sym, data_sym_like, mode=mode) + date_np_like = x_1.asnumpy() + check_symbolic_forward(resize_sym, [data_np, date_np_like], [expected], rtol=1e-3) + check_symbolic_backward(resize_sym, [data_np, date_np_like], [out_grads], expected_backward, rtol=1e-3) + check_numeric_gradient(resize_sym, [data_np, date_np_like]) + shape = (2, 2, 10, 10) check_bilinear_resize_op(shape, 5, 5) check_bilinear_resize_op(shape, 10, 10) check_bilinear_resize_op(shape, 15, 15) check_bilinear_resize_op(shape, 3, 7) check_bilinear_resize_op(shape, 13, 17) + shape = (2, 2, 20, 20) + check_bilinear_resize_modes_op(shape, scale_height=0.5, scale_width=0.5, mode='odd_scale') + check_bilinear_resize_modes_op(shape, scale_height=5, scale_width=10, mode='odd_scale') + check_bilinear_resize_modes_op(shape, scale_height=0.1, scale_width=0.2, mode='odd_scale') + check_bilinear_resize_modes_op(shape, mode='to_even_down') + check_bilinear_resize_modes_op(shape, mode='to_even_up') + check_bilinear_resize_modes_op(shape, mode='to_odd_down') + check_bilinear_resize_modes_op(shape, mode='to_odd_up') + shape = (2, 2, 21, 21) + check_bilinear_resize_modes_op(shape, scale_height=0.5, scale_width=0.5, mode='odd_scale') + check_bilinear_resize_modes_op(shape, scale_height=5, scale_width=10, mode='odd_scale') + check_bilinear_resize_modes_op(shape, scale_height=0.1, scale_width=0.2, mode='odd_scale') + check_bilinear_resize_modes_op(shape, mode='to_even_down') + check_bilinear_resize_modes_op(shape, mode='to_even_up') + check_bilinear_resize_modes_op(shape, mode='to_odd_down') + check_bilinear_resize_modes_op(shape, mode='to_odd_up') + shape_0 = (2, 2, 21, 21) + shape_1 = (2, 2, 10, 10) + check_bilinear_resize_modes_op(shape_0, shape_1=shape_1, mode='like') + check_bilinear_resize_modes_op(shape_1, shape_1=shape_0, mode='like') def test_multi_proposal_op(): # paramters