Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Random rotation #16794

Merged
merged 8 commits into from
Feb 11, 2020
74 changes: 74 additions & 0 deletions python/mxnet/gluon/data/vision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
# pylint: disable= arguments-differ
"Image transforms."

import numpy as np

from ...block import Block, HybridBlock
from ...nn import Sequential, HybridSequential
from .... import image
Expand Down Expand Up @@ -197,6 +199,78 @@ def hybrid_forward(self, F, x):
return F.image.normalize(x, self._mean, self._std)


class Rotate(Block):
"""Rotate the input image by a given angle. Keeps the original image shape.

Parameters
----------
rotation_degrees : float32
Desired rotation angle in degrees, in range (-90, 90).
lkubin marked this conversation as resolved.
Show resolved Hide resolved
zoom_in : bool
Zoom in image so that no padding is present in final output.
zoom_out : bool
Zoom out image so that the entire original image is present in final output.


Inputs:
- **data**: input tensor with (C x H x W) or (N x C x H x W) shape.

Outputs:
- **out**: output tensor with (C x H x W) or (N x C x H x W) shape.
"""
def __init__(self, rotation_degrees, zoom_in=False, zoom_out=False):
super(Rotate, self).__init__()
if (rotation_degrees < -90 or rotation_degrees > 90):
raise ValueError("Rotation angle should be between -90 and 90 degrees")
lkubin marked this conversation as resolved.
Show resolved Hide resolved
self._args = (rotation_degrees, zoom_in, zoom_out)

def forward(self, x):
if x.dtype is not np.float32:
raise TypeError("This transformation only supports float32. "
"Consider calling it after ToTensor")
return image.imrotate(x, *self._args)


class RandomRotation(Block):
"""Random rotate the input image by a random angle.
Keeps the original image shape and aspect ratio.

Parameters
----------
angle_limits: tuple
Tuple of 2 elements containing the upper and lower limit
for rotation angles in degree.
zoom_in : bool
Zoom in image so that no padding is present in final output.
zoom_out : bool
Zoom out image so that the entire original image is present in final output.
rotate_with_proba : float32


Inputs:
- **data**: input tensor with (C x H x W) or (N x C x H x W) shape.

Outputs:
- **out**: output tensor with (C x H x W) or (N x C x H x W) shape.
"""
def __init__(self, angle_limits, zoom_in=False, zoom_out=False, rotate_with_proba=1.0):
super(RandomRotation, self).__init__()
lower, upper = angle_limits
if any(i < -90 for i in angle_limits) or any(i > 90 for i in angle_limits):
raise ValueError("rotation angles should be between -90 and 90 degrees")
lkubin marked this conversation as resolved.
Show resolved Hide resolved
if lower >= upper:
raise ValueError("`angle_limits` must be an ordered tuple")
if rotate_with_proba < 0 or rotate_with_proba > 1:
raise ValueError("Probability of rotating the image should be between 0 and 1")
self._args = (angle_limits, zoom_in, zoom_out, rotate_with_proba)

def forward(self, x):
if x.dtype is not np.float32:
raise TypeError("This transformation only supports float32. "
"Consider calling it after ToTensor")
lkubin marked this conversation as resolved.
Show resolved Hide resolved
return image.random_rotate(x, *self._args)


class RandomResizedCrop(Block):
"""Crop the input image with random scale and aspect ratio.

Expand Down
146 changes: 146 additions & 0 deletions python/mxnet/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
import logging
import json
import warnings

from numbers import Number

lkubin marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np

from .. import numpy as _mx_np # pylint: disable=reimported


Expand Down Expand Up @@ -612,6 +616,148 @@ def random_size_crop(src, size, area, ratio, interp=2, **kwargs):
return center_crop(src, size, interp)


def imrotate(src, rotation_degrees, zoom_in=False, zoom_out=False):
"""Rotates the input image(s) of a specific rotation degree.

