From 19bd76c895620cb2dcca94be6f2d649fd8790aa4 Mon Sep 17 00:00:00 2001 From: reminisce Date: Thu, 9 May 2019 20:15:18 -0700 Subject: [PATCH] [numpy] Some np ops for d2l (#14924) * Add np transpose More ops and namespaces for submodules Add relu and sigmoid Add reshape Fix symbolic name mismatch Add maximum and minimum * Add convenience fluent method * Add ndarray.item() * Fix CI * Fix lint * Fix lint * Fix reshape gpu * Add example * Remove python notebook outputs * Remove notebook output * Add one more example --- example/numpy/demo.ipynb | 415 ++++++++++++++++++ include/mxnet/tuple.h | 8 + python/mxnet/base.py | 9 +- python/mxnet/ndarray/numpy/__init__.py | 3 + python/mxnet/ndarray/numpy/_op.py | 90 +++- python/mxnet/ndarray/numpy/ext.py | 20 + python/mxnet/ndarray/numpy/linalg.py | 20 + python/mxnet/ndarray/numpy/random.py | 20 + python/mxnet/numpy/__init__.py | 5 +- python/mxnet/numpy/ext.py | 20 + python/mxnet/numpy/linalg.py | 2 +- python/mxnet/numpy/multiarray.py | 112 ++++- python/mxnet/numpy/random.py | 2 +- python/mxnet/symbol/numpy/__init__.py | 3 + python/mxnet/symbol/numpy/_symbol.py | 92 +++- python/mxnet/symbol/numpy/ext.py | 20 + python/mxnet/symbol/numpy/linalg.py | 20 + python/mxnet/symbol/numpy/random.py | 20 + src/c_api/c_api_common.h | 6 +- .../numpy/np_elemwise_broadcast_op.cc | 18 + .../numpy/np_elemwise_broadcast_op.cu | 15 +- .../numpy/np_elemwise_unary_op_basic.cc | 63 +++ .../numpy/np_elemwise_unary_op_basic.cu | 39 ++ src/operator/numpy/np_matrix_op-inl.h | 65 +++ src/operator/numpy/np_matrix_op.cc | 218 +++++++++ src/operator/numpy/np_matrix_op.cu | 37 ++ .../tensor/elemwise_binary_broadcast_op.h | 1 + src/operator/tensor/matrix_op-inl.h | 8 +- tests/python/unittest/test_numpy_ndarray.py | 1 - tests/python/unittest/test_numpy_op.py | 120 +++++ 30 files changed, 1428 insertions(+), 44 deletions(-) create mode 100644 example/numpy/demo.ipynb create mode 100644 python/mxnet/ndarray/numpy/ext.py create mode 100644 python/mxnet/ndarray/numpy/linalg.py create mode 100644 python/mxnet/ndarray/numpy/random.py create mode 100644 python/mxnet/numpy/ext.py create mode 100644 python/mxnet/symbol/numpy/ext.py create mode 100644 python/mxnet/symbol/numpy/linalg.py create mode 100644 python/mxnet/symbol/numpy/random.py create mode 100644 src/operator/numpy/np_elemwise_unary_op_basic.cc create mode 100644 src/operator/numpy/np_elemwise_unary_op_basic.cu create mode 100644 src/operator/numpy/np_matrix_op-inl.h create mode 100644 src/operator/numpy/np_matrix_op.cc create mode 100644 src/operator/numpy/np_matrix_op.cu diff --git a/example/numpy/demo.ipynb b/example/numpy/demo.ipynb new file mode 100644 index 000000000000..d8e6e06e1818 --- /dev/null +++ b/example/numpy/demo.ipynb @@ -0,0 +1,415 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fundamentals of MXNet Numpy Module\n", + "\n", + "## Operator Namespaces for Imperative Programming\n", + "- `mxnet.numpy`: Regular NumPy operators\n", + "- `mxnet.numpy.random`: NumPy random operators\n", + "- `mxnet.numpy.linalg`: NumPy linear algebra operators\n", + "- `mxnet.numpy.ext`: Operators implemented in MXNet that do not exist in official NumPy\n", + "\n", + "## Operator Namespaces for Gluon\n", + "`F` can be either `mxnet.ndarray` or `mxnet.symbol`.\n", + "- `F.np`: Regular NumPy operators\n", + "- `F.np.random`: NumPy random operators\n", + "- `F.np.linalg`: NumPy linear algebra operators\n", + "- `F.np.ext`: Operators implemented in MXNet that do not exist in official NumPy\n", + "\n", + "## New `ndarray` and `symbol`\n", + "`mxnet.numpy.ndarray` and `mxnet.symbol.numpy._NumpySymbol` (not visible to users)\n", + "- Same name as in the official NumPy package\n", + "- Dispatch convience fluent method calls to MXNet Numpy operators\n", + "- Override many convenience fluent methods that do not exist in the official NumPy ndarray\n", + "- Make the behavior of built-in methods consistent with the official NumPy\n", + " - Indexing: `__getitem__` and `__setitem__`\n", + " - Many binary element-wise with broadcasting, not supported in `mxnet.symbol.Symbol`\n", + " \n", + "## Examples of ndarray and symbol Basics\n", + "### Scalar and zero-size tensors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import mxnet as mx\n", + "from mxnet import numpy as np\n", + "\n", + "# use numpy-compatible semantics\n", + "mx.set_np_compat(True)\n", + "\n", + "# create a scalar tensor\n", + "x = np.array(3.14)\n", + "print(x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "s = x.item() # copy the element from the scalar tensor to a python scalar\n", + "print('s = {}'.format(str(s)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create a scalar tensors with only one element 1.0\n", + "y = np.ones(())\n", + "print(y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create a zero-size tensor\n", + "x = np.ones((5, 4, 0, 6))\n", + "print(x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# transpose the zero-size tensor\n", + "y = np.transpose(x)\n", + "print(y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Conversion between classic and numpy ndarrays" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# create a classic MXNet NDArray\n", + "x = mx.nd.random.uniform(shape=(2, 3))\n", + "print(x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# convert classic NDArray type to mxnet.numpy.ndarray with zero-copy\n", + "y = x.as_np_ndarray()\n", + "print(y)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# changing y's content changes x's content too\n", + "y[:] = 1\n", + "print(x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# convert mxnet.numpy.ndarray to classic NDArray with zero-copy\n", + "z = y.as_classic_ndarray()\n", + "print(z)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# changing z's content changes y's content too\n", + "z[:] = 2\n", + "print(y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Binary element-wise operations with broadcasting in new and old symbols" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from mxnet import gluon\n", + "class TestBinaryBroadcast(gluon.HybridBlock):\n", + " def hybrid_forward(self, F, x1, x2):\n", + " print(\"x1 type:\", str(type(x1)))\n", + " print(\"x2 type:\", str(type(x2)))\n", + " return x1 + x2\n", + "\n", + "net = TestBinaryBroadcast()\n", + "x1 = mx.nd.ones((2, 1))\n", + "x2 = mx.nd.ones((1, 3))\n", + "out = net(x1, x2) # ok: imperative execution supports broadcasting\n", + "print(out)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "net.hybridize() # mark the block for execution using a computational graph\n", + "try:\n", + " out = net(x1, x2) # error: old symbol `+` operation does not support broadcasting\n", + " assert False # should not reach here\n", + "except mx.MXNetError:\n", + " print(\"ERROR: cannot perform broadcast add for two symbols of mxnet.sym.Symbol\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class TestBinaryBroadcast2(gluon.HybridBlock):\n", + " def hybrid_forward(self, F, x1, x2):\n", + " print(\"x1 type:\", str(type(x1)))\n", + " print(\"x2 type:\", str(type(x2)))\n", + " return x1.as_np_ndarray() + x2 # convert x1 to new numpy ndarray/symbol\n", + "\n", + "net2 = TestBinaryBroadcast2()\n", + "net2.hybridize()\n", + "\n", + "out =net2(x1, x2)\n", + "print(out)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "net = TestBinaryBroadcast() # Create a new block object to clear the graph\n", + "net.hybridize() # mark the block for execution using a computational graph\n", + "\n", + "x1 = x1.as_np_ndarray() # convert x1 to np.ndarray so that _NumpySymbol will be used in graph construction\n", + "x2 = x2.as_np_ndarray() # convert x2 to np.ndarray so that _NumpySymbol will be used in graph construction\n", + "out = net(x1, x2) # ok: `+` operation supports broadcasting for _NumpySymbol\n", + "print(out) # mxnet.numpy.ndarray type, because it's from a np operator" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## A Simple Linear Regression Model\n", + "Let's consider a simple linear regression model as the following.\n", + "Given dataset `{x, y}`, where `x`s represent input examples and `y`s represent observed data, find the parameters `w1` and `w2` for the following model.\n", + "```\n", + "y_pred = np.dot(np.maximum(np.dot(x, w1), 0), w2)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MXNet Numpy Operators in Imperative Programming" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import mxnet as mx\n", + "from mxnet import numpy as np\n", + "from mxnet import autograd\n", + "try:\n", + " from mxboard import SummaryWriter\n", + "except ImportError:\n", + " SummaryWriter = None\n", + "\n", + "# create a summary writer for visualization\n", + "sw = SummaryWriter(logdir='./logs', flush_secs=2) if SummaryWriter is not None else None\n", + "\n", + "# Use numpy-compatible semantics to support scalar tensors\n", + "mx.set_np_compat(True)\n", + "\n", + "# N is number of examples; D_in is input dimension;\n", + "# H is hidden dimension; D_out is output dimension.\n", + "N, D_in, H, D_out = 64, 1000, 100, 10\n", + "\n", + "# Create random input and output data\n", + "x = mx.nd.random.normal(shape=(N, D_in)).as_np_ndarray() # x is of type mxnet.numpy.ndarray\n", + "y = mx.nd.random.normal(shape=(N, D_out)).as_np_ndarray() # y is of type mxnet.numpy.ndarray\n", + "\n", + "# Randomly initialize weights\n", + "w1 = mx.nd.random.normal(shape=(D_in, H)).as_np_ndarray() # w1 is of type mxnet.numpy.ndarray\n", + "w1.attach_grad() # w1.grad is of type mxnet.numpy.ndarray\n", + "w2 = mx.nd.random.normal(shape=(H, D_out)).as_np_ndarray() # w2 is of type mxnet.numpy.ndarray\n", + "w2.attach_grad() # w2.grad is of type mxnet.numpy.ndarray\n", + "\n", + "learning_rate = 1e-6\n", + "\n", + "\n", + "for t in range(1000):\n", + " with autograd.record():\n", + " # Forward pass: compute predicted y\n", + " h = x.dot(w1) # equivalent to np.dot(x, w1)\n", + " h_relu = np.ext.relu(h) # equivalent to mx.nd.relu(h)\n", + " y_pred = h_relu.dot(w2) # equivalent to np.dot(h_relu, w2)\n", + "\n", + " # Compute loss\n", + " # (y_pred - y) ** 2 calls np.ndarray.__pow__\n", + " # sum() calls np.sum() which should return a scalar tensor\n", + " loss = ((y_pred - y) ** 2).sum()\n", + " # Note that the print function will invoke loss.asnumpy()\n", + " print(t, loss) # loss is a scalar tensor of type mxnet.numpy.ndarray\n", + " loss.backward()\n", + "\n", + " # Update weights\n", + " w1 -= learning_rate * w1.grad\n", + " w2 -= learning_rate * w2.grad\n", + "\n", + " if sw is not None:\n", + " sw.add_scalar('loss', loss.item(), global_step=t) # loss.item() copies the tensor element to a python scalar\n", + " if t % 50 == 0:\n", + " sw.add_histogram(tag='w1', values=w1, global_step=t)\n", + " sw.add_histogram(tag='w2', values=w2, global_step=t)\n", + "\n", + "if sw is not None:\n", + " sw.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MXNet Numpy Operators in Gluon `HybridBlock`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import mxnet as mx\n", + "from mxnet import gluon, autograd\n", + "try:\n", + " from mxboard import SummaryWriter\n", + "except ImportError:\n", + " SummaryWriter = None\n", + "\n", + "# create a summary writer for visualization\n", + "sw = SummaryWriter(logdir='./logs', flush_secs=2) if SummaryWriter is not None else None\n", + "\n", + "# Use numpy-compatible semantics to support scalar tensors\n", + "mx.set_np_compat(True)\n", + "\n", + "\n", + "class LinearRegression(gluon.HybridBlock):\n", + " def __init__(self, num_input_dim=1000, num_hidden_dim=100, num_output_dim=10):\n", + " super(LinearRegression, self).__init__()\n", + " with self.name_scope():\n", + " self.w1 = self.params.get('w1', shape=(num_input_dim, num_hidden_dim),\n", + " allow_deferred_init=True)\n", + " self.w2 = self.params.get('w2', shape=(num_hidden_dim, num_output_dim),\n", + " allow_deferred_init=True)\n", + "\n", + " def hybrid_forward(self, F, x, w1, w2):\n", + " h = x.dot(w1) # equivalent to F.np.dot(x, w1)\n", + " h_relu = F.np.ext.relu(h) # equivalent to F.relu(h)\n", + " y_pred = h_relu.dot(w2) # equivalent to F.np.dot(h_relu, w2)\n", + " return y_pred\n", + "\n", + "\n", + "class TotalLoss(gluon.HybridBlock):\n", + " def hybrid_forward(self, F, pred, label):\n", + " return ((pred - label) ** 2).sum() # equivalent to F.np.sum(F.np.square(pred - label))\n", + "\n", + "\n", + "regressor = LinearRegression()\n", + "regressor.initialize(mx.init.Normal())\n", + "regressor.hybridize()\n", + "\n", + "# Create random input and output data\n", + "x = mx.nd.random.normal(shape=(64, 1000)).as_np_ndarray() # x is of type mxnet.numpy.ndarray\n", + "y = mx.nd.random.normal(shape=(64, 10)).as_np_ndarray() # y is of type mxnet.numpy.ndarray\n", + "\n", + "total_loss = TotalLoss()\n", + "trainer = gluon.Trainer(regressor.collect_params(), 'sgd', {'learning_rate': 1e-3, 'momentum': 0.9})\n", + "\n", + "for t in range(1000):\n", + " with autograd.record():\n", + " output = regressor(x) # output is a type of np.ndarray because np.dot is the last op in the network\n", + " loss = total_loss(output, y) # loss is a scalar np.ndarray\n", + " loss.backward()\n", + " print(t, loss) # note that loss.asnumpy() is called\n", + " trainer.step(1)\n", + " if sw is not None:\n", + " sw.add_scalar('loss', loss.item(), global_step=t) # loss.item() copies the tensor element to a python scalar\n", + " if t % 50 == 0:\n", + " for k, v in regressor.collect_params().items():\n", + " sw.add_histogram(tag=k, values=v.data(), global_step=t)\n", + "\n", + "if sw is not None:\n", + " sw.close()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/include/mxnet/tuple.h b/include/mxnet/tuple.h index bc630f153744..08381e2152df 100644 --- a/include/mxnet/tuple.h +++ b/include/mxnet/tuple.h @@ -272,6 +272,14 @@ class Tuple { is.get(); if (ch == '(' || ch == '[') break; if (!isspace(ch)) { + if (ch == 'N') { + std::string tmp_val; + is >> tmp_val; + if (tmp_val == "one") { // is stores "None" + t.SetDim(-1); + return is; + } + } is.setstate(std::ios::failbit); return is; } diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 429d293ad10f..df5e6a615f00 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -749,7 +749,7 @@ def _sanity_check_params(func_name, unsupported_params, param_dict): .format(func_name, param_name)) -_NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_'] +_NP_OP_SUBMODULE_LIST = ['_ext_', '_random_', '_linalg_'] _NP_OP_PREFIX = '_numpy_' @@ -798,10 +798,9 @@ def _init_np_op_module(root_namespace, module_name, make_op_func): submodule_pattern = "%s.%s.numpy.%s" module_np_op = sys.modules[module_pattern % (root_namespace, module_name)] submodule_dict = {} - # TODO(junwu): uncomment the following lines when adding numpy ops in submodules, e.g. np.random - # for submodule_name in _NP_OP_SUBMODULE_LIST: - # submodule_dict[submodule_name] = \ - # sys.modules[submodule_pattern % (root_namespace, module_name, submodule_name[1:-1])] + for submodule_name in _NP_OP_SUBMODULE_LIST: + submodule_dict[submodule_name] = \ + sys.modules[submodule_pattern % (root_namespace, module_name, submodule_name[1:-1])] for name in op_names: hdl = OpHandle() check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) diff --git a/python/mxnet/ndarray/numpy/__init__.py b/python/mxnet/ndarray/numpy/__init__.py index a714a4b19fa4..d97e8086e8c3 100644 --- a/python/mxnet/ndarray/numpy/__init__.py +++ b/python/mxnet/ndarray/numpy/__init__.py @@ -17,6 +17,9 @@ """numpy module for numpy ops under mxnet.ndarray.""" +from . import ext +from . import random +from . import linalg from . import _op from . import _register from ._op import * # pylint: disable=wildcard-import diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 383bf2fdb792..9b32c314df7c 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -19,11 +19,12 @@ from __future__ import absolute_import import numpy as _np -from ...base import _sanity_check_params, use_np_compat +from ...base import _sanity_check_params, use_np_compat, numeric_types from ...context import current_context from .. import _internal +from ..ndarray import NDArray -__all__ = ['zeros', 'ones'] +__all__ = ['zeros', 'ones', 'maximum', 'minimum'] @use_np_compat @@ -86,3 +87,88 @@ def ones(shape, dtype=None, **kwargs): ctx = current_context() dtype = _np.float32 if dtype is None else dtype return _internal._np_ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs) + + +#pylint: disable= too-many-arguments, no-member, protected-access +def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, out=None): + """ Helper function for element-wise operation. + The function will perform numpy-like broadcasting if needed and call different functions. + + Parameters + -------- + lhs : NDArray or numeric value + Left-hand side operand. + + rhs : NDArray or numeric value + Right-hand operand, + + fn_array : function + Function to be called if both lhs and rhs are of ``NDArray`` type. + + fn_scalar : function + Function to be called if both lhs and rhs are numeric values. + + lfn_scalar : function + Function to be called if lhs is ``NDArray`` while rhs is numeric value + + rfn_scalar : function + Function to be called if lhs is numeric value while rhs is ``NDArray``; + if none is provided, then the function is commutative, so rfn_scalar is equal to lfn_scalar + + Returns + -------- + mxnet.numpy.ndarray + result array + """ + if isinstance(lhs, numeric_types): + if isinstance(rhs, numeric_types): + return fn_scalar(lhs, rhs, out=out) + else: + if rfn_scalar is None: + # commutative function + return lfn_scalar(rhs, float(lhs), out=out) + else: + return rfn_scalar(rhs, float(lhs), out=out) + elif isinstance(rhs, numeric_types): + return lfn_scalar(lhs, float(rhs), out=out) + elif isinstance(rhs, NDArray): + return fn_array(lhs, rhs, out=out) + else: + raise TypeError('type %s not supported' % str(type(rhs))) +#pylint: enable= too-many-arguments, no-member, protected-access + + +@use_np_compat +def maximum(x1, x2, out=None): + """Returns element-wise maximum of the input arrays with broadcasting. + + Parameters + ---------- + x1, x2 : scalar or mxnet.numpy.ndarray + The arrays holding the elements to be compared. They must have the same shape, + or shapes that can be broadcast to a single shape. + + Returns + ------- + out : mxnet.numpy.ndarray or scalar + The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.""" + return _ufunc_helper(x1, x2, _internal._np_maximum, _np.maximum, + _internal._np_maximum_scalar, None, out) + + +@use_np_compat +def minimum(x1, x2, out=None): + """Returns element-wise minimum of the input arrays with broadcasting. + + Parameters + ---------- + x1, x2 : scalar or mxnet.numpy.ndarray + The arrays holding the elements to be compared. They must have the same shape, + or shapes that can be broadcast to a single shape. + + Returns + ------- + out : mxnet.numpy.ndarray or scalar + The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.""" + return _ufunc_helper(x1, x2, _internal._np_minimum, _np.minimum, + _internal._np_minimum_scalar, None, out) diff --git a/python/mxnet/ndarray/numpy/ext.py b/python/mxnet/ndarray/numpy/ext.py new file mode 100644 index 000000000000..e13423f82535 --- /dev/null +++ b/python/mxnet/ndarray/numpy/ext.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""numpy.ext namespace for operators used in Gluon APIs dispatched by F=ndarray module.""" + +__all__ = [] diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py new file mode 100644 index 000000000000..b8f10b343430 --- /dev/null +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""numpy.linalg namespace for operators used in Gluon APIs dispatched by F=symbol module.""" + +__all__ = [] diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py new file mode 100644 index 000000000000..60908b5c8098 --- /dev/null +++ b/python/mxnet/ndarray/numpy/random.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""numpy.random namespace for operators used in Gluon APIs dispatched by F=ndarray module.""" + +__all__ = [] diff --git a/python/mxnet/numpy/__init__.py b/python/mxnet/numpy/__init__.py index c4dea9e5663e..2a58f270b96d 100644 --- a/python/mxnet/numpy/__init__.py +++ b/python/mxnet/numpy/__init__.py @@ -20,10 +20,11 @@ """numpy module for imperative programming.""" from __future__ import absolute_import -from .multiarray import * # pylint: disable=wildcard-import -from . import _op from . import random from . import linalg +from . import ext +from .multiarray import * # pylint: disable=wildcard-import +from . import _op from . import _register from ._op import * # pylint: disable=wildcard-import diff --git a/python/mxnet/numpy/ext.py b/python/mxnet/numpy/ext.py new file mode 100644 index 000000000000..e4c82518d474 --- /dev/null +++ b/python/mxnet/numpy/ext.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""namespace for registering numpy.ext ops for imperative programming.""" + +__all__ = [] diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py index 1527c61f1ad9..96c7ddc06612 100644 --- a/python/mxnet/numpy/linalg.py +++ b/python/mxnet/numpy/linalg.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. -"""namespace for registering numpy ops of linear algebra.""" +"""namespace for registering numpy.linalg ops for imperative programming.""" __all__ = [] diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 9f47ce15fc81..6c414b4c6266 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -27,14 +27,14 @@ import numpy as _np from ..ndarray import NDArray, _DTYPE_NP_TO_MX from ..ndarray._internal import _set_np_ndarray_class -from . import _op +from . import _op as _mx_np_op from ..base import use_np_compat, check_call, _LIB, NDArrayHandle, _sanity_check_params from ..base import mx_real_t, c_array_buf, mx_uint, numeric_types from ..context import current_context from ..ndarray import numpy as _mx_nd_np from ..ndarray import _internal as _nd_internal -__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones'] +__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum'] # This function is copied from ndarray.py since pylint @@ -73,7 +73,7 @@ def _np_ndarray_cls(handle, writable=True, stype=0): _set_np_ndarray_class(_np_ndarray_cls) -class ndarray(NDArray): +class ndarray(NDArray): # pylint: disable=invalid-name """An array object represents a multidimensional, homogeneous array of fixed-size items. An associated data-type object describes the format of each element in the array (its byte-order, how many bytes it occupies in memory, whether it is an integer, a @@ -104,7 +104,15 @@ def __add__(self, other): @use_np_compat def __iadd__(self, other): - raise NotImplementedError + """x.__iadd__(y) <=> x += y""" + if not self.writable: + raise ValueError('trying to add to a readonly ndarray') + if isinstance(other, NDArray): + return _nd_internal._np_add(self, other, out=self) + elif isinstance(other, numeric_types): + return _nd_internal._np_add_scalar(self, float(other), out=self) + else: + raise TypeError('type {} is not supported'.format(str(type(other)))) @use_np_compat def __sub__(self, other): @@ -118,7 +126,15 @@ def __sub__(self, other): @use_np_compat def __isub__(self, other): - raise NotImplementedError + """x.__isub__(y) <=> x -= y""" + if not self.writable: + raise ValueError('trying to subtract from a readonly ndarray') + if isinstance(other, NDArray): + return _nd_internal._np_subtract(self, other, out=self) + elif isinstance(other, numeric_types): + return _nd_internal._np_subtract_scalar(self, float(other), out=self) + else: + raise TypeError('type {} is not supported'.format(str(type(other)))) @use_np_compat def __rsub__(self, other): @@ -285,6 +301,36 @@ def __len__(self): def __reduce__(self): return ndarray, (None,), self.__getstate__() + def item(self, *args): + """Copy an element of an array to a standard Python scalar and return it. + + Parameters + ---------- + *args : Arguments (variable number and type) + none: in this case, the method only works for arrays with one element (a.size == 1), + which element is copied into a standard Python scalar object and returned. + + int_type: this argument is interpreted as a flat index into the array, specifying which + element to copy and return. + + tuple of int_types: functions as does a single int_type argument, except that the + argument is interpreted as an nd-index into the array. + + Returns + ------- + z : Standard Python scalar object + A copy of the specified element of the array as a suitable Python scalar. + """ + # TODO(junwu): no need to call asnumpy() on the whole array. + return self.asnumpy().item(*args) + + @property + # pylint: disable= invalid-name, undefined-variable + def T(self): + """Same as self.transpose(). This always returns a copy of self.""" + return self.transpose() + # pylint: enable= invalid-name, undefined-variable + @use_np_compat def _slice(self, start, stop): raise NotImplementedError @@ -380,9 +426,16 @@ def copy(self, order='C'): # pylint: disable=arguments-differ return super(ndarray, self).copy().as_np_ndarray() @use_np_compat - def reshape(self, *shape, **kwargs): + def dot(self, b, out=None): + return _mx_np_op.dot(self, b, out=out) + + @use_np_compat + def reshape(self, shape, order='C'): # pylint: disable=arguments-differ """Returns an array containing the same data with a new shape.""" - raise NotImplementedError + if order != 'C': + raise NotImplementedError('reshape only supports C-order,' + ' while received {}'.format(order)) + return _mx_np_op.reshape(self, shape=shape, order=order) def reshape_like(self, *args, **kwargs): """Convenience fluent method for :py:func:`reshape_like`. @@ -626,13 +679,13 @@ def tile(self, *args, **kwargs): raise AttributeError('mxnet.numpy.ndarray object has no attribute tile') @use_np_compat - def transpose(self, *args, **kwargs): + def transpose(self, *axes): # pylint: disable=arguments-differ """Convenience fluent method for :py:func:`transpose`. The arguments are the same as for :py:func:`transpose`, with this array as data. """ - raise NotImplementedError + return _mx_np_op.transpose(self, axes=axes if len(axes) != 0 else None) def flip(self, *args, **kwargs): """Convenience fluent method for :py:func:`flip`. @@ -667,13 +720,13 @@ def diag(self, k=0, **kwargs): raise AttributeError('mxnet.numpy.ndarray object has no attribute diag') @use_np_compat - def sum(self, *args, **kwargs): + def sum(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ """Convenience fluent method for :py:func:`sum`. The arguments are the same as for :py:func:`sum`, with this array as data. """ - return _op.sum(self, *args, **kwargs) + return _mx_np_op.sum(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims) def nansum(self, *args, **kwargs): """Convenience fluent method for :py:func:`nansum`. @@ -1069,11 +1122,6 @@ def size(self): def stype(self): raise AttributeError('mxnet.numpy.ndarray object has no attribute stype') - @property - @use_np_compat - def T(self): - raise NotImplementedError - def tostype(self, stype): raise AttributeError('mxnet.numpy.ndarray object has no attribute tostype') @@ -1198,3 +1246,35 @@ def ones(shape, dtype=None, **kwargs): Array of zeros with the given shape, dtype, and ctx. """ return _mx_nd_np.ones(shape, dtype, **kwargs) + + +def maximum(x1, x2, out=None): + """Returns element-wise maximum of the input arrays with broadcasting. + + Parameters + ---------- + x1, x2 : scalar or mxnet.numpy.ndarray + The arrays holding the elements to be compared. They must have the same shape, + or shapes that can be broadcast to a single shape. + + Returns + ------- + out : mxnet.numpy.ndarray or scalar + The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.""" + return _mx_nd_np.maximum(x1, x2, out=out) + + +def minimum(x1, x2, out=None): + """Returns element-wise minimum of the input arrays with broadcasting. + + Parameters + ---------- + x1, x2 : scalar or mxnet.numpy.ndarray + The arrays holding the elements to be compared. They must have the same shape, + or shapes that can be broadcast to a single shape. + + Returns + ------- + out : mxnet.numpy.ndarray or scalar + The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.""" + return _mx_nd_np.minimum(x1, x2, out=out) diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index 461da667b2d1..b1f4b02e5a71 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. -"""namespace for registering numpy random operators.""" +"""namespace for registering numpy.random ops for imperative programming.""" __all__ = [] diff --git a/python/mxnet/symbol/numpy/__init__.py b/python/mxnet/symbol/numpy/__init__.py index d63daa2c1400..1f20c037a0ec 100644 --- a/python/mxnet/symbol/numpy/__init__.py +++ b/python/mxnet/symbol/numpy/__init__.py @@ -17,6 +17,9 @@ """numpy module for numpy ops under mxnet.symbol.""" +from . import random +from . import linalg +from . import ext from . import _op, _symbol from ._symbol import _NumpySymbol from . import _register diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 087f11827010..8cf6e3039d98 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. +# pylint: disable=too-many-lines """numpy namespace for operators used in Gluon APIs dispatched by F=symbol module.""" from __future__ import absolute_import import ctypes import numpy as _np -from . import _op as _np_op +from . import _op as _mx_np_op from ...base import _sanity_check_params, use_np_compat, check_call, _LIB, SymbolHandle from ...base import numeric_types from ...context import current_context @@ -29,7 +30,7 @@ from .._internal import _set_np_symbol_class from .. import _internal as _sym_internal -__all__ = ['zeros', 'ones'] +__all__ = ['zeros', 'ones', 'maximum', 'minimum'] class _NumpySymbol(Symbol): @@ -237,13 +238,27 @@ def as_classic_ndarray(self): check_call(_LIB.MXShallowCopySymbol(self.handle, ctypes.byref(hdl))) return Symbol(handle=hdl) + @property + # pylint: disable= invalid-name, undefined-variable + def T(self): + """Same as self.transpose().""" + return self.transpose() + # pylint: enable= invalid-name, undefined-variable + @use_np_compat def astype(self, dtype, **kwargs): # pylint: disable=arguments-differ raise NotImplementedError @use_np_compat - def reshape(self, *shape, **kwargs): - raise NotImplementedError + def dot(self, b, out=None): + return _mx_np_op.dot(self, b, out=out) + + @use_np_compat + def reshape(self, shape, order='C'): # pylint: disable=arguments-differ + if order != 'C': + raise NotImplementedError('ndarray.copy only supports order=\'C\', while ' + 'received {}'.format(str(order))) + return _mx_np_op.reshape(self, shape=shape, order=order) def reshape_like(self, *args, **kwargs): """Convenience fluent method for :py:func:`reshape_like`. @@ -487,13 +502,13 @@ def tile(self, *args, **kwargs): raise AttributeError('_NumpySymbol object has no attribute tile') @use_np_compat - def transpose(self, *args, **kwargs): + def transpose(self, *axes): # pylint: disable=arguments-differ """Convenience fluent method for :py:func:`transpose`. The arguments are the same as for :py:func:`transpose`, with this array as data. """ - raise NotImplementedError + return _mx_np_op.transpose(self, axes=axes if len(axes) != 0 else None) def flip(self, *args, **kwargs): """Convenience fluent method for :py:func:`flip`. @@ -528,13 +543,13 @@ def diag(self, k=0, **kwargs): raise AttributeError('_NumpySymbol object has no attribute diag') @use_np_compat - def sum(self, *args, **kwargs): + def sum(self, axis=None, dtype=None, out=None, keepdims=False): # pylint: disable=arguments-differ """Convenience fluent method for :py:func:`sum`. The arguments are the same as for :py:func:`sum`, with this array as data. """ - return _np_op.sum(self, *args, **kwargs) + return _mx_np_op.sum(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims) def nansum(self, *args, **kwargs): """Convenience fluent method for :py:func:`nansum`. @@ -971,4 +986,65 @@ def ones(shape, dtype=None, **kwargs): return _internal._np_ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs) +#pylint: disable= too-many-arguments, no-member, protected-access +def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, out=None): + """ Helper function for element-wise operation. + The function will perform numpy-like broadcasting if needed and call different functions. + + Parameters + -------- + lhs : Symbol or numeric value + Left-hand side operand. + + rhs : Symbol or numeric value + Right-hand operand, + + fn_array : function + Function to be called if both lhs and rhs are of ``Symbol`` type. + + fn_scalar : function + Function to be called if both lhs and rhs are numeric values. + + lfn_scalar : function + Function to be called if lhs is ``Symbol`` while rhs is numeric value + + rfn_scalar : function + Function to be called if lhs is numeric value while rhs is ``Symbol``; + if none is provided, then the function is commutative, so rfn_scalar is equal to lfn_scalar + + Returns + -------- + mxnet.numpy.ndarray + result array + """ + if isinstance(lhs, numeric_types): + if isinstance(rhs, numeric_types): + return fn_scalar(lhs, rhs, out=out) + else: + if rfn_scalar is None: + # commutative function + return lfn_scalar(rhs, float(lhs), out=out) + else: + return rfn_scalar(rhs, float(lhs), out=out) + elif isinstance(rhs, numeric_types): + return lfn_scalar(lhs, float(rhs), out=out) + elif isinstance(rhs, Symbol): + return fn_array(lhs, rhs, out=out) + else: + raise TypeError('type %s not supported' % str(type(rhs))) +#pylint: enable= too-many-arguments, no-member, protected-access + + +@use_np_compat +def maximum(x1, x2, out=None): + return _ufunc_helper(x1, x2, _internal._np_maximum, _np.maximum, + _internal._np_maximum_scalar, None, out) + + +@use_np_compat +def minimum(x1, x2, out=None): + return _ufunc_helper(x1, x2, _internal._np_minimum, _np.minimum, + _internal._np_minimum_scalar, None, out) + + _set_np_symbol_class(_NumpySymbol) diff --git a/python/mxnet/symbol/numpy/ext.py b/python/mxnet/symbol/numpy/ext.py new file mode 100644 index 000000000000..12c5f15cba55 --- /dev/null +++ b/python/mxnet/symbol/numpy/ext.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""numpy.ext namespace for operators used in Gluon APIs dispatched by F=symbol module.""" + +__all__ = [] diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py new file mode 100644 index 000000000000..b8f10b343430 --- /dev/null +++ b/python/mxnet/symbol/numpy/linalg.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""numpy.linalg namespace for operators used in Gluon APIs dispatched by F=symbol module.""" + +__all__ = [] diff --git a/python/mxnet/symbol/numpy/random.py b/python/mxnet/symbol/numpy/random.py new file mode 100644 index 000000000000..79c73d871dd8 --- /dev/null +++ b/python/mxnet/symbol/numpy/random.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""numpy.random namespace for operators used in Gluon APIs dispatched by F=symbol module.""" + +__all__ = [] diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index ab1f5f71da99..82fe28bf38bc 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -177,11 +177,7 @@ extern const std::vector kHiddenKeys; inline bool IsNumpyCompatOp(const nnvm::Op* op) { static const auto& is_np_compat = nnvm::Op::GetAttr("TIsNumpyCompatible"); - if (is_np_compat.get(op, false)) { - return true; - } - static const std::string prefix = "_numpy_"; - return op->name.find(prefix.c_str(), 0, prefix.size()) != std::string::npos; + return is_np_compat.get(op, false); } #endif // MXNET_C_API_C_API_COMMON_H_ diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index e8988c80455e..5d36c29fc331 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -161,6 +161,16 @@ Example:: .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"}) .set_attr("TIsNumpyCompatible", true); +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_maximum) +.describe(R"code()code" ADD_FILELINE) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("TIsNumpyCompatible", true); + +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_np_minimum) +.describe(R"code()code" ADD_FILELINE) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("TIsNumpyCompatible", true); + MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseNone{"_copy"}); @@ -193,5 +203,13 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_rpower_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"}); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_maximum_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_maximum_scalar"}); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_np_minimum_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_minimum_scalar"}); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index 186bd1baac5b..26e2fceb839f 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -42,6 +42,12 @@ NNVM_REGISTER_OP(_np_mod) NNVM_REGISTER_OP(_np_power) .set_attr("FCompute", BinaryBroadcastCompute); +NNVM_REGISTER_OP(_np_maximum) +.set_attr("FCompute", BinaryBroadcastCompute); + +NNVM_REGISTER_OP(_np_minimum) +.set_attr("FCompute", BinaryBroadcastCompute); + NNVM_REGISTER_OP(_np_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); @@ -52,8 +58,7 @@ NNVM_REGISTER_OP(_np_rsubtract_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); NNVM_REGISTER_OP(_np_multiply_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FComputeEx", BinaryScalarOp::ComputeEx); +.set_attr("FCompute", BinaryScalarOp::Compute); NNVM_REGISTER_OP(_np_mod_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); @@ -67,5 +72,11 @@ NNVM_REGISTER_OP(_np_power_scalar) NNVM_REGISTER_OP(_np_rpower_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); +NNVM_REGISTER_OP(_np_maximum_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_np_minimum_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cc b/src/operator/numpy/np_elemwise_unary_op_basic.cc new file mode 100644 index 000000000000..f31ed5e11f15 --- /dev/null +++ b/src/operator/numpy/np_elemwise_unary_op_basic.cc @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_elemwise_unary_op_basic.cc + * \brief CPU Implementation of numpy elementwise unary function. + */ +#include +#include "../tensor/elemwise_unary_op.h" + +namespace mxnet { +namespace op { + +MXNET_OPERATOR_REGISTER_UNARY(_numpy__ext_relu) +.describe(R"code(Computes rectified linear activation. + +.. math:: + max(features, 0) + +)code" ADD_FILELINE) +.set_attr("FCompute", UnaryOp::Compute) +.set_attr("FGradient", ElemwiseGradUseOut{"_backward_relu"}) +.set_attr("TIsNumpyCompatible", true); + +MXNET_OPERATOR_REGISTER_UNARY(_numpy__ext_sigmoid) +.describe(R"code(Computes sigmoid of x element-wise. + +.. math:: + y = 1 / (1 + exp(-x)) + +)code" ADD_FILELINE) +.set_attr("FCompute", UnaryOp::Compute) +.set_attr("FGradient", ElemwiseGradUseOut{"_backward_sigmoid"}) +.set_attr("TIsNumpyCompatible", true); + +MXNET_OPERATOR_REGISTER_UNARY(_np_copy) +.MXNET_DESCRIBE("Returns a copy of the input.") +.set_attr("FCompute", UnaryOp::IdentityCompute) +.set_attr("FInplaceIdentity", + [](const NodeAttrs& attrs){ + return std::vector{true}; + }) +.set_attr("FGradient", ElemwiseGradUseNone{"_copy"}) +.set_attr("TIsNumpyCompatible", true); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_unary_op_basic.cu b/src/operator/numpy/np_elemwise_unary_op_basic.cu new file mode 100644 index 000000000000..9f108f75fc15 --- /dev/null +++ b/src/operator/numpy/np_elemwise_unary_op_basic.cu @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_elemwise_unary_op_basic.cu + * \brief GPU Implementation of numpy unary functions. + */ +#include "../tensor/elemwise_binary_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_numpy__ext_relu) +.set_attr("FCompute", UnaryOp::Compute); + +NNVM_REGISTER_OP(_numpy__ext_sigmoid) +.set_attr("FCompute", UnaryOp::Compute); + +NNVM_REGISTER_OP(_np_copy) +.set_attr("FCompute", UnaryOp::IdentityCompute); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h new file mode 100644 index 000000000000..44a6c909c9cf --- /dev/null +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_matrix_op-inl.h + * \brief Function definition of matrix related operators + */ +#ifndef MXNET_OPERATOR_NUMPY_NP_MATRIX_OP_INL_H_ +#define MXNET_OPERATOR_NUMPY_NP_MATRIX_OP_INL_H_ + +#include +#include "../tensor/matrix_op-inl.h" + +namespace mxnet { +namespace op { + +struct NumpyTransposeParam : public dmlc::Parameter { + mxnet::TShape axes; + DMLC_DECLARE_PARAMETER(NumpyTransposeParam) { + DMLC_DECLARE_FIELD(axes).set_default(mxnet::TShape(-1, 0)) + .describe("By default, reverse the dimensions, otherwise permute " + "the axes according to the values given."); + } +}; + +template +void NumpyTranspose(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const NumpyTransposeParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace"; + if (ndim_is_known(param.axes)) { + TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], param.axes); + } else { + mxnet::TShape axes(inputs[0].ndim(), -1); + for (int i = 0; i < axes.ndim(); ++i) { + axes[i] = axes.ndim() - 1 - i; + } + TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_MATRIX_OP_INL_H_ diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc new file mode 100644 index 000000000000..215b1c5a8c87 --- /dev/null +++ b/src/operator/numpy/np_matrix_op.cc @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_matrix_op.cc + * \brief CPU Implementation of numpy matrix operations + */ + +#include "./np_matrix_op-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(NumpyTransposeParam); + +bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + const NumpyTransposeParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + mxnet::TShape& shp = (*in_attrs)[0]; + CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; + mxnet::TShape ret(shp.ndim(), -1); + if (ndim_is_known(param.axes)) { + CHECK_EQ(shp.ndim(), param.axes.ndim()); + for (int i = 0; i < shp.ndim(); ++i) { + CHECK(param.axes[i] < static_cast(shp.ndim())); + ret[i] = shp[param.axes[i]]; + } + } else { + for (int i = 0; i < shp.ndim(); ++i) { + ret[i] = shp[shp.ndim()-1-i]; + } + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret); + return shape_is_known(ret); +} + +NNVM_REGISTER_OP(_numpy_transpose) +.describe(R"code(Permute the dimensions of an array. + +Examples:: + + x = [[ 1, 2], + [ 3, 4]] + + transpose(x) = [[ 1., 3.], + [ 2., 4.]] + + x = [[[ 1., 2.], + [ 3., 4.]], + + [[ 5., 6.], + [ 7., 8.]]] + + transpose(x) = [[[ 1., 5.], + [ 3., 7.]], + + [[ 2., 6.], + [ 4., 8.]]] + + transpose(x, axes=(1,0,2)) = [[[ 1., 2.], + [ 5., 6.]], + + [[ 3., 4.], + [ 7., 8.]]] +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyTransposeShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + const NumpyTransposeParam& param = nnvm::get(n->attrs.parsed); + if (ndim_is_known(param.axes)) { + mxnet::TShape axes = mxnet::TShape(param.axes.ndim(), -1); + for (int i = 0; i < axes.ndim(); ++i) { + axes[param.axes[i]] = i; + } + std::ostringstream os; + os << axes; + return MakeNonlossGradNode("transpose", n, ograds, {}, {{"axes", os.str()}}); + } else { + return MakeNonlossGradNode("transpose", n, ograds, {}, + std::unordered_map()); + } + }) +.set_attr("FCompute", NumpyTranspose) +.set_attr("TIsNumpyCompatible", true) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a"}; + }) +.add_argument("a", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyTransposeParam::__FIELDS__()); + +struct NumpyReshapeParam : public dmlc::Parameter { + mxnet::TShape newshape; + std::string order; + DMLC_DECLARE_PARAMETER(NumpyReshapeParam) { + DMLC_DECLARE_FIELD(newshape) + .describe("The new shape should be compatible with the original shape." + " If an integer, then the result will be a 1-D array of that length." + " One shape dimension can be -1. In this case, the value is inferred" + " from the length of the array and remaining dimensions."); + DMLC_DECLARE_FIELD(order) + .set_default("C") + .describe("Read the elements of a using this index order, and place the elements into" + " the reshaped array using this index order. 'C' means to read/write the elements" + " using C-like index order, with the last axis index changing fastest, back to the" + " first axis index changing slowest. Note that currently only C-like order is" + " supported"); + } +}; + +DMLC_REGISTER_PARAMETER(NumpyReshapeParam); + +bool NumpyReshapeInferShape(const mxnet::TShape& src, mxnet::TShape* dst) { + if (shape_is_known(src) && shape_is_known(*dst)) { + CHECK_EQ(src.Size(), dst->Size()) << "Cannot reshape array of size " + << src.Size() << " into shape " << *dst; + return true; + } else if (!shape_is_known(src) || !ndim_is_known(*dst)) { + return false; + } else { + int unknown_axis = -1; + dim_t known_dim_size_prod = 1; + for (int i = 0; i < dst->ndim(); ++i) { + if (!dim_size_is_known(*dst, i)) { + if (unknown_axis == -1) { + unknown_axis = i; + } else { + return false; // more than one unknown dim + } + } else { + known_dim_size_prod *= (*dst)[i]; + } + } + CHECK_NE(known_dim_size_prod, 0) << "Cannot reshape array of size " + << src.Size() << " into shape " << *dst; + CHECK_EQ(src.Size() % known_dim_size_prod, 0) << "Cannot reshape array of size " + << src.Size() << " into shape " << *dst; + (*dst)[unknown_axis] = src.Size() / known_dim_size_prod; + return true; + } +} + +bool NumpyReshapeShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]"; + CHECK_EQ(out_attrs->size(), 1U); + const NumpyReshapeParam& param = nnvm::get(attrs.parsed); + // sanity check + bool has_unknown_dim_size = false; + for (int i = 0; i < param.newshape.ndim(); ++i) { + if (param.newshape[i] < 0) { + CHECK_EQ(param.newshape[i], -1) << "The shape dimension size to inferred must be -1"; + CHECK(!has_unknown_dim_size) << "Can only specify one unknown dimension"; + has_unknown_dim_size = true; + } + } + + mxnet::TShape target_shape = param.newshape; + bool success = NumpyReshapeInferShape(in_attrs->at(0), &target_shape); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, target_shape); + if (!success) { + success = NumpyReshapeInferShape(out_attrs->at(0), &in_attrs->at(0)); + } + return success; +} + +NNVM_REGISTER_OP(_numpy_reshape) +.describe(R"code()code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyReshapeShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_reshape"}) +.set_attr("FCompute", UnaryOp::IdentityCompute) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.set_attr("FInplaceIdentity", + [](const NodeAttrs& attrs){ + return std::vector{true}; + }) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a"}; + }) +.set_attr("TIsNumpyCompatible", true) +.add_argument("a", "NDArray-or-Symbol", "Array to be reshaped.") +.add_arguments(NumpyReshapeParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu new file mode 100644 index 000000000000..9753566aebe9 --- /dev/null +++ b/src/operator/numpy/np_matrix_op.cu @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_matrix_op.cu + * \brief GPU Implementation of numpy matrix operations + */ +#include "./np_matrix_op-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_numpy_transpose) +.set_attr("FCompute", NumpyTranspose); + +NNVM_REGISTER_OP(_numpy_reshape) +.set_attr("FCompute", UnaryOp::IdentityCompute); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index f84767dd4b2f..8a81bbc1c475 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -292,6 +292,7 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { + if (outputs[0].shape_.Size() == 0U) return; mxnet::TShape new_lshape, new_rshape, new_oshape; int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, &new_lshape, &new_rshape, &new_oshape); diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 5cd7bf6652d3..4e13354a4be6 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -265,11 +265,17 @@ void TransposeImpl(RunContext ctx, using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(src.type_flag_, ret.type_flag_); + // zero-size tensor, no need to compute + if (src.shape_.Size() == 0U) return; Stream *s = ctx.get_stream(); MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, { switch (axes.ndim()) { - case 0: + case 0: { + Tensor in = src.get_with_shape(mshadow::Shape1(1), s); + Tensor out = ret.get_with_shape(mshadow::Shape1(1), s); + Copy(out, in, s); break; + } case 1: { Tensor in = src.get(s); Tensor out = ret.get(s); diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 88e56acd04b8..141d153a207f 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -24,7 +24,6 @@ from mxnet.gluon import HybridBlock from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, assert_exception from common import with_seed -import random @with_seed() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 024c893880e7..8c13227584ae 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -192,6 +192,126 @@ def is_int(dtype): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) +@with_seed() +@mx.use_np_compat +def test_np_transpose(): + # TODO(junwu): Add more test cases + data = mx.sym.var('a') + ret = mx.sym.np.transpose(data) + assert type(ret) == mx.sym.np._NumpySymbol + + dtypes = ['float32', 'int32'] + for dtype in dtypes: + for ndim in [0, 1, 2, 3, 4, 5, 6]: + shape = rand_shape_nd(ndim, dim=5, allow_zero_size=True) + np_data = _np.random.uniform(low=-100, high=100, size=shape).astype(dtype) + mx_data = np.array(np_data, dtype=dtype) + axes = [None] + if ndim == 0: + axes += [()] + else: + axis = [i for i in range(ndim)] + axes.append(tuple(axis)) + random.shuffle(axis) + axes.append(tuple(axis)) + for axis in axes: + np_out = _np.transpose(np_data, axes=axis) + mx_out = np.transpose(mx_data, axes=axis) + assert np_out.dtype == mx_out.dtype + assert same(mx_out.asnumpy(), np_out) + # TODO(junwu): Add numerical gradient test and Gluon API test. + + +@with_seed() +@mx.use_np_compat +def test_relu(): + # TODO(junwu): Add more test cases + data = mx.sym.var('data') + ret = mx.sym.np.ext.relu(data) + assert type(ret) == mx.sym.np._NumpySymbol + + shapes = [(), (0, 2, 0)] + shapes.extend([rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)]) + for shape in shapes: + data = np.array(_np.random.uniform(size=shape).astype('float32')) + ret = np.ext.relu(data) + assert type(ret) == np.ndarray + + +@with_seed() +@mx.use_np_compat +def test_sigmoid(): + # TODO(junwu): Add more test cases + data = mx.sym.var('data') + ret = mx.sym.np.ext.sigmoid(data) + assert type(ret) == mx.sym.np._NumpySymbol + + shapes = [(), (0, 2, 0)] + shapes.extend([rand_shape_nd(ndim, allow_zero_size=True) for ndim in range(5)]) + for shape in shapes: + data = np.array(_np.random.uniform(size=shape).astype('float32')) + ret = np.ext.sigmoid(data) + assert type(ret) == np.ndarray + + +@with_seed() +@mx.use_np_compat +def test_np_reshape(): + # TODO(junwu): Add more test cases + data = mx.sym.var('a') + ret = mx.sym.np.reshape(data, newshape=()) + assert type(ret) == mx.sym.np._NumpySymbol + + data = np.ones((1, 1, 1)) + ret = np.reshape(data, ()) + assert ret.shape == () + ret = np.reshape(ret, (1, 1, 1, 1)) + assert ret.shape == (1, 1, 1, 1) + assert type(ret) == np.ndarray + + +@with_seed() +@mx.use_np_compat +def test_np_maximum(): + # TODO(junwu): Add more test cases + x1, x2 = mx.sym.var('x1'), mx.sym.var('x2') + ret = mx.sym.np.maximum(x1, x2) + assert type(ret) == mx.sym.np._NumpySymbol + + def check_maximum(x1, x2): + mx_out = np.maximum(x1, x2) + if isinstance(x1, np.ndarray) or isinstance(x2, np.ndarray): + assert type(mx_out) == np.ndarray + np_out = _np.maximum(x1.asnumpy() if isinstance(x1, np.ndarray) else x1, + x2.asnumpy() if isinstance(x2, np.ndarray) else x2) + assert same(mx_out.asnumpy() if isinstance(mx_out, np.ndarray) else mx_out, np_out) + + check_maximum(np.zeros((2, 1)), np.ones((5, 1, 4))) + check_maximum(np.zeros((2, 0)), np.ones((5, 1, 1))) + check_maximum(np.zeros(()), np.ones((5, 1, 4))) + + +@with_seed() +@mx.use_np_compat +def test_np_minimum(): + # TODO(junwu): Add more test cases + x1, x2 = mx.sym.var('x1'), mx.sym.var('x2') + ret = mx.sym.np.minimum(x1, x2) + assert type(ret) == mx.sym.np._NumpySymbol + + def check_minimum(x1, x2): + mx_out = np.minimum(x1, x2) + if isinstance(x1, np.ndarray) or isinstance(x2, np.ndarray): + assert type(mx_out) == np.ndarray + np_out = _np.minimum(x1.asnumpy() if isinstance(x1, np.ndarray) else x1, + x2.asnumpy() if isinstance(x2, np.ndarray) else x2) + assert same(mx_out.asnumpy() if isinstance(mx_out, np.ndarray) else mx_out, np_out) + + check_minimum(np.zeros((2, 1)), np.ones((5, 1, 4))) + check_minimum(np.zeros((2, 0)), np.ones((5, 1, 1))) + check_minimum(np.zeros(()), np.ones((5, 1, 4))) + + if __name__ == '__main__': import nose nose.runmodule()