Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[WIP] scalar support for unary ndarray, ie. nd.sqrt(2), nd.log(3) #4877

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..base import NDArrayHandle, OpHandle
from ..base import check_call
from ..ndarray_doc import _build_doc
from ..context import current_context

_ndarray_cls = None

Expand Down Expand Up @@ -39,6 +40,29 @@ def __reduce__(self):
return (_ndarray_cls, (None,), self.__getstate__())


def _convert2ndarray(data):
# convert to numpy array
src = np.array(data)
# get the information
shape = src.shape
ctx = current_context()

# create empty ndarray
hdl = NDArrayHandle()
check_call(_LIB.MXNDArrayCreate(
c_array(mx_uint,shape),
mx_uint(len(shape)),
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
ctypes.c_int(int(False)),
ctypes.byref(hdl)))

# assignment
arr = _ndarray_cls(handle=hdl)
arr[:] = src
return arr


# pylint: disable=too-many-locals, invalid-name
def _make_ndarray_function(handle, name):
"""Create a NDArray function from the FunctionHandle."""
Expand Down Expand Up @@ -97,6 +121,11 @@ def generic_ndarray_function(*args, **kwargs):
for i in args:
if isinstance(i, NDArrayBase):
ndargs.append(i)
elif num_args.value == 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't valid. It breaks mx.nd.arange()
I think you should fix this by falling back on minpy side

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or at least check if this argument type is NDArray|Symbol

data = i
if np.isscalar(data):
data = [data]
ndargs.append(_convert2ndarray(data))
else:
pos_args.append(str(i))

Expand Down