Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Image crop gpu support #16464

Merged
merged 1 commit into from
Oct 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/operator/contrib/bilinear_resize-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ static unsigned getNumThreads(int nElem, const bool smaller) {

// caffe_gpu_interp2_kernel overloading with Tensor<xpu, 3, DType>
template<typename xpu, typename Dtype, typename Acctype>
__global__ void caffe_gpu_interp2_kernel(const int n,
__global__ void
__launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
caffe_gpu_interp2_kernel(const int n,
const Acctype rheight, const Acctype rwidth,
const Tensor<xpu, 3, Dtype> data1,
Tensor<xpu, 3, Dtype> data2,
Expand Down Expand Up @@ -111,7 +113,9 @@ __global__ void caffe_gpu_interp2_kernel(const int n,

// caffe_gpu_interp2_kernel overloading with Tensor<xpu, 4, DType>
template<typename xpu, typename Dtype, typename Acctype>
__global__ void caffe_gpu_interp2_kernel(const int n,
__global__ void
__launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
caffe_gpu_interp2_kernel(const int n,
const Acctype rheight, const Acctype rwidth,
const Tensor<xpu, 4, Dtype> data1,
Tensor<xpu, 4, Dtype> data2,
Expand Down
22 changes: 16 additions & 6 deletions src/operator/image/crop-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ inline bool CropShape(const nnvm::NodeAttrs& attrs,
return true;
}

template<typename xpu>
inline void CropImpl(int x,
int y,
int width,
Expand All @@ -106,7 +107,7 @@ inline void CropImpl(int x,
const TBlob& data = inputs[0];
const TBlob& out = outputs[0];
MXNET_NDIM_SWITCH(data.ndim(), ndim, {
Stream<cpu>* s = ctx.get_stream<cpu>();
Stream<xpu>* s = ctx.get_stream<xpu>();
common::StaticArray<index_t, ndim> begin = {0}, step = {1};
if (ndim == 3) {
begin[0] = y;
Expand All @@ -118,14 +119,18 @@ inline void CropImpl(int x,
MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
size_t num_threads = out.shape_.FlatTo2D()[0];
mxnet_op::Kernel<slice_forward<ndim, Req, cpu>, cpu>::Launch(s, num_threads,
if (std::is_same<xpu, gpu>::value) {
num_threads *= out.shape_.get<ndim>()[ndim - 1];
}
mxnet_op::Kernel<slice_forward<ndim, Req, xpu>, xpu>::Launch(s, num_threads,
out.dptr<DType>(), data.dptr<DType>(),
data.shape_.get<ndim>(), out.shape_.get<ndim>(), begin, step);
})
})
})
}

template<typename xpu>
inline void CropBackwardImpl(int x,
int y,
int width,
Expand All @@ -138,7 +143,7 @@ inline void CropBackwardImpl(int x,
if (req[0] == kNullOp) return;
const TBlob& output_grad = inputs[0];
const TBlob& input_grad = outputs[0];
Stream<cpu>* s = ctx.get_stream<cpu>();
Stream<xpu>* s = ctx.get_stream<xpu>();
if (req[0] == kWriteTo) {
Fill(s, input_grad, req[0], 0);
} else if (req[0] == kWriteInplace) {
Expand All @@ -156,32 +161,37 @@ inline void CropBackwardImpl(int x,
MSHADOW_TYPE_SWITCH(output_grad.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
size_t num_threads = output_grad.shape_.FlatTo2D()[0];
mxnet_op::Kernel<slice_assign<ndim, Req, cpu>, cpu>::Launch(s, num_threads,
if (std::is_same<xpu, gpu>::value) {
num_threads *= output_grad.shape_.get<ndim>()[ndim - 1];
}
mxnet_op::Kernel<slice_assign<ndim, Req, xpu>, xpu>::Launch(s, num_threads,
input_grad.dptr<DType>(), output_grad.dptr<DType>(),
input_grad.shape_.get<ndim>(), output_grad.shape_.get<ndim>(), begin, step);
})
})
})
}

template<typename xpu>
inline void CropOpForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(outputs.size(), 1U);
const CropParam& param = nnvm::get<CropParam>(attrs.parsed);
CropImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req);
CropImpl<xpu>(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req);
}

template<typename xpu>
inline void CropOpBackward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(outputs.size(), 1U);
const CropParam& param = nnvm::get<CropParam>(attrs.parsed);
CropBackwardImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req);
CropBackwardImpl<xpu>(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req);
}
} // namespace image
} // namespace op
Expand Down
4 changes: 2 additions & 2 deletions src/operator/image/crop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ to the given size.
.set_attr_parser(ParamParser<CropParam>)
.set_attr<mxnet::FInferShape>("FInferShape", CropShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", CropOpForward)
.set_attr<FCompute>("FCompute<cpu>", CropOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_backward_image_crop" })
.add_argument("data", "NDArray-or-Symbol", "The input.")
.add_arguments(CropParam::__FIELDS__());
Expand All @@ -79,7 +79,7 @@ NNVM_REGISTER_OP(_backward_image_crop)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", CropOpBackward);
.set_attr<FCompute>("FCompute<cpu>", CropOpBackward<cpu>);

} // namespace image
} // namespace op
Expand Down
34 changes: 34 additions & 0 deletions src/operator/image/crop.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "crop-inl.h"

