From 47d27afad175986bc45c18f0ac38fbd726d06623 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 26 Feb 2019 15:14:49 -0800 Subject: [PATCH 01/13] implement crop --- python/mxnet/gluon/data/vision/transforms.py | 48 +++++++++++++++++++ python/mxnet/image/image.py | 2 +- .../python/unittest/test_gluon_data_vision.py | 37 ++++++++++++++ 3 files changed, 86 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index 9310e15f5133..58f24606ff93 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -228,6 +228,54 @@ def forward(self, x): return image.random_size_crop(x, *self._args)[0] +class Crop(HybridBlock): + """Crop the input image with and optionally resize it. + Makes a crop of the original image then optionally resize it to the specified size. + Parameters + ---------- + x0 : int + Left boundary of the cropping area + y0 : int + Top boundary of the cropping area + w : int + Width of the cropping area + h : int + Height of the cropping area + size : int or tuple of (w, h) + Optional, resize to new size after cropping + interp : int, optional + Optional, interpolation method. See opencv for details. + Inputs: + - **data**: input tensor with (H x W x C) or (N x H x W x C) shape. + Outputs: + - **out**: output tensor with (H x W x C) or (N x H x W x C) shape. + Examples + -------- + >>> transformer = vision.transforms.Crop(0, 0, 100, 100) + >>> image = mx.nd.random.uniform(0, 255, (224, 224, 3)).astype(dtype=np.uint8) + >>> transformer(image) + + >>> image = mx.nd.random.uniform(0, 255, (3, 224, 224, 3)).astype(dtype=np.uint8) + + >>> transformer = vision.transforms.Crop(0, 0, 100, 100, (50, 50), 1) + >>> transformer(image) + + """ + def __init__(self, x0, y0, width, height, size=None, interpolation=None): + super(Crop, self).__init__() + self._x0 = x0 + self._y0 = y0 + self._width = width + self._height = height + self._size = size + self._interpolation = interpolation + + def hybrid_forward(self, F, x): + out = F.image.crop(x, self._x0, self._y0, self._width, self._height) + if self._size is not None: + out = F.image.resize(out, self._size, False, self._interpolation) + return out + class CenterCrop(Block): """Crops the image `src` to the given `size` by trimming on all four sides and preserving the center of the image. Upsamples if `src` is diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index 1dd665607597..d2631e810529 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -428,7 +428,7 @@ def fixed_crop(src, x0, y0, w, h, size=None, interp=2): NDArray An `NDArray` containing the cropped image. """ - out = nd.crop(src, begin=(y0, x0, 0), end=(y0 + h, x0 + w, int(src.shape[2]))) + out = nd.slice(src, begin=(y0, x0, 0), end=(y0 + h, x0 + w, int(src.shape[2]))) if size is not None and (w, h) != size: sizes = (h, w, size[1], size[0]) out = imresize(out, *size, interp=_get_interp_method(interp, sizes)) diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py index a855fc8cf1df..b621b734184a 100644 --- a/tests/python/unittest/test_gluon_data_vision.py +++ b/tests/python/unittest/test_gluon_data_vision.py @@ -21,6 +21,7 @@ from mxnet import gluon from mxnet.base import MXNetError from mxnet.gluon.data.vision import transforms +from mxnet import image from mxnet.test_utils import assert_almost_equal from mxnet.test_utils import almost_equal from common import assertRaises, setup_module, with_seed, teardown @@ -118,6 +119,42 @@ def _test_resize_with_diff_type(dtype): _test_resize_with_diff_type(dtype) +@with_seed() +def test_crop(): + def _test_crop_with_diff_type(dtype): + # test normal case + data_in = nd.arange(60).reshape((5, 4, 3)).astype('uint8') + out_nd = transforms.Crop(0, 0, 3, 2)(data_in) + out_np = out_nd.asnumpy() + assert(out_np.sum() == 180) + assert((out_np[0:2,1,1].flatten() == [4, 16]).all()) + # test 4D input + data_bath_in = nd.arange(180).reshape((2, 6, 5, 3)).astype('uint8') + out_batch_nd = transforms.Crop(1, 2, 3, 4)(data_bath_in) + out_batch_np = out_batch_nd.asnumpy() + assert(out_batch_np.sum() == 7524) + assert((out_batch_np[0:2,0:4,1,1].flatten() == [37, 52, 67, 82, 127, 142, 157, 172]).all()) + # test normal case with resize + data_in = nd.random.uniform(0, 255, (300, 200, 3)).astype('uint8') + out_nd = transforms.Crop(0, 0, 100, 50, (25, 25), 2)(data_in) + data_expected = image.imresize(nd.slice(data_in, (0, 0, 0), (50, 100 , 3)), 25, 25, 2) + assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy()) + # test 4D input with resize + data_bath_in = nd.random.uniform(0, 255, (3, 300, 200, 3)).astype('uint8') + out_batch_nd = transforms.Crop(0, 0, 100, 50, (25, 25), 2)(data_bath_in) + for i in range(len(out_batch_nd)): + assert_almost_equal(image.imresize(nd.slice(data_bath_in[i], (0, 0, 0), (50, 100, 3)), 25, 25, 2).asnumpy(), + out_batch_nd[i].asnumpy()) + # test with resize height and width should be greater than 0 + transformer = transforms.Crop(0, 0, 100, 50, (-25, 25), 2) + assertRaises(MXNetError, transformer, data_in) + # test height and width should be greater than 0 + transformer = transforms.Crop(0, 0, -100, -50) + assertRaises(MXNetError, transformer, data_in) + + for dtype in ['uint8', 'float32', 'float64']: + _test_crop_with_diff_type(dtype) + @with_seed() def test_flip_left_right(): data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8) From 8c1254b159d8c176c778102c574004c1f68bed00 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 26 Feb 2019 15:21:05 -0800 Subject: [PATCH 02/13] add crop operator --- src/operator/image/crop-inl.h | 139 ++++++++++++++++++++++++++++++++++ src/operator/image/crop.cc | 78 +++++++++++++++++++ 2 files changed, 217 insertions(+) create mode 100644 src/operator/image/crop-inl.h create mode 100644 src/operator/image/crop.cc diff --git a/src/operator/image/crop-inl.h b/src/operator/image/crop-inl.h new file mode 100644 index 000000000000..69a4f525d342 --- /dev/null +++ b/src/operator/image/crop-inl.h @@ -0,0 +1,139 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +/*! + * Copyright (c) 2019 by Contributors + * \file crop-inl.h + * \brief the image crop operator implementation + */ + +#ifndef MXNET_OPERATOR_IMAGE_CROP_INL_H_ +#define MXNET_OPERATOR_IMAGE_CROP_INL_H_ + + +#include +#include + +#include "mxnet/base.h" +#include "dmlc/optional.h" +#include "image_utils.h" +#include "../mxnet_op.h" +#include "../operator_common.h" +#include "../../common/static_array.h" +#include "../tensor/matrix_op-inl.h" +#include "resize-inl.h" + +namespace mxnet { +namespace op { +namespace image { + +struct CropParam : public dmlc::Parameter { + int x; + int y; + int width; + int height; + DMLC_DECLARE_PARAMETER(CropParam) { + DMLC_DECLARE_FIELD(x) + .describe("Left boundary of the cropping area."); + DMLC_DECLARE_FIELD(y) + .describe("Top boundary of the cropping area."); + DMLC_DECLARE_FIELD(width) + .describe("Width of the cropping area."); + DMLC_DECLARE_FIELD(height) + .describe("Top boundary of the cropping area"); + } +}; + +inline bool CropShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + // input attrs should only be (h, w, c) or (n, h, w, c) + CHECK((in_attrs->at(0).ndim() == 3U) || (in_attrs->at(0).ndim() == 4U)) + << "Input image dimension should be 3 or 4 but got " + << in_attrs->at(0).ndim(); + + const auto& ishape = (*in_attrs)[0]; + const CropParam& param = nnvm::get(attrs.parsed); + + CHECK((param.height > 0) && (param.width > 0)) + << "Input height and width must be greater than 0"; + if (ishape.ndim() == 3) { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({param.height, param.width, ishape[C]})); + } else { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({ishape[N], param.height, param.width, ishape[kC]})); + } + return true; +} + +inline void CropImpl(int x, + int y, + int width, + int height, + const std::vector &inputs, + const std::vector &outputs, + const OpContext &ctx, + const std::vector &req) { + using namespace mshadow; + // invalid param + const TBlob& data = inputs[0]; + const TBlob& out = outputs[0]; + MXNET_NDIM_SWITCH(data.ndim(), ndim, { + CHECK(x + width <= data.shape_[ndim - 2]) + << " x + width should not be greater than input width"; + CHECK(y + height <= data.shape_[ndim - 3]) + << " y + height should not be greater than input height"; + Stream* s = ctx.get_stream(); + common::StaticArray begin = {0}, step = {1}; + if (ndim == 3) { + begin[0] = y; + begin[1] = x; + } else { + begin[1] = y; + begin[2] = x; + } + MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { + Tensor input_tensor = data.get(s); + Tensor output_tensor = out.get(s); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + size_t num_threads = out.shape_.FlatTo2D()[0]; + mxnet_op::Kernel, cpu>::Launch(s, num_threads, + output_tensor.dptr_, input_tensor.dptr_, + input_tensor.shape_, output_tensor.shape_, begin, step); + }) + }) + }) +} + +inline void Crop(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(outputs.size(), 1U); + const CropParam& param = nnvm::get(attrs.parsed); + CHECK((param.height > 0) && (param.width > 0)) + << "Input height and width must be greater than 0"; + + CropImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req); +} +} // namespace image +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_IMAGE_CROP_INL_H_ diff --git a/src/operator/image/crop.cc b/src/operator/image/crop.cc new file mode 100644 index 000000000000..4e7ef4359328 --- /dev/null +++ b/src/operator/image/crop.cc @@ -0,0 +1,78 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +/*! + * Copyright (c) 2019 by Contributors + * \file crop-cc.h + * \brief the image crop operator registration + */ + +#include "mxnet/base.h" +#include "crop-inl.h" +#include "../operator_common.h" +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { +namespace image { + +DMLC_REGISTER_PARAMETER(CropParam); + +NNVM_REGISTER_OP(_image_crop) +.describe(R"code(Crop an image NDArray of shape (H x W x C) or (N x H x W x C) +to the given size. +Example: + .. code-block:: python + image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8) + mx.nd.image.crop(image, 1, 1, 2, 2) + [[[144 34 4] + [ 82 157 38]] + + [[156 111 230] + [177 25 15]]] + + image = mx.nd.random.uniform(0, 255, (2, 4, 2, 3)).astype(dtype=np.uint8) + mx.nd.image.crop(image, 1, 1, 2, 2) + [[[[ 35 198 50] + [242 94 168]] + + [[223 119 129] + [249 14 154]]] + + + [[[137 215 106] + [ 79 174 133]] + + [[116 142 109] + [ 35 239 50]]]] + +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", CropShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCompute", Crop) +.set_attr("FGradient", ElemwiseGradUseNone{ "_copy" }) +.add_argument("data", "NDArray-or-Symbol", "The input.") +.add_arguments(CropParam::__FIELDS__()); + +} // namespace image +} // namespace op +} // namespace mxnet From 460694aca674206768746506d8f33cb622501196 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 26 Feb 2019 15:38:22 -0800 Subject: [PATCH 03/13] fix for linter --- python/mxnet/gluon/data/vision/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index 58f24606ff93..598b3eb8db92 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -274,7 +274,7 @@ def hybrid_forward(self, F, x): out = F.image.crop(x, self._x0, self._y0, self._width, self._height) if self._size is not None: out = F.image.resize(out, self._size, False, self._interpolation) - return out + return out class CenterCrop(Block): """Crops the image `src` to the given `size` by trimming on all four From 9ceda613dcee78dda55412fec9e9b72d8b959a83 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Thu, 14 Mar 2019 14:03:16 -0700 Subject: [PATCH 04/13] add. backword and refactor the code --- python/mxnet/gluon/data/vision/transforms.py | 18 ++--- src/operator/image/crop-inl.h | 69 +++++++++++++++---- src/operator/image/crop.cc | 9 ++- .../python/unittest/test_gluon_data_vision.py | 22 +++--- 4 files changed, 85 insertions(+), 33 deletions(-) diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index 598b3eb8db92..cad601c956d8 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -228,14 +228,14 @@ def forward(self, x): return image.random_size_crop(x, *self._args)[0] -class Crop(HybridBlock): +class CropResize(HybridBlock): """Crop the input image with and optionally resize it. Makes a crop of the original image then optionally resize it to the specified size. Parameters ---------- - x0 : int + x : int Left boundary of the cropping area - y0 : int + y : int Top boundary of the cropping area w : int Width of the cropping area @@ -261,18 +261,18 @@ class Crop(HybridBlock): >>> transformer(image) """ - def __init__(self, x0, y0, width, height, size=None, interpolation=None): - super(Crop, self).__init__() - self._x0 = x0 - self._y0 = y0 + def __init__(self, x, y, width, height, size=None, interpolation=None): + super(CropResize, self).__init__() + self._x = x + self._y = y self._width = width self._height = height self._size = size self._interpolation = interpolation def hybrid_forward(self, F, x): - out = F.image.crop(x, self._x0, self._y0, self._width, self._height) - if self._size is not None: + out = F.image.crop(x, self._x, self._y, self._width, self._height) + if self._size: out = F.image.resize(out, self._size, False, self._interpolation) return out diff --git a/src/operator/image/crop-inl.h b/src/operator/image/crop-inl.h index 69a4f525d342..d8072f42b681 100644 --- a/src/operator/image/crop-inl.h +++ b/src/operator/image/crop-inl.h @@ -56,7 +56,7 @@ struct CropParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(width) .describe("Width of the cropping area."); DMLC_DECLARE_FIELD(height) - .describe("Top boundary of the cropping area"); + .describe("Height of the cropping area."); } }; @@ -73,6 +73,10 @@ inline bool CropShape(const nnvm::NodeAttrs& attrs, CHECK((param.height > 0) && (param.width > 0)) << "Input height and width must be greater than 0"; + CHECK(param.x + param.width <= ishape[ishape.ndim() - 2]) + << " x + width should not be greater than input width"; + CHECK(param.y + param.height <= ishape[ishape.ndim() - 3]) + << " y + height should not be greater than input height"; if (ishape.ndim() == 3) { SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({param.height, param.width, ishape[C]})); } else { @@ -94,10 +98,6 @@ inline void CropImpl(int x, const TBlob& data = inputs[0]; const TBlob& out = outputs[0]; MXNET_NDIM_SWITCH(data.ndim(), ndim, { - CHECK(x + width <= data.shape_[ndim - 2]) - << " x + width should not be greater than input width"; - CHECK(y + height <= data.shape_[ndim - 3]) - << " y + height should not be greater than input height"; Stream* s = ctx.get_stream(); common::StaticArray begin = {0}, step = {1}; if (ndim == 3) { @@ -108,30 +108,73 @@ inline void CropImpl(int x, begin[2] = x; } MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { - Tensor input_tensor = data.get(s); - Tensor output_tensor = out.get(s); MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { size_t num_threads = out.shape_.FlatTo2D()[0]; mxnet_op::Kernel, cpu>::Launch(s, num_threads, - output_tensor.dptr_, input_tensor.dptr_, - input_tensor.shape_, output_tensor.shape_, begin, step); + out.dptr(), data.dptr(), + data.shape_.get(), out.shape_.get(), begin, step); + }) + }) + }) +} + +inline void CropBackwardImpl(int x, + int y, + int width, + int height, + const std::vector &inputs, + const std::vector &outputs, + const OpContext &ctx, + const std::vector &req) { + using namespace mshadow; + if (req[0] == kNullOp) return; + const TBlob& output_grad = inputs[0]; + const TBlob& input_grad = outputs[0]; + Stream* s = ctx.get_stream(); + if (req[0] == kWriteTo) { + Fill(s, input_grad, req[0], 0); + } else if (req[0] == kWriteInplace) { + LOG(FATAL) << "_backward_image_crop does not support kWriteInplace"; + } + MXNET_NDIM_SWITCH(output_grad.ndim(), ndim, { + common::StaticArray begin = {0}, step = {1}; + if (ndim == 3) { + begin[0] = y; + begin[1] = x; + } else { + begin[1] = y; + begin[2] = x; + } + MSHADOW_TYPE_SWITCH(input_grad.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + size_t num_threads = output_grad.shape_.FlatTo2D()[0]; + mxnet_op::Kernel, cpu>::Launch(s, num_threads, + input_grad.dptr(), output_grad.dptr(), + output_grad.shape_.get(), input_grad.shape_.get(), begin, step); }) }) }) } -inline void Crop(const nnvm::NodeAttrs &attrs, +inline void CropOpForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { CHECK_EQ(outputs.size(), 1U); const CropParam& param = nnvm::get(attrs.parsed); - CHECK((param.height > 0) && (param.width > 0)) - << "Input height and width must be greater than 0"; - CropImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req); } + +inline void CropOpBackward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(outputs.size(), 1U); + const CropParam& param = nnvm::get(attrs.parsed); + CropBackwardImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req); +} } // namespace image } // namespace op } // namespace mxnet diff --git a/src/operator/image/crop.cc b/src/operator/image/crop.cc index 4e7ef4359328..7b9d857cb668 100644 --- a/src/operator/image/crop.cc +++ b/src/operator/image/crop.cc @@ -68,11 +68,16 @@ to the given size. .set_attr_parser(ParamParser) .set_attr("FInferShape", CropShape) .set_attr("FInferType", ElemwiseType<1, 1>) -.set_attr("FCompute", Crop) -.set_attr("FGradient", ElemwiseGradUseNone{ "_copy" }) +.set_attr("FCompute", CropOpForward) +.set_attr("FGradient", ElemwiseGradUseNone{ "_backward_image_crop" }) .add_argument("data", "NDArray-or-Symbol", "The input.") .add_arguments(CropParam::__FIELDS__()); +NNVM_REGISTER_OP(_backward_image_crop) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FCompute", CropOpBackward); + } // namespace image } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py index b621b734184a..ee15057d4e67 100644 --- a/tests/python/unittest/test_gluon_data_vision.py +++ b/tests/python/unittest/test_gluon_data_vision.py @@ -120,40 +120,44 @@ def _test_resize_with_diff_type(dtype): @with_seed() -def test_crop(): - def _test_crop_with_diff_type(dtype): +def test_crop_resize(): + def _test_crop_resize_with_diff_type(dtype): # test normal case data_in = nd.arange(60).reshape((5, 4, 3)).astype('uint8') - out_nd = transforms.Crop(0, 0, 3, 2)(data_in) + out_nd = transforms.CropResize(0, 0, 3, 2)(data_in) out_np = out_nd.asnumpy() assert(out_np.sum() == 180) assert((out_np[0:2,1,1].flatten() == [4, 16]).all()) # test 4D input data_bath_in = nd.arange(180).reshape((2, 6, 5, 3)).astype('uint8') - out_batch_nd = transforms.Crop(1, 2, 3, 4)(data_bath_in) + out_batch_nd = transforms.CropResize(1, 2, 3, 4)(data_bath_in) out_batch_np = out_batch_nd.asnumpy() assert(out_batch_np.sum() == 7524) assert((out_batch_np[0:2,0:4,1,1].flatten() == [37, 52, 67, 82, 127, 142, 157, 172]).all()) # test normal case with resize data_in = nd.random.uniform(0, 255, (300, 200, 3)).astype('uint8') - out_nd = transforms.Crop(0, 0, 100, 50, (25, 25), 2)(data_in) + out_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_in) data_expected = image.imresize(nd.slice(data_in, (0, 0, 0), (50, 100 , 3)), 25, 25, 2) assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy()) # test 4D input with resize data_bath_in = nd.random.uniform(0, 255, (3, 300, 200, 3)).astype('uint8') - out_batch_nd = transforms.Crop(0, 0, 100, 50, (25, 25), 2)(data_bath_in) + out_batch_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_bath_in) for i in range(len(out_batch_nd)): assert_almost_equal(image.imresize(nd.slice(data_bath_in[i], (0, 0, 0), (50, 100, 3)), 25, 25, 2).asnumpy(), out_batch_nd[i].asnumpy()) # test with resize height and width should be greater than 0 - transformer = transforms.Crop(0, 0, 100, 50, (-25, 25), 2) + transformer = transforms.CropResize(0, 0, 100, 50, (-25, 25), 2) assertRaises(MXNetError, transformer, data_in) # test height and width should be greater than 0 - transformer = transforms.Crop(0, 0, -100, -50) + transformer = transforms.CropResize(0, 0, -100, -50) assertRaises(MXNetError, transformer, data_in) + # test cropped area is bigger than input data + transformer = transforms.CropResize(150, 200, 200, 500) + assertRaises(MXNetError, transformer, data_in) + assertRaises(MXNetError, transformer, data_bath_in) for dtype in ['uint8', 'float32', 'float64']: - _test_crop_with_diff_type(dtype) + _test_crop_resize_with_diff_type(dtype) @with_seed() def test_flip_left_right(): From c015d4794f6eebc1e6576b2ec25ac1e268a62b7d Mon Sep 17 00:00:00 2001 From: stu1130 Date: Thu, 14 Mar 2019 14:41:26 -0700 Subject: [PATCH 05/13] fix error namespace --- src/operator/image/crop.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/operator/image/crop.cc b/src/operator/image/crop.cc index 7b9d857cb668..52d2f11a464b 100644 --- a/src/operator/image/crop.cc +++ b/src/operator/image/crop.cc @@ -66,7 +66,7 @@ to the given size. .set_num_inputs(1) .set_num_outputs(1) .set_attr_parser(ParamParser) -.set_attr("FInferShape", CropShape) +.set_attr("FInferShape", CropShape) .set_attr("FInferType", ElemwiseType<1, 1>) .set_attr("FCompute", CropOpForward) .set_attr("FGradient", ElemwiseGradUseNone{ "_backward_image_crop" }) @@ -75,6 +75,8 @@ to the given size. NNVM_REGISTER_OP(_backward_image_crop) .set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) .set_attr("TIsBackward", true) .set_attr("FCompute", CropOpBackward); From d0da9a87967bac644cda325067e8b9f9d55d083c Mon Sep 17 00:00:00 2001 From: stu1130 Date: Thu, 14 Mar 2019 18:22:00 -0700 Subject: [PATCH 06/13] fix the website build failure --- python/mxnet/gluon/data/vision/transforms.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index cad601c956d8..4ed59ce91ed8 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -229,7 +229,8 @@ def forward(self, x): class CropResize(HybridBlock): - """Crop the input image with and optionally resize it. + r"""Crop the input image with and optionally resize it. + Makes a crop of the original image then optionally resize it to the specified size. Parameters ---------- @@ -244,11 +245,15 @@ class CropResize(HybridBlock): size : int or tuple of (w, h) Optional, resize to new size after cropping interp : int, optional - Optional, interpolation method. See opencv for details. + Optional, interpolation method. See opencv for details + + Inputs: - **data**: input tensor with (H x W x C) or (N x H x W x C) shape. + Outputs: - - **out**: output tensor with (H x W x C) or (N x H x W x C) shape. + - **out**: input tensor with (H x W x C) or (N x H x W x C) shape. + Examples -------- >>> transformer = vision.transforms.Crop(0, 0, 100, 100) From 6dc6bcc1a0378dcc5b14c87b4f33f8f6fe79099f Mon Sep 17 00:00:00 2001 From: stu1130 Date: Mon, 18 Mar 2019 13:05:21 -0700 Subject: [PATCH 07/13] start adding the unit test of backword --- tests/python/unittest/test_gluon_data_vision.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py index ee15057d4e67..4716a7dcaa34 100644 --- a/tests/python/unittest/test_gluon_data_vision.py +++ b/tests/python/unittest/test_gluon_data_vision.py @@ -123,24 +123,24 @@ def _test_resize_with_diff_type(dtype): def test_crop_resize(): def _test_crop_resize_with_diff_type(dtype): # test normal case - data_in = nd.arange(60).reshape((5, 4, 3)).astype('uint8') + data_in = nd.arange(60).reshape((5, 4, 3)).astype(dtype) out_nd = transforms.CropResize(0, 0, 3, 2)(data_in) out_np = out_nd.asnumpy() assert(out_np.sum() == 180) assert((out_np[0:2,1,1].flatten() == [4, 16]).all()) # test 4D input - data_bath_in = nd.arange(180).reshape((2, 6, 5, 3)).astype('uint8') + data_bath_in = nd.arange(180).reshape((2, 6, 5, 3)).astype(dtype) out_batch_nd = transforms.CropResize(1, 2, 3, 4)(data_bath_in) out_batch_np = out_batch_nd.asnumpy() assert(out_batch_np.sum() == 7524) assert((out_batch_np[0:2,0:4,1,1].flatten() == [37, 52, 67, 82, 127, 142, 157, 172]).all()) # test normal case with resize - data_in = nd.random.uniform(0, 255, (300, 200, 3)).astype('uint8') + data_in = nd.random.uniform(0, 255, (300, 200, 3)).astype(dtype) out_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_in) data_expected = image.imresize(nd.slice(data_in, (0, 0, 0), (50, 100 , 3)), 25, 25, 2) assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy()) # test 4D input with resize - data_bath_in = nd.random.uniform(0, 255, (3, 300, 200, 3)).astype('uint8') + data_bath_in = nd.random.uniform(0, 255, (3, 300, 200, 3)).astype(dtype) out_batch_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_bath_in) for i in range(len(out_batch_nd)): assert_almost_equal(image.imresize(nd.slice(data_bath_in[i], (0, 0, 0), (50, 100, 3)), 25, 25, 2).asnumpy(), @@ -157,7 +157,13 @@ def _test_crop_resize_with_diff_type(dtype): assertRaises(MXNetError, transformer, data_bath_in) for dtype in ['uint8', 'float32', 'float64']: - _test_crop_resize_with_diff_type(dtype) + _test_crop_resize_with_diff_type(dtype) + # test for gradient + data = mx.sym.Variable('data') + slice_sym = mx.sym.slice(data, begin=begin, end=end, step=step) + expected_in_grad = np.zeros_like(a_np) + expected_in_grad[index] = b_np + check_symbolic_backward(slice_sym, [a_np], [b_np], [expected_in_grad]) @with_seed() def test_flip_left_right(): From ac738c093075e2925fc2f0037e8656080807d691 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Mon, 1 Apr 2019 15:20:47 -0700 Subject: [PATCH 08/13] add unit test for backward --- src/operator/image/crop-inl.h | 4 +- .../python/unittest/test_gluon_data_vision.py | 40 +++++++++++++++---- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/src/operator/image/crop-inl.h b/src/operator/image/crop-inl.h index d8072f42b681..847863ba0d54 100644 --- a/src/operator/image/crop-inl.h +++ b/src/operator/image/crop-inl.h @@ -145,12 +145,12 @@ inline void CropBackwardImpl(int x, begin[1] = y; begin[2] = x; } - MSHADOW_TYPE_SWITCH(input_grad.type_flag_, DType, { + MSHADOW_TYPE_SWITCH(output_grad.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { size_t num_threads = output_grad.shape_.FlatTo2D()[0]; mxnet_op::Kernel, cpu>::Launch(s, num_threads, input_grad.dptr(), output_grad.dptr(), - output_grad.shape_.get(), input_grad.shape_.get(), begin, step); + input_grad.shape_.get(), output_grad.shape_.get(), begin, step); }) }) }) diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py index 4716a7dcaa34..cc15bec5dee9 100644 --- a/tests/python/unittest/test_gluon_data_vision.py +++ b/tests/python/unittest/test_gluon_data_vision.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. from __future__ import print_function +from collections import namedtuple + import mxnet as mx import mxnet.ndarray as nd from mxnet.base import MXNetError @@ -22,8 +24,7 @@ from mxnet.base import MXNetError from mxnet.gluon.data.vision import transforms from mxnet import image -from mxnet.test_utils import assert_almost_equal -from mxnet.test_utils import almost_equal +from mxnet.test_utils import * from common import assertRaises, setup_module, with_seed, teardown import numpy as np @@ -158,12 +159,35 @@ def _test_crop_resize_with_diff_type(dtype): for dtype in ['uint8', 'float32', 'float64']: _test_crop_resize_with_diff_type(dtype) - # test for gradient - data = mx.sym.Variable('data') - slice_sym = mx.sym.slice(data, begin=begin, end=end, step=step) - expected_in_grad = np.zeros_like(a_np) - expected_in_grad[index] = b_np - check_symbolic_backward(slice_sym, [a_np], [b_np], [expected_in_grad]) + + # test nd.image.crop backward + def test_crop_backward(test_nd_arr, TestCase): + a_np = test_nd_arr.asnumpy() + b_np = a_np[(slice(TestCase.y, TestCase.y + TestCase.height), slice(TestCase.x, TestCase.x + TestCase.width), slice(0, 3))] + + data = mx.sym.Variable('data') + crop_sym = mx.sym.image.crop(data, TestCase.x, TestCase.y, TestCase.width, TestCase.height) + + expected_in_grad = np.zeros_like(a_np) + expected_in_grad[(slice(TestCase.y, TestCase.y + TestCase.height), slice(TestCase.x, TestCase.x + TestCase.width), slice(0, 3))] = b_np + check_symbolic_backward(crop_sym, [a_np], [b_np], [expected_in_grad]) + + TestCase = namedtuple('TestCase', ['x', 'y', 'width', 'height']) + test_list = [TestCase(0, 0, 3, 3), TestCase(2, 1, 1, 2), TestCase(0, 1, 3, 2)] + + for dtype in ['uint8', 'float32', 'float64']: + data_in = nd.arange(60).reshape((5, 4, 3)).astype(dtype) + for test_case in test_list: + test_crop_backward(data_in, test_case) + + + + # check numeric gradient of nd.image.crop + # in_data = np.arange(36).reshape(3, 4, 3) + # data = mx.sym.Variable('data') + # image_crop_sym = mx.sym.image.crop(data, 0, 0, 2, 2) + # check_numeric_gradient(image_crop_sym, [in_data]) + @with_seed() def test_flip_left_right(): From 11dbda755d9331615307025d968ad2fe400846e9 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Wed, 3 Apr 2019 10:01:06 -0700 Subject: [PATCH 09/13] address the comment --- python/mxnet/gluon/data/vision/transforms.py | 18 ++++++++++------ src/operator/image/crop-inl.h | 22 +++++++++++++------- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index 4ed59ce91ed8..fe680181aceb 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -244,8 +244,14 @@ class CropResize(HybridBlock): Height of the cropping area size : int or tuple of (w, h) Optional, resize to new size after cropping - interp : int, optional - Optional, interpolation method. See opencv for details + interpolation : int, optional + Optional, interpolation method for resizing. By default uses bilinear + interpolation. See OpenCV's resize function for available choices. + https://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=resize#resize + Note that the Resize on gpu use contrib.bilinearResize2D operator + which only support bilinear interpolation(1). The result would be slightly + different on gpu compared to cpu. OpenCV tend to align center while bilinearResize2D + use algorithm which aligns corner. Inputs: @@ -256,13 +262,13 @@ class CropResize(HybridBlock): Examples -------- - >>> transformer = vision.transforms.Crop(0, 0, 100, 100) + >>> transformer = vision.transforms.CropResize(x=0, y=0, width=100, height=100) >>> image = mx.nd.random.uniform(0, 255, (224, 224, 3)).astype(dtype=np.uint8) >>> transformer(image) - + >>> image = mx.nd.random.uniform(0, 255, (3, 224, 224, 3)).astype(dtype=np.uint8) - - >>> transformer = vision.transforms.Crop(0, 0, 100, 100, (50, 50), 1) + + >>> transformer = vision.transforms.CropResize(x=0, y=0, width=100, height=100, size=(50, 50), interpolation=1) >>> transformer(image) """ diff --git a/src/operator/image/crop-inl.h b/src/operator/image/crop-inl.h index 847863ba0d54..c1638c0ea0aa 100644 --- a/src/operator/image/crop-inl.h +++ b/src/operator/image/crop-inl.h @@ -64,19 +64,28 @@ inline bool CropShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { // input attrs should only be (h, w, c) or (n, h, w, c) - CHECK((in_attrs->at(0).ndim() == 3U) || (in_attrs->at(0).ndim() == 4U)) - << "Input image dimension should be 3 or 4 but got " - << in_attrs->at(0).ndim(); + if (in_attrs->at(0).ndim() == 3U) { + CHECK((in_attrs->at(0)[2] == 1) || (in_attrs->at(0)[2] == 3)) + << "Expect channel of the input image is 1 or 3, but got" + << in_attrs->at(0)[2]; + } else if { + CHECK((in_attrs->at(0)[3] == 1) || (in_attrs->at(0)[3] == 3)) + << "Expect channel of the input image is 1 or 3, but got" + << in_attrs->at(0)[3]; + } else { + LOG(FATAL) << "Image Crop expects inputs of 3D (h, w, c) or 4D (n, h, w, c). But got " + << in_attrs->at(0).ndim(); + } const auto& ishape = (*in_attrs)[0]; const CropParam& param = nnvm::get(attrs.parsed); CHECK((param.height > 0) && (param.width > 0)) - << "Input height and width must be greater than 0"; + << "Input height and width must be greater than 0"; CHECK(param.x + param.width <= ishape[ishape.ndim() - 2]) - << " x + width should not be greater than input width"; + << " x + width should not be greater than input width"; CHECK(param.y + param.height <= ishape[ishape.ndim() - 3]) - << " y + height should not be greater than input height"; + << " y + height should not be greater than input height"; if (ishape.ndim() == 3) { SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({param.height, param.width, ishape[C]})); } else { @@ -94,7 +103,6 @@ inline void CropImpl(int x, const OpContext &ctx, const std::vector &req) { using namespace mshadow; - // invalid param const TBlob& data = inputs[0]; const TBlob& out = outputs[0]; MXNET_NDIM_SWITCH(data.ndim(), ndim, { From e7388693ec5e7368fb5dcdeb6819d95321f8e484 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Wed, 3 Apr 2019 10:06:36 -0700 Subject: [PATCH 10/13] add missing statement --- src/operator/image/crop-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/image/crop-inl.h b/src/operator/image/crop-inl.h index c1638c0ea0aa..a1a4b23f658e 100644 --- a/src/operator/image/crop-inl.h +++ b/src/operator/image/crop-inl.h @@ -68,7 +68,7 @@ inline bool CropShape(const nnvm::NodeAttrs& attrs, CHECK((in_attrs->at(0)[2] == 1) || (in_attrs->at(0)[2] == 3)) << "Expect channel of the input image is 1 or 3, but got" << in_attrs->at(0)[2]; - } else if { + } else if (in_attrs->at(0).ndim() == 4U) { CHECK((in_attrs->at(0)[3] == 1) || (in_attrs->at(0)[3] == 3)) << "Expect channel of the input image is 1 or 3, but got" << in_attrs->at(0)[3]; From 7aabda616cd71fff3e32bb2c2720a60ec822918e Mon Sep 17 00:00:00 2001 From: stu1130 Date: Wed, 3 Apr 2019 10:37:50 -0700 Subject: [PATCH 11/13] fix the website error --- python/mxnet/gluon/data/vision/transforms.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index fe680181aceb..e06a5cf272f4 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -245,9 +245,8 @@ class CropResize(HybridBlock): size : int or tuple of (w, h) Optional, resize to new size after cropping interpolation : int, optional - Optional, interpolation method for resizing. By default uses bilinear + Interpolation method for resizing. By default uses bilinear interpolation. See OpenCV's resize function for available choices. - https://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=resize#resize Note that the Resize on gpu use contrib.bilinearResize2D operator which only support bilinear interpolation(1). The result would be slightly different on gpu compared to cpu. OpenCV tend to align center while bilinearResize2D From 352dedecb4837bb146e94a568b632fe612ab9fd1 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Wed, 3 Apr 2019 12:06:53 -0700 Subject: [PATCH 12/13] fix the website building --- python/mxnet/gluon/data/vision/transforms.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index e06a5cf272f4..c7822456dc61 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -232,6 +232,7 @@ class CropResize(HybridBlock): r"""Crop the input image with and optionally resize it. Makes a crop of the original image then optionally resize it to the specified size. + Parameters ---------- x : int @@ -247,6 +248,7 @@ class CropResize(HybridBlock): interpolation : int, optional Interpolation method for resizing. By default uses bilinear interpolation. See OpenCV's resize function for available choices. + https://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=resize#resize Note that the Resize on gpu use contrib.bilinearResize2D operator which only support bilinear interpolation(1). The result would be slightly different on gpu compared to cpu. OpenCV tend to align center while bilinearResize2D From 29264953bcb0e3e17e0649354d9a49bf6cc6f014 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Wed, 3 Apr 2019 14:12:43 -0700 Subject: [PATCH 13/13] add missing doc --- python/mxnet/gluon/data/vision/transforms.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index c7822456dc61..dff7f66b032d 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -268,6 +268,7 @@ class CropResize(HybridBlock): >>> transformer(image) >>> image = mx.nd.random.uniform(0, 255, (3, 224, 224, 3)).astype(dtype=np.uint8) + >>> transformer(image) >>> transformer = vision.transforms.CropResize(x=0, y=0, width=100, height=100, size=(50, 50), interpolation=1) >>> transformer(image)