diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index f9164855dfe9..80a99148d21c 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -28,7 +28,7 @@ from . import _internal as _npi from ..ndarray import NDArray -__all__ = ['zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', +__all__ = ['zeros', 'ones', 'empty_like', 'full', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', @@ -110,6 +110,74 @@ def ones(shape, dtype=_np.float32, order='C', ctx=None): return _npi.ones(shape=shape, ctx=ctx, dtype=dtype) +@set_module('mxnet.ndarray.numpy') +def empty_like(prototype, dtype=None, order='K', subok=True, shape=None): + """ + Return a new array with the same shape and type as a given array. + + Parameters + ---------- + prototype : ndarray + The shape and data-type of `prototype` define these same attributes + of the returned array. + dtype : data-type, optional + Overrides the data type of the result. + + order : {'C', 'F', 'A', or 'K'}, optional + Overrides the memory layout of the result. 'C' means C-order, + 'F' means F-order, 'A' means 'F' if ``prototype`` is Fortran + contiguous, 'C' otherwise. 'K' means match the layout of ``prototype`` + as closely as possible. + + subok : bool, optional. + If True, then the newly created array will use the sub-class + type of 'a', otherwise it will be a base-class array. Defaults + to True. + shape : int or sequence of ints, optional. + Overrides the shape of the result. If order='K' and the number of + dimensions is unchanged, will try to keep order, otherwise, + order='C' is implied. + (Not supported at this moment) + + Returns + ------- + out : ndarray + Array of uninitialized (arbitrary) data with the same + shape and type as `prototype`. + + See Also + -------- + ones_like : Return an array of ones with shape and type of input. + zeros_like : Return an array of zeros with shape and type of input. + full_like : Return a new array with shape of input filled with value. + empty : Return a new uninitialized array. + + Notes + ----- + This function does *not* initialize the returned array; to do that use + `zeros_like` or `ones_like` instead. It may be marginally faster than + the functions that do set the array values. + + Examples + -------- + >>> a = np.array([[1,2,3], [4,5,6]]) + >>> np.empty_like(a) + array([[-5764607523034234880, -2305834244544065442, 4563075075], # uninitialized + [ 4567052944, -5764607523034234880, 844424930131968]]) + >>> a = np.array([[1., 2., 3.],[4.,5.,6.]]) + >>> np.empty_like(a) + array([[4.9e-324, 9.9e-324, 1.5e-323], # uninitialized + [2.0e-323, 2.5e-323, 3.0e-323]]) + """ + dtype_list = {None:'None', _np.int8:'int8', _np.uint8:'uint8', _np.int32:'int32', + _np.int64:'int64', _np.float16:'float16', _np.float32:'float32', + _np.float64:'float64', _np.bool_:'bool'} + try: + dtype = dtype if isinstance(dtype, str) else dtype_list[dtype] + except: + raise NotImplementedError("Do not support this dtype at this moment") + return _npi.empty_like_fallback(prototype, dtype=dtype, order=order, subok=subok, shape=shape) + @set_module('mxnet.ndarray.numpy') def full(shape, fill_value, dtype=None, order='C', ctx=None, out=None): # pylint: disable=too-many-arguments """ diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index b6816d75a98e..a6b7a78f4b08 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -49,7 +49,7 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', - 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', + 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'empty_like', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', 'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', @@ -1914,6 +1914,68 @@ def tostype(self, stype): raise AttributeError('mxnet.numpy.ndarray object has no attribute tostype') +@set_module('mxnet.numpy') +def empty_like(prototype, dtype=None, order='K', subok=True, shape=None): + """ + Return a new array with the same shape and type as a given array. + + Parameters + ---------- + prototype : ndarray + The shape and data-type of `prototype` define these same attributes + of the returned array. + dtype : data-type, optional + Overrides the data type of the result. + + order : {'C', 'F', 'A', or 'K'}, optional + Overrides the memory layout of the result. 'C' means C-order, + 'F' means F-order, 'A' means 'F' if ``prototype`` is Fortran + contiguous, 'C' otherwise. 'K' means match the layout of ``prototype`` + as closely as possible. + + subok : bool, optional. + If True, then the newly created array will use the sub-class + type of 'a', otherwise it will be a base-class array. Defaults + to True. + shape : int or sequence of ints, optional. + Overrides the shape of the result. If order='K' and the number of + dimensions is unchanged, will try to keep order, otherwise, + order='C' is implied. + (Not supported at this moment) + + Returns + ------- + out : ndarray + Array of uninitialized (arbitrary) data with the same + shape and type as `prototype`. + + See Also + -------- + ones_like : Return an array of ones with shape and type of input. + zeros_like : Return an array of zeros with shape and type of input. + full_like : Return a new array with shape of input filled with value. + empty : Return a new uninitialized array. + + Notes + ----- + This function does *not* initialize the returned array; to do that use + `zeros_like` or `ones_like` instead. It may be marginally faster than + the functions that do set the array values. + + Examples + -------- + >>> a = np.array([[1,2,3], [4,5,6]]) + >>> np.empty_like(a) + array([[-5764607523034234880, -2305834244544065442, 4563075075], # uninitialized + [ 4567052944, -5764607523034234880, 844424930131968]]) + >>> a = np.array([[1., 2., 3.],[4.,5.,6.]]) + >>> np.empty_like(a) + array([[4.9e-324, 9.9e-324, 1.5e-323], # uninitialized + [2.0e-323, 2.5e-323, 3.0e-323]]) + """ + return _mx_nd_np.empty_like(prototype, dtype=dtype, order=order, subok=subok, shape=shape) + + @set_module('mxnet.numpy') def empty(shape, dtype=_np.float32, order='C', ctx=None): """Return a new array of given shape and type, without initializing entries. diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 025982cfc7a5..b16661c68666 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -136,6 +136,8 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'may_share_memory', 'diff', 'resize', + 'empty_like', + 'nan_to_num', ] diff --git a/python/mxnet/numpy_op_fallback.py b/python/mxnet/numpy_op_fallback.py index b98a211c7169..bd07da5a1979 100644 --- a/python/mxnet/numpy_op_fallback.py +++ b/python/mxnet/numpy_op_fallback.py @@ -18,6 +18,7 @@ """Fallback-to-NumPy operator implementation.""" from __future__ import absolute_import +from distutils.version import StrictVersion import functools import ast import numpy as np @@ -49,6 +50,49 @@ def _register_helper(prop_cls): return _register_helper +@use_np # enforce np shape and array semantics for all the methods in this class +class EmptyLike(operator.CustomOp): + """Fallback to NumPy empty_like operator.""" + def __init__(self, dtype, order, subok, shape): + super(EmptyLike, self).__init__() + self._dtype = dtype + self._order = order + self._subok = subok + self._shape = shape + + def forward(self, is_train, req, in_data, out_data, aux): + np_version = np.version.version + if StrictVersion(np_version) >= StrictVersion('1.6.0'): + out = np.empty_like(in_data[0].asnumpy(), dtype=self._dtype, order=self._order, + subok=self._subok) + else: + out = np.empty_like(in_data[0].asnumpy()) + self.assign(out_data[0], req[0], _mx_np.array(out, dtype=out.dtype, ctx=out_data[0].ctx)) + + def backward(self, req, out_grad, in_data, out_data, in_grad, aux): + raise NotImplementedError('Operator empty_like does not support gradient computation') + + +@register('empty_like_fallback') +class EmptyLikeProp(operator.CustomOpProp): + """Fallback empty_like operator properties.""" + def __init__(self, dtype, order, subok, shape): + super(EmptyLikeProp, self).__init__(need_top_grad=True) + self._dtype = None if dtype == 'None' else dtype + self._order = order + self._subok = ast.literal_eval(subok) + self._shape = ast.literal_eval(shape) + + def list_arguments(self): + return ['prototype'] + + def infer_shape(self, in_shape): + return (in_shape[0],), (in_shape[0],), () + + def create_operator(self, ctx, in_shapes, in_dtypes): + return EmptyLike(self._dtype, self._order, self._subok, self._shape) + + @use_np # enforce np shape and array semantics for all the methods in this class class Resize(operator.CustomOp): """Fallback to NumPy resize operator.""" diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 0d7303865b92..59c7def9d49c 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -34,7 +34,7 @@ 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', - 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', + 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', 'empty_like', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', @@ -1344,6 +1344,64 @@ def eye(N, M=None, k=0, dtype=_np.float32, **kwargs): return _npi.eye(N, M, k, ctx, dtype) +@set_module('mxnet.symbol.numpy') +def empty_like(prototype, dtype=None, order='K', subok=True, shape=None): + """ + Return a new array with the same shape and type as a given array. + + Parameters + ---------- + prototype : _Symbol + The shape and data-type of `prototype` define these same attributes + of the returned array. + dtype : data-type, optional + Overrides the data type of the result. + + order : {'C', 'F', 'A', or 'K'}, optional + Overrides the memory layout of the result. 'C' means C-order, + 'F' means F-order, 'A' means 'F' if ``prototype`` is Fortran + contiguous, 'C' otherwise. 'K' means match the layout of ``prototype`` + as closely as possible. + + subok : bool, optional. + If True, then the newly created array will use the sub-class + type of 'a', otherwise it will be a base-class array. Defaults + to True. + shape : int or sequence of ints, optional. + Overrides the shape of the result. If order='K' and the number of + dimensions is unchanged, will try to keep order, otherwise, + order='C' is implied. + (Not supported at this moment) + + Returns + ------- + out : _Symbol + Array of uninitialized (arbitrary) data with the same + shape and type as `prototype`. + + See Also + -------- + ones_like : Return an array of ones with shape and type of input. + zeros_like : Return an array of zeros with shape and type of input. + full_like : Return a new array with shape of input filled with value. + empty : Return a new uninitialized array. + + Notes + ----- + This function does *not* initialize the returned array; to do that use + `zeros_like` or `ones_like` instead. It may be marginally faster than + the functions that do set the array values. + """ + dtype_list = {None:'None', _np.int8:'int8', _np.uint8:'uint8', _np.int32:'int32', + _np.int64:'int64', _np.float16:'float16', _np.float32:'float32', + _np.float64:'float64', _np.bool_:'bool'} + try: + dtype = dtype if isinstance(dtype, str) else dtype_list[dtype] + except: + raise NotImplementedError("Do not support this dtype at this moment") + return _npi.empty_like_fallback(prototype, dtype=dtype, order=order, subok=subok, shape=shape) + + @set_module('mxnet.symbol.numpy') def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments r""" diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 6d9c63f9f857..ad8d364ea95d 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1172,6 +1172,24 @@ def _add_workload_resize(): OpArgMngr.add_workload('resize', np.zeros((10, 0)), (0, 10)) OpArgMngr.add_workload('resize', np.zeros((10, 0)), (0, 100)) +def _add_workload_empty_like(): + OpArgMngr.add_workload('empty_like', np.random.uniform(low=0, high=100, size=(1,3,4), dtype='float64')) + OpArgMngr.add_workload('empty_like', np.random.uniform(low=0, high=100, size=(9,3,1)), np.int32) + OpArgMngr.add_workload('empty_like', np.random.uniform(low=0, high=100, size=(9,3)), 'float32') + OpArgMngr.add_workload('empty_like', np.random.uniform(low=0, high=100, size=(9,3,1)), np.bool_, 'K', True) + OpArgMngr.add_workload('empty_like', np.random.uniform(low=0, high=100, size=(0,3)), np.float32) + + +def _add_workload_nan_to_num(): + array1 = np.array([[-433, 0, 456, _np.inf], [-1, -_np.inf, 0, 1]]) + array2 = np.array([_np.nan, _np.inf, -_np.inf, -574, 0, 23425, 24234,-5]) + array3 = np.array(-_np.inf) + OpArgMngr.add_workload('nan_to_num', array1, True, 0, 100, -100) + OpArgMngr.add_workload('nan_to_num', array1, True, 0.00) + OpArgMngr.add_workload('nan_to_num', array2, True) + OpArgMngr.add_workload('nan_to_num', array2, True, -2000, 10000, -10000) + OpArgMngr.add_workload('nan_to_num', array3, True) + @use_np def _prepare_workloads(): @@ -1284,6 +1302,8 @@ def _prepare_workloads(): _add_workload_less_equal(array_pool) _add_workload_diff() _add_workload_resize() + _add_workload_empty_like() + _add_workload_nan_to_num() _prepare_workloads() @@ -1328,7 +1348,7 @@ def check_interoperability(op_list): for name in op_list: if name in _TVM_OPS and not is_op_runnable(): continue - if name in ['shares_memory', 'may_share_memory']: # skip list + if name in ['shares_memory', 'may_share_memory', 'empty_like']: # skip list continue print('Dispatch test:', name) workloads = OpArgMngr.get_workloads(name) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index fb022e06158c..cd2cbeb26d34 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -17,6 +17,7 @@ # pylint: skip-file from __future__ import absolute_import +from distutils.version import StrictVersion import sys import unittest import itertools @@ -4363,7 +4364,6 @@ def hybrid_forward(self, F, x, *args, **kwargs): @with_seed() @use_np def test_np_nan_to_num(): - def take_ele_grad(ele): if _np.isinf(ele) or _np.isnan(ele): return 0 @@ -4463,6 +4463,65 @@ def hybrid_forward(self, F, a): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) +@with_seed() +@use_np +def test_np_empty_like(): + class TestEmptyLike(HybridBlock): + def __init__(self, dtype, order, subok, shape): + super(TestEmptyLike, self).__init__() + self._dtype = dtype + self._order = order + self._subok = subok + self._shape = shape + + def hybrid_forward(self, F, x, *args, **kwargs): + return F.np.empty_like(x, self._dtype, self._order, self._subok, self._shape) + + if StrictVersion(platform.python_version()) < StrictVersion('3.0.0'): + return + + dtypes = [None, 'float16', 'float32', np.int8, np.uint8, np.int32, np.int64, + np.float16, np.float32, np.float64, np.bool_] + shapes = [ + (), + (1,), + (5,), + (4, 3), + (3, 5), + (4, 4), + (4, 5), + (5, 5), + (5, 6), + (6, 6), + (0, 1), + (6, 5, 6), + (2, 3, 3, 4), + (4, 2, 1, 2), + (0, 5, 3, 3), + (5, 0, 3, 3), + (3, 3, 0, 0), + ] + orders = ["C", "F", "A", "K"] + subok_list = [True, False] + flags = [True, False] + _np_version = _np.version.version + for dtype, shape, hybridize, order, subok in itertools.product(dtypes, shapes, flags, orders, subok_list): + prototype = np.random.uniform(low=0, high=100, size=shape, dtype='float64').astype(dtype) + test = TestEmptyLike(dtype, order, subok, shape) + if StrictVersion(_np_version) >= StrictVersion('1.6.0'): + expected_ret = _np.empty_like(prototype, dtype=dtype, order=order, subok=subok) + else: + expected_ret = _np.empty_like(prototype) + if hybridize: + test.hybridize() + ret = test(prototype) + assert ret.asnumpy().shape == expected_ret.shape + + # check imperative again + ret = np.empty_like(prototype, dtype, order, subok, shape) + assert ret.asnumpy().shape == expected_ret.shape + + if __name__ == '__main__': import nose nose.runmodule()