namespace mxnet {
namespace op {
namespace image {

NNVM_REGISTER_OP(_image_crop)
.set_attr<FCompute>("FCompute<gpu>", CropOpForward<gpu>);

NNVM_REGISTER_OP(_backward_image_crop)
.set_attr<FCompute>("FCompute<gpu>", CropOpBackward<gpu>);

} // namespace image
} // namespace op
} // namespace mxnet
74 changes: 11 additions & 63 deletions tests/python/gpu/test_gluon_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,79 +28,22 @@
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
from common import assertRaises, setup_module, with_seed, teardown

from test_gluon_data_vision import test_to_tensor, test_normalize, test_crop_resize

set_default_context(mx.gpu(0))

@with_seed()
def test_normalize():
# 3D Input
data_in_3d = nd.random.uniform(0, 1, (3, 300, 300))
out_nd_3d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_3d)
data_expected_3d = data_in_3d.asnumpy()
data_expected_3d[:][:][0] = data_expected_3d[:][:][0] / 3.0
data_expected_3d[:][:][1] = (data_expected_3d[:][:][1] - 1.0) / 2.0
data_expected_3d[:][:][2] = data_expected_3d[:][:][2] - 2.0
assert_almost_equal(data_expected_3d, out_nd_3d.asnumpy())

# 4D Input
data_in_4d = nd.random.uniform(0, 1, (2, 3, 300, 300))
out_nd_4d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_4d)
data_expected_4d = data_in_4d.asnumpy()
data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0
data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0
data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0
data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0
data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0
data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0
assert_almost_equal(data_expected_4d, out_nd_4d.asnumpy())

# Default normalize values i.e., mean=0, std=1
data_in_3d_def = nd.random.uniform(0, 1, (3, 300, 300))
out_nd_3d_def = transforms.Normalize()(data_in_3d_def)
data_expected_3d_def = data_in_3d_def.asnumpy()
assert_almost_equal(data_expected_3d_def, out_nd_3d_def.asnumpy())
def test_normalize_gpu():
test_normalize()

# Invalid Input - Neither 3D or 4D input
invalid_data_in = nd.random.uniform(0, 1, (5, 5, 3, 300, 300))
normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
assertRaises(MXNetError, normalize_transformer, invalid_data_in)

# Invalid Input - Channel neither 1 or 3
invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300))
normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
assertRaises(MXNetError, normalize_transformer, invalid_data_in)

@with_seed()
def test_to_tensor():
# 3D Input
data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
assert_almost_equal(out_nd.asnumpy(), np.transpose(
data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1)))

# 4D Input
data_in = np.random.uniform(0, 255, (5, 300, 300, 3)).astype(dtype=np.uint8)
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
assert_almost_equal(out_nd.asnumpy(), np.transpose(
data_in.astype(dtype=np.float32) / 255.0, (0, 3, 1, 2)))
def test_to_tensor_gpu():
test_to_tensor()

