diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index 175076925332..2f557f591f60 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -135,7 +135,7 @@ def hybrid_forward(self, F, x): class Normalize(HybridBlock): - """Normalize an tensor of shape (C x H x W) with mean and + """Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and standard deviation. Given mean `(m1, ..., mn)` and std `(s1, ..., sn)` for `n` channels, @@ -154,12 +154,31 @@ class Normalize(HybridBlock): Inputs: - - **data**: input tensor with (C x H x W) shape. + - **data**: input tensor with (C x H x W) or (N x C x H x W) shape. Outputs: - **out**: output tensor with the shape as `data`. + + Examples + -------- + >>> transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1)) + >>> image = mx.nd.random.uniform(0, 1, (3, 4, 2)) + >>> transformer(image) + [[[ 0.18293785 0.19761486] + [ 0.23839645 0.28142193] + [ 0.20092112 0.28598186] + [ 0.18162774 0.28241724]] + [[-0.2881726 -0.18821815] + [-0.17705294 -0.30780914] + [-0.2812064 -0.3512327 ] + [-0.05411351 -0.4716435 ]] + [[-1.0363373 -1.7273437 ] + [-1.6165586 -1.5223348 ] + [-1.208275 -1.1878313 ] + [-1.4711051 -1.5200229 ]]] + """ - def __init__(self, mean, std): + def __init__(self, mean=0.0, std=1.0): super(Normalize, self).__init__() self._mean = mean self._std = std diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index c64ed28ecc2d..74807b9b681e 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -31,7 +31,6 @@ #include #include #include -#include #include #include "../mxnet_op.h" #include "../operator_common.h" @@ -62,7 +61,7 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs, return (*in_attrs)[0] != -1; } -void ToTensor(const nnvm::NodeAttrs &attrs, +inline void ToTensor(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -85,32 +84,53 @@ void ToTensor(const nnvm::NodeAttrs &attrs, }); } +// Normalize Operator +// Parameter registration for image Normalize operator struct NormalizeParam : public dmlc::Parameter { nnvm::Tuple mean; nnvm::Tuple std; + DMLC_DECLARE_PARAMETER(NormalizeParam) { DMLC_DECLARE_FIELD(mean) - .describe("Sequence of mean for each channel."); + .set_default(nnvm::Tuple {0.0f, 0.0f, 0.0f, 0.0f}) + .describe("Sequence of means for each channel. " + "Default value is 0."); DMLC_DECLARE_FIELD(std) - .describe("Sequence of standard deviations for each channel."); + .set_default(nnvm::Tuple {1.0f, 1.0f, 1.0f, 1.0f}) + .describe("Sequence of standard deviations for each channel. " + "Default value is 1."); } }; -inline bool NormalizeShape(const nnvm::NodeAttrs& attrs, +// Shape and Type inference for image Normalize operator + +// Shape inference +inline bool NormalizeOpShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { const NormalizeParam ¶m = nnvm::get(attrs.parsed); + const auto& dshape = (*in_attrs)[0]; if (!dshape.ndim()) return false; - CHECK_EQ(dshape.ndim(), 3) - << "Input tensor must have shape (channels, height, width), but got " - << dshape; - auto nchannels = dshape[0]; - CHECK(nchannels == 3 || nchannels == 1) + CHECK((dshape.ndim() == 3) || (dshape.ndim() == 4)) + << "Input tensor must have shape (channels, height, width), or " + << "(N, channels, height, width), but got " << dshape; + + uint32_t nchannels; + if (dshape.ndim() == 3) { + nchannels = dshape[0]; + CHECK(nchannels == 3 || nchannels == 1) << "The first dimension of input tensor must be the channel dimension with " << "either 1 or 3 elements, but got input with shape " << dshape; - CHECK(param.mean.ndim() == 1 || param.mean.ndim() == nchannels) + } else if (dshape.ndim() == 4) { + nchannels = dshape[1]; + CHECK(nchannels == 3 || nchannels == 1) + << "The second dimension of input tensor must be the channel dimension with " + << "either 1 or 3 elements, but got input with shape " << dshape; + } + + CHECK((param.mean.ndim() == 1) || (param.mean.ndim() == nchannels)) << "Invalid mean for input with shape " << dshape << ". mean must have either 1 or " << nchannels << " elements, but got " << param.mean; @@ -123,28 +143,156 @@ inline bool NormalizeShape(const nnvm::NodeAttrs& attrs, return true; } -void Normalize(const nnvm::NodeAttrs &attrs, +// Type Inference +inline bool NormalizeOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + return out_attrs->at(0) != -1; +} + +template +struct normalize_forward { + template + MSHADOW_XINLINE static void Map(int j, DType* out_data, const DType* in_data, + const int i, const int length, const int step, + const DType mean, const DType std_dev) { + KERNEL_ASSIGN(out_data[step + i*length + j], req, + (in_data[step + i*length + j] - mean) / std_dev); + } +}; + +template +void NormalizeImpl(const OpContext &ctx, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &req, + const NormalizeParam ¶m, + const int length, + const uint32_t channel, + const int step = 0) { + mshadow::Stream *s = ctx.get_stream(); + + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + DType* input = inputs[0].dptr(); + DType* output = outputs[0].dptr(); + + for (uint32_t i = 0; i < channel; ++i) { + DType mean = param.mean[param.mean.ndim() > i ? i : 0]; + DType std_dev = param.std[param.std.ndim() > i ? i : 0]; + mxnet_op::Kernel, xpu>::Launch( + s, length, output, input, + i, length, step, mean, std_dev); + } + }); + }); +} + +template +void NormalizeOpForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const NormalizeParam ¶m = nnvm::get(attrs.parsed); - int nchannels = inputs[0].shape_[0]; - int length = inputs[0].shape_[1] * inputs[0].shape_[2]; + // 3D input (c, h, w) + if (inputs[0].ndim() == 3) { + const int length = inputs[0].shape_[1] * inputs[0].shape_[2]; + const uint32_t channel = inputs[0].shape_[0]; + NormalizeImpl(ctx, inputs, outputs, req, param, length, channel); + } else if (inputs[0].ndim() == 4) { + // 4D input (n, c, h, w) + const int batch_size = inputs[0].shape_[0]; + const int length = inputs[0].shape_[2] * inputs[0].shape_[3]; + const uint32_t channel = inputs[0].shape_[1]; + const int step = channel * length; + + #pragma omp parallel for + for (auto n = 0; n < batch_size; ++n) { + NormalizeImpl(ctx, inputs, outputs, req, param, length, channel, n*step); + } + } +} - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - DType* input = inputs[0].dptr(); - DType* output = outputs[0].dptr(); +// Backward function +template +struct normalize_backward { + template + MSHADOW_XINLINE static void Map(int j, DType* in_grad, const DType* out_grad, + const int i, const int length, + const int step, const DType std_dev) { + // d/dx{(x - mean) / std_dev} => (1 / std_dev) + KERNEL_ASSIGN(in_grad[step + i*length + j], req, + out_grad[step + i*length + j] * (1.0 / std_dev)); + } +}; - for (int i = 0; i < nchannels; ++i) { - DType mean = param.mean[param.mean.ndim() > 1 ? i : 0]; - DType std = param.std[param.std.ndim() > 1 ? i : 0]; - for (int j = 0; j < length; ++j) { - output[i*length + j] = (input[i*length + j] - mean) / std; - } +template +void NormalizeBackwardImpl(const OpContext &ctx, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &req, + const NormalizeParam ¶m, + const int length, + const uint32_t channel, + const int step = 0) { + mshadow::Stream *s = ctx.get_stream(); + const TBlob& out_grad = inputs[0]; + const TBlob& in_grad = outputs[0]; + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + for (uint32_t i = 0; i < channel; ++i) { + DType std_dev = param.std[param.std.ndim() > i ? i : 0]; + mxnet_op::Kernel, xpu>::Launch( + s, length, in_grad.dptr(), out_grad.dptr(), + i, length, step, std_dev); + } + }); + }); +} + +template +void NormalizeOpBackward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + + const NormalizeParam ¶m = nnvm::get(attrs.parsed); + + // Note: inputs[0] is out_grad + const TBlob& in_data = inputs[1]; + + // 3D input (c, h, w) + if (in_data.ndim() == 3) { + const int length = in_data.shape_[1] * in_data.shape_[2]; + const uint32_t channel = in_data.shape_[0]; + NormalizeBackwardImpl(ctx, inputs, outputs, req, param, length, channel); + } else if (in_data.ndim() == 4) { + // 4D input (n, c, h, w) + const int batch_size = in_data.shape_[0]; + const int length = in_data.shape_[2] * in_data.shape_[3]; + const uint32_t channel = in_data.shape_[1]; + const int step = channel * length; + + #pragma omp parallel for + for (auto n = 0; n < batch_size; ++n) { + NormalizeBackwardImpl(ctx, inputs, outputs, req, param, length, channel, n*step); } - }); + } } template @@ -190,7 +338,7 @@ void FlipImpl(const TShape &shape, DType *src, DType *dst) { } } -void FlipLeftRight(const nnvm::NodeAttrs &attrs, +inline void FlipLeftRight(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -202,7 +350,7 @@ void FlipLeftRight(const nnvm::NodeAttrs &attrs, }); } -void FlipTopBottom(const nnvm::NodeAttrs &attrs, +inline void FlipTopBottom(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -214,7 +362,7 @@ void FlipTopBottom(const nnvm::NodeAttrs &attrs, }); } -void RandomFlipLeftRight( +inline void RandomFlipLeftRight( const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, @@ -235,7 +383,7 @@ void RandomFlipLeftRight( }); } -void RandomFlipTopBottom( +inline void RandomFlipTopBottom( const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, @@ -287,7 +435,7 @@ inline void AdjustBrightnessImpl(const float& alpha_b, }); } -void RandomBrightness(const nnvm::NodeAttrs &attrs, +inline void RandomBrightness(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -405,7 +553,7 @@ inline void RandomSaturation(const nnvm::NodeAttrs &attrs, AdjustSaturationImpl(alpha_s, ctx, inputs, req, outputs); } -void RGB2HLSConvert(const float& src_r, +inline void RGB2HLSConvert(const float& src_r, const float& src_g, const float& src_b, float *dst_h, @@ -443,7 +591,7 @@ void RGB2HLSConvert(const float& src_r, *dst_s = s; } -void HLS2RGBConvert(const float& src_h, +inline void HLS2RGBConvert(const float& src_h, const float& src_l, const float& src_s, float *dst_r, @@ -494,7 +642,7 @@ void HLS2RGBConvert(const float& src_h, *dst_r = r * 255.f; } -void AdjustHueImpl(float alpha, +inline void AdjustHueImpl(float alpha, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -521,7 +669,7 @@ void AdjustHueImpl(float alpha, }); } -void RandomHue(const nnvm::NodeAttrs &attrs, +inline void RandomHue(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -554,7 +702,7 @@ struct RandomColorJitterParam : public dmlc::Parameter { } }; -void RandomColorJitter(const nnvm::NodeAttrs &attrs, +inline void RandomColorJitter(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -623,7 +771,7 @@ struct RandomLightingParam : public dmlc::Parameter { } }; -void AdjustLightingImpl(const nnvm::Tuple& alpha, +inline void AdjustLightingImpl(const nnvm::Tuple& alpha, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -658,7 +806,7 @@ void AdjustLightingImpl(const nnvm::Tuple& alpha, }); } -void AdjustLighting(const nnvm::NodeAttrs &attrs, +inline void AdjustLighting(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -668,7 +816,7 @@ void AdjustLighting(const nnvm::NodeAttrs &attrs, AdjustLightingImpl(param.alpha, ctx, inputs, req, outputs); } -void RandomLighting(const nnvm::NodeAttrs &attrs, +inline void RandomLighting(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, diff --git a/src/operator/image/image_random.cc b/src/operator/image/image_random.cc index 26f520bb8c5f..7901747c8ea1 100644 --- a/src/operator/image/image_random.cc +++ b/src/operator/image/image_random.cc @@ -49,21 +49,92 @@ NNVM_REGISTER_OP(_image_to_tensor) .add_argument("data", "NDArray-or-Symbol", "The input."); NNVM_REGISTER_OP(_image_normalize) -.describe(R"code()code" ADD_FILELINE) +.describe(R"code(Normalize an tensor of shape (C x H x W) or (N x C x H x W) with mean and + standard deviation. + + Given mean `(m1, ..., mn)` and std `(s\ :sub:`1`\ , ..., s\ :sub:`n`)` for `n` channels, + this transform normalizes each channel of the input tensor with: + +.. math:: + + output[i] = (input[i] - m\ :sub:`i`\ ) / s\ :sub:`i` + + If mean or std is scalar, the same value will be applied to all channels. + + Default value for mean is 0.0 and stand deviation is 1.0. + +Example: + + .. code-block:: python + image = mx.nd.random.uniform(0, 1, (3, 4, 2)) + normalize(image, mean=(0, 1, 2), std=(3, 2, 1)) + [[[ 0.18293785 0.19761486] + [ 0.23839645 0.28142193] + [ 0.20092112 0.28598186] + [ 0.18162774 0.28241724]] + [[-0.2881726 -0.18821815] + [-0.17705294 -0.30780914] + [-0.2812064 -0.3512327 ] + [-0.05411351 -0.4716435 ]] + [[-1.0363373 -1.7273437 ] + [-1.6165586 -1.5223348 ] + [-1.208275 -1.1878313 ] + [-1.4711051 -1.5200229 ]]] + + + image = mx.nd.random.uniform(0, 1, (2, 3, 4, 2)) + normalize(image, mean=(0, 1, 2), std=(3, 2, 1)) + [[[[ 0.18934818 0.13092826] + [ 0.3085322 0.27869293] + [ 0.02367868 0.11246539] + [ 0.0290431 0.2160573 ]] + [[-0.4898908 -0.31587923] + [-0.08369008 -0.02142242] + [-0.11092162 -0.42982462] + [-0.06499392 -0.06495637]] + [[-1.0213816 -1.526392 ] + [-1.2008414 -1.1990893 ] + [-1.5385206 -1.4795225 ] + [-1.2194707 -1.3211205 ]]] + [[[ 0.03942481 0.24021089] + [ 0.21330701 0.1940066 ] + [ 0.04778443 0.17912441] + [ 0.31488964 0.25287187]] + [[-0.23907584 -0.4470462 ] + [-0.29266903 -0.2631998 ] + [-0.3677222 -0.40683383] + [-0.11288315 -0.13154092]] + [[-1.5438497 -1.7834496 ] + [-1.431566 -1.8647819 ] + [-1.9812102 -1.675859 ] + [-1.3823645 -1.8503251 ]]]] + +)code" ADD_FILELINE) +.set_attr_parser(ParamParser) .set_num_inputs(1) .set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", NormalizeShape) -.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FInferShape", NormalizeOpShape) +.set_attr("FInferType", NormalizeOpType) +.set_attr("FCompute", NormalizeOpForward) .set_attr("FInplaceOption", - [](const NodeAttrs& attrs){ + [](const NodeAttrs& attrs) { return std::vector >{{0, 0}}; }) -.set_attr("FCompute", Normalize) -.set_attr("FGradient", ElemwiseGradUseNone{ "_copy" }) -.add_argument("data", "NDArray-or-Symbol", "The input.") +.set_attr("FGradient", ElemwiseGradUseIn{ "_backward_image_normalize"}) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") .add_arguments(NormalizeParam::__FIELDS__()); +NNVM_REGISTER_OP(_backward_image_normalize) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FCompute", NormalizeOpBackward); + MXNET_REGISTER_IMAGE_AUG_OP(_image_flip_left_right) .describe(R"code()code" ADD_FILELINE) .set_attr("FCompute", FlipLeftRight); diff --git a/src/operator/image/image_random.cu b/src/operator/image/image_random.cu new file mode 100644 index 000000000000..404c3d25477a --- /dev/null +++ b/src/operator/image/image_random.cu @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file image_random.cu + * \brief GPU Implementation of image transformation operators + */ +#include "./image_random-inl.h" +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { +namespace image { + +NNVM_REGISTER_OP(_image_normalize) +.set_attr("FCompute", NormalizeOpForward); + +NNVM_REGISTER_OP(_backward_image_normalize) +.set_attr("FCompute", NormalizeOpBackward); + + +} // namespace image +} // namespace op +} // namespace mxnet diff --git a/tests/python/gpu/test_gluon_transforms.py b/tests/python/gpu/test_gluon_transforms.py new file mode 100644 index 000000000000..c7afc762bd80 --- /dev/null +++ b/tests/python/gpu/test_gluon_transforms.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import print_function +import os +import sys +import mxnet as mx +import mxnet.ndarray as nd +import numpy as np +from mxnet import gluon +from mxnet.base import MXNetError +from mxnet.gluon.data.vision import transforms +from mxnet.test_utils import assert_almost_equal, set_default_context +from mxnet.test_utils import almost_equal +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.insert(0, os.path.join(curr_path, '../unittest')) +from common import assertRaises, setup_module, with_seed, teardown + + +set_default_context(mx.gpu(0)) + +@with_seed() +def test_normalize(): + # 3D Input + data_in_3d = nd.random.uniform(0, 1, (3, 300, 300)) + out_nd_3d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_3d) + data_expected_3d = data_in_3d.asnumpy() + data_expected_3d[:][:][0] = data_expected_3d[:][:][0] / 3.0 + data_expected_3d[:][:][1] = (data_expected_3d[:][:][1] - 1.0) / 2.0 + data_expected_3d[:][:][2] = data_expected_3d[:][:][2] - 2.0 + assert_almost_equal(data_expected_3d, out_nd_3d.asnumpy()) + + # 4D Input + data_in_4d = nd.random.uniform(0, 1, (2, 3, 300, 300)) + out_nd_4d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_4d) + data_expected_4d = data_in_4d.asnumpy() + data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0 + data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0 + data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0 + data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0 + data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0 + data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0 + assert_almost_equal(data_expected_4d, out_nd_4d.asnumpy()) + + # Default normalize values i.e., mean=0, std=1 + data_in_3d_def = nd.random.uniform(0, 1, (3, 300, 300)) + out_nd_3d_def = transforms.Normalize()(data_in_3d_def) + data_expected_3d_def = data_in_3d_def.asnumpy() + assert_almost_equal(data_expected_3d_def, out_nd_3d_def.asnumpy()) + + # Invalid Input - Neither 3D or 4D input + invalid_data_in = nd.random.uniform(0, 1, (5, 5, 3, 300, 300)) + normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1)) + assertRaises(MXNetError, normalize_transformer, invalid_data_in) + + # Invalid Input - Channel neither 1 or 3 + invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300)) + normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1)) + assertRaises(MXNetError, normalize_transformer, invalid_data_in) \ No newline at end of file diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py index 2ff9c5cb2a1d..c83778fefc65 100644 --- a/tests/python/unittest/test_gluon_data_vision.py +++ b/tests/python/unittest/test_gluon_data_vision.py @@ -19,10 +19,11 @@ import mxnet.ndarray as nd import numpy as np from mxnet import gluon +from mxnet.base import MXNetError from mxnet.gluon.data.vision import transforms from mxnet.test_utils import assert_almost_equal from mxnet.test_utils import almost_equal -from common import setup_module, with_seed, teardown +from common import assertRaises, setup_module, with_seed, teardown @with_seed() @@ -35,14 +36,36 @@ def test_to_tensor(): @with_seed() def test_normalize(): - data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8) - data_in = transforms.ToTensor()(nd.array(data_in, dtype='uint8')) - out_nd = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in) - data_expected = data_in.asnumpy() - data_expected[:][:][0] = data_expected[:][:][0] / 3.0 - data_expected[:][:][1] = (data_expected[:][:][1] - 1.0) / 2.0 - data_expected[:][:][2] = data_expected[:][:][2] - 2.0 - assert_almost_equal(data_expected, out_nd.asnumpy()) + # 3D Input + data_in_3d = nd.random.uniform(0, 1, (3, 300, 300)) + out_nd_3d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_3d) + data_expected_3d = data_in_3d.asnumpy() + data_expected_3d[:][:][0] = data_expected_3d[:][:][0] / 3.0 + data_expected_3d[:][:][1] = (data_expected_3d[:][:][1] - 1.0) / 2.0 + data_expected_3d[:][:][2] = data_expected_3d[:][:][2] - 2.0 + assert_almost_equal(data_expected_3d, out_nd_3d.asnumpy()) + + # 4D Input + data_in_4d = nd.random.uniform(0, 1, (2, 3, 300, 300)) + out_nd_4d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_4d) + data_expected_4d = data_in_4d.asnumpy() + data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0 + data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0 + data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0 + data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0 + data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0 + data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0 + assert_almost_equal(data_expected_4d, out_nd_4d.asnumpy()) + + # Invalid Input - Neither 3D or 4D input + invalid_data_in = nd.random.uniform(0, 1, (5, 5, 3, 300, 300)) + normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1)) + assertRaises(MXNetError, normalize_transformer, invalid_data_in) + + # Invalid Input - Channel neither 1 or 3 + invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300)) + normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1)) + assertRaises(MXNetError, normalize_transformer, invalid_data_in) @with_seed() diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 670cc7eb15e0..ce61beb125d3 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -7326,6 +7326,73 @@ def test_invalid_max_pooling_pad_type_same(): name='pooling', pooling_convention="same") + +@with_seed() +def test_image_normalize(): + # Part 1 - Test 3D Input + shape_3d = (3, 28, 28) + mean = (0, 1, 2) + std = (3, 2, 1) + + data_in_3d = mx.nd.random.uniform(0, 1, shape_3d) + data_expected_3d = data_in_3d.asnumpy() + data_expected_3d[:][:][0] = data_expected_3d[:][:][0] / 3.0 + data_expected_3d[:][:][1] = (data_expected_3d[:][:][1] - 1.0) / 2.0 + data_expected_3d[:][:][2] = data_expected_3d[:][:][2] - 2.0 + + data = mx.symbol.Variable('data') + img_norm_sym = mx.sym.image.normalize(data=data, mean=mean, std=std) + + # check forward + check_symbolic_forward(img_norm_sym, [data_in_3d], [data_expected_3d], + rtol=1e-5, atol=1e-5) + + # Gradient is 1/std_dev + grad_expected_3d = np.ones(shape_3d) + grad_expected_3d[:][:][0] = 1 / 3.0 + grad_expected_3d[:][:][1] = 1 / 2.0 + grad_expected_3d[:][:][2] = 1 / 1.0 + + # check backward + check_symbolic_backward(img_norm_sym, location=[data_in_3d], out_grads=[mx.nd.ones(shape_3d)], + expected=[grad_expected_3d], rtol=1e-5, atol=1e-5) + + # check backward using finite difference + check_numeric_gradient(img_norm_sym, [data_in_3d], atol=0.001) + + # Part 2 - Test 4D Input + shape_4d = (2, 3, 28, 28) + + data_in_4d = mx.nd.random.uniform(0, 1, shape_4d) + data_expected_4d = data_in_4d.asnumpy() + data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0 + data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0 + data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0 + data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0 + data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0 + data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0 + + # check forward + check_symbolic_forward(img_norm_sym, [data_in_4d], [data_expected_4d], + rtol=1e-5, atol=1e-5) + + # Gradient is 1/std_dev + grad_expected_4d = np.ones(shape_4d) + grad_expected_4d[0][:][:][0] = 1 / 3.0 + grad_expected_4d[0][:][:][1] = 1 / 2.0 + grad_expected_4d[0][:][:][2] = 1 / 1.0 + grad_expected_4d[1][:][:][0] = 1 / 3.0 + grad_expected_4d[1][:][:][1] = 1 / 2.0 + grad_expected_4d[1][:][:][2] = 1 / 1.0 + + # check backward + check_symbolic_backward(img_norm_sym, location=[data_in_4d], out_grads=[mx.nd.ones(shape_4d)], + expected=[grad_expected_4d], rtol=1e-5, atol=1e-5) + + # check backward using finite difference + check_numeric_gradient(img_norm_sym, [data_in_4d], atol=0.001) + + if __name__ == '__main__': import nose nose.runmodule()