Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Op Unravel_index PR [Numpy] (#16862)
Browse files Browse the repository at this point in the history
* unravel_index pr

* fix line too long
  • Loading branch information
Tommliu authored and haojin2 committed Nov 25, 2019
1 parent 6bff547 commit f11592d
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 15 deletions.
48 changes: 43 additions & 5 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax',
'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity',
'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory',
'diff', 'resize', 'nan_to_num', 'where']
'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman',
'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril',
'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory',
'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where']

@set_module('mxnet.ndarray.numpy')
def shape(a):
Expand Down Expand Up @@ -3828,6 +3828,44 @@ def ravel(x, order='C'):
raise TypeError('type {} not supported'.format(str(type(x))))


def unravel_index(indices, shape, order='C'): # pylint: disable=redefined-outer-name
"""
Converts a flat index or array of flat indices into a tuple of coordinate arrays.
Parameters:
-------------
indices : array_like
An integer array whose elements are indices into the flattened version of an array of dimensions shape.
Before version 1.6.0, this function accepted just one index value.
shape : tuple of ints
The shape of the array to use for unraveling indices.
Returns:
-------------
unraveled_coords : ndarray
Each row in the ndarray has the same shape as the indices array.
Each column in the ndarray represents the unravelled index
Examples:
-------------
>>> np.unravel_index([22, 41, 37], (7,6))
([3. 6. 6.]
[4. 5. 1.])
>>> np.unravel_index(1621, (6,7,8,9))
(3, 1, 4, 1)
"""
if order == 'C':
if isinstance(indices, numeric_types):
return _np.unravel_index(indices, shape)
ret = _npi.unravel_index_fallback(indices, shape=shape)
ret_list = []
for item in ret:
ret_list += [item]
return tuple(ret_list)
else:
raise NotImplementedError('Do not support column-major (Fortran-style) order at this moment')


@set_module('mxnet.ndarray.numpy')
def hanning(M, dtype=_np.float32, ctx=None):
r"""Return the Hanning window.
Expand Down
40 changes: 35 additions & 5 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append',
'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange',
'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming',
'blackman', 'flip', 'around', 'arctan2', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal',
'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero',
'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where']
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'bitwise_xor', 'bitwise_or',
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum',
'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -5759,6 +5759,36 @@ def ravel(x, order='C'):
return _mx_nd_np.ravel(x, order)


def unravel_index(indices, shape, order='C'): # pylint: disable=redefined-outer-name
"""
Converts a flat index or array of flat indices into a tuple of coordinate arrays.
Parameters:
-------------
indices : array_like
An integer array whose elements are indices into the flattened version of an array of dimensions shape.
Before version 1.6.0, this function accepted just one index value.
shape : tuple of ints
The shape of the array to use for unraveling indices.
order : Only row-major is supported currently.
Returns:
-------------
unraveled_coords : ndarray
Each row in the ndarray has the same shape as the indices array.
Each column in the ndarray represents the unravelled index
Examples:
-------------
>>> np.unravel_index([22, 41, 37], (7,6))
[[3. 6. 6.]
[4. 5. 1.]]
>>> np.unravel_index(1621, (6,7,8,9))
[3, 1, 4, 1]
"""
return _mx_nd_np.unravel_index(indices, shape, order=order)


@set_module('mxnet.numpy')
def hanning(M, dtype=_np.float32, ctx=None):
r"""Return the Hanning window.
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'tile',
'transpose',
'unique',
'unravel_index',
'var',
'vdot',
'vstack',
Expand Down
34 changes: 34 additions & 0 deletions python/mxnet/numpy_op_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,37 @@ def infer_shape(self, in_shape):

def create_operator(self, ctx, in_shapes, in_dtypes):
return Resize(self._new_shape)


@use_np
class Unravel_index(operator.CustomOp):
"""Fallback to NumPy Unravel_index operator."""
def __init__(self, shape):
super(Unravel_index, self).__init__()
self._shape = shape

def forward(self, is_train, req, in_data, out_data, aux):
out = np.unravel_index(in_data[0].asnumpy(), self._shape)
self.assign(out_data[0], req[0], _mx_np.array(out, dtype=out[0].dtype, ctx=out_data[0].ctx))

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
raise NotImplementedError('Operator Unravel_index does not support gradient computation')


@register('unravel_index_fallback')
class Unravel_indexProp(operator.CustomOpProp):
"""Fallback unravel_index operator properties."""
def __init__(self, shape):
super(Unravel_indexProp, self).__init__(need_top_grad=True)
self._shape = ast.literal_eval(shape)

def list_arguments(self):
return ['indices']

def infer_shape(self, in_shape):
dim_list = (1,) if np.isscalar(self._shape) else (len(self._shape),)
out_shape = dim_list + tuple(in_shape[0])
return (in_shape[0],), (out_shape,), ()

def create_operator(self, ctx, in_shapes, in_dtypes):
return Unravel_index(self._shape)
42 changes: 37 additions & 5 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax',
'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity',
'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff',
'resize', 'nan_to_num', 'where']
'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman',
'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril',
'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory',
'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where']

