Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
all changes
Browse files Browse the repository at this point in the history
fix sanity problem

change is_op_runnable's location

Fix
  • Loading branch information
JiangZhaoh authored and reminisce committed Oct 21, 2019
1 parent 746cbc5 commit f0e7911
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 0 deletions.
8 changes: 8 additions & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'transpose',
'unique',
'var',
'vdot',
'vstack',
'zeros_like',
'linalg.norm',
'trace',
Expand Down Expand Up @@ -214,6 +216,12 @@ def _register_array_function():
'trunc',
'floor',
'logical_not',
'equal',
'not_equal',
'less',
'less_equal',
'greater',
'greater_equal'
]


Expand Down
84 changes: 84 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@
_INT_DTYPES = [np.int8, np.int32, np.int64, np.uint8]
_FLOAT_DTYPES = [np.float16, np.float32, np.float64]
_DTYPES = _INT_DTYPES + _FLOAT_DTYPES
_TVM_OPS = [
'equal',
'not_equal',
'less',
'less_equal',
'greater',
'greater_equal'
]


class OpArgMngr(object):
Expand Down Expand Up @@ -535,6 +543,13 @@ def _add_workload_roll():

def _add_workload_stack(array_pool):
OpArgMngr.add_workload('stack', [array_pool['4x1']] * 2)
OpArgMngr.add_workload('stack', [array_pool['4x1']] * 2, 1)
OpArgMngr.add_workload('stack', [array_pool['4x1']] * 2, -1)
OpArgMngr.add_workload('stack', [array_pool['4x1']] * 2, -2)
OpArgMngr.add_workload('stack', np.random.normal(size=(2, 4, 3)), 2)
OpArgMngr.add_workload('stack', np.random.normal(size=(2, 4, 3)), -3)
OpArgMngr.add_workload('stack', np.array([[], [], []]), 1)
OpArgMngr.add_workload('stack', np.array([[], [], []]))


def _add_workload_sum():
Expand Down Expand Up @@ -590,10 +605,22 @@ def _add_workload_unique():

def _add_workload_var(array_pool):
OpArgMngr.add_workload('var', array_pool['4x1'])
OpArgMngr.add_workload('var', np.array([np.float16(1.)]))
OpArgMngr.add_workload('var', np.array([1]))
OpArgMngr.add_workload('var', np.array([1.]))
OpArgMngr.add_workload('var', np.array([[1, 2, 3], [4, 5, 6]]))
OpArgMngr.add_workload('var', np.array([[1, 2, 3], [4, 5, 6]]), 0)
OpArgMngr.add_workload('var', np.array([[1, 2, 3], [4, 5, 6]]), 1)
OpArgMngr.add_workload('var', np.array([np.nan]))
OpArgMngr.add_workload('var', np.array([1, -1, 1, -1]))
OpArgMngr.add_workload('var', np.array([1,2,3,4], dtype='f8'))


def _add_workload_zeros_like(array_pool):
OpArgMngr.add_workload('zeros_like', array_pool['4x1'])
OpArgMngr.add_workload('zeros_like', np.random.uniform(size=(3, 3)).astype(np.float64))
OpArgMngr.add_workload('zeros_like', np.random.uniform(size=(3, 3)).astype(np.float32))
OpArgMngr.add_workload('zeros_like', np.random.randint(2, size = (3, 3)))


def _add_workload_outer():
Expand Down Expand Up @@ -933,6 +960,53 @@ def _add_workload_logical_not(array_pool):
OpArgMngr.add_workload('logical_not', np.array([True, False, True, False], dtype=np.bool))


def _add_workload_vdot():
OpArgMngr.add_workload('vdot', np.random.normal(size=(2, 4)), np.random.normal(size=(4, 2)))
OpArgMngr.add_workload('vdot', np.random.normal(size=(2, 4)).astype(np.float64), np.random.normal(size=(2, 4)).astype(np.float64))


def _add_workload_vstack(array_pool):
OpArgMngr.add_workload('vstack', (array_pool['4x1'], np.random.uniform(size=(5, 1))))
OpArgMngr.add_workload('vstack', array_pool['4x1'])
OpArgMngr.add_workload('vstack', array_pool['1x1x0'])


def _add_workload_equal(array_pool):
OpArgMngr.add_workload('equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('equal', np.array([np.nan]), np.array([np.nan]))
OpArgMngr.add_workload('equal', array_pool['4x1'], array_pool['1x2'])


def _add_workload_not_equal(array_pool):
OpArgMngr.add_workload('not_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('not_equal', np.array([np.nan]), np.array([np.nan]))
OpArgMngr.add_workload('not_equal', array_pool['4x1'], array_pool['1x2'])


def _add_workload_greater(array_pool):
OpArgMngr.add_workload('greater', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('greater', array_pool['4x1'], array_pool['1x2'])
OpArgMngr.add_workload('greater', np.array([np.nan]), np.array([np.nan]))


def _add_workload_greater_equal(array_pool):
OpArgMngr.add_workload('greater_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('greater_equal', array_pool['4x1'], array_pool['1x2'])
OpArgMngr.add_workload('greater_equal', np.array([np.nan]), np.array([np.nan]))


def _add_workload_less(array_pool):
OpArgMngr.add_workload('less', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('less', array_pool['4x1'], array_pool['1x2'])
OpArgMngr.add_workload('less', np.array([np.nan]), np.array([np.nan]))


def _add_workload_less_equal(array_pool):
OpArgMngr.add_workload('less_equal', np.array([0, 1, 2, 4, 2], dtype=np.float16), np.array([-2, 5, 1, 4, 3], dtype=np.float16))
OpArgMngr.add_workload('less_equal', array_pool['4x1'], array_pool['1x2'])
OpArgMngr.add_workload('less_equal', np.array([np.nan]), np.array([np.nan]))


@use_np
def _prepare_workloads():
array_pool = {
Expand Down Expand Up @@ -1028,6 +1102,14 @@ def _prepare_workloads():
_add_workload_turnc(array_pool)
_add_workload_floor(array_pool)
_add_workload_logical_not(array_pool)
_add_workload_vdot()
_add_workload_vstack(array_pool)
_add_workload_equal(array_pool)
_add_workload_not_equal(array_pool)
_add_workload_greater(array_pool)
_add_workload_greater_equal(array_pool)
_add_workload_less(array_pool)
_add_workload_less_equal(array_pool)


_prepare_workloads()
Expand Down Expand Up @@ -1070,6 +1152,8 @@ def _check_interoperability_helper(op_name, *args, **kwargs):

def check_interoperability(op_list):
for name in op_list:
if name in _TVM_OPS and not is_op_runnable():
continue
print('Dispatch test:', name)
workloads = OpArgMngr.get_workloads(name)
assert workloads is not None, 'Workloads for operator `{}` has not been ' \
Expand Down

0 comments on commit f0e7911

Please sign in to comment.