Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[FFI] Adopt PackedFunc on Numpy Imperative Invoke #20006

Closed
wants to merge 16 commits into from
4 changes: 2 additions & 2 deletions python/mxnet/_ctypes/_api_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.

"""CachedOp APIs exposed from C++."""
"""NDArray APIs exposed from C++."""

import mxnet._ffi

mxnet._ffi._init_api("cached_op", __name__)
mxnet._ffi._init_api("ndarray", __name__)
10 changes: 5 additions & 5 deletions python/mxnet/_ctypes/cached_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ def __init__(self, sym, flags=(), thread_safe=False):
self.is_np_sym = bool(isinstance(sym, _Symbol))

flags = {key: str(value) for key, value in flags}
self.handle = CachedOpHandle(_api_internal.create(
self.handle = CachedOpHandle(_api_internal.cached_op_create(
sym.handle,
flags,
thread_safe
))

def __del__(self):
_api_internal.free(self.handle)
_api_internal.cached_op_free(self.handle)

def get_optimized_symbol(self):
"""Get an optimized version of the symbol from the cached op.
Expand All @@ -66,7 +66,7 @@ def get_optimized_symbol(self):
Optimized symbol from the executor.
"""
from ..symbol import Symbol
sym_handle = SymbolHandle(_api_internal.get_optimized_symbol(self.handle))
sym_handle = SymbolHandle(_api_internal.cached_op_get_optimized_symbol(self.handle))
ret = Symbol(sym_handle)
return ret

Expand All @@ -85,7 +85,7 @@ def __call__(self, *args, **kwargs):
type_id = default_ctx.device_typeid if default_ctx else None
device_id = default_ctx.device_id if default_ctx else None
out_arg = out if out is not None and not isinstance(out, NDArrayBase) else (out, )
output_vars = _api_internal.invoke(
output_vars = _api_internal.cached_op_invoke(
self.handle,
len(args),
*args,
Expand Down Expand Up @@ -157,4 +157,4 @@ def _register_op_hook(self, callback, monitor_all=False):
if callback:
self._monitor_callback = cb_type(_monitor_callback_wrapper(callback))
callback_ptr = ctypes.cast(self._monitor_callback, ctypes.c_void_p)
_api_internal.register_op_hook(self.handle, callback_ptr, monitor_all)
_api_internal.cached_op_register_op_hook(self.handle, callback_ptr, monitor_all)
18 changes: 15 additions & 3 deletions python/mxnet/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,13 @@ def _cast_symbol_NDArray(s, dtype, is_numpy_module=False):
amp_cast = symbol.numpy._internal.amp_cast if is_numpy_module else symbol.amp_cast
return amp_cast(s, dtype=dtype)
if isinstance(s, NDArray):
amp_cast = ndarray.numpy._internal.amp_cast if is_numpy_module else ndarray.amp_cast
if is_numpy_module:
def amp_cast(s, dtype=None): # pylint: disable=function-redefined
if not isinstance(dtype, str):
dtype = np.dtype(dtype).name
return ndarray.numpy._api_internal.amp_cast(s, dtype)
else:
amp_cast = ndarray.amp_cast
if s.dtype != dtype and (s.dtype in float_types_gpu and s.context.device_type != 'cpu' or
s.dtype in float_types_cpu and s.context.device_type == 'cpu'):
return amp_cast(s, dtype=dtype)
Expand Down Expand Up @@ -106,7 +112,13 @@ def _wrap_module_functions(module, is_numpy_module, target_dtype, get_aliases, g
get_fun_to_wrap, target_precision_ops=None, conditional_fp32_ops=None,
fp32_ops=None):

nd_mod = ndarray.numpy._internal if is_numpy_module else ndarray
if is_numpy_module:
def amp_cast(s, dtype=None): # pylint: disable=function-redefined
if not isinstance(dtype, str):
dtype = np.dtype(dtype).name
return ndarray.numpy._api_internal.amp_cast(s, dtype)
else:
amp_cast = ndarray.amp_cast
sy_mod = symbol.numpy._internal if is_numpy_module else symbol

def _ndarray_wrapper(f, target_dtype, fp32_param=None, cond_arg=None):
Expand Down Expand Up @@ -194,7 +206,7 @@ def _new_fun(*args, **kwargs):
widest_type = np.float32
for arr, index, arg in symbols:
if arg.dtype != widest_type and arg.dtype == target_dtype:
arr[index] = nd_mod.amp_cast(arg, dtype=widest_type)
arr[index] = amp_cast(arg, dtype=widest_type)
else:
# Symbol case
sym_to_check = list(map(lambda x: x[2], symbols))
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/data/vision/transforms/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def hybrid_forward(self, F, x, *args):
mat = F.np.concatenate((F.np.full((3, 1), 0.2989),
F.np.full((3, 1), 0.5870),
F.np.full((3, 1), 0.114)), axis=1)
x = F.npx.cast(x, dtype='float32')
x = x.astype(dtype='float32')
gray = F.np.where(self.p < F.np.random.uniform(), x, F.np.dot(x, mat))
else:
mat = F.concat(F.full((3, 1), 0.2989),
Expand Down
156 changes: 117 additions & 39 deletions python/mxnet/ndarray/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from ..base import mx_uint, check_call, _LIB, py_str, _init_op_module, _Null, _is_np_op, _output_is_list # pylint: disable=unused-import
from ..util import use_np_shape # pylint: disable=unused-import
from .._ctypes import _api_internal # pylint: disable=unused-import


def _verify_all_np_ndarrays(op_name, func_name, args, out):
Expand Down Expand Up @@ -111,6 +112,17 @@ def _verify_all_legacy_ndarrays(op_name, func_name, args, out):
.format(op_name, func_name))


def _np_imperative_invoke(handle, ndargs, out):
"""PackedFunc based numpy operator invocation call"""
output_vars = _api_internal.invoke(handle, *ndargs, out)
if out is not None:
return out
if isinstance(output_vars, NDArrayBase):
return output_vars
else:
return list(output_vars)


# pylint: disable=too-many-locals
def _generate_ndarray_function_code(handle, op_name, func_name, signature_only=False):
"""Generate function for ndarray op by handle and function op_name."""
Expand Down Expand Up @@ -176,62 +188,131 @@ def _generate_ndarray_function_code(handle, op_name, func_name, signature_only=F
code = []
is_np_op = _is_np_op(op_name)
output_is_list = _output_is_list(op_name)
doc_str_idx = 1
if is_np_op:
doc_str_idx = 2
if arr_name:
code.append("""
def %s(*%s, **kwargs):"""%(func_name, arr_name))
if not signature_only:
if arr_name:
code.append("""
def %s(*%s, **kwargs):"""%(func_name, arr_name))
if not signature_only:
code.append("""
ndargs = []
for i in {}:
assert isinstance(i, NDArrayBase), \\
"Positional arguments must have NDArray type, " \\
"but got %s"%str(i)
ndargs.append(i)""".format(arr_name))
if dtype_name is not None:
code.append("""
if dtype_name is not None:
code.append("""
if '%s' in kwargs:
if _np.dtype(kwargs['%s']).names:
kwargs['%s'] = _np.dtype(kwargs['%s']).names[0]
else:
kwargs['%s'] = _np.dtype(kwargs['%s']).name """%(
dtype_name, dtype_name, dtype_name, dtype_name, dtype_name, dtype_name))
code.append("""
_ = kwargs.pop('name', None)
out = kwargs.pop('out', None)""")
if not signature_only:
code.append("""
_verify_all_np_ndarrays("{op_name}", "{func_name}", ndargs, out)
""".format(op_name=op_name, func_name=func_name))
code.append("""
return _imperative_invoke(%d, ndargs, kwargs.keys(), kwargs.values(), out, True, %s)"""%(
handle.value, str(output_is_list)))
else:
code.append("""
return (0,)""")
else:
code.append("""
def %s(%s):"""%(func_name, ', '.join(signature)))
if not signature_only:
code.append("""
ndargs = []""")
# NDArray args
for name in ndarg_names: # pylint: disable=redefined-argument-from-local
code.append("""
if {name} is not None:
assert isinstance({name}, NDArrayBase), \\
"Argument {name} must have NDArray type, but got %s"%str({name})
ndargs.append({name})""".format(name=name))
# kwargs
if not kwarg_names:
code.append("""
_verify_all_np_ndarrays("{op_name}", "{func_name}", ndargs, out)
""".format(op_name=op_name, func_name=func_name))
if not signature_only:
code.append("""
return _np_imperative_invoke(%d, ndargs, out)"""%(handle.value))
else:
code.append("""
return (0,)""")
else:
for name in kwarg_names: # pylint: disable=redefined-argument-from-local
code.append("""
if %s is not _Null:
kwargs['%s'] = %s"""%(name, name, name))
# dtype
if dtype_name is not None:
code.append("""
if %s is not _Null and %s is not None:
kwargs['%s'] = _np.dtype(%s).name"""%(dtype_name, dtype_name, dtype_name, dtype_name))
if not signature_only:
code.append("""
_verify_all_np_ndarrays("{op_name}", "{func_name}", ndargs, out)
""".format(op_name=op_name, func_name=func_name))
code.append("""
return _imperative_invoke(%d, ndargs, kwargs.keys(), kwargs.values(), out, True, %s)"""%(
handle.value, str(output_is_list)))
else:
code.append("""
return (0,)""")
else:
if arr_name:
code.append("""
def %s(*%s, **kwargs):"""%(func_name, arr_name))
if not signature_only:
code.append("""
ndargs = []
for i in {}:
assert isinstance(i, NDArrayBase), \\
"Positional arguments must have NDArray type, " \\
"but got %s"%str(i)
ndargs.append(i)""".format(arr_name))
if dtype_name is not None:
code.append("""
if '%s' in kwargs:
if _np.dtype(kwargs['%s']).names:
kwargs['%s'] = _np.dtype(kwargs['%s']).names[0]
else:
kwargs['%s'] = _np.dtype(kwargs['%s']).name """%(
dtype_name, dtype_name, dtype_name, dtype_name, dtype_name, dtype_name))
code.append("""
_ = kwargs.pop('name', None)
out = kwargs.pop('out', None)
keys = list(kwargs.keys())
vals = list(kwargs.values())""")
else:
code.append("""
def %s(%s):"""%(func_name, ', '.join(signature)))
if not signature_only:
else:
code.append("""
def %s(%s):"""%(func_name, ', '.join(signature)))
if not signature_only:
code.append("""
ndargs = []
keys = list(kwargs.keys())
vals = list(kwargs.values())""")
# NDArray args
for name in ndarg_names: # pylint: disable=redefined-argument-from-local
code.append("""
# NDArray args
for name in ndarg_names: # pylint: disable=redefined-argument-from-local
code.append("""
if {name} is not None:
assert isinstance({name}, NDArrayBase), \\
"Argument {name} must have NDArray type, but got %s"%str({name})
ndargs.append({name})""".format(name=name))
# kwargs
for name in kwarg_names: # pylint: disable=redefined-argument-from-local
code.append("""
# kwargs
for name in kwarg_names: # pylint: disable=redefined-argument-from-local
code.append("""
if %s is not _Null:
keys.append('%s')
vals.append(%s)"""%(name, name, name))
# dtype
if dtype_name is not None:
if is_np_op:
code.append("""
if %s is not _Null and %s is not None:
keys.append('%s')
vals.append(_np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name, dtype_name))
else:
# dtype
if dtype_name is not None:
code.append("""
if %s is not _Null:
keys.append('%s')
Expand All @@ -240,24 +321,21 @@ def %s(%s):"""%(func_name, ', '.join(signature)))
else:
vals.append(_np.dtype(%s).name) """%(dtype_name, dtype_name, dtype_name,
dtype_name, dtype_name))

verify_ndarrays_fn =\
_verify_all_np_ndarrays.__name__ if is_np_op else _verify_all_legacy_ndarrays.__name__
if not signature_only:
code.append("""
{verify_fn}("{op_name}", "{func_name}", ndargs, out)
""".format(verify_fn=verify_ndarrays_fn, op_name=op_name, func_name=func_name))
code.append("""
return _imperative_invoke(%d, ndargs, keys, vals, out, %s, %s)"""%(
handle.value, str(is_np_op), str(output_is_list)))
else:
code.append("""
if not signature_only:
code.append("""
_verify_all_legacy_ndarrays("{op_name}", "{func_name}", ndargs, out)
""".format(op_name=op_name, func_name=func_name))
code.append("""
return _imperative_invoke(%d, ndargs, keys, vals, out, False, %s)"""%(
handle.value, str(output_is_list)))
else:
code.append("""
return (0,)""")

doc_str_lines = _os.linesep+''.join([' '+s if s.strip() else s
for s in 'r"""{doc_str}"""'.format(doc_str=doc_str)
.splitlines(True)])
code.insert(doc_str_idx, doc_str_lines)
code.insert(1, doc_str_lines)
return ''.join(code), doc_str


Expand Down
3 changes: 2 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
is_np_default_dtype
from ..context import current_context
from ..ndarray import numpy as _mx_nd_np
from ..ndarray.numpy import _api_internal
from ..ndarray.numpy import _internal as _npi
from ..ndarray.ndarray import _storage_type
from ..dlpack import ndarray_from_numpy
Expand Down Expand Up @@ -1478,7 +1479,7 @@ def astype(self, dtype, order='K', casting='unsafe', subok=True, copy=True): #
if not copy and _np.dtype(dtype) == self.dtype:
return self

return _npi.cast(self, dtype=dtype)
return _api_internal.cast(self, _np.dtype(dtype).name)

def copyto(self, other):
"""Copies the value of this array to another array.
Expand Down
10 changes: 5 additions & 5 deletions src/api/cached_op_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

namespace mxnet {

MXNET_REGISTER_GLOBAL("cached_op.invoke")
MXNET_REGISTER_GLOBAL("ndarray.cached_op_invoke")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
CachedOpPtr op_shared = *static_cast<CachedOpPtr*>(args[0].value().v_handle);
// CachedOp* points to CachedOpThreadSafe object if CreateCachedOpEX
Expand Down Expand Up @@ -88,7 +88,7 @@ MXNET_REGISTER_GLOBAL("cached_op.invoke")
}
});

MXNET_REGISTER_GLOBAL("cached_op.create")
MXNET_REGISTER_GLOBAL("ndarray.cached_op_create")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(args[0].value().v_handle);
Object* flags_ptr = static_cast<Object*>(args[1].value().v_handle);
Expand All @@ -110,21 +110,21 @@ MXNET_REGISTER_GLOBAL("cached_op.create")
*ret = static_cast<void*>(out);
});

MXNET_REGISTER_GLOBAL("cached_op.free")
MXNET_REGISTER_GLOBAL("ndarray.cached_op_free")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
CachedOpPtr* g = static_cast<CachedOpPtr*>(args[0].value().v_handle);
delete g;
});

MXNET_REGISTER_GLOBAL("cached_op.get_optimized_symbol")
MXNET_REGISTER_GLOBAL("ndarray.cached_op_get_optimized_symbol")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
auto s = new nnvm::Symbol();
CachedOpPtr op = *static_cast<CachedOpPtr*>(args[0].value().v_handle);
*s = op->GetOptimizedSymbol();
*ret = static_cast<void*>(static_cast<SymbolHandle>(s));
});

MXNET_REGISTER_GLOBAL("cached_op.register_op_hook")
MXNET_REGISTER_GLOBAL("ndarray.cached_op_register_op_hook")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
CachedOpHandle handle = static_cast<CachedOpHandle>(args[0].value().v_handle);
CachedOpMonitorCallback callback = reinterpret_cast<CachedOpMonitorCallback>(
Expand Down
Loading