From a613d039f0b44afb877643f0ead21c0bba2a11e5 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Sun, 14 Jun 2020 21:23:02 -0700 Subject: [PATCH] allow mixed types in array func protocol --- python/mxnet/ndarray/ndarray.py | 1 + python/mxnet/numpy/multiarray.py | 27 ++++++++++++--------- tests/python/unittest/test_numpy_ndarray.py | 6 +++++ 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 9cc8b8942c1d..28e251db153d 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -67,6 +67,7 @@ np.float16: 2, np.uint8: 3, np.int32: 4, + np.intc: 4, np.int8: 5, np.int64: 6, np.bool_: 7, diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index f6b4c8081645..1f6d6f121e22 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -48,7 +48,7 @@ from ..context import current_context from ..ndarray import numpy as _mx_nd_np from ..ndarray.numpy import _internal as _npi -from ..ndarray.ndarray import _storage_type, from_numpy +from ..ndarray.ndarray import _storage_type from .utils import _get_np_op from .fallback import * # pylint: disable=wildcard-import,unused-wildcard-import from . import fallback @@ -148,11 +148,10 @@ def _reshape_view(a, *shape): # pylint: disable=redefined-outer-name def _as_mx_np_array(object, ctx=None): """Convert object to mxnet.numpy.ndarray.""" - if isinstance(object, _np.ndarray): - if not object.flags['C_CONTIGUOUS']: - object = _np.ascontiguousarray(object, dtype=object.dtype) - ret = from_numpy(object, array_cls=ndarray) - return ret if ctx is None else ret.as_in_ctx(ctx=ctx) + if isinstance(object, ndarray): + return object + elif isinstance(object, _np.ndarray): + return array(object, dtype=object.dtype, ctx=ctx) elif isinstance(object, (integer_types, numeric_types)): return object elif isinstance(object, (list, tuple)): @@ -358,11 +357,17 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad- out = func(*new_args, **new_kwargs) return _as_mx_np_array(out, ctx=cur_ctx) else: - # Note: this allows subclasses that don't override - # __array_function__ to handle mxnet.numpy.ndarray objects - if not py_all(issubclass(t, ndarray) for t in types): - return NotImplemented - return mx_np_func(*args, **kwargs) + if py_all(issubclass(t, ndarray) for t in types): + return mx_np_func(*args, **kwargs) + else: + try: + cur_ctx = next(a.ctx for a in args if hasattr(a, 'ctx')) + except StopIteration: + cur_ctx = next(a.ctx for a in kwargs.values() if hasattr(a, 'ctx')) + new_args = _as_mx_np_array(args, ctx=cur_ctx) + new_kwargs = {k: _as_mx_np_array(v, cur_ctx) for k, v in kwargs.items()} + return mx_np_func(*new_args, **new_kwargs) + def _get_np_basic_indexing(self, key): """ diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 966b26d7e2d2..257e94e35ec3 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -1369,3 +1369,9 @@ def test_dlpack(dtype, size): same(a_np+1, b) same(a_np+2, c) same(a_np+2, a_copy) + +@use_np +def test_mixed_array_types(): + np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32") + mx_array = mx.np.ones((3, 1)) + assert_almost_equal(mx_array + np_array, 1+np_array)