Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
image crop gpu (#16464)
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 authored and roywei committed Oct 19, 2019
1 parent 2d4c3a4 commit 27b3e52
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 87 deletions.
8 changes: 6 additions & 2 deletions src/operator/contrib/bilinear_resize-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ static unsigned getNumThreads(int nElem, const bool smaller) {

// caffe_gpu_interp2_kernel overloading with Tensor<xpu, 3, DType>
template<typename xpu, typename Dtype, typename Acctype>
__global__ void caffe_gpu_interp2_kernel(const int n,
__global__ void
__launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
caffe_gpu_interp2_kernel(const int n,
const Acctype rheight, const Acctype rwidth,
const Tensor<xpu, 3, Dtype> data1,
Tensor<xpu, 3, Dtype> data2,
Expand Down Expand Up @@ -111,7 +113,9 @@ __global__ void caffe_gpu_interp2_kernel(const int n,

// caffe_gpu_interp2_kernel overloading with Tensor<xpu, 4, DType>
template<typename xpu, typename Dtype, typename Acctype>
__global__ void caffe_gpu_interp2_kernel(const int n,
__global__ void
__launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
caffe_gpu_interp2_kernel(const int n,
const Acctype rheight, const Acctype rwidth,
const Tensor<xpu, 4, Dtype> data1,
Tensor<xpu, 4, Dtype> data2,
Expand Down
22 changes: 16 additions & 6 deletions src/operator/image/crop-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ inline bool CropShape(const nnvm::NodeAttrs& attrs,
return true;
}

template<typename xpu>
inline void CropImpl(int x,
int y,
int width,
Expand All @@ -106,7 +107,7 @@ inline void CropImpl(int x,
const TBlob& data = inputs[0];
const TBlob& out = outputs[0];
MXNET_NDIM_SWITCH(data.ndim(), ndim, {
Stream<cpu>* s = ctx.get_stream<cpu>();
Stream<xpu>* s = ctx.get_stream<xpu>();
common::StaticArray<index_t, ndim> begin = {0}, step = {1};
if (ndim == 3) {
begin[0] = y;
Expand All @@ -118,14 +119,18 @@ inline void CropImpl(int x,
MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
size_t num_threads = out.shape_.FlatTo2D()[0];
mxnet_op::Kernel<slice_forward<ndim, Req, cpu>, cpu>::Launch(s, num_threads,
if (std::is_same<xpu, gpu>::value) {
num_threads *= out.shape_.get<ndim>()[ndim - 1];
}
mxnet_op::Kernel<slice_forward<ndim, Req, xpu>, xpu>::Launch(s, num_threads,
out.dptr<DType>(), data.dptr<DType>(),
data.shape_.get<ndim>(), out.shape_.get<ndim>(), begin, step);
})
})
})
}

template<typename xpu>
inline void CropBackwardImpl(int x,
int y,
int width,
Expand All @@ -138,7 +143,7 @@ inline void CropBackwardImpl(int x,
if (req[0] == kNullOp) return;
const TBlob& output_grad = inputs[0];
const TBlob& input_grad = outputs[0];
Stream<cpu>* s = ctx.get_stream<cpu>();
Stream<xpu>* s = ctx.get_stream<xpu>();
if (req[0] == kWriteTo) {
Fill(s, input_grad, req[0], 0);
} else if (req[0] == kWriteInplace) {
Expand All @@ -156,32 +161,37 @@ inline void CropBackwardImpl(int x,
MSHADOW_TYPE_SWITCH(output_grad.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
size_t num_threads = output_grad.shape_.FlatTo2D()[0];
mxnet_op::Kernel<slice_assign<ndim, Req, cpu>, cpu>::Launch(s, num_threads,
if (std::is_same<xpu, gpu>::value) {
num_threads *= output_grad.shape_.get<ndim>()[ndim - 1];
}
mxnet_op::Kernel<slice_assign<ndim, Req, xpu>, xpu>::Launch(s, num_threads,
input_grad.dptr<DType>(), output_grad.dptr<DType>(),
input_grad.shape_.get<ndim>(), output_grad.shape_.get<ndim>(), begin, step);
})
})
})
}

template<typename xpu>
inline void CropOpForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(outputs.size(), 1U);
const CropParam& param = nnvm::get<CropParam>(attrs.parsed);
CropImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req);
CropImpl<xpu>(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req);
}

template<typename xpu>
inline void CropOpBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(outputs.size(), 1U);
const CropParam& param = nnvm::get<CropParam>(attrs.parsed);
CropBackwardImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req);
CropBackwardImpl<xpu>(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req);
}
} // namespace image
} // namespace op
Expand Down
4 changes: 2 additions & 2 deletions src/operator/image/crop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ to the given size.
.set_attr_parser(ParamParser<CropParam>)
.set_attr<mxnet::FInferShape>("FInferShape", CropShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", CropOpForward)
.set_attr<FCompute>("FCompute<cpu>", CropOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_backward_image_crop" })
.add_argument("data", "NDArray-or-Symbol", "The input.")
.add_arguments(CropParam::__FIELDS__());
Expand All @@ -79,7 +79,7 @@ NNVM_REGISTER_OP(_backward_image_crop)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", CropOpBackward);
.set_attr<FCompute>("FCompute<cpu>", CropOpBackward<cpu>);

} // namespace image
} // namespace op
Expand Down
34 changes: 34 additions & 0 deletions src/operator/image/crop.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "crop-inl.h"

