Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add Gluon Transformer Crop (#14259)
Browse files Browse the repository at this point in the history
* implement crop

* add crop operator

* fix for linter

* add. backword and refactor the code

* fix error namespace

* fix the website build failure

* start adding the unit test of backword

* add unit test for backward

* address the comment

* add missing statement

* fix the website error

* fix the website building

* add missing doc
  • Loading branch information
stu1130 authored and nswamy committed Apr 5, 2019
1 parent d843a85 commit 49c1ccc
Show file tree
Hide file tree
Showing 5 changed files with 410 additions and 3 deletions.
61 changes: 61 additions & 0 deletions python/mxnet/gluon/data/vision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,67 @@ def forward(self, x):
return image.random_size_crop(x, *self._args)[0]


class CropResize(HybridBlock):
r"""Crop the input image with and optionally resize it.
Makes a crop of the original image then optionally resize it to the specified size.
Parameters
----------
x : int
Left boundary of the cropping area
y : int
Top boundary of the cropping area
w : int
Width of the cropping area
h : int
Height of the cropping area
size : int or tuple of (w, h)
Optional, resize to new size after cropping
interpolation : int, optional
Interpolation method for resizing. By default uses bilinear
interpolation. See OpenCV's resize function for available choices.
https://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=resize#resize
Note that the Resize on gpu use contrib.bilinearResize2D operator
which only support bilinear interpolation(1). The result would be slightly
different on gpu compared to cpu. OpenCV tend to align center while bilinearResize2D
use algorithm which aligns corner.
Inputs:
- **data**: input tensor with (H x W x C) or (N x H x W x C) shape.
Outputs:
- **out**: input tensor with (H x W x C) or (N x H x W x C) shape.
Examples
--------
>>> transformer = vision.transforms.CropResize(x=0, y=0, width=100, height=100)
>>> image = mx.nd.random.uniform(0, 255, (224, 224, 3)).astype(dtype=np.uint8)
>>> transformer(image)
<NDArray 100x100x3 @cpu(0)>
>>> image = mx.nd.random.uniform(0, 255, (3, 224, 224, 3)).astype(dtype=np.uint8)
>>> transformer(image)
<NDArray 3x100x100x3 @cpu(0)>
>>> transformer = vision.transforms.CropResize(x=0, y=0, width=100, height=100, size=(50, 50), interpolation=1)
>>> transformer(image)
<NDArray 3x50x50 @cpu(0)>
"""
def __init__(self, x, y, width, height, size=None, interpolation=None):
super(CropResize, self).__init__()
self._x = x
self._y = y
self._width = width
self._height = height
self._size = size
self._interpolation = interpolation

def hybrid_forward(self, F, x):
out = F.image.crop(x, self._x, self._y, self._width, self._height)
if self._size:
out = F.image.resize(out, self._size, False, self._interpolation)
return out

class CenterCrop(Block):
"""Crops the image `src` to the given `size` by trimming on all four
sides and preserving the center of the image. Upsamples if `src` is
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def fixed_crop(src, x0, y0, w, h, size=None, interp=2):
NDArray
An `NDArray` containing the cropped image.
"""
out = nd.crop(src, begin=(y0, x0, 0), end=(y0 + h, x0 + w, int(src.shape[2])))
out = nd.slice(src, begin=(y0, x0, 0), end=(y0 + h, x0 + w, int(src.shape[2])))
if size is not None and (w, h) != size:
sizes = (h, w, size[1], size[0])
out = imresize(out, *size, interp=_get_interp_method(interp, sizes))
Expand Down
190 changes: 190 additions & 0 deletions src/operator/image/crop-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file crop-inl.h
* \brief the image crop operator implementation
*/

#ifndef MXNET_OPERATOR_IMAGE_CROP_INL_H_
#define MXNET_OPERATOR_IMAGE_CROP_INL_H_


#include <algorithm>
#include <vector>

#include "mxnet/base.h"
#include "dmlc/optional.h"
#include "image_utils.h"
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../../common/static_array.h"
#include "../tensor/matrix_op-inl.h"
#include "resize-inl.h"

namespace mxnet {
namespace op {
namespace image {

struct CropParam : public dmlc::Parameter<CropParam> {
int x;
int y;
int width;
int height;
DMLC_DECLARE_PARAMETER(CropParam) {
DMLC_DECLARE_FIELD(x)
.describe("Left boundary of the cropping area.");
DMLC_DECLARE_FIELD(y)
.describe("Top boundary of the cropping area.");
DMLC_DECLARE_FIELD(width)
.describe("Width of the cropping area.");
DMLC_DECLARE_FIELD(height)
.describe("Height of the cropping area.");
}
};

inline bool CropShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
// input attrs should only be (h, w, c) or (n, h, w, c)
if (in_attrs->at(0).ndim() == 3U) {
CHECK((in_attrs->at(0)[2] == 1) || (in_attrs->at(0)[2] == 3))
<< "Expect channel of the input image is 1 or 3, but got"
<< in_attrs->at(0)[2];
} else if (in_attrs->at(0).ndim() == 4U) {
CHECK((in_attrs->at(0)[3] == 1) || (in_attrs->at(0)[3] == 3))
<< "Expect channel of the input image is 1 or 3, but got"
<< in_attrs->at(0)[3];
} else {
LOG(FATAL) << "Image Crop expects inputs of 3D (h, w, c) or 4D (n, h, w, c). But got "
<< in_attrs->at(0).ndim();
}

const auto& ishape = (*in_attrs)[0];
const CropParam& param = nnvm::get<CropParam>(attrs.parsed);

CHECK((param.height > 0) && (param.width > 0))
<< "Input height and width must be greater than 0";
CHECK(param.x + param.width <= ishape[ishape.ndim() - 2])
<< " x + width should not be greater than input width";
CHECK(param.y + param.height <= ishape[ishape.ndim() - 3])
<< " y + height should not be greater than input height";
if (ishape.ndim() == 3) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({param.height, param.width, ishape[C]}));
} else {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({ishape[N], param.height, param.width, ishape[kC]}));
}
return true;
}