# Invalid Input
invalid_data_in = nd.random.uniform(0, 255, (5, 5, 300, 300, 3)).astype(dtype=np.uint8)
transformer = transforms.ToTensor()
assertRaises(MXNetError, transformer, invalid_data_in)

# Bounds (0->0, 255->1)
data_in = np.zeros((10, 20, 3)).astype(dtype=np.uint8)
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
assert same(out_nd.asnumpy(), np.transpose(np.zeros(data_in.shape, dtype=np.float32), (2, 0, 1)))

data_in = np.full((10, 20, 3), 255).astype(dtype=np.uint8)
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
assert same(out_nd.asnumpy(), np.transpose(np.ones(data_in.shape, dtype=np.float32), (2, 0, 1)))

@with_seed()
def test_resize():
def test_resize_gpu():
# Test with normal case 3D input float type
data_in_3d = nd.random.uniform(0, 255, (300, 300, 3))
out_nd_3d = transforms.Resize((100, 100))(data_in_3d)
Expand Down Expand Up @@ -155,3 +98,8 @@ def py_bilinear_resize_nhwc(x, outputHeight, outputWidth):
data_in_4d = nd.random.uniform(0, 255, (2, 300, 300, 3)).astype('uint8')
out_nd_4d = transforms.Resize((100, 100))(data_in_4d)
assert_almost_equal(out_nd_4d.asnumpy(), py_bilinear_resize_nhwc(data_in_4d.asnumpy(), 100, 100), atol=1.0)


@with_seed()
def test_crop_resize_gpu():
test_crop_resize()
21 changes: 7 additions & 14 deletions tests/python/unittest/test_gluon_data_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,18 @@ def _test_crop_resize_with_diff_type(dtype):
assert((out_batch_np[0:2,0:4,1,1].flatten() == [37, 52, 67, 82, 127, 142, 157, 172]).all())
# test normal case with resize
data_in = nd.random.uniform(0, 255, (300, 200, 3)).astype(dtype)
out_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_in)
data_expected = image.imresize(nd.slice(data_in, (0, 0, 0), (50, 100 , 3)), 25, 25, 2)
out_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 1)(data_in)
data_expected = transforms.Resize(size=25, interpolation=1)(nd.slice(data_in, (0, 0, 0), (50, 100, 3)))
assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy())
# test 4D input with resize
data_bath_in = nd.random.uniform(0, 255, (3, 300, 200, 3)).astype(dtype)
out_batch_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_bath_in)
out_batch_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 1)(data_bath_in)
for i in range(len(out_batch_nd)):
assert_almost_equal(image.imresize(nd.slice(data_bath_in[i], (0, 0, 0), (50, 100, 3)), 25, 25, 2).asnumpy(),
out_batch_nd[i].asnumpy())
actual = transforms.Resize(size=25, interpolation=1)(nd.slice(data_bath_in[i], (0, 0, 0), (50, 100, 3))).asnumpy()
expected = out_batch_nd[i].asnumpy()
assert_almost_equal(expected, actual)
# test with resize height and width should be greater than 0
transformer = transforms.CropResize(0, 0, 100, 50, (-25, 25), 2)
transformer = transforms.CropResize(0, 0, 100, 50, (-25, 25), 1)
assertRaises(MXNetError, transformer, data_in)
# test height and width should be greater than 0
transformer = transforms.CropResize(0, 0, -100, -50)
Expand Down Expand Up @@ -188,14 +189,6 @@ def test_crop_backward(test_nd_arr, TestCase):
data_in = nd.arange(60).reshape((5, 4, 3)).astype(dtype)
for test_case in test_list:
test_crop_backward(data_in, test_case)



# check numeric gradient of nd.image.crop
# in_data = np.arange(36).reshape(3, 4, 3)
# data = mx.sym.Variable('data')
# image_crop_sym = mx.sym.image.crop(data, 0, 0, 2, 2)
# check_numeric_gradient(image_crop_sym, [in_data])


@with_seed()
Expand Down