def _num_outputs(sym):
return len(sym.as_nd_ndarray())
Expand Down Expand Up @@ -3657,6 +3657,38 @@ def ravel(x, order='C'):
raise TypeError('type {} not supported'.format(str(type(x))))


def unravel_index(indices, shape, order='C'): # pylint: disable=redefined-outer-name
"""
Converts a flat index or array of flat indices into a tuple of coordinate arrays.
Parameters:
-------------
indices : array_like
An integer array whose elements are indices into the flattened version of an array of dimensions shape.
Before version 1.6.0, this function accepted just one index value.
shape : tuple of ints
The shape of the array to use for unraveling indices.
Returns:
-------------
unraveled_coords : ndarray
Each row in the ndarray has the same shape as the indices array.
Each column in the ndarray represents the unravelled index
Examples:
-------------
>>> np.unravel_index([22, 41, 37], (7,6))
([3. 6. 6.]
[4. 5. 1.])
>>> np.unravel_index(1621, (6,7,8,9))
(3, 1, 4, 1)
"""
if order == 'C':
return _npi.unravel_index_fallback(indices, shape=shape)
else:
raise NotImplementedError('Don not support column-major (Fortran-style) order at this moment')


@set_module('mxnet.symbol.numpy')
def hanning(M, dtype=_np.float32, ctx=None):
r"""Return the Hanning window.
Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ def get_workloads(name):
return OpArgMngr._args.get(name, None)


def _add_workload_unravel_index():
OpArgMngr.add_workload('unravel_index', indices=np.array([2],dtype=_np.int64), shape=(2, 2))
OpArgMngr.add_workload('unravel_index', np.array([(2*3 + 1)*6 + 4], dtype=_np.int64), (4, 3, 6))
OpArgMngr.add_workload('unravel_index', np.array([22, 41, 37], dtype=_np.int32), (7, 6))
OpArgMngr.add_workload('unravel_index', np.array([1621],dtype=_np.uint8), (6, 7, 8, 9))
OpArgMngr.add_workload('unravel_index', np.array([],dtype=_np.int64), (10, 3, 5))
OpArgMngr.add_workload('unravel_index', np.array([3], dtype=_np.int32), (2,2))


def _add_workload_diag():
def get_mat(n):
data = _np.arange(n)
Expand Down Expand Up @@ -1310,6 +1319,7 @@ def _prepare_workloads():
_add_workload_copy()
_add_workload_cumsum()
_add_workload_ravel()
_add_workload_unravel_index()
_add_workload_diag()
_add_workload_diagflat()
_add_workload_dot()
Expand Down
52 changes: 52 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4998,6 +4998,58 @@ def hybrid_forward(self, F, x):
assert_almost_equal(y.asnumpy(), expected, use_broadcast=False)


@with_seed()
@use_np
def test_np_unravel_index():
class TestUnravel_index(HybridBlock):
def __init__(self, shape, order='C') :
super(TestUnravel_index, self).__init__()
self._shape = shape
self._order = order

def hybrid_forward(self, F, a):
return F.np.unravel_index(a, self._shape, self._order)

in_shapes = [
2, 5,
(), (1,), (4,),
(2, 2), (2, 4), (3, 5),
(2, 2, 2), (2, 3, 2), (2, 3, 4),
]
unravel_shapes = [
10, (15,),
(3, 4), (4, 5),
(2,3,4)
]
dtypes = [np.uint8, np.int8, np.int32, np.int64]
for hybridize, ishape, dtype, rshape in itertools.product([False, True], in_shapes, dtypes, unravel_shapes):
rtol = 1e-2 if dtype == np.float16 else 1e-3
atol = 1e-4 if dtype == np.float16 else 1e-5
test_unravel_index = TestUnravel_index(rshape)
if hybridize:
test_unravel_index.hybridize()
if type(ishape) == int and hybridize:
x = np.array([ishape], dtype=dtype)
np_out = _np.unravel_index(x.asnumpy(), rshape)
else:
x = np.random.uniform(0, 8, size=ishape).astype(dtype)
np_out = _np.unravel_index(x.asnumpy(), rshape)
mx_out = test_unravel_index(x)
assert len(mx_out) == len(np_out)
for elem_mx, elem_np in zip(mx_out, np_out):
assert elem_mx.asnumpy().shape == elem_np.shape
assert_almost_equal(elem_mx.asnumpy(), elem_np, rtol=rtol, atol=atol)
# no backward function for unravel_index operator

# Test imperative once again
mx_out = np.unravel_index(x, rshape)
np_out = _np.unravel_index(x.asnumpy(), rshape)
assert len(mx_out) == len(np_out)
for elem_mx, elem_np in zip(mx_out, np_out):
assert elem_mx.asnumpy().shape == elem_np.shape
assert_almost_equal(elem_mx.asnumpy(), elem_np, rtol=rtol, atol=atol)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit f11592d

Please sign in to comment.