diff --git a/src/operator/contrib/bilinear_resize-inl.cuh b/src/operator/contrib/bilinear_resize-inl.cuh index b8dacb1c4f31..0f1605549d0b 100644 --- a/src/operator/contrib/bilinear_resize-inl.cuh +++ b/src/operator/contrib/bilinear_resize-inl.cuh @@ -62,7 +62,9 @@ static unsigned getNumThreads(int nElem, const bool smaller) { // caffe_gpu_interp2_kernel overloading with Tensor template -__global__ void caffe_gpu_interp2_kernel(const int n, +__global__ void +__launch_bounds__(cuda::kMaxThreadsPerBlock, 1) +caffe_gpu_interp2_kernel(const int n, const Acctype rheight, const Acctype rwidth, const Tensor data1, Tensor data2, @@ -111,7 +113,9 @@ __global__ void caffe_gpu_interp2_kernel(const int n, // caffe_gpu_interp2_kernel overloading with Tensor template -__global__ void caffe_gpu_interp2_kernel(const int n, +__global__ void +__launch_bounds__(cuda::kMaxThreadsPerBlock, 1) +caffe_gpu_interp2_kernel(const int n, const Acctype rheight, const Acctype rwidth, const Tensor data1, Tensor data2, diff --git a/src/operator/image/crop-inl.h b/src/operator/image/crop-inl.h index a1a4b23f658e..c13049685dea 100644 --- a/src/operator/image/crop-inl.h +++ b/src/operator/image/crop-inl.h @@ -94,6 +94,7 @@ inline bool CropShape(const nnvm::NodeAttrs& attrs, return true; } +template inline void CropImpl(int x, int y, int width, @@ -106,7 +107,7 @@ inline void CropImpl(int x, const TBlob& data = inputs[0]; const TBlob& out = outputs[0]; MXNET_NDIM_SWITCH(data.ndim(), ndim, { - Stream* s = ctx.get_stream(); + Stream* s = ctx.get_stream(); common::StaticArray begin = {0}, step = {1}; if (ndim == 3) { begin[0] = y; @@ -118,7 +119,10 @@ inline void CropImpl(int x, MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { size_t num_threads = out.shape_.FlatTo2D()[0]; - mxnet_op::Kernel, cpu>::Launch(s, num_threads, + if (std::is_same::value) { + num_threads *= out.shape_.get()[ndim - 1]; + } + mxnet_op::Kernel, xpu>::Launch(s, num_threads, out.dptr(), data.dptr(), data.shape_.get(), out.shape_.get(), begin, step); }) @@ -126,6 +130,7 @@ inline void CropImpl(int x, }) } +template inline void CropBackwardImpl(int x, int y, int width, @@ -138,7 +143,7 @@ inline void CropBackwardImpl(int x, if (req[0] == kNullOp) return; const TBlob& output_grad = inputs[0]; const TBlob& input_grad = outputs[0]; - Stream* s = ctx.get_stream(); + Stream* s = ctx.get_stream(); if (req[0] == kWriteTo) { Fill(s, input_grad, req[0], 0); } else if (req[0] == kWriteInplace) { @@ -156,7 +161,10 @@ inline void CropBackwardImpl(int x, MSHADOW_TYPE_SWITCH(output_grad.type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { size_t num_threads = output_grad.shape_.FlatTo2D()[0]; - mxnet_op::Kernel, cpu>::Launch(s, num_threads, + if (std::is_same::value) { + num_threads *= output_grad.shape_.get()[ndim - 1]; + } + mxnet_op::Kernel, xpu>::Launch(s, num_threads, input_grad.dptr(), output_grad.dptr(), input_grad.shape_.get(), output_grad.shape_.get(), begin, step); }) @@ -164,6 +172,7 @@ inline void CropBackwardImpl(int x, }) } +template inline void CropOpForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, @@ -171,9 +180,10 @@ inline void CropOpForward(const nnvm::NodeAttrs &attrs, const std::vector &outputs) { CHECK_EQ(outputs.size(), 1U); const CropParam& param = nnvm::get(attrs.parsed); - CropImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req); + CropImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req); } +template inline void CropOpBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, @@ -181,7 +191,7 @@ inline void CropOpBackward(const nnvm::NodeAttrs &attrs, const std::vector &outputs) { CHECK_EQ(outputs.size(), 1U); const CropParam& param = nnvm::get(attrs.parsed); - CropBackwardImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req); + CropBackwardImpl(param.x, param.y, param.width, param.height, inputs, outputs, ctx, req); } } // namespace image } // namespace op diff --git a/src/operator/image/crop.cc b/src/operator/image/crop.cc index 6067f89d7033..9a7aad38b486 100644 --- a/src/operator/image/crop.cc +++ b/src/operator/image/crop.cc @@ -69,7 +69,7 @@ to the given size. .set_attr_parser(ParamParser) .set_attr("FInferShape", CropShape) .set_attr("FInferType", ElemwiseType<1, 1>) -.set_attr("FCompute", CropOpForward) +.set_attr("FCompute", CropOpForward) .set_attr("FGradient", ElemwiseGradUseNone{ "_backward_image_crop" }) .add_argument("data", "NDArray-or-Symbol", "The input.") .add_arguments(CropParam::__FIELDS__()); @@ -79,7 +79,7 @@ NNVM_REGISTER_OP(_backward_image_crop) .set_num_inputs(1) .set_num_outputs(1) .set_attr("TIsBackward", true) -.set_attr("FCompute", CropOpBackward); +.set_attr("FCompute", CropOpBackward); } // namespace image } // namespace op diff --git a/src/operator/image/crop.cu b/src/operator/image/crop.cu new file mode 100644 index 000000000000..71fde06dacc0 --- /dev/null +++ b/src/operator/image/crop.cu @@ -0,0 +1,34 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +#include "crop-inl.h" + +namespace mxnet { +namespace op { +namespace image { + +NNVM_REGISTER_OP(_image_crop) +.set_attr("FCompute", CropOpForward); + +NNVM_REGISTER_OP(_backward_image_crop) +.set_attr("FCompute", CropOpBackward); + +} // namespace image +} // namespace op +} // namespace mxnet diff --git a/tests/python/gpu/test_gluon_transforms.py b/tests/python/gpu/test_gluon_transforms.py index e303008dee9a..acdd22ac39ec 100644 --- a/tests/python/gpu/test_gluon_transforms.py +++ b/tests/python/gpu/test_gluon_transforms.py @@ -28,79 +28,22 @@ curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) sys.path.insert(0, os.path.join(curr_path, '../unittest')) from common import assertRaises, setup_module, with_seed, teardown - +from test_gluon_data_vision import test_to_tensor, test_normalize, test_crop_resize set_default_context(mx.gpu(0)) @with_seed() -def test_normalize(): - # 3D Input - data_in_3d = nd.random.uniform(0, 1, (3, 300, 300)) - out_nd_3d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_3d) - data_expected_3d = data_in_3d.asnumpy() - data_expected_3d[:][:][0] = data_expected_3d[:][:][0] / 3.0 - data_expected_3d[:][:][1] = (data_expected_3d[:][:][1] - 1.0) / 2.0 - data_expected_3d[:][:][2] = data_expected_3d[:][:][2] - 2.0 - assert_almost_equal(data_expected_3d, out_nd_3d.asnumpy()) - - # 4D Input - data_in_4d = nd.random.uniform(0, 1, (2, 3, 300, 300)) - out_nd_4d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_4d) - data_expected_4d = data_in_4d.asnumpy() - data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0 - data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0 - data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0 - data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0 - data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0 - data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0 - assert_almost_equal(data_expected_4d, out_nd_4d.asnumpy()) - - # Default normalize values i.e., mean=0, std=1 - data_in_3d_def = nd.random.uniform(0, 1, (3, 300, 300)) - out_nd_3d_def = transforms.Normalize()(data_in_3d_def) - data_expected_3d_def = data_in_3d_def.asnumpy() - assert_almost_equal(data_expected_3d_def, out_nd_3d_def.asnumpy()) +def test_normalize_gpu(): + test_normalize() - # Invalid Input - Neither 3D or 4D input - invalid_data_in = nd.random.uniform(0, 1, (5, 5, 3, 300, 300)) - normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1)) - assertRaises(MXNetError, normalize_transformer, invalid_data_in) - - # Invalid Input - Channel neither 1 or 3 - invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300)) - normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1)) - assertRaises(MXNetError, normalize_transformer, invalid_data_in) @with_seed() -def test_to_tensor(): - # 3D Input - data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8) - out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8')) - assert_almost_equal(out_nd.asnumpy(), np.transpose( - data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1))) - - # 4D Input - data_in = np.random.uniform(0, 255, (5, 300, 300, 3)).astype(dtype=np.uint8) - out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8')) - assert_almost_equal(out_nd.asnumpy(), np.transpose( - data_in.astype(dtype=np.float32) / 255.0, (0, 3, 1, 2))) +def test_to_tensor_gpu(): + test_to_tensor() - # Invalid Input - invalid_data_in = nd.random.uniform(0, 255, (5, 5, 300, 300, 3)).astype(dtype=np.uint8) - transformer = transforms.ToTensor() - assertRaises(MXNetError, transformer, invalid_data_in) - - # Bounds (0->0, 255->1) - data_in = np.zeros((10, 20, 3)).astype(dtype=np.uint8) - out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8')) - assert same(out_nd.asnumpy(), np.transpose(np.zeros(data_in.shape, dtype=np.float32), (2, 0, 1))) - - data_in = np.full((10, 20, 3), 255).astype(dtype=np.uint8) - out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8')) - assert same(out_nd.asnumpy(), np.transpose(np.ones(data_in.shape, dtype=np.float32), (2, 0, 1))) @with_seed() -def test_resize(): +def test_resize_gpu(): # Test with normal case 3D input float type data_in_3d = nd.random.uniform(0, 255, (300, 300, 3)) out_nd_3d = transforms.Resize((100, 100))(data_in_3d) @@ -155,3 +98,8 @@ def py_bilinear_resize_nhwc(x, outputHeight, outputWidth): data_in_4d = nd.random.uniform(0, 255, (2, 300, 300, 3)).astype('uint8') out_nd_4d = transforms.Resize((100, 100))(data_in_4d) assert_almost_equal(out_nd_4d.asnumpy(), py_bilinear_resize_nhwc(data_in_4d.asnumpy(), 100, 100), atol=1.0) + + +@with_seed() +def test_crop_resize_gpu(): + test_crop_resize() diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py index 627567ca4244..8bc0f8072260 100644 --- a/tests/python/unittest/test_gluon_data_vision.py +++ b/tests/python/unittest/test_gluon_data_vision.py @@ -146,17 +146,18 @@ def _test_crop_resize_with_diff_type(dtype): assert((out_batch_np[0:2,0:4,1,1].flatten() == [37, 52, 67, 82, 127, 142, 157, 172]).all()) # test normal case with resize data_in = nd.random.uniform(0, 255, (300, 200, 3)).astype(dtype) - out_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_in) - data_expected = image.imresize(nd.slice(data_in, (0, 0, 0), (50, 100 , 3)), 25, 25, 2) + out_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 1)(data_in) + data_expected = transforms.Resize(size=25, interpolation=1)(nd.slice(data_in, (0, 0, 0), (50, 100, 3))) assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy()) # test 4D input with resize data_bath_in = nd.random.uniform(0, 255, (3, 300, 200, 3)).astype(dtype) - out_batch_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 2)(data_bath_in) + out_batch_nd = transforms.CropResize(0, 0, 100, 50, (25, 25), 1)(data_bath_in) for i in range(len(out_batch_nd)): - assert_almost_equal(image.imresize(nd.slice(data_bath_in[i], (0, 0, 0), (50, 100, 3)), 25, 25, 2).asnumpy(), - out_batch_nd[i].asnumpy()) + actual = transforms.Resize(size=25, interpolation=1)(nd.slice(data_bath_in[i], (0, 0, 0), (50, 100, 3))).asnumpy() + expected = out_batch_nd[i].asnumpy() + assert_almost_equal(expected, actual) # test with resize height and width should be greater than 0 - transformer = transforms.CropResize(0, 0, 100, 50, (-25, 25), 2) + transformer = transforms.CropResize(0, 0, 100, 50, (-25, 25), 1) assertRaises(MXNetError, transformer, data_in) # test height and width should be greater than 0 transformer = transforms.CropResize(0, 0, -100, -50) @@ -188,14 +189,6 @@ def test_crop_backward(test_nd_arr, TestCase): data_in = nd.arange(60).reshape((5, 4, 3)).astype(dtype) for test_case in test_list: test_crop_backward(data_in, test_case) - - - - # check numeric gradient of nd.image.crop - # in_data = np.arange(36).reshape(3, 4, 3) - # data = mx.sym.Variable('data') - # image_crop_sym = mx.sym.image.crop(data, 0, 0, 2, 2) - # check_numeric_gradient(image_crop_sym, [in_data]) @with_seed()