Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[numpy] Infra for supporting numpy ops in imperative mode and Gluon A…
Browse files Browse the repository at this point in the history
…PIs (#14758)

* Infra of new ndarray and symbol types for numpy operators

* Rename

* Fix import problem

* Refactor

* Remove redundant code

* Add docstring

* More on numpy ndarray and symbol

* Override unimplemented methdos for ndarray and _NumpySymbol

* Fix built-in methods of ndarray and _NumpySymbol

* Fix test and sanity check

* Fix pylint

* Address cr comments

* Add unit tests for ndarray and _NumpySymbol

* Add _true_divide

* Fix gpu build

* Add future import division

* More correct way of checking if an output is from a np compat op

* Fix gpu build

* Fix output ndarray/symbol types with at least one new ndarray/symbol

* Modify true_divide doc

* Fix flaky copying zero-size arrays via gpus

* Fix zero size in gluon hybridize and zeros/ones symbol not creating new symbol type

* Fix doc
  • Loading branch information
reminisce authored and haojin2 committed Jul 24, 2019
1 parent dd449b5 commit bfbdd4d
Show file tree
Hide file tree
Showing 42 changed files with 3,689 additions and 59 deletions.
29 changes: 29 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2902,6 +2902,35 @@ MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
EngineVarHandle mutable_vars_handle, int num_mutable_vars,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));
/*!
* \brief Determines if an op is a Numpy op by its name prefix.
* Every Numpy op starts with a prefix string "_numpy_".
* \param creator Operator handle
* \param is_np_op Indicator of whether creator is a numpy op handle
*/
MXNET_DLL int MXIsNumpyCompatOp(AtomicSymbolCreator creator,
int* is_np_op);
/*!
* \brief Create an NDArray from source sharing the same data chunk.
* \param src source NDArray
* \param out new NDArray sharing the same data chunck with src
*/
MXNET_DLL int MXShallowCopyNDArray(NDArrayHandle src, NDArrayHandle* out);
/*!
* \brief Create an Symbol from source sharing the same graph structure.
* \param src source Symbol
* \param out new Symbol sharing the same graph structure with src
*/
MXNET_DLL int MXShallowCopySymbol(SymbolHandle src, SymbolHandle * out);
/*!
* \brief Checks if an output of CachedOp is from a numpy op.
* \param handle CachedOp shared ptr
* \param output_idx index of the output of the CachedOp
* \param is_from_np_op indicator of whether the output is from a numpy op
*/
MXNET_DLL int MXIsCachedOpOutputFromNumpyCompatOp(CachedOpHandle handle,
int output_idx,
int* is_from_np_op);

/*!
* \brief Push an asynchronous operation to the engine.
Expand Down
9 changes: 9 additions & 0 deletions include/mxnet/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,15 @@ using FNeedRequantize = std::function<bool (const NodeAttrs& attrs)>;
using FAvoidQuantizeInput = std::function<bool (const NodeAttrs& attrs,
size_t index)>;

/*!
* \brief Indicates whether this operator is NumPy compatible.
* It is for distinguishing the operator from classic MXNet operators
* which do not support zero-dim and zero-size tensors.
* In Python, it is used to determine whether to output numpy ndarrays
* or symbols that are NumPy compatible.
*/
using TIsNumpyCompatible = bool;

} // namespace mxnet

#endif // MXNET_OP_ATTR_TYPES_H_
2 changes: 1 addition & 1 deletion python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from .base import MXNetError
from .util import is_np_shape, set_np_shape, np_shape, use_np_shape
from . import base
from . import numpy
from . import contrib
from . import ndarray
from . import ndarray as nd
from . import numpy
from . import name
# use mx.sym as short for symbol
from . import symbol as sym
Expand Down
38 changes: 28 additions & 10 deletions python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..base import _LIB
from ..base import c_str_array, c_handle_array
from ..base import NDArrayHandle, CachedOpHandle
from ..base import check_call
from ..base import check_call, _is_np_compat_op


class NDArrayBase(object):
Expand Down Expand Up @@ -55,13 +55,21 @@ def __reduce__(self):


_ndarray_cls = None
_np_ndarray_cls = None


