Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] op empty_like, add nan_to_num to dispatch #17169

Merged
merged 1 commit into from
Dec 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from . import _internal as _npi
from ..ndarray import NDArray

__all__ = ['shape', 'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'invert', 'delete',
__all__ = ['shape', 'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'empty_like', 'invert', 'delete',
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'bitwise_not',
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs',
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
Expand Down Expand Up @@ -372,6 +372,78 @@ def full_like(a, fill_value, dtype=None, order='C', ctx=None, out=None): # pylin
return _npi.full_like(a, fill_value=fill_value, dtype=dtype, ctx=ctx, out=out)


@set_module('mxnet.ndarray.numpy')
def empty_like(prototype, dtype=None, order='C', subok=False, shape=None): # pylint: disable=W0621
"""
Return a new array with the same shape and type as a given array.
Parameters
----------
prototype : ndarray
The shape and data-type of `prototype` define these same attributes
of the returned array.
dtype : data-type, optional
Overrides the data type of the result.
order : {'C'}, optional
Whether to store multidimensional data in C- or Fortran-contiguous
(row- or column-wise) order in memory. Currently only supports C order.
subok : {False}, optional
If True, then the newly created array will use the sub-class
type of 'a', otherwise it will be a base-class array. Defaults
to False.
(Only support False at this moment)
shape : int or sequence of ints, optional.
Overrides the shape of the result. If order='K' and the number of
dimensions is unchanged, will try to keep order, otherwise,
order='C' is implied.
(Not supported at this moment)
Returns
-------
out : ndarray
Array of uninitialized (arbitrary) data with the same
shape and type as `prototype`.
See Also
--------
ones_like : Return an array of ones with shape and type of input.
zeros_like : Return an array of zeros with shape and type of input.
full_like : Return a new array with shape of input filled with value.
empty : Return a new uninitialized array.
Notes
-----
This function does *not* initialize the returned array; to do that use
`zeros_like` or `ones_like` instead. It may be marginally faster than
the functions that do set the array values.
Examples
--------
>>> a = np.array([[1,2,3], [4,5,6]])
>>> np.empty_like(a)
array([[-5764607523034234880, -2305834244544065442, 4563075075], # uninitialized
[ 4567052944, -5764607523034234880, 844424930131968]])
>>> a = np.array([[1., 2., 3.],[4.,5.,6.]])
>>> np.empty_like(a)
array([[4.9e-324, 9.9e-324, 1.5e-323], # uninitialized
[2.0e-323, 2.5e-323, 3.0e-323]])
"""
dtype_list = {None:'None', _np.int8:'int8', _np.uint8:'uint8', _np.int32:'int32',
_np.int64:'int64', _np.float16:'float16', _np.float32:'float32',
_np.float64:'float64', _np.bool_:'bool'}
if order != 'C':
raise NotImplementedError("Only support C-order at this moment")
if subok:
raise NotImplementedError("Creating array by using sub-class is not supported at this moment")
if shape is not None:
raise NotImplementedError("Assigning new shape is not supported at this moment")
try:
dtype = dtype if isinstance(dtype, str) else dtype_list[dtype]
except:
raise NotImplementedError("Do not support this dtype at this moment")
return _npi.empty_like_fallback(prototype, dtype=dtype, order=order, subok=subok, shape=shape)
Alicia1529 marked this conversation as resolved.
Show resolved Hide resolved


@set_module('mxnet.ndarray.numpy')
def arange(start, stop=None, step=1, dtype=None, ctx=None):
"""Return evenly spaced values within a given interval.
Expand Down
62 changes: 61 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
from ..ndarray.numpy import _internal as _npi
from ..ndarray.ndarray import _storage_type

__all__ = ['ndarray', 'empty', 'array', 'shape', 'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like',
__all__ = ['ndarray', 'empty', 'empty_like', 'array', 'shape',
'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like',
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'bitwise_not', 'delete',
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'invert',
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
Expand Down Expand Up @@ -2202,6 +2203,65 @@ def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None):
# pylint: enable=too-many-arguments, redefined-outer-name


@set_module('mxnet.numpy')
def empty_like(prototype, dtype=None, order='C', subok=False, shape=None): # pylint: disable=W0621
"""
Return a new array with the same shape and type as a given array.
Parameters
----------
prototype : ndarray
The shape and data-type of `prototype` define these same attributes
of the returned array.
dtype : data-type, optional
Overrides the data type of the result.
order : {'C'}, optional
Whether to store multidimensional data in C- or Fortran-contiguous
(row- or column-wise) order in memory. Currently only supports C order.
subok : {False}, optional
If True, then the newly created array will use the sub-class
type of 'a', otherwise it will be a base-class array. Defaults
to False.
(Only support False at this moment)
shape : int or sequence of ints, optional.
Overrides the shape of the result. If order='K' and the number of
dimensions is unchanged, will try to keep order, otherwise,
order='C' is implied.
(Not supported at this moment)
Returns
-------
out : ndarray
Array of uninitialized (arbitrary) data with the same
shape and type as `prototype`.
See Also
--------
ones_like : Return an array of ones with shape and type of input.
zeros_like : Return an array of zeros with shape and type of input.
full_like : Return a new array with shape of input filled with value.
empty : Return a new uninitialized array.
Notes
-----
This function does *not* initialize the returned array; to do that use
`zeros_like` or `ones_like` instead. It may be marginally faster than
the functions that do set the array values.
Examples
--------
>>> a = np.array([[1,2,3], [4,5,6]])
>>> np.empty_like(a)
array([[-5764607523034234880, -2305834244544065442, 4563075075], # uninitialized
[ 4567052944, -5764607523034234880, 844424930131968]])
>>> a = np.array([[1., 2., 3.],[4.,5.,6.]])
>>> np.empty_like(a)
array([[4.9e-324, 9.9e-324, 1.5e-323], # uninitialized
[2.0e-323, 2.5e-323, 3.0e-323]])
"""
return _mx_nd_np.empty_like(prototype, dtype=dtype, order=order, subok=subok, shape=shape)


@set_module('mxnet.numpy')
def identity(n, dtype=None, ctx=None):
"""
Expand Down
4 changes: 3 additions & 1 deletion python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'resize',
'where',
'full_like',
'bincount'
'bincount',
'empty_like',
'nan_to_num',
]


Expand Down
44 changes: 44 additions & 0 deletions python/mxnet/numpy_op_fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Fallback-to-NumPy operator implementation."""

from __future__ import absolute_import
from distutils.version import StrictVersion
import functools
import ast
import numpy as np
Expand Down Expand Up @@ -49,6 +50,49 @@ def _register_helper(prop_cls):
return _register_helper


@use_np # enforce np shape and array semantics for all the methods in this class
class EmptyLike(operator.CustomOp):
"""Fallback to NumPy empty_like operator."""
def __init__(self, dtype, order, subok, shape):
super(EmptyLike, self).__init__()
self._dtype = dtype
self._order = order
self._subok = subok
self._shape = shape

def forward(self, is_train, req, in_data, out_data, aux):
np_version = np.version.version
if StrictVersion(np_version) >= StrictVersion('1.6.0'):
out = np.empty_like(in_data[0].asnumpy(), dtype=self._dtype, order=self._order,
subok=self._subok)
else:
out = np.empty_like(in_data[0].asnumpy())
self.assign(out_data[0], req[0], _mx_np.array(out, dtype=out.dtype, ctx=out_data[0].ctx))

def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
raise NotImplementedError('Operator empty_like does not support gradient computation')


@register('empty_like_fallback')
class EmptyLikeProp(operator.CustomOpProp):
"""Fallback empty_like operator properties."""
def __init__(self, dtype, order, subok, shape):
super(EmptyLikeProp, self).__init__(need_top_grad=True)
self._dtype = None if dtype == 'None' else dtype
self._order = order
self._subok = ast.literal_eval(subok)
self._shape = ast.literal_eval(shape)

def list_arguments(self):
return ['prototype']

def infer_shape(self, in_shape):
return (in_shape[0],), (in_shape[0],), ()

def create_operator(self, ctx, in_shapes, in_dtypes):
return EmptyLike(self._dtype, self._order, self._subok, self._shape)


@use_np # enforce np shape and array semantics for all the methods in this class
class Resize(operator.CustomOp):
"""Fallback to NumPy resize operator."""
Expand Down
63 changes: 62 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
except ImportError:
from builtins import slice as py_slice

__all__ = ['zeros', 'zeros_like', 'ones', 'ones_like', 'full_like', 'bitwise_not', 'invert', 'delete',
__all__ = ['zeros', 'zeros_like', 'ones', 'ones_like', 'full_like', 'empty_like', 'bitwise_not', 'invert', 'delete',
'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2',
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp',
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
Expand Down Expand Up @@ -1719,6 +1719,67 @@ def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
return _npi.eye(N, M, k, ctx, dtype)


@set_module('mxnet.symbol.numpy')
def empty_like(prototype, dtype=None, order='C', subok=False, shape=None): # pylint: disable=W0621
"""
Return a new array with the same shape and type as a given array.
Parameters
----------
prototype : _Symbol
The shape and data-type of `prototype` define these same attributes
of the returned array.
dtype : data-type, optional
Overrides the data type of the result.
order : {'C'}, optional
Whether to store multidimensional data in C- or Fortran-contiguous
(row- or column-wise) order in memory. Currently only supports C order.
subok : bool, optional.
If True, then the newly created array will use the sub-class
type of 'a', otherwise it will be a base-class array. Defaults
to False.
(Only support False at this moment)
shape : int or sequence of ints, optional.
Overrides the shape of the result. If order='K' and the number of
dimensions is unchanged, will try to keep order, otherwise,
order='C' is implied.
(This parameter is not supported at this moment)
Returns
-------
out : _Symbol
Array of uninitialized (arbitrary) data with the same
shape and type as `prototype`.
See Also
--------
ones_like : Return an array of ones with shape and type of input.
zeros_like : Return an array of zeros with shape and type of input.
full_like : Return a new array with shape of input filled with value.
empty : Return a new uninitialized array.
Notes
-----
This function does *not* initialize the returned array; to do that use
`zeros_like` or `ones_like` instead. It may be marginally faster than
the functions that do set the array values.
"""
dtype_list = {None:'None', _np.int8:'int8', _np.uint8:'uint8', _np.int32:'int32',
_np.int64:'int64', _np.float16:'float16', _np.float32:'float32',
_np.float64:'float64', _np.bool_:'bool'}
if order != 'C':
raise NotImplementedError("Only support C order at this moment")
if subok:
raise NotImplementedError("Creating array by using sub-class is not supported at this moment")
if shape is not None:
raise NotImplementedError("Parameter 'shape' is not supported at this moment")
try:
dtype = dtype if isinstance(dtype, str) else dtype_list[dtype]
except:
raise NotImplementedError("Do not support this dtype at this moment")
return _npi.empty_like_fallback(prototype, dtype=dtype, order=order, subok=subok, shape=shape)


@set_module('mxnet.symbol.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
r"""
Expand Down
22 changes: 21 additions & 1 deletion tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -1555,6 +1555,24 @@ def _add_workload_resize():
OpArgMngr.add_workload('resize', np.zeros((10, 0)), (0, 10))
OpArgMngr.add_workload('resize', np.zeros((10, 0)), (0, 100))

def _add_workload_empty_like():
OpArgMngr.add_workload('empty_like', np.random.uniform(low=0, high=100, size=(1,3,4), dtype='float64'))
OpArgMngr.add_workload('empty_like', np.random.uniform(low=0, high=100, size=(9,3,1)), np.int32)
OpArgMngr.add_workload('empty_like', np.random.uniform(low=0, high=100, size=(9,3)), 'float32')
OpArgMngr.add_workload('empty_like', np.random.uniform(low=0, high=100, size=(9,3,1)), np.bool_)
OpArgMngr.add_workload('empty_like', np.random.uniform(low=0, high=100, size=(0,3)), np.float32)


def _add_workload_nan_to_num():
array1 = np.array([[-433, 0, 456, _np.inf], [-1, -_np.inf, 0, 1]])
array2 = np.array([_np.nan, _np.inf, -_np.inf, -574, 0, 23425, 24234,-5])
array3 = np.array(-_np.inf)
OpArgMngr.add_workload('nan_to_num', array1, True, 0, 100, -100)
OpArgMngr.add_workload('nan_to_num', array1, True, 0.00)
OpArgMngr.add_workload('nan_to_num', array2, True)
OpArgMngr.add_workload('nan_to_num', array2, True, -2000, 10000, -10000)
OpArgMngr.add_workload('nan_to_num', array3, True)


@use_np
def _prepare_workloads():
Expand Down Expand Up @@ -1691,6 +1709,8 @@ def _prepare_workloads():
_add_workload_diff()
_add_workload_resize()
_add_workload_full_like(array_pool)
_add_workload_empty_like()
_add_workload_nan_to_num()


_prepare_workloads()
Expand Down Expand Up @@ -1735,7 +1755,7 @@ def check_interoperability(op_list):
for name in op_list:
if name in _TVM_OPS and not is_op_runnable():
continue
if name in ['shares_memory', 'may_share_memory']: # skip list
if name in ['shares_memory', 'may_share_memory', 'empty_like']: # skip list
continue
if name in ['full_like', 'zeros_like', 'ones_like'] and \
StrictVersion(platform.python_version()) < StrictVersion('3.0.0'):
Expand Down
Loading