diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index ad76e43b2a90..cf66e29d6205 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -38,7 +38,7 @@ 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', - 'hsplit', 'rot90', 'einsum'] + 'hsplit', 'rot90', 'einsum', 'true_divide'] @set_module('mxnet.ndarray.numpy') @@ -337,10 +337,10 @@ def take(a, indices, axis=None, mode='raise', out=None): if mode not in ('wrap', 'clip', 'raise'): raise NotImplementedError( "function take does not support mode '{}'".format(mode)) - if axis: - return _npi.take(a, indices, axis, mode, out) - else: + if axis is None: return _npi.take(_npi.reshape(a, -1), indices, 0, mode, out) + else: + return _npi.take(a, indices, axis, mode, out) # pylint: enable=redefined-outer-name @@ -495,7 +495,11 @@ def unique(ar, return_index=False, return_inverse=False, return_counts=False, ax >>> u[indices] array([1., 2., 6., 4., 2., 3., 2.]) """ - return _npi.unique(ar, return_index, return_inverse, return_counts, axis) + ret = _npi.unique(ar, return_index, return_inverse, return_counts, axis) + if isinstance(ret, list): + return tuple(ret) + else: + return ret @set_module('mxnet.ndarray.numpy') @@ -604,6 +608,36 @@ def divide(x1, x2, out=None, **kwargs): _npi.rtrue_divide_scalar, out) +@set_module('mxnet.ndarray.numpy') +def true_divide(x1, x2, out=None): + """Returns a true division of the inputs, element-wise. + + Instead of the Python traditional 'floor division', this returns a true + division. True division adjusts the output type to present the best + answer, regardless of input types. + + Parameters + ---------- + x1 : ndarray or scalar + Dividend array. + + x2 : ndarray or scalar + Divisor array. + + out : ndarray + A location into which the result is stored. If provided, it must have a shape + that the inputs broadcast to. If not provided or None, a freshly-allocated array + is returned. + + Returns + ------- + out : ndarray or scalar + This is a scalar if both x1 and x2 are scalars. + """ + return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar, + _npi.rtrue_divide_scalar, out) + + @set_module('mxnet.ndarray.numpy') @wrap_np_binary_func def mod(x1, x2, out=None, **kwargs): diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index e507b17d68d8..8cca82bcf827 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -55,7 +55,7 @@ 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', - 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum'] + 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -2216,6 +2216,35 @@ def divide(x1, x2, out=None, **kwargs): return _mx_nd_np.divide(x1, x2, out=out) +@set_module('mxnet.numpy') +def true_divide(x1, x2, out=None): + """Returns a true division of the inputs, element-wise. + + Instead of the Python traditional 'floor division', this returns a true + division. True division adjusts the output type to present the best + answer, regardless of input types. + + Parameters + ---------- + x1 : ndarray or scalar + Dividend array. + + x2 : ndarray or scalar + Divisor array. + + out : ndarray + A location into which the result is stored. If provided, it must have a shape + that the inputs broadcast to. If not provided or None, a freshly-allocated array + is returned. + + Returns + ------- + out : ndarray or scalar + This is a scalar if both x1 and x2 are scalars. + """ + return _mx_nd_np.true_divide(x1, x2, out=out) + + @set_module('mxnet.numpy') @wrap_np_binary_func def mod(x1, x2, out=None, **kwargs): diff --git a/python/mxnet/numpy/utils.py b/python/mxnet/numpy/utils.py index b2d0dd96d324..b2335e29855d 100644 --- a/python/mxnet/numpy/utils.py +++ b/python/mxnet/numpy/utils.py @@ -23,7 +23,7 @@ import numpy as onp __all__ = ['float16', 'float32', 'float64', 'uint8', 'int32', 'int8', 'int64', - 'bool', 'bool_', 'pi', 'inf', 'nan'] + 'bool', 'bool_', 'pi', 'inf', 'nan', 'PZERO', 'NZERO'] float16 = onp.float16 float32 = onp.float32 @@ -38,3 +38,5 @@ pi = onp.pi inf = onp.inf nan = onp.nan +PZERO = onp.PZERO +NZERO = onp.NZERO diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index a241d2687ee4..9fd3976dead4 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -84,6 +84,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): _NUMPY_ARRAY_FUNCTION_LIST = [ 'argmax', + 'around', 'broadcast_arrays', 'broadcast_to', 'clip', @@ -93,6 +94,8 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'dot', 'expand_dims', 'fix', + 'flip', + 'inner', 'max', 'mean', 'min', @@ -108,9 +111,11 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'std', 'sum', 'swapaxes', + 'take', 'tensordot', 'tile', 'transpose', + 'unique', 'var', 'zeros_like', 'meshgrid', @@ -161,11 +166,17 @@ def _register_array_function(): # https://docs.scipy.org/doc/numpy/reference/ufuncs.html#available-ufuncs _NUMPY_ARRAY_UFUNC_LIST = [ + 'abs', 'add', + 'arctan2', + 'copysign', + 'degrees', + 'hypot', + 'lcm', + # 'ldexp', 'subtract', 'multiply', - # Uncomment divide when mxnet.numpy.true_divide is added - # 'divide', + 'true_divide', 'negative', 'power', 'mod', diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 1945c5b0e695..4b4319f45373 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -40,7 +40,7 @@ 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', - 'less_equal', 'hsplit', 'rot90', 'einsum'] + 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide'] def _num_outputs(sym): @@ -1082,10 +1082,10 @@ def take(a, indices, axis=None, mode='raise', out=None): if mode not in ('wrap', 'clip', 'raise'): raise NotImplementedError( "function take does not support mode '{}'".format(mode)) - if axis: - return _npi.take(a, indices, axis, mode, out) - else: + if axis is None: return _npi.take(_npi.reshape(a, -1), indices, 0, mode, out) + else: + return _npi.take(a, indices, axis, mode, out) # pylint: enable=redefined-outer-name @@ -1164,6 +1164,12 @@ def divide(x1, x2, out=None, **kwargs): _npi.rtrue_divide_scalar, out) +@set_module('mxnet.ndarray.numpy') +def true_divide(x1, x2, out=None): + return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar, + _npi.rtrue_divide_scalar, out) + + @set_module('mxnet.symbol.numpy') @wrap_np_binary_func def mod(x1, x2, out=None, **kwargs): diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index f22d42bb678b..7495c9002f26 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -17,6 +17,8 @@ # pylint: skip-file from __future__ import absolute_import +from __future__ import division +import itertools import numpy as _np from mxnet import np from mxnet.test_utils import assert_almost_equal @@ -25,7 +27,11 @@ from mxnet.numpy_dispatch_protocol import with_array_function_protocol, with_array_ufunc_protocol from mxnet.numpy_dispatch_protocol import _NUMPY_ARRAY_FUNCTION_LIST, _NUMPY_ARRAY_UFUNC_LIST -import itertools + +_INT_DTYPES = [np.int8, np.int32, np.int64, np.uint8] +_FLOAT_DTYPES = [np.float16, np.float32, np.float64] +_DTYPES = _INT_DTYPES + _FLOAT_DTYPES + class OpArgMngr(object): """Operator argument manager for storing operator workloads.""" @@ -114,50 +120,158 @@ def _add_workload_einsum(): OpArgMngr.add_workload('einsum', subscripts, *args, optimize=optimize) -@use_np -def _prepare_workloads(): - array_pool = { - '4x1': np.random.uniform(size=(4, 1)) + 2, - '1x2': np.random.uniform(size=(1, 2)) + 2, - '1x1x0': np.array([[[]]]) - } +def _add_workload_argmax(): + OpArgMngr.add_workload('argmax', np.random.uniform(size=(4, 5, 6, 7, 8)), 0) + OpArgMngr.add_workload('argmax', np.random.uniform(size=(4, 5, 6, 7, 8)), 1) + OpArgMngr.add_workload('argmax', np.random.uniform(size=(4, 5, 6, 7, 8)), 2) + OpArgMngr.add_workload('argmax', np.random.uniform(size=(4, 5, 6, 7, 8)), 3) + OpArgMngr.add_workload('argmax', np.random.uniform(size=(4, 5, 6, 7, 8)), 4) + # OpArgMngr.add_workload('argmax', np.array([0, 1, 2, 3, np.nan])) + # OpArgMngr.add_workload('argmax', np.array([0, 1, 2, np.nan, 3])) + # OpArgMngr.add_workload('argmax', np.array([np.nan, 0, 1, 2, 3])) + # OpArgMngr.add_workload('argmax', np.array([np.nan, 0, np.nan, 2, 3])) + OpArgMngr.add_workload('argmax', np.array([False, False, False, False, True])) + OpArgMngr.add_workload('argmax', np.array([False, False, False, True, False])) + OpArgMngr.add_workload('argmax', np.array([True, False, False, False, False])) + OpArgMngr.add_workload('argmax', np.array([True, False, True, False, False])) + + +def _add_workload_around(): + OpArgMngr.add_workload('around', np.array([1.56, 72.54, 6.35, 3.25]), decimals=1) - dt_int = [np.int8, np.int32, np.int64, np.uint8] - dt_float = [np.float16, np.float32, np.float64] - dt = dt_int + dt_float - # workloads for array function protocol - OpArgMngr.add_workload('argmax', array_pool['4x1']) +def _add_workload_broadcast_arrays(array_pool): OpArgMngr.add_workload('broadcast_arrays', array_pool['4x1'], array_pool['1x2']) - OpArgMngr.add_workload('broadcast_to', array_pool['4x1'], (4, 2)) - OpArgMngr.add_workload('clip', array_pool['4x1'], 0.2, 0.8) + + +def _add_workload_broadcast_to(): + OpArgMngr.add_workload('broadcast_to', np.array(0), (0,)) + OpArgMngr.add_workload('broadcast_to', np.array(0), (1,)) + OpArgMngr.add_workload('broadcast_to', np.array(0), (3,)) + OpArgMngr.add_workload('broadcast_to', np.ones(1), (1,)) + OpArgMngr.add_workload('broadcast_to', np.ones(1), (2,)) + OpArgMngr.add_workload('broadcast_to', np.ones(1), (1, 2, 3)) + OpArgMngr.add_workload('broadcast_to', np.arange(3), (3,)) + OpArgMngr.add_workload('broadcast_to', np.arange(3), (1, 3)) + OpArgMngr.add_workload('broadcast_to', np.arange(3), (2, 3)) + OpArgMngr.add_workload('broadcast_to', np.ones(0), 0) + OpArgMngr.add_workload('broadcast_to', np.ones(1), 1) + OpArgMngr.add_workload('broadcast_to', np.ones(1), 2) + OpArgMngr.add_workload('broadcast_to', np.ones(1), (0,)) + OpArgMngr.add_workload('broadcast_to', np.ones((1, 2)), (0, 2)) + OpArgMngr.add_workload('broadcast_to', np.ones((2, 1)), (2, 0)) + + +def _add_workload_clip(): + OpArgMngr.add_workload('clip', (np.random.normal(size=(1000,)) * 1024).astype("float"), -12.8, 100.2) + OpArgMngr.add_workload('clip', (np.random.normal(size=(1000,)) * 1024).astype("float"), 0, 0) + OpArgMngr.add_workload('clip', (np.random.normal(size=(1000,)) * 1024).astype("int"), -120, 100) + OpArgMngr.add_workload('clip', (np.random.normal(size=(1000,)) * 1024).astype("int"), 0.0, 2.0) + OpArgMngr.add_workload('clip', (np.random.normal(size=(1000,)) * 1024).astype("int"), 0, 0) + OpArgMngr.add_workload('clip', (np.random.normal(size=(1000,)) * 1024).astype("uint8"), 0, 0) + OpArgMngr.add_workload('clip', (np.random.normal(size=(1000,)) * 1024).astype("uint8"), 0.0, 2.0) + OpArgMngr.add_workload('clip', (np.random.normal(size=(1000,)) * 1024).astype("uint8"), -120, 100) + # OpArgMngr.add_workload('clip', np.random.normal(size=(1000,)), np.zeros((1000,))+0.5, 1) + # OpArgMngr.add_workload('clip', np.random.normal(size=(1000,)), 0, np.zeros((1000,))+0.5) + # OpArgMngr.add_workload('clip', np.array([0, 1, 2, 3, 4, 5, 6, 7]), 3) + # OpArgMngr.add_workload('clip', np.array([0, 1, 2, 3, 4, 5, 6, 7]), a_min=3) + # OpArgMngr.add_workload('clip', np.array([0, 1, 2, 3, 4, 5, 6, 7]), a_max=4) + OpArgMngr.add_workload('clip', np.array([-2., np.nan, 0.5, 3., 0.25, np.nan]), -1, 1) + + +def _add_workload_concatenate(array_pool): OpArgMngr.add_workload('concatenate', [array_pool['4x1'], array_pool['4x1']]) OpArgMngr.add_workload('concatenate', [array_pool['4x1'], array_pool['4x1']], axis=1) + + +def _add_workload_copy(array_pool): OpArgMngr.add_workload('copy', array_pool['4x1']) - for ctype in dt: + +def _add_workload_cumsum(): + for ctype in _DTYPES: OpArgMngr.add_workload('cumsum', np.array([1, 2, 10, 11, 6, 5, 4], dtype=ctype)) OpArgMngr.add_workload('cumsum', np.array([[1, 2, 3, 4], [5, 6, 7, 9], [10, 3, 4, 5]], dtype=ctype), axis=0) OpArgMngr.add_workload('cumsum', np.array([[1, 2, 3, 4], [5, 6, 7, 9], [10, 3, 4, 5]], dtype=ctype), axis=1) + +def _add_workload_ravel(): OpArgMngr.add_workload('ravel', np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])) - OpArgMngr.add_workload('dot', array_pool['4x1'], array_pool['4x1'].T) + +def _add_workload_dot(): + OpArgMngr.add_workload('dot', np.random.normal(size=(2, 4)), np.random.normal(size=(4, 2))) + OpArgMngr.add_workload('dot', np.random.normal(size=(4, 2)), np.random.normal(size=(2, 1))) + OpArgMngr.add_workload('dot', np.random.normal(size=(4, 2)), np.random.normal(size=(2,))) + OpArgMngr.add_workload('dot', np.random.normal(size=(1, 2)), np.random.normal(size=(2, 4))) + OpArgMngr.add_workload('dot', np.random.normal(size=(2, 4)), np.random.normal(size=(4,))) + OpArgMngr.add_workload('dot', np.random.normal(size=(1, 2)), np.random.normal(size=(2, 1))) + OpArgMngr.add_workload('dot', np.ones((3, 1)), np.array([5.3])) + OpArgMngr.add_workload('dot', np.array([5.3]), np.ones((1, 3))) + OpArgMngr.add_workload('dot', np.random.normal(size=(1, 1)), np.random.normal(size=(1, 4))) + OpArgMngr.add_workload('dot', np.random.normal(size=(4, 1)), np.random.normal(size=(1, 1))) + + dims = [(), (1,), (1, 1)] + for (dim1, dim2) in itertools.product(dims, dims): + b1 = np.zeros(dim1) + b2 = np.zeros(dim2) + OpArgMngr.add_workload('dot', b1, b2) + OpArgMngr.add_workload('dot', np.array([[1, 2], [3, 4]], dtype=float), np.array([[1, 0], [1, 1]], dtype=float)) + OpArgMngr.add_workload('dot', np.random.normal(size=(1024, 16)), np.random.normal(size=(16, 32))) + + +def _add_workload_expand_dims(array_pool): OpArgMngr.add_workload('expand_dims', array_pool['4x1'], -1) - OpArgMngr.add_workload('fix', array_pool['4x1']) + + +def _add_workload_fix(): + OpArgMngr.add_workload('fix', np.array([[1.0, 1.1, 1.5, 1.8], [-1.0, -1.1, -1.5, -1.8]])) + OpArgMngr.add_workload('fix', np.array([3.14])) + + +def _add_workload_flip(): + OpArgMngr.add_workload('flip', np.random.normal(size=(4, 4)), 1) + OpArgMngr.add_workload('flip', np.array([[0, 1, 2], [3, 4, 5]]), 1) + OpArgMngr.add_workload('flip', np.random.normal(size=(4, 4)), 0) + OpArgMngr.add_workload('flip', np.array([[0, 1, 2], [3, 4, 5]]), 0) + OpArgMngr.add_workload('flip', np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]), 0) + OpArgMngr.add_workload('flip', np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]), 1) + OpArgMngr.add_workload('flip', np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]), 2) + for i in range(4): + OpArgMngr.add_workload('flip', np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5), i) + OpArgMngr.add_workload('flip', np.array([[1, 2, 3], [4, 5, 6]])) + OpArgMngr.add_workload('flip', np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]), ()) + OpArgMngr.add_workload('flip', np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]), (0, 2)) + OpArgMngr.add_workload('flip', np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]), (1, 2)) + + +def _add_workload_max(array_pool): OpArgMngr.add_workload('max', array_pool['4x1']) + + +def _add_workload_min(array_pool): OpArgMngr.add_workload('min', array_pool['4x1']) + + +def _add_workload_mean(array_pool): OpArgMngr.add_workload('mean', array_pool['4x1']) OpArgMngr.add_workload('mean', array_pool['4x1'], axis=0, keepdims=True) OpArgMngr.add_workload('mean', np.array([[1, 2, 3], [4, 5, 6]])) OpArgMngr.add_workload('mean', np.array([[1, 2, 3], [4, 5, 6]]), axis=0) OpArgMngr.add_workload('mean', np.array([[1, 2, 3], [4, 5, 6]]), axis=1) + + +def _add_workload_ones_like(array_pool): OpArgMngr.add_workload('ones_like', array_pool['4x1']) + + +def _add_workload_prod(array_pool): OpArgMngr.add_workload('prod', array_pool['4x1']) + +def _add_workload_repeat(array_pool): OpArgMngr.add_workload('repeat', array_pool['4x1'], 3) OpArgMngr.add_workload('repeat', np.array(_np.arange(12).reshape(4, 3)[:, 2]), 3) - m = _np.array([1, 2, 3, 4, 5, 6]) m_rect = m.reshape((2, 3)) @@ -175,13 +289,15 @@ def _prepare_workloads(): OpArgMngr.add_workload('repeat', np.array(a), 2, axis=axis) # OpArgMngr.add_workload('repeat', np.array(a), [2], axis=axis) # Argument "repeats" only supports int + +def _add_workload_reshape(): arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]) OpArgMngr.add_workload('reshape', arr, (2, 6)) OpArgMngr.add_workload('reshape', arr, (3, 4)) # OpArgMngr.add_workload('reshape', arr, (3, 4), order='F') # Items are not equal with order='F' OpArgMngr.add_workload('reshape', arr, (3, 4), order='C') OpArgMngr.add_workload('reshape', np.array(_np.ones(100)), (100, 1, 1)) - + # test_reshape_order a = np.array(_np.arange(6)) # OpArgMngr.add_workload('reshape', a, (2, 3), order='F') # Items are not equal with order='F' @@ -192,9 +308,13 @@ def _prepare_workloads(): a = np.array(_np.ones((0, 2))) OpArgMngr.add_workload('reshape', a, -1, 2) + +def _add_workload_rint(array_pool): OpArgMngr.add_workload('rint', np.array(4607998452777363968)) OpArgMngr.add_workload('rint', array_pool['4x1']) + +def _add_workload_roll(): # test_roll1d(self) OpArgMngr.add_workload('roll', np.array(_np.arange(10)), 2) @@ -219,38 +339,221 @@ def _prepare_workloads(): OpArgMngr.add_workload('roll', x2, -4, axis=1) # # test_roll_empty OpArgMngr.add_workload('roll', np.array([]), 1) - + + +def _add_workload_split(array_pool): OpArgMngr.add_workload('split', array_pool['4x1'], 2) + + +def _add_workload_squeeze(array_pool): OpArgMngr.add_workload('squeeze', array_pool['4x1']) + + +def _add_workload_stack(array_pool): OpArgMngr.add_workload('stack', [array_pool['4x1']] * 2) + + +def _add_workload_std(array_pool): OpArgMngr.add_workload('std', array_pool['4x1']) - OpArgMngr.add_workload('sum', array_pool['4x1']) + + +def _add_workload_sum(): + # OpArgMngr.add_workload('sum', np.ones(101, dtype=bool)) + OpArgMngr.add_workload('sum', np.arange(1, 10).reshape((3, 3)), axis=1, keepdims=True) + OpArgMngr.add_workload('sum', np.ones(500, dtype=np.float32)/10.) + OpArgMngr.add_workload('sum', np.ones(500, dtype=np.float64)/10.) + for dt in (np.float16, np.float32, np.float64): + for v in (0, 1, 2, 7, 8, 9, 15, 16, 19, 127, + 128, 1024, 1235): + d = np.arange(1, v + 1, dtype=dt) + OpArgMngr.add_workload('sum', d) + d = np.ones(500, dtype=dt) + OpArgMngr.add_workload('sum', d[::2]) + OpArgMngr.add_workload('sum', d[1::2]) + OpArgMngr.add_workload('sum', d[::3]) + OpArgMngr.add_workload('sum', d[1::3]) + OpArgMngr.add_workload('sum', d[::-2]) + OpArgMngr.add_workload('sum', d[-1::-2]) + OpArgMngr.add_workload('sum', d[::-3]) + OpArgMngr.add_workload('sum', d[-1::-3]) + d = np.ones((1,), dtype=dt) + d += d + OpArgMngr.add_workload('sum', d) + # OpArgMngr.add_workload('sum', np.array([3]), initial=2) + # OpArgMngr.add_workload('sum', np.array([0.2]), initial=0.1) + + +def _add_workload_swapaxes(array_pool): OpArgMngr.add_workload('swapaxes', array_pool['4x1'], 0, 1) + + +def _add_workload_take(): + OpArgMngr.add_workload('take', np.array([[1, 2], [3, 4]], dtype=int), np.array([], int)) + for mode in ['wrap', 'clip']: + OpArgMngr.add_workload('take', np.array([[1, 2], [3, 4]], dtype=int), np.array(-1, int), mode=mode) + OpArgMngr.add_workload('take', np.array([[1, 2], [3, 4]], dtype=int), np.array(4, int), mode=mode) + OpArgMngr.add_workload('take', np.array([[1, 2], [3, 4]], dtype=int), np.array([-1], int), mode=mode) + OpArgMngr.add_workload('take', np.array([[1, 2], [3, 4]], dtype=int), np.array([4], int), mode=mode) + x = (np.random.normal(size=24)*100).reshape((2, 3, 4)) + # OpArgMngr.add_workload('take', x, np.array([-1], int), axis=0) + OpArgMngr.add_workload('take', x, np.array([-1], int), axis=0, mode='clip') + OpArgMngr.add_workload('take', x, np.array([2], int), axis=0, mode='clip') + OpArgMngr.add_workload('take', x, np.array([-1], int), axis=0, mode='wrap') + OpArgMngr.add_workload('take', x, np.array([2], int), axis=0, mode='wrap') + OpArgMngr.add_workload('take', x, np.array([3], int), axis=0, mode='wrap') + + +def _add_workload_tensordot(array_pool): OpArgMngr.add_workload('tensordot', array_pool['4x1'], array_pool['4x1']) + + +def _add_workload_tile(array_pool): OpArgMngr.add_workload('tile', array_pool['4x1'], 2) OpArgMngr.add_workload('tile', np.array([[[]]]), (3, 2, 5)) + + +def _add_workload_transpose(array_pool): OpArgMngr.add_workload('transpose', array_pool['4x1']) + + +def _add_workload_unique(): + OpArgMngr.add_workload('unique', np.array([5, 7, 1, 2, 1, 5, 7]*10), True, True, True) + OpArgMngr.add_workload('unique', np.array([]), True, True, True) + OpArgMngr.add_workload('unique', np.array([[0, 1, 0], [0, 1, 0]])) + OpArgMngr.add_workload('unique', np.array([[0, 1, 0], [0, 1, 0]]), axis=0) + OpArgMngr.add_workload('unique', np.array([[0, 1, 0], [0, 1, 0]]), axis=1) + # OpArgMngr.add_workload('unique', np.arange(10, dtype=np.uint8).reshape(-1, 2).astype(bool), axis=1) + + +def _add_workload_var(array_pool): OpArgMngr.add_workload('var', array_pool['4x1']) + + +def _add_workload_zeros_like(array_pool): OpArgMngr.add_workload('zeros_like', array_pool['4x1']) + + +def _add_workload_outer(): OpArgMngr.add_workload('outer', np.ones((5)), np.ones((2))) + + +def _add_workload_meshgrid(): OpArgMngr.add_workload('meshgrid', np.array([1, 2, 3])) OpArgMngr.add_workload('meshgrid', np.array([1, 2, 3]), np.array([4, 5, 6, 7])) OpArgMngr.add_workload('meshgrid', np.array([1, 2, 3]), np.array([4, 5, 6, 7]), indexing='ij') - _add_workload_einsum() - # workloads for array ufunc protocol + +def _add_workload_abs(): + OpArgMngr.add_workload('abs', np.random.uniform(size=(11,)).astype(np.float32)) + OpArgMngr.add_workload('abs', np.random.uniform(size=(5,)).astype(np.float64)) + OpArgMngr.add_workload('abs', np.array([np.inf, -np.inf, np.nan])) + + +def _add_workload_add(array_pool): OpArgMngr.add_workload('add', array_pool['4x1'], array_pool['1x2']) OpArgMngr.add_workload('add', array_pool['4x1'], 2) OpArgMngr.add_workload('add', 2, array_pool['4x1']) OpArgMngr.add_workload('add', array_pool['4x1'], array_pool['1x1x0']) + + +def _add_workload_arctan2(): + OpArgMngr.add_workload('arctan2', np.array([1, -1, 1]), np.array([1, 1, -1])) + OpArgMngr.add_workload('arctan2', np.array([np.PZERO, np.NZERO]), np.array([np.NZERO, np.NZERO])) + OpArgMngr.add_workload('arctan2', np.array([np.PZERO, np.NZERO]), np.array([np.PZERO, np.PZERO])) + OpArgMngr.add_workload('arctan2', np.array([np.PZERO, np.NZERO]), np.array([-1, -1])) + OpArgMngr.add_workload('arctan2', np.array([np.PZERO, np.NZERO]), np.array([1, 1])) + OpArgMngr.add_workload('arctan2', np.array([-1, -1]), np.array([np.PZERO, np.NZERO])) + OpArgMngr.add_workload('arctan2', np.array([1, 1]), np.array([np.PZERO, np.NZERO])) + OpArgMngr.add_workload('arctan2', np.array([1, -1, 1, -1]), np.array([-np.inf, -np.inf, np.inf, np.inf])) + OpArgMngr.add_workload('arctan2', np.array([np.inf, -np.inf]), np.array([1, 1])) + OpArgMngr.add_workload('arctan2', np.array([np.inf, -np.inf]), np.array([-np.inf, -np.inf])) + OpArgMngr.add_workload('arctan2', np.array([np.inf, -np.inf]), np.array([np.inf, np.inf])) + + +def _add_workload_copysign(): + OpArgMngr.add_workload('copysign', np.array([1, 0, 0]), np.array([-1, -1, 1])) + OpArgMngr.add_workload('copysign', np.array([-2, 5, 1, 4, 3], dtype=np.float16), np.array([0, 1, 2, 4, 2], dtype=np.float16)) + + +def _add_workload_degrees(): + OpArgMngr.add_workload('degrees', np.array(np.pi)) + OpArgMngr.add_workload('degrees', np.array(-0.5*np.pi)) + + +def _add_workload_true_divide(): + for dt in [np.float32, np.float64, np.float16]: + OpArgMngr.add_workload('true_divide', np.array([10, 10, -10, -10], dt), np.array([20, -20, 20, -20], dt)) + + +def _add_workload_inner(): + OpArgMngr.add_workload('inner', np.zeros(shape=(1, 80), dtype=np.float64), np.zeros(shape=(1, 80), dtype=np.float64)) + for dt in [np.float32, np.float64]: + # OpArgMngr.add_workload('inner', np.array(3, dtype=dt)[()], np.array([1, 2], dtype=dt)) + # OpArgMngr.add_workload('inner', np.array([1, 2], dtype=dt), np.array(3, dtype=dt)[()]) + A = np.array([[1, 2], [3, 4]], dtype=dt) + B = np.array([[1, 3], [2, 4]], dtype=dt) + C = np.array([1, 1], dtype=dt) + OpArgMngr.add_workload('inner', A.T, C) + OpArgMngr.add_workload('inner', C, A.T) + OpArgMngr.add_workload('inner', B, C) + OpArgMngr.add_workload('inner', C, B) + OpArgMngr.add_workload('inner', A, B) + OpArgMngr.add_workload('inner', A, A) + OpArgMngr.add_workload('inner', A, A.copy()) + a = np.arange(5).astype(dt) + b = a[::-1] + OpArgMngr.add_workload('inner', b, a) + a = np.arange(24).reshape(2,3,4).astype(dt) + b = np.arange(24, 48).reshape(2,3,4).astype(dt) + OpArgMngr.add_workload('inner', a, b) + OpArgMngr.add_workload('inner', b, a) + + +def _add_workload_hypot(): + OpArgMngr.add_workload('hypot', np.array(1), np.array(1)) + OpArgMngr.add_workload('hypot', np.array(0), np.array(0)) + OpArgMngr.add_workload('hypot', np.array(np.nan), np.array(np.nan)) + OpArgMngr.add_workload('hypot', np.array(np.nan), np.array(1)) + OpArgMngr.add_workload('hypot', np.array(np.nan), np.array(np.inf)) + OpArgMngr.add_workload('hypot', np.array(np.inf), np.array(np.nan)) + OpArgMngr.add_workload('hypot', np.array(np.inf), np.array(0)) + OpArgMngr.add_workload('hypot', np.array(0), np.array(np.inf)) + OpArgMngr.add_workload('hypot', np.array(np.inf), np.array(np.inf)) + OpArgMngr.add_workload('hypot', np.array(np.inf), np.array(23.0)) + + +def _add_workload_lcm(): + OpArgMngr.add_workload('lcm', np.array([12, 120], dtype=np.int8), np.array([20, 200], dtype=np.int8)) + OpArgMngr.add_workload('lcm', np.array([12, 120], dtype=np.uint8), np.array([20, 200], dtype=np.uint8)) + OpArgMngr.add_workload('lcm', np.array(195225786*2, dtype=np.int32), np.array(195225786*5, dtype=np.int32)) + + +def _add_workload_ldexp(): + OpArgMngr.add_workload('ldexp', np.array(2., np.float32), np.array(3, np.int8)) + OpArgMngr.add_workload('ldexp', np.array(2., np.float64), np.array(3, np.int8)) + OpArgMngr.add_workload('ldexp', np.array(2., np.float32), np.array(3, np.int32)) + OpArgMngr.add_workload('ldexp', np.array(2., np.float64), np.array(3, np.int32)) + OpArgMngr.add_workload('ldexp', np.array(2., np.float32), np.array(3, np.int64)) + OpArgMngr.add_workload('ldexp', np.array(2., np.float64), np.array(3, np.int64)) + OpArgMngr.add_workload('ldexp', np.array(2., np.float64), np.array(9223372036854775807, np.int64)) + OpArgMngr.add_workload('ldexp', np.array(2., np.float64), np.array(-9223372036854775808, np.int64)) + + +def _add_workload_subtract(array_pool): OpArgMngr.add_workload('subtract', array_pool['4x1'], array_pool['1x2']) OpArgMngr.add_workload('subtract', array_pool['4x1'], 2) OpArgMngr.add_workload('subtract', 2, array_pool['4x1']) OpArgMngr.add_workload('subtract', array_pool['4x1'], array_pool['1x1x0']) + + +def _add_workload_multiply(array_pool): OpArgMngr.add_workload('multiply', array_pool['4x1'], array_pool['1x2']) OpArgMngr.add_workload('multiply', array_pool['4x1'], 2) OpArgMngr.add_workload('multiply', 2, array_pool['4x1']) OpArgMngr.add_workload('multiply', array_pool['4x1'], array_pool['1x1x0']) + + +def _add_workload_power(array_pool): OpArgMngr.add_workload('power', array_pool['4x1'], array_pool['1x2']) OpArgMngr.add_workload('power', array_pool['4x1'], 2) OpArgMngr.add_workload('power', 2, array_pool['4x1']) @@ -258,14 +561,19 @@ def _prepare_workloads(): OpArgMngr.add_workload('power', np.array([1, 2, 3], np.int32), 2.00001) OpArgMngr.add_workload('power', np.array([15, 15], np.int64), np.array([15, 15], np.int64)) OpArgMngr.add_workload('power', 0, np.arange(1, 10)) + + +def _add_workload_mod(array_pool): OpArgMngr.add_workload('mod', array_pool['4x1'], array_pool['1x2']) OpArgMngr.add_workload('mod', array_pool['4x1'], 2) OpArgMngr.add_workload('mod', 2, array_pool['4x1']) OpArgMngr.add_workload('mod', array_pool['4x1'], array_pool['1x1x0']) + +def _add_workload_remainder(): # test remainder basic OpArgMngr.add_workload('remainder', np.array([0, 1, 2, 4, 2], dtype=np.float16), - np.array([-2, 5, 1, 4, 3], dtype=np.float16)) + np.array([-2, 5, 1, 4, 3], dtype=np.float16)) def _signs(dt): if dt in [np.uint8]: @@ -273,7 +581,7 @@ def _signs(dt): else: return (+1, -1) - for ct in dt: + for ct in _DTYPES: for sg1, sg2 in itertools.product(_signs(ct), _signs(ct)): a = np.array(sg1*71, dtype=ct) b = np.array(sg2*19, dtype=ct) @@ -293,9 +601,9 @@ def _signs(dt): fa = a.astype(dt) fb = b.astype(dt) OpArgMngr.add_workload('remainder', fa, fb) - + # test_float_remainder_roundoff - for ct in dt_float: + for ct in _FLOAT_DTYPES: for sg1, sg2 in itertools.product((+1, -1), (+1, -1)): a = np.array(sg1*78*6e-8, dtype=ct) b = np.array(sg2*6e-8, dtype=ct) @@ -303,80 +611,262 @@ def _signs(dt): # test_float_remainder_corner_cases # Check remainder magnitude. - for ct in dt_float: + for ct in _FLOAT_DTYPES: b = _np.array(1.0) a = np.array(_np.nextafter(_np.array(0.0), -b), dtype=ct) b = np.array(b, dtype=ct) OpArgMngr.add_workload('remainder', a, b) OpArgMngr.add_workload('remainder', -a, -b) - # Check nans, inf - for ct in [np.float16, np.float32, np.float64]: - fone = np.array(1.0, dtype=ct) - fzer = np.array(0.0, dtype=ct) - finf = np.array(np.inf, dtype=ct) - fnan = np.array(np.nan, dtype=ct) - # OpArgMngr.add_workload('remainder', fone, fzer) # failed - OpArgMngr.add_workload('remainder', fone, fnan) - OpArgMngr.add_workload('remainder', finf, fone) + # Check nans, inf + for ct in [np.float16, np.float32, np.float64]: + fone = np.array(1.0, dtype=ct) + fzer = np.array(0.0, dtype=ct) + finf = np.array(np.inf, dtype=ct) + fnan = np.array(np.nan, dtype=ct) + # OpArgMngr.add_workload('remainder', fone, fzer) # failed + OpArgMngr.add_workload('remainder', fone, fnan) + OpArgMngr.add_workload('remainder', finf, fone) + +def _add_workload_maximum(array_pool): OpArgMngr.add_workload('maximum', array_pool['4x1'], array_pool['1x2']) OpArgMngr.add_workload('maximum', array_pool['4x1'], 2) OpArgMngr.add_workload('maximum', 2, array_pool['4x1']) OpArgMngr.add_workload('maximum', array_pool['4x1'], array_pool['1x1x0']) + + +def _add_workload_minimum(array_pool): OpArgMngr.add_workload('minimum', array_pool['4x1'], array_pool['1x2']) OpArgMngr.add_workload('minimum', array_pool['4x1'], 2) OpArgMngr.add_workload('minimum', 2, array_pool['4x1']) OpArgMngr.add_workload('minimum', array_pool['4x1'], array_pool['1x1x0']) + + +def _add_workload_negative(array_pool): OpArgMngr.add_workload('negative', array_pool['4x1']) + + +def _add_workload_absolute(array_pool): OpArgMngr.add_workload('absolute', array_pool['4x1']) - + + +def _add_workload_sign(array_pool): OpArgMngr.add_workload('sign', array_pool['4x1']) OpArgMngr.add_workload('sign', np.array([-2, 5, 1, 4, 3], dtype=np.float16)) OpArgMngr.add_workload('sign', np.array([-.1, 0, .1])) # OpArgMngr.add_workload('sign', np.array(_np.array([_np.nan]))) # failed + +def _add_workload_exp(array_pool): OpArgMngr.add_workload('exp', array_pool['4x1']) + + +def _add_workload_log(array_pool): OpArgMngr.add_workload('log', array_pool['4x1']) + + +def _add_workload_log2(array_pool): OpArgMngr.add_workload('log2', array_pool['4x1']) OpArgMngr.add_workload('log2', np.array(2.**65)) OpArgMngr.add_workload('log2', np.array(np.inf)) OpArgMngr.add_workload('log2', np.array(1.)) + + +def _add_workload_log1p(): OpArgMngr.add_workload('log1p', np.array(-1.)) OpArgMngr.add_workload('log1p', np.array(np.inf)) OpArgMngr.add_workload('log1p', np.array(1e-6)) + + +def _add_workload_log10(array_pool): OpArgMngr.add_workload('log10', array_pool['4x1']) + + +def _add_workload_expm1(array_pool): OpArgMngr.add_workload('expm1', array_pool['4x1']) - OpArgMngr.add_workload('sqrt', array_pool['4x1']) - OpArgMngr.add_workload('square', array_pool['4x1']) - OpArgMngr.add_workload('cbrt', array_pool['4x1']) + +def _add_workload_sqrt(): + OpArgMngr.add_workload('sqrt', np.array([1, np.PZERO, np.NZERO, np.inf, np.nan])) + + +def _add_workload_square(): + OpArgMngr.add_workload('square', np.array([-2, 5, 1, 4, 3], dtype=np.float16)) + + +def _add_workload_cbrt(): + OpArgMngr.add_workload('cbrt', np.array(-2.5**3, dtype=np.float32)) + OpArgMngr.add_workload('cbrt', np.array([1., 2., -3., np.inf, -np.inf])**3) + OpArgMngr.add_workload('cbrt', np.array([np.inf, -np.inf, np.nan])) + + +def _add_workload_reciprocal(): for ctype in [np.float16, np.float32, np.float64]: OpArgMngr.add_workload('reciprocal', np.array([-2, 5, 1, 4, 3], dtype=ctype)) OpArgMngr.add_workload('reciprocal', np.array([-2, 0, 1, 0, 3], dtype=ctype)) OpArgMngr.add_workload('reciprocal', np.array([0], dtype=ctype)) + +def _add_workload_sin(array_pool): OpArgMngr.add_workload('sin', array_pool['4x1']) + + +def _add_workload_cos(array_pool): OpArgMngr.add_workload('cos', array_pool['4x1']) + + +def _add_workload_tan(array_pool): OpArgMngr.add_workload('tan', array_pool['4x1']) + + +def _add_workload_sinh(array_pool): OpArgMngr.add_workload('sinh', array_pool['4x1']) + + +def _add_workload_cosh(array_pool): OpArgMngr.add_workload('cosh', array_pool['4x1']) + + +def _add_workload_tanh(array_pool): OpArgMngr.add_workload('tanh', array_pool['4x1']) + + +def _add_workload_arcsin(array_pool): OpArgMngr.add_workload('arcsin', array_pool['4x1'] - 2) + + +def _add_workload_arccos(array_pool): OpArgMngr.add_workload('arccos', array_pool['4x1'] - 2) + + +def _add_workload_arctan(array_pool): OpArgMngr.add_workload('arctan', array_pool['4x1']) + + +def _add_workload_arcsinh(array_pool): OpArgMngr.add_workload('arcsinh', array_pool['4x1']) + + +def _add_workload_arccosh(array_pool): OpArgMngr.add_workload('arccosh', array_pool['4x1']) + + +def _add_workload_arctanh(array_pool): OpArgMngr.add_workload('arctanh', array_pool['4x1'] - 2) + + +def _add_workload_ceil(array_pool): OpArgMngr.add_workload('ceil', array_pool['4x1']) + + +def _add_workload_turnc(array_pool): OpArgMngr.add_workload('trunc', array_pool['4x1']) + + +def _add_workload_floor(array_pool): OpArgMngr.add_workload('floor', array_pool['4x1']) + + +def _add_workload_logical_not(array_pool): OpArgMngr.add_workload('logical_not', np.ones(10, dtype=np.int32)) OpArgMngr.add_workload('logical_not', array_pool['4x1']) OpArgMngr.add_workload('logical_not', np.array([True, False, True, False], dtype=np.bool)) - +@use_np +def _prepare_workloads(): + array_pool = { + '4x1': np.random.uniform(size=(4, 1)) + 2, + '1x2': np.random.uniform(size=(1, 2)) + 2, + '1x1x0': np.array([[[]]]) + } + + _add_workload_argmax() + _add_workload_around() + _add_workload_broadcast_arrays(array_pool) + _add_workload_broadcast_to() + _add_workload_clip() + _add_workload_concatenate(array_pool) + _add_workload_copy(array_pool) + _add_workload_cumsum() + _add_workload_ravel() + _add_workload_dot() + _add_workload_expand_dims(array_pool) + _add_workload_fix() + _add_workload_flip() + _add_workload_max(array_pool) + _add_workload_min(array_pool) + _add_workload_mean(array_pool) + _add_workload_ones_like(array_pool) + _add_workload_prod(array_pool) + _add_workload_repeat(array_pool) + _add_workload_reshape() + _add_workload_rint(array_pool) + _add_workload_roll() + _add_workload_split(array_pool) + _add_workload_squeeze(array_pool) + _add_workload_stack(array_pool) + _add_workload_std(array_pool) + _add_workload_sum() + _add_workload_swapaxes(array_pool) + _add_workload_take() + _add_workload_tensordot(array_pool) + _add_workload_tile(array_pool) + _add_workload_transpose(array_pool) + _add_workload_unique() + _add_workload_var(array_pool) + _add_workload_zeros_like(array_pool) + _add_workload_outer() + _add_workload_meshgrid() + _add_workload_einsum() + _add_workload_abs() + _add_workload_add(array_pool) + _add_workload_arctan2() + _add_workload_copysign() + _add_workload_degrees() + _add_workload_true_divide() + _add_workload_inner() + _add_workload_hypot() + _add_workload_lcm() + _add_workload_ldexp() + _add_workload_subtract(array_pool) + _add_workload_multiply(array_pool) + _add_workload_power(array_pool) + _add_workload_mod(array_pool) + _add_workload_remainder() + _add_workload_maximum(array_pool) + _add_workload_minimum(array_pool) + _add_workload_negative(array_pool) + _add_workload_absolute(array_pool) + _add_workload_sign(array_pool) + _add_workload_exp(array_pool) + _add_workload_log(array_pool) + _add_workload_log2(array_pool) + _add_workload_log1p() + _add_workload_log10(array_pool) + _add_workload_expm1(array_pool) + _add_workload_sqrt() + _add_workload_square() + _add_workload_cbrt() + _add_workload_reciprocal() + _add_workload_sin(array_pool) + _add_workload_cos(array_pool) + _add_workload_tan(array_pool) + _add_workload_sinh(array_pool) + _add_workload_cosh(array_pool) + _add_workload_tanh(array_pool) + _add_workload_arcsin(array_pool) + _add_workload_arccos(array_pool) + _add_workload_arctan(array_pool) + _add_workload_arcsinh(array_pool) + _add_workload_arccosh(array_pool) + _add_workload_arctanh(array_pool) + _add_workload_ceil(array_pool) + _add_workload_turnc(array_pool) + _add_workload_floor(array_pool) + _add_workload_logical_not(array_pool) + _prepare_workloads() @@ -416,6 +906,7 @@ def _check_interoperability_helper(op_name, *args, **kwargs): def check_interoperability(op_list): for name in op_list: + print('Dispatch test:', name) workloads = OpArgMngr.get_workloads(name) assert workloads is not None, 'Workloads for operator `{}` has not been ' \ 'added for checking interoperability with ' \