Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Numpy dispatch test of ...... #16422

Merged
merged 3 commits into from
Oct 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal',
'hsplit', 'rot90', 'einsum']
'hsplit', 'rot90', 'einsum', 'true_divide']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -337,10 +337,10 @@ def take(a, indices, axis=None, mode='raise', out=None):
if mode not in ('wrap', 'clip', 'raise'):
raise NotImplementedError(
"function take does not support mode '{}'".format(mode))
if axis:
return _npi.take(a, indices, axis, mode, out)
else:
if axis is None:
return _npi.take(_npi.reshape(a, -1), indices, 0, mode, out)
else:
return _npi.take(a, indices, axis, mode, out)
# pylint: enable=redefined-outer-name


Expand Down Expand Up @@ -495,7 +495,11 @@ def unique(ar, return_index=False, return_inverse=False, return_counts=False, ax
>>> u[indices]
array([1., 2., 6., 4., 2., 3., 2.])
"""
return _npi.unique(ar, return_index, return_inverse, return_counts, axis)
ret = _npi.unique(ar, return_index, return_inverse, return_counts, axis)
if isinstance(ret, list):
return tuple(ret)
else:
return ret


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -604,6 +608,36 @@ def divide(x1, x2, out=None, **kwargs):
_npi.rtrue_divide_scalar, out)


@set_module('mxnet.ndarray.numpy')
def true_divide(x1, x2, out=None):
"""Returns a true division of the inputs, element-wise.

Instead of the Python traditional 'floor division', this returns a true
division. True division adjusts the output type to present the best
answer, regardless of input types.

Parameters
----------
x1 : ndarray or scalar
Dividend array.

x2 : ndarray or scalar
Divisor array.

out : ndarray
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.

Returns
-------
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
"""
return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar,
_npi.rtrue_divide_scalar, out)


@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def mod(x1, x2, out=None, **kwargs):
Expand Down
31 changes: 30 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming',
'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril',
'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum']
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -2216,6 +2216,35 @@ def divide(x1, x2, out=None, **kwargs):
return _mx_nd_np.divide(x1, x2, out=out)


@set_module('mxnet.numpy')
def true_divide(x1, x2, out=None):
"""Returns a true division of the inputs, element-wise.

Instead of the Python traditional 'floor division', this returns a true
division. True division adjusts the output type to present the best
answer, regardless of input types.

Parameters
----------
x1 : ndarray or scalar
Dividend array.

x2 : ndarray or scalar
Divisor array.

out : ndarray
A location into which the result is stored. If provided, it must have a shape
that the inputs broadcast to. If not provided or None, a freshly-allocated array
is returned.

Returns
-------
out : ndarray or scalar
This is a scalar if both x1 and x2 are scalars.
"""
return _mx_nd_np.true_divide(x1, x2, out=out)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def mod(x1, x2, out=None, **kwargs):
Expand Down
4 changes: 3 additions & 1 deletion python/mxnet/numpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as onp

__all__ = ['float16', 'float32', 'float64', 'uint8', 'int32', 'int8', 'int64',
'bool', 'bool_', 'pi', 'inf', 'nan']
'bool', 'bool_', 'pi', 'inf', 'nan', 'PZERO', 'NZERO']

float16 = onp.float16
float32 = onp.float32
Expand All @@ -38,3 +38,5 @@
pi = onp.pi
inf = onp.inf
nan = onp.nan
PZERO = onp.PZERO
NZERO = onp.NZERO
15 changes: 13 additions & 2 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):

_NUMPY_ARRAY_FUNCTION_LIST = [
'argmax',
'around',
'broadcast_arrays',
'broadcast_to',
'clip',
Expand All @@ -93,6 +94,8 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'dot',
'expand_dims',
'fix',
'flip',
'inner',
'max',
'mean',
'min',
Expand All @@ -108,9 +111,11 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'std',
'sum',
'swapaxes',
'take',
'tensordot',
'tile',
'transpose',
'unique',
'var',
'zeros_like',
'meshgrid',
Expand Down Expand Up @@ -161,11 +166,17 @@ def _register_array_function():

# https://docs.scipy.org/doc/numpy/reference/ufuncs.html#available-ufuncs
_NUMPY_ARRAY_UFUNC_LIST = [
'abs',
'add',
'arctan2',
'copysign',
'degrees',
'hypot',
'lcm',
# 'ldexp',
'subtract',
'multiply',
# Uncomment divide when mxnet.numpy.true_divide is added
# 'divide',
'true_divide',
'negative',
'power',
'mod',
Expand Down
14 changes: 10 additions & 4 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
'less_equal', 'hsplit', 'rot90', 'einsum']
'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide']


def _num_outputs(sym):
Expand Down Expand Up @@ -1082,10 +1082,10 @@ def take(a, indices, axis=None, mode='raise', out=None):
if mode not in ('wrap', 'clip', 'raise'):
raise NotImplementedError(
"function take does not support mode '{}'".format(mode))
if axis:
return _npi.take(a, indices, axis, mode, out)
else:
if axis is None:
return _npi.take(_npi.reshape(a, -1), indices, 0, mode, out)
else:
return _npi.take(a, indices, axis, mode, out)
# pylint: enable=redefined-outer-name


Expand Down Expand Up @@ -1164,6 +1164,12 @@ def divide(x1, x2, out=None, **kwargs):
_npi.rtrue_divide_scalar, out)


@set_module('mxnet.ndarray.numpy')
def true_divide(x1, x2, out=None):
return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar,
_npi.rtrue_divide_scalar, out)


@set_module('mxnet.symbol.numpy')
@wrap_np_binary_func
def mod(x1, x2, out=None, **kwargs):
Expand Down
Loading