Skip to content

Commit

Permalink
[numpy] Refactor np modules (apache#14989)
Browse files Browse the repository at this point in the history
* Refactor

* Initial refactoring

* Fix notebook

* Move numpy op check from backend to frontend

* Add homogeneous ndarray check

* Fix grouping inhomogeneous types of symbols

* Improve error handling of different types of symbols as outputs

* Fix test

* Fix numpy test

* Fix ci

* Try to fix gpu ci failure
  • Loading branch information
reminisce authored and Ying committed Jul 18, 2019
1 parent 1c6c212 commit 04ea773
Show file tree
Hide file tree
Showing 16 changed files with 103 additions and 93 deletions.
9 changes: 0 additions & 9 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,15 +319,6 @@ using FNeedRequantize = std::function<bool (const NodeAttrs& attrs)>;
using FAvoidQuantizeInput = std::function<bool (const NodeAttrs& attrs,
size_t index)>;

/*!
* \brief Indicates whether this operator is NumPy compatible.
* It is for distinguishing the operator from classic MXNet operators
* which do not support zero-dim and zero-size tensors.
* In Python, it is used to determine whether to output numpy ndarrays
* or symbols that are NumPy compatible.
*/
using TIsNumpyCompatible = bool;

} // namespace mxnet

#endif // MXNET_OP_ATTR_TYPES_H_
3 changes: 1 addition & 2 deletions python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..base import _LIB
from ..base import c_str_array, c_handle_array
from ..base import NDArrayHandle, CachedOpHandle
from ..base import check_call, _is_np_compat_op
from ..base import check_call


class NDArrayBase(object):
Expand Down Expand Up @@ -102,7 +102,6 @@ def _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op):
create_ndarray_fn = _np_ndarray_cls if is_np_op else _ndarray_cls
if original_output is not None:
return original_output
create_ndarray_fn = _np_ndarray_cls if _is_np_compat_op(handle) else _ndarray_cls
if num_output.value == 1:
return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle),
stype=out_stypes[0])
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/_ctypes/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import ctypes
from ..base import _LIB
from ..base import c_str_array, c_handle_array, c_str, mx_uint, _is_np_compat_op
from ..base import c_str_array, c_handle_array, c_str, mx_uint
from ..base import SymbolHandle
from ..base import check_call

Expand Down
1 change: 1 addition & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ def write_all_str(module_file, module_all_list):

_NP_INTERNAL_OP_PREFIX = '_npi_'

_NP_EXT_OP_PREFIX = '_npe_'

def _is_np_op(op_name):
return op_name.startswith(_NP_OP_PREFIX) or op_name.startswith(_NP_EXT_OP_PREFIX)\
Expand Down
1 change: 0 additions & 1 deletion python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,6 @@ def shape_is_known(shape):
"received {}".format(unknown_dim_size, dim_size)
return True


