diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 8dceafd4875c..b2a2a95952ee 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -48,7 +48,7 @@ from ..context import current_context from ..ndarray import numpy as _mx_nd_np from ..ndarray.numpy import _internal as _npi -from ..ndarray.ndarray import _storage_type, from_numpy +from ..ndarray.ndarray import _storage_type from .utils import _get_np_op from .fallback import * # pylint: disable=wildcard-import,unused-wildcard-import from . import fallback @@ -148,10 +148,10 @@ def _reshape_view(a, *shape): # pylint: disable=redefined-outer-name def _as_mx_np_array(object, ctx=None): """Convert object to mxnet.numpy.ndarray.""" - if isinstance(object, _np.ndarray): - if not object.flags['C_CONTIGUOUS']: - object = _np.ascontiguousarray(object, dtype=object.dtype) - ret = from_numpy(object) + if isinstance(object, ndarray): + return object + elif isinstance(object, _np.ndarray): + ret = array(object, dtype=object.dtype, ctx=ctx) return ret if ctx is None else ret.as_in_ctx(ctx=ctx) elif isinstance(object, (integer_types, numeric_types)): return object @@ -358,11 +358,17 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad- out = func(*new_args, **new_kwargs) return _as_mx_np_array(out, ctx=cur_ctx) else: - # Note: this allows subclasses that don't override - # __array_function__ to handle mxnet.numpy.ndarray objects - if not py_all(issubclass(t, ndarray) for t in types): - return NotImplemented - return mx_np_func(*args, **kwargs) + if py_all(issubclass(t, ndarray) for t in types): + return mx_np_func(*args, **kwargs) + else: + try: + cur_ctx = next(a.ctx for a in args if hasattr(a, 'ctx')) + except StopIteration: + cur_ctx = next(a.ctx for a in kwargs.values() if hasattr(a, 'ctx')) + new_args = _as_mx_np_array(args, ctx=cur_ctx) + new_kwargs = {k: _as_mx_np_array(v, cur_ctx) for k, v in kwargs.items()} + return mx_np_func(*new_args, **new_kwargs) + def _get_np_basic_indexing(self, key): """ diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index a8ec5dcf998b..c109023870c2 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -1385,6 +1385,7 @@ def test_from_numpy(np_array, zero_copy): mx_array = mx.npx.from_numpy(np_array, zero_copy=zero_copy) mx.test_utils.assert_almost_equal(np_array, mx_array.asnumpy()) +@use_np def test_from_numpy_exception(): np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32") mx_array = mx.npx.from_numpy(np_array) @@ -1397,3 +1398,9 @@ def test_from_numpy_exception(): assert not np_array.flags["C_CONTIGUOUS"] with pytest.raises(ValueError): mx_array = mx.nd.from_numpy(np_array) + +@use_np +def test_mixed_array_types(): + np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32") + mx_array = mx.np.ones((3, 1)) + assert_almost_equal(mx_array + np_array, 1+np_array)