Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix (#18313)
Browse files Browse the repository at this point in the history
  • Loading branch information
yijunc committed May 19, 2020
1 parent 3140c55 commit b214477
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 1 deletion.
77 changes: 76 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from builtins import slice as py_slice

from array import array as native_array
import functools
import ctypes
import warnings
import numpy as _np
Expand Down Expand Up @@ -203,6 +204,26 @@ def _np_ndarray_cls(handle, writable=True, stype=0):
_FALLBACK_ARRAY_FUNCTION_WARNED_RECORD = {}
_FALLBACK_ARRAY_UFUNC_WARNED_RECORD = {}

def wrap_mxnp_np_ufunc(func):
"""
A convenience decorator for wrapping for python overload-able ops to provide type
casting for mixed use of mx_np and onp inputs.
Parameters
----------
func : a python overload-able binary function to be wrapped for type casting.
Returns
-------
Function
A function wrapped with type casted.
"""
@functools.wraps(func)
def _wrap_mxnp_np_ufunc(x1, x2):
if isinstance(x2, _np.ndarray):
x2 = _as_mx_np_array(x2, ctx=x1.ctx)
return func(x1, x2)
return _wrap_mxnp_np_ufunc

@set_module('mxnet.numpy') # pylint: disable=invalid-name
class ndarray(NDArray):
Expand Down Expand Up @@ -256,7 +277,17 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # pylint: disable=
Dispatch official NumPy unary/binary operator calls on mxnet.numpy.ndarray
to this function. The operators must comply with the ufunc definition in NumPy.
The following code is adapted from CuPy.
Casting rules for operator with mx_np and onp (inplace op will keep its type)
| Expression | a type | b type | out type|
| --- | --- | --- | --- |
| `a += b` | onp | mx_np | onp |
| `a += b` | mx_np | onp | mx_np |
| `c = a + b` | onp | mx_np | mx_np |
| `c = a + b` | mx_np | onp | mx_np |
"""
ufunc_list = ["add", "subtract", "multiply", "divide", "true_divide", "floor_divide", "power",
"remainder", "bitwise_and", "bitwise_or", "bitwise_xor", "left_shift", "right_shift",
"greater", "greater_equal", "less", "less_equal", "not_equal", "equal", "matmul"]
if 'out' in kwargs:
# need to unfold tuple argument in kwargs
out = kwargs['out']
Expand All @@ -267,13 +298,13 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # pylint: disable=
if method == '__call__':
name = ufunc.__name__
mx_ufunc = _NUMPY_ARRAY_UFUNC_DICT.get(name, None)
onp_op = _get_np_op(name)
if mx_ufunc is None:
# try to fallback to official NumPy op
if is_recording():
raise ValueError("Falling back to NumPy operator {} with autograd active is not supported."
"Please consider moving the operator to the outside of the autograd scope.")\
.format(name)
onp_op = _get_np_op(name)
new_inputs = [arg.asnumpy() if isinstance(arg, ndarray) else arg for arg in inputs]
if onp_op not in _FALLBACK_ARRAY_UFUNC_WARNED_RECORD:
import logging
Expand All @@ -282,6 +313,16 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # pylint: disable=
_FALLBACK_ARRAY_UFUNC_WARNED_RECORD[onp_op] = True
out = onp_op(*new_inputs, **kwargs)
return _as_mx_np_array(out, ctx=inputs[0].ctx)
# ops with np mx_np
elif name in ufunc_list and isinstance(inputs[0], _np.ndarray):
# inplace
if 'out' in kwargs:
new_inputs = [arg.asnumpy() if isinstance(arg, ndarray) else arg for arg in inputs]
return onp_op(*new_inputs, **kwargs)
else:
new_inputs = [_as_mx_np_array(arg, ctx=inputs[1].ctx)
if isinstance(arg, _np.ndarray) else arg for arg in inputs]
return mx_ufunc(*new_inputs, **kwargs)
else:
return mx_ufunc(*inputs, **kwargs)
else:
Expand Down Expand Up @@ -854,10 +895,12 @@ def _prepare_value_nd(self, value, bcast_shape, squeeze_axes=None):
value_nd = value_nd.broadcast_to(bcast_shape)
return value_nd

@wrap_mxnp_np_ufunc
def __add__(self, other):
"""x.__add__(y) <=> x + y"""
return add(self, other)

@wrap_mxnp_np_ufunc
def __iadd__(self, other):
"""x.__iadd__(y) <=> x += y"""
if not self.writable:
Expand All @@ -868,26 +911,32 @@ def __invert__(self):
"""x.__invert__() <=> ~x"""
return invert(self)

@wrap_mxnp_np_ufunc
def __and__(self, other):
"""x.__and__(y) <=> x & y"""
return bitwise_and(self, other)

@wrap_mxnp_np_ufunc
def __or__(self, other):
"""x.__or__(y) <=> x | y"""
return bitwise_or(self, other)

@wrap_mxnp_np_ufunc
def __xor__(self, other):
"""x.__xor__(y) <=> x ^ y"""
return bitwise_xor(self, other)

@wrap_mxnp_np_ufunc
def __iand__(self, other):
"""x.__iand__(y) <=> x &= y"""
return bitwise_and(self, other, out=self)

@wrap_mxnp_np_ufunc
def __ior__(self, other):
"""x.__ior__(y) <=> x |= y"""
return bitwise_or(self, other, out=self)

@wrap_mxnp_np_ufunc
def __ixor__(self, other):
"""x.__ixor__(y) <=> x ^= y"""
return bitwise_xor(self, other, out=self)
Expand All @@ -912,116 +961,142 @@ def __trunc__(self):
"""x.__trunc__()"""
return trunc(self)

@wrap_mxnp_np_ufunc
def __sub__(self, other):
"""x.__sub__(y) <=> x - y"""
return subtract(self, other)

@wrap_mxnp_np_ufunc
def __isub__(self, other):
"""x.__isub__(y) <=> x -= y"""
if not self.writable:
raise ValueError('trying to subtract from a readonly ndarray')
return subtract(self, other, out=self)

@wrap_mxnp_np_ufunc
def __rsub__(self, other):
"""x.__rsub__(y) <=> y - x"""
return subtract(other, self)

@wrap_mxnp_np_ufunc
def __mul__(self, other):
"""x.__mul__(y) <=> x * y"""
return multiply(self, other)

def __neg__(self):
return self.__mul__(-1.0)

@wrap_mxnp_np_ufunc
def __imul__(self, other):
"""x.__imul__(y) <=> x *= y"""
if not self.writable:
raise ValueError('trying to add to a readonly ndarray')
return multiply(self, other, out=self)

@wrap_mxnp_np_ufunc
def __rmul__(self, other):
"""x.__rmul__(y) <=> y * x"""
return self.__mul__(other)

@wrap_mxnp_np_ufunc
def __div__(self, other):
"""x.__div__(y) <=> x / y"""
return divide(self, other)

@wrap_mxnp_np_ufunc
def __rdiv__(self, other):
"""x.__rdiv__(y) <=> y / x"""
return divide(other, self)

@wrap_mxnp_np_ufunc
def __idiv__(self, other):
"""x.__idiv__(y) <=> x /= y"""
return divide(self, other, out=self)

@wrap_mxnp_np_ufunc
def __truediv__(self, other):
"""x.__truediv__(y) <=> x / y"""
return divide(self, other)

@wrap_mxnp_np_ufunc
def __rtruediv__(self, other):
"""x.__rtruediv__(y) <=> y / x"""
return divide(other, self)

@wrap_mxnp_np_ufunc
def __itruediv__(self, other):
"""x.__itruediv__(y) <=> x /= y"""
return divide(self, other, out=self)

@wrap_mxnp_np_ufunc
def __mod__(self, other):
"""x.__mod__(y) <=> x % y"""
return mod(self, other)

@wrap_mxnp_np_ufunc
def __rmod__(self, other):
"""x.__rmod__(y) <=> y % x"""
return mod(other, self)

@wrap_mxnp_np_ufunc
def __imod__(self, other):
"""x.__imod__(y) <=> x %= y"""
return mod(self, other, out=self)

@wrap_mxnp_np_ufunc
def __pow__(self, other):
"""x.__pow__(y) <=> x ** y"""
return power(self, other)

@wrap_mxnp_np_ufunc
def __rpow__(self, other):
"""x.__rpow__(y) <=> y ** x"""
return power(other, self)

@wrap_mxnp_np_ufunc
def __eq__(self, other):
"""x.__eq__(y) <=> x == y"""
return equal(self, other)

def __hash__(self):
raise NotImplementedError

@wrap_mxnp_np_ufunc
def __ne__(self, other):
"""x.__ne__(y) <=> x != y"""
return not_equal(self, other)

@wrap_mxnp_np_ufunc
def __gt__(self, other):
"""x.__gt__(y) <=> x > y"""
return greater(self, other)

@wrap_mxnp_np_ufunc
def __ge__(self, other):
"""x.__ge__(y) <=> x >= y"""
return greater_equal(self, other)

@wrap_mxnp_np_ufunc
def __lt__(self, other):
"""x.__lt__(y) <=> x < y"""
return less(self, other)

@wrap_mxnp_np_ufunc
def __le__(self, other):
"""x.__le__(y) <=> x <= y"""
return less_equal(self, other)

@wrap_mxnp_np_ufunc
def __matmul__(self, other):
"""x.__matmul__(y) <=> x @ y"""
return matmul(self, other)

@wrap_mxnp_np_ufunc
def __rmatmul__(self, other):
"""x.__rmatmul__(y) <=> y @ x"""
return matmul(other, self)

@wrap_mxnp_np_ufunc
def __imatmul__(self, other):
"""x.__imatmul__(y) <=> x @= y"""
return matmul(self, other, out=self)
Expand Down
40 changes: 40 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2659,6 +2659,46 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
continue
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2)

@with_seed()
@use_np
def test_np_mixed_mxnp_op_funcs():
# generate onp & mx_np in same type
onp = _np.array([1,2,3,4,5]).astype("int64")
mx_np = mx.np.array([1,2,3,4,5]).astype("int64")
# inplace onp mx_np
onp += mx_np
assert isinstance(onp, _np.ndarray)
onp -= mx_np
assert isinstance(onp, _np.ndarray)
onp *= mx_np
assert isinstance(onp, _np.ndarray)
# inplace mx_np onp
mx_np ^= onp
assert isinstance(mx_np, mx.np.ndarray)
mx_np |= onp
assert isinstance(mx_np, mx.np.ndarray)
mx_np &= onp
assert isinstance(mx_np, mx.np.ndarray)
# mxnp onp
out = mx_np << onp
assert isinstance(out, mx.np.ndarray)
out = mx_np >> onp
assert isinstance(out, mx.np.ndarray)
out = mx_np != onp
assert isinstance(out, mx.np.ndarray)
# onp mxnp
out = onp == mx_np
assert isinstance(out, mx.np.ndarray)
out = onp >= mx_np
assert isinstance(out, mx.np.ndarray)
out = onp < mx_np
assert isinstance(out, mx.np.ndarray)
onp = _np.array([1,2,3,4,5]).astype("float32")
mx_np = mx.np.array([1,2,3,4,5]).astype("float32")
out = onp @ mx_np
assert isinstance(out, mx.np.ndarray)
out = onp / mx_np
assert isinstance(out, mx.np.ndarray)

@with_seed()
@use_np
Expand Down

0 comments on commit b214477

Please sign in to comment.