From ab6a3f9c4ad0f64899adc81a7e9d0b8696445950 Mon Sep 17 00:00:00 2001 From: NathanYyc <39988193+NathanYyc@users.noreply.github.com> Date: Thu, 14 Oct 2021 11:05:45 +0800 Subject: [PATCH] [API Standardization]Standardize MXNet NumPy Statistical & Linalg Functions (#20592) * [Website] Fix website publish (#20573) * fix website publish * update * remove .asf.yaml from version/master * force include .asf.yaml * include .htaccess * add .asf.yaml check in CI * change linalg & statical funcs * add vecdot * changes made * changes made * changes made * changes made * delete test vecdot * fixed lint add radd rand ror rxor * fixed lint error * fixed lint error * fixed problems * delete 'vecdot' in __all__ * fixed acosh doc * fixed tensordot bug add vecdot notes * add line in line 58 * add line in line 4254 * add line in 5423,9080 in multiarray add line in 260 in test_numpy_op * Update python/mxnet/numpy/multiarray.py Co-authored-by: Zhenghui Jin <69359374+barry-jin@users.noreply.github.com> * solve typo * add wrap_data_api_linalg_func in line 1335 & 1205 Co-authored-by: Zhenghui Jin <69359374+barry-jin@users.noreply.github.com> --- python/mxnet/numpy/linalg.py | 72 ++- python/mxnet/numpy/multiarray.py | 634 ++++++++++++++++++++++++- python/mxnet/util.py | 48 ++ tests/python/unittest/test_numpy_op.py | 58 +-- 4 files changed, 754 insertions(+), 58 deletions(-) diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py index ea4c9d5e3d0a..a9c0f9b38313 100644 --- a/python/mxnet/numpy/linalg.py +++ b/python/mxnet/numpy/linalg.py @@ -18,12 +18,13 @@ """Namespace for ops used in imperative programming.""" from ..ndarray import numpy as _mx_nd_np +from ..util import wrap_data_api_linalg_func from .fallback_linalg import * # pylint: disable=wildcard-import,unused-wildcard-import from . import fallback_linalg __all__ = ['norm', 'svd', 'cholesky', 'qr', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve', 'pinv', 'eigvals', 'eig', 'eigvalsh', 'eigh', 'lstsq', 'matrix_rank', 'cross', 'diagonal', 'outer', - 'tensordot', 'trace', 'matrix_transpose'] + 'tensordot', 'trace', 'matrix_transpose', 'vecdot'] __all__ += fallback_linalg.__all__ @@ -373,6 +374,59 @@ def outer(a, b): return _mx_nd_np.tensordot(a.flatten(), b.flatten(), 0) +def vecdot(a, b, axis=None): + r""" + Return the dot product of two vectors. + Note that `vecdot` handles multidimensional arrays differently than `dot`: + it does *not* perform a matrix product, but flattens input arguments + to 1-D vectors first. Consequently, it should only be used for vectors. + + Notes + ---------- + `vecdot` is a alias for `vdot`. It is a standard API in + https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#vecdot-x1-x2-axis-1 + instead of an official NumPy operator. + + Parameters + ---------- + a : ndarray + First argument to the dot product. + b : ndarray + Second argument to the dot product. + axis : axis over which to compute the dot product. Must be an integer on + the interval [-N, N) , where N is the rank (number of dimensions) of + the shape determined according to Broadcasting . If specified as a + negative integer, the function must determine the axis along which + to compute the dot product by counting backward from the last dimension + (where -1 refers to the last dimension). If None , the function must + compute the dot product over the last axis. Default: None . + + Returns + ------- + output : ndarray + Dot product of `a` and `b`. + + See Also + -------- + dot : Return the dot product without using the complex conjugate of the + first argument. + + Examples + -------- + Note that higher-dimensional arrays are flattened! + + >>> a = np.array([[1, 4], [5, 6]]) + >>> b = np.array([[4, 1], [2, 2]]) + >>> np.linalg.vecdot(a, b) + array(30.) + >>> np.linalg.vecdot(b, a) + array(30.) + >>> 1*4 + 4*1 + 5*2 + 6*2 + 30 + """ + return _mx_nd_np.tensordot(a.flatten(), b.flatten(), axis) + + def lstsq(a, b, rcond='warn'): r""" Return the least-squares solution to a linear matrix equation. @@ -1148,7 +1202,8 @@ def eigvals(a): return _mx_nd_np.linalg.eigvals(a) -def eigvalsh(a, UPLO='L'): +@wrap_data_api_linalg_func +def eigvalsh(a, upper=False): r""" Compute the eigenvalues real symmetric matrix. @@ -1203,6 +1258,10 @@ def eigvalsh(a, UPLO='L'): >>> LA.eigvalsh(a, UPLO='L') array([-2.87381886, 5.10144682, 6.38623114]) # in ascending order """ + if not upper: + UPLO = 'L' + else: + UPLO = 'U' return _mx_nd_np.linalg.eigvalsh(a, UPLO) @@ -1273,7 +1332,8 @@ def eig(a): return _mx_nd_np.linalg.eig(a) -def eigh(a, UPLO='L'): +@wrap_data_api_linalg_func +def eigh(a, upper=False): r""" Return the eigenvalues and eigenvectors real symmetric matrix. @@ -1329,7 +1389,7 @@ def eigh(a, UPLO='L'): >>> a = np.array([[ 6.8189726 , -3.926585 , 4.3990498 ], ... [-0.59656644, -1.9166266 , 9.54532 ], ... [ 2.1093285 , 0.19688708, -1.1634291 ]]) - >>> w, v = LA.eigh(a, UPLO='L') + >>> w, v = LA.eigh(a, upper=False) >>> w array([-2.175445 , -1.4581827, 7.3725457]) >>> v @@ -1337,4 +1397,8 @@ def eigh(a, UPLO='L'): [ 0.8242942 , 0.56326365, -0.05721384], [-0.53661287, 0.80949366, 0.23825769]]) """ + if not upper: + UPLO = 'L' + else: + UPLO = 'U' return _mx_nd_np.linalg.eigh(a, UPLO) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index c2d9db95f471..fadac105d69b 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -45,7 +45,7 @@ from ..runtime import Features from ..context import Context from ..util import set_module, wrap_np_unary_func, wrap_np_binary_func,\ - is_np_default_dtype + is_np_default_dtype, wrap_data_api_statical_func from ..context import current_context from ..ndarray import numpy as _mx_nd_np from ..ndarray.numpy import _internal as _npi @@ -58,17 +58,17 @@ __all__ = ['ndarray', 'empty', 'empty_like', 'array', 'shape', 'median', 'zeros', 'zeros_like', 'ones', 'ones_like', 'full', 'full_like', 'all', 'any', 'broadcast_to', - 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'fmod', 'power', 'bitwise_not', + 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'fmod', 'pow', 'power', 'bitwise_not', 'delete', 'trace', 'transpose', 'copy', 'moveaxis', 'reshape', 'dot', - 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'invert', - 'sqrt', 'cbrt', 'abs', 'absolute', 'fabs', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', - 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'histogram', - 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', 'argsort', - 'sort', 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', - 'array_split', 'split', 'hsplit', 'vsplit', 'dsplit', 'flatnonzero', 'tril_indices', - 'concatenate', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack', - 'average', 'mean', 'maximum', 'fmax', 'minimum', 'fmin', 'amax', 'amin', 'max', 'min', - 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'insert', + 'arctan2', 'atan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'bitwise_invert', 'invert', + 'sqrt', 'cbrt', 'abs', 'absolute', 'fabs', 'exp', 'expm1', 'arcsin', 'asin', 'arccos', 'acos', 'arctan', + 'atan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', + 'negative', 'histogram', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'asinh', + 'arccosh', 'acosh', 'arctanh', 'atanh', 'append', 'argsort', 'sort', 'tensordot', 'eye', 'linspace', + 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'hsplit', 'vsplit', + 'dsplit', 'flatnonzero', 'tril_indices', 'concatenate', 'concat', 'stack', 'vstack', 'row_stack', + 'column_stack', 'hstack', 'dstack', 'average', 'mean', 'maximum', 'fmax', 'minimum', 'fmin', + 'amax', 'amin', 'max', 'min', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'insert', 'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman', 'logical_and', 'logical_or', 'logical_xor', 'flip', 'flipud', 'fliplr', 'around', 'round', 'round_', 'arctan2', 'hypot', @@ -1017,6 +1017,11 @@ def __iadd__(self, other): raise ValueError('trying to add to a readonly ndarray') return add(self, other, out=self) + @wrap_mxnp_np_ufunc + def __radd__(self, other): + """x.__radd__(y) <=> y + x""" + return add(other, self) + def __invert__(self): """x.__invert__() <=> ~x""" return invert(self) @@ -1026,16 +1031,31 @@ def __and__(self, other): """x.__and__(y) <=> x & y""" return bitwise_and(self, other) + @wrap_mxnp_np_ufunc + def __rand__(self, other): + """x.__rand__(y) <=> y & x""" + return bitwise_and(other, self) + @wrap_mxnp_np_ufunc def __or__(self, other): """x.__or__(y) <=> x | y""" return bitwise_or(self, other) + @wrap_mxnp_np_ufunc + def __ror__(self, other): + """x.__ror__(y) <=> y | x""" + return bitwise_or(other, self) + @wrap_mxnp_np_ufunc def __xor__(self, other): """x.__xor__(y) <=> x ^ y""" return bitwise_xor(self, other) + @wrap_mxnp_np_ufunc + def __rxor__(self, other): + """x.__rxor__(y) <=> y ^ x""" + return bitwise_xor(other, self) + @wrap_mxnp_np_ufunc def __iand__(self, other): """x.__iand__(y) <=> x &= y""" @@ -1163,6 +1183,11 @@ def __rpow__(self, other): """x.__rpow__(y) <=> y ** x""" return power(other, self) + @wrap_mxnp_np_ufunc + def __ipow__(self, other): + """x.__ipow__(y) <=> x **= y""" + return power(self, other, out=self) + @wrap_mxnp_np_ufunc def __eq__(self, other): """x.__eq__(y) <=> x == y""" @@ -2021,13 +2046,16 @@ def mean(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disa return mean(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims) # pylint: disable=too-many-arguments, arguments-differ - def std(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False): + + @wrap_data_api_statical_func + def std(self, axis=None, dtype=None, out=None, correction=0, keepdims=False): """Returns the standard deviation of the array elements along given axis.""" - return std(self, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out) + return std(self, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, out=out) - def var(self, axis=None, dtype=None, out=None, ddof=0, keepdims=False): + @wrap_data_api_statical_func + def var(self, axis=None, dtype=None, out=None, correction=0, keepdims=False): """Returns the variance of the array elements, along given axis.""" - return var(self, axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims) + return var(self, axis=axis, dtype=dtype, out=out, correction=correction, keepdims=keepdims) # pylint: enable=too-many-arguments, arguments-differ def cumsum(self, axis=None, dtype=None, out=None): @@ -3597,6 +3625,7 @@ def remainder(x1, x2, out=None, **kwargs): return _mx_nd_np.remainder(x1, x2, out=out) + @set_module('mxnet.numpy') @wrap_np_binary_func def power(x1, x2, out=None, **kwargs): @@ -3647,6 +3676,61 @@ def power(x1, x2, out=None, **kwargs): """ return _mx_nd_np.power(x1, x2, out=out) +pow = power +pow.__doc_ = """ + First array elements raised to powers from second array, element-wise. + + Notes + ----- + `pow` is an alias for `power`. It is a standard API in + https://data-apis.org/array-api/latest/API_specification/elementwise_functions.html#pow-x1-x2 + instead of an official NumPy operator. + + >>> np.pow is np.power + True + + Parameters + ---------- + x1 : ndarray or scalar + The bases. + + x2 : ndarray or scalar + The exponent. + + out : ndarray + A location into which the result is stored. If provided, it must have a shape + that the inputs broadcast to. If not provided or None, a freshly-allocated array + is returned. + + Returns + ------- + out : ndarray or scalar + The bases in x1 raised to the exponents in x2. + This is a scalar if both x1 and x2 are scalars. + + Examples + -------- + >>> x1 = np.arange(6) + >>> np.pow(x1, 3) + array([ 0., 1., 8., 27., 64., 125.]) + + Raise the bases to different exponents. + + >>> x2 = np.array([1.0, 2.0, 3.0, 3.0, 2.0, 1.0]) + >>> np.pow(x1, x2) + array([ 0., 1., 8., 27., 16., 5.]) + + The effect of broadcasting. + + >>> x2 = np.array([[1, 2, 3, 3, 2, 1], [1, 2, 3, 3, 2, 1]]) + >>> x2 + array([[1., 2., 3., 3., 2., 1.], + [1., 2., 3., 3., 2., 1.]]) + + >>> np.pow(x1, x2) + array([[ 0., 1., 8., 27., 16., 5.], + [ 0., 1., 8., 27., 16., 5.]]) + """ @set_module('mxnet.numpy') @wrap_np_binary_func @@ -4230,6 +4314,66 @@ def arcsin(x, out=None, **kwargs): """ return _mx_nd_np.arcsin(x, out=out, **kwargs) +asin = arcsin +asin.__doc__ = """ + Inverse sine, element-wise. + + >>>np.asin is np.asin + True + + Parameters + ---------- + x : ndarray or scalar + `y`-coordinate on the unit circle. + out : ndarray or None, optional + A location into which the result is stored. + If provided, it must have the same shape as the input. + If not provided or None, a freshly-allocated array is returned. + + Returns + ------- + angle : ndarray or scalar + Output array is same shape and type as x. This is a scalar if x is a scalar. + The inverse sine of each element in `x`, in radians and in the + closed interval ``[-pi/2, pi/2]``. + + Examples + -------- + >>> np.asin(1) # pi/2 + 1.5707963267948966 + >>> np.asin(-1) # -pi/2 + -1.5707963267948966 + >>> np.asin(0) + 0.0 + + .. note:: + `asin` is a alias for `arcsin`. It is a standard API in + https://data-apis.org/array-api/latest/API_specification/elementwise_functions.html#asin-x + instead of an official NumPy operator. + + `asin` is a multivalued function: for each `x` there are infinitely + many numbers `z` such that :math:`sin(z) = x`. The convention is to + return the angle `z` whose real part lies in [-pi/2, pi/2]. + For real-valued input data types, *asin* always returns real output. + For each value that cannot be expressed as a real number or infinity, + it yields ``nan`` and sets the `invalid` floating point error flag. + The inverse sine is also known as `asin` or sin^{-1}. + The output `ndarray` has the same `ctx` as the input `ndarray`. + This function differs from the original `numpy.arcsin + `_ in + the following aspects: + + * Only support ndarray or scalar now. + * `where` argument is not supported. + * Complex input is not supported. + + References + ---------- + Abramowitz, M. and Stegun, I. A., *Handbook of Mathematical Functions*, + 10th printing, New York: Dover, 1964, pp. 79ff. + http://www.math.sfu.ca/~cbm/aands/ + """ + @set_module('mxnet.numpy') @wrap_np_unary_func @@ -4269,6 +4413,47 @@ def arccos(x, out=None, **kwargs): """ return _mx_nd_np.arccos(x, out=out, **kwargs) +acos = arccos +acos.__doc__ = """ + Trigonometric inverse cosine, element-wise. + The inverse of cos so that, if y = cos(x), then x = acos(y). + + >>>np.acos is np.arccos + True + + Parameters + ---------- + x : ndarray + x-coordinate on the unit circle. For real arguments, the domain is [-1, 1]. + out : ndarray, optional + A location into which the result is stored. If provided, it must have a shape that + the inputs broadcast to. If not provided or None, a freshly-allocated array is returned. + A tuple (possible only as a keyword argument) must have length equal to the number of outputs. + + Returns + ---------- + angle : ndarray + The angle of the ray intersecting the unit circle at the given x-coordinate in radians [0, pi]. + This is a scalar if x is a scalar. + + Notes + ---------- + `acos` is a alias for `arccos`. It is a standard API in + https://data-apis.org/array-api/latest/API_specification/elementwise_functions.html#acos-x + instead of an official NumPy operator. + + acos is a multivalued function: for each x there are infinitely many numbers z such that + cos(z) = x. The convention is to return the angle z whose real part lies in [0, pi]. + For real-valued input data types, acos always returns real output. + For each value that cannot be expressed as a real number or infinity, it yields nan and sets + the invalid floating point error flag. + The inverse cos is also known as acos or cos^-1. + + Examples + ---------- + >>> np.acos([1, -1]) + array([ 0. , 3.14159265]) + """ @set_module('mxnet.numpy') @wrap_np_unary_func @@ -4314,6 +4499,54 @@ def arctan(x, out=None, **kwargs): """ return _mx_nd_np.arctan(x, out=out, **kwargs) +atan = arctan +atan.__doc__ = """ + Trigonometric inverse tangent, element-wise. + The inverse of tan, so that if ``y = tan(x)`` then ``x = atan(y)``. + + >>>np.atan is np.arctan + True + + Parameters + ---------- + x : ndarray or scalar + Input values. + out : ndarray or None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + + Returns + ------- + out : ndarray or scalar + Out has the same shape as `x`. It lies is in + ``[-pi/2, pi/2]`` (``atan(+/-inf)`` returns ``+/-pi/2``). + This is a scalar if `x` is a scalar. + + Notes + ----- + `atan` is a alias for `arctan`. It is a standard API in + https://data-apis.org/array-api/latest/API_specification/elementwise_functions.html#atan-x + instead of an official NumPy operator. + + `atan` is a multi-valued function: for each `x` there are infinitely + many numbers `z` such that tan(`z`) = `x`. The convention is to return + the angle `z` whose real part lies in [-pi/2, pi/2]. + For real-valued input data types, `atan` always returns real output. + For each value that cannot be expressed as a real number or infinity, + it yields ``nan`` and sets the `invalid` floating point error flag. + For complex-valued input, we do not have support for them yet. + The inverse tangent is also known as `atan` or tan^{-1}. + + Examples + -------- + >>> x = np.array([0, 1]) + >>> np.atan(x) + array([0. , 0.7853982]) + >>> np.pi/4 + 0.7853981633974483 + """ + @set_module('mxnet.numpy') @wrap_np_unary_func @@ -4958,6 +5191,59 @@ def floor(x, out=None, **kwargs): """ return _mx_nd_np.floor(x, out=out, **kwargs) +@set_module('mxnet.numpy') +@wrap_np_unary_func +def bitwise_invert(x, out=None, **kwargs): + r""" + Compute bit-wise inversion, or bit-wise NOT, element-wise. + Computes the bit-wise NOT of the underlying binary representation of + the integers in the input arrays. This ufunc implements the C/Python + operator ``~``. + + Parameters + ---------- + x : array_like + Only integer and boolean types are handled. + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. A tuple (possible only as a + keyword argument) must have length equal to the number of outputs. + + Returns + ------- + out : ndarray or scalar + Result. + This is a scalar if `x` is a scalar. + + See Also + -------- + bitwise_and, bitwise_or, bitwise_xor + logical_not + binary_repr : + Return the binary representation of the input number as a string. + + Examples + -------- + We've seen that 13 is represented by ``00001101``. + The invert or bit-wise NOT of 13 is then: + + >>> x = np.bitwise_invert(np.array(13, dtype=np.uint8)) + >>> x + 242 + >>> np.binary_repr(x, width=8) + '11110010' + + Notes + ----- + `bitwise_not` is an alias for `invert`: + + >>> np.bitwise_not is np.invert + True + """ + return _mx_nd_np.bitwise_not(x, out=out, **kwargs) + + @set_module('mxnet.numpy') @wrap_np_unary_func def invert(x, out=None, **kwargs): @@ -5188,6 +5474,56 @@ def arcsinh(x, out=None, **kwargs): """ return _mx_nd_np.arcsinh(x, out=out, **kwargs) +asinh = arcsinh +asinh.__doc__ = """ + Inverse hyperbolic cosine, element-wise. + + >>>np.asinh is np.arcsinh + True + + Parameters + ---------- + x : ndarray or scalar + Input array. + out : ndarray or None, optional + A location into which the result is stored. + + Returns + ------- + asinh : ndarray + Array of the same shape as `x`. + This is a scalar if `x` is a scalar. + + .. note:: + `asinh` is a alias for `arcsinh`. It is a standard API in + https://data-apis.org/array-api/latest/API_specification/elementwise_functions.html#asinh-x + instead of an official NumPy operator. + + `asinh` is a multivalued function: for each `x` there are infinitely + many numbers `z` such that `sinh(z) = x`. + + For real-valued input data types, `asinh` always returns real output. + For each value that cannot be expressed as a real number or infinity, it + yields ``nan`` and sets the `invalid` floating point error flag. + + This function differs from the original numpy.arcsinh in the following aspects: + + * Do not support `where`, a parameter in numpy which indicates where to calculate. + * Do not support complex-valued input. + * Cannot cast type automatically. DType of `out` must be same as the expected one. + * Cannot broadcast automatically. Shape of `out` must be same as the expected one. + * If `x` is plain python numeric, the result won't be stored in out. + + Examples + -------- + >>> a = np.array([3.2, 5.0]) + >>> np.asinh(a) + array([1.8309381, 2.2924316]) + + >>> np.asinh(1) + 0.0 + """ + @set_module('mxnet.numpy') @wrap_np_unary_func @@ -5235,6 +5571,55 @@ def arccosh(x, out=None, **kwargs): """ return _mx_nd_np.arccosh(x, out=out, **kwargs) +acosh = arccosh +acosh.__doc__ = """ + Inverse hyperbolic cosine, element-wise. + + >>>np.acosh is np.arccosh + True + + Parameters + ---------- + x : ndarray or scalar + Input array. + out : ndarray or None, optional + A location into which the result is stored. + + Returns + ------- + acosh : ndarray + Array of the same shape as `x`. + This is a scalar if `x` is a scalar. + + .. note:: + `acosh` is a alias for `arccosh`. It is a standard API in + https://data-apis.org/array-api/latest/API_specification/elementwise_functions.html#acosh-x + instead of an official NumPy operator. + + `acosh` is a multivalued function: for each `x` there are infinitely + many numbers `z` such that `cosh(z) = x`. + + For real-valued input data types, `acosh` always returns real output. + For each value that cannot be expressed as a real number or infinity, it + yields ``nan`` and sets the `invalid` floating point error flag. + + This function differs from the original numpy.arccosh in the following aspects: + + * Do not support `where`, a parameter in numpy which indicates where to calculate. + * Do not support complex-valued input. + * Cannot cast type automatically. Dtype of `out` must be same as the expected one. + * Cannot broadcast automatically. Shape of `out` must be same as the expected one. + * If `x` is plain python numeric, the result won't be stored in out. + + Examples + -------- + >>> a = np.array([3.2, 5.0]) + >>> np.acosh(a) + array([1.8309381, 2.2924316]) + + >>> np.acosh(1) + 0.0 + """ @set_module('mxnet.numpy') @wrap_np_unary_func @@ -5282,6 +5667,56 @@ def arctanh(x, out=None, **kwargs): """ return _mx_nd_np.arctanh(x, out=out, **kwargs) +atanh = arctanh +atanh.__doc__ = """ + Inverse hyperbolic tangent, element-wise. + + >>>np.atanh is np.arctanh + True + + Parameters + ---------- + x : ndarray or scalar + Input array. + out : ndarray or None, optional + A location into which the result is stored. + + Returns + ------- + atanh : ndarray + Array of the same shape as `x`. + This is a scalar if `x` is a scalar. + + .. note:: + `atanh` is a alias for `arctanh`. It is a standard API in + https://data-apis.org/array-api/latest/API_specification/elementwise_functions.html#atanh-x + instead of an official NumPy operator. + + `atanh` is a multivalued function: for each `x` there are infinitely + many numbers `z` such that `tanh(z) = x`. + + For real-valued input data types, `atanh` always returns real output. + For each value that cannot be expressed as a real number or infinity, it + yields ``nan`` and sets the `invalid` floating point error flag. + + This function differs from the original numpy.arctanh in the following aspects: + + * Do not support `where`, a parameter in numpy which indicates where to calculate. + * Do not support complex-valued input. + * Cannot cast type automatically. Dtype of `out` must be same as the expected one. + * Cannot broadcast automatically. Shape of `out` must be same as the expected one. + * If `x` is plain python numeric, the result won't be stored in out. + + Examples + -------- + >>> a = np.array([0.0, -0.5]) + >>> np.atanh(a) + array([0., -0.54930615]) + + >>> np.atanh(1) + 0.0 + """ + @set_module('mxnet.numpy') def argsort(a, axis=-1, kind=None, order=None): @@ -6538,6 +6973,62 @@ def dsplit(ary, indices_or_sections): """ return _mx_nd_np.dsplit(ary, indices_or_sections) +@set_module('mxnet.numpy') +def concat(seq, axis=0, out=None): + """Join a sequence of arrays along an existing axis. + + Parameters + ---------- + a1, a2, ... : sequence of array_like + The arrays must have the same shape, except in the dimension + corresponding to `axis` (the first, by default). + axis : int, optional + The axis along which the arrays will be joined. If axis is None, + arrays are flattened before use. Default is 0. + out : ndarray, optional + If provided, the destination to place the result. The shape must be + correct, matching that of what concatenate would have returned if no + out argument were specified. + + Returns + ------- + res : ndarray + The concatenated array. + + Note + -------- + `concate` is a alias for `concatante`. It is a standard API in + https://data-apis.org/array-api/latest/API_specification/manipulation_functions.html#concat-arrays-axis-0 + instead of an official NumPy operator. + + See Also + -------- + split : Split array into a list of multiple sub-arrays of equal size. + hsplit : Split array into multiple sub-arrays horizontally (column wise) + vsplit : Split array into multiple sub-arrays vertically (row wise) + dsplit : Split array into multiple sub-arrays along the 3rd axis (depth). + stack : Stack a sequence of arrays along a new axis. + hstack : Stack arrays in sequence horizontally (column wise) + vstack : Stack arrays in sequence vertically (row wise) + dstack : Stack arrays in sequence depth wise (along third dimension) + + Examples + -------- + >>> a = np.array([[1, 2], [3, 4]]) + >>> b = np.array([[5, 6]]) + >>> np.concat((a, b), axis=0) + array([[1., 2.], + [3., 4.], + [5., 6.]]) + + >>> np.concat((a, b.T), axis=1) + array([[1., 2., 5.], + [3., 4., 6.]]) + + >>> np.concat((a, b), axis=None) + array([1., 2., 3., 4., 5., 6.]) + """ + return _mx_nd_np.concatenate(seq, axis=axis, out=out) @set_module('mxnet.numpy') def concatenate(seq, axis=0, out=None): @@ -7654,7 +8145,8 @@ def mean(a, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable # pylint: disable=redefined-outer-name @set_module('mxnet.numpy') -def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: disable=too-many-arguments +@wrap_data_api_statical_func +def std(a, axis=None, dtype=None, out=None, correction=0, keepdims=False): # pylint: disable=too-many-arguments """ Compute the standard deviation along the specified axis. Returns the standard deviation, a measure of the spread of a distribution, @@ -7679,10 +8171,10 @@ def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: Alternative output array in which to place the result. It must have the same shape as the expected output but the type (of the calculated values) will be cast if necessary. - ddof : int, optional + correction : int, optional Means Delta Degrees of Freedom. The divisor used in calculations - is ``N - ddof``, where ``N`` represents the number of elements. - By default `ddof` is zero. + is ``N - correction``, where ``N`` represents the number of elements. + By default `correction` is zero. keepdims : bool, optional If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, @@ -7717,7 +8209,7 @@ def std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: >>> np.std(a, dtype=np.float64) array(0.45, dtype=float64) """ - return _mx_nd_np.std(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out) + return _mx_nd_np.std(a, axis=axis, dtype=dtype, ddof=correction, keepdims=keepdims, out=out) # pylint: enable=redefined-outer-name @@ -7772,7 +8264,8 @@ def delete(arr, obj, axis=None): # pylint: disable=redefined-outer-name @set_module('mxnet.numpy') -def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: disable=too-many-arguments +@wrap_data_api_statical_func +def var(a, axis=None, dtype=None, out=None, correction=0, keepdims=False): # pylint: disable=too-many-arguments """ Compute the variance along the specified axis. Returns the variance of the array elements, a measure of the spread of a @@ -7800,10 +8293,10 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: Alternate output array in which to place the result. It must have the same shape as the expected output, but the type is cast if necessary. - ddof : int, optional + correction : int, optional "Delta Degrees of Freedom": the divisor used in the calculation is - ``N - ddof``, where ``N`` represents the number of elements. By - default `ddof` is zero. + ``N - correction``, where ``N`` represents the number of elements. By + default `correction` is zero. keepdims : bool, optional If this is set to True, the axes which are reduced are left in the result as dimensions with size one. With this option, @@ -7840,7 +8333,7 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: >>> ((1-0.55)**2 + (0.1-0.55)**2)/2 0.2025 """ - return _mx_nd_np.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out) + return _mx_nd_np.var(a, axis=axis, dtype=dtype, ddof=correction, keepdims=keepdims, out=out) # pylint: disable=redefined-outer-name @@ -8679,6 +9172,97 @@ def arctan2(x1, x2, out=None, **kwargs): """ return _mx_nd_np.arctan2(x1, x2, out=out) +atan2 = arctan2 +atan2.__doc__ = """ + Element-wise arc tangent of ``x1/x2`` choosing the quadrant correctly. + + The quadrant (i.e., branch) is chosen so that ``atan2(x1, x2)`` is + the signed angle in radians between the ray ending at the origin and + passing through the point (1,0), and the ray ending at the origin and + passing through the point (`x2`, `x1`). (Note the role reversal: the + "`y`-coordinate" is the first function parameter, the "`x`-coordinate" + is the second.) By IEEE convention, this function is defined for + `x2` = +/-0 and for either or both of `x1` and `x2` = +/-inf (see + Notes for specific values). + + This function is not defined for complex-valued arguments; for the + so-called argument of complex values, use `angle`. + + >>>np.atan2 is np.arctan2 + True + + Parameters + ---------- + x1 : ndarray or scalar + `y`-coordinates. + x2 : ndarray or scalar + `x`-coordinates. `x2` must be broadcastable to match the shape of + `x1` or vice versa. + out : ndarray or None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + + Returns + ------- + out : ndarray or scalar + Array of angles in radians, in the range ``[-pi, pi]``. This is a scalar if + `x1` and `x2` are scalars. + + .. notes:: + `atan2` is a alias for `arctan2`. It is a standard API in + https://data-apis.org/array-api/latest/API_specification/elementwise_functions.html#atan2-x1-x2 + instead of an official NumPy operator. + + *atan2* is identical to the ``atan2`` function of the underlying + C library. The following special values are defined in the C + standard: [1]_ + + +========+========+==================+ + | `x1` | `x2` | `atan2(x1,x2)` | + +========+========+==================+ + | +/- 0 | +0 | +/- 0 | + +========+========+==================+ + | +/- 0 | -0 | +/- pi | + +========+========+==================+ + | > 0 | +/-inf | +0 / +pi | + +========+========+==================+ + | < 0 | +/-inf | -0 / -pi | + +========+========+==================+ + | +/-inf | +inf | +/- (pi/4) | + +========+========+==================+ + | +/-inf | -inf | +/- (3*pi/4) | + +========+========+==================+ + + Note that +0 and -0 are distinct floating point numbers, as are +inf + and -inf. + + This function differs from the original numpy.arange in the following aspects: + + * Only support float16, float32 and float64. + + References + ---------- + .. [1] ISO/IEC standard 9899:1999, "Programming language C." + + Examples + -------- + Consider four points in different quadrants: + + >>> x = np.array([-1, +1, +1, -1]) + >>> y = np.array([-1, -1, +1, +1]) + >>> np.atan2(y, x) * 180 / np.pi + array([-135., -45., 45., 135.]) + + Note the order of the parameters. `atan2` is defined also when `x2` = 0 + and at several other special points, obtaining values in + the range ``[-pi, pi]``: + + >>> x = np.array([1, -1]) + >>> y = np.array([0, 0]) + >>> np.atan2(x, y) + array([ 1.5707964, -1.5707964]) + """ @set_module('mxnet.numpy') @wrap_np_binary_func diff --git a/python/mxnet/util.py b/python/mxnet/util.py index ea75030614be..733d4843a76a 100644 --- a/python/mxnet/util.py +++ b/python/mxnet/util.py @@ -645,6 +645,54 @@ def _wrap_np_binary_func(x1, x2, out=None, **kwargs): return func(x1, x2, out=out) return _wrap_np_binary_func +def wrap_data_api_statical_func(func): + """ + A convenience decorator for wrapping data apis standardized statical functions to provide + context keyward backward compatibility + Parameters + ---------- + func : a numpy-compatible array statical function to be wrapped for context keyward change. + Returns + ------- + Function + A function wrapped with context keyward changes. + """ + + @functools.wraps(func) + def _wrap_api_creation_func(*args, **kwargs): + if len(kwargs) != 0: + correction = kwargs.pop('ddof', None) + if correction is not None: + kwargs['correction'] = correction + return func(*args, **kwargs) + + return _wrap_api_creation_func + +def wrap_data_api_linalg_func(func): + """ + A convenience decorator for wrapping data apis standardized linalg functions to provide + context keyward backward compatibility + Parameters + ---------- + func : a numpy-compatible array linalg function to be wrapped for context keyward change. + Returns + ------- + Function + A function wrapped with context keyward changes. + """ + + @functools.wraps(func) + def _wrap_api_creation_func(*args, **kwargs): + if len(kwargs) != 0: + upper = kwargs.pop('UPLO', None) + if upper is not None: + if upper == 'U': + kwargs['upper'] = True + else: + kwargs['upper'] = False + return func(*args, **kwargs) + + return _wrap_api_creation_func # pylint: disable=exec-used def numpy_fallback(func): diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 488f1a80285d..880e617522fd 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -6918,16 +6918,16 @@ def check_eigvals(x, a_np): @use_np def test_np_linalg_eigvalsh(): class TestEigvalsh(HybridBlock): - def __init__(self, UPLO): + def __init__(self, upper): super(TestEigvalsh, self).__init__() - self._UPLO = UPLO + self._upper = upper def forward(self, a): - return np.linalg.eigvalsh(a, UPLO=self._UPLO) + return np.linalg.eigvalsh(a, upper=self._upper) - def check_eigvalsh(w, a_np, UPLO): + def check_eigvalsh(w, a_np, upper): try: - w_expected = onp.linalg.eigvalsh(a_np, UPLO) + w_expected = onp.linalg.eigvalsh(a_np, upper) except Exception as e: print("a:", a_np) print("a shape:", a_np.shape) @@ -6936,7 +6936,7 @@ def check_eigvalsh(w, a_np, UPLO): assert w.shape == w_expected.shape assert_almost_equal(w, w_expected, rtol=rtol, atol=atol) - def new_matrix_from_sym_matrix_nd(sym_a, UPLO): + def new_matrix_from_sym_matrix_nd(sym_a, upper): shape = sym_a.shape if 0 in shape: return sym_a @@ -6945,7 +6945,7 @@ def new_matrix_from_sym_matrix_nd(sym_a, UPLO): for idx in range(n): for i in range(shape[-2]): for j in range(shape[-1]): - if ((UPLO == 'U' and i > j) or (UPLO == 'L' and i < j)): + if ((upper == True and i > j) or (upper == False and i < j)): a[idx][i][j] = onp.random.uniform(-10., 10.) return a.reshape(shape) @@ -6964,12 +6964,12 @@ def new_matrix_from_sym_matrix_nd(sym_a, UPLO): (2, 3, 4, 4) ] dtypes = ['float32', 'float64', 'uint8', 'int8', 'int32', 'int64'] - UPLOs = ['L', 'U'] + uppers = [True, False] for hybridize in [True, False]: - for shape, dtype, UPLO in itertools.product(shapes, dtypes, UPLOs): + for shape, dtype, upper in itertools.product(shapes, dtypes, uppers): rtol = 1e-2 if dtype == 'float32' else 1e-3 atol = 1e-4 if dtype == 'float32' else 1e-5 - test_eigvalsh = TestEigvalsh(UPLO) + test_eigvalsh = TestEigvalsh(upper) if hybridize: test_eigvalsh.hybridize() if 0 in shape: @@ -6980,15 +6980,15 @@ def new_matrix_from_sym_matrix_nd(sym_a, UPLO): a_np = onp.array([onp.diag(onp.random.randint(1, 10, size=shape[-1])) for i in range(n)], dtype=dtype).reshape(shape) else: a_np = new_sym_matrix_with_real_eigvals_nd(shape) - a_np = new_matrix_from_sym_matrix_nd(a_np, UPLO) + a_np = new_matrix_from_sym_matrix_nd(a_np, upper) a = np.array(a_np, dtype=dtype) # check eigvalsh validity mx_out = test_eigvalsh(a) - check_eigvalsh(mx_out, a.asnumpy(), UPLO) + check_eigvalsh(mx_out, a.asnumpy(), upper) # check imperative once again mx_out = test_eigvalsh(a) - check_eigvalsh(mx_out, a.asnumpy(), UPLO) + check_eigvalsh(mx_out, a.asnumpy(), upper) @use_np @@ -7073,16 +7073,16 @@ def check_eig(w, v, a_np): @use_np def test_np_linalg_eigh(): class TestEigh(HybridBlock): - def __init__(self, UPLO): + def __init__(self, upper): super(TestEigh, self).__init__() - self._UPLO = UPLO + self.upper = uppers def forward(self, a): - return np.linalg.eigh(a, UPLO=self._UPLO) + return np.linalg.eigh(a, upper=self.upper) - def check_eigh(w, v, a_np, UPLO): + def check_eigh(w, v, a_np, upper): try: - w_expected, v_expected = onp.linalg.eigh(a_np, UPLO) + w_expected, v_expected = onp.linalg.eigh(a_np, upper) except Exception as e: print("a:", a_np) print("a shape:", a_np.shape) @@ -7093,7 +7093,7 @@ def check_eigh(w, v, a_np, UPLO): # check eigenvalues. assert_almost_equal(w, w_expected, rtol=rtol, atol=atol) # check eigenvectors. - w_shape, v_shape, a_sym_np = get_sym_matrix_nd(a_np, UPLO) + w_shape, v_shape, a_sym_np = get_sym_matrix_nd(a_np, upper) w_np = w.asnumpy() v_np = v.asnumpy() if 0 not in a_np.shape: @@ -7104,7 +7104,7 @@ def check_eigh(w, v, a_np, UPLO): for j in range(w_shape[1]): assert_almost_equal(onp.dot(a_sym_np[i], v_np[i][:, j]), w_np[i][j] * v_np[i][:, j], rtol=rtol, atol=atol) - def get_sym_matrix_nd(a_np, UPLO): + def get_sym_matrix_nd(a_np, upper): a_res_np = a_np shape = a_np.shape if 0 not in a_np.shape: @@ -7115,13 +7115,13 @@ def get_sym_matrix_nd(a_np, UPLO): for idx in range(n): for i in range(nrow): for j in range(ncol): - if ((UPLO == 'L' and i < j) or (UPLO == 'U' and i > j)): + if ((upper == False and i < j) or (upper == True and i > j)): a_res_np[idx][i][j] = a_np[idx][j][i] return (n, nrow), (n, nrow, ncol), a_res_np.reshape(shape) else : return (0, 0), (0, 0, 0), a_res_np.reshape(shape) - def new_matrix_from_sym_matrix_nd(sym_a, UPLO): + def new_matrix_from_sym_matrix_nd(sym_a, upper): shape = sym_a.shape if 0 in shape: return sym_a @@ -7130,7 +7130,7 @@ def new_matrix_from_sym_matrix_nd(sym_a, UPLO): for idx in range(n): for i in range(shape[-2]): for j in range(shape[-1]): - if ((UPLO == 'U' and i > j) or (UPLO == 'L' and i < j)): + if ((upper == True and i > j) or (upper == False and i < j)): a[idx][i][j] = onp.random.uniform(-10., 10.) return a.reshape(shape) @@ -7148,12 +7148,12 @@ def new_matrix_from_sym_matrix_nd(sym_a, UPLO): (2, 3, 4, 4) ] dtypes = ['float32', 'float64', 'uint8', 'int8', 'int32', 'int64'] - UPLOs = ['L', 'U'] + uppers = [True, False] for hybridize in [True, False]: - for shape, dtype, UPLO in itertools.product(shapes, dtypes, UPLOs): + for shape, dtype, upper in itertools.product(shapes, dtypes, uppers): rtol = 1e-2 if dtype == 'float32' else 1e-3 atol = 1e-4 if dtype == 'float32' else 1e-5 - test_eigh = TestEigh(UPLO) + test_eigh = TestEigh(upper) if hybridize: test_eigh.hybridize() if 0 in shape: @@ -7164,15 +7164,15 @@ def new_matrix_from_sym_matrix_nd(sym_a, UPLO): a_np = onp.array([onp.diag(onp.random.randint(1, 10, size=shape[-1])) for i in range(n)], dtype=dtype).reshape(shape) else: a_np = new_sym_matrix_with_real_eigvals_nd(shape) - a_np = new_matrix_from_sym_matrix_nd(a_np, UPLO) + a_np = new_matrix_from_sym_matrix_nd(a_np, upper) a = np.array(a_np, dtype=dtype) # check eigh validity w, v = test_eigh(a) - check_eigh(w, v, a.asnumpy(), UPLO) + check_eigh(w, v, a.asnumpy(), upper) # check imperative once again w, v = test_eigh(a) - check_eigh(w, v, a.asnumpy(), UPLO) + check_eigh(w, v, a.asnumpy(), upper) @use_np