diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 5274408e4403..b61686738391 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -50,6 +50,7 @@ from ..ndarray import numpy as _mx_nd_np from ..ndarray.numpy import _internal as _npi from ..ndarray.ndarray import _storage_type +from ..dlpack import ndarray_from_numpy from .utils import _get_np_op from .fallback import * # pylint: disable=wildcard-import,unused-wildcard-import from . import fallback @@ -179,21 +180,20 @@ def _reshape_view(a, *shape): # pylint: disable=redefined-outer-name ctypes.byref(handle))) return ndarray(handle=handle, writable=a.writable) - -def _as_mx_np_array(object, ctx=None): - """Convert object to mxnet.numpy.ndarray.""" - if isinstance(object, ndarray): +def _as_mx_np_array(object, ctx=None, zero_copy=False): + """Convert arrays or any array member of container to mxnet.numpy.ndarray on ctx.""" + if object is None or isinstance(object, ndarray): return object elif isinstance(object, _np.ndarray): - np_dtype = _np.dtype(object.dtype).type - return array(object, dtype=np_dtype, ctx=ctx) + from_numpy = ndarray_from_numpy(ndarray, array) + return from_numpy(object, zero_copy and object.flags['C_CONTIGUOUS']) elif isinstance(object, (integer_types, numeric_types)): return object - elif isinstance(object, (list, tuple)): - tmp = [_as_mx_np_array(arr) for arr in object] - return object.__class__(tmp) elif isinstance(object, (_np.bool_, _np.bool)): return array(object, dtype=_np.bool_, ctx=ctx) + elif isinstance(object, (list, tuple)): + tmp = [_as_mx_np_array(arr, ctx=ctx, zero_copy=zero_copy) for arr in object] + return object.__class__(tmp) else: raise TypeError('Does not support converting {} to mx.np.ndarray.'.format(str(type(object)))) @@ -392,11 +392,18 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad- out = func(*new_args, **new_kwargs) return _as_mx_np_array(out, ctx=cur_ctx) else: - # Note: this allows subclasses that don't override - # __array_function__ to handle mxnet.numpy.ndarray objects - if not py_all(issubclass(t, ndarray) for t in types): - return NotImplemented - return mx_np_func(*args, **kwargs) + if py_all(issubclass(t, ndarray) for t in types): + return mx_np_func(*args, **kwargs) + else: + try: + cur_ctx = next(a.ctx for a in args if hasattr(a, 'ctx')) + except StopIteration: + cur_ctx = next(a.ctx for a in kwargs.values() if hasattr(a, 'ctx')) + new_args = _as_mx_np_array(args, ctx=cur_ctx, + zero_copy=func_name in {'may_share_memory', 'shares_memory'}) + new_kwargs = {k: _as_mx_np_array(v, cur_ctx) for k, v in kwargs.items()} + return mx_np_func(*new_args, **new_kwargs) + def _get_np_basic_indexing(self, key): """ diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index f1cd9b38621b..74f6af33d4ad 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -1401,3 +1401,23 @@ def test_from_numpy_exception(): np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32") mx_array = mx.npx.from_numpy(np_array, zero_copy=False) np_array[2, 1] = 0 # no error + +def test_mixed_array_types(): + np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32") + mx_array = mx.np.ones((3, 1)) + assert_almost_equal(mx_array + np_array, 1+np_array) + +def test_mixed_array_types_share_memory(): + np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32") + mx_array = mx.npx.from_numpy(np_array) + assert _np.may_share_memory(np_array, mx_array) + assert _np.shares_memory(np_array, mx_array) + + np_array_slice = np_array[:2] + mx_array_slice = mx_array[1:] + assert _np.may_share_memory(np_array_slice, mx_array) + assert _np.shares_memory(np_array_slice, mx_array) + + mx_pinned_array = mx_array.as_in_ctx(mx.cpu_pinned(0)) + assert not _np.may_share_memory(np_array, mx_pinned_array) + assert not _np.shares_memory(np_array, mx_pinned_array)