Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Image ToTensor operator - GPU support, 3D/4D inputs #13837

Merged
9 changes: 6 additions & 3 deletions python/mxnet/gluon/data/vision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,20 @@ def hybrid_forward(self, F, x):


class ToTensor(HybridBlock):
"""Converts an image NDArray to a tensor NDArray.
"""Converts an image NDArray or batch of image NDArray to a tensor NDArray.

Converts an image NDArray of shape (H x W x C) in the range
[0, 255] to a float32 tensor NDArray of shape (C x H x W) in
the range [0, 1).

If batch input, converts a batch image NDArray of shape (N x H x W x C) in the
range [0, 255] to a float32 tensor NDArray of shape (N x C x H x W).

Inputs:
- **data**: input tensor with (H x W x C) shape and uint8 type.
- **data**: input tensor with (H x W x C) or (N x H x W x C) shape and uint8 type.

Outputs:
- **out**: output tensor with (C x H x W) shape and float32 type.
- **out**: output tensor with (C x H x W) or (N x H x W x C) shape and float32 type.

Examples
--------
Expand Down
97 changes: 76 additions & 21 deletions src/operator/image/image_random-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,28 @@ namespace mxnet {
namespace op {
namespace image {

// There are no parameters for this operator.
// Hence, no arameter registration.

// Shape and Type inference for image to tensor operator
inline bool ToTensorShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

TShape &shp = (*in_attrs)[0];
if (!shp.ndim()) return false;
CHECK_EQ(shp.ndim(), 3)
<< "Input image must have shape (height, width, channels), but got " << shp;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[2], shp[0], shp[1]}));

CHECK((shp.ndim() == 3) || (shp.ndim() == 4))
<< "Input image must have shape (height, width, channels), or "
<< "(N, height, width, channels) but got " << shp;
if (shp.ndim() == 3) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[2], shp[0], shp[1]}));
} else if (shp.ndim() == 4) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({shp[0], shp[3], shp[1], shp[2]}));
}

return true;
}

Expand All @@ -65,31 +77,74 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs,
return (*in_attrs)[0] != -1;
}

inline void ToTensor(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(req[0], kWriteTo)
<< "`to_tensor` does not support inplace";
// Operator Implementation

int length = inputs[0].shape_[0] * inputs[0].shape_[1];
int channel = inputs[0].shape_[2];
template<int req>
struct totensor_forward {
template<typename DType>
MSHADOW_XINLINE static void Map(uint32_t c, float* out_data, const DType* in_data,
const int length, const int channel, const int step,
const float normalize_factor = 255.0f) {
#pragma omp parallel for
for (int i = 0; i < length; ++i) {
KERNEL_ASSIGN(out_data[step + c*length + i], req,
(in_data[step + i*channel + c]) / normalize_factor);
}
}
};

template<typename xpu>
void ToTensorImpl(const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const std::vector<OpReqType> &req,
const int length,
const uint32_t channel,
const int step = 0) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();

MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
float* output = outputs[0].dptr<float>();
DType* input = inputs[0].dptr<DType>();
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
float* output = outputs[0].dptr<float>();
DType* input = inputs[0].dptr<DType>();
mxnet_op::Kernel<totensor_forward<req_type>, xpu>::Launch(
s, channel, output, input, length, channel, step);
});
});
}

for (int l = 0; l < length; ++l) {
for (int c = 0; c < channel; ++c) {
output[c*length + l] = static_cast<float>(input[l*channel + c]) / 255.0f;
}
template<typename xpu>
void ToTensorOpForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);

CHECK_EQ(req[0], kWriteTo)
<< "`to_tensor` does not support inplace updates";

// 3D Input - (h, w, c)
if (inputs[0].ndim() == 3) {
const int length = inputs[0].shape_[0] * inputs[0].shape_[1];
const uint32_t channel = inputs[0].shape_[2];
ToTensorImpl<xpu>(ctx, inputs, outputs, req, length, channel);
} else if (inputs[0].ndim() == 4) {
// 4D input (n, h, w, c)
const int batch_size = inputs[0].shape_[0];
const int length = inputs[0].shape_[1] * inputs[0].shape_[2];
const uint32_t channel = inputs[0].shape_[3];
const int step = channel * length;

#pragma omp parallel for
for (auto n = 0; n < batch_size; ++n) {
ToTensorImpl<xpu>(ctx, inputs, outputs, req, length, channel, n*step);
}
});
}
}

