Skip to content

Commit

Permalink
Add warning message for fallback operators (apache#17697)
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 authored and MoisesHer committed Apr 10, 2020
1 parent 02e3c66 commit b558966
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def _np_ndarray_cls(handle, writable=True, stype=0):

_NUMPY_ARRAY_FUNCTION_DICT = {}
_NUMPY_ARRAY_UFUNC_DICT = {}
_FALLBACK_ARRAY_FUNCTION_WARNED_RECORD = {}
_FALLBACK_ARRAY_UFUNC_WARNED_RECORD = {}


@set_module('mxnet.numpy') # pylint: disable=invalid-name
Expand Down Expand Up @@ -263,6 +265,11 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # pylint: disable=
.format(name)
onp_op = _get_np_op(name)
new_inputs = [arg.asnumpy() if isinstance(arg, ndarray) else arg for arg in inputs]
if onp_op not in _FALLBACK_ARRAY_UFUNC_WARNED_RECORD:
import logging
logging.warning("np.%s is a fallback operator, "
"which is actually using official numpy's implementation", name)
_FALLBACK_ARRAY_UFUNC_WARNED_RECORD[onp_op] = True
out = onp_op(*new_inputs, **kwargs)
return _as_mx_np_array(out, ctx=inputs[0].ctx)
else:
Expand All @@ -277,6 +284,7 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad-
this function.
"""
mx_np_func = _NUMPY_ARRAY_FUNCTION_DICT.get(func, None)
func_name = func.__name__
if mx_np_func is None:
# try to fallback to official NumPy op
if is_recording():
Expand All @@ -290,6 +298,11 @@ def __array_function__(self, func, types, args, kwargs): # pylint: disable=bad-
new_kwargs = {}
for k, v in kwargs.items():
new_kwargs[k] = v.asnumpy() if isinstance(v, ndarray) else v
if func not in _FALLBACK_ARRAY_FUNCTION_WARNED_RECORD:
import logging
logging.warning("np.%s is a fallback operator, "
"which is actually using official numpy's implementation.", func_name)
_FALLBACK_ARRAY_FUNCTION_WARNED_RECORD[func] = True
out = func(*new_args, **new_kwargs)
return _as_mx_np_array(out, ctx=cur_ctx)
else:
Expand Down

0 comments on commit b558966

Please sign in to comment.