From 45806e75cd22e9f7faaf542a6d5b19b81f76c31c Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Mon, 4 Feb 2019 14:59:13 -0800 Subject: [PATCH] Image ToTensor operator - GPU support, 3D/4D inputs (#13837) * Add CPU implementation of ToTensor * Add tests for cpu * Add gpu implementation and tests * Fix lint issues * Cleanup includes * Move back changes to original image operators files * Add 4D example * resolve merge conflicts * Fix failing tests * parallelize on channel in kernel launch --- python/mxnet/gluon/data/vision/transforms.py | 9 +- src/operator/image/image_random-inl.h | 97 +++++++++++++++---- src/operator/image/image_random.cc | 63 +++++++++++- src/operator/image/image_random.cu | 44 +++++---- tests/python/gpu/test_gluon_transforms.py | 36 ++++++- .../python/unittest/test_gluon_data_vision.py | 14 ++- 6 files changed, 213 insertions(+), 50 deletions(-) diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index aa4a3e3d8957..9310e15f5133 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -96,17 +96,20 @@ def hybrid_forward(self, F, x): class ToTensor(HybridBlock): - """Converts an image NDArray to a tensor NDArray. + """Converts an image NDArray or batch of image NDArray to a tensor NDArray. Converts an image NDArray of shape (H x W x C) in the range [0, 255] to a float32 tensor NDArray of shape (C x H x W) in the range [0, 1). + If batch input, converts a batch image NDArray of shape (N x H x W x C) in the + range [0, 255] to a float32 tensor NDArray of shape (N x C x H x W). + Inputs: - - **data**: input tensor with (H x W x C) shape and uint8 type. + - **data**: input tensor with (H x W x C) or (N x H x W x C) shape and uint8 type. Outputs: - - **out**: output tensor with (C x H x W) shape and float32 type. + - **out**: output tensor with (C x H x W) or (N x H x W x C) shape and float32 type. Examples -------- diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index aeea0bcf9fec..c9dd85af616f 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -43,16 +43,28 @@ namespace mxnet { namespace op { namespace image { +// There are no parameters for this operator. +// Hence, no arameter registration. + +// Shape and Type inference for image to tensor operator inline bool ToTensorShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); + TShape &shp = (*in_attrs)[0]; if (!shp.ndim()) return false; - CHECK_EQ(shp.ndim(), 3) - << "Input image must have shape (height, width, channels), but got " << shp; - SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[2], shp[0], shp[1]})); + + CHECK((shp.ndim() == 3) || (shp.ndim() == 4)) + << "Input image must have shape (height, width, channels), or " + << "(N, height, width, channels) but got " << shp; + if (shp.ndim() == 3) { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[2], shp[0], shp[1]})); + } else if (shp.ndim() == 4) { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[0], shp[3], shp[1], shp[2]})); + } + return true; } @@ -65,31 +77,74 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs, return (*in_attrs)[0] != -1; } -inline void ToTensor(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - CHECK_EQ(req[0], kWriteTo) - << "`to_tensor` does not support inplace"; +// Operator Implementation - int length = inputs[0].shape_[0] * inputs[0].shape_[1]; - int channel = inputs[0].shape_[2]; +template +struct totensor_forward { + template + MSHADOW_XINLINE static void Map(uint32_t c, float* out_data, const DType* in_data, + const int length, const int channel, const int step, + const float normalize_factor = 255.0f) { + #pragma omp parallel for + for (int i = 0; i < length; ++i) { + KERNEL_ASSIGN(out_data[step + c*length + i], req, + (in_data[step + i*channel + c]) / normalize_factor); + } + } +}; + +template +void ToTensorImpl(const OpContext &ctx, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &req, + const int length, + const uint32_t channel, + const int step = 0) { + mshadow::Stream *s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { - float* output = outputs[0].dptr(); - DType* input = inputs[0].dptr(); + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + float* output = outputs[0].dptr(); + DType* input = inputs[0].dptr(); + mxnet_op::Kernel, xpu>::Launch( + s, channel, output, input, length, channel, step); + }); + }); +} - for (int l = 0; l < length; ++l) { - for (int c = 0; c < channel; ++c) { - output[c*length + l] = static_cast(input[l*channel + c]) / 255.0f; - } +template +void ToTensorOpForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + + CHECK_EQ(req[0], kWriteTo) + << "`to_tensor` does not support inplace updates"; + + // 3D Input - (h, w, c) + if (inputs[0].ndim() == 3) { + const int length = inputs[0].shape_[0] * inputs[0].shape_[1]; + const uint32_t channel = inputs[0].shape_[2]; + ToTensorImpl(ctx, inputs, outputs, req, length, channel); + } else if (inputs[0].ndim() == 4) { + // 4D input (n, h, w, c) + const int batch_size = inputs[0].shape_[0]; + const int length = inputs[0].shape_[1] * inputs[0].shape_[2]; + const uint32_t channel = inputs[0].shape_[3]; + const int step = channel * length; + + #pragma omp parallel for + for (auto n = 0; n < batch_size; ++n) { + ToTensorImpl(ctx, inputs, outputs, req, length, channel, n*step); } - }); + } } -// Normalize Operator -// Parameter registration for image Normalize operator struct NormalizeParam : public dmlc::Parameter { nnvm::Tuple mean; nnvm::Tuple std; diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc index 7901747c8ea1..fc6b17c0b1ca 100644 --- a/src/operator/image/image_random.cc +++ b/src/operator/image/image_random.cc @@ -39,14 +39,71 @@ DMLC_REGISTER_PARAMETER(RandomLightingParam); DMLC_REGISTER_PARAMETER(RandomColorJitterParam); NNVM_REGISTER_OP(_image_to_tensor) -.describe(R"code()code" ADD_FILELINE) +.describe(R"code(Converts an image NDArray of shape (H x W x C) or (N x H x W x C) +with values in the range [0, 255] to a tensor NDArray of shape (C x H x W) or (N x C x H x W) +with values in the range [0, 1) + +Example: + .. code-block:: python + image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8) + to_tensor(image) + [[[ 0.85490197 0.72156864] + [ 0.09019608 0.74117649] + [ 0.61960787 0.92941177] + [ 0.96470588 0.1882353 ]] + [[ 0.6156863 0.73725492] + [ 0.46666667 0.98039216] + [ 0.44705883 0.45490196] + [ 0.01960784 0.8509804 ]] + [[ 0.39607844 0.03137255] + [ 0.72156864 0.52941179] + [ 0.16470589 0.7647059 ] + [ 0.05490196 0.70588237]]] + + + image = mx.nd.random.uniform(0, 255, (2, 4, 2, 3)).astype(dtype=np.uint8) + to_tensor(image) + [[[[0.11764706 0.5803922 ] + [0.9411765 0.10588235] + [0.2627451 0.73333335] + [0.5647059 0.32156864]] + [[0.7176471 0.14117648] + [0.75686276 0.4117647 ] + [0.18431373 0.45490196] + [0.13333334 0.6156863 ]] + [[0.6392157 0.5372549 ] + [0.52156866 0.47058824] + [0.77254903 0.21568628] + [0.01568628 0.14901961]]] + [[[0.6117647 0.38431373] + [0.6784314 0.6117647 ] + [0.69411767 0.96862745] + [0.67058825 0.35686275]] + [[0.21960784 0.9411765 ] + [0.44705883 0.43529412] + [0.09803922 0.6666667 ] + [0.16862746 0.1254902 ]] + [[0.6156863 0.9019608 ] + [0.35686275 0.9019608 ] + [0.05882353 0.6509804 ] + [0.20784314 0.7490196 ]]]] + +)code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) .set_attr("FInferShape", ToTensorShape) .set_attr("FInferType", ToTensorType) -.set_attr("FCompute", ToTensor) +.set_attr("FCompute", ToTensorOpForward) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) .set_attr("FGradient", ElemwiseGradUseNone{ "_copy" }) -.add_argument("data", "NDArray-or-Symbol", "The input."); +.add_argument("data", "NDArray-or-Symbol", "Input ndarray"); NNVM_REGISTER_OP(_image_normalize) .describe(R"code(Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and diff --git a/src/operator/image/image_random.cu b/src/operator/image/image_random.cu index 404c3d25477a..5f9aff27e85b 100644 --- a/src/operator/image/image_random.cu +++ b/src/operator/image/image_random.cu @@ -1,26 +1,26 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ /*! - * \file image_random.cu - * \brief GPU Implementation of image transformation operators - */ +* \file image_random.cu +* \brief GPU Implementation of image transformation operators +*/ #include "./image_random-inl.h" #include "../elemwise_op_common.h" @@ -28,13 +28,15 @@ namespace mxnet { namespace op { namespace image { +NNVM_REGISTER_OP(_image_to_tensor) +.set_attr("FCompute", ToTensorOpForward); + NNVM_REGISTER_OP(_image_normalize) .set_attr("FCompute", NormalizeOpForward); NNVM_REGISTER_OP(_backward_image_normalize) .set_attr("FCompute", NormalizeOpBackward); - } // namespace image } // namespace op } // namespace mxnet diff --git a/tests/python/gpu/test_gluon_transforms.py b/tests/python/gpu/test_gluon_transforms.py index 4a1017b538ac..3927d4c1f094 100644 --- a/tests/python/gpu/test_gluon_transforms.py +++ b/tests/python/gpu/test_gluon_transforms.py @@ -71,6 +71,41 @@ def test_normalize(): normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1)) assertRaises(MXNetError, normalize_transformer, invalid_data_in) +@with_seed() +def test_to_tensor(): + # 3D Input + data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8) + out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8')) + assert_almost_equal(out_nd.asnumpy(), np.transpose( + data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1))) + + # 4D Input + data_in_4d = nd.random.uniform(0, 1, (2, 3, 300, 300)) + out_nd_4d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_4d) + data_expected_4d = data_in_4d.asnumpy() + data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0 + data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0 + data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0 + data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0 + data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0 + data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0 + assert_almost_equal(data_expected_4d, out_nd_4d.asnumpy()) + + # Default normalize values i.e., mean=0, std=1 + data_in_3d_def = nd.random.uniform(0, 1, (3, 300, 300)) + out_nd_3d_def = transforms.Normalize()(data_in_3d_def) + data_expected_3d_def = data_in_3d_def.asnumpy() + assert_almost_equal(data_expected_3d_def, out_nd_3d_def.asnumpy()) + + # Invalid Input - Neither 3D or 4D input + invalid_data_in = nd.random.uniform(0, 1, (5, 5, 3, 300, 300)) + normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1)) + assertRaises(MXNetError, normalize_transformer, invalid_data_in) + + # Invalid Input - Channel neither 1 or 3 + invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300)) + normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1)) + assertRaises(MXNetError, normalize_transformer, invalid_data_in) @with_seed() def test_resize(): @@ -128,4 +163,3 @@ def py_bilinear_resize_nhwc(x, outputHeight, outputWidth): data_in_4d = nd.random.uniform(0, 255, (2, 300, 300, 3)).astype('uint8') out_nd_4d = transforms.Resize((100, 100))(data_in_4d) assert_almost_equal(out_nd_4d.asnumpy(), py_bilinear_resize_nhwc(data_in_4d.asnumpy(), 100, 100), atol=1.0) - diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py index f10f0ae4fe19..a855fc8cf1df 100644 --- a/tests/python/unittest/test_gluon_data_vision.py +++ b/tests/python/unittest/test_gluon_data_vision.py @@ -29,10 +29,22 @@ @with_seed() def test_to_tensor(): + # 3D Input data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8) out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8')) assert_almost_equal(out_nd.asnumpy(), np.transpose( - data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1))) + data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1))) + + # 4D Input + data_in = np.random.uniform(0, 255, (5, 300, 300, 3)).astype(dtype=np.uint8) + out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8')) + assert_almost_equal(out_nd.asnumpy(), np.transpose( + data_in.astype(dtype=np.float32) / 255.0, (0, 3, 1, 2))) + + # Invalid Input + invalid_data_in = nd.random.uniform(0, 255, (5, 5, 300, 300, 3)).astype(dtype=np.uint8) + transformer = transforms.ToTensor() + assertRaises(MXNetError, transformer, invalid_data_in) @with_seed()