From eb3ce7f34cf657496fd6699218667137e96a1ecc Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Sun, 14 Jun 2020 21:23:02 -0700 Subject: [PATCH 1/3] allow mixed types in array func protocol --- python/mxnet/numpy/multiarray.py | 16 +++++++++++----- tests/python/unittest/test_numpy_ndarray.py | 5 +++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 5274408e4403..cd4b3b8fae5c 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -392,11 +392,17 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad- out = func(*new_args, **new_kwargs) return _as_mx_np_array(out, ctx=cur_ctx) else: - # Note: this allows subclasses that don't override - # __array_function__ to handle mxnet.numpy.ndarray objects - if not py_all(issubclass(t, ndarray) for t in types): - return NotImplemented - return mx_np_func(*args, **kwargs) + if py_all(issubclass(t, ndarray) for t in types): + return mx_np_func(*args, **kwargs) + else: + try: + cur_ctx = next(a.ctx for a in args if hasattr(a, 'ctx')) + except StopIteration: + cur_ctx = next(a.ctx for a in kwargs.values() if hasattr(a, 'ctx')) + new_args = _as_mx_np_array(args, ctx=cur_ctx) + new_kwargs = {k: _as_mx_np_array(v, cur_ctx) for k, v in kwargs.items()} + return mx_np_func(*new_args, **new_kwargs) + def _get_np_basic_indexing(self, key): """ diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index f1cd9b38621b..f21ab81b8691 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -1401,3 +1401,8 @@ def test_from_numpy_exception(): np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32") mx_array = mx.npx.from_numpy(np_array, zero_copy=False) np_array[2, 1] = 0 # no error + +def test_mixed_array_types(): + np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32") + mx_array = mx.np.ones((3, 1)) + assert_almost_equal(mx_array + np_array, 1+np_array) From 266a10a56d4a81edf888805816938ffc40479e07 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Sun, 19 Jul 2020 14:42:41 -0700 Subject: [PATCH 2/3] fix #18746 --- python/mxnet/numpy/multiarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index cd4b3b8fae5c..87bfa7a06355 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -182,7 +182,7 @@ def _reshape_view(a, *shape): # pylint: disable=redefined-outer-name def _as_mx_np_array(object, ctx=None): """Convert object to mxnet.numpy.ndarray.""" - if isinstance(object, ndarray): + if object is None or isinstance(object, ndarray): return object elif isinstance(object, _np.ndarray): np_dtype = _np.dtype(object.dtype).type From fc0bc3c1b6f6ff756b1479a5d815e24441862572 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Fri, 31 Jul 2020 12:04:39 -0700 Subject: [PATCH 3/3] add support for memory share check --- python/mxnet/numpy/multiarray.py | 19 ++++++++++--------- tests/python/unittest/test_numpy_ndarray.py | 15 +++++++++++++++ 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 87bfa7a06355..b61686738391 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -50,6 +50,7 @@ from ..ndarray import numpy as _mx_nd_np from ..ndarray.numpy import _internal as _npi from ..ndarray.ndarray import _storage_type +from ..dlpack import ndarray_from_numpy from .utils import _get_np_op from .fallback import * # pylint: disable=wildcard-import,unused-wildcard-import from . import fallback @@ -179,21 +180,20 @@ def _reshape_view(a, *shape): # pylint: disable=redefined-outer-name ctypes.byref(handle))) return ndarray(handle=handle, writable=a.writable) - -def _as_mx_np_array(object, ctx=None): - """Convert object to mxnet.numpy.ndarray.""" +def _as_mx_np_array(object, ctx=None, zero_copy=False): + """Convert arrays or any array member of container to mxnet.numpy.ndarray on ctx.""" if object is None or isinstance(object, ndarray): return object elif isinstance(object, _np.ndarray): - np_dtype = _np.dtype(object.dtype).type - return array(object, dtype=np_dtype, ctx=ctx) + from_numpy = ndarray_from_numpy(ndarray, array) + return from_numpy(object, zero_copy and object.flags['C_CONTIGUOUS']) elif isinstance(object, (integer_types, numeric_types)): return object - elif isinstance(object, (list, tuple)): - tmp = [_as_mx_np_array(arr) for arr in object] - return object.__class__(tmp) elif isinstance(object, (_np.bool_, _np.bool)): return array(object, dtype=_np.bool_, ctx=ctx) + elif isinstance(object, (list, tuple)): + tmp = [_as_mx_np_array(arr, ctx=ctx, zero_copy=zero_copy) for arr in object] + return object.__class__(tmp) else: raise TypeError('Does not support converting {} to mx.np.ndarray.'.format(str(type(object)))) @@ -399,7 +399,8 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad- cur_ctx = next(a.ctx for a in args if hasattr(a, 'ctx')) except StopIteration: cur_ctx = next(a.ctx for a in kwargs.values() if hasattr(a, 'ctx')) - new_args = _as_mx_np_array(args, ctx=cur_ctx) + new_args = _as_mx_np_array(args, ctx=cur_ctx, + zero_copy=func_name in {'may_share_memory', 'shares_memory'}) new_kwargs = {k: _as_mx_np_array(v, cur_ctx) for k, v in kwargs.items()} return mx_np_func(*new_args, **new_kwargs) diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index f21ab81b8691..74f6af33d4ad 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -1406,3 +1406,18 @@ def test_mixed_array_types(): np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32") mx_array = mx.np.ones((3, 1)) assert_almost_equal(mx_array + np_array, 1+np_array) + +def test_mixed_array_types_share_memory(): + np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32") + mx_array = mx.npx.from_numpy(np_array) + assert _np.may_share_memory(np_array, mx_array) + assert _np.shares_memory(np_array, mx_array) + + np_array_slice = np_array[:2] + mx_array_slice = mx_array[1:] + assert _np.may_share_memory(np_array_slice, mx_array) + assert _np.shares_memory(np_array_slice, mx_array) + + mx_pinned_array = mx_array.as_in_ctx(mx.cpu_pinned(0)) + assert not _np.may_share_memory(np_array, mx_pinned_array) + assert not _np.shares_memory(np_array, mx_pinned_array)