From 0a1ef48a33c0cdc8f4d8823ce79fca0b7ab72c2c Mon Sep 17 00:00:00 2001 From: reminisce Date: Sat, 18 May 2019 13:30:29 -0700 Subject: [PATCH] [numpy] Refactor np modules (#14989) * Refactor * Initial refactoring * Fix notebook * Move numpy op check from backend to frontend * Add homogeneous ndarray check * Fix grouping inhomogeneous types of symbols * Improve error handling of different types of symbols as outputs * Fix test * Fix numpy test * Fix ci * Try to fix gpu ci failure --- example/numpy/demo.ipynb | 73 ++--- include/mxnet/c_api.h | 17 -- include/mxnet/op_attr_types.h | 9 - python/mxnet/__init__.py | 3 + python/mxnet/_ctypes/ndarray.py | 19 +- python/mxnet/_ctypes/symbol.py | 10 +- python/mxnet/base.py | 119 +++++--- python/mxnet/gluon/block.py | 6 +- python/mxnet/gluon/utils.py | 22 ++ python/mxnet/ndarray/__init__.py | 3 +- python/mxnet/ndarray/ndarray.py | 48 +--- python/mxnet/ndarray/numpy/__init__.py | 5 +- .../ext.py => ndarray/numpy/_internal.py} | 2 +- python/mxnet/ndarray/numpy/_op.py | 20 +- python/mxnet/ndarray/numpy/_register.py | 8 +- python/mxnet/ndarray/numpy/linalg.py | 2 +- python/mxnet/ndarray/numpy/random.py | 2 +- .../mxnet/ndarray/numpy_extension/__init__.py | 24 ++ python/mxnet/ndarray/numpy_extension/_op.py | 21 ++ .../ndarray/numpy_extension/_register.py | 25 ++ python/mxnet/ndarray/register.py | 66 ++++- python/mxnet/numpy/__init__.py | 4 +- python/mxnet/numpy/_op.py | 2 +- python/mxnet/numpy/_register.py | 5 +- python/mxnet/numpy/linalg.py | 2 +- python/mxnet/numpy/multiarray.py | 185 ++++++++----- python/mxnet/numpy/random.py | 2 +- python/mxnet/numpy_extension/__init__.py | 28 ++ .../numpy/ext.py => numpy_extension/_op.py} | 2 +- python/mxnet/numpy_extension/_register.py | 27 ++ python/mxnet/symbol/__init__.py | 4 +- python/mxnet/symbol/numpy/__init__.py | 7 +- .../ext.py => symbol/numpy/_internal.py} | 2 +- python/mxnet/symbol/numpy/_op.py | 2 +- python/mxnet/symbol/numpy/_register.py | 9 +- python/mxnet/symbol/numpy/_symbol.py | 258 +++++++++--------- python/mxnet/symbol/numpy/linalg.py | 2 +- python/mxnet/symbol/numpy/random.py | 2 +- .../mxnet/symbol/numpy_extension/__init__.py | 24 ++ python/mxnet/symbol/numpy_extension/_op.py | 21 ++ .../mxnet/symbol/numpy_extension/_register.py | 24 ++ python/mxnet/symbol/register.py | 74 ++++- python/mxnet/symbol/symbol.py | 57 +--- python/mxnet/test_utils.py | 6 + src/c_api/c_api_common.h | 17 -- src/c_api/c_api_ndarray.cc | 16 -- src/operator/numpy/np_broadcast_reduce_op.h | 1 + .../numpy/np_broadcast_reduce_op_value.cc | 14 +- .../numpy/np_broadcast_reduce_op_value.cu | 8 +- src/operator/numpy/np_dot-inl.h | 11 +- src/operator/numpy/np_dot.cc | 2 +- src/operator/numpy/np_dot.cu | 2 +- .../numpy/np_elemwise_broadcast_op.cc | 56 ++-- .../numpy/np_elemwise_broadcast_op.cu | 34 +-- .../numpy/np_elemwise_unary_op_basic.cc | 28 +- .../numpy/np_elemwise_unary_op_basic.cu | 4 +- src/operator/numpy/np_init_op.cc | 64 ++++- src/operator/numpy/np_init_op.cu | 10 +- src/operator/numpy/np_matrix_op.cc | 6 +- src/operator/numpy/np_matrix_op.cu | 4 +- src/operator/numpy/np_true_divide.cc | 9 +- src/operator/numpy/np_true_divide.cu | 6 +- tests/python/unittest/test_numpy_ndarray.py | 95 ++++--- tests/python/unittest/test_numpy_op.py | 78 +++--- 64 files changed, 1052 insertions(+), 666 deletions(-) rename python/mxnet/{numpy/ext.py => ndarray/numpy/_internal.py} (91%) create mode 100644 python/mxnet/ndarray/numpy_extension/__init__.py create mode 100644 python/mxnet/ndarray/numpy_extension/_op.py create mode 100644 python/mxnet/ndarray/numpy_extension/_register.py create mode 100644 python/mxnet/numpy_extension/__init__.py rename python/mxnet/{symbol/numpy/ext.py => numpy_extension/_op.py} (89%) create mode 100644 python/mxnet/numpy_extension/_register.py rename python/mxnet/{ndarray/numpy/ext.py => symbol/numpy/_internal.py} (89%) create mode 100644 python/mxnet/symbol/numpy_extension/__init__.py create mode 100644 python/mxnet/symbol/numpy_extension/_op.py create mode 100644 python/mxnet/symbol/numpy_extension/_register.py diff --git a/example/numpy/demo.ipynb b/example/numpy/demo.ipynb index d8e6e06e1818..7ba184dad43f 100644 --- a/example/numpy/demo.ipynb +++ b/example/numpy/demo.ipynb @@ -6,21 +6,21 @@ "source": [ "# Fundamentals of MXNet Numpy Module\n", "\n", - "## Operator Namespaces for Imperative Programming\n", + "## Namespaces for Imperative Programming\n", "- `mxnet.numpy`: Regular NumPy operators\n", "- `mxnet.numpy.random`: NumPy random operators\n", "- `mxnet.numpy.linalg`: NumPy linear algebra operators\n", - "- `mxnet.numpy.ext`: Operators implemented in MXNet that do not exist in official NumPy\n", + "- `mxnet.numpy_extension`: Operators implemented in MXNet that do not exist in the official NumPy\n", "\n", "## Operator Namespaces for Gluon\n", - "`F` can be either `mxnet.ndarray` or `mxnet.symbol`.\n", + "`F` can be either `mxnet.ndarray` or `mxnet.symbol`. Note that `np` and `npe` are aliases of `numpy` and `numpy_extension`, respectively.\n", "- `F.np`: Regular NumPy operators\n", "- `F.np.random`: NumPy random operators\n", "- `F.np.linalg`: NumPy linear algebra operators\n", - "- `F.np.ext`: Operators implemented in MXNet that do not exist in official NumPy\n", + "- `F.npe`: Operators implemented in MXNet that do not exist in official NumPy\n", "\n", "## New `ndarray` and `symbol`\n", - "`mxnet.numpy.ndarray` and `mxnet.symbol.numpy._NumpySymbol` (not visible to users)\n", + "`mxnet.numpy.ndarray` (visible to users) and `mxnet.symbol.numpy._Symbol` (not visible to users)\n", "- Same name as in the official NumPy package\n", "- Dispatch convience fluent method calls to MXNet Numpy operators\n", "- Override many convenience fluent methods that do not exist in the official NumPy ndarray\n", @@ -46,7 +46,7 @@ "\n", "# create a scalar tensor\n", "x = np.array(3.14)\n", - "print(x)" + "print(x) # x is actually an ndarray, but a scalar value will be printed" ] }, { @@ -170,13 +170,15 @@ "from mxnet import gluon\n", "class TestBinaryBroadcast(gluon.HybridBlock):\n", " def hybrid_forward(self, F, x1, x2):\n", - " print(\"x1 type:\", str(type(x1)))\n", - " print(\"x2 type:\", str(type(x2)))\n", + " print(\"x1 type in hybrid_forward:\", str(type(x1)))\n", + " print(\"x2 type in hybrid_forward:\", str(type(x2)))\n", " return x1 + x2\n", "\n", "net = TestBinaryBroadcast()\n", "x1 = mx.nd.ones((2, 1))\n", "x2 = mx.nd.ones((1, 3))\n", + "print('x1 input tensor type: ', str(type(x1)))\n", + "print('x2 input tensor type: ', str(type(x2)))\n", "out = net(x1, x2) # ok: imperative execution supports broadcasting\n", "print(out)" ] @@ -203,13 +205,15 @@ "source": [ "class TestBinaryBroadcast2(gluon.HybridBlock):\n", " def hybrid_forward(self, F, x1, x2):\n", - " print(\"x1 type:\", str(type(x1)))\n", - " print(\"x2 type:\", str(type(x2)))\n", + " print(\"x1 type in hybrid_forward:\", str(type(x1)))\n", + " print(\"x2 type in hybrid_forward:\", str(type(x2)))\n", " return x1.as_np_ndarray() + x2 # convert x1 to new numpy ndarray/symbol\n", "\n", "net2 = TestBinaryBroadcast2()\n", "net2.hybridize()\n", "\n", + "print('x1 input tensor type: ', str(type(x1)))\n", + "print('x2 input tensor type: ', str(type(x2)))\n", "out =net2(x1, x2)\n", "print(out)" ] @@ -224,7 +228,9 @@ "net.hybridize() # mark the block for execution using a computational graph\n", "\n", "x1 = x1.as_np_ndarray() # convert x1 to np.ndarray so that _NumpySymbol will be used in graph construction\n", + "print('x1 input tensor type: ', str(type(x1)))\n", "x2 = x2.as_np_ndarray() # convert x2 to np.ndarray so that _NumpySymbol will be used in graph construction\n", + "print('x2 input tensor type: ', str(type(x2)))\n", "out = net(x1, x2) # ok: `+` operation supports broadcasting for _NumpySymbol\n", "print(out) # mxnet.numpy.ndarray type, because it's from a np operator" ] @@ -245,7 +251,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## MXNet Numpy Operators in Imperative Programming" + "### MXNet Numpy Operators in Imperative Programming" ] }, { @@ -255,15 +261,9 @@ "outputs": [], "source": [ "import mxnet as mx\n", - "from mxnet import numpy as np\n", + "from mxnet import numpy as np, numpy_extension as npe\n", "from mxnet import autograd\n", - "try:\n", - " from mxboard import SummaryWriter\n", - "except ImportError:\n", - " SummaryWriter = None\n", "\n", - "# create a summary writer for visualization\n", - "sw = SummaryWriter(logdir='./logs', flush_secs=2) if SummaryWriter is not None else None\n", "\n", "# Use numpy-compatible semantics to support scalar tensors\n", "mx.set_np_compat(True)\n", @@ -285,11 +285,11 @@ "learning_rate = 1e-6\n", "\n", "\n", - "for t in range(1000):\n", + "for t in range(50):\n", " with autograd.record():\n", " # Forward pass: compute predicted y\n", " h = x.dot(w1) # equivalent to np.dot(x, w1)\n", - " h_relu = np.ext.relu(h) # equivalent to mx.nd.relu(h)\n", + " h_relu = npe.relu(h) # equivalent to mx.nd.relu(h)\n", " y_pred = h_relu.dot(w2) # equivalent to np.dot(h_relu, w2)\n", "\n", " # Compute loss\n", @@ -302,23 +302,14 @@ "\n", " # Update weights\n", " w1 -= learning_rate * w1.grad\n", - " w2 -= learning_rate * w2.grad\n", - "\n", - " if sw is not None:\n", - " sw.add_scalar('loss', loss.item(), global_step=t) # loss.item() copies the tensor element to a python scalar\n", - " if t % 50 == 0:\n", - " sw.add_histogram(tag='w1', values=w1, global_step=t)\n", - " sw.add_histogram(tag='w2', values=w2, global_step=t)\n", - "\n", - "if sw is not None:\n", - " sw.close()" + " w2 -= learning_rate * w2.grad" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## MXNet Numpy Operators in Gluon `HybridBlock`" + "### MXNet Numpy Operators in Gluon `HybridBlock`" ] }, { @@ -329,13 +320,7 @@ "source": [ "import mxnet as mx\n", "from mxnet import gluon, autograd\n", - "try:\n", - " from mxboard import SummaryWriter\n", - "except ImportError:\n", - " SummaryWriter = None\n", "\n", - "# create a summary writer for visualization\n", - "sw = SummaryWriter(logdir='./logs', flush_secs=2) if SummaryWriter is not None else None\n", "\n", "# Use numpy-compatible semantics to support scalar tensors\n", "mx.set_np_compat(True)\n", @@ -352,7 +337,7 @@ "\n", " def hybrid_forward(self, F, x, w1, w2):\n", " h = x.dot(w1) # equivalent to F.np.dot(x, w1)\n", - " h_relu = F.np.ext.relu(h) # equivalent to F.relu(h)\n", + " h_relu = F.npe.relu(h) # equivalent to F.relu(h)\n", " y_pred = h_relu.dot(w2) # equivalent to F.np.dot(h_relu, w2)\n", " return y_pred\n", "\n", @@ -373,21 +358,13 @@ "total_loss = TotalLoss()\n", "trainer = gluon.Trainer(regressor.collect_params(), 'sgd', {'learning_rate': 1e-3, 'momentum': 0.9})\n", "\n", - "for t in range(1000):\n", + "for t in range(50):\n", " with autograd.record():\n", " output = regressor(x) # output is a type of np.ndarray because np.dot is the last op in the network\n", " loss = total_loss(output, y) # loss is a scalar np.ndarray\n", " loss.backward()\n", " print(t, loss) # note that loss.asnumpy() is called\n", - " trainer.step(1)\n", - " if sw is not None:\n", - " sw.add_scalar('loss', loss.item(), global_step=t) # loss.item() copies the tensor element to a python scalar\n", - " if t % 50 == 0:\n", - " for k, v in regressor.collect_params().items():\n", - " sw.add_histogram(tag=k, values=v.data(), global_step=t)\n", - "\n", - "if sw is not None:\n", - " sw.close()" + " trainer.step(1)" ] } ], diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 6be371419853..ddd66cd7fae0 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -2902,14 +2902,6 @@ MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param, EngineVarHandle mutable_vars_handle, int num_mutable_vars, EngineFnPropertyHandle prop_handle DEFAULT(NULL), int priority DEFAULT(0), const char* opr_name DEFAULT(NULL)); -/*! - * \brief Determines if an op is a Numpy op by its name prefix. - * Every Numpy op starts with a prefix string "_numpy_". - * \param creator Operator handle - * \param is_np_op Indicator of whether creator is a numpy op handle - */ -MXNET_DLL int MXIsNumpyCompatOp(AtomicSymbolCreator creator, - int* is_np_op); /*! * \brief Create an NDArray from source sharing the same data chunk. * \param src source NDArray @@ -2922,15 +2914,6 @@ MXNET_DLL int MXShallowCopyNDArray(NDArrayHandle src, NDArrayHandle* out); * \param out new Symbol sharing the same graph structure with src */ MXNET_DLL int MXShallowCopySymbol(SymbolHandle src, SymbolHandle * out); -/*! - * \brief Checks if an output of CachedOp is from a numpy op. - * \param handle CachedOp shared ptr - * \param output_idx index of the output of the CachedOp - * \param is_from_np_op indicator of whether the output is from a numpy op - */ -MXNET_DLL int MXIsCachedOpOutputFromNumpyCompatOp(CachedOpHandle handle, - int output_idx, - int* is_from_np_op); /*! * \brief Push an asynchronous operation to the engine. diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index 0e4e3229c195..889b5028a460 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -319,15 +319,6 @@ using FNeedRequantize = std::function; using FAvoidQuantizeInput = std::function; -/*! - * \brief Indicates whether this operator is NumPy compatible. - * It is for distinguishing the operator from classic MXNet operators - * which do not support zero-dim and zero-size tensors. - * In Python, it is used to determine whether to output numpy ndarrays - * or symbols that are NumPy compatible. - */ -using TIsNumpyCompatible = bool; - } // namespace mxnet #endif // MXNET_OP_ATTR_TYPES_H_ diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 7c8150bbcaab..883e84604132 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -30,6 +30,9 @@ from . import ndarray from . import ndarray as nd from . import numpy +from . import numpy_extension +from . import numpy as np +from . import numpy_extension as npe from . import name # use mx.sym as short for symbol from . import symbol as sym diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py index 60ec248c18be..6404d895b884 100644 --- a/python/mxnet/_ctypes/ndarray.py +++ b/python/mxnet/_ctypes/ndarray.py @@ -26,7 +26,7 @@ from ..base import _LIB from ..base import c_str_array, c_handle_array from ..base import NDArrayHandle, CachedOpHandle -from ..base import check_call, _is_np_compat_op +from ..base import check_call class NDArrayBase(object): @@ -70,7 +70,7 @@ def _set_np_ndarray_class(cls): _np_ndarray_cls = cls -def _imperative_invoke(handle, ndargs, keys, vals, out): +def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op): """ctypes implementation of imperative invoke wrapper""" if out is not None: original_output = out @@ -99,9 +99,9 @@ def _imperative_invoke(handle, ndargs, keys, vals, out): c_str_array([str(s) for s in vals]), ctypes.byref(out_stypes))) + create_ndarray_fn = _np_ndarray_cls if is_np_op else _ndarray_cls if original_output is not None: return original_output - create_ndarray_fn = _np_ndarray_cls if _is_np_compat_op(handle) else _ndarray_cls if num_output.value == 1: return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle), stype=out_stypes[0]) @@ -112,11 +112,14 @@ def _imperative_invoke(handle, ndargs, keys, vals, out): class CachedOp(object): """Cached operator handle.""" - __slots__ = ["handle"] + __slots__ = ["handle", "is_np_sym"] def __init__(self, sym, flags=()): self.handle = CachedOpHandle() + from ..symbol.numpy._symbol import _Symbol + self.is_np_sym = True if isinstance(sym, _Symbol) else False + check_call(_LIB.MXCreateCachedOpEx( sym.handle, len(flags), @@ -167,12 +170,10 @@ def __call__(self, *args, **kwargs): if original_output is not None: return original_output + create_ndarray_fn = _np_ndarray_cls if self.is_np_sym else _ndarray_cls if num_output.value == 1: - create_ndarray_fn = _np_ndarray_cls if self._is_from_np_compat_op(0) else _ndarray_cls return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle), stype=out_stypes[0]) else: - return [_np_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), stype=out_stypes[i]) - if self._is_from_np_compat_op(i) else - _ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), stype=out_stypes[i]) - for i in range(num_output.value)] + return [create_ndarray_fn(ctypes.cast(output_vars[i], NDArrayHandle), + stype=out_stypes[i]) for i in range(num_output.value)] diff --git a/python/mxnet/_ctypes/symbol.py b/python/mxnet/_ctypes/symbol.py index 7aea0a251f87..fc159f86854d 100644 --- a/python/mxnet/_ctypes/symbol.py +++ b/python/mxnet/_ctypes/symbol.py @@ -22,7 +22,7 @@ import ctypes from ..base import _LIB -from ..base import c_str_array, c_handle_array, c_str, mx_uint, _is_np_compat_op +from ..base import c_str_array, c_handle_array, c_str, mx_uint from ..base import SymbolHandle from ..base import check_call @@ -122,7 +122,7 @@ def _set_np_symbol_class(cls): _np_symbol_cls = cls -def _symbol_creator(handle, args, kwargs, keys, vals, name): +def _symbol_creator(handle, args, kwargs, keys, vals, name, is_np_op): sym_handle = SymbolHandle() check_call(_LIB.MXSymbolCreateAtomicSymbol( ctypes.c_void_p(handle), @@ -135,10 +135,8 @@ def _symbol_creator(handle, args, kwargs, keys, vals, name): raise TypeError( 'Operators with variable length input can only accept input' 'Symbols either as positional or keyword arguments, not both') - if _is_np_compat_op(handle): - s = _np_symbol_cls(sym_handle) - else: - s = _symbol_cls(sym_handle) + create_symbol_fn = _np_symbol_cls if is_np_op else _symbol_cls + s = create_symbol_fn(sym_handle) if args: s._compose(*args, name=name) elif kwargs: diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 9bce08a58a1c..4013b1408637 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -16,7 +16,7 @@ # under the License. # coding: utf-8 -# pylint: disable=invalid-name, no-member, trailing-comma-tuple, bad-mcs-classmethod-argument, unnecessary-pass, wrong-import-position +# pylint: disable=invalid-name, no-member, trailing-comma-tuple, bad-mcs-classmethod-argument, unnecessary-pass, too-many-lines, wrong-import-position """ctypes library of mxnet and helper functions.""" from __future__ import absolute_import @@ -598,7 +598,9 @@ def _init_op_module(root_namespace, module_name, make_op_func): ctypes.byref(plist))) op_names = [] for i in range(size.value): - op_names.append(py_str(plist[i])) + op_name = py_str(plist[i]) + if not _is_np_op(op_name): + op_names.append(op_name) module_op = sys.modules["%s.%s.op" % (root_namespace, module_name)] module_internal = sys.modules["%s.%s._internal" % (root_namespace, module_name)] @@ -692,7 +694,9 @@ def write_all_str(module_file, module_all_list): ctypes.byref(plist))) op_names = [] for i in range(size.value): - op_names.append(py_str(plist[i])) + op_name = py_str(plist[i]) + if not _is_np_op(op_name): + op_names.append(op_name) module_op_file = get_module_file("%s.%s.op" % (root_namespace, module_name)) module_op_all = [] @@ -748,19 +752,28 @@ def _sanity_check_params(func_name, unsupported_params, param_dict): .format(func_name, param_name)) -_NP_OP_SUBMODULE_LIST = ['_ext_', '_random_', '_linalg_'] -_NP_OP_PREFIX = '_numpy_' +_NP_OP_PREFIX = '_np_' +_NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_'] +_NP_EXT_OP_PREFIX = '_npe_' -def _get_np_op_submodule_name(op_name): - assert op_name.startswith(_NP_OP_PREFIX) - for name in _NP_OP_SUBMODULE_LIST: - if op_name[len(_NP_OP_PREFIX):].startswith(name): - return name +_NP_INTERNAL_OP_PREFIX = '_npi_' + + +def _is_np_op(op_name): + return op_name.startswith(_NP_OP_PREFIX) or op_name.startswith(_NP_EXT_OP_PREFIX)\ + or op_name.startswith(_NP_INTERNAL_OP_PREFIX) + + +def _get_op_submodule_name(op_name, op_name_prefix, submodule_name_list): + assert op_name.startswith(op_name_prefix) + for submodule_name in submodule_name_list: + if op_name[len(op_name_prefix):].startswith(submodule_name): + return submodule_name return "" -def _init_np_op_module(root_namespace, module_name, make_op_func): +def _init_np_op_module(root_module_name, np_module_name, mx_module_name, make_op_func): """ Register numpy operators in namespaces `mxnet.numpy`, `mxnet.ndarray.numpy` and `mxnet.symbol.numpy`. They are used in imperative mode, Gluon APIs w/o hybridization, @@ -770,51 +783,89 @@ def _init_np_op_module(root_namespace, module_name, make_op_func): Parameters ---------- - root_namespace : str + root_module_name : str Top level module name, `mxnet` in the current cases. - module_name : str - Second level module name, `ndarray` or `symbol` in the current case. + np_module_name : str + Second level module name, `numpy` or `numpy_extension` in the current case. make_op_func : function Function for creating op functions. """ + if np_module_name == 'numpy': + op_name_prefix = _NP_OP_PREFIX + submodule_name_list = _NP_OP_SUBMODULE_LIST + elif np_module_name == 'numpy_extension': + op_name_prefix = _NP_EXT_OP_PREFIX + submodule_name_list = [] + elif np_module_name == 'numpy._internal': + op_name_prefix = _NP_INTERNAL_OP_PREFIX + submodule_name_list = [] + else: + raise ValueError('unsupported np module name {}'.format(np_module_name)) + plist = ctypes.POINTER(ctypes.c_char_p)() size = ctypes.c_uint() - check_call(_LIB.MXListAllOpNames(ctypes.byref(size), ctypes.byref(plist))) op_names = [] for i in range(size.value): name = py_str(plist[i]) - if name.startswith(_NP_OP_PREFIX): + if name.startswith(op_name_prefix): op_names.append(name) - if module_name == 'numpy': - # register ops for mxnet.numpy - module_pattern = "%s.%s._op" - submodule_pattern = "%s.%s.%s" + if mx_module_name is None: + # register np/npe ops for imperative programming + op_module_name = "%s.%s._op" % (root_module_name, np_module_name) # e.g. mxnet.numpy._op + op_submodule_name = "%s.%s" % (root_module_name, np_module_name) # e.g. mxnet.numpy.random + elif mx_module_name == 'ndarray' or mx_module_name == 'symbol': + # register numpy internal ops and np/npe ops for use in Gluon + # np internal ops are registered in mxnet.ndarray/symbol.numpy._internal + # np ops are registered in mxnet.ndarray/symbol.numpy._op + # npe ops are registered in mxnet.ndarray/symbol.numpy_extension._op + op_module_name = "%s.%s.%s" % (root_module_name, mx_module_name, np_module_name) + if op_name_prefix != _NP_INTERNAL_OP_PREFIX: + op_module_name += '._op' + # e.g. mxnet.symbol.numpy.random + op_submodule_name = "%s.%s.%s" % (root_module_name, mx_module_name, np_module_name) else: - # register ops for mxnet.ndarray.numpy or mxnet.symbol.numpy - module_pattern = "%s.%s.numpy._op" - submodule_pattern = "%s.%s.numpy.%s" - module_np_op = sys.modules[module_pattern % (root_namespace, module_name)] + raise ValueError('unsupported mxnet module {}'.format(mx_module_name)) + op_submodule_name += '.%s' + + op_module = sys.modules[op_module_name] submodule_dict = {} - for submodule_name in _NP_OP_SUBMODULE_LIST: - submodule_dict[submodule_name] = \ - sys.modules[submodule_pattern % (root_namespace, module_name, submodule_name[1:-1])] + for submodule_name in submodule_name_list: + submodule_dict[submodule_name] = sys.modules[op_submodule_name % submodule_name[1:-1]] for name in op_names: hdl = OpHandle() check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) - submodule_name = _get_np_op_submodule_name(name) - module_name_local = module_name + submodule_name = _get_op_submodule_name(name, op_name_prefix, submodule_name_list) if len(submodule_name) > 0: - func_name = name[(len(_NP_OP_PREFIX) + len(submodule_name)):] + func_name = name[(len(op_name_prefix) + len(submodule_name)):] cur_module = submodule_dict[submodule_name] - module_name_local = submodule_pattern % (root_namespace, - module_name, submodule_name[1:-1]) + module_name_local = op_submodule_name % submodule_name[1:-1] else: - func_name = name[len(_NP_OP_PREFIX):] - cur_module = module_np_op + func_name = name[len(op_name_prefix):] + cur_module = op_module + module_name_local =\ + op_module_name[:-len('._op')] if op_module_name.endswith('._op') else op_module_name function = make_op_func(hdl, name, func_name) function.__module__ = module_name_local setattr(cur_module, function.__name__, function) cur_module.__all__.append(function.__name__) + + +def set_module(module): + """Decorator for overriding __module__ on a function or class. + + Example usage:: + + @set_module('mxnet.numpy') + def example(): + pass + + assert example.__module__ == 'numpy' + """ + def decorator(func): + if module is not None: + func.__module__ = module + return func + return decorator diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index c4c4595c6126..6b4f4b609d13 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -33,7 +33,7 @@ from ..ndarray import NDArray from .. import name as _name from .parameter import Parameter, ParameterDict, DeferredInitializationError -from .utils import _indent, _brief_print_list, HookHandle +from .utils import _indent, _brief_print_list, HookHandle, _check_same_symbol_type from .. import numpy as _mx_np @@ -754,7 +754,7 @@ def _get_graph(self, *args): out = self.hybrid_forward(symbol, *grouped_inputs, **params) # pylint: disable=no-value-for-parameter out, self._out_format = _flatten(out, "output") - self._cached_graph = inputs, symbol.Group(out) + self._cached_graph = inputs, symbol.Group(out, _check_same_symbol_type(out)) return self._cached_graph @@ -1063,7 +1063,7 @@ def __init__(self, outputs, inputs, params=None): syms, self._in_format = _flatten(inputs, "input") out, self._out_format = _flatten(outputs, "output") - out = symbol.Group(out) + out = symbol.Group(out, _check_same_symbol_type(out)) input_names = set() for i in syms: diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 2060f61a0212..241baf415818 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -430,3 +430,25 @@ def shape_is_known(shape): assert dim_size > unknown_dim_size, "shape dimension size cannot be less than {}, while " \ "received {}".format(unknown_dim_size, dim_size) return True + +def _check_same_symbol_type(symbols): + """Check whether all the symbols in the list are of the same type. + Raise type error if the types are different. Return the class of + the symbols.""" + from ..symbol.numpy import _Symbol as np_symbol + from ..symbol import Symbol as classic_symbol + is_np_sym = True if isinstance(symbols[0], np_symbol) else False + for s in symbols[1:]: + if is_np_sym != isinstance(s, np_symbol): + raise TypeError('Found both classic symbol (mx.sym.Symbol) and numpy symbol ' + '(mx.sym.np._Symbol) in outputs. This will prevent you from building ' + 'a computation graph by grouping them since different types of symbols ' + 'are not allowed to be grouped in Gluon to form a computation graph. ' + 'You will need to convert them to the same type of symbols, either ' + 'classic or numpy following this rule: if you want numpy ndarray ' + 'output(s) from the computation graph, please convert all the classic ' + 'symbols in the list to numpy symbols by calling `as_np_ndarray()` ' + 'on each of them; if you want classic ndarray output(s) from the ' + 'computation graph, please convert all the numpy symbols in the list ' + 'to classic symbols by calling `as_classic_ndarray()` on each of them.') + return np_symbol if is_np_sym else classic_symbol diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py index f0e6edb0c748..c326850ec3e4 100644 --- a/python/mxnet/ndarray/__init__.py +++ b/python/mxnet/ndarray/__init__.py @@ -31,6 +31,7 @@ from .sparse import _ndarray_cls from .ndarray import _GRAD_REQ_MAP, _DTYPE_MX_TO_NP, _DTYPE_NP_TO_MX, _new_empty_handle from . import numpy as np +from . import numpy_extension as npe __all__ = op.__all__ + ndarray.__all__ + utils.__all__ + \ - ['contrib', 'linalg', 'random', 'sparse', 'image'] + ['contrib', 'linalg', 'random', 'sparse', 'image', 'numpy', 'numpy_extension'] diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 23a239c54ca6..d835ab6c6d87 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -187,15 +187,15 @@ class NDArray(NDArrayBase): def as_np_ndarray(self): """Convert mxnet.ndarray.NDArray to mxnet.numpy.ndarray.""" + storage_type = self.stype + if storage_type != 'default': + raise ValueError('cannot convert ndarray of stype {} to numpy ndarray' + .format(str(type(storage_type)))) from ..numpy import ndarray hdl = NDArrayHandle() check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl))) return ndarray(handle=hdl, writable=self.writable) - def _is_np_compat(self): - """Always returns False except for mxnet.numpy.ndarray.""" - return False - @property def _tvm_handle(self): return self.handle.value @@ -220,8 +220,6 @@ def _to_shared_mem(self): def __add__(self, other): """x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """ # other may be the type of mxnet.numpy.ndarray - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__add__(self) return add(self, other) def __iadd__(self, other): @@ -236,15 +234,11 @@ def __iadd__(self, other): raise TypeError('type %s not supported' % str(type(other))) def __radd__(self, other): - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__add__(self) return self.__add__(other) def __sub__(self, other): """x.__sub__(y) <=> x-y <=> mx.nd.subtract(x, y) """ # other may be the type of mxnet.numpy.ndarray - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__rsub__(self) return subtract(self, other) def __isub__(self, other): @@ -260,14 +254,10 @@ def __isub__(self, other): def __rsub__(self, other): """x.__rsub__(y) <=> y-x <=> mx.nd.subtract(y, x) """ - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__sub__(self) return subtract(other, self) def __mul__(self, other): """x.__mul__(y) <=> x*y <=> mx.nd.multiply(x, y) """ - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__mul__(self) return multiply(self, other) def __neg__(self): @@ -286,20 +276,14 @@ def __imul__(self, other): raise TypeError('type %s not supported' % str(type(other))) def __rmul__(self, other): - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__mul__(self) return self.__mul__(other) def __div__(self, other): """x.__div__(y) <=> x/y <=> mx.nd.divide(x, y) """ - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__rtruediv__(self) return divide(self, other) def __rdiv__(self, other): """x.__rdiv__(y) <=> y/x <=> mx.nd.divide(y, x) """ - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__truediv__(self) return divide(other, self) def __idiv__(self, other): @@ -314,13 +298,9 @@ def __idiv__(self, other): raise TypeError('type %s not supported' % str(type(other))) def __truediv__(self, other): - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__rtruediv__(self) return divide(self, other) def __rtruediv__(self, other): - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__truediv__(self) return divide(other, self) def __itruediv__(self, other): @@ -328,14 +308,10 @@ def __itruediv__(self, other): def __mod__(self, other): """x.__mod__(y) <=> x%y <=> mx.nd.modulo(x, y) """ - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__rmod__(self) return modulo(self, other) def __rmod__(self, other): """x.__rmod__(y) <=> y%x <=> mx.nd.modulo(y, x) """ - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__mod__(self) return modulo(other, self) def __imod__(self, other): @@ -351,20 +327,14 @@ def __imod__(self, other): def __pow__(self, other): """x.__pow__(y) <=> x**y <=> mx.nd.power(x,y) """ - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__rpow__(self) return power(self, other) def __rpow__(self, other): """x.__pow__(y) <=> y**x <=> mx.nd.power(y,x) """ - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__pow__(self) return power(other, self) def __eq__(self, other): """x.__eq__(y) <=> x==y <=> mx.nd.equal(x, y) """ - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__eq__(self) return equal(self, other) def __hash__(self): @@ -373,32 +343,22 @@ def __hash__(self): def __ne__(self, other): """x.__ne__(y) <=> x!=y <=> mx.nd.not_equal(x, y) """ - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__ne__(self) return not_equal(self, other) def __gt__(self, other): """x.__gt__(y) <=> x>y <=> mx.nd.greater(x, y) """ - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__lt__(self) return greater(self, other) def __ge__(self, other): """x.__ge__(y) <=> x>=y <=> mx.nd.greater_equal(x, y) """ - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__le__(self) return greater_equal(self, other) def __lt__(self, other): """x.__lt__(y) <=> x mx.nd.lesser(x, y) """ - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__gt__(self) return lesser(self, other) def __le__(self, other): """x.__le__(y) <=> x<=y <=> mx.nd.less_equal(x, y) """ - if isinstance(other, NDArray) and other._is_np_compat(): - return other.__ge__(self) return lesser_equal(self, other) def __bool__(self): diff --git a/python/mxnet/ndarray/numpy/__init__.py b/python/mxnet/ndarray/numpy/__init__.py index d97e8086e8c3..7eb478f792f5 100644 --- a/python/mxnet/ndarray/numpy/__init__.py +++ b/python/mxnet/ndarray/numpy/__init__.py @@ -15,12 +15,11 @@ # specific language governing permissions and limitations # under the License. -"""numpy module for numpy ops under mxnet.ndarray.""" +"""Module for numpy ops under mxnet.ndarray.""" -from . import ext from . import random from . import linalg -from . import _op +from . import _op, _internal from . import _register from ._op import * # pylint: disable=wildcard-import diff --git a/python/mxnet/numpy/ext.py b/python/mxnet/ndarray/numpy/_internal.py similarity index 91% rename from python/mxnet/numpy/ext.py rename to python/mxnet/ndarray/numpy/_internal.py index e4c82518d474..c5f292842b3b 100644 --- a/python/mxnet/numpy/ext.py +++ b/python/mxnet/ndarray/numpy/_internal.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. -"""namespace for registering numpy.ext ops for imperative programming.""" +"""Namespace for numpy internal ops.""" __all__ = [] diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 9b32c314df7c..e905fdf9dac6 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -15,18 +15,19 @@ # specific language governing permissions and limitations # under the License. -"""numpy namespace for operators used in Gluon APIs dispatched by F=ndarray module.""" +"""Namespace for numpy operators used in Gluon dispatched by F=ndarray.""" from __future__ import absolute_import import numpy as _np -from ...base import _sanity_check_params, use_np_compat, numeric_types +from ...base import _sanity_check_params, use_np_compat, numeric_types, set_module from ...context import current_context -from .. import _internal +from . import _internal as _npi from ..ndarray import NDArray __all__ = ['zeros', 'ones', 'maximum', 'minimum'] +@set_module('mxnet.ndarray.numpy') @use_np_compat def zeros(shape, dtype=_np.float32, **kwargs): """Return a new array of given shape and type, filled with zeros. @@ -55,9 +56,10 @@ def zeros(shape, dtype=_np.float32, **kwargs): if ctx is None: ctx = current_context() dtype = _np.float32 if dtype is None else dtype - return _internal._np_zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs) + return _npi.zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs) +@set_module('mxnet.ndarray.numpy') @use_np_compat def ones(shape, dtype=None, **kwargs): """Return a new array of given shape and type, filled with ones. @@ -86,7 +88,7 @@ def ones(shape, dtype=None, **kwargs): if ctx is None: ctx = current_context() dtype = _np.float32 if dtype is None else dtype - return _internal._np_ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs) + return _npi.ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs) #pylint: disable= too-many-arguments, no-member, protected-access @@ -138,6 +140,7 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou #pylint: enable= too-many-arguments, no-member, protected-access +@set_module('mxnet.ndarray.numpy') @use_np_compat def maximum(x1, x2, out=None): """Returns element-wise maximum of the input arrays with broadcasting. @@ -152,10 +155,10 @@ def maximum(x1, x2, out=None): ------- out : mxnet.numpy.ndarray or scalar The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.""" - return _ufunc_helper(x1, x2, _internal._np_maximum, _np.maximum, - _internal._np_maximum_scalar, None, out) + return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out) +@set_module('mxnet.ndarray.numpy') @use_np_compat def minimum(x1, x2, out=None): """Returns element-wise minimum of the input arrays with broadcasting. @@ -170,5 +173,4 @@ def minimum(x1, x2, out=None): ------- out : mxnet.numpy.ndarray or scalar The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.""" - return _ufunc_helper(x1, x2, _internal._np_minimum, _np.minimum, - _internal._np_minimum_scalar, None, out) + return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out) diff --git a/python/mxnet/ndarray/numpy/_register.py b/python/mxnet/ndarray/numpy/_register.py index 840797f8c952..3ac464e24217 100644 --- a/python/mxnet/ndarray/numpy/_register.py +++ b/python/mxnet/ndarray/numpy/_register.py @@ -15,10 +15,14 @@ # specific language governing permissions and limitations # under the License. -"""module for registering numpy ops under mxnet.ndarray.numpy.""" +"""Registering numpy ops.""" from ...base import _init_np_op_module from ..register import _make_ndarray_function -_init_np_op_module('mxnet', 'ndarray', _make_ndarray_function) +_init_np_op_module(root_module_name='mxnet', np_module_name='numpy', + mx_module_name='ndarray', make_op_func=_make_ndarray_function) + +_init_np_op_module(root_module_name='mxnet', np_module_name='numpy._internal', + mx_module_name='ndarray', make_op_func=_make_ndarray_function) diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py index b8f10b343430..8f521fd0d456 100644 --- a/python/mxnet/ndarray/numpy/linalg.py +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. -"""numpy.linalg namespace for operators used in Gluon APIs dispatched by F=symbol module.""" +"""Namespace for operators used in Gluon dispatched by F=ndarray.""" __all__ = [] diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 60908b5c8098..8f521fd0d456 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. -"""numpy.random namespace for operators used in Gluon APIs dispatched by F=ndarray module.""" +"""Namespace for operators used in Gluon dispatched by F=ndarray.""" __all__ = [] diff --git a/python/mxnet/ndarray/numpy_extension/__init__.py b/python/mxnet/ndarray/numpy_extension/__init__.py new file mode 100644 index 000000000000..a718274ae9ed --- /dev/null +++ b/python/mxnet/ndarray/numpy_extension/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Module for the ops not belonging to the official numpy package.""" + +from . import _op +from . import _register +from ._op import * # pylint: disable=wildcard-import + +__all__ = _op.__all__ diff --git a/python/mxnet/ndarray/numpy_extension/_op.py b/python/mxnet/ndarray/numpy_extension/_op.py new file mode 100644 index 000000000000..22738a0f1950 --- /dev/null +++ b/python/mxnet/ndarray/numpy_extension/_op.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Namespace for the operators not belonging to the official numpy package +used in Gluon dispatched by F=ndarray module.""" + +__all__ = [] diff --git a/python/mxnet/ndarray/numpy_extension/_register.py b/python/mxnet/ndarray/numpy_extension/_register.py new file mode 100644 index 000000000000..32cd0686551c --- /dev/null +++ b/python/mxnet/ndarray/numpy_extension/_register.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Registering numpy_extension ops.""" + +from ...base import _init_np_op_module +from ..register import _make_ndarray_function + + +_init_np_op_module(root_module_name='mxnet', np_module_name='numpy_extension', + mx_module_name='ndarray', make_op_func=_make_ndarray_function) diff --git a/python/mxnet/ndarray/register.py b/python/mxnet/ndarray/register.py index 1ccf228698ba..a285e508e04c 100644 --- a/python/mxnet/ndarray/register.py +++ b/python/mxnet/ndarray/register.py @@ -24,12 +24,60 @@ from ._internal import NDArrayBase, _imperative_invoke # pylint: disable=unused-import from ..ndarray_doc import _build_doc -from ..base import mx_uint, check_call, _LIB, py_str, _init_op_module, _Null # pylint: disable=unused-import +from ..base import mx_uint, check_call, _LIB, py_str, _init_op_module, _Null, _is_np_op # pylint: disable=unused-import + + +def _verify_all_np_ndarrays(op_name, func_name, *array_list): + """Verify if all the arrays are numpy ndarrays. + + Parameters + ---------- + op_name : str + Operator full name registered in backend. + func_name : str + Operator name exposed to users. This is usually the name by stripping off + the prefix of the full operator names registered in backend. + array_list : list of arrays + """ + from ..numpy import ndarray as np_ndarray + for array in array_list: + if (array is not None) and (not isinstance(array, np_ndarray)): + raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. ' + 'This is a numpy operator which can only accept ' + 'MXNet numpy ndarrays, while received a classic ndarray. ' + 'Please call `as_np_ndarray()` upon the classic ndarray to ' + 'convert it to an MXNet numpy ndarray, and then feed the converted ' + 'array to this operator.' + .format(op_name, func_name)) + + +def _verify_all_classic_ndarrays(op_name, func_name, *array_list): + """Verify if all the arrays are classic ndarrays. + + Parameters + ---------- + op_name : str + Operator full name registered in backend. + func_name : str + Operator name exposed to users. This is usually the name by stripping off + the prefix of the full operator names registered in backend. + array_list : list of arrays + """ + from ..numpy import ndarray as np_ndarray + for array in array_list: + if (array is not None) and (isinstance(array, np_ndarray)): + raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. ' + 'This is a classic operator which can only accept ' + 'classic ndarrays, while received an MXNet numpy ndarray. ' + 'Please call `as_classic_ndarray()` upon the numpy ndarray to ' + 'convert it to a classic ndarray, and then feed the converted ' + 'array to this operator.' + .format(op_name, func_name)) # pylint: disable=too-many-locals -def _generate_ndarray_function_code(handle, name, func_name, signature_only=False): - """Generate function for ndarray op by handle and function name.""" +def _generate_ndarray_function_code(handle, op_name, func_name, signature_only=False): + """Generate function for ndarray op by handle and function op_name.""" real_name = ctypes.c_char_p() desc = ctypes.c_char_p() num_args = mx_uint() @@ -52,7 +100,7 @@ def _generate_ndarray_function_code(handle, name, func_name, signature_only=Fals arg_types = [py_str(arg_types[i]) for i in range(narg)] key_var_num_args = py_str(key_var_num_args.value) ret_type = py_str(ret_type.value) if ret_type.value is not None else '' - doc_str = _build_doc(name, + doc_str = _build_doc(op_name, py_str(desc.value), arg_names, arg_types, @@ -139,10 +187,16 @@ def %s(%s):"""%(func_name, ', '.join(signature))) keys.append('%s') vals.append(_np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name)) + is_np_op = _is_np_op(op_name) + verify_ndarrays_fn =\ + _verify_all_np_ndarrays.__name__ if is_np_op else _verify_all_classic_ndarrays.__name__ if not signature_only: code.append(""" - return _imperative_invoke(%d, ndargs, keys, vals, out)"""%( - handle.value)) + {}("{}", "{}", out, *ndargs) + """.format(verify_ndarrays_fn, op_name, func_name)) + code.append(""" + return _imperative_invoke(%d, ndargs, keys, vals, out, %s)"""%( + handle.value, str(is_np_op))) else: code.append(""" return (0,)""") diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py index 2a58f270b96d..0f3c3c72504e 100644 --- a/python/mxnet/numpy/__init__.py +++ b/python/mxnet/numpy/__init__.py @@ -17,15 +17,15 @@ # specific language governing permissions and limitations # under the License. -"""numpy module for imperative programming.""" +"""Module for numpy ops used in imperative programming.""" from __future__ import absolute_import from . import random from . import linalg -from . import ext from .multiarray import * # pylint: disable=wildcard-import from . import _op from . import _register from ._op import * # pylint: disable=wildcard-import +from ..base import use_np_compat, set_np_compat, np_compat __all__ = [] diff --git a/python/mxnet/numpy/_op.py b/python/mxnet/numpy/_op.py index e6a918c97be4..8f6f9cc053e4 100644 --- a/python/mxnet/numpy/_op.py +++ b/python/mxnet/numpy/_op.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. -"""namespace for registering numpy ops for imperative programming.""" +"""Namespace for registering numpy ops for imperative programming.""" __all__ = [] diff --git a/python/mxnet/numpy/_register.py b/python/mxnet/numpy/_register.py index 53ceecd92478..8a2d2ea61c24 100644 --- a/python/mxnet/numpy/_register.py +++ b/python/mxnet/numpy/_register.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""Register backend ops in mxnet.ndarray namespace.""" +"""Registering ops in mxnet.numpy for imperative programming.""" from __future__ import absolute_import @@ -23,4 +23,5 @@ from ..ndarray.register import _make_ndarray_function -_init_np_op_module('mxnet', 'numpy', _make_ndarray_function) +_init_np_op_module(root_module_name='mxnet', np_module_name='numpy', + mx_module_name=None, make_op_func=_make_ndarray_function) diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py index 96c7ddc06612..e49bfcf6a97c 100644 --- a/python/mxnet/numpy/linalg.py +++ b/python/mxnet/numpy/linalg.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. -"""namespace for registering numpy.linalg ops for imperative programming.""" +"""Namespace for ops used in imperative programming.""" __all__ = [] diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 6c414b4c6266..dfcce0b9a671 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -25,14 +25,14 @@ from array import array as native_array import ctypes import numpy as _np -from ..ndarray import NDArray, _DTYPE_NP_TO_MX +from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _GRAD_REQ_MAP from ..ndarray._internal import _set_np_ndarray_class from . import _op as _mx_np_op from ..base import use_np_compat, check_call, _LIB, NDArrayHandle, _sanity_check_params -from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types +from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types, set_module from ..context import current_context from ..ndarray import numpy as _mx_nd_np -from ..ndarray import _internal as _nd_internal +from ..ndarray.numpy import _internal as _npi __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum'] @@ -73,16 +73,14 @@ def _np_ndarray_cls(handle, writable=True, stype=0): _set_np_ndarray_class(_np_ndarray_cls) -class ndarray(NDArray): # pylint: disable=invalid-name +@set_module('mxnet.numpy') # pylint: disable=invalid-name +class ndarray(NDArray): """An array object represents a multidimensional, homogeneous array of fixed-size items. An associated data-type object describes the format of each element in the array (its byte-order, how many bytes it occupies in memory, whether it is an integer, a floating point number, or something else, etc.). Arrays should be constructed using `array`, `zeros` or `empty`. Currently, only c-contiguous arrays are supported.""" - def _is_np_compat(self): - return True - @use_np_compat def __getitem__(self, item): # TODO(junwu): make output shape of integer indexing correct @@ -90,15 +88,15 @@ def __getitem__(self, item): @use_np_compat def __setitem__(self, key, value): - super(ndarray, self).__setitem__(key, value) + self.as_classic_ndarray().__setitem__(key, value) @use_np_compat def __add__(self, other): """x.__add__(y) <=> x + y""" - if isinstance(other, NDArray): - return _nd_internal._np_add(self, other) + if isinstance(other, ndarray): + return _npi.add(self, other) elif isinstance(other, numeric_types): - return _nd_internal._np_add_scalar(self, float(other)) + return _npi.add_scalar(self, float(other)) else: raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) @@ -107,20 +105,20 @@ def __iadd__(self, other): """x.__iadd__(y) <=> x += y""" if not self.writable: raise ValueError('trying to add to a readonly ndarray') - if isinstance(other, NDArray): - return _nd_internal._np_add(self, other, out=self) + if isinstance(other, ndarray): + return _npi.add(self, other, out=self) elif isinstance(other, numeric_types): - return _nd_internal._np_add_scalar(self, float(other), out=self) + return _npi.add_scalar(self, float(other), out=self) else: raise TypeError('type {} is not supported'.format(str(type(other)))) @use_np_compat def __sub__(self, other): """x.__sub__(y) <=> x - y""" - if isinstance(other, NDArray): - return _nd_internal._np_subtract(self, other) + if isinstance(other, ndarray): + return _npi.subtract(self, other) elif isinstance(other, numeric_types): - return _nd_internal._np_subtract_scalar(self, float(other)) + return _npi.subtract_scalar(self, float(other)) else: raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) @@ -129,30 +127,30 @@ def __isub__(self, other): """x.__isub__(y) <=> x -= y""" if not self.writable: raise ValueError('trying to subtract from a readonly ndarray') - if isinstance(other, NDArray): - return _nd_internal._np_subtract(self, other, out=self) + if isinstance(other, ndarray): + return _npi.subtract(self, other, out=self) elif isinstance(other, numeric_types): - return _nd_internal._np_subtract_scalar(self, float(other), out=self) + return _npi.subtract_scalar(self, float(other), out=self) else: raise TypeError('type {} is not supported'.format(str(type(other)))) @use_np_compat def __rsub__(self, other): """x.__rsub__(y) <=> y - x""" - if isinstance(other, NDArray): - return _nd_internal._np_subtract(other, self) + if isinstance(other, ndarray): + return _npi.subtract(other, self) elif isinstance(other, numeric_types): - return _nd_internal._np_rsubtract_scalar(self, float(other)) + return _npi.rsubtract_scalar(self, float(other)) else: raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) @use_np_compat def __mul__(self, other): """x.__mul__(y) <=> x * y""" - if isinstance(other, NDArray): - return _nd_internal._np_multiply(self, other) + if isinstance(other, ndarray): + return _npi.multiply(self, other) elif isinstance(other, numeric_types): - return _nd_internal._np_multiply_scalar(self, float(other)) + return _npi.multiply_scalar(self, float(other)) else: raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) @@ -190,20 +188,20 @@ def __idiv__(self, other): @use_np_compat def __truediv__(self, other): """x.__truediv__(y) <=> x / y""" - if isinstance(other, NDArray): - return _nd_internal._true_divide(self, other) + if isinstance(other, ndarray): + return _npi.true_divide(self, other) elif isinstance(other, numeric_types): - return _nd_internal._true_divide_scalar(self, float(other)) + return _npi.true_divide_scalar(self, float(other)) else: raise TypeError("ndarray does not support type {} as divisor".format(str(type(other)))) @use_np_compat def __rtruediv__(self, other): """x.__rtruediv__(y) <=> y / x""" - if isinstance(other, NDArray): - return _nd_internal._true_divide(other, self) + if isinstance(other, ndarray): + return _npi.true_divide(other, self) elif isinstance(other, numeric_types): - return _nd_internal._rtrue_divide_scalar(self, float(other)) + return _npi.rtrue_divide_scalar(self, float(other)) else: raise TypeError("ndarray does not support type {} as dividend".format(str(type(other)))) @@ -214,20 +212,20 @@ def __itruediv__(self, other): @use_np_compat def __mod__(self, other): """x.__mod__(y) <=> x % y""" - if isinstance(other, NDArray): - return _nd_internal._np_mod(self, other) + if isinstance(other, ndarray): + return _npi.mod(self, other) elif isinstance(other, numeric_types): - return _nd_internal._np_mod_scalar(self, float(other)) + return _npi.mod_scalar(self, float(other)) else: raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) @use_np_compat def __rmod__(self, other): """x.__rmod__(y) <=> y % x""" - if isinstance(other, NDArray): - return _nd_internal._np_mod(other, self) + if isinstance(other, ndarray): + return _npi.mod(other, self) elif isinstance(other, numeric_types): - return _nd_internal._np_rmod_scalar(self, float(other)) + return _npi.rmod_scalar(self, float(other)) else: raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) @@ -238,20 +236,20 @@ def __imod__(self, other): @use_np_compat def __pow__(self, other): """x.__pow__(y) <=> x ** y""" - if isinstance(other, NDArray): - return _nd_internal._np_power(self, other) + if isinstance(other, ndarray): + return _npi.power(self, other) elif isinstance(other, numeric_types): - return _nd_internal._np_power_scalar(self, float(other)) + return _npi.power_scalar(self, float(other)) else: raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) @use_np_compat def __rpow__(self, other): """x.__rpow__(y) <=> y ** x""" - if isinstance(other, NDArray): - return _nd_internal._np_power(other, self) + if isinstance(other, ndarray): + return _npi.power(other, self) elif isinstance(other, numeric_types): - return _nd_internal._np_rpower_scalar(self, float(other)) + return _npi.rpower_scalar(self, float(other)) else: raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) @@ -355,15 +353,41 @@ def as_classic_ndarray(self): @use_np_compat def __repr__(self): - """Returns a string representation of the array.""" - return '%s\n<%s shape=%s ctx=%s>' % (str(self.asnumpy()), self.__class__.__name__, - self.shape, self.context) + """Returns a string representation of the array using the following rules: + 1. If the `ndarray` is a scalar tensor, only the string of the scalar is returned. + 2. Else if the `ndarray` is allocated on cpu, the string of its numpy form, class name, + and shape is returned. + 3. Else (the `ndarray` is allocated on gpu), the string of its numpy form, class name, + shape, and context is returned.""" + array_str = str(self.asnumpy()) + if self.ndim == 0: # scalar tensor + return array_str + context = self.context + if context.device_type == 'gpu': + return '%s\n<%s shape=%s ctx=%s>' % (array_str, self.__class__.__name__, self.shape, + context) + else: + return '%s\n<%s shape=%s>' % (array_str, self.__class__.__name__, self.shape) @use_np_compat - def attach_grad(self, grad_req='write', stype=None): - if stype is not None: - raise NotImplementedError('mxnet.numpy.ndarray currently does not support stype') - super(ndarray, self).attach_grad(grad_req, stype) + def attach_grad(self, grad_req='write'): # pylint: disable=arguments-differ + """Attach a gradient buffer to this ndarray, so that `backward` + can compute gradient with respect to it. + + Parameters + ---------- + grad_req : {'write', 'add', 'null'} + How gradient will be accumulated. + - 'write': gradient will be overwritten on every backward. + - 'add': gradient will be added to existing value on every backward. + - 'null': do not compute gradient for this NDArray. + """ + grad = _mx_np_op.zeros_like(self) # pylint: disable=undefined-variable + grad_req = _GRAD_REQ_MAP[grad_req] + check_call(_LIB.MXAutogradMarkVariables( + 1, ctypes.pointer(self.handle), + ctypes.pointer(mx_uint(grad_req)), + ctypes.pointer(grad.handle))) @property def grad(self): @@ -412,6 +436,43 @@ def astype(self, dtype, *args, **kwargs): # pylint: disable=arguments-differ,un self.copyto(res) return res + @use_np_compat + def copyto(self, other): + """Copies the value of this array to another array. + + If ``other`` is a ``ndarray`` object, then ``other.shape`` and + ``self.shape`` should be the same. This function copies the value from + ``self`` to ``other``. + + If ``other`` is a context, a new ``NDArray`` will be first created on + the target context, and the value of ``self`` is copied. + + Parameters + ---------- + other : ndarray or Context + The destination array or context. + + Returns + ------- + ndarray + The copied array. If ``other`` is an ``ndarray``, then the return value + and ``other`` will point to the same ``ndarray``. + + Examples + -------- + >>> x = np.ones((2,3)) + >>> y = np.zeros((2,3), mx.gpu(0)) + >>> z = x.copyto(y) + >>> z is y + True + >>> y.asnumpy() + array([[ 1., 1., 1.], + [ 1., 1., 1.]], dtype=float32) + """ + if isinstance(other, ndarray): + other = other.as_classic_ndarray() + return self.as_classic_ndarray().copyto(other).as_np_ndarray() + def asscalar(self): raise AttributeError('mxnet.numpy.ndarray object has no attribute as_scalar') @@ -435,7 +496,7 @@ def reshape(self, shape, order='C'): # pylint: disable=arguments-differ if order != 'C': raise NotImplementedError('reshape only supports C-order,' ' while received {}'.format(order)) - return _mx_np_op.reshape(self, shape=shape, order=order) + return _mx_np_op.reshape(self, newshape=shape, order=order) def reshape_like(self, *args, **kwargs): """Convenience fluent method for :py:func:`reshape_like`. @@ -1117,15 +1178,11 @@ def size(self): """Number of elements in the array.""" return super(ndarray, self).size - @property - @use_np_compat - def stype(self): - raise AttributeError('mxnet.numpy.ndarray object has no attribute stype') - def tostype(self, stype): raise AttributeError('mxnet.numpy.ndarray object has no attribute tostype') +@set_module('mxnet.numpy') @use_np_compat def empty(shape, dtype=None, **kwargs): """Return a new array of given shape and type, without initializing entries. @@ -1158,6 +1215,7 @@ def empty(shape, dtype=None, **kwargs): return ndarray(handle=_new_alloc_handle(shape, ctx, False, dtype)) +@set_module('mxnet.numpy') @use_np_compat def array(object, dtype=None, **kwargs): """ @@ -1169,10 +1227,7 @@ def array(object, dtype=None, **kwargs): An array, any object exposing the array interface, an object whose __array__ method returns an array, or any (nested) sequence. dtype : data-type, optional - The desired data-type for the array. If not given, then the type will - be determined as the minimum type required to hold the objects in the - sequence. This argument can only be used to 'upcast' the array. For - downcasting, use the .astype(t) method. + The desired data-type for the array. Default is `float32`. ctx : device context, optional Device context on which the memory is allocated. Default is `mxnet.context.current_context()`. @@ -1186,18 +1241,19 @@ def array(object, dtype=None, **kwargs): ctx = kwargs.get('ctx', current_context()) if ctx is None: ctx = current_context() + if dtype is None: + dtype = _np.float32 if not isinstance(object, (ndarray, NDArray, _np.ndarray)): try: object = _np.array(object, dtype=dtype) except: raise TypeError('source array must be an array like object') - if dtype is None: - dtype = object.dtype ret = empty(object.shape, dtype=dtype, ctx=ctx) ret[:] = object return ret +@set_module('mxnet.numpy') def zeros(shape, dtype=_np.float32, **kwargs): """Return a new array of given shape and type, filled with zeros. This function currently only supports storing multi-dimensional data @@ -1223,6 +1279,7 @@ def zeros(shape, dtype=_np.float32, **kwargs): return _mx_nd_np.zeros(shape, dtype, **kwargs) +@set_module('mxnet.numpy') def ones(shape, dtype=None, **kwargs): """Return a new array of given shape and type, filled with zeros. This function currently only supports storing multi-dimensional data @@ -1248,6 +1305,7 @@ def ones(shape, dtype=None, **kwargs): return _mx_nd_np.ones(shape, dtype, **kwargs) +@set_module('mxnet.numpy') def maximum(x1, x2, out=None): """Returns element-wise maximum of the input arrays with broadcasting. @@ -1264,6 +1322,7 @@ def maximum(x1, x2, out=None): return _mx_nd_np.maximum(x1, x2, out=out) +@set_module('mxnet.numpy') def minimum(x1, x2, out=None): """Returns element-wise minimum of the input arrays with broadcasting. diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index b1f4b02e5a71..e49bfcf6a97c 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. -"""namespace for registering numpy.random ops for imperative programming.""" +"""Namespace for ops used in imperative programming.""" __all__ = [] diff --git a/python/mxnet/numpy_extension/__init__.py b/python/mxnet/numpy_extension/__init__.py new file mode 100644 index 000000000000..bd5117528e7d --- /dev/null +++ b/python/mxnet/numpy_extension/__init__.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Module for ops not belonging to the official numpy package for imperative programming.""" + +from __future__ import absolute_import +from . import _op +from . import _register +from ._op import * # pylint: disable=wildcard-import +from ..context import * # pylint: disable=wildcard-import + +__all__ = [] diff --git a/python/mxnet/symbol/numpy/ext.py b/python/mxnet/numpy_extension/_op.py similarity index 89% rename from python/mxnet/symbol/numpy/ext.py rename to python/mxnet/numpy_extension/_op.py index 12c5f15cba55..a995e480221a 100644 --- a/python/mxnet/symbol/numpy/ext.py +++ b/python/mxnet/numpy_extension/_op.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. -"""numpy.ext namespace for operators used in Gluon APIs dispatched by F=symbol module.""" +"""Namespace for registering numpy_extension ops for imperative programming.""" __all__ = [] diff --git a/python/mxnet/numpy_extension/_register.py b/python/mxnet/numpy_extension/_register.py new file mode 100644 index 000000000000..8abb7254057c --- /dev/null +++ b/python/mxnet/numpy_extension/_register.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Registering ops in mxnet.numpy_extension for imperative programming.""" + +from __future__ import absolute_import + +from ..base import _init_np_op_module +from ..ndarray.register import _make_ndarray_function + + +_init_np_op_module(root_module_name='mxnet', np_module_name='numpy_extension', + mx_module_name=None, make_op_func=_make_ndarray_function) diff --git a/python/mxnet/symbol/__init__.py b/python/mxnet/symbol/__init__.py index ae9477aaf86f..1cd805792b41 100644 --- a/python/mxnet/symbol/__init__.py +++ b/python/mxnet/symbol/__init__.py @@ -28,5 +28,7 @@ from .symbol import * # pylint: enable=wildcard-import from . import numpy as np +from . import numpy_extension as npe -__all__ = op.__all__ + symbol.__all__ + ['contrib', 'linalg', 'random', 'sparse', 'image'] +__all__ = op.__all__ + symbol.__all__\ + + ['contrib', 'linalg', 'random', 'sparse', 'image', 'numpy', 'numpy_extension'] diff --git a/python/mxnet/symbol/numpy/__init__.py b/python/mxnet/symbol/numpy/__init__.py index 1f20c037a0ec..857849c4ae62 100644 --- a/python/mxnet/symbol/numpy/__init__.py +++ b/python/mxnet/symbol/numpy/__init__.py @@ -15,13 +15,12 @@ # specific language governing permissions and limitations # under the License. -"""numpy module for numpy ops under mxnet.symbol.""" +"""Module for numpy ops under mxnet.symbol.""" from . import random from . import linalg -from . import ext -from . import _op, _symbol -from ._symbol import _NumpySymbol +from . import _op, _symbol, _internal +from ._symbol import _Symbol from . import _register from ._op import * # pylint: disable=wildcard-import from ._symbol import * # pylint: disable=wildcard-import diff --git a/python/mxnet/ndarray/numpy/ext.py b/python/mxnet/symbol/numpy/_internal.py similarity index 89% rename from python/mxnet/ndarray/numpy/ext.py rename to python/mxnet/symbol/numpy/_internal.py index e13423f82535..c5f292842b3b 100644 --- a/python/mxnet/ndarray/numpy/ext.py +++ b/python/mxnet/symbol/numpy/_internal.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. -"""numpy.ext namespace for operators used in Gluon APIs dispatched by F=ndarray module.""" +"""Namespace for numpy internal ops.""" __all__ = [] diff --git a/python/mxnet/symbol/numpy/_op.py b/python/mxnet/symbol/numpy/_op.py index 96da828ecbbb..a4a979f30b18 100644 --- a/python/mxnet/symbol/numpy/_op.py +++ b/python/mxnet/symbol/numpy/_op.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. -"""numpy namespace for operators used in Gluon APIs dispatched by F=symbol module.""" +"""Namespace for operators used in Gluon dispatched by F=symbol module.""" __all__ = [] diff --git a/python/mxnet/symbol/numpy/_register.py b/python/mxnet/symbol/numpy/_register.py index 36dfd7842112..3245c8d6d638 100644 --- a/python/mxnet/symbol/numpy/_register.py +++ b/python/mxnet/symbol/numpy/_register.py @@ -15,9 +15,14 @@ # specific language governing permissions and limitations # under the License. -"""module for registering numpy ops under mxnet.symbol.numpy.""" +"""Registering numpy ops.""" from ...base import _init_np_op_module from ..register import _make_symbol_function -_init_np_op_module('mxnet', 'symbol', _make_symbol_function) +_init_np_op_module(root_module_name='mxnet', np_module_name='numpy', + mx_module_name='symbol', make_op_func=_make_symbol_function) + + +_init_np_op_module(root_module_name='mxnet', np_module_name='numpy._internal', + mx_module_name='symbol', make_op_func=_make_symbol_function) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 8cf6e3039d98..0bbd96b3b2bb 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -23,21 +23,17 @@ import numpy as _np from . import _op as _mx_np_op from ...base import _sanity_check_params, use_np_compat, check_call, _LIB, SymbolHandle -from ...base import numeric_types +from ...base import numeric_types, set_module from ...context import current_context -from .. import _internal from ..symbol import Symbol from .._internal import _set_np_symbol_class -from .. import _internal as _sym_internal +from . import _internal as _npi __all__ = ['zeros', 'ones', 'maximum', 'minimum'] -class _NumpySymbol(Symbol): - - def _is_np_compat(self): - return True - +@set_module('mxnet.symbol.numpy') +class _Symbol(Symbol): def __getitem__(self, item): raise NotImplementedError @@ -45,72 +41,72 @@ def __setitem__(self, key, value): raise NotImplementedError def __iter__(self): - raise AttributeError('_NumpySymbol object has no attribute __iter__') + raise AttributeError('_Symbol object has no attribute __iter__') @use_np_compat def __add__(self, other): """x.__add__(y) <=> x + y""" - if isinstance(other, Symbol): - return _sym_internal._np_add(self, other) + if isinstance(other, _Symbol): + return _npi.add(self, other) elif isinstance(other, numeric_types): - return _sym_internal._np_add_scalar(self, float(other)) + return _npi.add_scalar(self, float(other)) else: - raise TypeError("_NumpySymbol does not support type {} as operand" + raise TypeError("_Symbol does not support type {} as operand" .format(str(type(other)))) @use_np_compat def __sub__(self, other): """x.__sub__(y) <=> x - y""" - if isinstance(other, Symbol): - return _sym_internal._np_subtract(self, other) + if isinstance(other, _Symbol): + return _npi.subtract(self, other) elif isinstance(other, numeric_types): - return _sym_internal._np_subtract_scalar(self, float(other)) + return _npi.subtract_scalar(self, float(other)) else: - raise TypeError("_NumpySymbol does not support type {} as operand" + raise TypeError("_Symbol does not support type {} as operand" .format(str(type(other)))) @use_np_compat def __rsub__(self, other): """x.__rsub__(y) <=> y - x""" - if isinstance(other, Symbol): - return _sym_internal._np_subtract(other, self) + if isinstance(other, _Symbol): + return _npi.subtract(other, self) elif isinstance(other, numeric_types): - return _sym_internal._np_rsubtract_scalar(self, float(other)) + return _npi.rsubtract_scalar(self, float(other)) else: - raise TypeError("_NumpySymbol does not support type {} as operand" + raise TypeError("_Symbol does not support type {} as operand" .format(str(type(other)))) @use_np_compat def __mul__(self, other): """x.__mul__(y) <=> x * y""" - if isinstance(other, Symbol): - return _sym_internal._np_multiply(self, other) + if isinstance(other, _Symbol): + return _npi.multiply(self, other) elif isinstance(other, numeric_types): - return _sym_internal._np_multiply_scalar(self, float(other)) + return _npi.multiply_scalar(self, float(other)) else: - raise TypeError("_NumpySymbol does not support type {} as operand" + raise TypeError("_Symbol does not support type {} as operand" .format(str(type(other)))) @use_np_compat def __rmul__(self, other): """x.__rmul__(y) <=> y * x""" - if isinstance(other, Symbol): - return _sym_internal._np_multiply(self, other) + if isinstance(other, _Symbol): + return _npi.multiply(self, other) elif isinstance(other, numeric_types): - return _sym_internal._np_multiply_scalar(self, float(other)) + return _npi.multiply_scalar(self, float(other)) else: - raise TypeError("_NumpySymbol does not support type {} as operand" + raise TypeError("_Symbol does not support type {} as operand" .format(str(type(other)))) def __div__(self, other): - raise AttributeError('_NumpySymbol.__div__ is replaced by __truediv__. If you are using' + raise AttributeError('_Symbol.__div__ is replaced by __truediv__. If you are using' ' Python2, please use the statement from __future__ import division' ' to change the / operator to mean true division throughout the' ' module. If you are using Python3, this error should not have' ' been encountered.') def __rdiv__(self, other): - raise AttributeError('_NumpySymbol.__rdiv__ is replaced by __rtruediv__. If you are using' + raise AttributeError('_Symbol.__rdiv__ is replaced by __rtruediv__. If you are using' ' Python2, please use the statement from __future__ import division' ' to change the / operator to mean true division throughout the' ' module. If you are using Python3, this error should not have' @@ -119,23 +115,23 @@ def __rdiv__(self, other): @use_np_compat def __mod__(self, other): """x.__mod__(y) <=> x % y""" - if isinstance(other, Symbol): - return _sym_internal._np_mod(self, other) + if isinstance(other, _Symbol): + return _npi.mod(self, other) elif isinstance(other, numeric_types): - return _sym_internal._np_mod_scalar(self, float(other)) + return _npi.mod_scalar(self, float(other)) else: - raise TypeError("_NumpySymbol does not support type {} as operand" + raise TypeError("_Symbol does not support type {} as operand" .format(str(type(other)))) @use_np_compat def __rmod__(self, other): """x.__rmod__(y) <=> y % x""" - if isinstance(other, Symbol): - return _sym_internal._np_mod(other, self) + if isinstance(other, _Symbol): + return _npi.mod(other, self) elif isinstance(other, numeric_types): - return _sym_internal._np_rmod_scalar(self, float(other)) + return _npi.rmod_scalar(self, float(other)) else: - raise TypeError("_NumpySymbol does not support type {} as operand" + raise TypeError("_Symbol does not support type {} as operand" .format(str(type(other)))) @use_np_compat @@ -145,23 +141,23 @@ def __idiv__(self, other): @use_np_compat def __truediv__(self, other): """x.__truediv__(y) <=> x / y""" - if isinstance(other, Symbol): - return _sym_internal._true_divide(self, other) + if isinstance(other, _Symbol): + return _npi.true_divide(self, other) elif isinstance(other, numeric_types): - return _sym_internal._true_divide_scalar(self, float(other)) + return _npi.true_divide_scalar(self, float(other)) else: - raise TypeError("_NumpySymbol does not support type {} as divisor" + raise TypeError("_Symbol does not support type {} as divisor" .format(str(type(other)))) @use_np_compat def __rtruediv__(self, other): """x.__rtruediv__(y) <=> y / x""" - if isinstance(other, Symbol): - return _sym_internal._true_divide(other, self) + if isinstance(other, _Symbol): + return _npi.true_divide(other, self) elif isinstance(other, numeric_types): - return _sym_internal._rtrue_divide_scalar(self, float(other)).as_np_ndarray() + return _npi.rtrue_divide_scalar(self, float(other)).as_np_ndarray() else: - raise TypeError("_NumpySymbol does not support type {} as dividend" + raise TypeError("_Symbol does not support type {} as dividend" .format(str(type(other)))) @use_np_compat @@ -171,23 +167,23 @@ def __itruediv__(self, other): @use_np_compat def __pow__(self, other): """x.__pow__(y) <=> x ** y""" - if isinstance(other, Symbol): - return _sym_internal._np_power(self, other) + if isinstance(other, _Symbol): + return _npi.power(self, other) elif isinstance(other, numeric_types): - return _sym_internal._np_power_scalar(self, float(other)) + return _npi.power_scalar(self, float(other)) else: - raise TypeError("_NumpySymbol does not support type {} as operand" + raise TypeError("_Symbol does not support type {} as operand" .format(str(type(other)))) @use_np_compat def __rpow__(self, other): """x.__rpow__(y) <=> y ** x""" - if isinstance(other, Symbol): - return _sym_internal._np_power(other, self) + if isinstance(other, _Symbol): + return _npi.power(other, self) elif isinstance(other, numeric_types): - return _sym_internal._np_rpower_scalar(self, float(other)) + return _npi.rpower_scalar(self, float(other)) else: - raise TypeError("_NumpySymbol does not support type {} as operand" + raise TypeError("_Symbol does not support type {} as operand" .format(str(type(other)))) @use_np_compat @@ -197,7 +193,7 @@ def __neg__(self): @use_np_compat def __deepcopy__(self, _): - return super(_NumpySymbol, self).as_np_ndarray() + return super(_Symbol, self).as_np_ndarray() @use_np_compat def __eq__(self, other): @@ -233,7 +229,7 @@ def __len__(self): raise NotImplementedError def as_classic_ndarray(self): - """Convert _NumpySymbol to mxnet.symbol.Symbol to use its convenience fluent methods.""" + """Convert _Symbol to mxnet.symbol.Symbol to use its convenience fluent methods.""" hdl = SymbolHandle() check_call(_LIB.MXShallowCopySymbol(self.handle, ctypes.byref(hdl))) return Symbol(handle=hdl) @@ -258,7 +254,7 @@ def reshape(self, shape, order='C'): # pylint: disable=arguments-differ if order != 'C': raise NotImplementedError('ndarray.copy only supports order=\'C\', while ' 'received {}'.format(str(order))) - return _mx_np_op.reshape(self, shape=shape, order=order) + return _mx_np_op.reshape(self, newshape=shape, order=order) def reshape_like(self, *args, **kwargs): """Convenience fluent method for :py:func:`reshape_like`. @@ -266,7 +262,7 @@ def reshape_like(self, *args, **kwargs): The arguments are the same as for :py:func:`reshape_like`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute reshape_like') + raise AttributeError('_Symbol object has no attribute reshape_like') def zeros_like(self, *args, **kwargs): """Convenience fluent method for :py:func:`zeros_like`. @@ -274,7 +270,7 @@ def zeros_like(self, *args, **kwargs): The arguments are the same as for :py:func:`zeros_like`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute zeros_like') + raise AttributeError('_Symbol object has no attribute zeros_like') def ones_like(self, *args, **kwargs): """Convenience fluent method for :py:func:`ones_like`. @@ -282,7 +278,7 @@ def ones_like(self, *args, **kwargs): The arguments are the same as for :py:func:`ones_like`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute ones_like') + raise AttributeError('_Symbol object has no attribute ones_like') def broadcast_axes(self, *args, **kwargs): """Convenience fluent method for :py:func:`broadcast_axes`. @@ -290,7 +286,7 @@ def broadcast_axes(self, *args, **kwargs): The arguments are the same as for :py:func:`broadcast_axes`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute broadcast_like') + raise AttributeError('_Symbol object has no attribute broadcast_like') @use_np_compat def repeat(self, *args, **kwargs): @@ -307,7 +303,7 @@ def pad(self, *args, **kwargs): The arguments are the same as for :py:func:`pad`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute pad') + raise AttributeError('_Symbol object has no attribute pad') @use_np_compat def swapaxes(self, *args, **kwargs): @@ -324,7 +320,7 @@ def split(self, *args, **kwargs): The arguments are the same as for :py:func:`split`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute split') + raise AttributeError('_Symbol object has no attribute split') def split_v2(self, *args, **kwargs): """Convenience fluent method for :py:func:`split_v2`. @@ -332,7 +328,7 @@ def split_v2(self, *args, **kwargs): The arguments are the same as for :py:func:`split_v2`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute split_v2') + raise AttributeError('_Symbol object has no attribute split_v2') def slice(self, *args, **kwargs): """Convenience fluent method for :py:func:`slice`. @@ -340,7 +336,7 @@ def slice(self, *args, **kwargs): The arguments are the same as for :py:func:`slice`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute slice') + raise AttributeError('_Symbol object has no attribute slice') def slice_axis(self, *args, **kwargs): """Convenience fluent method for :py:func:`slice_axis`. @@ -348,7 +344,7 @@ def slice_axis(self, *args, **kwargs): The arguments are the same as for :py:func:`slice_axis`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute slice_axis') + raise AttributeError('_Symbol object has no attribute slice_axis') def slice_like(self, *args, **kwargs): """Convenience fluent method for :py:func:`slice_like`. @@ -356,7 +352,7 @@ def slice_like(self, *args, **kwargs): The arguments are the same as for :py:func:`slice_like`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute slice_like') + raise AttributeError('_Symbol object has no attribute slice_like') @use_np_compat def take(self, *args, **kwargs): @@ -373,7 +369,7 @@ def one_hot(self, *args, **kwargs): The arguments are the same as for :py:func:`one_hot`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute one_hot') + raise AttributeError('_Symbol object has no attribute one_hot') def pick(self, *args, **kwargs): """Convenience fluent method for :py:func:`pick`. @@ -381,7 +377,7 @@ def pick(self, *args, **kwargs): The arguments are the same as for :py:func:`pick`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute pick') + raise AttributeError('_Symbol object has no attribute pick') @use_np_compat def sort(self, *args, **kwargs): @@ -398,7 +394,7 @@ def topk(self, *args, **kwargs): The arguments are the same as for :py:func:`topk`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute topk') + raise AttributeError('_Symbol object has no attribute topk') @use_np_compat def argsort(self, *args, **kwargs): @@ -424,7 +420,7 @@ def argmax_channel(self, *args, **kwargs): The arguments are the same as for :py:func:`argmax_channel`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute argmax_channel') + raise AttributeError('_Symbol object has no attribute argmax_channel') @use_np_compat def argmin(self, *args, **kwargs): @@ -450,7 +446,7 @@ def abs(self, *args, **kwargs): The arguments are the same as for :py:func:`abs`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute abs') + raise AttributeError('_Symbol object has no attribute abs') def sign(self, *args, **kwargs): """Convenience fluent method for :py:func:`sign`. @@ -458,7 +454,7 @@ def sign(self, *args, **kwargs): The arguments are the same as for :py:func:`sign`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute abs') + raise AttributeError('_Symbol object has no attribute abs') @use_np_compat def flatten(self, *args, **kwargs): @@ -475,7 +471,7 @@ def shape_array(self, *args, **kwargs): The arguments are the same as for :py:func:`shape_array`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute shape_array') + raise AttributeError('_Symbol object has no attribute shape_array') def size_array(self, *args, **kwargs): """Convenience fluent method for :py:func:`size_array`. @@ -483,7 +479,7 @@ def size_array(self, *args, **kwargs): The arguments are the same as for :py:func:`size_array`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute size_array') + raise AttributeError('_Symbol object has no attribute size_array') def expand_dims(self, *args, **kwargs): """Convenience fluent method for :py:func:`expand_dims`. @@ -491,7 +487,7 @@ def expand_dims(self, *args, **kwargs): The arguments are the same as for :py:func:`expand_dims`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute expand_dims') + raise AttributeError('_Symbol object has no attribute expand_dims') def tile(self, *args, **kwargs): """Convenience fluent method for :py:func:`tile`. @@ -499,7 +495,7 @@ def tile(self, *args, **kwargs): The arguments are the same as for :py:func:`tile`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute tile') + raise AttributeError('_Symbol object has no attribute tile') @use_np_compat def transpose(self, *axes): # pylint: disable=arguments-differ @@ -516,7 +512,7 @@ def flip(self, *args, **kwargs): The arguments are the same as for :py:func:`flip`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute flip') + raise AttributeError('_Symbol object has no attribute flip') def depth_to_space(self, *args, **kwargs): """Convenience fluent method for :py:func:`depth_to_space`. @@ -524,7 +520,7 @@ def depth_to_space(self, *args, **kwargs): The arguments are the same as for :py:func:`depth_to_space`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute depth_to_space') + raise AttributeError('_Symbol object has no attribute depth_to_space') def space_to_depth(self, *args, **kwargs): """Convenience fluent method for :py:func:`space_to_depth`. @@ -532,7 +528,7 @@ def space_to_depth(self, *args, **kwargs): The arguments are the same as for :py:func:`space_to_depth`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute space_to_depth') + raise AttributeError('_Symbol object has no attribute space_to_depth') def diag(self, k=0, **kwargs): """Convenience fluent method for :py:func:`diag`. @@ -540,7 +536,7 @@ def diag(self, k=0, **kwargs): The arguments are the same as for :py:func:`diag`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute diag') + raise AttributeError('_Symbol object has no attribute diag') @use_np_compat def sum(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ @@ -557,7 +553,7 @@ def nansum(self, *args, **kwargs): The arguments are the same as for :py:func:`nansum`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute nansum') + raise AttributeError('_Symbol object has no attribute nansum') @use_np_compat def prod(self, *args, **kwargs): @@ -574,7 +570,7 @@ def nanprod(self, *args, **kwargs): The arguments are the same as for :py:func:`nanprod`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute nanprod') + raise AttributeError('_Symbol object has no attribute nanprod') @use_np_compat def mean(self, *args, **kwargs): @@ -609,7 +605,7 @@ def norm(self, *args, **kwargs): The arguments are the same as for :py:func:`norm`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute norm') + raise AttributeError('_Symbol object has no attribute norm') @use_np_compat def round(self, *args, **kwargs): @@ -626,7 +622,7 @@ def rint(self, *args, **kwargs): The arguments are the same as for :py:func:`rint`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute rint') + raise AttributeError('_Symbol object has no attribute rint') def fix(self, *args, **kwargs): """Convenience fluent method for :py:func:`fix`. @@ -634,7 +630,7 @@ def fix(self, *args, **kwargs): The arguments are the same as for :py:func:`fix`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute fix') + raise AttributeError('_Symbol object has no attribute fix') def floor(self, *args, **kwargs): """Convenience fluent method for :py:func:`floor`. @@ -642,7 +638,7 @@ def floor(self, *args, **kwargs): The arguments are the same as for :py:func:`floor`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute floor') + raise AttributeError('_Symbol object has no attribute floor') def ceil(self, *args, **kwargs): """Convenience fluent method for :py:func:`ceil`. @@ -650,7 +646,7 @@ def ceil(self, *args, **kwargs): The arguments are the same as for :py:func:`ceil`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute ceil') + raise AttributeError('_Symbol object has no attribute ceil') def trunc(self, *args, **kwargs): """Convenience fluent method for :py:func:`trunc`. @@ -658,7 +654,7 @@ def trunc(self, *args, **kwargs): The arguments are the same as for :py:func:`trunc`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute trunc') + raise AttributeError('_Symbol object has no attribute trunc') def sin(self, *args, **kwargs): """Convenience fluent method for :py:func:`sin`. @@ -666,7 +662,7 @@ def sin(self, *args, **kwargs): The arguments are the same as for :py:func:`sin`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute sin') + raise AttributeError('_Symbol object has no attribute sin') def cos(self, *args, **kwargs): """Convenience fluent method for :py:func:`cos`. @@ -674,7 +670,7 @@ def cos(self, *args, **kwargs): The arguments are the same as for :py:func:`cos`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute cos') + raise AttributeError('_Symbol object has no attribute cos') def tan(self, *args, **kwargs): """Convenience fluent method for :py:func:`tan`. @@ -682,7 +678,7 @@ def tan(self, *args, **kwargs): The arguments are the same as for :py:func:`tan`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute tan') + raise AttributeError('_Symbol object has no attribute tan') def arcsin(self, *args, **kwargs): """Convenience fluent method for :py:func:`arcsin`. @@ -690,7 +686,7 @@ def arcsin(self, *args, **kwargs): The arguments are the same as for :py:func:`arcsin`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute arcsin') + raise AttributeError('_Symbol object has no attribute arcsin') def arccos(self, *args, **kwargs): """Convenience fluent method for :py:func:`arccos`. @@ -698,7 +694,7 @@ def arccos(self, *args, **kwargs): The arguments are the same as for :py:func:`arccos`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute arccos') + raise AttributeError('_Symbol object has no attribute arccos') def arctan(self, *args, **kwargs): """Convenience fluent method for :py:func:`arctan`. @@ -706,7 +702,7 @@ def arctan(self, *args, **kwargs): The arguments are the same as for :py:func:`arctan`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute arctan') + raise AttributeError('_Symbol object has no attribute arctan') def degrees(self, *args, **kwargs): """Convenience fluent method for :py:func:`degrees`. @@ -714,7 +710,7 @@ def degrees(self, *args, **kwargs): The arguments are the same as for :py:func:`degrees`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute degrees') + raise AttributeError('_Symbol object has no attribute degrees') def radians(self, *args, **kwargs): """Convenience fluent method for :py:func:`radians`. @@ -722,7 +718,7 @@ def radians(self, *args, **kwargs): The arguments are the same as for :py:func:`radians`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute radians') + raise AttributeError('_Symbol object has no attribute radians') def sinh(self, *args, **kwargs): """Convenience fluent method for :py:func:`sinh`. @@ -730,7 +726,7 @@ def sinh(self, *args, **kwargs): The arguments are the same as for :py:func:`sinh`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute sinh') + raise AttributeError('_Symbol object has no attribute sinh') def cosh(self, *args, **kwargs): """Convenience fluent method for :py:func:`cosh`. @@ -738,7 +734,7 @@ def cosh(self, *args, **kwargs): The arguments are the same as for :py:func:`cosh`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute cosh') + raise AttributeError('_Symbol object has no attribute cosh') def tanh(self, *args, **kwargs): """Convenience fluent method for :py:func:`tanh`. @@ -746,7 +742,7 @@ def tanh(self, *args, **kwargs): The arguments are the same as for :py:func:`tanh`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute tanh') + raise AttributeError('_Symbol object has no attribute tanh') def arcsinh(self, *args, **kwargs): """Convenience fluent method for :py:func:`arcsinh`. @@ -754,7 +750,7 @@ def arcsinh(self, *args, **kwargs): The arguments are the same as for :py:func:`arcsinh`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute arcsinh') + raise AttributeError('_Symbol object has no attribute arcsinh') def arccosh(self, *args, **kwargs): """Convenience fluent method for :py:func:`arccosh`. @@ -762,7 +758,7 @@ def arccosh(self, *args, **kwargs): The arguments are the same as for :py:func:`arccosh`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute arccosh') + raise AttributeError('_Symbol object has no attribute arccosh') def arctanh(self, *args, **kwargs): """Convenience fluent method for :py:func:`arctanh`. @@ -770,7 +766,7 @@ def arctanh(self, *args, **kwargs): The arguments are the same as for :py:func:`arctanh`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute arctanh') + raise AttributeError('_Symbol object has no attribute arctanh') def exp(self, *args, **kwargs): """Convenience fluent method for :py:func:`exp`. @@ -778,7 +774,7 @@ def exp(self, *args, **kwargs): The arguments are the same as for :py:func:`exp`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute exp') + raise AttributeError('_Symbol object has no attribute exp') def expm1(self, *args, **kwargs): """Convenience fluent method for :py:func:`expm1`. @@ -786,7 +782,7 @@ def expm1(self, *args, **kwargs): The arguments are the same as for :py:func:`expm1`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute expm1') + raise AttributeError('_Symbol object has no attribute expm1') def log(self, *args, **kwargs): """Convenience fluent method for :py:func:`log`. @@ -794,7 +790,7 @@ def log(self, *args, **kwargs): The arguments are the same as for :py:func:`log`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute log') + raise AttributeError('_Symbol object has no attribute log') def log10(self, *args, **kwargs): """Convenience fluent method for :py:func:`log10`. @@ -802,7 +798,7 @@ def log10(self, *args, **kwargs): The arguments are the same as for :py:func:`log10`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute log10') + raise AttributeError('_Symbol object has no attribute log10') def log2(self, *args, **kwargs): """Convenience fluent method for :py:func:`log2`. @@ -810,7 +806,7 @@ def log2(self, *args, **kwargs): The arguments are the same as for :py:func:`log2`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute log2') + raise AttributeError('_Symbol object has no attribute log2') def log1p(self, *args, **kwargs): """Convenience fluent method for :py:func:`log1p`. @@ -818,7 +814,7 @@ def log1p(self, *args, **kwargs): The arguments are the same as for :py:func:`log1p`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute log1p') + raise AttributeError('_Symbol object has no attribute log1p') def sqrt(self, *args, **kwargs): """Convenience fluent method for :py:func:`sqrt`. @@ -826,7 +822,7 @@ def sqrt(self, *args, **kwargs): The arguments are the same as for :py:func:`sqrt`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute sqrt') + raise AttributeError('_Symbol object has no attribute sqrt') def rsqrt(self, *args, **kwargs): """Convenience fluent method for :py:func:`rsqrt`. @@ -834,7 +830,7 @@ def rsqrt(self, *args, **kwargs): The arguments are the same as for :py:func:`rsqrt`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute rsqrt') + raise AttributeError('_Symbol object has no attribute rsqrt') def cbrt(self, *args, **kwargs): """Convenience fluent method for :py:func:`cbrt`. @@ -842,7 +838,7 @@ def cbrt(self, *args, **kwargs): The arguments are the same as for :py:func:`cbrt`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute cqrt') + raise AttributeError('_Symbol object has no attribute cqrt') def rcbrt(self, *args, **kwargs): """Convenience fluent method for :py:func:`rcbrt`. @@ -850,7 +846,7 @@ def rcbrt(self, *args, **kwargs): The arguments are the same as for :py:func:`rcbrt`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute rcqrt') + raise AttributeError('_Symbol object has no attribute rcqrt') def square(self, *args, **kwargs): """Convenience fluent method for :py:func:`square`. @@ -858,7 +854,7 @@ def square(self, *args, **kwargs): The arguments are the same as for :py:func:`square`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute square') + raise AttributeError('_Symbol object has no attribute square') def reciprocal(self, *args, **kwargs): """Convenience fluent method for :py:func:`reciprocal`. @@ -866,7 +862,7 @@ def reciprocal(self, *args, **kwargs): The arguments are the same as for :py:func:`reciprocal`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute reciprocal') + raise AttributeError('_Symbol object has no attribute reciprocal') def relu(self, *args, **kwargs): """Convenience fluent method for :py:func:`relu`. @@ -874,7 +870,7 @@ def relu(self, *args, **kwargs): The arguments are the same as for :py:func:`relu`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute relu') + raise AttributeError('_Symbol object has no attribute relu') def sigmoid(self, *args, **kwargs): """Convenience fluent method for :py:func:`sigmoid`. @@ -882,7 +878,7 @@ def sigmoid(self, *args, **kwargs): The arguments are the same as for :py:func:`sigmoid`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute sigmoid') + raise AttributeError('_Symbol object has no attribute sigmoid') def softmax(self, *args, **kwargs): """Convenience fluent method for :py:func:`softmax`. @@ -890,7 +886,7 @@ def softmax(self, *args, **kwargs): The arguments are the same as for :py:func:`softmax`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute softmax') + raise AttributeError('_Symbol object has no attribute softmax') def log_softmax(self, *args, **kwargs): """Convenience fluent method for :py:func:`log_softmax`. @@ -898,7 +894,7 @@ def log_softmax(self, *args, **kwargs): The arguments are the same as for :py:func:`log_softmax`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute log_softmax') + raise AttributeError('_Symbol object has no attribute log_softmax') def softmin(self, *args, **kwargs): """Convenience fluent method for :py:func:`softmin`. @@ -906,7 +902,7 @@ def softmin(self, *args, **kwargs): The arguments are the same as for :py:func:`softmin`, with this array as data. """ - raise AttributeError('_NumpySymbol object has no attribute softmin') + raise AttributeError('_Symbol object has no attribute softmin') @use_np_compat def squeeze(self, *args, **kwargs): @@ -918,12 +914,13 @@ def squeeze(self, *args, **kwargs): raise NotImplementedError def broadcast_to(self, *args, **kwargs): - raise AttributeError('_NumpySymbol object has no attribute broadcast_to') + raise AttributeError('_Symbol object has no attribute broadcast_to') def broadcast_like(self, *args, **kwargs): - raise AttributeError('_NumpySymbol object has no attribute broadcast_like') + raise AttributeError('_Symbol object has no attribute broadcast_like') +@set_module('mxnet.symbol.numpy') @use_np_compat def zeros(shape, dtype=_np.float32, **kwargs): """Return a new array of given shape and type, filled with zeros. @@ -952,9 +949,10 @@ def zeros(shape, dtype=_np.float32, **kwargs): if ctx is None: ctx = current_context() dtype = _np.float32 if dtype is None else dtype - return _internal._np_zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs) + return _npi.zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs) +@set_module('mxnet.symbol.numpy') @use_np_compat def ones(shape, dtype=None, **kwargs): """Return a new array of given shape and type, filled with zeros. @@ -983,7 +981,7 @@ def ones(shape, dtype=None, **kwargs): if ctx is None: ctx = current_context() dtype = _np.float32 if dtype is None else dtype - return _internal._np_ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs) + return _npi.ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs) #pylint: disable= too-many-arguments, no-member, protected-access @@ -1035,16 +1033,16 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou #pylint: enable= too-many-arguments, no-member, protected-access +@set_module('mxnet.symbol.numpy') @use_np_compat def maximum(x1, x2, out=None): - return _ufunc_helper(x1, x2, _internal._np_maximum, _np.maximum, - _internal._np_maximum_scalar, None, out) + return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out) +@set_module('mxnet.symbol.numpy') @use_np_compat def minimum(x1, x2, out=None): - return _ufunc_helper(x1, x2, _internal._np_minimum, _np.minimum, - _internal._np_minimum_scalar, None, out) + return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out) -_set_np_symbol_class(_NumpySymbol) +_set_np_symbol_class(_Symbol) diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py index b8f10b343430..869fdeb276b9 100644 --- a/python/mxnet/symbol/numpy/linalg.py +++ b/python/mxnet/symbol/numpy/linalg.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. -"""numpy.linalg namespace for operators used in Gluon APIs dispatched by F=symbol module.""" +"""Namespace for operators used in Gluon dispatched by F=symbol.""" __all__ = [] diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py index 79c73d871dd8..869fdeb276b9 100644 --- a/python/mxnet/symbol/numpy/random.py +++ b/python/mxnet/symbol/numpy/random.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. -"""numpy.random namespace for operators used in Gluon APIs dispatched by F=symbol module.""" +"""Namespace for operators used in Gluon dispatched by F=symbol.""" __all__ = [] diff --git a/python/mxnet/symbol/numpy_extension/__init__.py b/python/mxnet/symbol/numpy_extension/__init__.py new file mode 100644 index 000000000000..a718274ae9ed --- /dev/null +++ b/python/mxnet/symbol/numpy_extension/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Module for the ops not belonging to the official numpy package.""" + +from . import _op +from . import _register +from ._op import * # pylint: disable=wildcard-import + +__all__ = _op.__all__ diff --git a/python/mxnet/symbol/numpy_extension/_op.py b/python/mxnet/symbol/numpy_extension/_op.py new file mode 100644 index 000000000000..82eaa8e6ec9f --- /dev/null +++ b/python/mxnet/symbol/numpy_extension/_op.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Namespace for operators not belonging to the official numpy package +used in Gluon APIs dispatched by F=symbol module.""" + +__all__ = [] diff --git a/python/mxnet/symbol/numpy_extension/_register.py b/python/mxnet/symbol/numpy_extension/_register.py new file mode 100644 index 000000000000..b118987b1fd3 --- /dev/null +++ b/python/mxnet/symbol/numpy_extension/_register.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Registering numpy_extension ops.""" + +from ...base import _init_np_op_module +from ..register import _make_symbol_function + +_init_np_op_module(root_module_name='mxnet', np_module_name='numpy_extension', + mx_module_name='symbol', make_op_func=_make_symbol_function) diff --git a/python/mxnet/symbol/register.py b/python/mxnet/symbol/register.py index ac59f8b97f15..a835e2e4d339 100644 --- a/python/mxnet/symbol/register.py +++ b/python/mxnet/symbol/register.py @@ -27,12 +27,58 @@ from ..attribute import AttrScope from ..base import mx_uint, check_call, _LIB, py_str from ..symbol_doc import _build_doc -from ..base import _Null, _init_op_module +from ..base import _Null, _init_op_module, _is_np_op from ..name import NameManager # pylint: enable=unused-import -def _generate_symbol_function_code(handle, name, func_name, signature_only=False): +def _verify_np_symbol(op_name, func_name, sym): + """Verify if the sym is a numpy symbol. + + Parameters + ---------- + op_name : str + Operator full name registered in backend. + func_name : str + Operator name exposed to users. This is usually the name by stripping off + the prefix of the full operator names registered in backend. + sym : symbol to be verified + """ + from .numpy._symbol import _Symbol as np_symbol + if not isinstance(sym, np_symbol): + raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. ' + 'This is a numpy operator which can only accept ' + 'MXNet numpy ndarrays, while received a classic ndarray. ' + 'Please call `as_np_ndarray()` upon the classic ndarray to ' + 'convert it to an MXNet numpy ndarray, and then feed the converted ' + 'array to this operator.' + .format(op_name, func_name)) + + +def _verify_classic_symbol(op_name, func_name, sym): + """Verify if the sym is a classic symbol. + + Parameters + ---------- + op_name : str + Operator full name registered in backend. + func_name : str + Operator name exposed to users. This is usually the name by stripping off + the prefix of the full operator names registered in backend. + sym : symbol to be verified + """ + from .numpy._symbol import _Symbol as np_symbol + if isinstance(sym, np_symbol): + raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. ' + 'This is a classic operator which can only accept ' + 'classic ndarrays, while received an MXNet numpy ndarray. ' + 'Please call `as_classic_ndarray()` upon the numpy ndarray to ' + 'convert it to a classic ndarray, and then feed the converted ' + 'array to this operator.' + .format(op_name, func_name)) + + +def _generate_symbol_function_code(handle, op_name, func_name, signature_only=False): """Generate function for symbol op by handle and function name.""" real_name = ctypes.c_char_p() desc = ctypes.c_char_p() @@ -56,7 +102,7 @@ def _generate_symbol_function_code(handle, name, func_name, signature_only=False arg_types = [py_str(arg_types[i]) for i in range(narg)] key_var_num_args = py_str(key_var_num_args.value) ret_type = py_str(ret_type.value) if ret_type.value is not None else '' - doc_str = _build_doc(name, + doc_str = _build_doc(op_name, py_str(desc.value), arg_names, arg_types, @@ -95,6 +141,8 @@ def _generate_symbol_function_code(handle, name, func_name, signature_only=False signature.append('**kwargs') signature = ndsignature + signature + is_np_op = _is_np_op(op_name) + verify_symbol_fn = _verify_np_symbol.__name__ if is_np_op else _verify_classic_symbol.__name__ code = [] if arr_name: code.append(""" @@ -106,7 +154,8 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name)) assert isinstance(i, SymbolBase), \\ "Positional arguments must be Symbol instances, " \\ "but got %s"%str(i) - sym_args.append(i)""".format(arr_name)) + {}('{}', '{}', i) + sym_args.append(i)""".format(arr_name, verify_symbol_fn, op_name, func_name)) if dtype_name is not None: code.append(""" if '%s' in kwargs: @@ -128,9 +177,10 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name)) for k, v in kwargs.items(): if isinstance(v, SymbolBase): sym_kwargs[k] = v + %s('%s', '%s', v) else: keys.append(k) - vals.append(v)"""%(func_name.lower())) + vals.append(v)"""%(func_name.lower(), verify_symbol_fn, op_name, func_name)) if key_var_num_args: # pylint: disable=using-constant-test code.append(""" if '%s' not in kwargs: @@ -139,8 +189,8 @@ def %s(*%s, **kwargs):"""%(func_name, arr_name)) key_var_num_args, key_var_num_args)) code.append(""" - return _symbol_creator(%d, sym_args, sym_kwargs, keys, vals, name)"""%( - handle.value)) + return _symbol_creator(%d, sym_args, sym_kwargs, keys, vals, name, %s)"""%( + handle.value, str(is_np_op))) else: code.append(""" def %s(%s):"""%(func_name, ', '.join(signature))) @@ -155,9 +205,10 @@ def %s(%s):"""%(func_name, ', '.join(signature))) for _k, _v in kwargs.items(): if isinstance(_v, SymbolBase): sym_kwargs[_k] = _v + {}('{}', '{}', _v) else: _keys.append(_k) - _vals.append(_v)""") + _vals.append(_v)""".format(verify_symbol_fn, op_name, func_name)) # NDArray args for name in ndarg_names: # pylint: disable=redefined-argument-from-local code.append(""" @@ -165,6 +216,9 @@ def %s(%s):"""%(func_name, ', '.join(signature))) assert isinstance({name}, SymbolBase), \\ "Argument {name} must be Symbol instances, but got %s"%str({name}) sym_kwargs['{name}'] = {name}""".format(name=name)) + code.append(""" + {}('{}', '{}', {name}) + """.format(verify_symbol_fn, op_name, func_name, name=name)) # kwargs for name in kwarg_names: # pylint: disable=redefined-argument-from-local code.append(""" @@ -182,8 +236,8 @@ def %s(%s):"""%(func_name, ', '.join(signature))) if not hasattr(NameManager._current, "value"): NameManager._current.value = NameManager() name = NameManager._current.value.get(name, '%s') - return _symbol_creator(%d, None, sym_kwargs, _keys, _vals, name)"""%( - func_name.lower(), handle.value)) + return _symbol_creator(%d, None, sym_kwargs, _keys, _vals, name, %s)"""%( + func_name.lower(), handle.value, str(is_np_op))) if signature_only: code.append(""" diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 7be042cd7671..96397f68cc06 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -62,15 +62,11 @@ class Symbol(SymbolBase): __array_priority__ = 1000.0 def as_np_ndarray(self): - """Convert mxnet.symbol.Symbol to _NumpySymbol.""" - from .numpy import _NumpySymbol + """Convert mx.sym.Symbol to mx.sym.np._Symbol.""" + from .numpy import _Symbol hdl = SymbolHandle() check_call(_LIB.MXShallowCopySymbol(self.handle, ctypes.byref(hdl))) - return _NumpySymbol(hdl) - - def _is_np_compat(self): - """Always returns False except for mxnet.symbol.numpy._NumpySymbol.""" - return False + return _Symbol(hdl) def __repr__(self): """Gets a string representation of the symbol.""" @@ -110,8 +106,6 @@ def __add__(self, other): Scalar input is supported. Broadcasting is not supported. Use `broadcast_add` instead. """ if isinstance(other, Symbol): - if other._is_np_compat(): - return other.__add__(self) return _internal._Plus(self, other) if isinstance(other, Number): return _internal._PlusScalar(self, scalar=other) @@ -127,8 +121,6 @@ def __iadd__(self, other): raise NotImplementedForSymbol(self.__iadd__, '+=', other, 1) def __radd__(self, other): - if isinstance(other, Symbol) and other._is_np_compat(): - return other.__add__(self) return self.__add__(other) def __sub__(self, other): @@ -137,8 +129,6 @@ def __sub__(self, other): Scalar input is supported. Broadcasting is not supported. Use `broadcast_sub` instead. """ if isinstance(other, Symbol): - if other._is_np_compat(): - return other.__rsub__(self) return _internal._Minus(self, other) if isinstance(other, Number): return _internal._MinusScalar(self, scalar=other) @@ -161,7 +151,7 @@ def __rsub__(self, other): array([[-2., -2., -2.], [-2., -2., -2.]], dtype=float32) """ - if isinstance(other, Symbol) and other._is_np_compat(): + if isinstance(other, Symbol): return other.__sub__(self) if isinstance(other, Number): return _internal._RMinusScalar(self, scalar=other) @@ -174,8 +164,6 @@ def __mul__(self, other): Scalar input is supported. Broadcasting is not supported. Use `broadcast_mul` instead. """ if isinstance(other, Symbol): - if other._is_np_compat(): - return other.__mul__(self) return _internal._Mul(self, other) if isinstance(other, Number): return _internal._MulScalar(self, scalar=other) @@ -186,8 +174,6 @@ def __imul__(self, other): raise NotImplementedForSymbol(self.__imul__, '*=', other) def __rmul__(self, other): - if isinstance(other, Symbol) and other._is_np_compat(): - return other.__mul__(self) return self.__mul__(other) def __div__(self, other): @@ -196,8 +182,6 @@ def __div__(self, other): Scalar input is supported. Broadcasting is not supported. Use `broadcast_div` instead. """ if isinstance(other, Symbol): - if other._is_np_compat(): - return other.__rtruediv__(self) return _internal._Div(self, other) if isinstance(other, Number): return _internal._DivScalar(self, scalar=other) @@ -217,7 +201,7 @@ def __rdiv__(self, other): array([[ 0.33333334, 0.33333334, 0.33333334], [ 0.33333334, 0.33333334, 0.33333334]], dtype=float32) """ - if isinstance(other, Symbol) and other._is_np_compat(): + if isinstance(other, Symbol): return other.__truediv__(self) if isinstance(other, Number): return _internal._RDivScalar(self, scalar=other) @@ -230,8 +214,6 @@ def __mod__(self, other): Scalar input is supported. Broadcasting is not supported. Use `broadcast_mod` instead. """ if isinstance(other, Symbol): - if other._is_np_compat(): - return other.__rmod__(self) return _internal._Mod(self, other) if isinstance(other, Number): return _internal._ModScalar(self, scalar=other) @@ -251,7 +233,7 @@ def __rmod__(self, other): array([[ 1., 1., 1., [ 1., 1., 1., dtype=float32) """ - if isinstance(other, Symbol) and other._is_np_compat(): + if isinstance(other, Symbol): return other.__mod__(self) if isinstance(other, Number): return _internal._RModScalar(self, scalar=other) @@ -276,8 +258,6 @@ def __pow__(self, other): Scalar input is supported. Broadcasting is not supported. Use `broadcast_pow` instead. """ if isinstance(other, Symbol): - if other._is_np_compat(): - return other.__rpow__(self) return _internal._Power(self, other) if isinstance(other, Number): return _internal._PowerScalar(self, scalar=other) @@ -287,8 +267,6 @@ def __pow__(self, other): def __rpow__(self, other): """x.__rpow__(y) <=> y ** x""" if isinstance(other, Symbol): - if other._is_np_compat(): - return other.__pow__(self) return other.__pow__(self) elif isinstance(other, Number): return _internal._rpower_scalar(self, scalar=other) @@ -348,8 +326,6 @@ def __eq__(self, other): Scalar input is supported. Broadcasting is not supported. Use `broadcast_equal` instead. """ if isinstance(other, Symbol): - if other._is_np_compat(): - return other.__eq__(self) return _internal._equal(self, other) if isinstance(other, numeric_types): return _internal._equal_scalar(self, scalar=other) @@ -362,8 +338,6 @@ def __ne__(self, other): Scalar input is supported. Broadcasting is not supported. Use `broadcast_not_equal` instead. """ if isinstance(other, Symbol): - if other._is_np_compat(): - return other.__ne__(self) return _internal._not_equal(self, other) if isinstance(other, numeric_types): return _internal._not_equal_scalar(self, scalar=other) @@ -376,8 +350,6 @@ def __gt__(self, other): Scalar input is supported. Broadcasting is not supported. Use `broadcast_greater` instead. """ if isinstance(other, Symbol): - if other._is_np_compat(): - return other.__lt__(self) return _internal._greater(self, other) if isinstance(other, numeric_types): return _internal._greater_scalar(self, scalar=other) @@ -390,8 +362,6 @@ def __ge__(self, other): Scalar input is supported. Broadcasting is not supported. Use `broadcast_greater_equal` instead. """ if isinstance(other, Symbol): - if other._is_np_compat(): - return other.__le__(self) return _internal._greater_equal(self, other) if isinstance(other, numeric_types): return _internal._greater_equal_scalar(self, scalar=other) @@ -404,8 +374,6 @@ def __lt__(self, other): Scalar input is supported. Broadcasting is not supported. Use `broadcast_lesser` instead. """ if isinstance(other, Symbol): - if other._is_np_compat(): - return other.__gt__(self) return _internal._lesser(self, other) if isinstance(other, numeric_types): return _internal._lesser_scalar(self, scalar=other) @@ -418,8 +386,6 @@ def __le__(self, other): Scalar input is supported. Broadcasting is not supported. Use `broadcast_lesser_equal` instead. """ if isinstance(other, Symbol): - if other._is_np_compat(): - return other.__ge__(self) return _internal._lesser_equal(self, other) if isinstance(other, numeric_types): return _internal._lesser_equal_scalar(self, scalar=other) @@ -2720,8 +2686,12 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, Variable = var -def Group(symbols): +def Group(symbols, create_fn=Symbol): """Creates a symbol that contains a collection of other symbols, grouped together. + A classic symbol (`mx.sym.Symbol`) will be returned if all the symbols in the list + are of that type; a numpy symbol (`mx.sym.np._Symbol`) will be returned if all the + symbols in the list are of that type. A type error will be raised if a list of mixed + classic and numpy symbols are provided. Example ------- @@ -2735,6 +2705,9 @@ def Group(symbols): symbols : list List of symbols to be grouped. + create_fn : mx.sym.Symbol or mx.sym.np._Symbol + Symbol class for creating the grouped symbol. + Returns ------- sym : Symbol @@ -2746,7 +2719,7 @@ def Group(symbols): check_call(_LIB.MXSymbolCreateGroup( mx_uint(len(symbols)), c_handle_array(symbols), ctypes.byref(handle))) - return Symbol(handle) + return create_fn(handle) def load(fname): diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 91f38ff176e3..925007ddd2f0 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -48,6 +48,7 @@ from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID from .ndarray import array from .symbol import Symbol +from .symbol.numpy import _Symbol as np_symbol def default_context(): @@ -946,7 +947,12 @@ def random_projection(shape): input_shape = {k: v.shape for k, v in location.items()} _, out_shape, _ = sym.infer_shape(**input_shape) proj = mx.sym.Variable("__random_proj") + is_np_sym = True if isinstance(sym, np_symbol) else False + if is_np_sym: # convert to np symbol for using element-wise multiplication + proj = proj.as_np_ndarray() out = sym * proj + if is_np_sym: # convert to classic symbol so that make_loss can be used + out = out.as_classic_ndarray() out = mx.sym.make_loss(out) location = dict(list(location.items()) + diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index 82fe28bf38bc..233acc85f36b 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -163,21 +163,4 @@ inline void CopyAttr(const nnvm::IndexedGraph& idx, extern const std::vector kHiddenKeys; } // namespace mxnet -/*! - * An operator is considered as numpy compatible if it satisfies either one - * of the following conditions. - * 1. The op has the attribute mxnet::TIsNumpyCompatible> registered as True. - * 2. The op's name starts with the prefix _numpy_. - * The first condition is usually for the ops registered as internal ops, such - * as _np_add, _true_divide, etc. They are wrapped by some user-facing op - * APIs in the Python end. - * The second condition is for the ops registered in the backend while exposed - * directly to users as is, such as _numpy_sum etc. - */ -inline bool IsNumpyCompatOp(const nnvm::Op* op) { - static const auto& is_np_compat = - nnvm::Op::GetAttr("TIsNumpyCompatible"); - return is_np_compat.get(op, false); -} - #endif // MXNET_C_API_C_API_COMMON_H_ diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index f65c804070b7..c9c6000e2f6f 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -378,19 +378,3 @@ int MXAutogradGetSymbol(NDArrayHandle handle, SymbolHandle *out) { *out = reinterpret_cast(sym); API_END(); } - -int MXIsCachedOpOutputFromNumpyCompatOp(CachedOpHandle handle, - int output_idx, - int* is_from_np_op) { - API_BEGIN(); - CachedOpPtr op = *static_cast(handle); - const auto& output_entries = op->GetForwardSym().outputs; - CHECK_LT(output_idx, static_cast(output_entries.size())); - const nnvm::NodePtr& node_ptr = output_entries[output_idx].node; - if (node_ptr->is_variable()) { - *is_from_np_op = 0; - } else { - *is_from_np_op = (IsNumpyCompatOp(node_ptr->op()) ? 1 : 0); - } - API_END(); -} diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index 2c4d579f9f44..0f3d71dc1ed5 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -169,6 +169,7 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs, if (param.initial.has_value()) { LOG(FATAL) << "initial is not supported yet"; } + if (outputs[0].shape_.Size() == 0U) return; // zero-size tensor if (param.axis.has_value() && param.axis.value().ndim() == 0) { UnaryOp::IdentityCompute(attrs, ctx, inputs, req, outputs); } diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc index c1c11324a9aa..a72efd9a4d23 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc @@ -47,7 +47,7 @@ inline bool NumpySumType(const nnvm::NodeAttrs& attrs, return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; } -NNVM_REGISTER_OP(_numpy_sum) +NNVM_REGISTER_OP(_np_sum) .describe(R"code()code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) @@ -61,14 +61,13 @@ NNVM_REGISTER_OP(_numpy_sum) .add_argument("a", "NDArray-or-Symbol", "The input") .add_arguments(NumpyReduceAxesParam::__FIELDS__()) .set_attr("FCompute", NumpyReduceAxesCompute) -.set_attr("TIsNumpyCompatible", true) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"}); +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_np_sum"}); -NNVM_REGISTER_OP(_backward_numpy_sum) +NNVM_REGISTER_OP(_backward_np_sum) .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("TIsBackward", true) @@ -102,7 +101,7 @@ inline bool NumpyMeanType(const nnvm::NodeAttrs& attrs, return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; } -NNVM_REGISTER_OP(_numpy_mean) +NNVM_REGISTER_OP(_np_mean) .describe(R"code()code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) @@ -116,14 +115,13 @@ NNVM_REGISTER_OP(_numpy_mean) .add_argument("a", "NDArray-or-Symbol", "The input") .add_arguments(NumpyReduceAxesParam::__FIELDS__()) .set_attr("FCompute", NumpyReduceAxesCompute) -.set_attr("TIsNumpyCompatible", true) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_numpy_mean"}); +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_np_mean"}); -NNVM_REGISTER_OP(_backward_numpy_mean) +NNVM_REGISTER_OP(_backward_np_mean) .set_num_outputs(1) .set_attr_parser(ParamParser) .set_attr("TIsBackward", true) diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cu b/src/operator/numpy/np_broadcast_reduce_op_value.cu index f16745d4c8b4..2f50738832fe 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cu +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cu @@ -27,16 +27,16 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_numpy_sum) +NNVM_REGISTER_OP(_np_sum) .set_attr("FCompute", NumpyReduceAxesCompute); -NNVM_REGISTER_OP(_backward_numpy_sum) +NNVM_REGISTER_OP(_backward_np_sum) .set_attr("FCompute", NumpyReduceAxesBackwardUseNone); -NNVM_REGISTER_OP(_numpy_mean) +NNVM_REGISTER_OP(_np_mean) .set_attr("FCompute", NumpyReduceAxesCompute); -NNVM_REGISTER_OP(_backward_numpy_mean) +NNVM_REGISTER_OP(_backward_np_mean) .set_attr("FCompute", NumpyReduceAxesBackwardUseNone); diff --git a/src/operator/numpy/np_dot-inl.h b/src/operator/numpy/np_dot-inl.h index 8fc7d5d89fee..2f7c589c0127 100644 --- a/src/operator/numpy/np_dot-inl.h +++ b/src/operator/numpy/np_dot-inl.h @@ -95,6 +95,7 @@ inline void NumpyDotForward(const nnvm::NodeAttrs& attrs, const TBlob& a = inputs[0]; const TBlob& b = inputs[1]; const TBlob& out = outputs[0]; + if (out.shape_.Size() == 0U) return; // zero-size tensor, no need to launch kernel const mxnet::TShape a_shape = a.shape_; const mxnet::TShape b_shape = b.shape_; @@ -107,7 +108,13 @@ inline void NumpyDotForward(const nnvm::NodeAttrs& attrs, (out.type_flag_ == kFloat16 && ctx.run_ctx.ctx.dev_mask() == mshadow::gpu::kDevMask)) << "dot only supports float32/float64 for CPU, and float16/float32/float64 for GPU"; MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, { - if (a_shape.ndim() == 1 && b_shape.ndim() == 1) { + if (a_shape.Size() == 0U || b_shape.Size() == 0U) { + if (req[0] != kAddTo) { + Tensor out_data = out.get_with_shape( + Shape1(out.shape_.Size()), s); + out_data = static_cast(0); + } + } else if (a_shape.ndim() == 1 && b_shape.ndim() == 1) { // Case 1: both 1-D arrays, inner product of vectors if (out.type_flag_ == kFloat16) { MMImpl(ctx, a, b, out, req[0]); @@ -158,12 +165,14 @@ inline void NumpyDotBackward(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 2U); const TBlob& ograd = inputs[0]; + if (ograd.shape_.Size() == 0U) return; const TBlob& a = inputs[1]; const TBlob& b = inputs[2]; const TBlob& grad_a = outputs[0]; const TBlob& grad_b = outputs[1]; const mxnet::TShape a_shape = a.shape_; const mxnet::TShape b_shape = b.shape_; + if (a_shape.Size() == 0U || b_shape.Size() == 0U) return; Stream *s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH(ograd.type_flag_, DType, { diff --git a/src/operator/numpy/np_dot.cc b/src/operator/numpy/np_dot.cc index c25953f219d9..bcb310fda4b6 100644 --- a/src/operator/numpy/np_dot.cc +++ b/src/operator/numpy/np_dot.cc @@ -71,7 +71,7 @@ inline bool NumpyDotShape(const nnvm::NodeAttrs& attrs, return true; } -NNVM_REGISTER_OP(_numpy_dot) +NNVM_REGISTER_OP(_np_dot) .describe(R"doc(Dot product of two arrays. Specifically, - If both a and b are 1-D arrays, it is inner product of vectors. diff --git a/src/operator/numpy/np_dot.cu b/src/operator/numpy/np_dot.cu index 2accd9d8badb..9a9c69aa98e5 100644 --- a/src/operator/numpy/np_dot.cu +++ b/src/operator/numpy/np_dot.cu @@ -27,7 +27,7 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_numpy_dot) +NNVM_REGISTER_OP(_np_dot) .set_attr("FCompute", NumpyDotForward); NNVM_REGISTER_OP(_backward_np_dot) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index 5d36c29fc331..2ffa3b8f85aa 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -57,12 +57,11 @@ bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs, [](const NodeAttrs& attrs){ \ return std::vector >{{0, 0}}; \ }) \ - .set_attr("TIsNumpyCompatible", true) \ .add_argument("data", "NDArray-or-Symbol", "source input") \ .add_argument("scalar", "float", "scalar input") -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_add) +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_add) .describe(R"code(Add arguments element-wise with broadcasting if necessary. Example:: @@ -78,10 +77,9 @@ Example:: )code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"}) -.set_attr("TIsNumpyCompatible", true); +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"}); -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_subtract) +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_subtract) .describe(R"code(Subtract arguments element-wise with broadcasting if necessary. Example:: @@ -97,10 +95,9 @@ Example:: )code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"}) -.set_attr("TIsNumpyCompatible", true); +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"}); -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_multiply) +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_multiply) .describe(R"code(Multiply arguments with broadcasting if necessary. Example:: @@ -116,10 +113,9 @@ Example:: )code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"}) -.set_attr("TIsNumpyCompatible", true); +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"}); -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_mod) +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod) .describe(R"code(Return element-wise remainder of division. It is equivalent to the Python modulus operator``x1 % x2`` and has the same sign as the divisor x2. @@ -136,10 +132,9 @@ Example:: )code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mod"}) -.set_attr("TIsNumpyCompatible", true); +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mod"}); -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_power) +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_power) .describe(R"code(First array elements raised to powers from second array, element-wise. Raise each base in x1 to the positionally-corresponding power in x2. x1 and x2 must be @@ -158,56 +153,53 @@ Example:: )code" ADD_FILELINE) .set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"}) -.set_attr("TIsNumpyCompatible", true); +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"}); -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_maximum) +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_maximum) .describe(R"code()code" ADD_FILELINE) -.set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("TIsNumpyCompatible", true); +.set_attr("FCompute", BinaryBroadcastCompute); -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_minimum) +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_minimum) .describe(R"code()code" ADD_FILELINE) -.set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("TIsNumpyCompatible", true); +.set_attr("FCompute", BinaryBroadcastCompute); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_add_scalar) +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseNone{"_copy"}); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_subtract_scalar) +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_subtract_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseNone{"_copy"}); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_rsubtract_scalar) +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rsubtract_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseNone{"negative"}); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_multiply_scalar) +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_multiply_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_mul_scalar"}); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_mod_scalar) +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_mod_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_mod_scalar"}); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_rmod_scalar) +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rmod_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_rmod_scalar"}); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_power_scalar) +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_power_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_power_scalar"}); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_rpower_scalar) +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rpower_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"}); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_maximum_scalar) +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_maximum_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_maximum_scalar"}); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_minimum_scalar) +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_minimum_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_minimum_scalar"}); diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index 26e2fceb839f..c858b3a4987a 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -27,55 +27,55 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_np_add) +NNVM_REGISTER_OP(_npi_add) .set_attr("FCompute", BinaryBroadcastCompute); -NNVM_REGISTER_OP(_np_subtract) +NNVM_REGISTER_OP(_npi_subtract) .set_attr("FCompute", BinaryBroadcastCompute); -NNVM_REGISTER_OP(_np_multiply) +NNVM_REGISTER_OP(_npi_multiply) .set_attr("FCompute", BinaryBroadcastCompute); -NNVM_REGISTER_OP(_np_mod) +NNVM_REGISTER_OP(_npi_mod) .set_attr("FCompute", BinaryBroadcastCompute); -NNVM_REGISTER_OP(_np_power) +NNVM_REGISTER_OP(_npi_power) .set_attr("FCompute", BinaryBroadcastCompute); -NNVM_REGISTER_OP(_np_maximum) +NNVM_REGISTER_OP(_npi_maximum) .set_attr("FCompute", BinaryBroadcastCompute); -NNVM_REGISTER_OP(_np_minimum) +NNVM_REGISTER_OP(_npi_minimum) .set_attr("FCompute", BinaryBroadcastCompute); -NNVM_REGISTER_OP(_np_add_scalar) +NNVM_REGISTER_OP(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); -NNVM_REGISTER_OP(_np_subtract_scalar) +NNVM_REGISTER_OP(_npi_subtract_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); -NNVM_REGISTER_OP(_np_rsubtract_scalar) +NNVM_REGISTER_OP(_npi_rsubtract_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); -NNVM_REGISTER_OP(_np_multiply_scalar) +NNVM_REGISTER_OP(_npi_multiply_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); -NNVM_REGISTER_OP(_np_mod_scalar) +NNVM_REGISTER_OP(_npi_mod_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); -NNVM_REGISTER_OP(_np_rmod_scalar) +NNVM_REGISTER_OP(_npi_rmod_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); -NNVM_REGISTER_OP(_np_power_scalar) +NNVM_REGISTER_OP(_npi_power_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); -NNVM_REGISTER_OP(_np_rpower_scalar) +NNVM_REGISTER_OP(_npi_rpower_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); -NNVM_REGISTER_OP(_np_maximum_scalar) +NNVM_REGISTER_OP(_npi_maximum_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); -NNVM_REGISTER_OP(_np_minimum_scalar) +NNVM_REGISTER_OP(_npi_minimum_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); } // namespace op diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cc b/src/operator/numpy/np_elemwise_unary_op_basic.cc index f31ed5e11f15..a64356e2a1aa 100644 --- a/src/operator/numpy/np_elemwise_unary_op_basic.cc +++ b/src/operator/numpy/np_elemwise_unary_op_basic.cc @@ -27,7 +27,7 @@ namespace mxnet { namespace op { -MXNET_OPERATOR_REGISTER_UNARY(_numpy__ext_relu) +MXNET_OPERATOR_REGISTER_UNARY(_npe_relu) .describe(R"code(Computes rectified linear activation. .. math:: @@ -35,10 +35,9 @@ MXNET_OPERATOR_REGISTER_UNARY(_numpy__ext_relu) )code" ADD_FILELINE) .set_attr("FCompute", UnaryOp::Compute) -.set_attr("FGradient", ElemwiseGradUseOut{"_backward_relu"}) -.set_attr("TIsNumpyCompatible", true); +.set_attr("FGradient", ElemwiseGradUseOut{"_backward_relu"}); -MXNET_OPERATOR_REGISTER_UNARY(_numpy__ext_sigmoid) +MXNET_OPERATOR_REGISTER_UNARY(_npe_sigmoid) .describe(R"code(Computes sigmoid of x element-wise. .. math:: @@ -46,18 +45,29 @@ MXNET_OPERATOR_REGISTER_UNARY(_numpy__ext_sigmoid) )code" ADD_FILELINE) .set_attr("FCompute", UnaryOp::Compute) -.set_attr("FGradient", ElemwiseGradUseOut{"_backward_sigmoid"}) -.set_attr("TIsNumpyCompatible", true); +.set_attr("FGradient", ElemwiseGradUseOut{"_backward_sigmoid"}); -MXNET_OPERATOR_REGISTER_UNARY(_np_copy) -.MXNET_DESCRIBE("Returns a copy of the input.") +NNVM_REGISTER_OP(_np_copy) +.describe(R"code(Return an array copy of the given object.)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) .set_attr("FCompute", UnaryOp::IdentityCompute) .set_attr("FInplaceIdentity", [](const NodeAttrs& attrs){ return std::vector{true}; }) .set_attr("FGradient", ElemwiseGradUseNone{"_copy"}) -.set_attr("TIsNumpyCompatible", true); +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a"}; + }) +.add_argument("a", "NDArray-or-Symbol", "The input"); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cu b/src/operator/numpy/np_elemwise_unary_op_basic.cu index 9f108f75fc15..600f19880e0c 100644 --- a/src/operator/numpy/np_elemwise_unary_op_basic.cu +++ b/src/operator/numpy/np_elemwise_unary_op_basic.cu @@ -26,10 +26,10 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_numpy__ext_relu) +NNVM_REGISTER_OP(_npe_relu) .set_attr("FCompute", UnaryOp::Compute); -NNVM_REGISTER_OP(_numpy__ext_sigmoid) +NNVM_REGISTER_OP(_npe_sigmoid) .set_attr("FCompute", UnaryOp::Compute); NNVM_REGISTER_OP(_np_copy) diff --git a/src/operator/numpy/np_init_op.cc b/src/operator/numpy/np_init_op.cc index 0abd010dfe73..83a44c8ae280 100644 --- a/src/operator/numpy/np_init_op.cc +++ b/src/operator/numpy/np_init_op.cc @@ -28,7 +28,7 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_np_zeros) +NNVM_REGISTER_OP(_npi_zeros) .describe("Return a new array of given shape, type, and context, filled with zeros.") .set_num_inputs(0) .set_num_outputs(1) @@ -37,10 +37,9 @@ NNVM_REGISTER_OP(_np_zeros) .set_attr("FInferType", InitType) .set_attr("FInferStorageType", InitStorageType) .set_attr("FCompute", FillCompute) -.set_attr("TIsNumpyCompatible", true) .add_arguments(InitOpParam::__FIELDS__()); -NNVM_REGISTER_OP(_np_ones) +NNVM_REGISTER_OP(_npi_ones) .describe("Return a new array of given shape, type, and context, filled with ones.") .set_num_inputs(0) .set_num_outputs(1) @@ -48,8 +47,65 @@ NNVM_REGISTER_OP(_np_ones) .set_attr("FInferShape", InitShape) .set_attr("FInferType", InitType) .set_attr("FCompute", FillCompute) -.set_attr("TIsNumpyCompatible", true) .add_arguments(InitOpParam::__FIELDS__()); +NNVM_REGISTER_OP(_np_zeros_like) +.describe(R"code(Return an array of zeros with the same shape and type as a given array. + +Examples:: + + x = [[ 1., 1., 1.], + [ 1., 1., 1.]] + + zeros_like(x) = [[ 0., 0., 0.], + [ 0., 0., 0.]] + +)code") +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FIgnoreInputs", + [](const NodeAttrs& attrs) { + return std::vector(1, 0); + }) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a"}; + }) +.set_attr("FCompute", FillCompute) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("a", "NDArray-or-Symbol", + "The shape and data-type of a define these same attributes of the returned array."); + +NNVM_REGISTER_OP(_np_ones_like) +.describe(R"code(Return an array of ones with the same shape and type as a given array. + +Examples:: + + x = [[ 0., 0., 0.], + [ 0., 0., 0.]] + + ones_like(x) = [[ 1., 1., 1.], + [ 1., 1., 1.]] + +)code") +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FIgnoreInputs", + [](const NodeAttrs& attrs) { + return std::vector(1, 0); + }) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a"}; + }) +.set_attr("FCompute", FillCompute) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("a", "NDArray-or-Symbol", + "The shape and data-type of a define these same attributes of the returned array."); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_init_op.cu b/src/operator/numpy/np_init_op.cu index 4e6f81d48b45..2eb8ed6d83b7 100644 --- a/src/operator/numpy/np_init_op.cu +++ b/src/operator/numpy/np_init_op.cu @@ -28,10 +28,16 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_np_zeros) +NNVM_REGISTER_OP(_npi_zeros) .set_attr("FCompute", FillCompute); -NNVM_REGISTER_OP(_np_ones) +NNVM_REGISTER_OP(_npi_ones) +.set_attr("FCompute", FillCompute); + +NNVM_REGISTER_OP(_np_zeros_like) +.set_attr("FCompute", FillCompute); + +NNVM_REGISTER_OP(_np_ones_like) .set_attr("FCompute", FillCompute); } // namespace op diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 215b1c5a8c87..6e93442619b5 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -54,7 +54,7 @@ bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs, return shape_is_known(ret); } -NNVM_REGISTER_OP(_numpy_transpose) +NNVM_REGISTER_OP(_np_transpose) .describe(R"code(Permute the dimensions of an array. Examples:: @@ -105,7 +105,6 @@ Examples:: } }) .set_attr("FCompute", NumpyTranspose) -.set_attr("TIsNumpyCompatible", true) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { return std::vector{"a"}; @@ -189,7 +188,7 @@ bool NumpyReshapeShape(const nnvm::NodeAttrs& attrs, return success; } -NNVM_REGISTER_OP(_numpy_reshape) +NNVM_REGISTER_OP(_np_reshape) .describe(R"code()code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) @@ -210,7 +209,6 @@ NNVM_REGISTER_OP(_numpy_reshape) [](const NodeAttrs& attrs) { return std::vector{"a"}; }) -.set_attr("TIsNumpyCompatible", true) .add_argument("a", "NDArray-or-Symbol", "Array to be reshaped.") .add_arguments(NumpyReshapeParam::__FIELDS__()); diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index 9753566aebe9..5bf36e54e098 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -27,10 +27,10 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_numpy_transpose) +NNVM_REGISTER_OP(_np_transpose) .set_attr("FCompute", NumpyTranspose); -NNVM_REGISTER_OP(_numpy_reshape) +NNVM_REGISTER_OP(_np_reshape) .set_attr("FCompute", UnaryOp::IdentityCompute); } // namespace op diff --git a/src/operator/numpy/np_true_divide.cc b/src/operator/numpy/np_true_divide.cc index 3bafa261e20f..429762778700 100644 --- a/src/operator/numpy/np_true_divide.cc +++ b/src/operator/numpy/np_true_divide.cc @@ -54,7 +54,7 @@ bool TrueDivideType(const nnvm::NodeAttrs& attrs, return true; } -NNVM_REGISTER_OP(_true_divide) +NNVM_REGISTER_OP(_npi_true_divide) .describe(R"code( Returns a true division of the inputs, element-wise. @@ -86,11 +86,10 @@ Example:: }) .set_attr("FCompute", BinaryBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_div"}) -.set_attr("TIsNumpyCompatible", true) .add_argument("lhs", "NDArray-or-Symbol", "Dividend array") .add_argument("rhs", "NDArray-or-Symbol", "Divisor array"); -NNVM_REGISTER_OP(_true_divide_scalar) +NNVM_REGISTER_OP(_npi_true_divide_scalar) .set_num_inputs(1) .set_num_outputs(1) .set_attr_parser([](NodeAttrs* attrs) { @@ -104,11 +103,10 @@ NNVM_REGISTER_OP(_true_divide_scalar) }) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_div_scalar"}) -.set_attr("TIsNumpyCompatible", true) .add_argument("data", "NDArray-or-Symbol", "source input") .add_argument("scalar", "float", "scalar input"); -NNVM_REGISTER_OP(_rtrue_divide_scalar) +NNVM_REGISTER_OP(_npi_rtrue_divide_scalar) .set_num_inputs(1) .set_num_outputs(1) .set_attr_parser([](NodeAttrs* attrs) { @@ -122,7 +120,6 @@ NNVM_REGISTER_OP(_rtrue_divide_scalar) }) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_rdiv_scalar"}) -.set_attr("TIsNumpyCompatible", true) .add_argument("data", "NDArray-or-Symbol", "source input") .add_argument("scalar", "float", "scalar input"); diff --git a/src/operator/numpy/np_true_divide.cu b/src/operator/numpy/np_true_divide.cu index cbc7cf94c109..be10c44f92a1 100644 --- a/src/operator/numpy/np_true_divide.cu +++ b/src/operator/numpy/np_true_divide.cu @@ -28,13 +28,13 @@ namespace mxnet { namespace op { -NNVM_REGISTER_OP(_true_divide) +NNVM_REGISTER_OP(_npi_true_divide) .set_attr("FCompute", BinaryBroadcastCompute); -NNVM_REGISTER_OP(_true_divide_scalar) +NNVM_REGISTER_OP(_npi_true_divide_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); -NNVM_REGISTER_OP(_rtrue_divide_scalar) +NNVM_REGISTER_OP(_npi_rtrue_divide_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); } // namespace op diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 141d153a207f..eb452346f6eb 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -20,7 +20,7 @@ from __future__ import division import numpy as _np import mxnet as mx -from mxnet import numpy as np +from mxnet import np from mxnet.gluon import HybridBlock from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, assert_exception from common import with_seed @@ -37,15 +37,15 @@ def test_array_creation(): mx_arr = np.array(src, dtype=dtype) assert mx_arr.context == mx.current_context() if isinstance(src, mx.nd.NDArray): - np_arr = _np.array(src.asnumpy(), dtype=dtype) + np_arr = _np.array(src.asnumpy(), dtype=dtype if dtype is not None else _np.float32) else: - np_arr = _np.array(src, dtype=dtype) - assert same(mx_arr.asnumpy(), np_arr) + np_arr = _np.array(src, dtype=dtype if dtype is not None else _np.float32) assert mx_arr.dtype == np_arr.dtype + assert same(mx_arr.asnumpy(), np_arr) @with_seed() -@mx.use_np_compat +@np.use_np_compat def test_zeros(): # test np.zeros in Gluon class TestZeros(HybridBlock): @@ -76,7 +76,7 @@ def check_zero_array_creation(shape, dtype): for shape in shapes: for dtype in dtypes: check_zero_array_creation(shape, dtype) - x = mx.nd.array(_np.random.uniform(size=shape), dtype=dtype) + x = np.array(_np.random.uniform(size=shape), dtype=dtype) if dtype is None: x = x.astype('float32') for hybridize in [True, False]: @@ -93,7 +93,7 @@ def check_zero_array_creation(shape, dtype): @with_seed() -@mx.use_np_compat +@np.use_np_compat def test_ones(): # test np.ones in Gluon class TestOnes(HybridBlock): @@ -141,7 +141,7 @@ def check_ones_array_creation(shape, dtype): @with_seed() -@mx.use_np_compat +@np.use_np_compat def test_ndarray_binary_element_wise_ops(): # Cannot test operators like >, because boolean arrays are not supported yet. np_op_map = {'+': _np.add, '*': _np.multiply, '-': _np.subtract, '/': _np.divide, @@ -241,23 +241,22 @@ def check_binary_op_result(shape1, shape2, op, dtype=None): np_out = get_np_ret(np_input1, np_input2, op) for hybridize in [True, False]: if scalar is None: - get_mx_ret = TestBinaryElementWiseOp(op) + get_mx_ret_np = TestBinaryElementWiseOp(op) + get_mx_ret_classic = TestBinaryElementWiseOp(op) if hybridize: - get_mx_ret.hybridize() - mx_out = get_mx_ret(mx_input1.as_np_ndarray(), mx_input2.as_np_ndarray()) + get_mx_ret_np.hybridize() + get_mx_ret_classic.hybridize() + mx_out = get_mx_ret_np(mx_input1.as_np_ndarray(), mx_input2.as_np_ndarray()) assert type(mx_out) == np.ndarray assert np_out.shape == mx_out.shape assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5) - mx_out = get_mx_ret(mx_input1, mx_input2.as_np_ndarray()) - assert type(mx_out) == np.ndarray - assert np_out.shape == mx_out.shape - assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5) - - mx_out = get_mx_ret(mx_input1.as_np_ndarray(), mx_input2) - assert type(mx_out) == np.ndarray - assert np_out.shape == mx_out.shape - assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5) + if mx_input1.shape == mx_input2.shape: + # classic symbol does not support element-wise binary broadcast. + mx_out = get_mx_ret_classic(mx_input1, mx_input2) + assert type(mx_out) == mx.nd.NDArray + assert np_out.shape == mx_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5) else: get_mx_ret = TestBinaryElementWiseOp(op, scalar=scalar, reverse=reverse) if hybridize: @@ -291,29 +290,42 @@ def check_binary_op_result(shape1, shape2, op, dtype=None): @with_seed() -def test_np_op_output_type(): - # test imperative invoke - data = np.array([1., 3.], dtype='float32') - ret = np.sum(data) - assert type(ret) == np.ndarray - ret = mx.nd.sin(data) - assert type(ret) == mx.nd.NDArray - - # test cached op - class TestCachedOpOutputType(HybridBlock): - @mx.use_np_compat +def test_hybrid_block_multiple_outputs(): + class TestAllNumpyOutputs(HybridBlock): + @np.use_np_compat + def hybrid_forward(self, F, x, *args, **kwargs): + return F.npe.relu(x), F.np.sum(x) + + class TestAllClassicOutputs(HybridBlock): + @np.use_np_compat + def hybrid_forward(self, F, x, *args, **kwargs): + return F.relu(x.as_classic_ndarray()), F.sum(x.as_classic_ndarray()) + + class TestMixedTypeOutputsSuccess(HybridBlock): + @np.use_np_compat + def hybrid_forward(self, F, x, *args, **kwargs): + return F.relu(x.as_classic_ndarray()).as_np_ndarray(), F.np.sum(x) + + data_np = np.ones((2, 3)) + for block, expected_out_type in [(TestAllClassicOutputs, mx.nd.NDArray), + (TestAllNumpyOutputs, np.ndarray), + (TestMixedTypeOutputsSuccess, np.ndarray)]: + net = block() + for hybridize in [True, False]: + if hybridize: + net.hybridize() + out1, out2 = net(data_np) + assert type(out1) is expected_out_type + assert type(out2) is expected_out_type + + class TestMixedTypeOutputsFailure(HybridBlock): + @np.use_np_compat def hybrid_forward(self, F, x, *args, **kwargs): - ret1 = F.sin(x) - ret2 = F.np.sum(x) - return ret1, ret2 + return F.relu(x.as_classic_ndarray()), F.np.sum(x) - net = TestCachedOpOutputType() - for hybridize in [True, False]: - if hybridize: - net.hybridize() - ret1, ret2 = net(data) - assert type(ret1) == mx.nd.NDArray - assert type(ret2) == np.ndarray + net = TestMixedTypeOutputsFailure() + net.hybridize() + assert_exception(net, TypeError, data_np) @with_seed() @@ -331,6 +343,7 @@ def test_np_ndarray_astype(): def check_astype_equal(dtype, copy, expect_zero_copy=False): mx_ret = mx_data.astype(dtype=dtype, copy=copy) + assert type(mx_ret) is np.ndarray np_ret = np_data.astype(dtype=dtype, copy=copy) assert mx_ret.dtype == np_ret.dtype assert same(mx_ret.asnumpy(), np_ret) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 8c13227584ae..34b2cbe82353 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -19,7 +19,7 @@ from __future__ import absolute_import import numpy as _np import mxnet as mx -from mxnet import numpy as np +from mxnet import np, npe from mxnet.gluon import HybridBlock from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray from mxnet.test_utils import check_numeric_gradient @@ -27,7 +27,7 @@ import random -@mx.use_np_compat +@np.use_np_compat @with_seed() def test_np_sum(): class TestSum(HybridBlock): @@ -38,7 +38,7 @@ def __init__(self, axis=None, dtype=None, keepdims=False): self._keepdims = keepdims def hybrid_forward(self, F, a, *args, **kwargs): - return F.numpy.sum(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims) + return F.np.sum(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims) def is_int(dtype): return 'int' in dtype @@ -63,6 +63,7 @@ def is_int(dtype): x = mx.nd.array(x) else: x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype) + x = x.as_np_ndarray() x.attach_grad() expected_ret = _np.sum(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims) expected_ret = expected_ret.astype(dtype) @@ -77,8 +78,8 @@ def is_int(dtype): # test numeric if itype == 'float32' and dtype == 'float32': - x_sym = mx.sym.Variable("x") - mx_sym = mx.sym.numpy.sum(x_sym, axis=axis, dtype=dtype, keepdims=keepdims) + x_sym = mx.sym.Variable("x").as_np_ndarray() + mx_sym = mx.sym.np.sum(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_classic_ndarray() check_numeric_gradient(mx_sym, [x], numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32) # test imperative @@ -87,10 +88,11 @@ def is_int(dtype): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) -@mx.use_np_compat +@np.use_np_compat @with_seed() def test_np_dot(): shapes = [ + ((3, 0), (0, 4)), ((3,), (3,)), # Case 1 ((3, 4), (4, 5)), # Case 2 ((), ()), # Case 3 @@ -102,7 +104,6 @@ def test_np_dot(): eps = 1e-3 for shape_a, shape_b in shapes: - print(shape_a, shape_b) np_a = _np.random.uniform(-1.0, 1.0, shape_a) np_a[abs(np_a) < eps] = 2 * eps; np_b = _np.random.uniform(-1.0, 1.0, shape_b) @@ -110,12 +111,12 @@ def test_np_dot(): a = mx.nd.array(np_a) b = mx.nd.array(np_b) np_res = _np.dot(np_a, np_b) - mx_res = np.dot(a, b) + mx_res = np.dot(a.as_np_ndarray(), b.as_np_ndarray()) assert mx_res.shape == np_res.shape assert_almost_equal(np_res, mx_res.asnumpy(), rtol=1e-5, atol=1e-5) mx_a = mx.sym.Variable("a") mx_b = mx.sym.Variable("b") - mx_sym = mx.sym.numpy.dot(mx_a, mx_b) + mx_sym = mx.sym.np.dot(mx_a.as_np_ndarray(), mx_b.as_np_ndarray()).as_classic_ndarray() check_numeric_gradient(mx_sym, {"a": a, "b": b}, numeric_eps=eps, rtol=1e-2, atol=1e-3) bad_shapes = [((4, 5), (2, 3)), ((3, 4, 5), (6, ))] @@ -124,13 +125,13 @@ def test_np_dot(): a = mx.nd.array(random.random()) if len(shape_a) == 0 else rand_ndarray(shape_a) b = mx.nd.array(random.random()) if len(shape_b) == 0 else rand_ndarray(shape_b) try: - mx_res = np.dot(a, b) + mx_res = np.dot(a.as_np_ndarray(), b.as_np_ndarray()) except mx.base.MXNetError: continue assert False -@mx.use_np_compat +@np.use_np_compat @with_seed() def test_np_mean(): class TestMean(HybridBlock): @@ -141,7 +142,7 @@ def __init__(self, axis=None, dtype=None, keepdims=False): self._keepdims = keepdims def hybrid_forward(self, F, a, *args, **kwargs): - return F.numpy.mean(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims) + return F.np.mean(a, axis=self._axis, dtype=self._dtype, keepdims=self._keepdims) def is_int(dtype): return 'int' in dtype @@ -167,6 +168,7 @@ def is_int(dtype): x = mx.nd.array(x, dtype=itype) else: x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype) + x = x.as_np_ndarray() x.attach_grad() expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims) expected_ret = expected_ret.astype(dtype) @@ -182,8 +184,8 @@ def is_int(dtype): # test numeric if itype == 'float32' and dtype == 'float32': - x_sym = mx.sym.Variable("x") - mx_sym = mx.sym.numpy.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims) + x_sym = mx.sym.Variable("x").as_np_ndarray() + mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_classic_ndarray() check_numeric_gradient(mx_sym, [x], numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32) # test imperative @@ -193,12 +195,12 @@ def is_int(dtype): @with_seed() -@mx.use_np_compat +@np.use_np_compat def test_np_transpose(): # TODO(junwu): Add more test cases - data = mx.sym.var('a') - ret = mx.sym.np.transpose(data) - assert type(ret) == mx.sym.np._NumpySymbol + data = mx.sym.var('a').as_np_ndarray() + ret = data.transpose() + assert type(ret) == mx.sym.np._Symbol dtypes = ['float32', 'int32'] for dtype in dtypes: @@ -223,44 +225,44 @@ def test_np_transpose(): @with_seed() -@mx.use_np_compat +@np.use_np_compat def test_relu(): # TODO(junwu): Add more test cases - data = mx.sym.var('data') - ret = mx.sym.np.ext.relu(data) - assert type(ret) == mx.sym.np._NumpySymbol + data = mx.sym.var('data').as_np_ndarray() + ret = mx.sym.npe.relu(data) + assert type(ret) == mx.sym.np._Symbol shapes = [(), (0, 2, 0)] shapes.extend([rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)]) for shape in shapes: data = np.array(_np.random.uniform(size=shape).astype('float32')) - ret = np.ext.relu(data) + ret = npe.relu(data) assert type(ret) == np.ndarray @with_seed() -@mx.use_np_compat +@np.use_np_compat def test_sigmoid(): # TODO(junwu): Add more test cases - data = mx.sym.var('data') - ret = mx.sym.np.ext.sigmoid(data) - assert type(ret) == mx.sym.np._NumpySymbol + data = mx.sym.var('data').as_np_ndarray() + ret = mx.sym.npe.sigmoid(data) + assert type(ret) == mx.sym.np._Symbol shapes = [(), (0, 2, 0)] shapes.extend([rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)]) for shape in shapes: data = np.array(_np.random.uniform(size=shape).astype('float32')) - ret = np.ext.sigmoid(data) + ret = npe.sigmoid(data) assert type(ret) == np.ndarray @with_seed() -@mx.use_np_compat +@np.use_np_compat def test_np_reshape(): # TODO(junwu): Add more test cases - data = mx.sym.var('a') - ret = mx.sym.np.reshape(data, newshape=()) - assert type(ret) == mx.sym.np._NumpySymbol + data = mx.sym.var('a').as_np_ndarray() + ret = data.reshape(shape=()) + assert type(ret) == mx.sym.np._Symbol data = np.ones((1, 1, 1)) ret = np.reshape(data, ()) @@ -271,12 +273,12 @@ def test_np_reshape(): @with_seed() -@mx.use_np_compat +@np.use_np_compat def test_np_maximum(): # TODO(junwu): Add more test cases - x1, x2 = mx.sym.var('x1'), mx.sym.var('x2') + x1, x2 = mx.sym.var('x1').as_np_ndarray(), mx.sym.var('x2').as_np_ndarray() ret = mx.sym.np.maximum(x1, x2) - assert type(ret) == mx.sym.np._NumpySymbol + assert type(ret) == mx.sym.np._Symbol def check_maximum(x1, x2): mx_out = np.maximum(x1, x2) @@ -292,12 +294,12 @@ def check_maximum(x1, x2): @with_seed() -@mx.use_np_compat +@np.use_np_compat def test_np_minimum(): # TODO(junwu): Add more test cases - x1, x2 = mx.sym.var('x1'), mx.sym.var('x2') + x1, x2 = mx.sym.var('x1').as_np_ndarray(), mx.sym.var('x2').as_np_ndarray() ret = mx.sym.np.minimum(x1, x2) - assert type(ret) == mx.sym.np._NumpySymbol + assert type(ret) == mx.sym.np._Symbol def check_minimum(x1, x2): mx_out = np.minimum(x1, x2)