def _set_ndarray_class(cls):
"""Set the symbolic class to be cls"""
global _ndarray_cls
_ndarray_cls = cls


def _set_np_ndarray_class(cls):
"""Set the symbolic class to be cls"""
global _np_ndarray_cls
_np_ndarray_cls = cls


def _imperative_invoke(handle, ndargs, keys, vals, out):
"""ctypes implementation of imperative invoke wrapper"""
if out is not None:
Expand Down Expand Up @@ -93,18 +101,19 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):

if original_output is not None:
return original_output
create_ndarray_fn = _np_ndarray_cls if _is_np_compat_op(handle) else _ndarray_cls
if num_output.value == 1:
return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle),
stype=out_stypes[0])
return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle),
stype=out_stypes[0])
else:
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle),
stype=out_stypes[i])
for i in range(num_output.value)]
return [create_ndarray_fn(ctypes.cast(output_vars[i], NDArrayHandle),
stype=out_stypes[i]) for i in range(num_output.value)]


class CachedOp(object):
"""Cached operator handle."""
__slots__ = ["handle"]

def __init__(self, sym, flags=()):
self.handle = CachedOpHandle()

Expand All @@ -118,6 +127,13 @@ def __init__(self, sym, flags=()):
def __del__(self):
check_call(_LIB.MXFreeCachedOp(self.handle))

def _is_from_np_compat_op(self, idx):
"""Check if the CachedOp's idx-th output is directly from a numpy op."""
is_from_np_op = ctypes.c_int(0)
check_call(_LIB.MXIsCachedOpOutputFromNumpyCompatOp(self.handle, ctypes.c_int(idx),
ctypes.byref(is_from_np_op)))
return is_from_np_op.value != 0

def __call__(self, *args, **kwargs):
"""ctypes implementation of imperative invoke wrapper"""
out = kwargs.pop('out', None)
Expand Down Expand Up @@ -152,9 +168,11 @@ def __call__(self, *args, **kwargs):
if original_output is not None:
return original_output
if num_output.value == 1:
return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle),
stype=out_stypes[0])
create_ndarray_fn = _np_ndarray_cls if self._is_from_np_compat_op(0) else _ndarray_cls
return create_ndarray_fn(ctypes.cast(output_vars[0], NDArrayHandle),
stype=out_stypes[0])
else:
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle),
stype=out_stypes[i])
return [_np_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), stype=out_stypes[i])
if self._is_from_np_compat_op(i) else
_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), stype=out_stypes[i])
for i in range(num_output.value)]
14 changes: 12 additions & 2 deletions python/mxnet/_ctypes/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@

import ctypes
from ..base import _LIB
from ..base import c_str_array, c_handle_array, c_str, mx_uint
from ..base import c_str_array, c_handle_array, c_str, mx_uint, _is_np_compat_op
from ..base import SymbolHandle
from ..base import check_call

_symbol_cls = None
_np_symbol_cls = None

class SymbolBase(object):
"""Symbol is symbolic graph."""
Expand Down Expand Up @@ -115,6 +116,12 @@ def _set_symbol_class(cls):
_symbol_cls = cls


def _set_np_symbol_class(cls):
"""Set the symbolic class to be cls"""
global _np_symbol_cls
_np_symbol_cls = cls


