Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
allow mixed types in array func protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jul 16, 2020
1 parent 6901325 commit 299326c
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 11 deletions.
1 change: 1 addition & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
np.float16: 2,
np.uint8: 3,
np.int32: 4,
np.intc: 4,
np.int8: 5,
np.int64: 6,
np.bool_: 7,
Expand Down
27 changes: 16 additions & 11 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from ..context import current_context
from ..ndarray import numpy as _mx_nd_np
from ..ndarray.numpy import _internal as _npi
from ..ndarray.ndarray import _storage_type, from_numpy
from ..ndarray.ndarray import _storage_type
from .utils import _get_np_op
from .fallback import * # pylint: disable=wildcard-import,unused-wildcard-import
from . import fallback
Expand Down Expand Up @@ -148,11 +148,10 @@ def _reshape_view(a, *shape): # pylint: disable=redefined-outer-name

def _as_mx_np_array(object, ctx=None):
"""Convert object to mxnet.numpy.ndarray."""
if isinstance(object, _np.ndarray):
if not object.flags['C_CONTIGUOUS']:
object = _np.ascontiguousarray(object, dtype=object.dtype)
ret = from_numpy(object, array_cls=ndarray)
return ret if ctx is None else ret.as_in_ctx(ctx=ctx)
if isinstance(object, ndarray):
return object
elif isinstance(object, _np.ndarray):
return array(object, dtype=object.dtype, ctx=ctx)
elif isinstance(object, (integer_types, numeric_types)):
return object
elif isinstance(object, (list, tuple)):
Expand Down Expand Up @@ -358,11 +357,17 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad-
out = func(*new_args, **new_kwargs)
return _as_mx_np_array(out, ctx=cur_ctx)
else:
# Note: this allows subclasses that don't override
# __array_function__ to handle mxnet.numpy.ndarray objects
if not py_all(issubclass(t, ndarray) for t in types):
return NotImplemented
return mx_np_func(*args, **kwargs)
if py_all(issubclass(t, ndarray) for t in types):
return mx_np_func(*args, **kwargs)
else:
try:
cur_ctx = next(a.ctx for a in args if hasattr(a, 'ctx'))
except StopIteration:
cur_ctx = next(a.ctx for a in kwargs.values() if hasattr(a, 'ctx'))
new_args = _as_mx_np_array(args, ctx=cur_ctx)
new_kwargs = {k: _as_mx_np_array(v, cur_ctx) for k, v in kwargs.items()}
return mx_np_func(*new_args, **new_kwargs)


def _get_np_basic_indexing(self, key):
"""
Expand Down
6 changes: 6 additions & 0 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,3 +1369,9 @@ def test_dlpack(dtype, size):
same(a_np+1, b)
same(a_np+2, c)
same(a_np+2, a_copy)

@use_np
def test_mixed_array_types():
np_array = _np.array([[1, 2], [3, 4], [5, 6]], dtype="float32")
mx_array = mx.np.ones((3, 1))
assert_almost_equal(mx_array + np_array, 1+np_array)

0 comments on commit 299326c

Please sign in to comment.