Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
break down interoperability test
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jul 23, 2020
1 parent a94c98a commit 48a25a8
Showing 1 changed file with 31 additions and 47 deletions.
78 changes: 31 additions & 47 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -3232,30 +3232,38 @@ def _check_interoperability_helper(op_name, rel_tol, abs_tol, *args, **kwargs):
_np.testing.assert_equal(out, expected_out)


def check_interoperability(op_list):
@with_seed()
@use_np
@with_array_function_protocol
@with_array_ufunc_protocol
@pytest.mark.parametrize('name',
_NUMPY_ARRAY_FUNCTION_LIST \
+_NUMPY_ARRAY_UFUNC_LIST \
+np.fallback.__all__ \
+['linalg.{}'.format(op_name) for op_name in np.fallback_linalg.__all__])
def test_interoperability(name):
OpArgMngr.randomize_workloads()
for name in op_list:
if name in _TVM_OPS and not is_op_runnable():
continue
if name in ['shares_memory', 'may_share_memory', 'empty_like',
'__version__', 'dtype', '_NoValue']: # skip list
continue
if name in ['delete']: # https://github.com/apache/incubator-mxnet/issues/18600
continue
if name in ['full_like', 'zeros_like', 'ones_like'] and \
StrictVersion(platform.python_version()) < StrictVersion('3.0.0'):
continue
default_tols = (1e-3, 1e-4)
tols = {'linalg.tensorinv': (1e-2, 5e-3),
'linalg.solve': (1e-3, 5e-2)}
(rel_tol, abs_tol) = tols.get(name, default_tols)
print('Dispatch test:', name)
workloads = OpArgMngr.get_workloads(name)
assert workloads is not None, 'Workloads for operator `{}` has not been ' \
'added for checking interoperability with ' \
'the official NumPy.'.format(name)
for workload in workloads:
_check_interoperability_helper(name, rel_tol, abs_tol, *workload['args'], **workload['kwargs'])
if name in _TVM_OPS and not is_op_runnable():
return
if name in ['shares_memory', 'may_share_memory', 'empty_like',
'__version__', 'dtype', '_NoValue']: # skip list
return
if name in ['delete']: # https://github.com/apache/incubator-mxnet/issues/18600
return
if name in ['full_like', 'zeros_like', 'ones_like'] and \
StrictVersion(platform.python_version()) < StrictVersion('3.0.0'):
return
default_tols = (1e-3, 1e-4)
tols = {'linalg.tensorinv': (1e-2, 5e-3),
'linalg.solve': (1e-3, 5e-2)}
(rel_tol, abs_tol) = tols.get(name, default_tols)
print('Dispatch test:', name)
workloads = OpArgMngr.get_workloads(name)
assert workloads is not None, 'Workloads for operator `{}` has not been ' \
'added for checking interoperability with ' \
'the official NumPy.'.format(name)
for workload in workloads:
_check_interoperability_helper(name, rel_tol, abs_tol, *workload['args'], **workload['kwargs'])


@with_seed()
Expand All @@ -3270,27 +3278,3 @@ def test_np_memory_array_function():
assert op(data_mx[0,:,:,:], data_mx[1,:,:,:]) == op(data_np[0,:,:,:], data_np[1,:,:,:])
assert op(data_mx[0,0,0,2:5], data_mx[0,0,0,4:7]) == op(data_np[0,0,0,2:5], data_np[0,0,0,4:7])
assert op(data_mx, np.ones((5, 0))) == op(data_np, _np.ones((5, 0)))


@with_seed()
@use_np
@with_array_function_protocol
@pytest.mark.serial
def test_np_array_function_protocol():
check_interoperability(_NUMPY_ARRAY_FUNCTION_LIST)


@with_seed()
@use_np
@with_array_ufunc_protocol
@pytest.mark.serial
def test_np_array_ufunc_protocol():
check_interoperability(_NUMPY_ARRAY_UFUNC_LIST)


@with_seed()
@use_np
@pytest.mark.serial
def test_np_fallback_ops():
op_list = np.fallback.__all__ + ['linalg.{}'.format(op_name) for op_name in np.fallback_linalg.__all__]
check_interoperability(op_list)

0 comments on commit 48a25a8

Please sign in to comment.