From 6e101b7d9b15c0e169635634bed76f461fb32ccb Mon Sep 17 00:00:00 2001 From: stu1130 Date: Thu, 14 Mar 2019 14:03:16 -0700 Subject: [PATCH] add. backword and refactor the code --- python/mxnet/gluon/data/vision/transforms.py | 18 ++--- src/operator/image/crop-inl.h | 69 +++++++++++++++---- src/operator/image/crop.cc | 9 ++- .../python/unittest/test_gluon_data_vision.py | 22 +++--- 4 files changed, 85 insertions(+), 33 deletions(-) diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index 598b3eb8db92..cad601c956d8 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -228,14 +228,14 @@ def forward(self, x): return image.random_size_crop(x, *self._args)[0] -class Crop(HybridBlock): +class CropResize(HybridBlock): """Crop the input image with and optionally resize it. Makes a crop of the original image then optionally resize it to the specified size. Parameters ---------- - x0 : int + x : int Left boundary of the cropping area - y0 : int + y : int Top boundary of the cropping area w : int Width of the cropping area @@ -261,18 +261,18 @@ class Crop(HybridBlock): >>> transformer(image) """ - def __init__(self, x0, y0, width, height, size=None, interpolation=None): - super(Crop, self).__init__() - self._x0 = x0 - self._y0 = y0 + def __init__(self, x, y, width, height, size=None, interpolation=None): + super(CropResize, self).__init__() + self._x = x + self._y = y self._width = width self._height = height self._size = size self._interpolation = interpolation def hybrid_forward(self, F, x): - out = F.image.crop(x, self._x0, self._y0, self._width, self._height) - if self._size is not None: + out = F.image.crop(x, self._x, self._y, self._width, self._height) + if self._size: out = F.image.resize(out, self._size, False, self._interpolation) return out diff --git a/src/operator/image/crop-inl.h b/src/operator/image/crop-inl.h index 69a4f525d342..d8072f42b681 100644 --- a/src/operator/image/crop-inl.h +++ b/src/operator/image/crop-inl.h @@ -56,7 +56,7 @@ struct CropParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(width) .describe("Width of the cropping area."); DMLC_DECLARE_FIELD(height) - .describe("Top boundary of the cropping area"); + .describe("Height of the cropping area."); } }; @@ -73,6 +73,10 @@ inline bool CropShape(const nnvm::NodeAttrs& attrs, CHECK((param.height > 0) && (param.width > 0)) << "Input height and width must be greater than 0"; + CHECK(param.x + param.width <= ishape[ishape.ndim() - 2]) + << " x + width should not be greater than input width"; + CHECK(param.y + param.height <= ishape[ishape.ndim() - 3]) + << " y + height should not be greater than input height"; if (ishape.ndim() == 3) { SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({param.height, param.width, ishape[C]})); } else { @@ -94,10 +98,6 @@ inline void CropImpl(int x, const TBlob& data = inputs[0]; const TBlob& out = outputs[0]; MXNET_NDIM_SWITCH(data.ndim(), ndim, { - CHECK(x + width <= data.shape_[ndim - 2]) - << " x + width should not be greater than input width"; - CHECK(y + height <= data.shape_[ndim - 3]) - << " y + height should not be greater than input height"; Stream* s = ctx.get_stream(); common::StaticArray begin = {0}, step = {1}; if (ndim == 3) { @@ -108,30 +108,73 @@ inline void CropImpl(int x, begin[2] = x; } MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { - Tensor input_tensor = data.get(s); - Tensor output_tensor = out.get(s); MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { size_t num_threads = out.shape_.FlatTo2D()[0]; mxnet_op::Kernel, cpu>::Launch(s, num_threads, - output_tensor.dptr_, input_tensor.dptr_, - input_tensor.shape_, output_tensor.shape_, begin, step); + out.dptr(), data.dptr(), + data.shape_.get(), out.shape_.get(), begin, step); + }) + }) + }) +} + +inline void CropBackwardImpl(int x, + int y, + int width, + int height, + const std::vector &inputs, + const std::vector &outputs, + const OpContext &ctx, + const std::vector &req) { + using namespace mshadow; + if (req[0] == kNullOp) return; + const TBlob& output_grad = inputs[0]; + const TBlob& input_grad = outputs[0]; + Stream* s = ctx.get_stream(); + if (req[0] == kWriteTo) { + Fill(s, input_grad, req[0], 0); + } else if (req[0] == kWriteInplace) { + LOG(FATAL) << "_backward_image_crop does not support kWriteInplace"; + } + MXNET_NDIM_SWITCH(output_grad.ndim(), ndim, { + common::StaticArray begin = {0}, step = {1}; + if (ndim == 3) { + begin[0] = y; + begin[1] = x; + } else { + begin[1] = y; + begin[2] = x; + } + MSHADOW_TYPE_SWITCH(input_grad.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + size_t num_threads = output_grad.shape_.FlatTo2D()[0]; + mxnet_op::Kernel, cpu>::Launch(s, num_threads, + input_grad.dptr(), output_grad.dptr(), + output_grad.shape_.get(), input_grad.shape_.get(), begin, step); }) }) }) } -inline void Crop(const nnvm::NodeAttrs &attrs, +inline void CropOpForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { CHECK_EQ(outputs.size(), 1U); const CropParam& param = nnvm::get(attrs.parsed); - CHECK((param.height > 0) && (param.width > 0)) - << "Input height and width must be greater than 0"; - CropImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req); } + +inline void CropOpBackward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(outputs.size(), 1U); + const CropParam& param = nnvm::get(attrs.parsed); + CropBackwardImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req); +} } // namespace image } // namespace op } // namespace mxnet diff --git a/src/operator/image/crop.cc b/src/operator/image/crop.cc index 4e7ef4359328..7b9d857cb668 100644 --- a/src/operator/image/crop.cc +++ b/src/operator/image/crop.cc @@ -68,11 +68,16 @@ to the given size. .set_attr_parser(ParamParser) .set_attr("FInferShape", CropShape) .set_attr("FInferType", ElemwiseType<1, 1>) -.set_attr("FCompute", Crop) -.set_attr("FGradient", ElemwiseGradUseNone{ "_copy" }) +.set_attr("FCompute", CropOpForward) +.set_attr("FGradient", ElemwiseGradUseNone{ "_backward_image_crop" }) .add_argument("data", "NDArray-or-Symbol", "The input.") .add_arguments(CropParam::__FIELDS__()); +NNVM_REGISTER_OP(_backward_image_crop) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FCompute", CropOpBackward); + } // namespace image } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py index b621b734184a..ee15057d4e67 100644 --- a/tests/python/unittest/test_gluon_data_vision.py +++ b/tests/python/unittest/test_gluon_data_vision.py @@ -120,40 +120,44 @@ def _test_resize_with_diff_type(dtype): @with_seed() -def test_crop(): - def _test_crop_with_diff_type(dtype): +def test_crop_resize(): + def _test_crop_resize_with_diff_type(dtype): # test normal case data_in = nd.arange(60).reshape((5, 4, 3)).astype('uint8') - out_nd = transforms.Crop(0, 0, 3, 2)(data_in) + out_nd = transforms.CropResize(0, 0, 3, 2)(data_in) out_np = out_nd.asnumpy() assert(out_np.sum() == 180) assert((out_np[0:2,1,1].flatten() == [4, 16]).all()) # test 4D input data_bath_in = nd.arange(180).reshape((2, 6, 5, 3)).astype('uint8') - out_batch_nd = transforms.Crop(1, 2, 3, 4)(data_bath_in) + out_batch_nd = transforms.CropResize(1, 2, 3, 4)(data_bath_in) out_batch_np = out_batch_nd.asnumpy() assert(out_batch_np.sum() == 7524) assert((out_batch_np[0:2,0:4,1,1].flatten() == [37, 52, 67, 82, 127, 142, 157, 172]).all()) # test normal case with resize data_in = nd.random.uniform(0, 255, (300, 200, 3)).astype('uint8') - out_nd = transforms.Crop(0, 0, 100, 50, (25, 25), 2)(data_in) + out_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_in) data_expected = image.imresize(nd.slice(data_in, (0, 0, 0), (50, 100 , 3)), 25, 25, 2) assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy()) # test 4D input with resize data_bath_in = nd.random.uniform(0, 255, (3, 300, 200, 3)).astype('uint8') - out_batch_nd = transforms.Crop(0, 0, 100, 50, (25, 25), 2)(data_bath_in) + out_batch_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_bath_in) for i in range(len(out_batch_nd)): assert_almost_equal(image.imresize(nd.slice(data_bath_in[i], (0, 0, 0), (50, 100, 3)), 25, 25, 2).asnumpy(), out_batch_nd[i].asnumpy()) # test with resize height and width should be greater than 0 - transformer = transforms.Crop(0, 0, 100, 50, (-25, 25), 2) + transformer = transforms.CropResize(0, 0, 100, 50, (-25, 25), 2) assertRaises(MXNetError, transformer, data_in) # test height and width should be greater than 0 - transformer = transforms.Crop(0, 0, -100, -50) + transformer = transforms.CropResize(0, 0, -100, -50) assertRaises(MXNetError, transformer, data_in) + # test cropped area is bigger than input data + transformer = transforms.CropResize(150, 200, 200, 500) + assertRaises(MXNetError, transformer, data_in) + assertRaises(MXNetError, transformer, data_bath_in) for dtype in ['uint8', 'float32', 'float64']: - _test_crop_with_diff_type(dtype) + _test_crop_resize_with_diff_type(dtype) @with_seed() def test_flip_left_right():