diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index 0be5578360df..0fd1c89575d7 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -20,6 +20,8 @@ "Image transforms." import random +import numpy as np + from ...block import Block, HybridBlock from ...nn import Sequential, HybridSequential from .... import image @@ -198,6 +200,77 @@ def hybrid_forward(self, F, x): return F.image.normalize(x, self._mean, self._std) +class Rotate(Block): + """Rotate the input image by a given angle. Keeps the original image shape. + + Parameters + ---------- + rotation_degrees : float32 + Desired rotation angle in degrees. + zoom_in : bool + Zoom in image so that no padding is present in final output. + zoom_out : bool + Zoom out image so that the entire original image is present in final output. + + + Inputs: + - **data**: input tensor with (C x H x W) or (N x C x H x W) shape. + + Outputs: + - **out**: output tensor with (C x H x W) or (N x C x H x W) shape. + """ + def __init__(self, rotation_degrees, zoom_in=False, zoom_out=False): + super(Rotate, self).__init__() + self._args = (rotation_degrees, zoom_in, zoom_out) + + def forward(self, x): + if x.dtype is not np.float32: + raise TypeError("This transformation only supports float32. " + "Consider calling it after ToTensor") + return image.imrotate(x, *self._args) + + +class RandomRotation(Block): + """Random rotate the input image by a random angle. + Keeps the original image shape and aspect ratio. + + Parameters + ---------- + angle_limits: tuple + Tuple of 2 elements containing the upper and lower limit + for rotation angles in degree. + zoom_in : bool + Zoom in image so that no padding is present in final output. + zoom_out : bool + Zoom out image so that the entire original image is present in final output. + rotate_with_proba : float32 + + + Inputs: + - **data**: input tensor with (C x H x W) or (N x C x H x W) shape. + + Outputs: + - **out**: output tensor with (C x H x W) or (N x C x H x W) shape. + """ + def __init__(self, angle_limits, zoom_in=False, zoom_out=False, rotate_with_proba=1.0): + super(RandomRotation, self).__init__() + lower, upper = angle_limits + if lower >= upper: + raise ValueError("`angle_limits` must be an ordered tuple") + if rotate_with_proba < 0 or rotate_with_proba > 1: + raise ValueError("Probability of rotating the image should be between 0 and 1") + self._args = (angle_limits, zoom_in, zoom_out) + self._rotate_with_proba = rotate_with_proba + + def forward(self, x): + if np.random.random() > self._rotate_with_proba: + return x + if x.dtype is not np.float32: + raise TypeError("This transformation only supports float32. " + "Consider calling it after ToTensor") + return image.random_rotate(x, *self._args) + + class RandomResizedCrop(Block): """Crop the input image with random scale and aspect ratio. diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index 5236027bfa3b..95b45af5efd8 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -27,7 +27,11 @@ import logging import json import warnings + +from numbers import Number + import numpy as np + from .. import numpy as _mx_np # pylint: disable=reimported @@ -612,6 +616,145 @@ def random_size_crop(src, size, area, ratio, interp=2, **kwargs): return center_crop(src, size, interp) +def imrotate(src, rotation_degrees, zoom_in=False, zoom_out=False): + """Rotates the input image(s) of a specific rotation degree. + + Parameters + ---------- + src : NDArray + Input image (format CHW) or batch of images (format NCHW), + in both case is required a float32 data type. + rotation_degrees: scalar or NDArray + Wanted rotation in degrees. In case of `src` being a single image + a scalar is needed, otherwise a mono-dimensional vector of angles + or a scalar. + zoom_in: bool + If True input image(s) will be zoomed in a way so that no padding + will be shown in the output result. + zoom_out: bool + If True input image(s) will be zoomed in a way so that the whole + original image will be contained in the output result. + Returns + ------- + NDArray + An `NDArray` containing the rotated image(s). + """ + if zoom_in and zoom_out: + raise ValueError("`zoom_in` and `zoom_out` cannot be both True") + if src.dtype is not np.float32: + raise TypeError("Only `float32` images are supported by this function") + # handles the case in which a single image is passed to this function + expanded = False + if src.ndim == 3: + expanded = True + src = src.expand_dims(axis=0) + if not isinstance(rotation_degrees, Number): + raise TypeError("When a single image is passed the rotation angle is " + "required to be a scalar.") + elif src.ndim != 4: + raise ValueError("Only 3D and 4D are supported by this function") + + # when a scalar is passed we wrap it into an array + if isinstance(rotation_degrees, Number): + rotation_degrees = nd.array([rotation_degrees] * len(src), + ctx=src.context) + + if len(src) != len(rotation_degrees): + raise ValueError( + "The number of images must be equal to the number of rotation angles" + ) + + rotation_degrees = rotation_degrees.as_in_context(src.context) + rotation_rad = np.pi * rotation_degrees / 180 + # reshape the rotations angle in order to be broadcasted + # over the `src` tensor + rotation_rad = rotation_rad.expand_dims(axis=1).expand_dims(axis=2) + _, _, h, w = src.shape + + # Generate a grid centered at the center of the image + hscale = (float(h - 1) / 2) + wscale = (float(w - 1) / 2) + h_matrix = ( + nd.repeat(nd.arange(h, ctx=src.context).astype('float32').reshape(h, 1), w, axis=1) - hscale + ).expand_dims(axis=0) + w_matrix = ( + nd.repeat(nd.arange(w, ctx=src.context).astype('float32').reshape(1, w), h, axis=0) - wscale + ).expand_dims(axis=0) + # perform rotation on the grid + c_alpha = nd.cos(rotation_rad) + s_alpha = nd.sin(rotation_rad) + w_matrix_rot = w_matrix * c_alpha - h_matrix * s_alpha + h_matrix_rot = w_matrix * s_alpha + h_matrix * c_alpha + # NOTE: grid normalization must be performed after the rotation + # to keep the aspec ratio + w_matrix_rot = w_matrix_rot / wscale + h_matrix_rot = h_matrix_rot / hscale + + h, w = nd.array([h], ctx=src.context), nd.array([w], ctx=src.context) + # compute the scale factor in case `zoom_in` or `zoom_out` are True + if zoom_in or zoom_out: + rho_corner = nd.sqrt(h * h + w * w) + ang_corner = nd.arctan(h / w) + corner1_x_pos = nd.abs(rho_corner * nd.cos(ang_corner + nd.abs(rotation_rad))) + corner1_y_pos = nd.abs(rho_corner * nd.sin(ang_corner + nd.abs(rotation_rad))) + corner2_x_pos = nd.abs(rho_corner * nd.cos(ang_corner - nd.abs(rotation_rad))) + corner2_y_pos = nd.abs(rho_corner * nd.sin(ang_corner - nd.abs(rotation_rad))) + max_x = nd.maximum(corner1_x_pos, corner2_x_pos) + max_y = nd.maximum(corner1_y_pos, corner2_y_pos) + if zoom_out: + scale_x = max_x / w + scale_y = max_y / h + globalscale = nd.maximum(scale_x, scale_y) + else: + scale_x = w / max_x + scale_y = h / max_y + globalscale = nd.minimum(scale_x, scale_y) + globalscale = globalscale.expand_dims(axis=3) + else: + globalscale = 1 + grid = nd.concat(w_matrix_rot.expand_dims(axis=1), + h_matrix_rot.expand_dims(axis=1), dim=1) + grid = grid * globalscale + rot_img = nd.BilinearSampler(src, grid) + if expanded: + return rot_img[0] + return rot_img + + +def random_rotate(src, angle_limits, zoom_in=False, zoom_out=False): + """Random rotates `src` by an angle included in angle limits. + + Parameters + ---------- + src : NDArray + Input image (format CHW) or batch of images (format NCHW), + in both case is required a float32 data type. + angle_limits: tuple + Tuple of 2 elements containing the upper and lower limit + for rotation angles in degree. + zoom_in: bool + If True input image(s) will be zoomed in a way so that no padding + will be shown in the output result. + zoom_out: bool + If True input image(s) will be zoomed in a way so that the whole + original image will be contained in the output result. + Returns + ------- + NDArray + An `NDArray` containing the rotated image(s). + """ + if src.ndim == 3: + rotation_degrees = np.random.uniform(*angle_limits) + else: + n = src.shape[0] + rotation_degrees = nd.array(np.random.uniform( + *angle_limits, + size=n + )) + return imrotate(src, rotation_degrees, + zoom_in=zoom_in, zoom_out=zoom_out) + + class Augmenter(object): """Image Augmenter base class""" def __init__(self, **kwargs): diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py index 71efb72b9ce5..b53dbf0687d3 100644 --- a/tests/python/unittest/test_gluon_data_vision.py +++ b/tests/python/unittest/test_gluon_data_vision.py @@ -224,11 +224,67 @@ def test_transformer(): transforms.RandomHue(0.1), transforms.RandomLighting(0.1), transforms.ToTensor(), + transforms.RandomRotation([-10., 10.]), transforms.Normalize([0, 0, 0], [1, 1, 1])]) transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read() +@with_seed() +def test_rotate(): + transformer = transforms.Rotate(10.) + assertRaises(TypeError, transformer, mx.nd.ones((3, 30, 60), dtype='uint8')) + single_image = mx.nd.ones((3, 30, 60), dtype='float32') + single_output = transformer(single_image) + assert same(single_output.shape, (3, 30, 60)) + batch_image = mx.nd.ones((3, 3, 30, 60), dtype='float32') + batch_output = transformer(batch_image) + assert same(batch_output.shape, (3, 3, 30, 60)) + + input_image = nd.array([[[0., 0., 0.], + [0., 0., 1.], + [0., 0., 0.]]]) + rotation_angles_expected_outs = [ + (90., nd.array([[[0., 1., 0.], + [0., 0., 0.], + [0., 0., 0.]]])), + (180., nd.array([[[0., 0., 0.], + [1., 0., 0.], + [0., 0., 0.]]])), + (270., nd.array([[[0., 0., 0.], + [0., 0., 0.], + [0., 1., 0.]]])), + (360., nd.array([[[0., 0., 0.], + [0., 0., 1.], + [0., 0., 0.]]])), + ] + for rot_angle, expected_result in rotation_angles_expected_outs: + transformer = transforms.Rotate(rot_angle) + ans = transformer(input_image) + print(ans, expected_result) + assert_almost_equal(ans, expected_result, atol=1e-6) + + +@with_seed() +def test_random_rotation(): + # test exceptions for probability input outside of [0,1] + assertRaises(ValueError, transforms.RandomRotation, [-10, 10.], rotate_with_proba=1.1) + assertRaises(ValueError, transforms.RandomRotation, [-10, 10.], rotate_with_proba=-0.3) + # test `forward` + transformer = transforms.RandomRotation([-10, 10.]) + assertRaises(TypeError, transformer, mx.nd.ones((3, 30, 60), dtype='uint8')) + single_image = mx.nd.ones((3, 30, 60), dtype='float32') + single_output = transformer(single_image) + assert same(single_output.shape, (3, 30, 60)) + batch_image = mx.nd.ones((3, 3, 30, 60), dtype='float32') + batch_output = transformer(batch_image) + assert same(batch_output.shape, (3, 3, 30, 60)) + # test identity (rotate_with_proba = 0) + transformer = transforms.RandomRotation([-100., 100.], rotate_with_proba=0.0) + data = mx.nd.random_normal(shape=(3, 30, 60)) + assert_almost_equal(data, transformer(data)) + + @with_seed() def test_random_transforms(): from mxnet.gluon.data.vision import transforms diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index 8a276c351d00..033b8e5aab04 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -17,6 +17,7 @@ import mxnet as mx import numpy as np +import scipy.ndimage from mxnet.test_utils import * from common import assertRaises, with_seed import shutil @@ -369,6 +370,87 @@ def test_random_size_crop(self): assert ratio[0] - epsilon <= float(new_w)/new_h <= ratio[1] + epsilon, \ 'ration of new width and height out of the bound{}/{}={}'.format(new_w, new_h, float(new_w)/new_h) + @with_seed() + def test_imrotate(self): + # test correctness + xlin = np.expand_dims(np.linspace(0, 0.5, 30), axis=1) + ylin = np.expand_dims(np.linspace(0, 0.5, 60), axis=0) + np_img = np.expand_dims(xlin + ylin, axis=2) + # rotate with imrotate + nd_img = mx.nd.array(np_img.transpose((2, 0, 1))) # convert to CHW + rot_angle = 6 + args = {'src': nd_img, 'rotation_degrees': rot_angle, 'zoom_in': False, 'zoom_out': False} + nd_rot = mx.image.imrotate(**args) + npnd_rot = nd_rot.asnumpy().transpose((1, 2, 0)) + # rotate with scipy + scipy_rot = scipy.ndimage.rotate(np_img, rot_angle, axes=(1, 0), reshape=False, + order=1, mode='constant', prefilter=False) + # cannot compare the edges (where image ends) because of different behavior + assert_almost_equal(scipy_rot[10:20, 20:40, :], npnd_rot[10:20, 20:40, :]) + + # test if execution raises exceptions in any allowed mode + # batch mode + img_in = mx.nd.random.uniform(0, 1, (5, 3, 30, 60), dtype=np.float32) + nd_rots = mx.nd.array([1, 2, 3, 4, 5], dtype=np.float32) + args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': False} + _ = mx.image.imrotate(**args) + args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': True} + _ = mx.image.imrotate(**args) + args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': True, 'zoom_out': False} + _ = mx.image.imrotate(**args) + # single image mode + nd_rots = 11 + img_in = mx.nd.random.uniform(0, 1, (3, 30, 60), dtype=np.float32) + args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': False} + _ = mx.image.imrotate(**args) + args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': True} + _ = mx.image.imrotate(**args) + args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': True, 'zoom_out': False} + _ = mx.image.imrotate(**args) + + # test if exceptions are correctly raised + # batch exception - zoom_in=zoom_out=True + img_in = mx.nd.random.uniform(0, 1, (5, 3, 30, 60), dtype=np.float32) + nd_rots = mx.nd.array([1, 2, 3, 4, 5], dtype=np.float32) + args={'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': True, 'zoom_out': True} + self.assertRaises(ValueError, mx.image.imrotate, **args) + + # single image exception - zoom_in=zoom_out=True + img_in = mx.nd.random.uniform(0, 1, (3, 30, 60), dtype=np.float32) + nd_rots = 11 + args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': True, 'zoom_out': True} + self.assertRaises(ValueError, mx.image.imrotate, **args) + + # batch of images with scalar rotation + img_in = mx.nd.stack(nd_img, nd_img, nd_img) + nd_rots = 6 + args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': False} + out = mx.image.imrotate(**args) + for img in out: + img = img.asnumpy().transpose((1, 2, 0)) + assert_almost_equal(scipy_rot[10:20, 20:40, :], img[10:20, 20:40, :]) + + # single image exception - single image with vector rotation + img_in = mx.nd.random.uniform(0, 1, (3, 30, 60), dtype=np.float32) + nd_rots = mx.nd.array([1, 2, 3, 4, 5], dtype=np.float32) + args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': False} + self.assertRaises(TypeError, mx.image.imrotate, **args) + + @with_seed() + def test_random_rotate(self): + angle_limits = [-5., 5.] + src_single_image = mx.nd.random.uniform(0, 1, (3, 30, 60), + dtype=np.float32) + out_single_image = mx.image.random_rotate(src_single_image, + angle_limits) + self.assertEqual(out_single_image.shape, (3, 30, 60)) + src_batch_image = mx.nd.stack(src_single_image, + src_single_image, + src_single_image) + out_batch_image = mx.image.random_rotate(src_batch_image, + angle_limits) + self.assertEqual(out_batch_image.shape, (3, 3, 30, 60)) + if __name__ == '__main__': import nose