From 94256121f90c71d5a893a57a2759c41884d23e65 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Tue, 8 Jan 2019 15:20:12 -0800 Subject: [PATCH 01/10] Add CPU implementation of ToTensor --- src/operator/image/image_random-inl.h | 47 ---------- src/operator/image/image_random.cc | 10 -- src/operator/image/totensor_op-inl.h | 126 ++++++++++++++++++++++++++ src/operator/image/totensor_op.cc | 71 +++++++++++++++ src/operator/image/totensor_op.cu | 18 ++++ 5 files changed, 215 insertions(+), 57 deletions(-) create mode 100644 src/operator/image/totensor_op-inl.h create mode 100644 src/operator/image/totensor_op.cc create mode 100644 src/operator/image/totensor_op.cu diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index aeea0bcf9fec..b922db7019d1 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -43,53 +43,6 @@ namespace mxnet { namespace op { namespace image { -inline bool ToTensorShape(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - TShape &shp = (*in_attrs)[0]; - if (!shp.ndim()) return false; - CHECK_EQ(shp.ndim(), 3) - << "Input image must have shape (height, width, channels), but got " << shp; - SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[2], shp[0], shp[1]})); - return true; -} - -inline bool ToTensorType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32); - return (*in_attrs)[0] != -1; -} - -inline void ToTensor(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - CHECK_EQ(req[0], kWriteTo) - << "`to_tensor` does not support inplace"; - - int length = inputs[0].shape_[0] * inputs[0].shape_[1]; - int channel = inputs[0].shape_[2]; - - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { - float* output = outputs[0].dptr(); - DType* input = inputs[0].dptr(); - - for (int l = 0; l < length; ++l) { - for (int c = 0; c < channel; ++c) { - output[c*length + l] = static_cast(input[l*channel + c]) / 255.0f; - } - } - }); -} - -// Normalize Operator -// Parameter registration for image Normalize operator struct NormalizeParam : public dmlc::Parameter { nnvm::Tuple mean; nnvm::Tuple std; diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc index 7901747c8ea1..56df9e37a3bb 100644 --- a/src/operator/image/image_random.cc +++ b/src/operator/image/image_random.cc @@ -38,16 +38,6 @@ DMLC_REGISTER_PARAMETER(AdjustLightingParam); DMLC_REGISTER_PARAMETER(RandomLightingParam); DMLC_REGISTER_PARAMETER(RandomColorJitterParam); -NNVM_REGISTER_OP(_image_to_tensor) -.describe(R"code()code" ADD_FILELINE) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr("FInferShape", ToTensorShape) -.set_attr("FInferType", ToTensorType) -.set_attr("FCompute", ToTensor) -.set_attr("FGradient", ElemwiseGradUseNone{ "_copy" }) -.add_argument("data", "NDArray-or-Symbol", "The input."); - NNVM_REGISTER_OP(_image_normalize) .describe(R"code(Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and standard deviation. diff --git a/src/operator/image/totensor_op-inl.h b/src/operator/image/totensor_op-inl.h new file mode 100644 index 000000000000..7ac6f36aa6c6 --- /dev/null +++ b/src/operator/image/totensor_op-inl.h @@ -0,0 +1,126 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +/*! + * Copyright (c) 2018 by Contributors + * \file totensor_op-inl.h + * \brief Image to tensor operator +*/ +#ifndef MXNET_OPERATOR_IMAGE_TOTENSOR_OP_INL_H_ +#define MXNET_OPERATOR_IMAGE_TOTENSOR_OP_INL_H_ + + +#include +#include +#include +#include +#include +#include +#include "../mxnet_op.h" +#include "../operator_common.h" +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { +namespace image { + +// There are no parameters for this operator. +// Hence, no arameter registration. + +// Shape and Type inference for image to tensor operator +inline bool ToTensorShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + TShape &shp = (*in_attrs)[0]; + if (!shp.ndim()) return false; + + CHECK((shp.ndim() == 3) || (shp.ndim() == 4)) + << "Input image must have shape (height, width, channels), or " + << "(N, height, width, channels) but got " << shp; + if (shp.ndim() == 3) { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[2], shp[0], shp[1]})); + } else if (shp.ndim() == 4) { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[0], shp[3], shp[1], shp[2]})); + } + + return true; +} + +inline bool ToTensorType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32); + return (*in_attrs)[0] != -1; +} + +// Operator Implementation +void ToTensorImpl(const std::vector &inputs, + const std::vector &outputs, + const int length, + const int channel, + const int step = 0) { + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + float* output = outputs[0].dptr(); + DType* input = inputs[0].dptr(); + + for (int l = 0; l < length; ++l) { + for (int c = 0; c < channel; ++c) { + output[step + c*length + l] = static_cast(input[step + l*channel + c]) / 255.0f; + } + } + }); +} + +void ToTensor(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(req[0], kWriteTo) + << "`to_tensor` does not support inplace"; + + // 3D Input - 1 image + if (inputs[0].ndim() == 3) { + const int length = inputs[0].shape_[0] * inputs[0].shape_[1]; + const int channel = inputs[0].shape_[2]; + ToTensorImpl(inputs, outputs, length, channel); + } else if (inputs[0].ndim() == 4) { + // 4D input batch of images + const int batch_size = inputs[0].shape_[0]; + const int length = inputs[0].shape_[1] * inputs[0].shape_[2]; + const int channel = inputs[0].shape_[3]; + const int step = channel * length; + + #pragma omp parallel for + for (auto n = 0; n < batch_size; ++n) { + ToTensorImpl(inputs, outputs, length, channel, n*step); + } + } +} + +} // namespace image +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_IMAGE_TOTENSOR_OP_INL_H_ diff --git a/src/operator/image/totensor_op.cc b/src/operator/image/totensor_op.cc new file mode 100644 index 000000000000..1f85454bd3f0 --- /dev/null +++ b/src/operator/image/totensor_op.cc @@ -0,0 +1,71 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +/*! + * \file totensor_op.cc + * \brief CPU Implementation of ToTensor op + */ +#include "./totensor_op-inl.h" + +namespace mxnet { +namespace op { +namespace image { + +NNVM_REGISTER_OP(_image_to_tensor) +.describe(R"code(Converts an image NDArray of shape (H x W x C) or (N x H x W x C) +with values in the range [0, 255] to a tensor NDArray of shape (C x H x W) or (N x C x H x W) +with values in the range [0, 1) + +Example: + .. code-block:: python + image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8) + to_tensor(image) + [[[ 0.85490197 0.72156864] + [ 0.09019608 0.74117649] + [ 0.61960787 0.92941177] + [ 0.96470588 0.1882353 ]] + [[ 0.6156863 0.73725492] + [ 0.46666667 0.98039216] + [ 0.44705883 0.45490196] + [ 0.01960784 0.8509804 ]] + [[ 0.39607844 0.03137255] + [ 0.72156864 0.52941179] + [ 0.16470589 0.7647059 ] + [ 0.05490196 0.70588237]]] + +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FInferShape", ToTensorShape) +.set_attr("FInferType", ToTensorType) +.set_attr("FCompute", ToTensor) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.set_attr("FGradient", ElemwiseGradUseNone{ "_copy" }) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray"); + +} // namespace image +} // namespace op +} // namespace mxnet diff --git a/src/operator/image/totensor_op.cu b/src/operator/image/totensor_op.cu new file mode 100644 index 000000000000..447cf434b320 --- /dev/null +++ b/src/operator/image/totensor_op.cu @@ -0,0 +1,18 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ From 69198746162b713dbfde470ac8b2b93f48183bf7 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Tue, 8 Jan 2019 15:27:05 -0800 Subject: [PATCH 02/10] Add tests for cpu --- python/mxnet/gluon/data/vision/transforms.py | 9 ++++++--- tests/python/unittest/test_gluon_data_vision.py | 12 ++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index aa4a3e3d8957..a37054b59244 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -96,17 +96,20 @@ def hybrid_forward(self, F, x): class ToTensor(HybridBlock): - """Converts an image NDArray to a tensor NDArray. + """Converts an image NDArray or batch of image NDArray to a tensor NDArray. Converts an image NDArray of shape (H x W x C) in the range [0, 255] to a float32 tensor NDArray of shape (C x H x W) in the range [0, 1). + If batch input, converts a batch image NDArray of shape (N x H x W x C) in the + range [0, 255] to a float32 tensor NDArray of shape (N x C x H x W). + Inputs: - - **data**: input tensor with (H x W x C) shape and uint8 type. + - **data**: input tensor with (H x W x C) or (N x H x W x C) shape and uint8 type. Outputs: - - **out**: output tensor with (C x H x W) shape and float32 type. + - **out**: output tensor with (C x H x W) or (N x H x W x C) shape and float32 type. Examples -------- diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py index f10f0ae4fe19..a92a8cb82bc7 100644 --- a/tests/python/unittest/test_gluon_data_vision.py +++ b/tests/python/unittest/test_gluon_data_vision.py @@ -29,11 +29,23 @@ @with_seed() def test_to_tensor(): + # 3D Input data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8) out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8')) assert_almost_equal(out_nd.asnumpy(), np.transpose( data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1))) + # 4D Input + data_in = np.random.uniform(0, 255, (5, 300, 300, 3)).astype(dtype=np.uint8) + out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8')) + assert_almost_equal(out_nd.asnumpy(), np.transpose( + data_in.astype(dtype=np.float32) / 255.0, (0, 3, 1, 2))) + + # Invalid Input + invalid_data_in = nd.random.uniform(0, 255, (5, 5, 300, 300, 3)).astype(dtype=np.uint8) + transformer = transforms.ToTensor() + assertRaises(MXNetError, transformer, invalid_data_in) + @with_seed() def test_normalize(): From a9edea97e47d75428c4bb06c997e2be0d3dae657 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Thu, 10 Jan 2019 13:22:57 -0800 Subject: [PATCH 03/10] Add gpu implementation and tests --- src/operator/image/totensor_op-inl.h | 61 ++++++++++++++++------- src/operator/image/totensor_op.cc | 2 +- src/operator/image/totensor_op.cu | 12 +++++ tests/python/gpu/test_gluon_transforms.py | 22 +++++--- 4 files changed, 69 insertions(+), 28 deletions(-) diff --git a/src/operator/image/totensor_op-inl.h b/src/operator/image/totensor_op-inl.h index 7ac6f36aa6c6..eb69edcfd24d 100644 --- a/src/operator/image/totensor_op-inl.h +++ b/src/operator/image/totensor_op-inl.h @@ -75,38 +75,61 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs, } // Operator Implementation -void ToTensorImpl(const std::vector &inputs, - const std::vector &outputs, - const int length, - const int channel, - const int step = 0) { + +template +struct totensor_forward { + template + MSHADOW_XINLINE static void Map(int l, float* out_data, const DType* in_data, + const int c, const int length, const int channel, + const int step, const float normalize_factor = 255.0f) { + KERNEL_ASSIGN(out_data[step + c*length + l], req, + (in_data[step + l*channel + c]) / normalize_factor); + } +}; + +template +void ToTensorImpl(const OpContext &ctx, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &req, + const int length, + const int channel, + const int step = 0) { + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { float* output = outputs[0].dptr(); DType* input = inputs[0].dptr(); - for (int l = 0; l < length; ++l) { - for (int c = 0; c < channel; ++c) { - output[step + c*length + l] = static_cast(input[step + l*channel + c]) / 255.0f; - } + for (int c = 0; c < channel; ++c) { + mxnet_op::Kernel, xpu>::Launch( + s, length, output, input, c, length, channel, step); } }); + }); } -void ToTensor(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { +template +void ToTensorOpForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(req[0], kWriteTo) - << "`to_tensor` does not support inplace"; + << "`to_tensor` does not support inplace updates"; - // 3D Input - 1 image + // 3D Input - (h, w, c) if (inputs[0].ndim() == 3) { const int length = inputs[0].shape_[0] * inputs[0].shape_[1]; const int channel = inputs[0].shape_[2]; - ToTensorImpl(inputs, outputs, length, channel); + ToTensorImpl(ctx, inputs, outputs, req, length, channel); } else if (inputs[0].ndim() == 4) { - // 4D input batch of images + // 4D input (n, h, w, c) const int batch_size = inputs[0].shape_[0]; const int length = inputs[0].shape_[1] * inputs[0].shape_[2]; const int channel = inputs[0].shape_[3]; @@ -114,7 +137,7 @@ void ToTensor(const nnvm::NodeAttrs &attrs, #pragma omp parallel for for (auto n = 0; n < batch_size; ++n) { - ToTensorImpl(inputs, outputs, length, channel, n*step); + ToTensorImpl(ctx, inputs, outputs, req, length, channel, n*step); } } } diff --git a/src/operator/image/totensor_op.cc b/src/operator/image/totensor_op.cc index 1f85454bd3f0..1369a6815d71 100644 --- a/src/operator/image/totensor_op.cc +++ b/src/operator/image/totensor_op.cc @@ -58,7 +58,7 @@ with values in the range [0, 1) }) .set_attr("FInferShape", ToTensorShape) .set_attr("FInferType", ToTensorType) -.set_attr("FCompute", ToTensor) +.set_attr("FCompute", ToTensorOpForward) .set_attr("FInplaceOption", [](const NodeAttrs& attrs) { return std::vector >{{0, 0}}; diff --git a/src/operator/image/totensor_op.cu b/src/operator/image/totensor_op.cu index 447cf434b320..55084ac0d9c5 100644 --- a/src/operator/image/totensor_op.cu +++ b/src/operator/image/totensor_op.cu @@ -16,3 +16,15 @@ * specific language governing permissions and limitations * under the License. */ +#include "./totensor_op-inl.h" + +namespace mxnet { +namespace op { +namespace image { + +NNVM_REGISTER_OP(_image_to_tensor) +.set_attr("FCompute", ToTensorOpForward); + +} // namespace image +} // namespace op +} // namespace mxnet diff --git a/tests/python/gpu/test_gluon_transforms.py b/tests/python/gpu/test_gluon_transforms.py index 4a1017b538ac..f4bdce4c5522 100644 --- a/tests/python/gpu/test_gluon_transforms.py +++ b/tests/python/gpu/test_gluon_transforms.py @@ -33,15 +33,12 @@ set_default_context(mx.gpu(0)) @with_seed() -def test_normalize(): +def test_to_tensor(): # 3D Input - data_in_3d = nd.random.uniform(0, 1, (3, 300, 300)) - out_nd_3d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_3d) - data_expected_3d = data_in_3d.asnumpy() - data_expected_3d[:][:][0] = data_expected_3d[:][:][0] / 3.0 - data_expected_3d[:][:][1] = (data_expected_3d[:][:][1] - 1.0) / 2.0 - data_expected_3d[:][:][2] = data_expected_3d[:][:][2] - 2.0 - assert_almost_equal(data_expected_3d, out_nd_3d.asnumpy()) + data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8) + out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8')) + assert_almost_equal(out_nd.asnumpy(), np.transpose( + data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1))) # 4D Input data_in_4d = nd.random.uniform(0, 1, (2, 3, 300, 300)) @@ -129,3 +126,12 @@ def py_bilinear_resize_nhwc(x, outputHeight, outputWidth): out_nd_4d = transforms.Resize((100, 100))(data_in_4d) assert_almost_equal(out_nd_4d.asnumpy(), py_bilinear_resize_nhwc(data_in_4d.asnumpy(), 100, 100), atol=1.0) + data_in = np.random.uniform(0, 255, (5, 300, 300, 3)).astype(dtype=np.uint8) + out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8')) + assert_almost_equal(out_nd.asnumpy(), np.transpose( + data_in.astype(dtype=np.float32) / 255.0, (0, 3, 1, 2))) + + # Invalid Input + invalid_data_in = nd.random.uniform(0, 255, (5, 5, 300, 300, 3)).astype(dtype=np.uint8) + transformer = transforms.ToTensor() + assertRaises(MXNetError, transformer, invalid_data_in) From 7c573db4bab464bf303340a8e4dc5b3987f1a0bc Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Thu, 10 Jan 2019 14:29:10 -0800 Subject: [PATCH 04/10] Fix lint issues --- python/mxnet/gluon/data/vision/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index a37054b59244..9310e15f5133 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -102,7 +102,7 @@ class ToTensor(HybridBlock): [0, 255] to a float32 tensor NDArray of shape (C x H x W) in the range [0, 1). - If batch input, converts a batch image NDArray of shape (N x H x W x C) in the + If batch input, converts a batch image NDArray of shape (N x H x W x C) in the range [0, 255] to a float32 tensor NDArray of shape (N x C x H x W). Inputs: From f9e5bdd47921b55e3a51c07ddf59552326df7dbf Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Thu, 17 Jan 2019 14:51:58 -0800 Subject: [PATCH 05/10] Cleanup includes --- src/operator/image/totensor_op-inl.h | 7 ------- tests/python/unittest/test_gluon_data_vision.py | 4 ++-- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/src/operator/image/totensor_op-inl.h b/src/operator/image/totensor_op-inl.h index eb69edcfd24d..14d7e0bf0126 100644 --- a/src/operator/image/totensor_op-inl.h +++ b/src/operator/image/totensor_op-inl.h @@ -26,14 +26,7 @@ #define MXNET_OPERATOR_IMAGE_TOTENSOR_OP_INL_H_ -#include -#include #include -#include -#include -#include -#include "../mxnet_op.h" -#include "../operator_common.h" #include "../elemwise_op_common.h" namespace mxnet { diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py index a92a8cb82bc7..a855fc8cf1df 100644 --- a/tests/python/unittest/test_gluon_data_vision.py +++ b/tests/python/unittest/test_gluon_data_vision.py @@ -33,13 +33,13 @@ def test_to_tensor(): data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8) out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8')) assert_almost_equal(out_nd.asnumpy(), np.transpose( - data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1))) + data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1))) # 4D Input data_in = np.random.uniform(0, 255, (5, 300, 300, 3)).astype(dtype=np.uint8) out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8')) assert_almost_equal(out_nd.asnumpy(), np.transpose( - data_in.astype(dtype=np.float32) / 255.0, (0, 3, 1, 2))) + data_in.astype(dtype=np.float32) / 255.0, (0, 3, 1, 2))) # Invalid Input invalid_data_in = nd.random.uniform(0, 255, (5, 5, 300, 300, 3)).astype(dtype=np.uint8) From 2f5f3210999b7de19243d55a9f8e95669c9df6a9 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Thu, 24 Jan 2019 15:44:22 -0800 Subject: [PATCH 06/10] Move back changes to original image operators files --- src/operator/image/image_random-inl.h | 102 ++++++++++++++++++ src/operator/image/image_random.cc | 39 +++++++ src/operator/image/image_random.cu | 44 ++++---- src/operator/image/totensor_op-inl.h | 142 -------------------------- src/operator/image/totensor_op.cc | 71 ------------- src/operator/image/totensor_op.cu | 30 ------ 6 files changed, 164 insertions(+), 264 deletions(-) delete mode 100644 src/operator/image/totensor_op-inl.h delete mode 100644 src/operator/image/totensor_op.cc delete mode 100644 src/operator/image/totensor_op.cu diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index b922db7019d1..5cd4a82df21d 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -43,6 +43,108 @@ namespace mxnet { namespace op { namespace image { +// There are no parameters for this operator. +// Hence, no arameter registration. + +// Shape and Type inference for image to tensor operator +inline bool ToTensorShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + TShape &shp = (*in_attrs)[0]; + if (!shp.ndim()) return false; + + CHECK((shp.ndim() == 3) || (shp.ndim() == 4)) + << "Input image must have shape (height, width, channels), or " + << "(N, height, width, channels) but got " << shp; + if (shp.ndim() == 3) { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[2], shp[0], shp[1]})); + } else if (shp.ndim() == 4) { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[0], shp[3], shp[1], shp[2]})); + } + + return true; +} + +inline bool ToTensorType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32); + return (*in_attrs)[0] != -1; +} + +// Operator Implementation + +template +struct totensor_forward { + template + MSHADOW_XINLINE static void Map(int l, float* out_data, const DType* in_data, + const int c, const int length, const int channel, + const int step, const float normalize_factor = 255.0f) { + KERNEL_ASSIGN(out_data[step + c*length + l], req, + (in_data[step + l*channel + c]) / normalize_factor); + } +}; + +template +void ToTensorImpl(const OpContext &ctx, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &req, + const int length, + const int channel, + const int step = 0) { + mshadow::Stream *s = ctx.get_stream(); + + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + float* output = outputs[0].dptr(); + DType* input = inputs[0].dptr(); + + for (int c = 0; c < channel; ++c) { + mxnet_op::Kernel, xpu>::Launch( + s, length, output, input, c, length, channel, step); + } + }); + }); +} + +template +void ToTensorOpForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + + CHECK_EQ(req[0], kWriteTo) + << "`to_tensor` does not support inplace updates"; + + // 3D Input - (h, w, c) + if (inputs[0].ndim() == 3) { + const int length = inputs[0].shape_[0] * inputs[0].shape_[1]; + const int channel = inputs[0].shape_[2]; + ToTensorImpl(ctx, inputs, outputs, req, length, channel); + } else if (inputs[0].ndim() == 4) { + // 4D input (n, h, w, c) + const int batch_size = inputs[0].shape_[0]; + const int length = inputs[0].shape_[1] * inputs[0].shape_[2]; + const int channel = inputs[0].shape_[3]; + const int step = channel * length; + + #pragma omp parallel for + for (auto n = 0; n < batch_size; ++n) { + ToTensorImpl(ctx, inputs, outputs, req, length, channel, n*step); + } + } +} + struct NormalizeParam : public dmlc::Parameter { nnvm::Tuple mean; nnvm::Tuple std; diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc index 56df9e37a3bb..f68a43aa8af5 100644 --- a/src/operator/image/image_random.cc +++ b/src/operator/image/image_random.cc @@ -38,6 +38,45 @@ DMLC_REGISTER_PARAMETER(AdjustLightingParam); DMLC_REGISTER_PARAMETER(RandomLightingParam); DMLC_REGISTER_PARAMETER(RandomColorJitterParam); +NNVM_REGISTER_OP(_image_to_tensor) +.describe(R"code(Converts an image NDArray of shape (H x W x C) or (N x H x W x C) +with values in the range [0, 255] to a tensor NDArray of shape (C x H x W) or (N x C x H x W) +with values in the range [0, 1) + +Example: + .. code-block:: python + image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8) + to_tensor(image) + [[[ 0.85490197 0.72156864] + [ 0.09019608 0.74117649] + [ 0.61960787 0.92941177] + [ 0.96470588 0.1882353 ]] + [[ 0.6156863 0.73725492] + [ 0.46666667 0.98039216] + [ 0.44705883 0.45490196] + [ 0.01960784 0.8509804 ]] + [[ 0.39607844 0.03137255] + [ 0.72156864 0.52941179] + [ 0.16470589 0.7647059 ] + [ 0.05490196 0.70588237]]] + +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FInferShape", ToTensorShape) +.set_attr("FInferType", ToTensorType) +.set_attr("FCompute", ToTensorOpForward) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.set_attr("FGradient", ElemwiseGradUseNone{ "_copy" }) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray"); + NNVM_REGISTER_OP(_image_normalize) .describe(R"code(Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and standard deviation. diff --git a/src/operator/image/image_random.cu b/src/operator/image/image_random.cu index 404c3d25477a..5f9aff27e85b 100644 --- a/src/operator/image/image_random.cu +++ b/src/operator/image/image_random.cu @@ -1,26 +1,26 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ /*! - * \file image_random.cu - * \brief GPU Implementation of image transformation operators - */ +* \file image_random.cu +* \brief GPU Implementation of image transformation operators +*/ #include "./image_random-inl.h" #include "../elemwise_op_common.h" @@ -28,13 +28,15 @@ namespace mxnet { namespace op { namespace image { +NNVM_REGISTER_OP(_image_to_tensor) +.set_attr("FCompute", ToTensorOpForward); + NNVM_REGISTER_OP(_image_normalize) .set_attr("FCompute", NormalizeOpForward); NNVM_REGISTER_OP(_backward_image_normalize) .set_attr("FCompute", NormalizeOpBackward); - } // namespace image } // namespace op } // namespace mxnet diff --git a/src/operator/image/totensor_op-inl.h b/src/operator/image/totensor_op-inl.h deleted file mode 100644 index 14d7e0bf0126..000000000000 --- a/src/operator/image/totensor_op-inl.h +++ /dev/null @@ -1,142 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -*/ - -/*! - * Copyright (c) 2018 by Contributors - * \file totensor_op-inl.h - * \brief Image to tensor operator -*/ -#ifndef MXNET_OPERATOR_IMAGE_TOTENSOR_OP_INL_H_ -#define MXNET_OPERATOR_IMAGE_TOTENSOR_OP_INL_H_ - - -#include -#include "../elemwise_op_common.h" - -namespace mxnet { -namespace op { -namespace image { - -// There are no parameters for this operator. -// Hence, no arameter registration. - -// Shape and Type inference for image to tensor operator -inline bool ToTensorShape(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - - TShape &shp = (*in_attrs)[0]; - if (!shp.ndim()) return false; - - CHECK((shp.ndim() == 3) || (shp.ndim() == 4)) - << "Input image must have shape (height, width, channels), or " - << "(N, height, width, channels) but got " << shp; - if (shp.ndim() == 3) { - SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[2], shp[0], shp[1]})); - } else if (shp.ndim() == 4) { - SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[0], shp[3], shp[1], shp[2]})); - } - - return true; -} - -inline bool ToTensorType(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32); - return (*in_attrs)[0] != -1; -} - -// Operator Implementation - -template -struct totensor_forward { - template - MSHADOW_XINLINE static void Map(int l, float* out_data, const DType* in_data, - const int c, const int length, const int channel, - const int step, const float normalize_factor = 255.0f) { - KERNEL_ASSIGN(out_data[step + c*length + l], req, - (in_data[step + l*channel + c]) / normalize_factor); - } -}; - -template -void ToTensorImpl(const OpContext &ctx, - const std::vector &inputs, - const std::vector &outputs, - const std::vector &req, - const int length, - const int channel, - const int step = 0) { - mshadow::Stream *s = ctx.get_stream(); - - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - float* output = outputs[0].dptr(); - DType* input = inputs[0].dptr(); - - for (int c = 0; c < channel; ++c) { - mxnet_op::Kernel, xpu>::Launch( - s, length, output, input, c, length, channel, step); - } - }); - }); -} - -template -void ToTensorOpForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - CHECK_EQ(inputs.size(), 1U); - CHECK_EQ(outputs.size(), 1U); - CHECK_EQ(req.size(), 1U); - - CHECK_EQ(req[0], kWriteTo) - << "`to_tensor` does not support inplace updates"; - - // 3D Input - (h, w, c) - if (inputs[0].ndim() == 3) { - const int length = inputs[0].shape_[0] * inputs[0].shape_[1]; - const int channel = inputs[0].shape_[2]; - ToTensorImpl(ctx, inputs, outputs, req, length, channel); - } else if (inputs[0].ndim() == 4) { - // 4D input (n, h, w, c) - const int batch_size = inputs[0].shape_[0]; - const int length = inputs[0].shape_[1] * inputs[0].shape_[2]; - const int channel = inputs[0].shape_[3]; - const int step = channel * length; - - #pragma omp parallel for - for (auto n = 0; n < batch_size; ++n) { - ToTensorImpl(ctx, inputs, outputs, req, length, channel, n*step); - } - } -} - -} // namespace image -} // namespace op -} // namespace mxnet - -#endif // MXNET_OPERATOR_IMAGE_TOTENSOR_OP_INL_H_ diff --git a/src/operator/image/totensor_op.cc b/src/operator/image/totensor_op.cc deleted file mode 100644 index 1369a6815d71..000000000000 --- a/src/operator/image/totensor_op.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -*/ - -/*! - * \file totensor_op.cc - * \brief CPU Implementation of ToTensor op - */ -#include "./totensor_op-inl.h" - -namespace mxnet { -namespace op { -namespace image { - -NNVM_REGISTER_OP(_image_to_tensor) -.describe(R"code(Converts an image NDArray of shape (H x W x C) or (N x H x W x C) -with values in the range [0, 255] to a tensor NDArray of shape (C x H x W) or (N x C x H x W) -with values in the range [0, 1) - -Example: - .. code-block:: python - image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8) - to_tensor(image) - [[[ 0.85490197 0.72156864] - [ 0.09019608 0.74117649] - [ 0.61960787 0.92941177] - [ 0.96470588 0.1882353 ]] - [[ 0.6156863 0.73725492] - [ 0.46666667 0.98039216] - [ 0.44705883 0.45490196] - [ 0.01960784 0.8509804 ]] - [[ 0.39607844 0.03137255] - [ 0.72156864 0.52941179] - [ 0.16470589 0.7647059 ] - [ 0.05490196 0.70588237]]] - -)code" ADD_FILELINE) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"data"}; - }) -.set_attr("FInferShape", ToTensorShape) -.set_attr("FInferType", ToTensorType) -.set_attr("FCompute", ToTensorOpForward) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs) { - return std::vector >{{0, 0}}; - }) -.set_attr("FGradient", ElemwiseGradUseNone{ "_copy" }) -.add_argument("data", "NDArray-or-Symbol", "Input ndarray"); - -} // namespace image -} // namespace op -} // namespace mxnet diff --git a/src/operator/image/totensor_op.cu b/src/operator/image/totensor_op.cu deleted file mode 100644 index 55084ac0d9c5..000000000000 --- a/src/operator/image/totensor_op.cu +++ /dev/null @@ -1,30 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one -* or more contributor license agreements. See the NOTICE file -* distributed with this work for additional information -* regarding copyright ownership. The ASF licenses this file -* to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance -* with the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, -* software distributed under the License is distributed on an -* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -* KIND, either express or implied. See the License for the -* specific language governing permissions and limitations -* under the License. -*/ -#include "./totensor_op-inl.h" - -namespace mxnet { -namespace op { -namespace image { - -NNVM_REGISTER_OP(_image_to_tensor) -.set_attr("FCompute", ToTensorOpForward); - -} // namespace image -} // namespace op -} // namespace mxnet From b856d775f7719fc4f643d3c8f96087a2075a6f9a Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Mon, 28 Jan 2019 11:53:54 -0800 Subject: [PATCH 07/10] Add 4D example --- src/operator/image/image_random.cc | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc index f68a43aa8af5..fc6b17c0b1ca 100644 --- a/src/operator/image/image_random.cc +++ b/src/operator/image/image_random.cc @@ -60,6 +60,34 @@ with values in the range [0, 1) [ 0.16470589 0.7647059 ] [ 0.05490196 0.70588237]]] + + image = mx.nd.random.uniform(0, 255, (2, 4, 2, 3)).astype(dtype=np.uint8) + to_tensor(image) + [[[[0.11764706 0.5803922 ] + [0.9411765 0.10588235] + [0.2627451 0.73333335] + [0.5647059 0.32156864]] + [[0.7176471 0.14117648] + [0.75686276 0.4117647 ] + [0.18431373 0.45490196] + [0.13333334 0.6156863 ]] + [[0.6392157 0.5372549 ] + [0.52156866 0.47058824] + [0.77254903 0.21568628] + [0.01568628 0.14901961]]] + [[[0.6117647 0.38431373] + [0.6784314 0.6117647 ] + [0.69411767 0.96862745] + [0.67058825 0.35686275]] + [[0.21960784 0.9411765 ] + [0.44705883 0.43529412] + [0.09803922 0.6666667 ] + [0.16862746 0.1254902 ]] + [[0.6156863 0.9019608 ] + [0.35686275 0.9019608 ] + [0.05882353 0.6509804 ] + [0.20784314 0.7490196 ]]]] + )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) From f7361a79c3274534e552ad55fbef0841c42d03a1 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Mon, 28 Jan 2019 22:24:23 -0800 Subject: [PATCH 08/10] resolve merge conflicts --- tests/python/gpu/test_gluon_transforms.py | 39 +++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/python/gpu/test_gluon_transforms.py b/tests/python/gpu/test_gluon_transforms.py index f4bdce4c5522..165f9329ad0d 100644 --- a/tests/python/gpu/test_gluon_transforms.py +++ b/tests/python/gpu/test_gluon_transforms.py @@ -32,6 +32,45 @@ set_default_context(mx.gpu(0)) +@with_seed() +def test_normalize(): + # 3D Input + data_in_3d = nd.random.uniform(0, 1, (3, 300, 300)) + out_nd_3d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_3d) + data_expected_3d = data_in_3d.asnumpy() + data_expected_3d[:][:][0] = data_expected_3d[:][:][0] / 3.0 + data_expected_3d[:][:][1] = (data_expected_3d[:][:][1] - 1.0) / 2.0 + data_expected_3d[:][:][2] = data_expected_3d[:][:][2] - 2.0 + assert_almost_equal(data_expected_3d, out_nd_3d.asnumpy()) + + # 4D Input + data_in_4d = nd.random.uniform(0, 1, (2, 3, 300, 300)) + out_nd_4d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_4d) + data_expected_4d = data_in_4d.asnumpy() + data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0 + data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0 + data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0 + data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0 + data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0 + data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0 + assert_almost_equal(data_expected_4d, out_nd_4d.asnumpy()) + + # Default normalize values i.e., mean=0, std=1 + data_in_3d_def = nd.random.uniform(0, 1, (3, 300, 300)) + out_nd_3d_def = transforms.Normalize()(data_in_3d_def) + data_expected_3d_def = data_in_3d_def.asnumpy() + assert_almost_equal(data_expected_3d_def, out_nd_3d_def.asnumpy()) + + # Invalid Input - Neither 3D or 4D input + invalid_data_in = nd.random.uniform(0, 1, (5, 5, 3, 300, 300)) + normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1)) + assertRaises(MXNetError, normalize_transformer, invalid_data_in) + + # Invalid Input - Channel neither 1 or 3 + invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300)) + normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1)) + assertRaises(MXNetError, normalize_transformer, invalid_data_in) + @with_seed() def test_to_tensor(): # 3D Input From d2988fa12ca27a82ba57638beb5c73716a0818db Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Fri, 1 Feb 2019 09:34:12 -0800 Subject: [PATCH 09/10] Fix failing tests --- tests/python/gpu/test_gluon_transforms.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/python/gpu/test_gluon_transforms.py b/tests/python/gpu/test_gluon_transforms.py index 165f9329ad0d..3927d4c1f094 100644 --- a/tests/python/gpu/test_gluon_transforms.py +++ b/tests/python/gpu/test_gluon_transforms.py @@ -107,7 +107,6 @@ def test_to_tensor(): normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1)) assertRaises(MXNetError, normalize_transformer, invalid_data_in) - @with_seed() def test_resize(): # Test with normal case 3D input float type @@ -164,13 +163,3 @@ def py_bilinear_resize_nhwc(x, outputHeight, outputWidth): data_in_4d = nd.random.uniform(0, 255, (2, 300, 300, 3)).astype('uint8') out_nd_4d = transforms.Resize((100, 100))(data_in_4d) assert_almost_equal(out_nd_4d.asnumpy(), py_bilinear_resize_nhwc(data_in_4d.asnumpy(), 100, 100), atol=1.0) - - data_in = np.random.uniform(0, 255, (5, 300, 300, 3)).astype(dtype=np.uint8) - out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8')) - assert_almost_equal(out_nd.asnumpy(), np.transpose( - data_in.astype(dtype=np.float32) / 255.0, (0, 3, 1, 2))) - - # Invalid Input - invalid_data_in = nd.random.uniform(0, 255, (5, 5, 300, 300, 3)).astype(dtype=np.uint8) - transformer = transforms.ToTensor() - assertRaises(MXNetError, transformer, invalid_data_in) From a770ce7c26ae94f4193638627cf369d2dcdb53de Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Fri, 1 Feb 2019 18:44:24 -0800 Subject: [PATCH 10/10] parallelize on channel in kernel launch --- src/operator/image/image_random-inl.h | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index 5cd4a82df21d..c9dd85af616f 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -82,11 +82,14 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs, template struct totensor_forward { template - MSHADOW_XINLINE static void Map(int l, float* out_data, const DType* in_data, - const int c, const int length, const int channel, - const int step, const float normalize_factor = 255.0f) { - KERNEL_ASSIGN(out_data[step + c*length + l], req, - (in_data[step + l*channel + c]) / normalize_factor); + MSHADOW_XINLINE static void Map(uint32_t c, float* out_data, const DType* in_data, + const int length, const int channel, const int step, + const float normalize_factor = 255.0f) { + #pragma omp parallel for + for (int i = 0; i < length; ++i) { + KERNEL_ASSIGN(out_data[step + c*length + i], req, + (in_data[step + i*channel + c]) / normalize_factor); + } } }; @@ -96,7 +99,7 @@ void ToTensorImpl(const OpContext &ctx, const std::vector &outputs, const std::vector &req, const int length, - const int channel, + const uint32_t channel, const int step = 0) { mshadow::Stream *s = ctx.get_stream(); @@ -104,11 +107,8 @@ void ToTensorImpl(const OpContext &ctx, MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { float* output = outputs[0].dptr(); DType* input = inputs[0].dptr(); - - for (int c = 0; c < channel; ++c) { - mxnet_op::Kernel, xpu>::Launch( - s, length, output, input, c, length, channel, step); - } + mxnet_op::Kernel, xpu>::Launch( + s, channel, output, input, length, channel, step); }); }); } @@ -129,13 +129,13 @@ void ToTensorOpForward(const nnvm::NodeAttrs &attrs, // 3D Input - (h, w, c) if (inputs[0].ndim() == 3) { const int length = inputs[0].shape_[0] * inputs[0].shape_[1]; - const int channel = inputs[0].shape_[2]; + const uint32_t channel = inputs[0].shape_[2]; ToTensorImpl(ctx, inputs, outputs, req, length, channel); } else if (inputs[0].ndim() == 4) { // 4D input (n, h, w, c) const int batch_size = inputs[0].shape_[0]; const int length = inputs[0].shape_[1] * inputs[0].shape_[2]; - const int channel = inputs[0].shape_[3]; + const uint32_t channel = inputs[0].shape_[3]; const int step = channel * length; #pragma omp parallel for