def _symbol_creator(handle, args, kwargs, keys, vals, name):
sym_handle = SymbolHandle()
check_call(_LIB.MXSymbolCreateAtomicSymbol(
Expand All @@ -128,7 +135,10 @@ def _symbol_creator(handle, args, kwargs, keys, vals, name):
raise TypeError(
'Operators with variable length input can only accept input'
'Symbols either as positional or keyword arguments, not both')
s = _symbol_cls(sym_handle)
if _is_np_compat_op(handle):
s = _np_symbol_cls(sym_handle)
else:
s = _symbol_cls(sym_handle)
if args:
s._compose(*args, name=name)
elif kwargs:
Expand Down
102 changes: 82 additions & 20 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def _as_list(obj):
return [obj]


_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_', '_numpy_']
_OP_NAME_PREFIX_LIST = ['_contrib_', '_linalg_', '_sparse_', '_image_', '_random_']


def _get_op_name_prefix(op_name):
Expand Down Expand Up @@ -607,15 +607,6 @@ def _init_op_module(root_namespace, module_name, make_op_func):
# use mx.nd.contrib or mx.sym.contrib from now on
contrib_module_name_old = "%s.contrib.%s" % (root_namespace, module_name)
contrib_module_old = sys.modules[contrib_module_name_old]
# special handling of registering numpy ops
# only expose mxnet.numpy.op_name to users for imperative mode.
# Symbolic mode should be used in Gluon.
if module_name == 'ndarray':
numpy_module_name = "%s.numpy" % root_namespace
numpy_module = sys.modules[numpy_module_name]
else:
numpy_module_name = None
numpy_module = None
submodule_dict = {}
for op_name_prefix in _OP_NAME_PREFIX_LIST:
submodule_dict[op_name_prefix] =\
Expand Down Expand Up @@ -654,16 +645,6 @@ def _init_op_module(root_namespace, module_name, make_op_func):
function.__module__ = contrib_module_name_old
setattr(contrib_module_old, function.__name__, function)
contrib_module_old.__all__.append(function.__name__)
elif op_name_prefix == '_numpy_' and numpy_module_name is not None:
# only register numpy ops under mxnet.numpy in imperative mode
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
# TODO(reminisce): Didn't consider third level module here, e.g. mxnet.numpy.random.
func_name = name[len(op_name_prefix):]
function = make_op_func(hdl, name, func_name)
function.__module__ = numpy_module_name
setattr(numpy_module, function.__name__, function)
numpy_module.__all__.append(function.__name__)


def _generate_op_module_signature(root_namespace, module_name, op_code_gen_func):
Expand Down Expand Up @@ -754,7 +735,88 @@ def write_all_str(module_file, module_all_list):
ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object
ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p


from .runtime import Features
if Features().is_enabled("TVM_OP"):
_LIB_TVM_OP = libinfo.find_lib_path("libtvmop")
check_call(_LIB.MXLoadTVMOp(c_str(_LIB_TVM_OP[0])))


def _sanity_check_params(func_name, unsupported_params, param_dict):
for param_name in unsupported_params:
if param_name in param_dict:
raise NotImplementedError("function {} does not support parameter {}"
.format(func_name, param_name))


_NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_']
_NP_OP_PREFIX = '_numpy_'


def _get_np_op_submodule_name(op_name):
assert op_name.startswith(_NP_OP_PREFIX)
for name in _NP_OP_SUBMODULE_LIST:
if op_name[len(_NP_OP_PREFIX):].startswith(name):
return name
return ""


def _init_np_op_module(root_namespace, module_name, make_op_func):
"""
Register numpy operators in namespaces `mxnet.numpy`, `mxnet.ndarray.numpy`
and `mxnet.symbol.numpy`. They are used in imperative mode, Gluon APIs w/o hybridization,
and Gluon APIs w/ hybridization, respectively. Essentially, operators with the same name
registered in three namespaces, respectively share the same functionality in C++ backend.
Different namespaces are needed for dispatching operator calls in Gluon's `HybridBlock` by `F`.
Parameters
----------
root_namespace : str
Top level module name, `mxnet` in the current cases.
module_name : str
Second level module name, `ndarray` or `symbol` in the current case.
make_op_func : function
Function for creating op functions.
"""
plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()

check_call(_LIB.MXListAllOpNames(ctypes.byref(size), ctypes.byref(plist)))
op_names = []
for i in range(size.value):
name = py_str(plist[i])
if name.startswith(_NP_OP_PREFIX):
op_names.append(name)

if module_name == 'numpy':
# register ops for mxnet.numpy
module_pattern = "%s.%s._op"
submodule_pattern = "%s.%s.%s"
else:
# register ops for mxnet.ndarray.numpy or mxnet.symbol.numpy
module_pattern = "%s.%s.numpy._op"
submodule_pattern = "%s.%s.numpy.%s"
module_np_op = sys.modules[module_pattern % (root_namespace, module_name)]
submodule_dict = {}
# TODO(junwu): uncomment the following lines when adding numpy ops in submodules, e.g. np.random
# for submodule_name in _NP_OP_SUBMODULE_LIST:
# submodule_dict[submodule_name] = \
# sys.modules[submodule_pattern % (root_namespace, module_name, submodule_name[1:-1])]
for name in op_names:
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
submodule_name = _get_np_op_submodule_name(name)
module_name_local = module_name
if len(submodule_name) > 0:
func_name = name[(len(_NP_OP_PREFIX) + len(submodule_name)):]
cur_module = submodule_dict[submodule_name]
module_name_local = submodule_pattern % (root_namespace,
module_name, submodule_name[1:-1])
else:
func_name = name[len(_NP_OP_PREFIX):]
cur_module = module_np_op

function = make_op_func(hdl, name, func_name)
function.__module__ = module_name_local
setattr(cur_module, function.__name__, function)
cur_module.__all__.append(function.__name__)
9 changes: 7 additions & 2 deletions python/mxnet/gluon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .. import name as _name
from .parameter import Parameter, ParameterDict, DeferredInitializationError
from .utils import _indent, _brief_print_list, HookHandle
from .. import numpy as _mx_np


class _BlockScope(object):
Expand Down Expand Up @@ -739,9 +740,13 @@ def _get_graph(self, *args):
if not self._cached_graph:
args, self._in_format = _flatten(args, "input")
if len(args) > 1:
inputs = [symbol.var('data%d'%i) for i in range(len(args))]
inputs = [symbol.var('data%d' % i).as_np_ndarray()
if isinstance(args[i], _mx_np.ndarray)
else symbol.var('data%d' % i) for i in range(len(args))]
else:
inputs = [symbol.var('data')]
inputs = [symbol.var('data').as_np_ndarray()
if isinstance(args[0], _mx_np.ndarray)
else symbol.var('data')]
grouped_inputs = _regroup(inputs, self._in_format)[0]

params = {i: j.var() for i, j in self._reg_params.items()}
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/ndarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .utils import load, load_frombuffer, save, zeros, empty, array
from .sparse import _ndarray_cls
from .ndarray import _GRAD_REQ_MAP, _DTYPE_MX_TO_NP, _DTYPE_NP_TO_MX, _new_empty_handle
from . import numpy as np

__all__ = op.__all__ + ndarray.__all__ + utils.__all__ + \
['contrib', 'linalg', 'random', 'sparse', 'image']
11 changes: 6 additions & 5 deletions python/mxnet/ndarray/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,24 @@
try:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
from .._ctypes.ndarray import NDArrayBase, CachedOp
from .._ctypes.ndarray import _set_ndarray_class, _imperative_invoke
from .._ctypes.ndarray import _set_ndarray_class, _imperative_invoke, _set_np_ndarray_class
elif _sys.version_info >= (3, 0):
from .._cy3.ndarray import NDArrayBase, CachedOp
from .._cy3.ndarray import _set_ndarray_class, _imperative_invoke
from .._cy3.ndarray import _set_ndarray_class, _imperative_invoke, _set_np_ndarray_class
else:
from .._cy2.ndarray import NDArrayBase, CachedOp
from .._cy2.ndarray import _set_ndarray_class, _imperative_invoke
from .._cy2.ndarray import _set_ndarray_class, _imperative_invoke, _set_np_ndarray_class
except ImportError:
if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0:
raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1")
from .._ctypes.ndarray import NDArrayBase, CachedOp
from .._ctypes.ndarray import _set_ndarray_class, _imperative_invoke
from .._ctypes.ndarray import _set_ndarray_class, _imperative_invoke, _set_np_ndarray_class

from ..base import _Null
try:
from .gen__internal import * # pylint: disable=unused-wildcard-import
except ImportError:
pass

__all__ = ['NDArrayBase', 'CachedOp', '_imperative_invoke', '_set_ndarray_class']
__all__ = ['NDArrayBase', 'CachedOp', '_imperative_invoke', '_set_ndarray_class',
'_set_np_ndarray_class']
Loading

0 comments on commit bfbdd4d

Please sign in to comment.