Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
change is_op_runnable's location
Browse files Browse the repository at this point in the history
  • Loading branch information
JiangZhaoh committed Oct 21, 2019
1 parent afb6dab commit 4c82f3d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
14 changes: 6 additions & 8 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from __future__ import absolute_import
import functools
import numpy as _np
from mxnet.test_utils import is_op_runnable
from . import numpy as mx_np # pylint: disable=reimported
from .numpy.multiarray import _NUMPY_ARRAY_FUNCTION_DICT, _NUMPY_ARRAY_UFUNC_DICT

Expand Down Expand Up @@ -214,6 +213,12 @@ def _register_array_function():
'trunc',
'floor',
'logical_not',
'equal',
'not_equal',
'less',
'less_equal',
'greater',
'greater_equal'
]


Expand All @@ -225,13 +230,6 @@ def _register_array_ufunc():
----------
https://numpy.org/neps/nep-0013-ufunc-overrides.html
"""
if is_op_runnable():
_NUMPY_ARRAY_UFUNC_LIST.extend(['equal',
'not_equal',
'less',
'less_equal',
'greater',
'greater_equal'])
dup = _find_duplicate(_NUMPY_ARRAY_UFUNC_LIST)
if dup is not None:
raise ValueError('Duplicate operator name {} in _NUMPY_ARRAY_UFUNC_LIST'.format(dup))
Expand Down
14 changes: 8 additions & 6 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from mxnet import np
from mxnet.test_utils import assert_almost_equal
from mxnet.test_utils import use_np
from mxnet.test_utils import is_op_runnable
from common import assertRaises, with_seed
from mxnet.numpy_dispatch_protocol import with_array_function_protocol, with_array_ufunc_protocol
from mxnet.numpy_dispatch_protocol import _NUMPY_ARRAY_FUNCTION_LIST, _NUMPY_ARRAY_UFUNC_LIST
Expand Down Expand Up @@ -934,12 +935,13 @@ def _prepare_workloads():
_add_workload_logical_not(array_pool)
_add_workload_vdot()
_add_workload_vstack(array_pool)
_add_workload_equal(array_pool)
_add_workload_not_equal(array_pool)
_add_workload_greater(array_pool)
_add_workload_greater_equal(array_pool)
_add_workload_less(array_pool)
_add_workload_less_equal(array_pool)
if is_op_runnable():
_add_workload_equal(array_pool)
_add_workload_not_equal(array_pool)
_add_workload_greater(array_pool)
_add_workload_greater_equal(array_pool)
_add_workload_less(array_pool)
_add_workload_less_equal(array_pool)


_prepare_workloads()
Expand Down

0 comments on commit 4c82f3d

Please sign in to comment.