Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add. backword and refactor the code
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 committed Mar 14, 2019
1 parent 54e746b commit 6e101b7
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 33 deletions.
18 changes: 9 additions & 9 deletions python/mxnet/gluon/data/vision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,14 @@ def forward(self, x):
return image.random_size_crop(x, *self._args)[0]


class Crop(HybridBlock):
class CropResize(HybridBlock):
"""Crop the input image with and optionally resize it.
Makes a crop of the original image then optionally resize it to the specified size.
Parameters
----------
x0 : int
x : int
Left boundary of the cropping area
y0 : int
y : int
Top boundary of the cropping area
w : int
Width of the cropping area
Expand All @@ -261,18 +261,18 @@ class Crop(HybridBlock):
>>> transformer(image)
<NDArray 3x50x50 @cpu(0)>
"""
def __init__(self, x0, y0, width, height, size=None, interpolation=None):
super(Crop, self).__init__()
self._x0 = x0
self._y0 = y0
def __init__(self, x, y, width, height, size=None, interpolation=None):
super(CropResize, self).__init__()
self._x = x
self._y = y
self._width = width
self._height = height
self._size = size
self._interpolation = interpolation

def hybrid_forward(self, F, x):
out = F.image.crop(x, self._x0, self._y0, self._width, self._height)
if self._size is not None:
out = F.image.crop(x, self._x, self._y, self._width, self._height)
if self._size:
out = F.image.resize(out, self._size, False, self._interpolation)
return out

Expand Down
69 changes: 56 additions & 13 deletions src/operator/image/crop-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ struct CropParam : public dmlc::Parameter<CropParam> {
DMLC_DECLARE_FIELD(width)
.describe("Width of the cropping area.");
DMLC_DECLARE_FIELD(height)
.describe("Top boundary of the cropping area");
.describe("Height of the cropping area.");
}
};

Expand All @@ -73,6 +73,10 @@ inline bool CropShape(const nnvm::NodeAttrs& attrs,

CHECK((param.height > 0) && (param.width > 0))
<< "Input height and width must be greater than 0";
CHECK(param.x + param.width <= ishape[ishape.ndim() - 2])
<< " x + width should not be greater than input width";
CHECK(param.y + param.height <= ishape[ishape.ndim() - 3])
<< " y + height should not be greater than input height";
if (ishape.ndim() == 3) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({param.height, param.width, ishape[C]}));
} else {
Expand All @@ -94,10 +98,6 @@ inline void CropImpl(int x,
const TBlob& data = inputs[0];
const TBlob& out = outputs[0];
MXNET_NDIM_SWITCH(data.ndim(), ndim, {
CHECK(x + width <= data.shape_[ndim - 2])
<< " x + width should not be greater than input width";
CHECK(y + height <= data.shape_[ndim - 3])
<< " y + height should not be greater than input height";
Stream<cpu>* s = ctx.get_stream<cpu>();
common::StaticArray<index_t, ndim> begin = {0}, step = {1};
if (ndim == 3) {
Expand All @@ -108,30 +108,73 @@ inline void CropImpl(int x,
begin[2] = x;
}
MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
Tensor<cpu, ndim, DType> input_tensor = data.get<cpu, ndim, DType>(s);
Tensor<cpu, ndim, DType> output_tensor = out.get<cpu, ndim, DType>(s);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
size_t num_threads = out.shape_.FlatTo2D()[0];
mxnet_op::Kernel<slice_forward<ndim, Req, cpu>, cpu>::Launch(s, num_threads,
output_tensor.dptr_, input_tensor.dptr_,
input_tensor.shape_, output_tensor.shape_, begin, step);
out.dptr<DType>(), data.dptr<DType>(),
data.shape_.get<ndim>(), out.shape_.get<ndim>(), begin, step);
})
})
})
}

inline void CropBackwardImpl(int x,
int y,
int width,
int height,
const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const OpContext &ctx,
const std::vector<OpReqType> &req) {
using namespace mshadow;
if (req[0] == kNullOp) return;
const TBlob& output_grad = inputs[0];
const TBlob& input_grad = outputs[0];
Stream<cpu>* s = ctx.get_stream<cpu>();
if (req[0] == kWriteTo) {
Fill(s, input_grad, req[0], 0);
} else if (req[0] == kWriteInplace) {
LOG(FATAL) << "_backward_image_crop does not support kWriteInplace";
}
MXNET_NDIM_SWITCH(output_grad.ndim(), ndim, {
common::StaticArray<index_t, ndim> begin = {0}, step = {1};
if (ndim == 3) {
begin[0] = y;
begin[1] = x;
} else {
begin[1] = y;
begin[2] = x;
}
MSHADOW_TYPE_SWITCH(input_grad.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
size_t num_threads = output_grad.shape_.FlatTo2D()[0];
mxnet_op::Kernel<slice_assign<ndim, Req, cpu>, cpu>::Launch(s, num_threads,
input_grad.dptr<DType>(), output_grad.dptr<DType>(),
output_grad.shape_.get<ndim>(), input_grad.shape_.get<ndim>(), begin, step);
})
})
})
}