inline void CropImpl(int x,
int y,
int width,
int height,
const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const OpContext &ctx,
const std::vector<OpReqType> &req) {
using namespace mshadow;
const TBlob& data = inputs[0];
const TBlob& out = outputs[0];
MXNET_NDIM_SWITCH(data.ndim(), ndim, {
Stream<cpu>* s = ctx.get_stream<cpu>();
common::StaticArray<index_t, ndim> begin = {0}, step = {1};
if (ndim == 3) {
begin[0] = y;
begin[1] = x;
} else {
begin[1] = y;
begin[2] = x;
}
MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
size_t num_threads = out.shape_.FlatTo2D()[0];
mxnet_op::Kernel<slice_forward<ndim, Req, cpu>, cpu>::Launch(s, num_threads,
out.dptr<DType>(), data.dptr<DType>(),
data.shape_.get<ndim>(), out.shape_.get<ndim>(), begin, step);
})
})
})
}

inline void CropBackwardImpl(int x,
int y,
int width,
int height,
const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const OpContext &ctx,
const std::vector<OpReqType> &req) {
using namespace mshadow;
if (req[0] == kNullOp) return;
const TBlob& output_grad = inputs[0];
const TBlob& input_grad = outputs[0];
Stream<cpu>* s = ctx.get_stream<cpu>();
if (req[0] == kWriteTo) {
Fill(s, input_grad, req[0], 0);
} else if (req[0] == kWriteInplace) {
LOG(FATAL) << "_backward_image_crop does not support kWriteInplace";
}
MXNET_NDIM_SWITCH(output_grad.ndim(), ndim, {
common::StaticArray<index_t, ndim> begin = {0}, step = {1};
if (ndim == 3) {
begin[0] = y;
begin[1] = x;
} else {
begin[1] = y;
begin[2] = x;
}
MSHADOW_TYPE_SWITCH(output_grad.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
size_t num_threads = output_grad.shape_.FlatTo2D()[0];
mxnet_op::Kernel<slice_assign<ndim, Req, cpu>, cpu>::Launch(s, num_threads,
input_grad.dptr<DType>(), output_grad.dptr<DType>(),
input_grad.shape_.get<ndim>(), output_grad.shape_.get<ndim>(), begin, step);
})
})
})
}

inline void CropOpForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(outputs.size(), 1U);
const CropParam& param = nnvm::get<CropParam>(attrs.parsed);
CropImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req);
}

inline void CropOpBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(outputs.size(), 1U);
const CropParam& param = nnvm::get<CropParam>(attrs.parsed);
CropBackwardImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req);
}
} // namespace image
} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_IMAGE_CROP_INL_H_
85 changes: 85 additions & 0 deletions src/operator/image/crop.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file crop-cc.h
* \brief the image crop operator registration
*/

#include "mxnet/base.h"
#include "crop-inl.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"

namespace mxnet {
namespace op {
namespace image {

DMLC_REGISTER_PARAMETER(CropParam);

NNVM_REGISTER_OP(_image_crop)
.describe(R"code(Crop an image NDArray of shape (H x W x C) or (N x H x W x C)
to the given size.
Example:
.. code-block:: python
image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8)
mx.nd.image.crop(image, 1, 1, 2, 2)
[[[144 34 4]
[ 82 157 38]]
[[156 111 230]
[177 25 15]]]
<NDArray 2x2x3 @cpu(0)>
image = mx.nd.random.uniform(0, 255, (2, 4, 2, 3)).astype(dtype=np.uint8)
mx.nd.image.crop(image, 1, 1, 2, 2)
[[[[ 35 198 50]
[242 94 168]]
[[223 119 129]
[249 14 154]]]
[[[137 215 106]
[ 79 174 133]]
[[116 142 109]
[ 35 239 50]]]]
<NDArray 2x2x2x3 @cpu(0)>
)code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<CropParam>)
.set_attr<mxnet::FInferShape>("FInferShape", CropShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", CropOpForward)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_backward_image_crop" })
.add_argument("data", "NDArray-or-Symbol", "The input.")
.add_arguments(CropParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_image_crop)
.set_attr_parser(ParamParser<CropParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", CropOpBackward);

} // namespace image
} // namespace op
} // namespace mxnet
Loading

0 comments on commit 49c1ccc

Please sign in to comment.