Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[NumPy] allow mixed array types #18562

Merged
merged 3 commits into from
Aug 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from ..ndarray import numpy as _mx_nd_np
from ..ndarray.numpy import _internal as _npi
from ..ndarray.ndarray import _storage_type
from ..dlpack import ndarray_from_numpy
from .utils import _get_np_op
from .fallback import * # pylint: disable=wildcard-import,unused-wildcard-import
from . import fallback
Expand Down Expand Up @@ -179,21 +180,20 @@ def _reshape_view(a, *shape): # pylint: disable=redefined-outer-name
ctypes.byref(handle)))
return ndarray(handle=handle, writable=a.writable)


def _as_mx_np_array(object, ctx=None):
"""Convert object to mxnet.numpy.ndarray."""
if isinstance(object, ndarray):
def _as_mx_np_array(object, ctx=None, zero_copy=False):
"""Convert arrays or any array member of container to mxnet.numpy.ndarray on ctx."""
if object is None or isinstance(object, ndarray):
return object
elif isinstance(object, _np.ndarray):
np_dtype = _np.dtype(object.dtype).type
return array(object, dtype=np_dtype, ctx=ctx)
from_numpy = ndarray_from_numpy(ndarray, array)
return from_numpy(object, zero_copy and object.flags['C_CONTIGUOUS'])
elif isinstance(object, (integer_types, numeric_types)):
return object
elif isinstance(object, (list, tuple)):
tmp = [_as_mx_np_array(arr) for arr in object]
return object.__class__(tmp)
elif isinstance(object, (_np.bool_, _np.bool)):
return array(object, dtype=_np.bool_, ctx=ctx)
elif isinstance(object, (list, tuple)):
tmp = [_as_mx_np_array(arr, ctx=ctx, zero_copy=zero_copy) for arr in object]
return object.__class__(tmp)
else:
raise TypeError('Does not support converting {} to mx.np.ndarray.'.format(str(type(object))))

Expand Down Expand Up @@ -392,11 +392,18 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad-
out = func(*new_args, **new_kwargs)
return _as_mx_np_array(out, ctx=cur_ctx)
else:
# Note: this allows subclasses that don't override
# __array_function__ to handle mxnet.numpy.ndarray objects
if not py_all(issubclass(t, ndarray) for t in types):
return NotImplemented
return mx_np_func(*args, **kwargs)
if py_all(issubclass(t, ndarray) for t in types):
return mx_np_func(*args, **kwargs)
else:
try:
cur_ctx = next(a.ctx for a in args if hasattr(a, 'ctx'))
except StopIteration:
cur_ctx = next(a.ctx for a in kwargs.values() if hasattr(a, 'ctx'))
new_args = _as_mx_np_array(args, ctx=cur_ctx,
zero_copy=func_name in {'may_share_memory', 'shares_memory'})
new_kwargs = {k: _as_mx_np_array(v, cur_ctx) for k, v in kwargs.items()}
return mx_np_func(*new_args, **new_kwargs)


def _get_np_basic_indexing(self, key):
"""
Expand Down
20 changes: 20 additions & 0 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,3 +1401,23 @@ def test_from_numpy_exception():
np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32")
mx_array = mx.npx.from_numpy(np_array, zero_copy=False)
np_array[2, 1] = 0 # no error

def test_mixed_array_types():
np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32")
mx_array = mx.np.ones((3, 1))
assert_almost_equal(mx_array + np_array, 1+np_array)

def test_mixed_array_types_share_memory():
np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32")
mx_array = mx.npx.from_numpy(np_array)
assert _np.may_share_memory(np_array, mx_array)
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
assert _np.shares_memory(np_array, mx_array)

np_array_slice = np_array[:2]
mx_array_slice = mx_array[1:]
assert _np.may_share_memory(np_array_slice, mx_array)
assert _np.shares_memory(np_array_slice, mx_array)

mx_pinned_array = mx_array.as_in_ctx(mx.cpu_pinned(0))
assert not _np.may_share_memory(np_array, mx_pinned_array)
assert not _np.shares_memory(np_array, mx_pinned_array)