namespace mxnet {
namespace op {
namespace image {

NNVM_REGISTER_OP(_image_crop)
.set_attr<FCompute>("FCompute<gpu>", CropOpForward<gpu>);

NNVM_REGISTER_OP(_backward_image_crop)
.set_attr<FCompute>("FCompute<gpu>", CropOpBackward<gpu>);

} // namespace image
} // namespace op
} // namespace mxnet
74 changes: 11 additions & 63 deletions tests/python/gpu/test_gluon_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,79 +28,22 @@
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import assertRaises, setup_module, with_seed, teardown

from test_gluon_data_vision import test_to_tensor, test_normalize, test_crop_resize

set_default_context(mx.gpu(0))

@with_seed()
def test_normalize():
# 3D Input
data_in_3d = nd.random.uniform(0, 1, (3, 300, 300))
out_nd_3d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_3d)
data_expected_3d = data_in_3d.asnumpy()
data_expected_3d[:][:][0] = data_expected_3d[:][:][0] / 3.0
data_expected_3d[:][:][1] = (data_expected_3d[:][:][1] - 1.0) / 2.0
data_expected_3d[:][:][2] = data_expected_3d[:][:][2] - 2.0
assert_almost_equal(data_expected_3d, out_nd_3d.asnumpy())

# 4D Input
data_in_4d = nd.random.uniform(0, 1, (2, 3, 300, 300))
out_nd_4d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_4d)
data_expected_4d = data_in_4d.asnumpy()
data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0
data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0
data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0
data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0
data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0
data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0
assert_almost_equal(data_expected_4d, out_nd_4d.asnumpy())

# Default normalize values i.e., mean=0, std=1
data_in_3d_def = nd.random.uniform(0, 1, (3, 300, 300))
out_nd_3d_def = transforms.Normalize()(data_in_3d_def)
data_expected_3d_def = data_in_3d_def.asnumpy()
assert_almost_equal(data_expected_3d_def, out_nd_3d_def.asnumpy())
def test_normalize_gpu():
test_normalize()

# Invalid Input - Neither 3D or 4D input
invalid_data_in = nd.random.uniform(0, 1, (5, 5, 3, 300, 300))
normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
assertRaises(MXNetError, normalize_transformer, invalid_data_in)

# Invalid Input - Channel neither 1 or 3
invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300))
normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
assertRaises(MXNetError, normalize_transformer, invalid_data_in)

@with_seed()
def test_to_tensor():
# 3D Input
data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
assert_almost_equal(out_nd.asnumpy(), np.transpose(
data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1)))

# 4D Input
data_in = np.random.uniform(0, 255, (5, 300, 300, 3)).astype(dtype=np.uint8)
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
assert_almost_equal(out_nd.asnumpy(), np.transpose(
data_in.astype(dtype=np.float32) / 255.0, (0, 3, 1, 2)))
def test_to_tensor_gpu():
test_to_tensor()