// Normalize Operator
// Parameter registration for image Normalize operator
struct NormalizeParam : public dmlc::Parameter<NormalizeParam> {
nnvm::Tuple<float> mean;
nnvm::Tuple<float> std;
Expand Down
63 changes: 60 additions & 3 deletions src/operator/image/image_random.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,71 @@ DMLC_REGISTER_PARAMETER(RandomLightingParam);
DMLC_REGISTER_PARAMETER(RandomColorJitterParam);

NNVM_REGISTER_OP(_image_to_tensor)
.describe(R"code()code" ADD_FILELINE)
.describe(R"code(Converts an image NDArray of shape (H x W x C) or (N x H x W x C)
with values in the range [0, 255] to a tensor NDArray of shape (C x H x W) or (N x C x H x W)
with values in the range [0, 1)

Example:
.. code-block:: python
image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8)
to_tensor(image)
[[[ 0.85490197 0.72156864]
[ 0.09019608 0.74117649]
[ 0.61960787 0.92941177]
[ 0.96470588 0.1882353 ]]
[[ 0.6156863 0.73725492]
[ 0.46666667 0.98039216]
[ 0.44705883 0.45490196]
[ 0.01960784 0.8509804 ]]
[[ 0.39607844 0.03137255]
[ 0.72156864 0.52941179]
[ 0.16470589 0.7647059 ]
[ 0.05490196 0.70588237]]]
<NDArray 3x4x2 @cpu(0)>
sandeep-krishnamurthy marked this conversation as resolved.
Show resolved Hide resolved

image = mx.nd.random.uniform(0, 255, (2, 4, 2, 3)).astype(dtype=np.uint8)
to_tensor(image)
[[[[0.11764706 0.5803922 ]
[0.9411765 0.10588235]
[0.2627451 0.73333335]
[0.5647059 0.32156864]]
[[0.7176471 0.14117648]
[0.75686276 0.4117647 ]
[0.18431373 0.45490196]
[0.13333334 0.6156863 ]]
[[0.6392157 0.5372549 ]
[0.52156866 0.47058824]
[0.77254903 0.21568628]
[0.01568628 0.14901961]]]
[[[0.6117647 0.38431373]
[0.6784314 0.6117647 ]
[0.69411767 0.96862745]
[0.67058825 0.35686275]]
[[0.21960784 0.9411765 ]
[0.44705883 0.43529412]
[0.09803922 0.6666667 ]
[0.16862746 0.1254902 ]]
[[0.6156863 0.9019608 ]
[0.35686275 0.9019608 ]
[0.05882353 0.6509804 ]
[0.20784314 0.7490196 ]]]]
<NDArray 2x3x4x2 @cpu(0)>
)code" ADD_FILELINE)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data"};
})
.set_attr<nnvm::FInferShape>("FInferShape", ToTensorShape)
.set_attr<nnvm::FInferType>("FInferType", ToTensorType)
.set_attr<FCompute>("FCompute<cpu>", ToTensor)
.set_attr<FCompute>("FCompute<cpu>", ToTensorOpForward<cpu>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })
.add_argument("data", "NDArray-or-Symbol", "The input.");
.add_argument("data", "NDArray-or-Symbol", "Input ndarray");

NNVM_REGISTER_OP(_image_normalize)
.describe(R"code(Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and
Expand Down
44 changes: 23 additions & 21 deletions src/operator/image/image_random.cu
Original file line number Diff line number Diff line change
@@ -1,40 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file image_random.cu
* \brief GPU Implementation of image transformation operators
*/
* \file image_random.cu
* \brief GPU Implementation of image transformation operators
*/
#include "./image_random-inl.h"
#include "../elemwise_op_common.h"

namespace mxnet {
namespace op {
namespace image {

NNVM_REGISTER_OP(_image_to_tensor)
.set_attr<FCompute>("FCompute<gpu>", ToTensorOpForward<gpu>);

NNVM_REGISTER_OP(_image_normalize)
.set_attr<FCompute>("FCompute<gpu>", NormalizeOpForward<gpu>);

NNVM_REGISTER_OP(_backward_image_normalize)
.set_attr<FCompute>("FCompute<gpu>", NormalizeOpBackward<gpu>);


} // namespace image
} // namespace op
} // namespace mxnet
36 changes: 35 additions & 1 deletion tests/python/gpu/test_gluon_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,41 @@ def test_normalize():
normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
assertRaises(MXNetError, normalize_transformer, invalid_data_in)

@with_seed()
def test_to_tensor():
# 3D Input
data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
assert_almost_equal(out_nd.asnumpy(), np.transpose(
data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1)))

# 4D Input
data_in_4d = nd.random.uniform(0, 1, (2, 3, 300, 300))
out_nd_4d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_4d)
data_expected_4d = data_in_4d.asnumpy()
data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0
data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0
data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0
data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0
data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0
data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0
assert_almost_equal(data_expected_4d, out_nd_4d.asnumpy())

# Default normalize values i.e., mean=0, std=1
data_in_3d_def = nd.random.uniform(0, 1, (3, 300, 300))
out_nd_3d_def = transforms.Normalize()(data_in_3d_def)
data_expected_3d_def = data_in_3d_def.asnumpy()
assert_almost_equal(data_expected_3d_def, out_nd_3d_def.asnumpy())

# Invalid Input - Neither 3D or 4D input
invalid_data_in = nd.random.uniform(0, 1, (5, 5, 3, 300, 300))
normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
assertRaises(MXNetError, normalize_transformer, invalid_data_in)

# Invalid Input - Channel neither 1 or 3
invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300))
normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
assertRaises(MXNetError, normalize_transformer, invalid_data_in)

@with_seed()
def test_resize():
Expand Down Expand Up @@ -128,4 +163,3 @@ def py_bilinear_resize_nhwc(x, outputHeight, outputWidth):
data_in_4d = nd.random.uniform(0, 255, (2, 300, 300, 3)).astype('uint8')
out_nd_4d = transforms.Resize((100, 100))(data_in_4d)
assert_almost_equal(out_nd_4d.asnumpy(), py_bilinear_resize_nhwc(data_in_4d.asnumpy(), 100, 100), atol=1.0)

14 changes: 13 additions & 1 deletion tests/python/unittest/test_gluon_data_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,22 @@

@with_seed()
def test_to_tensor():
# 3D Input
data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
assert_almost_equal(out_nd.asnumpy(), np.transpose(
data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1)))
data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1)))

# 4D Input
data_in = np.random.uniform(0, 255, (5, 300, 300, 3)).astype(dtype=np.uint8)
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
assert_almost_equal(out_nd.asnumpy(), np.transpose(
data_in.astype(dtype=np.float32) / 255.0, (0, 3, 1, 2)))

# Invalid Input
invalid_data_in = nd.random.uniform(0, 255, (5, 5, 300, 300, 3)).astype(dtype=np.uint8)
transformer = transforms.ToTensor()
assertRaises(MXNetError, transformer, invalid_data_in)


@with_seed()
Expand Down