def _check_same_symbol_type(symbols):
"""Check whether all the symbols in the list are of the same type.
Raise type error if the types are different. Return the class of
Expand Down
36 changes: 0 additions & 36 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,6 @@ def __iadd__(self, other):
raise TypeError('type %s not supported' % str(type(other)))

def __radd__(self, other):
if isinstance(other, NDArray) and other._is_np_compat():
return other.__add__(self)
return self.__add__(other)

def __sub__(self, other):
Expand All @@ -262,14 +260,10 @@ def __isub__(self, other):

def __rsub__(self, other):
"""x.__rsub__(y) <=> y-x <=> mx.nd.subtract(y, x) """
if isinstance(other, NDArray) and other._is_np_compat():
return other.__sub__(self)
return subtract(other, self)

def __mul__(self, other):
"""x.__mul__(y) <=> x*y <=> mx.nd.multiply(x, y) """
if isinstance(other, NDArray) and other._is_np_compat():
return other.__mul__(self)
return multiply(self, other)

def __neg__(self):
Expand All @@ -288,20 +282,14 @@ def __imul__(self, other):
raise TypeError('type %s not supported' % str(type(other)))

def __rmul__(self, other):
if isinstance(other, NDArray) and other._is_np_compat():
return other.__mul__(self)
return self.__mul__(other)

def __div__(self, other):
"""x.__div__(y) <=> x/y <=> mx.nd.divide(x, y) """
if isinstance(other, NDArray) and other._is_np_compat():
return other.__rtruediv__(self)
return divide(self, other)

def __rdiv__(self, other):
"""x.__rdiv__(y) <=> y/x <=> mx.nd.divide(y, x) """
if isinstance(other, NDArray) and other._is_np_compat():
return other.__truediv__(self)
return divide(other, self)

def __idiv__(self, other):
Expand All @@ -316,28 +304,20 @@ def __idiv__(self, other):
raise TypeError('type %s not supported' % str(type(other)))

def __truediv__(self, other):
if isinstance(other, NDArray) and other._is_np_compat():
return other.__rtruediv__(self)
return divide(self, other)

def __rtruediv__(self, other):
if isinstance(other, NDArray) and other._is_np_compat():
return other.__truediv__(self)
return divide(other, self)

def __itruediv__(self, other):
return self.__idiv__(other)

def __mod__(self, other):
"""x.__mod__(y) <=> x%y <=> mx.nd.modulo(x, y) """
if isinstance(other, NDArray) and other._is_np_compat():
return other.__rmod__(self)
return modulo(self, other)

def __rmod__(self, other):
"""x.__rmod__(y) <=> y%x <=> mx.nd.modulo(y, x) """
if isinstance(other, NDArray) and other._is_np_compat():
return other.__mod__(self)
return modulo(other, self)

def __imod__(self, other):
Expand All @@ -353,20 +333,14 @@ def __imod__(self, other):

def __pow__(self, other):
"""x.__pow__(y) <=> x**y <=> mx.nd.power(x,y) """
if isinstance(other, NDArray) and other._is_np_compat():
return other.__rpow__(self)
return power(self, other)

def __rpow__(self, other):
"""x.__pow__(y) <=> y**x <=> mx.nd.power(y,x) """
if isinstance(other, NDArray) and other._is_np_compat():
return other.__pow__(self)
return power(other, self)

def __eq__(self, other):
"""x.__eq__(y) <=> x==y <=> mx.nd.equal(x, y) """
if isinstance(other, NDArray) and other._is_np_compat():
return other.__eq__(self)
return equal(self, other)

def __hash__(self):
Expand All @@ -375,32 +349,22 @@ def __hash__(self):

def __ne__(self, other):
"""x.__ne__(y) <=> x!=y <=> mx.nd.not_equal(x, y) """
if isinstance(other, NDArray) and other._is_np_compat():
return other.__ne__(self)
return not_equal(self, other)

def __gt__(self, other):
"""x.__gt__(y) <=> x>y <=> mx.nd.greater(x, y) """
if isinstance(other, NDArray) and other._is_np_compat():
return other.__lt__(self)
return greater(self, other)

def __ge__(self, other):
"""x.__ge__(y) <=> x>=y <=> mx.nd.greater_equal(x, y) """
if isinstance(other, NDArray) and other._is_np_compat():
return other.__le__(self)
return greater_equal(self, other)

def __lt__(self, other):
"""x.__lt__(y) <=> x<y <=> mx.nd.lesser(x, y) """
if isinstance(other, NDArray) and other._is_np_compat():
return other.__gt__(self)
return lesser(self, other)

def __le__(self, other):
"""x.__le__(y) <=> x<=y <=> mx.nd.less_equal(x, y) """
if isinstance(other, NDArray) and other._is_np_compat():
return other.__ge__(self)
return lesser_equal(self, other)

def __bool__(self):
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Namespace for registering numpy_extension ops for imperative programming."""

__all__ = []
20 changes: 20 additions & 0 deletions python/mxnet/symbol/numpy/_internal.py~HEAD
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Namespace for numpy internal ops."""

__all__ = []
20 changes: 20 additions & 0 deletions python/mxnet/symbol/numpy/_internal.py~HEAD_0
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Namespace for numpy internal ops."""

__all__ = []
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Namespace for numpy internal ops."""

__all__ = []
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Namespace for numpy internal ops."""

__all__ = []
Loading

0 comments on commit 04ea773

Please sign in to comment.