# Invalid Input
invalid_data_in = nd.random.uniform(0, 255, (5, 5, 300, 300, 3)).astype(dtype=np.uint8)
transformer = transforms.ToTensor()
assertRaises(MXNetError, transformer, invalid_data_in)

# Bounds (0->0, 255->1)
data_in = np.zeros((10, 20, 3)).astype(dtype=np.uint8)
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
assert same(out_nd.asnumpy(), np.transpose(np.zeros(data_in.shape, dtype=np.float32), (2, 0, 1)))

data_in = np.full((10, 20, 3), 255).astype(dtype=np.uint8)
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
assert same(out_nd.asnumpy(), np.transpose(np.ones(data_in.shape, dtype=np.float32), (2, 0, 1)))

@with_seed()
def test_resize():
def test_resize_gpu():
# Test with normal case 3D input float type
data_in_3d = nd.random.uniform(0, 255, (300, 300, 3))
out_nd_3d = transforms.Resize((100, 100))(data_in_3d)
Expand Down Expand Up @@ -155,3 +98,8 @@ def py_bilinear_resize_nhwc(x, outputHeight, outputWidth):
data_in_4d = nd.random.uniform(0, 255, (2, 300, 300, 3)).astype('uint8')
out_nd_4d = transforms.Resize((100, 100))(data_in_4d)
assert_almost_equal(out_nd_4d.asnumpy(), py_bilinear_resize_nhwc(data_in_4d.asnumpy(), 100, 100), atol=1.0)


@with_seed()
def test_crop_resize_gpu():
test_crop_resize()
21 changes: 7 additions & 14 deletions tests/python/unittest/test_gluon_data_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,18 @@ def _test_crop_resize_with_diff_type(dtype):
assert((out_batch_np[0:2,0:4,1,1].flatten() == [37, 52, 67, 82, 127, 142, 157, 172]).all())
# test normal case with resize
data_in = nd.random.uniform(0, 255, (300, 200, 3)).astype(dtype)
out_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_in)
data_expected = image.imresize(nd.slice(data_in, (0, 0, 0), (50, 100 , 3)), 25, 25, 2)
out_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 1)(data_in)
data_expected = transforms.Resize(size=25, interpolation=1)(nd.slice(data_in, (0, 0, 0), (50, 100, 3)))
assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy())
# test 4D input with resize
data_bath_in = nd.random.uniform(0, 255, (3, 300, 200, 3)).astype(dtype)
out_batch_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_bath_in)
out_batch_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 1)(data_bath_in)
for i in range(len(out_batch_nd)):
assert_almost_equal(image.imresize(nd.slice(data_bath_in[i], (0, 0, 0), (50, 100, 3)), 25, 25, 2).asnumpy(),
out_batch_nd[i].asnumpy())
actual = transforms.Resize(size=25, interpolation=1)(nd.slice(data_bath_in[i], (0, 0, 0), (50, 100, 3))).asnumpy()
expected = out_batch_nd[i].asnumpy()
assert_almost_equal(expected, actual)
# test with resize height and width should be greater than 0
transformer = transforms.CropResize(0, 0, 100, 50, (-25, 25), 2)
transformer = transforms.CropResize(0, 0, 100, 50, (-25, 25), 1)
assertRaises(MXNetError, transformer, data_in)
# test height and width should be greater than 0
transformer = transforms.CropResize(0, 0, -100, -50)
Expand Down Expand Up @@ -188,14 +189,6 @@ def test_crop_backward(test_nd_arr, TestCase):
data_in = nd.arange(60).reshape((5, 4, 3)).astype(dtype)
for test_case in test_list:
test_crop_backward(data_in, test_case)



# check numeric gradient of nd.image.crop
# in_data = np.arange(36).reshape(3, 4, 3)
# data = mx.sym.Variable('data')
# image_crop_sym = mx.sym.image.crop(data, 0, 0, 2, 2)
# check_numeric_gradient(image_crop_sym, [in_data])


@with_seed()
Expand Down

0 comments on commit 27b3e52

Please sign in to comment.