From 48a25a89f163325a128c2c3a3888d33e9a973381 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Mon, 15 Jun 2020 10:47:51 -0700 Subject: [PATCH] break down interoperability test --- .../unittest/test_numpy_interoperability.py | 78 ++++++++----------- 1 file changed, 31 insertions(+), 47 deletions(-) diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index d6b5595036ad..b42268a5ffca 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -3232,30 +3232,38 @@ def _check_interoperability_helper(op_name, rel_tol, abs_tol, *args, **kwargs): _np.testing.assert_equal(out, expected_out) -def check_interoperability(op_list): +@with_seed() +@use_np +@with_array_function_protocol +@with_array_ufunc_protocol +@pytest.mark.parametrize('name', + _NUMPY_ARRAY_FUNCTION_LIST \ + +_NUMPY_ARRAY_UFUNC_LIST \ + +np.fallback.__all__ \ + +['linalg.{}'.format(op_name) for op_name in np.fallback_linalg.__all__]) +def test_interoperability(name): OpArgMngr.randomize_workloads() - for name in op_list: - if name in _TVM_OPS and not is_op_runnable(): - continue - if name in ['shares_memory', 'may_share_memory', 'empty_like', - '__version__', 'dtype', '_NoValue']: # skip list - continue - if name in ['delete']: # https://github.com/apache/incubator-mxnet/issues/18600 - continue - if name in ['full_like', 'zeros_like', 'ones_like'] and \ - StrictVersion(platform.python_version()) < StrictVersion('3.0.0'): - continue - default_tols = (1e-3, 1e-4) - tols = {'linalg.tensorinv': (1e-2, 5e-3), - 'linalg.solve': (1e-3, 5e-2)} - (rel_tol, abs_tol) = tols.get(name, default_tols) - print('Dispatch test:', name) - workloads = OpArgMngr.get_workloads(name) - assert workloads is not None, 'Workloads for operator `{}` has not been ' \ - 'added for checking interoperability with ' \ - 'the official NumPy.'.format(name) - for workload in workloads: - _check_interoperability_helper(name, rel_tol, abs_tol, *workload['args'], **workload['kwargs']) + if name in _TVM_OPS and not is_op_runnable(): + return + if name in ['shares_memory', 'may_share_memory', 'empty_like', + '__version__', 'dtype', '_NoValue']: # skip list + return + if name in ['delete']: # https://github.com/apache/incubator-mxnet/issues/18600 + return + if name in ['full_like', 'zeros_like', 'ones_like'] and \ + StrictVersion(platform.python_version()) < StrictVersion('3.0.0'): + return + default_tols = (1e-3, 1e-4) + tols = {'linalg.tensorinv': (1e-2, 5e-3), + 'linalg.solve': (1e-3, 5e-2)} + (rel_tol, abs_tol) = tols.get(name, default_tols) + print('Dispatch test:', name) + workloads = OpArgMngr.get_workloads(name) + assert workloads is not None, 'Workloads for operator `{}` has not been ' \ + 'added for checking interoperability with ' \ + 'the official NumPy.'.format(name) + for workload in workloads: + _check_interoperability_helper(name, rel_tol, abs_tol, *workload['args'], **workload['kwargs']) @with_seed() @@ -3270,27 +3278,3 @@ def test_np_memory_array_function(): assert op(data_mx[0,:,:,:], data_mx[1,:,:,:]) == op(data_np[0,:,:,:], data_np[1,:,:,:]) assert op(data_mx[0,0,0,2:5], data_mx[0,0,0,4:7]) == op(data_np[0,0,0,2:5], data_np[0,0,0,4:7]) assert op(data_mx, np.ones((5, 0))) == op(data_np, _np.ones((5, 0))) - - -@with_seed() -@use_np -@with_array_function_protocol -@pytest.mark.serial -def test_np_array_function_protocol(): - check_interoperability(_NUMPY_ARRAY_FUNCTION_LIST) - - -@with_seed() -@use_np -@with_array_ufunc_protocol -@pytest.mark.serial -def test_np_array_ufunc_protocol(): - check_interoperability(_NUMPY_ARRAY_UFUNC_LIST) - - -@with_seed() -@use_np -@pytest.mark.serial -def test_np_fallback_ops(): - op_list = np.fallback.__all__ + ['linalg.{}'.format(op_name) for op_name in np.fallback_linalg.__all__] - check_interoperability(op_list)