Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[numpy] Some np ops for d2l (#14924)
Browse files Browse the repository at this point in the history
* Add np transpose

More ops and namespaces for submodules

Add relu and sigmoid

Add reshape

Fix symbolic name mismatch

Add maximum and minimum

* Add convenience fluent method

* Add ndarray.item()

* Fix CI

* Fix lint

* Fix lint

* Fix reshape gpu

* Add example

* Remove python notebook outputs

* Remove notebook output

* Add one more example
  • Loading branch information
reminisce authored and haojin2 committed Jul 26, 2019
1 parent 8251b76 commit 3e1929a
Show file tree
Hide file tree
Showing 30 changed files with 1,428 additions and 44 deletions.
415 changes: 415 additions & 0 deletions example/numpy/demo.ipynb

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions include/mxnet/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,14 @@ class Tuple {
is.get();
if (ch == '(' || ch == '[') break;
if (!isspace(ch)) {
if (ch == 'N') {
std::string tmp_val;
is >> tmp_val;
if (tmp_val == "one") { // is stores "None"
t.SetDim(-1);
return is;
}
}
is.setstate(std::ios::failbit);
return is;
}
Expand Down
9 changes: 4 additions & 5 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def _sanity_check_params(func_name, unsupported_params, param_dict):
.format(func_name, param_name))


_NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_']
_NP_OP_SUBMODULE_LIST = ['_ext_', '_random_', '_linalg_']
_NP_OP_PREFIX = '_numpy_'


Expand Down Expand Up @@ -798,10 +798,9 @@ def _init_np_op_module(root_namespace, module_name, make_op_func):
submodule_pattern = "%s.%s.numpy.%s"
module_np_op = sys.modules[module_pattern % (root_namespace, module_name)]
submodule_dict = {}
# TODO(junwu): uncomment the following lines when adding numpy ops in submodules, e.g. np.random
# for submodule_name in _NP_OP_SUBMODULE_LIST:
# submodule_dict[submodule_name] = \
# sys.modules[submodule_pattern % (root_namespace, module_name, submodule_name[1:-1])]
for submodule_name in _NP_OP_SUBMODULE_LIST:
submodule_dict[submodule_name] = \
sys.modules[submodule_pattern % (root_namespace, module_name, submodule_name[1:-1])]
for name in op_names:
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
Expand Down
3 changes: 3 additions & 0 deletions python/mxnet/ndarray/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

"""numpy module for numpy ops under mxnet.ndarray."""

from . import ext
from . import random
from . import linalg
from . import _op
from . import _register
from ._op import * # pylint: disable=wildcard-import
Expand Down
90 changes: 88 additions & 2 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

from __future__ import absolute_import
import numpy as _np
from ...base import _sanity_check_params, use_np_compat
from ...base import _sanity_check_params, use_np_compat, numeric_types
from ...context import current_context
from .. import _internal
from ..ndarray import NDArray

__all__ = ['zeros', 'ones']
__all__ = ['zeros', 'ones', 'maximum', 'minimum']


@use_np_compat
Expand Down Expand Up @@ -86,3 +87,88 @@ def ones(shape, dtype=None, **kwargs):
ctx = current_context()
dtype = _np.float32 if dtype is None else dtype
return _internal._np_ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs)


#pylint: disable= too-many-arguments, no-member, protected-access
def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, out=None):
""" Helper function for element-wise operation.
The function will perform numpy-like broadcasting if needed and call different functions.
Parameters
--------
lhs : NDArray or numeric value
Left-hand side operand.
rhs : NDArray or numeric value
Right-hand operand,
fn_array : function
Function to be called if both lhs and rhs are of ``NDArray`` type.
fn_scalar : function
Function to be called if both lhs and rhs are numeric values.
lfn_scalar : function
Function to be called if lhs is ``NDArray`` while rhs is numeric value
rfn_scalar : function
Function to be called if lhs is numeric value while rhs is ``NDArray``;
if none is provided, then the function is commutative, so rfn_scalar is equal to lfn_scalar
Returns
--------
mxnet.numpy.ndarray
result array
"""
if isinstance(lhs, numeric_types):
if isinstance(rhs, numeric_types):
return fn_scalar(lhs, rhs, out=out)
else:
if rfn_scalar is None:
# commutative function
return lfn_scalar(rhs, float(lhs), out=out)
else:
return rfn_scalar(rhs, float(lhs), out=out)
elif isinstance(rhs, numeric_types):
return lfn_scalar(lhs, float(rhs), out=out)
elif isinstance(rhs, NDArray):
return fn_array(lhs, rhs, out=out)
else:
raise TypeError('type %s not supported' % str(type(rhs)))
#pylint: enable= too-many-arguments, no-member, protected-access


@use_np_compat
def maximum(x1, x2, out=None):
"""Returns element-wise maximum of the input arrays with broadcasting.
Parameters
----------
x1, x2 : scalar or mxnet.numpy.ndarray
The arrays holding the elements to be compared. They must have the same shape,
or shapes that can be broadcast to a single shape.
Returns
-------
out : mxnet.numpy.ndarray or scalar
The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
return _ufunc_helper(x1, x2, _internal._np_maximum, _np.maximum,
_internal._np_maximum_scalar, None, out)


@use_np_compat
def minimum(x1, x2, out=None):
"""Returns element-wise minimum of the input arrays with broadcasting.
Parameters
----------
x1, x2 : scalar or mxnet.numpy.ndarray
The arrays holding the elements to be compared. They must have the same shape,
or shapes that can be broadcast to a single shape.
Returns
-------
out : mxnet.numpy.ndarray or scalar
The minimum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars."""
return _ufunc_helper(x1, x2, _internal._np_minimum, _np.minimum,
_internal._np_minimum_scalar, None, out)
20 changes: 20 additions & 0 deletions python/mxnet/ndarray/numpy/ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""numpy.ext namespace for operators used in Gluon APIs dispatched by F=ndarray module."""

__all__ = []
20 changes: 20 additions & 0 deletions python/mxnet/ndarray/numpy/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""numpy.linalg namespace for operators used in Gluon APIs dispatched by F=symbol module."""

__all__ = []
20 changes: 20 additions & 0 deletions python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""numpy.random namespace for operators used in Gluon APIs dispatched by F=ndarray module."""

__all__ = []
5 changes: 3 additions & 2 deletions python/mxnet/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
"""numpy module for imperative programming."""

from __future__ import absolute_import
from .multiarray import * # pylint: disable=wildcard-import
from . import _op
from . import random
from . import linalg
from . import ext
from .multiarray import * # pylint: disable=wildcard-import
from . import _op
from . import _register
from ._op import * # pylint: disable=wildcard-import

Expand Down
20 changes: 20 additions & 0 deletions python/mxnet/numpy/ext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""namespace for registering numpy.ext ops for imperative programming."""

__all__ = []
2 changes: 1 addition & 1 deletion python/mxnet/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.

"""namespace for registering numpy ops of linear algebra."""
"""namespace for registering numpy.linalg ops for imperative programming."""

__all__ = []
Loading

0 comments on commit 3e1929a

Please sign in to comment.