From 3bb21b1b1e9690ac280b37f78dabf0408a08a1d7 Mon Sep 17 00:00:00 2001 From: "Huang, Guangtai" Date: Thu, 24 Oct 2019 13:17:46 +0800 Subject: [PATCH] add dispatch --- python/mxnet/numpy_dispatch_protocol.py | 4 +++- .../unittest/test_numpy_interoperability.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 3682334ebbea..af28fc30f31f 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -123,7 +123,9 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'tril', 'meshgrid', 'outer', - 'einsum' + 'einsum', + 'shares_memory', + 'may_share_memory', ] diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index e16be469aa3b..863c2ff6f046 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1070,6 +1070,8 @@ def _check_interoperability_helper(op_name, *args, **kwargs): def check_interoperability(op_list): for name in op_list: + if name in ['shares_memory', 'may_share_memory']: # skip list + continue print('Dispatch test:', name) workloads = OpArgMngr.get_workloads(name) assert workloads is not None, 'Workloads for operator `{}` has not been ' \ @@ -1079,6 +1081,19 @@ def check_interoperability(op_list): _check_interoperability_helper(name, *workload['args'], **workload['kwargs']) +@with_seed() +@use_np +@with_array_function_protocol +def test_np_memory_array_function(): + ops = [_np.shares_memory, _np.may_share_memory] + for op in ops: + data_mx = np.zeros([13, 21, 23, 22], dtype=np.float32) + data_np = _np.zeros([13, 21, 23, 22], dtype=np.float32) + assert op(data_mx[0,:,:,:], data_mx[1,:,:,:]) == op(data_np[0,:,:,:], data_np[1,:,:,:]) + assert op(data_mx[0,0,0,2:5], data_mx[0,0,0,4:7]) == op(data_np[0,0,0,2:5], data_np[0,0,0,4:7]) + assert op(data_mx, np.ones((5, 0))) == op(data_np, _np.ones((5, 0))) + + @with_seed() @use_np @with_array_function_protocol