inline void Crop(const nnvm::NodeAttrs &attrs,
inline void CropOpForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(outputs.size(), 1U);
const CropParam& param = nnvm::get<CropParam>(attrs.parsed);
CHECK((param.height > 0) && (param.width > 0))
<< "Input height and width must be greater than 0";

CropImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req);
}

inline void CropOpBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(outputs.size(), 1U);
const CropParam& param = nnvm::get<CropParam>(attrs.parsed);
CropBackwardImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req);
}
} // namespace image
} // namespace op
} // namespace mxnet
Expand Down
9 changes: 7 additions & 2 deletions src/operator/image/crop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,16 @@ to the given size.
.set_attr_parser(ParamParser<CropParam>)
.set_attr<nnvm::FInferShape>("FInferShape", CropShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", Crop)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
.set_attr<FCompute>("FCompute<cpu>", CropOpForward)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_backward_image_crop" })
.add_argument("data", "NDArray-or-Symbol", "The input.")
.add_arguments(CropParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_image_crop)
.set_attr_parser(ParamParser<CropParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", CropOpBackward);

} // namespace image
} // namespace op
} // namespace mxnet
22 changes: 13 additions & 9 deletions tests/python/unittest/test_gluon_data_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,40 +120,44 @@ def _test_resize_with_diff_type(dtype):


@with_seed()
def test_crop():
def _test_crop_with_diff_type(dtype):
def test_crop_resize():
def _test_crop_resize_with_diff_type(dtype):
# test normal case
data_in = nd.arange(60).reshape((5, 4, 3)).astype('uint8')
out_nd = transforms.Crop(0, 0, 3, 2)(data_in)
out_nd = transforms.CropResize(0, 0, 3, 2)(data_in)
out_np = out_nd.asnumpy()
assert(out_np.sum() == 180)
assert((out_np[0:2,1,1].flatten() == [4, 16]).all())
# test 4D input
data_bath_in = nd.arange(180).reshape((2, 6, 5, 3)).astype('uint8')
out_batch_nd = transforms.Crop(1, 2, 3, 4)(data_bath_in)
out_batch_nd = transforms.CropResize(1, 2, 3, 4)(data_bath_in)
out_batch_np = out_batch_nd.asnumpy()
assert(out_batch_np.sum() == 7524)
assert((out_batch_np[0:2,0:4,1,1].flatten() == [37, 52, 67, 82, 127, 142, 157, 172]).all())
# test normal case with resize
data_in = nd.random.uniform(0, 255, (300, 200, 3)).astype('uint8')
out_nd = transforms.Crop(0, 0, 100, 50, (25, 25), 2)(data_in)
out_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_in)
data_expected = image.imresize(nd.slice(data_in, (0, 0, 0), (50, 100 , 3)), 25, 25, 2)
assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy())
# test 4D input with resize
data_bath_in = nd.random.uniform(0, 255, (3, 300, 200, 3)).astype('uint8')
out_batch_nd = transforms.Crop(0, 0, 100, 50, (25, 25), 2)(data_bath_in)
out_batch_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_bath_in)
for i in range(len(out_batch_nd)):
assert_almost_equal(image.imresize(nd.slice(data_bath_in[i], (0, 0, 0), (50, 100, 3)), 25, 25, 2).asnumpy(),
out_batch_nd[i].asnumpy())
# test with resize height and width should be greater than 0
transformer = transforms.Crop(0, 0, 100, 50, (-25, 25), 2)
transformer = transforms.CropResize(0, 0, 100, 50, (-25, 25), 2)
assertRaises(MXNetError, transformer, data_in)
# test height and width should be greater than 0
transformer = transforms.Crop(0, 0, -100, -50)
transformer = transforms.CropResize(0, 0, -100, -50)
assertRaises(MXNetError, transformer, data_in)
# test cropped area is bigger than input data
transformer = transforms.CropResize(150, 200, 200, 500)
assertRaises(MXNetError, transformer, data_in)
assertRaises(MXNetError, transformer, data_bath_in)

for dtype in ['uint8', 'float32', 'float64']:
_test_crop_with_diff_type(dtype)
_test_crop_resize_with_diff_type(dtype)

@with_seed()
def test_flip_left_right():
Expand Down

0 comments on commit 6e101b7

Please sign in to comment.