Parameters
----------
src : NDArray
Input image (format CHW) or batch of images (format NCHW),
in both case is required a float32 data type.
rotation_degrees: scalar or NDArray
Wanted rotation in degrees. In case of `src` being a single image
a scalar is needed, otherwise a mono-dimensional vector of angles
or a scalar.
zoom_in: bool
If True input image(s) will be zoomed in a way so that no padding
will be shown in the output result.
zoom_out: bool
If True input image(s) will be zoomed in a way so that the whole
original image will be contained in the output result.
Returns
-------
NDArray
An `NDArray` containing the rotated image(s).
"""
if zoom_in and zoom_out:
raise ValueError("`zoom_in` and `zoom_out` cannot be both True")
if src.dtype is not np.float32:
raise TypeError("Only `float32` images are supported by this function")
# handles the case in which a single image is passed to this function
expanded = False
if src.ndim == 3:
expanded = True
src = src.expand_dims(axis=0)
if not isinstance(rotation_degrees, Number):
raise TypeError("When a single image is passed the rotation angle is "
"required to be a scalar.")
elif src.ndim != 4:
raise ValueError("Only 3D and 4D are supported by this function")

# when a scalar is passed we wrap it into an array
if isinstance(rotation_degrees, Number):
rotation_degrees = nd.array([rotation_degrees] * len(src),
ctx=src.context)

if ((rotation_degrees < -90).sum().asscalar() > 0.0 or
(rotation_degrees > 90).sum().asscalar() > 0.0):
raise ValueError("Rotation angles should be between -90 and 90 degrees")

if len(src) != len(rotation_degrees):
raise ValueError(
"The number of images must be equal to the number of rotation angles"
)

rotation_degrees = rotation_degrees.as_in_context(src.context)
rotation_rad = np.pi * rotation_degrees / 180
# reshape the rotations angle in order to be broadcasted
# over the `src` tensor
rotation_rad = rotation_rad.expand_dims(axis=1).expand_dims(axis=2)
_, _, h, w = src.shape

# Generate a grid centered at the center of the image
hscale = (float(h - 1) / 2)
wscale = (float(w - 1) / 2)
h_matrix = (
nd.repeat(nd.arange(h, ctx=src.context).astype('float32').reshape(h, 1), w, axis=1) - hscale
).expand_dims(axis=0)
w_matrix = (
nd.repeat(nd.arange(w, ctx=src.context).astype('float32').reshape(1, w), h, axis=0) - wscale
).expand_dims(axis=0)
# perform rotation on the grid
c_alpha = nd.cos(rotation_rad)
s_alpha = nd.sin(rotation_rad)
w_matrix_rot = w_matrix * c_alpha - h_matrix * s_alpha
h_matrix_rot = w_matrix * s_alpha + h_matrix * c_alpha
# NOTE: grid normalization must be performed after the rotation
# to keep the aspec ratio
w_matrix_rot = w_matrix_rot / wscale
h_matrix_rot = h_matrix_rot / hscale

h, w = nd.array([h], ctx=src.context), nd.array([w], ctx=src.context)
# compute the scale factor in case `zoom_in` or `zoom_out` are True
if zoom_in or zoom_out:
if h < w:
roth = nd.sqrt(h * h + w * w) * nd.sin(nd.arctan(h / w) + nd.abs(rotation_rad))
else:
roth = nd.sqrt(h * h + w * w) * nd.cos(nd.arctan(h / w) - nd.abs(rotation_rad))
if zoom_in:
globalscale = min(h, w) / roth
else:
globalscale = roth / min(h, w)
globalscale = globalscale.expand_dims(axis=3)
else:
globalscale = 1
grid = nd.concat(w_matrix_rot.expand_dims(axis=1),
h_matrix_rot.expand_dims(axis=1), dim=1)
grid = grid * globalscale
rot_img = nd.BilinearSampler(src, grid)
if expanded:
return rot_img[0]
return rot_img


def random_rotate(src, angle_limits, zoom_in=False,
zoom_out=False, rotate_with_proba=1.0):
"""Random rotates `src` by an angle included in angle limits.

