Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Added function for image rotation (imrotate) using BilinearSampler
Browse files Browse the repository at this point in the history
Added unit tests for imrotate

Added Rotate to transforms

Added RandomRotation to transforms
Added transforms tests
  • Loading branch information
Luca Kubin committed Dec 10, 2019
1 parent 0c5677e commit 0e68b29
Show file tree
Hide file tree
Showing 4 changed files with 342 additions and 0 deletions.
74 changes: 74 additions & 0 deletions python/mxnet/gluon/data/vision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
# pylint: disable= arguments-differ
"Image transforms."

import numpy as np

from ...block import Block, HybridBlock
from ...nn import Sequential, HybridSequential
from .... import image
Expand Down Expand Up @@ -197,6 +199,78 @@ def hybrid_forward(self, F, x):
return F.image.normalize(x, self._mean, self._std)


class Rotate(Block):
"""Rotate the input image by a given angle. Keeps the original image shape.
Parameters
----------
rotation_degrees : float32
Desired rotation angle in degrees, in range (-90, 90).
zoom_in : bool
Zoom in image so that no padding is present in final output.
zoom_out : bool
Zoom out image so that the entire original image is present in final output.
Inputs:
- **data**: input tensor with (C x H x W) or (N x C x H x W) shape.
Outputs:
- **out**: output tensor with (C x H x W) or (N x C x H x W) shape.
"""
def __init__(self, rotation_degrees, zoom_in=False, zoom_out=False):
super(Rotate, self).__init__()
if (rotation_degrees < -90 or rotation_degrees > 90):
raise ValueError("Rotation angle should be between -90 and 90 degrees")
self._args = (rotation_degrees, zoom_in, zoom_out)

def forward(self, x):
if x.dtype is not np.float32:
raise TypeError("This transformation only supports float32. "
"Consider calling it after ToTensor")
return image.imrotate(x, *self._args)


class RandomRotation(Block):
"""Random rotate the input image by a random angle.
Keeps the original image shape and aspect ratio.
Parameters
----------
angle_limits: tuple
Tuple of 2 elements containing the upper and lower limit
for rotation angles in degree.
zoom_in : bool
Zoom in image so that no padding is present in final output.
zoom_out : bool
Zoom out image so that the entire original image is present in final output.
rotate_with_proba : float32
Inputs:
- **data**: input tensor with (C x H x W) or (N x C x H x W) shape.
Outputs:
- **out**: output tensor with (C x H x W) or (N x C x H x W) shape.
"""
def __init__(self, angle_limits, zoom_in=False, zoom_out=False, rotate_with_proba=1.0):
super(RandomRotation, self).__init__()
lower, upper = angle_limits
if any(i < -90 for i in angle_limits) or any(i > 90 for i in angle_limits):
raise ValueError("rotation angles should be between -90 and 90 degrees")
if lower >= upper:
raise ValueError("`angle_limits` must be an ordered tuple")
if rotate_with_proba < 0 or rotate_with_proba > 1:
raise ValueError("Probability of rotating the image should be between 0 and 1")
self._args = (angle_limits, zoom_in, zoom_out, rotate_with_proba)

def forward(self, x):
if x.dtype is not np.float32:
raise TypeError("This transformation only supports float32. "
"Consider calling it after ToTensor")
return image.random_rotate(x, *self._args)


