Skip to content

Commit

Permalink
[NumPy] allow mixed array types (apache#18562)
Browse files Browse the repository at this point in the history
* allow mixed types in array func protocol

* fix apache#18746

* add support for memory share check
  • Loading branch information
szha authored Aug 1, 2020
1 parent 08a5ee3 commit 5a22193
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 14 deletions.
35 changes: 21 additions & 14 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from ..ndarray import numpy as _mx_nd_np
from ..ndarray.numpy import _internal as _npi
from ..ndarray.ndarray import _storage_type
from ..dlpack import ndarray_from_numpy
from .utils import _get_np_op
from .fallback import * # pylint: disable=wildcard-import,unused-wildcard-import
from . import fallback
Expand Down Expand Up @@ -179,21 +180,20 @@ def _reshape_view(a, *shape): # pylint: disable=redefined-outer-name
ctypes.byref(handle)))
return ndarray(handle=handle, writable=a.writable)


def _as_mx_np_array(object, ctx=None):
"""Convert object to mxnet.numpy.ndarray."""
if isinstance(object, ndarray):
def _as_mx_np_array(object, ctx=None, zero_copy=False):
"""Convert arrays or any array member of container to mxnet.numpy.ndarray on ctx."""
if object is None or isinstance(object, ndarray):
return object
elif isinstance(object, _np.ndarray):
np_dtype = _np.dtype(object.dtype).type
return array(object, dtype=np_dtype, ctx=ctx)
from_numpy = ndarray_from_numpy(ndarray, array)
return from_numpy(object, zero_copy and object.flags['C_CONTIGUOUS'])
elif isinstance(object, (integer_types, numeric_types)):
return object
elif isinstance(object, (list, tuple)):
tmp = [_as_mx_np_array(arr) for arr in object]
return object.__class__(tmp)
elif isinstance(object, (_np.bool_, _np.bool)):
return array(object, dtype=_np.bool_, ctx=ctx)
elif isinstance(object, (list, tuple)):
tmp = [_as_mx_np_array(arr, ctx=ctx, zero_copy=zero_copy) for arr in object]
return object.__class__(tmp)
else:
raise TypeError('Does not support converting {} to mx.np.ndarray.'.format(str(type(object))))

Expand Down Expand Up @@ -392,11 +392,18 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad-
out = func(*new_args, **new_kwargs)
return _as_mx_np_array(out, ctx=cur_ctx)
else:
# Note: this allows subclasses that don't override
# __array_function__ to handle mxnet.numpy.ndarray objects
if not py_all(issubclass(t, ndarray) for t in types):
return NotImplemented
return mx_np_func(*args, **kwargs)
if py_all(issubclass(t, ndarray) for t in types):
return mx_np_func(*args, **kwargs)
else:
try:
cur_ctx = next(a.ctx for a in args if hasattr(a, 'ctx'))
except StopIteration:
cur_ctx = next(a.ctx for a in kwargs.values() if hasattr(a, 'ctx'))
new_args = _as_mx_np_array(args, ctx=cur_ctx,
zero_copy=func_name in {'may_share_memory', 'shares_memory'})
new_kwargs = {k: _as_mx_np_array(v, cur_ctx) for k, v in kwargs.items()}
return mx_np_func(*new_args, **new_kwargs)


def _get_np_basic_indexing(self, key):
"""
Expand Down
20 changes: 20 additions & 0 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,3 +1401,23 @@ def test_from_numpy_exception():
np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32")
mx_array = mx.npx.from_numpy(np_array, zero_copy=False)
np_array[2, 1] = 0 # no error

def test_mixed_array_types():
np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32")
mx_array = mx.np.ones((3, 1))
assert_almost_equal(mx_array + np_array, 1+np_array)

def test_mixed_array_types_share_memory():
np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32")
mx_array = mx.npx.from_numpy(np_array)
assert _np.may_share_memory(np_array, mx_array)
assert _np.shares_memory(np_array, mx_array)

np_array_slice = np_array[:2]
mx_array_slice = mx_array[1:]
assert _np.may_share_memory(np_array_slice, mx_array)
assert _np.shares_memory(np_array_slice, mx_array)

mx_pinned_array = mx_array.as_in_ctx(mx.cpu_pinned(0))
assert not _np.may_share_memory(np_array, mx_pinned_array)
assert not _np.shares_memory(np_array, mx_pinned_array)

0 comments on commit 5a22193

Please sign in to comment.