diff --git a/example/numpy/demo.ipynb b/example/numpy/demo.ipynb index 7ba184dad43f..1f0627563159 100644 --- a/example/numpy/demo.ipynb +++ b/example/numpy/demo.ipynb @@ -4,13 +4,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Fundamentals of MXNet Numpy Module\n", + "# Fundamentals of MXNet-NumPy Module\n", "\n", "## Namespaces for Imperative Programming\n", "- `mxnet.numpy`: Regular NumPy operators\n", "- `mxnet.numpy.random`: NumPy random operators\n", "- `mxnet.numpy.linalg`: NumPy linear algebra operators\n", - "- `mxnet.numpy_extension`: Operators implemented in MXNet that do not exist in the official NumPy\n", + "- `mxnet.numpy_extension`: Operators implemented in MXNet that do not exist in the official NumPy and some utils (e.g. context related functions).\n", "\n", "## Operator Namespaces for Gluon\n", "`F` can be either `mxnet.ndarray` or `mxnet.symbol`. Note that `np` and `npe` are aliases of `numpy` and `numpy_extension`, respectively.\n", @@ -20,7 +20,7 @@ "- `F.npe`: Operators implemented in MXNet that do not exist in official NumPy\n", "\n", "## New `ndarray` and `symbol`\n", - "`mxnet.numpy.ndarray` (visible to users) and `mxnet.symbol.numpy._Symbol` (not visible to users)\n", + "`mxnet.numpy.ndarray` (visible to users) and `mxnet.symbol.numpy._Symbol` (not directly visible to users)\n", "- Same name as in the official NumPy package\n", "- Dispatch convience fluent method calls to MXNet Numpy operators\n", "- Override many convenience fluent methods that do not exist in the official NumPy ndarray\n", @@ -28,7 +28,19 @@ " - Indexing: `__getitem__` and `__setitem__`\n", " - Many binary element-wise with broadcasting, not supported in `mxnet.symbol.Symbol`\n", " \n", - "## Examples of ndarray and symbol Basics\n", + "## User Experience of Module Importing (In Progress)\n", + "**Legacy**\n", + "```python\n", + "import mxnet as mx\n", + "from mxnet import gluon\n", + "```\n", + "**Numpy**\n", + "```python\n", + "from mxnet import np, npe, gluon\n", + "```\n", + "\n", + " \n", + "## MXNet NumPy in Action\n", "### Scalar and zero-size tensors" ] }, @@ -41,9 +53,6 @@ "import mxnet as mx\n", "from mxnet import numpy as np\n", "\n", - "# use numpy-compatible semantics\n", - "mx.set_np_compat(True)\n", - "\n", "# create a scalar tensor\n", "x = np.array(3.14)\n", "print(x) # x is actually an ndarray, but a scalar value will be printed" @@ -158,7 +167,63 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Binary element-wise operations with broadcasting in new and old symbols" + "### There is a line between classic operators and numpy operators...\n", + "- Numpy operators can only accept numpy `ndarray`s/`_Symbol`s as inputs\n", + "- Classic operators can only accept classic `NDArray`s/`Symbol`s as inputs\n", + "- Explicit conversions must be performed if users want to leverage operators on both sides\n", + "- The layer inheriting from `HybridBlock` must have the same type of outputs, i.e., either all classic `NDArray`s or all numpy `ndarray`s, before hybridization\n", + "\n", + "#### Imperative" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a = mx.nd.ones((2, 3)) # create a classic NDArray\n", + "print(a)\n", + "out = np.sum(a) # feeding it to a numpy operator would result in failure" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "b = a.as_np_ndarray() # convert `a` to a numpy ndarray sharing the same data memory\n", + "print(b)\n", + "out = np.sum(b) # feed the numpy ndarray to a numpy operator\n", + "print('np.sum(b) =', out)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "out = mx.nd.sum(b) # feeding `b` to a classic operator would reuslt in failure" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "c = b.as_classic_ndarray() # convert `b` to a classic ndarray\n", + "out = mx.nd.sum(c) # feed the classic ndarray to a classic operator\n", + "print('mx.nd.sum(c) =', str(out))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Gluon" ] }, { @@ -168,19 +233,15 @@ "outputs": [], "source": [ "from mxnet import gluon\n", - "class TestBinaryBroadcast(gluon.HybridBlock):\n", - " def hybrid_forward(self, F, x1, x2):\n", - " print(\"x1 type in hybrid_forward:\", str(type(x1)))\n", - " print(\"x2 type in hybrid_forward:\", str(type(x2)))\n", - " return x1 + x2\n", + "class TestMultipleOutputs(gluon.HybridBlock):\n", + " def hybrid_forward(self, F, x):\n", + " ret1 = F.sum(x) # a classic operator produces a classic NDArray\n", + " ret2 = F.np.sum(x) # a numpy operator produces a numpy NDArray\n", + " return ret1, ret2\n", "\n", - "net = TestBinaryBroadcast()\n", - "x1 = mx.nd.ones((2, 1))\n", - "x2 = mx.nd.ones((1, 3))\n", - "print('x1 input tensor type: ', str(type(x1)))\n", - "print('x2 input tensor type: ', str(type(x2)))\n", - "out = net(x1, x2) # ok: imperative execution supports broadcasting\n", - "print(out)" + "net = TestMultipleOutputs()\n", + "net.hybridize()\n", + "out = net(a) # `a` is a classic NDArray and will cause an error on `F.np.sum` which is a numpy operator" ] }, { @@ -189,12 +250,9 @@ "metadata": {}, "outputs": [], "source": [ - "net.hybridize() # mark the block for execution using a computational graph\n", - "try:\n", - " out = net(x1, x2) # error: old symbol `+` operation does not support broadcasting\n", - " assert False # should not reach here\n", - "except mx.MXNetError:\n", - " print(\"ERROR: cannot perform broadcast add for two symbols of mxnet.sym.Symbol\")" + "net = TestMultipleOutputs() # redefine a net with no pre-built graph\n", + "net.hybridize()\n", + "out = net(b) # `b` is a numpy ndarray and will cause an error on `F.sum` which is a classic operator" ] }, { @@ -203,19 +261,15 @@ "metadata": {}, "outputs": [], "source": [ - "class TestBinaryBroadcast2(gluon.HybridBlock):\n", - " def hybrid_forward(self, F, x1, x2):\n", - " print(\"x1 type in hybrid_forward:\", str(type(x1)))\n", - " print(\"x2 type in hybrid_forward:\", str(type(x2)))\n", - " return x1.as_np_ndarray() + x2 # convert x1 to new numpy ndarray/symbol\n", - "\n", - "net2 = TestBinaryBroadcast2()\n", - "net2.hybridize()\n", + "class TestMultipleOutputs2(gluon.HybridBlock):\n", + " def hybrid_forward(self, F, x): # x is known to be a numpy ndarray\n", + " ret1 = F.sum(x.as_classic_ndarray()) # a classic operator produces a classic NDArray\n", + " ret2 = F.np.sum() # a numpy operator produces a numpy NDArray\n", + " return ret1, ret2 # two outputs of the layer with different types would result in failure in building the graph\n", "\n", - "print('x1 input tensor type: ', str(type(x1)))\n", - "print('x2 input tensor type: ', str(type(x2)))\n", - "out =net2(x1, x2)\n", - "print(out)" + "net = TestMultipleOutputs2()\n", + "net.hybridize()\n", + "out = net(b)" ] }, { @@ -224,34 +278,45 @@ "metadata": {}, "outputs": [], "source": [ - "net = TestBinaryBroadcast() # Create a new block object to clear the graph\n", - "net.hybridize() # mark the block for execution using a computational graph\n", + "class TestMultipleOutputs3(gluon.HybridBlock):\n", + " def hybrid_forward(self, F, x): # x is known to be a numpy ndarray\n", + " ret1 = F.sum(x.as_classic_ndarray()) # a classic operator produces a classic NDArray\n", + " ret2 = F.np.sum(x) # a numpy operator produces a numpy NDArray\n", + " return ret1.as_np_ndarray(), ret2 # two outputs of the layer with different types would result in failure in building the graph\n", "\n", - "x1 = x1.as_np_ndarray() # convert x1 to np.ndarray so that _NumpySymbol will be used in graph construction\n", - "print('x1 input tensor type: ', str(type(x1)))\n", - "x2 = x2.as_np_ndarray() # convert x2 to np.ndarray so that _NumpySymbol will be used in graph construction\n", - "print('x2 input tensor type: ', str(type(x2)))\n", - "out = net(x1, x2) # ok: `+` operation supports broadcasting for _NumpySymbol\n", - "print(out) # mxnet.numpy.ndarray type, because it's from a np operator" + "net = TestMultipleOutputs3()\n", + "net.hybridize()\n", + "out = net(b)\n", + "print('classic operator output: ', out[0])\n", + "print('numpy operator output: ', out[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## A Simple Linear Regression Model\n", - "Let's consider a simple linear regression model as the following.\n", - "Given dataset `{x, y}`, where `x`s represent input examples and `y`s represent observed data, find the parameters `w1` and `w2` for the following model.\n", - "```\n", - "y_pred = np.dot(np.maximum(np.dot(x, w1), 0), w2)\n", - "```" + "### Binary element-wise operations with broadcasting in new and old symbols" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "### MXNet Numpy Operators in Imperative Programming" + "class TestBinaryBroadcast(gluon.HybridBlock):\n", + " def hybrid_forward(self, F, x1, x2):\n", + " print(\"x1 type in hybrid_forward:\", str(type(x1)))\n", + " print(\"x2 type in hybrid_forward:\", str(type(x2)))\n", + " return x1 + x2\n", + "\n", + "net = TestBinaryBroadcast()\n", + "x1 = mx.nd.ones((2, 1))\n", + "x2 = mx.nd.ones((1, 3))\n", + "print('x1 input tensor type: ', str(type(x1)))\n", + "print('x2 input tensor type: ', str(type(x2)))\n", + "out = net(x1, x2) # ok: imperative execution supports broadcasting\n", + "print(out)" ] }, { @@ -260,56 +325,41 @@ "metadata": {}, "outputs": [], "source": [ - "import mxnet as mx\n", - "from mxnet import numpy as np, numpy_extension as npe\n", - "from mxnet import autograd\n", - "\n", - "\n", - "# Use numpy-compatible semantics to support scalar tensors\n", - "mx.set_np_compat(True)\n", - "\n", - "# N is number of examples; D_in is input dimension;\n", - "# H is hidden dimension; D_out is output dimension.\n", - "N, D_in, H, D_out = 64, 1000, 100, 10\n", - "\n", - "# Create random input and output data\n", - "x = mx.nd.random.normal(shape=(N, D_in)).as_np_ndarray() # x is of type mxnet.numpy.ndarray\n", - "y = mx.nd.random.normal(shape=(N, D_out)).as_np_ndarray() # y is of type mxnet.numpy.ndarray\n", - "\n", - "# Randomly initialize weights\n", - "w1 = mx.nd.random.normal(shape=(D_in, H)).as_np_ndarray() # w1 is of type mxnet.numpy.ndarray\n", - "w1.attach_grad() # w1.grad is of type mxnet.numpy.ndarray\n", - "w2 = mx.nd.random.normal(shape=(H, D_out)).as_np_ndarray() # w2 is of type mxnet.numpy.ndarray\n", - "w2.attach_grad() # w2.grad is of type mxnet.numpy.ndarray\n", - "\n", - "learning_rate = 1e-6\n", - "\n", - "\n", - "for t in range(50):\n", - " with autograd.record():\n", - " # Forward pass: compute predicted y\n", - " h = x.dot(w1) # equivalent to np.dot(x, w1)\n", - " h_relu = npe.relu(h) # equivalent to mx.nd.relu(h)\n", - " y_pred = h_relu.dot(w2) # equivalent to np.dot(h_relu, w2)\n", - "\n", - " # Compute loss\n", - " # (y_pred - y) ** 2 calls np.ndarray.__pow__\n", - " # sum() calls np.sum() which should return a scalar tensor\n", - " loss = ((y_pred - y) ** 2).sum()\n", - " # Note that the print function will invoke loss.asnumpy()\n", - " print(t, loss) # loss is a scalar tensor of type mxnet.numpy.ndarray\n", - " loss.backward()\n", + "net.hybridize() # mark the block for execution using a computational graph\n", + "try:\n", + " out = net(x1, x2) # error: old symbol `+` operation does not support broadcasting\n", + " assert False # should not reach here\n", + "except mx.MXNetError:\n", + " print(\"ERROR: cannot perform broadcast add for two symbols of type mx.sym.Symbol\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "net = TestBinaryBroadcast() # redefine a net to clear the pre-built graph cache\n", + "net.hybridize()\n", "\n", - " # Update weights\n", - " w1 -= learning_rate * w1.grad\n", - " w2 -= learning_rate * w2.grad" + "x1 = x1.as_np_ndarray() # convert x1 to np.ndarray\n", + "x2 = x2.as_np_ndarray() # convert x2 to np.ndarray\n", + "print('x1 input tensor type: ', str(type(x1)))\n", + "print('x2 input tensor type: ', str(type(x2)))\n", + "out = net(x1, x2) # ok: a graph is built with numpy symbols which supports broadcasting, because inputs are np.ndarray's, \n", + "print(out)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### MXNet Numpy Operators in Gluon `HybridBlock`" + "## A Simple Linear Regression Model\n", + "Let's consider a simple linear regression model as the following.\n", + "Given dataset `{x, y}`, where `x`s represent input examples and `y`s represent observed data, find the parameters `w1` and `w2` for the following model.\n", + "```\n", + "y_pred = np.dot(np.maximum(np.dot(x, w1), 0), w2)\n", + "```" ] }, { @@ -319,13 +369,10 @@ "outputs": [], "source": [ "import mxnet as mx\n", - "from mxnet import gluon, autograd\n", - "\n", - "\n", - "# Use numpy-compatible semantics to support scalar tensors\n", - "mx.set_np_compat(True)\n", + "from mxnet import gluon, autograd, np\n", "\n", "\n", + "@np.use_np_compat\n", "class LinearRegression(gluon.HybridBlock):\n", " def __init__(self, num_input_dim=1000, num_hidden_dim=100, num_output_dim=10):\n", " super(LinearRegression, self).__init__()\n", @@ -337,7 +384,7 @@ "\n", " def hybrid_forward(self, F, x, w1, w2):\n", " h = x.dot(w1) # equivalent to F.np.dot(x, w1)\n", - " h_relu = F.npe.relu(h) # equivalent to F.relu(h)\n", + " h_relu = F.npe.relu(h) # equivalent to F.relu(h) but generating np.ndarray\n", " y_pred = h_relu.dot(w2) # equivalent to F.np.dot(h_relu, w2)\n", " return y_pred\n", "\n", @@ -356,7 +403,9 @@ "y = mx.nd.random.normal(shape=(64, 10)).as_np_ndarray() # y is of type mxnet.numpy.ndarray\n", "\n", "total_loss = TotalLoss()\n", - "trainer = gluon.Trainer(regressor.collect_params(), 'sgd', {'learning_rate': 1e-3, 'momentum': 0.9})\n", + "trainer = gluon.Trainer(regressor.collect_params(),\n", + " 'sgd',\n", + " {'learning_rate': 1e-3, 'momentum': 0.9, 'allow_np': True})\n", "\n", "for t in range(50):\n", " with autograd.record():\n", diff --git a/include/mxnet/tuple.h b/include/mxnet/tuple.h index 08381e2152df..f018c8faabea 100644 --- a/include/mxnet/tuple.h +++ b/include/mxnet/tuple.h @@ -661,6 +661,13 @@ inline bool shape_is_known(const TShape& x) { return true; } +inline bool shape_is_known(const std::vector& shapes) { + for (const TShape& shape : shapes) { + if (!shape_is_known(shape)) return false; + } + return true; +} + /*! \brief helper function to cast type of container elements */ template inline DstIter ShapeTypeCast(const SrcIter begin, diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 92b45e5a2c6e..5393c511ce07 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -852,21 +852,3 @@ def _init_np_op_module(root_module_name, np_module_name, mx_module_name, make_op function.__module__ = module_name_local setattr(cur_module, function.__name__, function) cur_module.__all__.append(function.__name__) - - -def set_module(module): - """Decorator for overriding __module__ on a function or class. - - Example usage:: - - @set_module('mxnet.numpy') - def example(): - pass - - assert example.__module__ == 'numpy' - """ - def decorator(func): - if module is not None: - func.__module__ = module - return func - return decorator diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index 6b4f4b609d13..807f160baa56 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -33,7 +33,8 @@ from ..ndarray import NDArray from .. import name as _name from .parameter import Parameter, ParameterDict, DeferredInitializationError -from .utils import _indent, _brief_print_list, HookHandle, _check_same_symbol_type +from .utils import _indent, _brief_print_list, HookHandle +from .utils import _check_same_symbol_type, _check_all_np_ndarrays from .. import numpy as _mx_np @@ -550,7 +551,8 @@ def __call__(self, *args): for hook in self._forward_hooks.values(): hook(self, args, out) - + if _mx_np.is_np_compat(): + _check_all_np_ndarrays(_flatten(out, "output")[0]) return out def forward(self, *args): diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index a174d82341af..307fb15bd1b7 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -131,7 +131,6 @@ def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t, self._grad_stype = grad_stype self._stype = stype - def __repr__(self): s = 'Parameter {name} (shape={shape}, dtype={dtype})' return s.format(name=self.name, shape=self.shape, dtype=self.dtype) @@ -189,9 +188,9 @@ def shape(self, new_shape): if self._shape is None: self._shape = new_shape return - + unknown_dim_size = -1 if is_np_compat() else 0 assert len(self._shape) == len(new_shape) and \ - all(j in (0, i) for i, j in zip(new_shape, self._shape)), \ + all(j in (unknown_dim_size, i) for i, j in zip(new_shape, self._shape)), \ "Expected shape %s is incompatible with given shape %s."%( str(new_shape), str(self._shape)) @@ -330,6 +329,9 @@ def _finish_deferred_init(self): ctx=context.cpu(), stype=self._stype) initializer.create(default_init)( initializer.InitDesc(self.name, {'__init__': init}), data) + # TODO(junwu): use np random operators when available + if is_np_compat(): + data = data.as_np_ndarray() # convert to np.ndarray self._init_impl(data, ctx) @@ -354,6 +356,9 @@ def _init_grad(self): self._grad = [ndarray.zeros(shape=i.shape, dtype=i.dtype, ctx=i.context, stype=self._grad_stype) for i in self._data] + # TODO(junwu): use np.zeros + if is_np_compat(): + self._grad = [arr.as_np_ndarray() for arr in self._grad] autograd.mark_variables(self._check_and_get(self._data, list), self._grad, self.grad_req) @@ -463,7 +468,6 @@ def reset_ctx(self, ctx): raise ValueError("Cannot reset context for Parameter '%s' because it " "has not been initialized."%self.name) - def set_data(self, data): """Sets this parameter's value on all contexts.""" self.shape = data.shape @@ -602,6 +606,8 @@ def var(self): self._var = symbol.var(self.name, shape=self.shape, dtype=self.dtype, lr_mult=self.lr_mult, wd_mult=self.wd_mult, init=self.init, stype=self._stype) + if is_np_compat(): + self._var = self._var.as_np_ndarray() return self._var def cast(self, dtype): diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 241baf415818..acfcce2ae3de 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -452,3 +452,28 @@ def _check_same_symbol_type(symbols): 'computation graph, please convert all the numpy symbols in the list ' 'to classic symbols by calling `as_classic_ndarray()` on each of them.') return np_symbol if is_np_sym else classic_symbol + + +def _check_all_np_ndarrays(out): + """Check if ndarrays in out are all np.ndarray""" + from ..numpy import ndarray as np_ndarray + assert isinstance(out, (list, tuple)) + for array in out: + if not isinstance(array, np_ndarray): + raise TypeError('Expected np.ndarray type in output, while received type ' + '{}'.format(str(type(array)))) + + +def shape_is_known(shape): + """Check whether a shape is completely known w/ or w/o np semantics.""" + if shape is None: + return False + unknown_dim_size = -1 if is_np_shape() else 0 + if len(shape) == 0: + return unknown_dim_size == -1 + for dim_size in shape: + if dim_size == unknown_dim_size: + return False + assert dim_size > unknown_dim_size, "shape dimension size cannot be less than {}, while " \ + "received {}".format(unknown_dim_size, dim_size) + return True diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 4be752eedea2..8c44afa5aabe 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -196,6 +196,12 @@ def as_np_ndarray(self): check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl))) return ndarray(handle=hdl, writable=self.writable) + def as_classic_ndarray(self): + """A convenience function for creating a classic ndarray from the current + ndarray with zero copy. For this class, it just returns itself since it is + already a classic ndarray.""" + return self + @property def _tvm_handle(self): return self.handle.value diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index e905fdf9dac6..725fba4c1cf1 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -19,16 +19,15 @@ from __future__ import absolute_import import numpy as _np -from ...base import _sanity_check_params, use_np_compat, numeric_types, set_module +from ...base import numeric_types +from ...util import _sanity_check_params, use_np_compat, set_module from ...context import current_context from . import _internal as _npi -from ..ndarray import NDArray __all__ = ['zeros', 'ones', 'maximum', 'minimum'] @set_module('mxnet.ndarray.numpy') -@use_np_compat def zeros(shape, dtype=_np.float32, **kwargs): """Return a new array of given shape and type, filled with zeros. This function currently only supports storing multi-dimensional data @@ -60,7 +59,6 @@ def zeros(shape, dtype=_np.float32, **kwargs): @set_module('mxnet.ndarray.numpy') -@use_np_compat def ones(shape, dtype=None, **kwargs): """Return a new array of given shape and type, filled with ones. This function currently only supports storing multi-dimensional data @@ -92,6 +90,7 @@ def ones(shape, dtype=None, **kwargs): #pylint: disable= too-many-arguments, no-member, protected-access +@use_np_compat def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, out=None): """ Helper function for element-wise operation. The function will perform numpy-like broadcasting if needed and call different functions. @@ -122,6 +121,7 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou mxnet.numpy.ndarray result array """ + from ...numpy import ndarray if isinstance(lhs, numeric_types): if isinstance(rhs, numeric_types): return fn_scalar(lhs, rhs, out=out) @@ -133,7 +133,7 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou return rfn_scalar(rhs, float(lhs), out=out) elif isinstance(rhs, numeric_types): return lfn_scalar(lhs, float(rhs), out=out) - elif isinstance(rhs, NDArray): + elif isinstance(rhs, ndarray): return fn_array(lhs, rhs, out=out) else: raise TypeError('type %s not supported' % str(type(rhs))) @@ -141,7 +141,6 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou @set_module('mxnet.ndarray.numpy') -@use_np_compat def maximum(x1, x2, out=None): """Returns element-wise maximum of the input arrays with broadcasting. @@ -159,7 +158,6 @@ def maximum(x1, x2, out=None): @set_module('mxnet.ndarray.numpy') -@use_np_compat def minimum(x1, x2, out=None): """Returns element-wise minimum of the input arrays with broadcasting. diff --git a/python/mxnet/ndarray/register.py b/python/mxnet/ndarray/register.py index a285e508e04c..e93a74c5bf17 100644 --- a/python/mxnet/ndarray/register.py +++ b/python/mxnet/ndarray/register.py @@ -25,9 +25,10 @@ from ..ndarray_doc import _build_doc from ..base import mx_uint, check_call, _LIB, py_str, _init_op_module, _Null, _is_np_op # pylint: disable=unused-import +from ..util import use_np_compat # pylint: disable=unused-import -def _verify_all_np_ndarrays(op_name, func_name, *array_list): +def _verify_all_np_ndarrays(op_name, func_name, args, out): """Verify if all the arrays are numpy ndarrays. Parameters @@ -37,11 +38,14 @@ def _verify_all_np_ndarrays(op_name, func_name, *array_list): func_name : str Operator name exposed to users. This is usually the name by stripping off the prefix of the full operator names registered in backend. - array_list : list of arrays + args : list of arrays + Input ndarray arguments to be checked. + out : ndarray or None or list of ndarrays + User-provided output ndarrays. """ from ..numpy import ndarray as np_ndarray - for array in array_list: - if (array is not None) and (not isinstance(array, np_ndarray)): + for arr in args: + if (arr is not None) and (not isinstance(arr, np_ndarray)): raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. ' 'This is a numpy operator which can only accept ' 'MXNet numpy ndarrays, while received a classic ndarray. ' @@ -49,9 +53,22 @@ def _verify_all_np_ndarrays(op_name, func_name, *array_list): 'convert it to an MXNet numpy ndarray, and then feed the converted ' 'array to this operator.' .format(op_name, func_name)) + if out is None: + return + if not isinstance(out, (list, tuple)): + out = [out] + for arr in out: + if (arr is not None) and (not isinstance(arr, np_ndarray)): + raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. ' + 'This is a numpy operator which can only write to MXNet numpy ' + 'ndarrays, while received a classic ndarray. ' + 'Please call `as_np_ndarray()` upon the classic ndarray to ' + 'convert it to an MXNet numpy ndarray, and then feed the converted ' + 'array to this operator.' + .format(op_name, func_name)) -def _verify_all_classic_ndarrays(op_name, func_name, *array_list): +def _verify_all_classic_ndarrays(op_name, func_name, args, out): """Verify if all the arrays are classic ndarrays. Parameters @@ -61,11 +78,14 @@ def _verify_all_classic_ndarrays(op_name, func_name, *array_list): func_name : str Operator name exposed to users. This is usually the name by stripping off the prefix of the full operator names registered in backend. - array_list : list of arrays + args : list of arrays + Input ndarray arguments to be checked. + out : ndarray or None or list of ndarrays + User-provided output ndarrays. """ from ..numpy import ndarray as np_ndarray - for array in array_list: - if (array is not None) and (isinstance(array, np_ndarray)): + for arr in args: + if (arr is not None) and (isinstance(arr, np_ndarray)): raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. ' 'This is a classic operator which can only accept ' 'classic ndarrays, while received an MXNet numpy ndarray. ' @@ -73,6 +93,19 @@ def _verify_all_classic_ndarrays(op_name, func_name, *array_list): 'convert it to a classic ndarray, and then feed the converted ' 'array to this operator.' .format(op_name, func_name)) + if out is None: + return + if not isinstance(out, (list, tuple)): + out = [out] + for arr in out: + if (arr is not None) and (isinstance(arr, np_ndarray)): + raise TypeError('Operator `{}` registered in backend is known as `{}` in Python. ' + 'This is a classic operator which can only write to ' + 'classic ndarrays, while received an MXNet numpy ndarray. ' + 'Please call `as_classic_ndarray()` upon the numpy ndarray to ' + 'convert it to a classic ndarray, and then feed the converted ' + 'array to this operator.' + .format(op_name, func_name)) # pylint: disable=too-many-locals @@ -138,6 +171,12 @@ def _generate_ndarray_function_code(handle, op_name, func_name, signature_only=F signature = ndsignature + signature code = [] + is_np_op = _is_np_op(op_name) + doc_str_idx = 1 + if is_np_op: + doc_str_idx = 2 + code.append(""" +@use_np_compat""") if arr_name: code.append(""" def %s(*%s, **kwargs):"""%(func_name, arr_name)) @@ -187,13 +226,12 @@ def %s(%s):"""%(func_name, ', '.join(signature))) keys.append('%s') vals.append(_np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name)) - is_np_op = _is_np_op(op_name) verify_ndarrays_fn =\ _verify_all_np_ndarrays.__name__ if is_np_op else _verify_all_classic_ndarrays.__name__ if not signature_only: code.append(""" - {}("{}", "{}", out, *ndargs) - """.format(verify_ndarrays_fn, op_name, func_name)) + {verify_fn}("{op_name}", "{func_name}", ndargs, out) + """.format(verify_fn=verify_ndarrays_fn, op_name=op_name, func_name=func_name)) code.append(""" return _imperative_invoke(%d, ndargs, keys, vals, out, %s)"""%( handle.value, str(is_np_op))) @@ -204,7 +242,7 @@ def %s(%s):"""%(func_name, ', '.join(signature))) doc_str_lines = _os.linesep+''.join([' '+s if s.strip() else s for s in 'r"""{doc_str}"""'.format(doc_str=doc_str) .splitlines(True)]) - code.insert(1, doc_str_lines) + code.insert(doc_str_idx, doc_str_lines) return ''.join(code), doc_str diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py index 0f3c3c72504e..6d6ac6ad465c 100644 --- a/python/mxnet/numpy/__init__.py +++ b/python/mxnet/numpy/__init__.py @@ -26,6 +26,6 @@ from . import _op from . import _register from ._op import * # pylint: disable=wildcard-import -from ..base import use_np_compat, set_np_compat, np_compat +from ..util import use_np_compat, set_np_compat, np_compat, is_np_compat __all__ = [] diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index dfcce0b9a671..f5a3b83ba485 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -28,8 +28,9 @@ from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _GRAD_REQ_MAP from ..ndarray._internal import _set_np_ndarray_class from . import _op as _mx_np_op -from ..base import use_np_compat, check_call, _LIB, NDArrayHandle, _sanity_check_params -from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types, set_module +from ..base import check_call, _LIB, NDArrayHandle +from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types +from ..util import _sanity_check_params, set_module, use_np_compat from ..context import current_context from ..ndarray import numpy as _mx_nd_np from ..ndarray.numpy import _internal as _npi @@ -74,6 +75,7 @@ def _np_ndarray_cls(handle, writable=True, stype=0): @set_module('mxnet.numpy') # pylint: disable=invalid-name +@use_np_compat class ndarray(NDArray): """An array object represents a multidimensional, homogeneous array of fixed-size items. An associated data-type object describes the format of each element in the array @@ -81,16 +83,24 @@ class ndarray(NDArray): floating point number, or something else, etc.). Arrays should be constructed using `array`, `zeros` or `empty`. Currently, only c-contiguous arrays are supported.""" - @use_np_compat def __getitem__(self, item): # TODO(junwu): make output shape of integer indexing correct raise NotImplementedError - @use_np_compat def __setitem__(self, key, value): - self.as_classic_ndarray().__setitem__(key, value) + if self.size == 0: + return + if self.ndim == 0: + if key != (): + raise IndexError('scalar tensor can only accept `()` as index') + # TODO(junwu): Better handling of this situation + hdl = NDArrayHandle() + check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl))) + classic_ndarray = NDArray(handle=hdl, writable=self.writable) + classic_ndarray.__setitem__(slice(None), value) + return + self._as_classic_ndarray().__setitem__(key, value) - @use_np_compat def __add__(self, other): """x.__add__(y) <=> x + y""" if isinstance(other, ndarray): @@ -100,7 +110,6 @@ def __add__(self, other): else: raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) - @use_np_compat def __iadd__(self, other): """x.__iadd__(y) <=> x += y""" if not self.writable: @@ -112,7 +121,6 @@ def __iadd__(self, other): else: raise TypeError('type {} is not supported'.format(str(type(other)))) - @use_np_compat def __sub__(self, other): """x.__sub__(y) <=> x - y""" if isinstance(other, ndarray): @@ -122,7 +130,6 @@ def __sub__(self, other): else: raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) - @use_np_compat def __isub__(self, other): """x.__isub__(y) <=> x -= y""" if not self.writable: @@ -134,7 +141,6 @@ def __isub__(self, other): else: raise TypeError('type {} is not supported'.format(str(type(other)))) - @use_np_compat def __rsub__(self, other): """x.__rsub__(y) <=> y - x""" if isinstance(other, ndarray): @@ -144,7 +150,6 @@ def __rsub__(self, other): else: raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) - @use_np_compat def __mul__(self, other): """x.__mul__(y) <=> x * y""" if isinstance(other, ndarray): @@ -154,15 +159,12 @@ def __mul__(self, other): else: raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) - @use_np_compat def __neg__(self): return self.__mul__(-1.0) - @use_np_compat def __imul__(self, other): raise NotImplementedError - @use_np_compat def __rmul__(self, other): """x.__rmul__(y) <=> y * x""" return self.__mul__(other) @@ -181,11 +183,9 @@ def __rdiv__(self, other): ' module. If you are using Python3, this error should not have' ' been encountered.') - @use_np_compat def __idiv__(self, other): raise NotImplementedError - @use_np_compat def __truediv__(self, other): """x.__truediv__(y) <=> x / y""" if isinstance(other, ndarray): @@ -195,7 +195,6 @@ def __truediv__(self, other): else: raise TypeError("ndarray does not support type {} as divisor".format(str(type(other)))) - @use_np_compat def __rtruediv__(self, other): """x.__rtruediv__(y) <=> y / x""" if isinstance(other, ndarray): @@ -205,11 +204,9 @@ def __rtruediv__(self, other): else: raise TypeError("ndarray does not support type {} as dividend".format(str(type(other)))) - @use_np_compat def __itruediv__(self, other): raise NotImplementedError - @use_np_compat def __mod__(self, other): """x.__mod__(y) <=> x % y""" if isinstance(other, ndarray): @@ -219,7 +216,6 @@ def __mod__(self, other): else: raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) - @use_np_compat def __rmod__(self, other): """x.__rmod__(y) <=> y % x""" if isinstance(other, ndarray): @@ -229,11 +225,9 @@ def __rmod__(self, other): else: raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) - @use_np_compat def __imod__(self, other): raise NotImplementedError - @use_np_compat def __pow__(self, other): """x.__pow__(y) <=> x ** y""" if isinstance(other, ndarray): @@ -243,7 +237,6 @@ def __pow__(self, other): else: raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) - @use_np_compat def __rpow__(self, other): """x.__rpow__(y) <=> y ** x""" if isinstance(other, ndarray): @@ -253,45 +246,36 @@ def __rpow__(self, other): else: raise TypeError("ndarray does not support type {} as operand".format(str(type(other)))) - @use_np_compat def __eq__(self, other): """x.__eq__(y) <=> x == y""" raise NotImplementedError - @use_np_compat def __hash__(self): raise NotImplementedError - @use_np_compat def __ne__(self, other): """x.__ne__(y) <=> x != y""" raise NotImplementedError - @use_np_compat def __gt__(self, other): """x.__gt__(y) <=> x > y""" raise NotImplementedError - @use_np_compat def __ge__(self, other): """x.__ge__(y) <=> x >= y""" raise NotImplementedError - @use_np_compat def __lt__(self, other): """x.__lt__(y) <=> x < y""" raise NotImplementedError - @use_np_compat def __le__(self, other): """x.__le__(y) <=> x <= y""" raise NotImplementedError - @use_np_compat def __bool__(self): raise NotImplementedError - @use_np_compat def __len__(self): """Number of elements along the first axis.""" return self.shape[0] @@ -329,29 +313,38 @@ def T(self): return self.transpose() # pylint: enable= invalid-name, undefined-variable - @use_np_compat def _slice(self, start, stop): raise NotImplementedError - @use_np_compat def _at(self, idx): raise NotImplementedError - @use_np_compat def all(self, axis=None, out=None, keepdims=False): raise NotImplementedError - @use_np_compat def any(self, axis=None, out=None, keepdims=False): raise NotImplementedError - def as_classic_ndarray(self): - """Convert mxnet.numpy.ndarray to mxnet.ndarray.NDArray to use its fluent methods.""" + def _as_classic_ndarray(self): + """This is not a user-facing API.""" hdl = NDArrayHandle() check_call(_LIB.MXShallowCopyNDArray(self.handle, ctypes.byref(hdl))) return NDArray(handle=hdl, writable=self.writable) - @use_np_compat + def as_classic_ndarray(self): + """Convert mxnet.numpy.ndarray to mxnet.ndarray.NDArray to use its fluent methods.""" + if self.ndim == 0: # TODO(junwu): this costs ~10ns, can be moved to backend + raise ValueError('cannot convert a scalar np.ndarray to mx.nd.NDArray') + if self.size == 0: # TODO(junwu): this costs ~10ns, can be moved to backend + raise ValueError('cannot convert a zero-size np.ndarray to mx.nd.NDArray') + return self._as_classic_ndarray() + + def as_np_ndarray(self): + """A convenience function for creating a numpy ndarray from the current ndarray + with zero copy. For this class, it just returns itself since it's already a + numpy ndarray.""" + return self + def __repr__(self): """Returns a string representation of the array using the following rules: 1. If the `ndarray` is a scalar tensor, only the string of the scalar is returned. @@ -369,7 +362,6 @@ def __repr__(self): else: return '%s\n<%s shape=%s>' % (array_str, self.__class__.__name__, self.shape) - @use_np_compat def attach_grad(self, grad_req='write'): # pylint: disable=arguments-differ """Attach a gradient buffer to this ndarray, so that `backward` can compute gradient with respect to it. @@ -398,14 +390,12 @@ def grad(self): return None return _np_ndarray_cls(hdl) - @use_np_compat def detach(self): """Returns a new ndarray, detached from the current graph.""" hdl = NDArrayHandle() check_call(_LIB.MXNDArrayDetach(self.handle, ctypes.byref(hdl))) return _np_ndarray_cls(hdl) - @use_np_compat def astype(self, dtype, *args, **kwargs): # pylint: disable=arguments-differ,unused-argument """ Copy of the array, cast to a specified type. @@ -436,7 +426,6 @@ def astype(self, dtype, *args, **kwargs): # pylint: disable=arguments-differ,un self.copyto(res) return res - @use_np_compat def copyto(self, other): """Copies the value of this array to another array. @@ -470,8 +459,8 @@ def copyto(self, other): [ 1., 1., 1.]], dtype=float32) """ if isinstance(other, ndarray): - other = other.as_classic_ndarray() - return self.as_classic_ndarray().copyto(other).as_np_ndarray() + other = other._as_classic_ndarray() + return self._as_classic_ndarray().copyto(other).as_np_ndarray() def asscalar(self): raise AttributeError('mxnet.numpy.ndarray object has no attribute as_scalar') @@ -479,18 +468,15 @@ def asscalar(self): def as_in_context(self, context): return super(ndarray, self).as_in_context(context).as_np_ndarray() - @use_np_compat def copy(self, order='C'): # pylint: disable=arguments-differ if order != 'C': raise NotImplementedError('ndarray.copy only supports order=\'C\', while ' 'received {}'.format(str(order))) return super(ndarray, self).copy().as_np_ndarray() - @use_np_compat def dot(self, b, out=None): return _mx_np_op.dot(self, b, out=out) - @use_np_compat def reshape(self, shape, order='C'): # pylint: disable=arguments-differ """Returns an array containing the same data with a new shape.""" if order != 'C': @@ -530,7 +516,6 @@ def broadcast_axes(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute broadcast_like') - @use_np_compat def repeat(self, *args, **kwargs): """Convenience fluent method for :py:func:`repeat`. @@ -547,7 +532,6 @@ def pad(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute pad') - @use_np_compat def swapaxes(self, *args, **kwargs): """Convenience fluent method for :py:func:`swapaxes`. @@ -596,7 +580,6 @@ def slice_like(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute slice_like') - @use_np_compat def take(self, *args, **kwargs): """Convenience fluent method for :py:func:`take`. @@ -621,7 +604,6 @@ def pick(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute pick') - @use_np_compat def sort(self, *args, **kwargs): """Convenience fluent method for :py:func:`sort`. @@ -638,7 +620,6 @@ def topk(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute topk') - @use_np_compat def argsort(self, *args, **kwargs): """Convenience fluent method for :py:func:`argsort`. @@ -647,7 +628,6 @@ def argsort(self, *args, **kwargs): """ raise NotImplementedError - @use_np_compat def argmax(self, *args, **kwargs): """Convenience fluent method for :py:func:`argmax`. @@ -664,7 +644,6 @@ def argmax_channel(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute argmax_channel') - @use_np_compat def argmin(self, *args, **kwargs): """Convenience fluent method for :py:func:`argmin`. @@ -673,7 +652,6 @@ def argmin(self, *args, **kwargs): """ raise NotImplementedError - @use_np_compat def clip(self, *args, **kwargs): """Convenience fluent method for :py:func:`clip`. @@ -698,7 +676,6 @@ def sign(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute abs') - @use_np_compat def flatten(self, *args, **kwargs): """Convenience fluent method for :py:func:`flatten`. @@ -739,7 +716,6 @@ def tile(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute tile') - @use_np_compat def transpose(self, *axes): # pylint: disable=arguments-differ """Convenience fluent method for :py:func:`transpose`. @@ -780,7 +756,6 @@ def diag(self, k=0, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute diag') - @use_np_compat def sum(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ """Convenience fluent method for :py:func:`sum`. @@ -797,7 +772,6 @@ def nansum(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute nansum') - @use_np_compat def prod(self, *args, **kwargs): """Convenience fluent method for :py:func:`prod`. @@ -814,7 +788,6 @@ def nanprod(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute nanprod') - @use_np_compat def mean(self, *args, **kwargs): """Convenience fluent method for :py:func:`mean`. @@ -823,7 +796,6 @@ def mean(self, *args, **kwargs): """ raise NotImplementedError - @use_np_compat def max(self, *args, **kwargs): """Convenience fluent method for :py:func:`max`. @@ -832,7 +804,6 @@ def max(self, *args, **kwargs): """ raise NotImplementedError - @use_np_compat def min(self, *args, **kwargs): """Convenience fluent method for :py:func:`min`. @@ -849,7 +820,6 @@ def norm(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute norm') - @use_np_compat def round(self, *args, **kwargs): """Convenience fluent method for :py:func:`round`. @@ -1146,7 +1116,6 @@ def softmin(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute softmin') - @use_np_compat def squeeze(self, *args, **kwargs): """Convenience fluent method for :py:func:`squeeze`. @@ -1162,12 +1131,10 @@ def broadcast_like(self, other): raise AttributeError('mxnet.numpy.ndarray object has no attribute broadcast_like') @property - @use_np_compat def shape(self): return super(ndarray, self).shape @property - @use_np_compat def ndim(self): """Number of array dimensions.""" return len(self.shape) @@ -1249,7 +1216,10 @@ def array(object, dtype=None, **kwargs): except: raise TypeError('source array must be an array like object') ret = empty(object.shape, dtype=dtype, ctx=ctx) - ret[:] = object + if len(object.shape) == 0: + ret[()] = object + else: + ret[:] = object return ret diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index c2c1aa6a76f4..5b433eee7a59 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -18,6 +18,7 @@ # pylint: disable=too-many-lines """Weight updating functions.""" +from __future__ import absolute_import import logging import math import pickle @@ -94,7 +95,7 @@ class Optimizer(object): def __init__(self, rescale_grad=1., param_idx2name=None, wd=0., clip_gradient=None, learning_rate=0.01, lr_scheduler=None, sym=None, begin_num_update=0, - multi_precision=False, param_dict=None): + multi_precision=False, param_dict=None, allow_np=False): self.rescale_grad = rescale_grad self.lr = learning_rate self.lr_scheduler = lr_scheduler @@ -119,6 +120,7 @@ def __init__(self, rescale_grad=1., param_idx2name=None, wd=0., self.idx2name = param_idx2name.copy() self.sym_info = (sym.attr_dict(), sym.list_arguments()) if sym is not None else () self.param_dict = param_dict if param_dict else {} + self.allow_np = allow_np self.set_lr_mult({}) self.set_wd_mult({}) @@ -1644,6 +1646,25 @@ def update(self, index, weight, grad, state): # backward compatibility wrapper for Optimizer.CreateOptimizer create = Optimizer.create_optimizer # pylint: disable=invalid-name + +def _as_classic(a, allow_np): + from ..numpy import ndarray as np_ndarray + if isinstance(a, (tuple, list)): + if any(isinstance(x, np_ndarray) for x in a): + if allow_np: + return [x.as_classic_ndarray() for x in a] + else: + raise ValueError('Converting np.ndarray to mx.nd.NDArray is not allowed') + else: + if isinstance(a, np_ndarray): + if allow_np: + return a.as_classic_ndarray() + else: + raise ValueError('Converting np.ndarray to mx.nd.NDArray is not allowed') + return a + + + class Updater(object): """Updater for kvstore.""" def __init__(self, optimizer): @@ -1654,14 +1675,15 @@ def __init__(self, optimizer): def __call__(self, index, grad, weight): """Updates weight given gradient and index.""" + allow_np = self.optimizer.allow_np if not isinstance(index, (list, tuple)): indices = [index] - grads = [grad] - weights = [weight] + grads = [_as_classic(grad, allow_np)] + weights = [_as_classic(weight, allow_np)] else: indices = index - grads = grad - weights = weight + grads = _as_classic(grad, allow_np) + weights = _as_classic(weight, allow_np) if weights: self.optimizer._set_current_context(weights[0].context.device_id) for i, idx in enumerate(indices): diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 0bbd96b3b2bb..6a03cdb9be45 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -22,8 +22,8 @@ import ctypes import numpy as _np from . import _op as _mx_np_op -from ...base import _sanity_check_params, use_np_compat, check_call, _LIB, SymbolHandle -from ...base import numeric_types, set_module +from ...base import _LIB, SymbolHandle, numeric_types +from ...util import _sanity_check_params, check_call, set_module from ...context import current_context from ..symbol import Symbol from .._internal import _set_np_symbol_class @@ -43,7 +43,6 @@ def __setitem__(self, key, value): def __iter__(self): raise AttributeError('_Symbol object has no attribute __iter__') - @use_np_compat def __add__(self, other): """x.__add__(y) <=> x + y""" if isinstance(other, _Symbol): @@ -54,7 +53,6 @@ def __add__(self, other): raise TypeError("_Symbol does not support type {} as operand" .format(str(type(other)))) - @use_np_compat def __sub__(self, other): """x.__sub__(y) <=> x - y""" if isinstance(other, _Symbol): @@ -65,7 +63,6 @@ def __sub__(self, other): raise TypeError("_Symbol does not support type {} as operand" .format(str(type(other)))) - @use_np_compat def __rsub__(self, other): """x.__rsub__(y) <=> y - x""" if isinstance(other, _Symbol): @@ -76,7 +73,6 @@ def __rsub__(self, other): raise TypeError("_Symbol does not support type {} as operand" .format(str(type(other)))) - @use_np_compat def __mul__(self, other): """x.__mul__(y) <=> x * y""" if isinstance(other, _Symbol): @@ -87,7 +83,6 @@ def __mul__(self, other): raise TypeError("_Symbol does not support type {} as operand" .format(str(type(other)))) - @use_np_compat def __rmul__(self, other): """x.__rmul__(y) <=> y * x""" if isinstance(other, _Symbol): @@ -112,7 +107,6 @@ def __rdiv__(self, other): ' module. If you are using Python3, this error should not have' ' been encountered.') - @use_np_compat def __mod__(self, other): """x.__mod__(y) <=> x % y""" if isinstance(other, _Symbol): @@ -123,7 +117,6 @@ def __mod__(self, other): raise TypeError("_Symbol does not support type {} as operand" .format(str(type(other)))) - @use_np_compat def __rmod__(self, other): """x.__rmod__(y) <=> y % x""" if isinstance(other, _Symbol): @@ -134,11 +127,9 @@ def __rmod__(self, other): raise TypeError("_Symbol does not support type {} as operand" .format(str(type(other)))) - @use_np_compat def __idiv__(self, other): raise NotImplementedError - @use_np_compat def __truediv__(self, other): """x.__truediv__(y) <=> x / y""" if isinstance(other, _Symbol): @@ -149,7 +140,6 @@ def __truediv__(self, other): raise TypeError("_Symbol does not support type {} as divisor" .format(str(type(other)))) - @use_np_compat def __rtruediv__(self, other): """x.__rtruediv__(y) <=> y / x""" if isinstance(other, _Symbol): @@ -160,11 +150,9 @@ def __rtruediv__(self, other): raise TypeError("_Symbol does not support type {} as dividend" .format(str(type(other)))) - @use_np_compat def __itruediv__(self, other): raise NotImplementedError - @use_np_compat def __pow__(self, other): """x.__pow__(y) <=> x ** y""" if isinstance(other, _Symbol): @@ -175,7 +163,6 @@ def __pow__(self, other): raise TypeError("_Symbol does not support type {} as operand" .format(str(type(other)))) - @use_np_compat def __rpow__(self, other): """x.__rpow__(y) <=> y ** x""" if isinstance(other, _Symbol): @@ -186,41 +173,33 @@ def __rpow__(self, other): raise TypeError("_Symbol does not support type {} as operand" .format(str(type(other)))) - @use_np_compat def __neg__(self): """x.__neg__() <=> - x""" return self.__mul__(-1.0) - @use_np_compat def __deepcopy__(self, _): return super(_Symbol, self).as_np_ndarray() - @use_np_compat def __eq__(self, other): """x.__eq__(y) <=> x == y""" raise NotImplementedError - @use_np_compat def __ne__(self, other): """x.__ne__(y) <=> x != y""" raise NotImplementedError - @use_np_compat def __gt__(self, other): """x.__gt__(y) <=> x > y""" raise NotImplementedError - @use_np_compat def __ge__(self, other): """x.__ge__(y) <=> x >= y""" raise NotImplementedError - @use_np_compat def __lt__(self, other): """x.__lt__(y) <=> x < y""" raise NotImplementedError - @use_np_compat def __le__(self, other): """x.__le__(y) <=> x <= y""" raise NotImplementedError @@ -241,15 +220,12 @@ def T(self): return self.transpose() # pylint: enable= invalid-name, undefined-variable - @use_np_compat def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ raise NotImplementedError - @use_np_compat def dot(self, b, out=None): return _mx_np_op.dot(self, b, out=out) - @use_np_compat def reshape(self, shape, order='C'): # pylint: disable=arguments-differ if order != 'C': raise NotImplementedError('ndarray.copy only supports order=\'C\', while ' @@ -288,7 +264,6 @@ def broadcast_axes(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute broadcast_like') - @use_np_compat def repeat(self, *args, **kwargs): """Convenience fluent method for :py:func:`repeat`. @@ -305,7 +280,6 @@ def pad(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute pad') - @use_np_compat def swapaxes(self, *args, **kwargs): """Convenience fluent method for :py:func:`swapaxes`. @@ -354,7 +328,6 @@ def slice_like(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute slice_like') - @use_np_compat def take(self, *args, **kwargs): """Convenience fluent method for :py:func:`take`. @@ -379,7 +352,6 @@ def pick(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute pick') - @use_np_compat def sort(self, *args, **kwargs): """Convenience fluent method for :py:func:`sort`. @@ -396,7 +368,6 @@ def topk(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute topk') - @use_np_compat def argsort(self, *args, **kwargs): """Convenience fluent method for :py:func:`argsort`. @@ -405,7 +376,6 @@ def argsort(self, *args, **kwargs): """ raise NotImplementedError - @use_np_compat def argmax(self, *args, **kwargs): """Convenience fluent method for :py:func:`argmax`. @@ -422,7 +392,6 @@ def argmax_channel(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute argmax_channel') - @use_np_compat def argmin(self, *args, **kwargs): """Convenience fluent method for :py:func:`argmin`. @@ -431,7 +400,6 @@ def argmin(self, *args, **kwargs): """ raise NotImplementedError - @use_np_compat def clip(self, *args, **kwargs): """Convenience fluent method for :py:func:`clip`. @@ -456,7 +424,6 @@ def sign(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute abs') - @use_np_compat def flatten(self, *args, **kwargs): """Convenience fluent method for :py:func:`flatten`. @@ -497,7 +464,6 @@ def tile(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute tile') - @use_np_compat def transpose(self, *axes): # pylint: disable=arguments-differ """Convenience fluent method for :py:func:`transpose`. @@ -538,7 +504,6 @@ def diag(self, k=0, **kwargs): """ raise AttributeError('_Symbol object has no attribute diag') - @use_np_compat def sum(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ """Convenience fluent method for :py:func:`sum`. @@ -555,7 +520,6 @@ def nansum(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute nansum') - @use_np_compat def prod(self, *args, **kwargs): """Convenience fluent method for :py:func:`prod`. @@ -572,7 +536,6 @@ def nanprod(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute nanprod') - @use_np_compat def mean(self, *args, **kwargs): """Convenience fluent method for :py:func:`mean`. @@ -581,7 +544,6 @@ def mean(self, *args, **kwargs): """ raise NotImplementedError - @use_np_compat def max(self, *args, **kwargs): """Convenience fluent method for :py:func:`max`. @@ -590,7 +552,6 @@ def max(self, *args, **kwargs): """ raise NotImplementedError - @use_np_compat def min(self, *args, **kwargs): """Convenience fluent method for :py:func:`min`. @@ -607,7 +568,6 @@ def norm(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute norm') - @use_np_compat def round(self, *args, **kwargs): """Convenience fluent method for :py:func:`round`. @@ -904,7 +864,6 @@ def softmin(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute softmin') - @use_np_compat def squeeze(self, *args, **kwargs): """Convenience fluent method for :py:func:`squeeze`. @@ -921,7 +880,6 @@ def broadcast_like(self, *args, **kwargs): @set_module('mxnet.symbol.numpy') -@use_np_compat def zeros(shape, dtype=_np.float32, **kwargs): """Return a new array of given shape and type, filled with zeros. This function currently only supports storing multi-dimensional data @@ -953,7 +911,6 @@ def zeros(shape, dtype=_np.float32, **kwargs): @set_module('mxnet.symbol.numpy') -@use_np_compat def ones(shape, dtype=None, **kwargs): """Return a new array of given shape and type, filled with zeros. This function currently only supports storing multi-dimensional data @@ -1034,13 +991,11 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou @set_module('mxnet.symbol.numpy') -@use_np_compat def maximum(x1, x2, out=None): return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out) @set_module('mxnet.symbol.numpy') -@use_np_compat def minimum(x1, x2, out=None): return _ufunc_helper(x1, x2, _npi.minimum, _np.minimum, _npi.minimum_scalar, None, out) diff --git a/python/mxnet/util.py b/python/mxnet/util.py index 5bc1dc809c88..d41137142a70 100644 --- a/python/mxnet/util.py +++ b/python/mxnet/util.py @@ -20,6 +20,8 @@ import os import sys import functools +import itertools +import inspect from .base import _LIB, check_call @@ -213,39 +215,111 @@ def np_shape(active=True): return _NumpyShapeScope(active) -def use_np_shape(func): - """Wraps a function with an activated NumPy-shape scope. This ensures - that the execution of the function is guaranteed with the support of - scalar and zero-size tensors as in NumPy. +def wraps_safely(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS): + """This function is safe version of `functools.wraps` in Python2 which skips wrapping functions + for the attributes that do not exist.""" + if sys.version_info[0] > 2: + return functools.wraps(wrapped) + else: + return functools.wraps(wrapped, + assigned=itertools.ifilter( + functools.partial(hasattr, wrapped), assigned)) - Please note that this is designed as an infrastructure for the incoming - MXNet-NumPy operators. Legacy operators registered in the modules - `mx.nd` and `mx.sym` are not guaranteed to behave like their counterparts - in NumPy even within this scope. +def use_np_shape(func): + """A decorator wrapping a function or class with activated NumPy-shape semantics. + When `func` is a function, this ensures that the execution of the function is scoped with NumPy + shape semantics, such as the support for zero-dim and zero size tensors. When + `func` is a class, it ensures that all the methods, static functions, and properties + of the class are executed with the NumPy shape semantics. + + Example:: + import mxnet as mx + @mx.use_np_shape + def scalar_one(): + return mx.nd.ones(()) + print(scalar_one()) + + @np.use_np_shape + class ScalarTensor(object): + def __init__(self, val=None): + if val is None: + val = ScalarTensor.random().value + self._scalar = mx.nd.ones(()) * val + + def __repr__(self): + print("Is __repr__ in np_shape semantics? {}!".format(str(np.is_np_shape()))) + return str(self._scalar.asnumpy()) + + @staticmethod + def random(): + val = mx.nd.random.uniform().asnumpy().item() + return ScalarTensor(val) + + @property + def value(self): + print("Is value property in np_shape semantics? {}!".format(str(np.is_np_shape()))) + return self._scalar.asnumpy().item() + + + print("Is global scope of np_shape activated? {}!".format(str(np.is_np_shape()))) + scalar_tensor = ScalarTensor() + print(scalar_tensor) Parameters ---------- - func : a user-provided callable function to be scoped by the NumPy-shape semantics. + func : a user-provided callable function or class to be scoped by the NumPy compatibility state. Returns ------- - Function - A function for wrapping the user functions in the NumPy-shape semantics. + Function or class + A function or class wrapped in the NumPy compatibility scope. + """ + if inspect.isclass(func): + for name, method in inspect.getmembers( + func, + predicate= + lambda f: inspect.isfunction(f) or inspect.ismethod(f) or isinstance(f, property)): + if isinstance(method, property): + setattr(func, name, property(use_np_shape(method.__get__), + method.__set__, + method.__delattr__, + method.__doc__)) + else: + setattr(func, name, use_np_shape(method)) + return func + elif callable(func): + @wraps_safely(func) + def _with_np_shape(*args, **kwargs): + with np_shape(active=True): + return func(*args, **kwargs) + return _with_np_shape + else: + raise TypeError('use_np_shape can only decorate classes and callable objects, ' + 'while received a {}'.format(str(type(func)))) + + +def _sanity_check_params(func_name, unsupported_params, param_dict): + for param_name in unsupported_params: + if param_name in param_dict: + raise NotImplementedError("function {} does not support parameter {}" + .format(func_name, param_name)) - Examples - -------- - >>> import mxnet as mx - >>> @mx.use_np_shape - ... def scalar_one(): - ... return mx.nd.ones(()) - ... - >>> print(scalar_one()) - """ - @functools.wraps(func) - def _with_np_shape(*args, **kwargs): - with np_shape(active=True): - return func(*args, **kwargs) - return _with_np_shape +def set_module(module): + """Decorator for overriding __module__ on a function or class. + + Example usage:: + + @set_module('mxnet.numpy') + def example(): + pass + + assert example.__module__ == 'numpy' + """ + def decorator(func): + if module is not None: + func.__module__ = module + return func + return decorator diff --git a/src/operator/numpy/np_dot.cc b/src/operator/numpy/np_dot.cc index bcb310fda4b6..992bef086d2b 100644 --- a/src/operator/numpy/np_dot.cc +++ b/src/operator/numpy/np_dot.cc @@ -36,29 +36,43 @@ inline bool NumpyDotShape(const nnvm::NodeAttrs& attrs, const mxnet::TShape& a_shape = in_attrs->at(0); const mxnet::TShape& b_shape = in_attrs->at(1); - if (!shape_is_known(a_shape) || !shape_is_known(b_shape)) { + if (!ndim_is_known(a_shape) || !ndim_is_known(b_shape)) { return false; } if (a_shape.ndim() == 1 && b_shape.ndim() == 1) { // Case 1: both 1-D arrays, inner product of vectors - CHECK_EQ(a_shape[0], b_shape[0]); + SHAPE_ASSIGN_CHECK(*in_attrs, 0, in_attrs->at(1)); + SHAPE_ASSIGN_CHECK(*in_attrs, 1, in_attrs->at(0)); SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(0, 0)); } else if (a_shape.ndim() == 2 && b_shape.ndim() == 2) { // Case 2: both 2-D arrays, matrix multiplication - CHECK_EQ(a_shape[1], b_shape[0]); - mxnet::TShape mm_shape(2, 0); - mm_shape[0] = a_shape[0]; - mm_shape[1] = b_shape[1]; - SHAPE_ASSIGN_CHECK(*out_attrs, 0, mm_shape); + mxnet::TShape tmp_shape(2, -1); + tmp_shape[1] = b_shape[0]; + SHAPE_ASSIGN_CHECK(*in_attrs, 0, tmp_shape); + + tmp_shape[0] = a_shape[1]; + tmp_shape[1] = -1; + SHAPE_ASSIGN_CHECK(*in_attrs, 1, tmp_shape); + + tmp_shape[0] = a_shape[0]; + tmp_shape[1] = b_shape[1]; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, tmp_shape); } else if (a_shape.ndim() == 0 || b_shape.ndim() == 0) { // Case 3 + 3.5: either of them is a scalar, just scale by one of them mxnet::TShape oshape = (a_shape.ndim() == 0) ? b_shape : a_shape; SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); } else if (b_shape.ndim() == 1) { // Case 4: a is N-D array and b is 1-D array, sum product over the last axis - CHECK_EQ(a_shape[a_shape.ndim() - 1], b_shape[0]); - mxnet::TShape out_shape(a_shape.ndim() - 1, 0); + TShape tmp_shape(a_shape.ndim(), -1); + tmp_shape[a_shape.ndim() - 1] = b_shape[0]; + SHAPE_ASSIGN_CHECK(*in_attrs, 0, tmp_shape); + + tmp_shape = TShape(1, -1); + tmp_shape[0] = a_shape[a_shape.ndim() - 1]; + SHAPE_ASSIGN_CHECK(*in_attrs, 1, tmp_shape); + + mxnet::TShape out_shape(a_shape.ndim() - 1, -1); for (int i = 0; i < a_shape.ndim() - 1; ++i) { out_shape[i] = a_shape[i]; } @@ -68,7 +82,7 @@ inline bool NumpyDotShape(const nnvm::NodeAttrs& attrs, // of a and the 2nd-to-last axis of b LOG(FATAL) << "Case 5 not implemented yet..."; } - return true; + return shape_is_known(*in_attrs) && shape_is_known(*out_attrs); } NNVM_REGISTER_OP(_np_dot) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index ab7114ba3732..69a1d7e5a88d 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -37,6 +37,7 @@ from test_operator import * from test_numpy_op import * from test_numpy_ndarray import * +from test_numpy_gluon import * from test_optimizer import * from test_random import * from test_exc_handling import * diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py new file mode 100644 index 000000000000..446f5b8c9672 --- /dev/null +++ b/tests/python/unittest/test_numpy_gluon.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: skip-file +from __future__ import absolute_import +from __future__ import division +import mxnet as mx +from mxnet import gluon, autograd, np + + +def test_create_np_param(): + M, K, N = 10, 9, 20 + + def check_block_params(x, TestBlock, hybridize, expected_type): + net = TestBlock() + net.initialize() + if hybridize: + net.hybridize() + net(x) + params = net.collect_params() + for k, v in params.items(): + assert type(v.data()) is expected_type + + class TestBlock1(gluon.HybridBlock): + def __init__(self): + super(TestBlock1, self).__init__() + with self.name_scope(): + self.w = self.params.get('w', shape=(K, N), allow_deferred_init=True) + + def hybrid_forward(self, F, x, w): + return F.dot(x, w) + + @np.use_np_compat + class TestBlock2(gluon.HybridBlock): + def __init__(self): + super(TestBlock2, self).__init__() + with self.name_scope(): + self.w = self.params.get('w', shape=(K, N), allow_deferred_init=True) + + def hybrid_forward(self, F, x, w): + return F.np.dot(x, w) + + x = mx.nd.random.uniform(shape=(M, K)) + check_block_params(x, TestBlock1, False, mx.nd.NDArray) + check_block_params(x, TestBlock1, True, mx.nd.NDArray) + check_block_params(x.as_np_ndarray(), TestBlock2, False, np.ndarray) + check_block_params(x.as_np_ndarray(), TestBlock2, True, np.ndarray) + + +def test_optimizer_with_np_ndarrays(): + @np.use_np_compat + class LinearRegression(gluon.HybridBlock): + def __init__(self, num_input_dim=-1, num_hidden_dim=100, num_output_dim=10): + super(LinearRegression, self).__init__() + with self.name_scope(): + self.w1 = self.params.get('w1', shape=(num_input_dim, num_hidden_dim), + allow_deferred_init=True) + self.w2 = self.params.get('w2', shape=(num_hidden_dim, num_output_dim), + allow_deferred_init=True) + + def hybrid_forward(self, F, x, w1, w2): + h = x.dot(w1) # equivalent to F.np.dot(x, w1) + h_relu = F.npe.relu(h) # equivalent to F.relu(h) but generating np.ndarray + y_pred = h_relu.dot(w2) # equivalent to F.np.dot(h_relu, w2) + return y_pred + + @np.use_np_compat + class TotalLoss(gluon.HybridBlock): + def hybrid_forward(self, F, pred, label): + return ((pred - label) ** 2).sum() # equivalent to F.np.sum(F.np.square(pred - label)) + + regressor = LinearRegression() + regressor.initialize(mx.init.Normal()) + regressor.hybridize() + + # Create random input and output data + x = mx.nd.random.normal(shape=(64, 1000)).as_np_ndarray() # x is of type mxnet.numpy.ndarray + regressor(x) + y = mx.nd.random.normal(shape=(64, 10)).as_np_ndarray() # y is of type mxnet.numpy.ndarray + + total_loss = TotalLoss() + total_loss.hybridize() + + trainer = gluon.Trainer(regressor.collect_params(), + 'sgd', + {'learning_rate': 1e-3, 'momentum': 0.9, 'allow_np': True}) + + for t in range(5): + with autograd.record(): + output = regressor(x) # output is a type of np.ndarray because np.dot is the last op in the network + loss = total_loss(output, y) # loss is a scalar np.ndarray + loss.backward() + trainer.step(1) + + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index eb452346f6eb..7ffa77438e64 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -45,9 +45,9 @@ def test_array_creation(): @with_seed() -@np.use_np_compat def test_zeros(): # test np.zeros in Gluon + @np.use_np_compat class TestZeros(HybridBlock): def __init__(self, shape, dtype=None): super(TestZeros, self).__init__() @@ -57,11 +57,13 @@ def __init__(self, shape, dtype=None): def hybrid_forward(self, F, x, *args, **kwargs): return x + F.np.zeros(shape, dtype) + @np.use_np_compat class TestZerosOutputType(HybridBlock): def hybrid_forward(self, F, x, *args, **kwargs): return x, F.np.zeros(shape=()) # test np.zeros in imperative + @np.use_np_compat def check_zero_array_creation(shape, dtype): np_out = _np.zeros(shape=shape, dtype=dtype) mx_out = np.zeros(shape=shape, dtype=dtype) @@ -93,9 +95,9 @@ def check_zero_array_creation(shape, dtype): @with_seed() -@np.use_np_compat def test_ones(): # test np.ones in Gluon + @np.use_np_compat class TestOnes(HybridBlock): def __init__(self, shape, dtype=None): super(TestOnes, self).__init__() @@ -105,11 +107,13 @@ def __init__(self, shape, dtype=None): def hybrid_forward(self, F, x, *args, **kwargs): return x * F.np.ones(shape, dtype) + @np.use_np_compat class TestOnesOutputType(HybridBlock): def hybrid_forward(self, F, x, *args, **kwargs): return x, F.np.ones(shape=()) # test np.ones in imperative + @np.use_np_compat def check_ones_array_creation(shape, dtype): np_out = _np.ones(shape=shape, dtype=dtype) mx_out = np.ones(shape=shape, dtype=dtype) @@ -141,7 +145,6 @@ def check_ones_array_creation(shape, dtype): @with_seed() -@np.use_np_compat def test_ndarray_binary_element_wise_ops(): # Cannot test operators like >, because boolean arrays are not supported yet. np_op_map = {'+': _np.add, '*': _np.multiply, '-': _np.subtract, '/': _np.divide, @@ -153,6 +156,7 @@ def test_ndarray_binary_element_wise_ops(): def get_np_ret(x1, x2, op): return np_op_map[op](x1, x2) + @np.use_np_compat class TestBinaryElementWiseOp(HybridBlock): def __init__(self, op, scalar=None, reverse=False): super(TestBinaryElementWiseOp, self).__init__() @@ -215,6 +219,7 @@ def hybrid_forward(self, F, x, *args): print(self._op) assert False + @np.use_np_compat def check_binary_op_result(shape1, shape2, op, dtype=None): if shape1 is None: mx_input1 = abs(_np.random.uniform()) + 1 @@ -250,13 +255,6 @@ def check_binary_op_result(shape1, shape2, op, dtype=None): assert type(mx_out) == np.ndarray assert np_out.shape == mx_out.shape assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5) - - if mx_input1.shape == mx_input2.shape: - # classic symbol does not support element-wise binary broadcast. - mx_out = get_mx_ret_classic(mx_input1, mx_input2) - assert type(mx_out) == mx.nd.NDArray - assert np_out.shape == mx_out.shape - assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5) else: get_mx_ret = TestBinaryElementWiseOp(op, scalar=scalar, reverse=reverse) if hybridize: @@ -291,25 +289,18 @@ def check_binary_op_result(shape1, shape2, op, dtype=None): @with_seed() def test_hybrid_block_multiple_outputs(): + @np.use_np_compat class TestAllNumpyOutputs(HybridBlock): - @np.use_np_compat def hybrid_forward(self, F, x, *args, **kwargs): return F.npe.relu(x), F.np.sum(x) class TestAllClassicOutputs(HybridBlock): - @np.use_np_compat def hybrid_forward(self, F, x, *args, **kwargs): return F.relu(x.as_classic_ndarray()), F.sum(x.as_classic_ndarray()) - class TestMixedTypeOutputsSuccess(HybridBlock): - @np.use_np_compat - def hybrid_forward(self, F, x, *args, **kwargs): - return F.relu(x.as_classic_ndarray()).as_np_ndarray(), F.np.sum(x) - data_np = np.ones((2, 3)) for block, expected_out_type in [(TestAllClassicOutputs, mx.nd.NDArray), - (TestAllNumpyOutputs, np.ndarray), - (TestMixedTypeOutputsSuccess, np.ndarray)]: + (TestAllNumpyOutputs, np.ndarray)]: net = block() for hybridize in [True, False]: if hybridize: @@ -318,12 +309,13 @@ def hybrid_forward(self, F, x, *args, **kwargs): assert type(out1) is expected_out_type assert type(out2) is expected_out_type + @np.use_np_compat class TestMixedTypeOutputsFailure(HybridBlock): - @np.use_np_compat def hybrid_forward(self, F, x, *args, **kwargs): return F.relu(x.as_classic_ndarray()), F.np.sum(x) net = TestMixedTypeOutputsFailure() + assert_exception(net, TypeError, data_np) net.hybridize() assert_exception(net, TypeError, data_np) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 34b2cbe82353..e1993923ebf5 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -27,7 +27,6 @@ import random -@np.use_np_compat @with_seed() def test_np_sum(): class TestSum(HybridBlock): @@ -88,8 +87,8 @@ def is_int(dtype): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) -@np.use_np_compat @with_seed() +@np.use_np_compat def test_np_dot(): shapes = [ ((3, 0), (0, 4)), @@ -131,9 +130,9 @@ def test_np_dot(): assert False -@np.use_np_compat @with_seed() def test_np_mean(): + @np.use_np_compat class TestMean(HybridBlock): def __init__(self, axis=None, dtype=None, keepdims=False): super(TestMean, self).__init__()