Parameters
----------
src : NDArray
Input image (format CHW) or batch of images (format NCHW),
in both case is required a float32 data type.
angle_limits: tuple
Tuple of 2 elements containing the upper and lower limit
for rotation angles in degree.
zoom_in: bool
If True input image(s) will be zoomed in a way so that no padding
will be shown in the output result.
zoom_out: bool
If True input image(s) will be zoomed in a way so that the whole
original image will be contained in the output result.
rotate_with_proba: float in [0., 1]
Probability of rotating the image.
Returns
-------
NDArray
An `NDArray` containing the rotated image(s).
"""
if rotate_with_proba < 0 or rotate_with_proba > 1:
raise ValueError('Probability of rotating the image should be between 0 and 1')
if np.random.random() > rotate_with_proba:
return src
if src.ndim == 3:
rotation_degrees = np.random.uniform(*angle_limits)
else:
n = src.shape[0]
rotation_degrees = nd.array(np.random.uniform(
*angle_limits,
size=n
))
return imrotate(src, rotation_degrees,
zoom_in=zoom_in, zoom_out=zoom_out)


class Augmenter(object):
"""Image Augmenter base class"""
def __init__(self, **kwargs):
Expand Down
28 changes: 28 additions & 0 deletions tests/python/unittest/test_gluon_data_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,39 @@ def test_transformer():
transforms.RandomHue(0.1),
transforms.RandomLighting(0.1),
transforms.ToTensor(),
transforms.RandomRotation([-10., 10.]),
transforms.Normalize([0, 0, 0], [1, 1, 1])])

transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read()


@with_seed()
lkubin marked this conversation as resolved.
Show resolved Hide resolved
def test_rotate():
assertRaises(ValueError, transforms.Rotate, -100.)
assertRaises(ValueError, transforms.Rotate, 100.)
transformer = transforms.Rotate(10.)
assertRaises(TypeError, transformer, mx.nd.ones((3, 30, 60), dtype='uint8'))
single_image = mx.nd.ones((3, 30, 60), dtype='float32')
single_output = transformer(single_image)
assert same(single_output.shape, (3, 30, 60))
batch_image = mx.nd.ones((3, 3, 30, 60), dtype='float32')
batch_output = transformer(batch_image)
assert same(batch_output.shape, (3, 3, 30, 60))


@with_seed()
def test_random_rotation():
assertRaises(ValueError, transforms.RandomRotation, [-100., 100.])
assertRaises(ValueError, transforms.RandomRotation, [100., -100])
transformer = transforms.RandomRotation([-10, 10.])
assertRaises(TypeError, transformer, mx.nd.ones((3, 30, 60), dtype='uint8'))
single_image = mx.nd.ones((3, 30, 60), dtype='float32')
single_output = transformer(single_image)
assert same(single_output.shape, (3, 30, 60))
batch_image = mx.nd.ones((3, 3, 30, 60), dtype='float32')
batch_output = transformer(batch_image)
assert same(batch_output.shape, (3, 3, 30, 60))


if __name__ == '__main__':
import nose
Expand Down
94 changes: 94 additions & 0 deletions tests/python/unittest/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import mxnet as mx
import numpy as np
import scipy.ndimage
from mxnet.test_utils import *
from common import assertRaises, with_seed
import shutil
Expand Down Expand Up @@ -369,6 +370,99 @@ def test_random_size_crop(self):
assert ratio[0] - epsilon <= float(new_w)/new_h <= ratio[1] + epsilon, \
'ration of new width and height out of the bound{}/{}={}'.format(new_w, new_h, float(new_w)/new_h)

@with_seed()
def test_imrotate(self):
# test correctness
xlin = np.expand_dims(np.linspace(0, 0.5, 30), axis=1)
ylin = np.expand_dims(np.linspace(0, 0.5, 60), axis=0)
np_img = np.expand_dims(xlin + ylin, axis=2)
# rotate with imrotate
nd_img = mx.nd.array(np_img.transpose((2, 0, 1))) #convert to CHW
rot_angle = 6
args = {'src': nd_img, 'rotation_degrees': rot_angle, 'zoom_in': False, 'zoom_out': False}
nd_rot = mx.image.imrotate(**args)
npnd_rot = nd_rot.asnumpy().transpose((1, 2, 0))
# rotate with scipy
scipy_rot = scipy.ndimage.rotate(np_img, rot_angle, axes=(1, 0), reshape=False,
order=1, mode='constant', prefilter=False)
# cannot compare the edges (where image ends) because of different behavior
assert_almost_equal(scipy_rot[10:20, 20:40, :], npnd_rot[10:20, 20:40, :])

# test if execution raises exceptions in any allowed mode
# batch mode
img_in = mx.nd.random.uniform(0, 1, (5, 3, 30, 60), dtype=np.float32)
nd_rots = mx.nd.array([1, 2, 3, 4, 5], dtype=np.float32)
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': False}
_ = mx.image.imrotate(**args)
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': True}
_ = mx.image.imrotate(**args)
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': True, 'zoom_out': False}
_ = mx.image.imrotate(**args)
# single image mode
nd_rots = 11
img_in = mx.nd.random.uniform(0, 1, (3, 30, 60), dtype=np.float32)
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': False}
_ = mx.image.imrotate(**args)
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': True}
_ = mx.image.imrotate(**args)
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': True, 'zoom_out': False}
_ = mx.image.imrotate(**args)

# test if exceptions are correctly raised
# batch exception - zoom_in=zoom_out=True
img_in = mx.nd.random.uniform(0, 1, (5, 3, 30, 60), dtype=np.float32)
nd_rots = mx.nd.array([1, 2, 3, 4, 5], dtype=np.float32)
args={'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': True, 'zoom_out': True}
self.assertRaises(ValueError, mx.image.imrotate, **args)

# single image exception - zoom_in=zoom_out=True
img_in = mx.nd.random.uniform(0, 1, (3, 30, 60), dtype=np.float32)
nd_rots = 11
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': True, 'zoom_out': True}
self.assertRaises(ValueError, mx.image.imrotate, **args)

# batch of images with scalar rotation
img_in = mx.nd.stack(nd_img, nd_img, nd_img)
nd_rots = 6
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': False}
out = mx.image.imrotate(**args)
for img in out:
img = img.asnumpy().transpose((1, 2, 0))
assert_almost_equal(scipy_rot[10:20, 20:40, :], img[10:20, 20:40, :])

# single image exception - single image with vector rotation
img_in = mx.nd.random.uniform(0, 1, (3, 30, 60), dtype=np.float32)
nd_rots = mx.nd.array([1, 2, 3, 4, 5], dtype=np.float32)
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': False}
self.assertRaises(TypeError, mx.image.imrotate, **args)

@with_seed()
def test_random_rotate(self):
angle_limits = [-5., 5.]
src_single_image = mx.nd.random.uniform(0, 1, (3, 30, 60),
dtype=np.float32)
out_single_image = mx.image.random_rotate(src_single_image,
angle_limits)
self.assertEqual(out_single_image.shape, (3, 30, 60))
src_batch_image = mx.nd.stack(src_single_image,
src_single_image,
src_single_image)
out_batch_image = mx.image.random_rotate(src_batch_image,
angle_limits)
self.assertEqual(out_batch_image.shape, (3, 3, 30, 60))

# test exceptions for probability input outside of [0,1]
args = {'src': src_single_image,
'angle_limits': angle_limits,
'zoom_in': False, 'zoom_out': False,
'rotate_with_proba': 1.1}
self.assertRaises(ValueError, mx.image.random_rotate, **args)
args = {'src': src_single_image,
'angle_limits': angle_limits,
'zoom_in': False, 'zoom_out': False,
'rotate_with_proba': -0.3}
self.assertRaises(ValueError, mx.image.random_rotate, **args)


if __name__ == '__main__':
import nose
Expand Down