class RandomResizedCrop(Block):
"""Crop the input image with random scale and aspect ratio.
Expand Down
146 changes: 146 additions & 0 deletions python/mxnet/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
import logging
import json
import warnings

from numbers import Number

import numpy as np

from .. import numpy as _mx_np # pylint: disable=reimported


Expand Down Expand Up @@ -612,6 +616,148 @@ def random_size_crop(src, size, area, ratio, interp=2, **kwargs):
return center_crop(src, size, interp)


def imrotate(src, rotation_degrees, zoom_in=False, zoom_out=False):
"""Rotates the input image(s) of a specific rotation degree.
Parameters
----------
src : NDArray
Input image (format CHW) or batch of images (format NCHW),
in both case is required a float32 data type.
rotation_degrees: scalar or NDArray
Wanted rotation in degrees. In case of `src` being a single image
a scalar is needed, otherwise a mono-dimensional vector of angles
or a scalar.
zoom_in: bool
If True input image(s) will be zoomed in a way so that no padding
will be shown in the output result.
zoom_out: bool
If True input image(s) will be zoomed in a way so that the whole
original image will be contained in the output result.
Returns
-------
NDArray
An `NDArray` containing the rotated image(s).
"""
if zoom_in and zoom_out:
raise ValueError("`zoom_in` and `zoom_out` cannot be both True")
if src.dtype is not np.float32:
raise TypeError("Only `float32` images are supported by this function")
# handles the case in which a single image is passed to this function
expanded = False
if src.ndim == 3:
expanded = True
src = src.expand_dims(axis=0)
if not isinstance(rotation_degrees, Number):
raise TypeError("When a single image is passed the rotation angle is "
"required to be a scalar.")
elif src.ndim != 4:
raise ValueError("Only 3D and 4D are supported by this function")

# when a scalar is passed we wrap it into an array
if isinstance(rotation_degrees, Number):
rotation_degrees = nd.array([rotation_degrees] * len(src),
ctx=src.context)

if ((rotation_degrees < -90).sum().asscalar() > 0.0 or
(rotation_degrees > 90).sum().asscalar() > 0.0):
raise ValueError("Rotation angles should be between -90 and 90 degrees")

if len(src) != len(rotation_degrees):
raise ValueError(
"The number of images must be equal to the number of rotation angles"
)

rotation_degrees = rotation_degrees.as_in_context(src.context)
rotation_rad = np.pi * rotation_degrees / 180
# reshape the rotations angle in order to be broadcasted
# over the `src` tensor
rotation_rad = rotation_rad.expand_dims(axis=1).expand_dims(axis=2)
_, _, h, w = src.shape

# Generate a grid centered at the center of the image
hscale = (float(h - 1) / 2)
wscale = (float(w - 1) / 2)
h_matrix = (
nd.repeat(nd.arange(h, ctx=src.context).astype('float32').reshape(h, 1), w, axis=1) - hscale
).expand_dims(axis=0)
w_matrix = (
nd.repeat(nd.arange(w, ctx=src.context).astype('float32').reshape(1, w), h, axis=0) - wscale
).expand_dims(axis=0)
# perform rotation on the grid
c_alpha = nd.cos(rotation_rad)
s_alpha = nd.sin(rotation_rad)
w_matrix_rot = w_matrix * c_alpha - h_matrix * s_alpha
h_matrix_rot = w_matrix * s_alpha + h_matrix * c_alpha
# NOTE: grid normalization must be performed after the rotation
# to keep the aspec ratio
w_matrix_rot = w_matrix_rot / wscale
h_matrix_rot = h_matrix_rot / hscale

h, w = nd.array([h], ctx=src.context), nd.array([w], ctx=src.context)
# compute the scale factor in case `zoom_in` or `zoom_out` are True
if zoom_in or zoom_out:
if h < w:
roth = nd.sqrt(h * h + w * w) * nd.sin(nd.arctan(h / w) + nd.abs(rotation_rad))
else:
roth = nd.sqrt(h * h + w * w) * nd.cos(nd.arctan(h / w) - nd.abs(rotation_rad))
if zoom_in:
globalscale = min(h, w) / roth
else:
globalscale = roth / min(h, w)
globalscale = globalscale.expand_dims(axis=3)
else:
globalscale = 1
grid = nd.concat(w_matrix_rot.expand_dims(axis=1),
h_matrix_rot.expand_dims(axis=1), dim=1)
grid = grid * globalscale
rot_img = nd.BilinearSampler(src, grid)
if expanded:
return rot_img[0]
return rot_img


def random_rotate(src, angle_limits, zoom_in=False,
zoom_out=False, rotate_with_proba=1.0):
"""Random rotates `src` by an angle included in angle limits.
Parameters
----------
src : NDArray
Input image (format CHW) or batch of images (format NCHW),
in both case is required a float32 data type.
angle_limits: tuple
Tuple of 2 elements containing the upper and lower limit
for rotation angles in degree.
zoom_in: bool
If True input image(s) will be zoomed in a way so that no padding
will be shown in the output result.
zoom_out: bool
If True input image(s) will be zoomed in a way so that the whole
original image will be contained in the output result.
rotate_with_proba: float in [0., 1]
Probability of rotating the image.
Returns
-------
NDArray
An `NDArray` containing the rotated image(s).
"""
if rotate_with_proba < 0 or rotate_with_proba > 1:
raise ValueError('Probability of rotating the image should be between 0 and 1')
if np.random.random() > rotate_with_proba:
return src
if src.ndim == 3:
rotation_degrees = np.random.uniform(*angle_limits)
else:
n = src.shape[0]
rotation_degrees = nd.array(np.random.uniform(
*angle_limits,
size=n
))
return imrotate(src, rotation_degrees,
zoom_in=zoom_in, zoom_out=zoom_out)


class Augmenter(object):
"""Image Augmenter base class"""
def __init__(self, **kwargs):
Expand Down
28 changes: 28 additions & 0 deletions tests/python/unittest/test_gluon_data_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,39 @@ def test_transformer():
transforms.RandomHue(0.1),
transforms.RandomLighting(0.1),
transforms.ToTensor(),
transforms.RandomRotation([-10., 10.]),
transforms.Normalize([0, 0, 0], [1, 1, 1])])

transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read()


@with_seed()
def test_rotate():
assertRaises(ValueError, transforms.Rotate, -100.)
assertRaises(ValueError, transforms.Rotate, 100.)
transformer = transforms.Rotate(10.)
assertRaises(TypeError, transformer, mx.nd.ones((3, 30, 60), dtype='uint8'))
single_image = mx.nd.ones((3, 30, 60), dtype='float32')
single_output = transformer(single_image)
assert same(single_output.shape, (3, 30, 60))
batch_image = mx.nd.ones((3, 3, 30, 60), dtype='float32')
batch_output = transformer(batch_image)
assert same(batch_output.shape, (3, 3, 30, 60))


@with_seed()
def test_random_rotation():
assertRaises(ValueError, transforms.RandomRotation, [-100., 100.])
assertRaises(ValueError, transforms.RandomRotation, [100., -100])
transformer = transforms.RandomRotation([-10, 10.])
assertRaises(TypeError, transformer, mx.nd.ones((3, 30, 60), dtype='uint8'))
single_image = mx.nd.ones((3, 30, 60), dtype='float32')
single_output = transformer(single_image)
assert same(single_output.shape, (3, 30, 60))
batch_image = mx.nd.ones((3, 3, 30, 60), dtype='float32')
batch_output = transformer(batch_image)
assert same(batch_output.shape, (3, 3, 30, 60))


if __name__ == '__main__':
import nose
Expand Down
94 changes: 94 additions & 0 deletions tests/python/unittest/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import mxnet as mx
import numpy as np
import scipy.ndimage
from mxnet.test_utils import *
from common import assertRaises, with_seed
import shutil
Expand Down Expand Up @@ -369,6 +370,99 @@ def test_random_size_crop(self):
assert ratio[0] - epsilon <= float(new_w)/new_h <= ratio[1] + epsilon, \
'ration of new width and height out of the bound{}/{}={}'.format(new_w, new_h, float(new_w)/new_h)

@with_seed()
def test_imrotate(self):
# test correctness
xlin = np.expand_dims(np.linspace(0, 0.5, 30), axis=1)
ylin = np.expand_dims(np.linspace(0, 0.5, 60), axis=0)
np_img = np.expand_dims(xlin + ylin, axis=2)
# rotate with imrotate
nd_img = mx.nd.array(np_img.transpose((2, 0, 1))) #convert to CHW
rot_angle = 6
args = {'src': nd_img, 'rotation_degrees': rot_angle, 'zoom_in': False, 'zoom_out': False}
nd_rot = mx.image.imrotate(**args)
npnd_rot = nd_rot.asnumpy().transpose((1, 2, 0))
# rotate with scipy
scipy_rot = scipy.ndimage.rotate(np_img, rot_angle, axes=(1, 0), reshape=False,
order=1, mode='constant', prefilter=False)
# cannot compare the edges (where image ends) because of different behavior
assert_almost_equal(scipy_rot[10:20, 20:40, :], npnd_rot[10:20, 20:40, :])

# test if execution raises exceptions in any allowed mode
# batch mode
img_in = mx.nd.random.uniform(0, 1, (5, 3, 30, 60), dtype=np.float32)
nd_rots = mx.nd.array([1, 2, 3, 4, 5], dtype=np.float32)
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': False}
_ = mx.image.imrotate(**args)
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': True}
_ = mx.image.imrotate(**args)
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': True, 'zoom_out': False}
_ = mx.image.imrotate(**args)
# single image mode
nd_rots = 11
img_in = mx.nd.random.uniform(0, 1, (3, 30, 60), dtype=np.float32)
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': False}
_ = mx.image.imrotate(**args)
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': True}
_ = mx.image.imrotate(**args)
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': True, 'zoom_out': False}
_ = mx.image.imrotate(**args)

# test if exceptions are correctly raised
# batch exception - zoom_in=zoom_out=True
img_in = mx.nd.random.uniform(0, 1, (5, 3, 30, 60), dtype=np.float32)
nd_rots = mx.nd.array([1, 2, 3, 4, 5], dtype=np.float32)
args={'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': True, 'zoom_out': True}
self.assertRaises(ValueError, mx.image.imrotate, **args)

# single image exception - zoom_in=zoom_out=True
img_in = mx.nd.random.uniform(0, 1, (3, 30, 60), dtype=np.float32)
nd_rots = 11
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': True, 'zoom_out': True}
self.assertRaises(ValueError, mx.image.imrotate, **args)

# batch of images with scalar rotation
img_in = mx.nd.stack(nd_img, nd_img, nd_img)
nd_rots = 6
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': False}
out = mx.image.imrotate(**args)
for img in out:
img = img.asnumpy().transpose((1, 2, 0))
assert_almost_equal(scipy_rot[10:20, 20:40, :], img[10:20, 20:40, :])

# single image exception - single image with vector rotation
img_in = mx.nd.random.uniform(0, 1, (3, 30, 60), dtype=np.float32)
nd_rots = mx.nd.array([1, 2, 3, 4, 5], dtype=np.float32)
args = {'src': img_in, 'rotation_degrees': nd_rots, 'zoom_in': False, 'zoom_out': False}
self.assertRaises(TypeError, mx.image.imrotate, **args)

@with_seed()
def test_random_rotate(self):
angle_limits = [-5., 5.]
src_single_image = mx.nd.random.uniform(0, 1, (3, 30, 60),
dtype=np.float32)
out_single_image = mx.image.random_rotate(src_single_image,
angle_limits)
self.assertEqual(out_single_image.shape, (3, 30, 60))
src_batch_image = mx.nd.stack(src_single_image,
src_single_image,
src_single_image)
out_batch_image = mx.image.random_rotate(src_batch_image,
angle_limits)
self.assertEqual(out_batch_image.shape, (3, 3, 30, 60))

# test exceptions for probability input outside of [0,1]
args = {'src': src_single_image,
'angle_limits': angle_limits,
'zoom_in': False, 'zoom_out': False,
'rotate_with_proba': 1.1}
self.assertRaises(ValueError, mx.image.random_rotate, **args)
args = {'src': src_single_image,
'angle_limits': angle_limits,
'zoom_in': False, 'zoom_out': False,
'rotate_with_proba': -0.3}
self.assertRaises(ValueError, mx.image.random_rotate, **args)


if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 0e68b29

Please sign in to comment.